RS2002 commited on
Commit
77a8fb4
·
verified ·
1 Parent(s): 3988493

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -2
model.py CHANGED
@@ -66,7 +66,7 @@ class ML_BART(nn.Module):
66
  super().__init__()
67
  d_model = bartconfig.d_model
68
 
69
- self.decoder_emb = nn.ModuleList([
70
  nn.Embedding(class_num[0] + 1, d_model // 4),
71
  nn.Embedding(class_num[1] + 1, d_model // 4)
72
  ])
@@ -91,7 +91,7 @@ class ML_BART(nn.Module):
91
  emb_decoder = self.encoder(x_decoder)
92
  else:
93
  emb_decoder = torch.concatenate(
94
- [self.decoder_emb[0](x_decoder[..., 0]), self.decoder_emb[1](x_decoder[..., 1]),
95
  self.decoder(x_encoder)], dim=-1)
96
 
97
  y = self.bart(inputs_embeds=emb_encoder, decoder_inputs_embeds=emb_decoder,
 
66
  super().__init__()
67
  d_model = bartconfig.d_model
68
 
69
+ self.decoder_emb2 = nn.ModuleList([
70
  nn.Embedding(class_num[0] + 1, d_model // 4),
71
  nn.Embedding(class_num[1] + 1, d_model // 4)
72
  ])
 
91
  emb_decoder = self.encoder(x_decoder)
92
  else:
93
  emb_decoder = torch.concatenate(
94
+ [self.decoder_emb2[0](x_decoder[..., 0]), self.decoder_emb2[1](x_decoder[..., 1]),
95
  self.decoder(x_encoder)], dim=-1)
96
 
97
  y = self.bart(inputs_embeds=emb_encoder, decoder_inputs_embeds=emb_decoder,