Huhujingjing
commited on
Commit
•
40fcca6
1
Parent(s):
aea6f73
Upload model
Browse files- config.json +17 -0
- configuration_gcn.py +27 -0
- modeling_gcn.py +103 -0
- pytorch_model.bin +3 -0
config.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"GCNModel"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_gcn.GCNConfig",
|
7 |
+
"AutoModel": "modeling_gcn.GCNModel"
|
8 |
+
},
|
9 |
+
"emb_input": 20,
|
10 |
+
"hidden_size": 64,
|
11 |
+
"input_feature": 64,
|
12 |
+
"model_type": "gcn",
|
13 |
+
"n_layers": 6,
|
14 |
+
"num_classes": 1,
|
15 |
+
"torch_dtype": "float32",
|
16 |
+
"transformers_version": "4.29.2"
|
17 |
+
}
|
configuration_gcn.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
class GCNConfig(PretrainedConfig):
|
4 |
+
model_type = "gcn"
|
5 |
+
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
input_feature: int=64,
|
9 |
+
emb_input: int=20,
|
10 |
+
hidden_size: int=64,
|
11 |
+
n_layers: int=6,
|
12 |
+
num_classes: int=1,
|
13 |
+
**kwargs,
|
14 |
+
):
|
15 |
+
|
16 |
+
self.input_feature = input_feature # the dimension of input feature
|
17 |
+
self.emb_input = emb_input # the embedding dimension of input feature
|
18 |
+
self.hidden_size = hidden_size # the hidden size of GCN
|
19 |
+
self.n_layers = n_layers # the number of GCN layers
|
20 |
+
self.num_classes = num_classes # the number of output classes
|
21 |
+
|
22 |
+
super().__init__(**kwargs)
|
23 |
+
|
24 |
+
|
25 |
+
if __name__ == "__main__":
|
26 |
+
gcn_config = GCNConfig(input_feature=64, emb_input=20, hidden_size=64, n_layers=6, num_classes=1)
|
27 |
+
gcn_config.save_pretrained("custom-gcn")
|
modeling_gcn.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch_geometric.nn import GCNConv
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch_scatter import scatter
|
5 |
+
from transformers import PreTrainedModel
|
6 |
+
from gcn_model.configuration_gcn import GCNConfig
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
"""
|
11 |
+
MLP Layer used after graph vector representation
|
12 |
+
"""
|
13 |
+
class MLPReadout(nn.Module):
|
14 |
+
|
15 |
+
def __init__(self, input_dim, output_dim, L=2): # L=nb_hidden_layers
|
16 |
+
super().__init__()
|
17 |
+
list_FC_layers = [nn.Linear(input_dim // 2 ** l, input_dim // 2 ** (l + 1), bias=True) for l in range(L)]
|
18 |
+
list_FC_layers.append(nn.Linear(input_dim // 2 ** L, output_dim, bias=True))
|
19 |
+
self.FC_layers = nn.ModuleList(list_FC_layers)
|
20 |
+
self.L = L
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
y = x
|
24 |
+
for l in range(self.L):
|
25 |
+
y = self.FC_layers[l](y)
|
26 |
+
y = F.relu(y)
|
27 |
+
y = self.FC_layers[self.L](y)
|
28 |
+
return y
|
29 |
+
|
30 |
+
class GCNNet(torch.nn.Module):
|
31 |
+
def __init__(self, input_feature=64, emb_input=20, hidden_size=64, n_layers=6, num_classes=1):
|
32 |
+
super(GCNNet, self).__init__()
|
33 |
+
|
34 |
+
self.embedding = torch.nn.Embedding(emb_input, hidden_size, padding_idx=0)
|
35 |
+
self.input_feature = input_feature
|
36 |
+
self.n_layers = n_layers # 2层GCN
|
37 |
+
self.num_classes = num_classes
|
38 |
+
|
39 |
+
self.conv1 = GCNConv(hidden_size, hidden_size)
|
40 |
+
|
41 |
+
self.conv2 = GCNConv(hidden_size, 32)
|
42 |
+
self.mlp = MLPReadout(32, num_classes)
|
43 |
+
|
44 |
+
def forward(self, data):
|
45 |
+
x, edge_index, batch = data.x.long(), data.edge_index, data.batch
|
46 |
+
x = self.embedding(x.reshape(-1))
|
47 |
+
|
48 |
+
for i in range(self.n_layers):
|
49 |
+
x = F.relu(self.conv1(x, edge_index))
|
50 |
+
|
51 |
+
x = F.relu(self.conv2(x, edge_index))
|
52 |
+
x = scatter(x, batch, dim=-2, reduce='mean')
|
53 |
+
x = self.mlp(x)
|
54 |
+
|
55 |
+
return x.sequeeze(-1)
|
56 |
+
|
57 |
+
class GCNModel(PreTrainedModel):
|
58 |
+
config_class = GCNConfig
|
59 |
+
|
60 |
+
def __init__(self, config):
|
61 |
+
super().__init__(config)
|
62 |
+
|
63 |
+
self.model = GCNNet(
|
64 |
+
input_feature=config.input_feature,
|
65 |
+
emb_input=config.emb_input,
|
66 |
+
hidden_size=config.hidden_size,
|
67 |
+
n_layers=config.n_layers,
|
68 |
+
num_classes=config.num_classes,
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, tensor):
|
72 |
+
return self.model.forward_features(tensor)
|
73 |
+
|
74 |
+
# class GCNModelForMolecularPrediction(PreTrainedModel):
|
75 |
+
# config_class = GCNConfig
|
76 |
+
#
|
77 |
+
# def __init__(self, config):
|
78 |
+
# super().__init__(config)
|
79 |
+
#
|
80 |
+
# self.model = GCNNet(
|
81 |
+
# input_feature=config.input_feature,
|
82 |
+
# emb_input=config.emb_input,
|
83 |
+
# hidden_size=config.hidden_size,
|
84 |
+
# n_layers=config.n_layers,
|
85 |
+
# num_classes=config.num_classes,
|
86 |
+
# )
|
87 |
+
#
|
88 |
+
# def forward(self, tensor):
|
89 |
+
# return self.model.forward_features(tensor)
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
gcn_config = GCNConfig.from_pretrained("custom-gcn")
|
94 |
+
|
95 |
+
gcnd = GCNModel(gcn_config)
|
96 |
+
gcnd.model.load_state_dict(torch.load(r'G:\Trans_MXM\gcn_model\gcn.pt'))
|
97 |
+
gcnd.save_pretrained("custom-gcn")
|
98 |
+
|
99 |
+
# gcnd1 = GCNModelForMolecularPrediction(gcn_config)
|
100 |
+
#
|
101 |
+
# gcnd1.model.load_state_dict(torch.load(r'G:\Trans_MXM\gcn_model\gcn.pt'))
|
102 |
+
# gcnd1.save_pretrained("custom-gcn")
|
103 |
+
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b49c62eec2f337c8c36abe60088848da35409fd74ea35e0027f316ec92c1cc4f
|
3 |
+
size 35716
|