Update model
Browse files- __init__.py +4 -0
- configuration_reward_model.py +8 -0
- modeling_reward_model.py +29 -0
__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .configuration_reward_model import RewardModelConfig
|
2 |
+
from .modeling_reward_model import RewardModel
|
3 |
+
|
4 |
+
__all__ = ["RewardModelConfig", "RewardModel"]
|
configuration_reward_model.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
class RewardModelConfig(PretrainedConfig):
|
4 |
+
model_type = "reward_model"
|
5 |
+
|
6 |
+
def __init__(self, attributes=None, **kwargs):
|
7 |
+
self.attributes = attributes or []
|
8 |
+
super().__init__(**kwargs)
|
modeling_reward_model.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import AutoConfig, PreTrainedModel
|
4 |
+
from .configuration_reward_model import RewardModelConfig
|
5 |
+
|
6 |
+
# Register config
|
7 |
+
AutoConfig.register("reward_model", RewardModelConfig)
|
8 |
+
|
9 |
+
class RewardModel(PreTrainedModel):
|
10 |
+
config_class = RewardModelConfig
|
11 |
+
|
12 |
+
def __init__(self, config, transformer, regression_weights):
|
13 |
+
super().__init__(config)
|
14 |
+
self.transformer = transformer
|
15 |
+
self.regression_heads = nn.Linear(
|
16 |
+
regression_weights.shape[1],
|
17 |
+
regression_weights.shape[0],
|
18 |
+
bias=False
|
19 |
+
)
|
20 |
+
self.regression_heads.weight.data = torch.from_numpy(regression_weights).float()
|
21 |
+
|
22 |
+
def forward(self, input_ids, attention_mask=None):
|
23 |
+
outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
|
24 |
+
last_hidden = outputs.last_hidden_state[:, -1, :]
|
25 |
+
return self.regression_heads(last_hidden)
|
26 |
+
|
27 |
+
# Register model
|
28 |
+
from transformers import AutoModel
|
29 |
+
AutoModel.register(RewardModelConfig, RewardModel)
|