Upload model.py
Browse files
model.py
CHANGED
@@ -66,7 +66,7 @@ class ML_BART(nn.Module):
|
|
66 |
super().__init__()
|
67 |
d_model = bartconfig.d_model
|
68 |
|
69 |
-
self.
|
70 |
nn.Embedding(class_num[0] + 1, d_model // 4),
|
71 |
nn.Embedding(class_num[1] + 1, d_model // 4)
|
72 |
])
|
@@ -91,7 +91,7 @@ class ML_BART(nn.Module):
|
|
91 |
emb_decoder = self.encoder(x_decoder)
|
92 |
else:
|
93 |
emb_decoder = torch.concatenate(
|
94 |
-
[self.
|
95 |
self.decoder(x_encoder)], dim=-1)
|
96 |
|
97 |
y = self.bart(inputs_embeds=emb_encoder, decoder_inputs_embeds=emb_decoder,
|
|
|
66 |
super().__init__()
|
67 |
d_model = bartconfig.d_model
|
68 |
|
69 |
+
self.decoder_emb2 = nn.ModuleList([
|
70 |
nn.Embedding(class_num[0] + 1, d_model // 4),
|
71 |
nn.Embedding(class_num[1] + 1, d_model // 4)
|
72 |
])
|
|
|
91 |
emb_decoder = self.encoder(x_decoder)
|
92 |
else:
|
93 |
emb_decoder = torch.concatenate(
|
94 |
+
[self.decoder_emb2[0](x_decoder[..., 0]), self.decoder_emb2[1](x_decoder[..., 1]),
|
95 |
self.decoder(x_encoder)], dim=-1)
|
96 |
|
97 |
y = self.bart(inputs_embeds=emb_encoder, decoder_inputs_embeds=emb_decoder,
|