Schrieffer2sy commited on
Commit
1748050
·
1 Parent(s): 05a9ebf
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. app.py +43 -19
  3. assets/framework-v4.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -3,16 +3,16 @@ import torch
3
  from transformers import AutoTokenizer
4
  from sarm_llama import LlamaSARM
5
 
6
- # --- 1. 加载模型和Tokenizer ---
7
- # 这一步会自动从Hugging Face Hub下载你的模型文件
8
- # 确保你的模型仓库是公开的
9
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  MODEL_ID = "schrieffer/SARM-4B"
12
 
13
  print(f"Loading model: {MODEL_ID} on {DEVICE}...")
14
 
15
- # 加载模型时必须信任远程代码,因为SARM有自定义架构
16
  model = LlamaSARM.from_pretrained(
17
  MODEL_ID,
18
  sae_hidden_state_source_layer=16,
@@ -26,18 +26,18 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
26
 
27
  print("Model loaded successfully!")
28
 
29
- # --- 2. 定义推理函数 ---
30
- # 这个函数会被Gradio调用
31
 
32
  def get_reward_score(prompt: str, response: str) -> float:
33
  """
34
- 接收promptresponse,返回SARM模型计算出的奖励分数。
35
  """
36
  if not prompt or not response:
37
  return 0.0
38
 
39
  try:
40
- # 使用与模型训练时相同的聊天模板
41
  messages = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response}]
42
  input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(DEVICE)
43
 
@@ -47,20 +47,43 @@ def get_reward_score(prompt: str, response: str) -> float:
47
  return round(score, 4)
48
  except Exception as e:
49
  print(f"Error: {e}")
50
- # 在界面上返回一个错误提示可能更好,但这里我们简单返回0
51
  return 0.0
52
 
53
- # --- 3. 创建并启动Gradio界面 ---
54
 
55
- # 使用gr.Blocks()可以获得更灵活的布局
56
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
57
  gr.Markdown(
58
  """
59
- # SARM-4B: Interpretable Reward Model Demo
60
- This is an interactive demo for the SARM-4B model, an interpretable reward model enhanced by a Sparse Autoencoder.
61
- Enter a prompt (question) and a corresponding response below to get a reward score. A higher score indicates a better quality response according to the model.
62
 
63
- For more details, check out our [Tech Report](https://arxiv.org/abs/submit/6699218) and [Model Card](https://huggingface.co/schrieffer/SARM-4B).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  """
65
  )
66
 
@@ -71,7 +94,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
71
  calculate_btn = gr.Button("Calculate Reward Score", variant="primary")
72
  score_output = gr.Number(label="Reward Score", info="A higher score is better.")
73
 
74
- # 定义按钮点击时的行为
75
  calculate_btn.click(
76
  fn=get_reward_score,
77
  inputs=[prompt_input, response_input],
@@ -88,8 +111,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
88
  inputs=[prompt_input, response_input],
89
  outputs=score_output,
90
  fn=get_reward_score,
91
- cache_examples=True # 缓存示例结果,加快加载速度
92
  )
93
 
94
- # 启动应用
95
- demo.launch()
 
 
3
  from transformers import AutoTokenizer
4
  from sarm_llama import LlamaSARM
5
 
6
+ # --- 1. Load Model and Tokenizer ---
7
+ # This step automatically downloads your model files from the Hugging Face Hub.
8
+ # Ensure your model repository is public.
9
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  MODEL_ID = "schrieffer/SARM-4B"
12
 
13
  print(f"Loading model: {MODEL_ID} on {DEVICE}...")
14
 
15
+ # trust_remote_code=True is required because SARM has a custom architecture.
16
  model = LlamaSARM.from_pretrained(
17
  MODEL_ID,
18
  sae_hidden_state_source_layer=16,
 
26
 
27
  print("Model loaded successfully!")
28
 
29
+ # --- 2. Define the Inference Function ---
30
+ # This function will be called by Gradio.
31
 
32
  def get_reward_score(prompt: str, response: str) -> float:
33
  """
34
+ Receives a prompt and a response, and returns the reward score calculated by the SARM model.
35
  """
36
  if not prompt or not response:
37
  return 0.0
38
 
39
  try:
40
+ # Use the same chat template as used during model training.
41
  messages = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response}]
42
  input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(DEVICE)
43
 
 
47
  return round(score, 4)
48
  except Exception as e:
49
  print(f"Error: {e}")
50
+ # It might be better to return an error message on the UI, but here we simply return 0.
51
  return 0.0
52
 
53
+ # --- 3. Create and Launch the Gradio Interface ---
54
 
55
+ # Use gr.Blocks() for a more flexible layout.
56
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
57
  gr.Markdown(
58
  """
59
+ # SARM: Interpretable Reward Model Demo
 
 
60
 
61
+ This is an interactive demo for the **SARM-4B** model (Sparse Autoencoder-enhanced Reward Model).
62
+
63
+ SARM is a novel reward model architecture that enhances interpretability by integrating a pretrained Sparse Autoencoder (SAE). It maps the internal hidden states of a large language model into a sparse and human-understandable feature space, making the resulting reward scores transparent and conceptually meaningful.
64
+
65
+ **How to use this Demo:**
66
+ 1. Enter a **Prompt** (e.g., a question) in the left textbox below.
67
+ 2. Enter a corresponding **Response** in the right textbox.
68
+ 3. Click the "Calculate Reward Score" button.
69
+
70
+ The model will output a scalar score that evaluates the quality of the response. **A higher score indicates that the SARM model considers the response to be of better quality.**
71
+
72
+ ---
73
+
74
+ *SARM Architecture*
75
+ ![](https://huggingface.co/schrieffer/SARM-4B/resolve/main/sarm-framework.png?raw=true)
76
+
77
+ + **Authors** (* indicates equal contribution)
78
+
79
+ Shuyi Zhang\*, Wei Shi\*, Sihang Li\*, Jiayi Liao, Tao Liang, Hengxing Cai, Xiang Wang
80
+ + **Paper**: [Interpretable Reward Model via Sparse Autoencoder](https://arxiv.org/abs/2508.08746)
81
+
82
+ + **Model**: [schrieffer/SARM-4B](https://huggingface.co/schrieffer/SARM-4B)
83
+
84
+ + Finetuned from model: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)
85
+
86
+ + **Code Repository:** [https://github.com/schrieffer-z/sarm](https://github.com/schrieffer-z/sarm)
87
  """
88
  )
89
 
 
94
  calculate_btn = gr.Button("Calculate Reward Score", variant="primary")
95
  score_output = gr.Number(label="Reward Score", info="A higher score is better.")
96
 
97
+ # Define the button's click behavior.
98
  calculate_btn.click(
99
  fn=get_reward_score,
100
  inputs=[prompt_input, response_input],
 
111
  inputs=[prompt_input, response_input],
112
  outputs=score_output,
113
  fn=get_reward_score,
114
+ cache_examples=True # Cache the results of the examples to speed up loading.
115
  )
116
 
117
+ # Launch the application.
118
+ demo.launch()
119
+
assets/framework-v4.png ADDED

Git LFS Details

  • SHA256: 60a2e71aff1390a841c34a3f0c17290388251a61ff5395bed240047105cbed40
  • Pointer size: 131 Bytes
  • Size of remote file: 712 kB