il-pugin commited on
Commit
5e746c5
·
verified ·
1 Parent(s): e055f5e

Update model

Browse files
__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .configuration_reward_model import RewardModelConfig
2
+ from .modeling_reward_model import RewardModel
3
+
4
+ __all__ = ["RewardModelConfig", "RewardModel"]
configuration_reward_model.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class RewardModelConfig(PretrainedConfig):
4
+ model_type = "reward_model"
5
+
6
+ def __init__(self, attributes=None, **kwargs):
7
+ self.attributes = attributes or []
8
+ super().__init__(**kwargs)
modeling_reward_model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoConfig, PreTrainedModel
4
+ from .configuration_reward_model import RewardModelConfig
5
+
6
+ # Register config
7
+ AutoConfig.register("reward_model", RewardModelConfig)
8
+
9
+ class RewardModel(PreTrainedModel):
10
+ config_class = RewardModelConfig
11
+
12
+ def __init__(self, config, transformer, regression_weights):
13
+ super().__init__(config)
14
+ self.transformer = transformer
15
+ self.regression_heads = nn.Linear(
16
+ regression_weights.shape[1],
17
+ regression_weights.shape[0],
18
+ bias=False
19
+ )
20
+ self.regression_heads.weight.data = torch.from_numpy(regression_weights).float()
21
+
22
+ def forward(self, input_ids, attention_mask=None):
23
+ outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
24
+ last_hidden = outputs.last_hidden_state[:, -1, :]
25
+ return self.regression_heads(last_hidden)
26
+
27
+ # Register model
28
+ from transformers import AutoModel
29
+ AutoModel.register(RewardModelConfig, RewardModel)