Spaces:
Running
Running
Upload 11 files
Browse files- caller.py +100 -0
- caller_penalty.py +151 -0
- config.yaml +93 -0
- math.py +49 -0
- math_format.jinja +1 -0
- persona.jinja +1 -0
- questioner.jinja +1 -0
- r1v.py +47 -0
- r1v_format.jinja +1 -0
- runtime_env.yaml +9 -0
- solver.jinja +1 -0
caller.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import regex as re
|
16 |
+
from typing import Dict, List
|
17 |
+
import json
|
18 |
+
from mathruler.grader import extract_boxed_content, grade_answer
|
19 |
+
import os
|
20 |
+
import time
|
21 |
+
import random
|
22 |
+
import requests
|
23 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
24 |
+
|
25 |
+
STORAGE_PATH = os.getenv("STORAGE_PATH")
|
26 |
+
|
27 |
+
def generate_temp_filename(prefix="temp", suffix=".json"):
|
28 |
+
timestamp = int(time.time() * 1000)
|
29 |
+
rand_part = random.randint(0, 99999)
|
30 |
+
return f"{STORAGE_PATH}/temp_results/{prefix}_{timestamp}_{rand_part}{suffix}"
|
31 |
+
def split_list(lst, n=4):
|
32 |
+
k, m = divmod(len(lst), n)
|
33 |
+
return [lst[i*k + min(i, m):(i+1)*k + min(i+1, m)] for i in range(n)]
|
34 |
+
|
35 |
+
os.environ["NO_PROXY"] = "0.0.0.0,127.0.0.1"
|
36 |
+
|
37 |
+
def fetch(index,i):
|
38 |
+
response = requests.get(f"http://0.0.0.0:{5000+index}/hello?name={i}")
|
39 |
+
print(response)
|
40 |
+
return True
|
41 |
+
|
42 |
+
def generate_results(data):
|
43 |
+
datas = split_list(data,4)
|
44 |
+
random_names = [generate_temp_filename(prefix=f"temp_{i}", suffix=".json") for i in range(4)]
|
45 |
+
for i in range(4):
|
46 |
+
with open(random_names[i],'w') as f:
|
47 |
+
json.dump(datas[i],f,indent=4)
|
48 |
+
|
49 |
+
final_results = []
|
50 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
51 |
+
futures = [executor.submit(fetch, i,random_names[i]) for i in range(4)]
|
52 |
+
|
53 |
+
for future in as_completed(futures):
|
54 |
+
print(future.result())
|
55 |
+
|
56 |
+
for future in as_completed(futures):
|
57 |
+
with open(random_names[i].replace('.json','_results.json'),'r') as f:
|
58 |
+
final_results.extend(json.load(f))
|
59 |
+
|
60 |
+
return final_results
|
61 |
+
|
62 |
+
def format_reward(predict: str) -> float:
|
63 |
+
pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL)
|
64 |
+
format_match = re.fullmatch(pattern, predict)
|
65 |
+
return 1.0 if format_match else 0.0
|
66 |
+
|
67 |
+
|
68 |
+
def accuracy_reward(predict: str, ground_truth: str) -> float:
|
69 |
+
answer = extract_boxed_content(predict)
|
70 |
+
return 1.0 if grade_answer(answer, ground_truth) else 0.0
|
71 |
+
|
72 |
+
|
73 |
+
def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1, file_path: str = "") -> List[Dict[str, float]]:
|
74 |
+
results = []
|
75 |
+
with open('test.json','w') as f:
|
76 |
+
json.dump(predicts,f,indent=4)
|
77 |
+
for i in range(len(predicts)):
|
78 |
+
questions = re.findall(r"<question>(.*?)</question>", predicts[i], re.DOTALL)
|
79 |
+
answers = extract_boxed_content(predicts[i])
|
80 |
+
if questions and answers:
|
81 |
+
try:
|
82 |
+
question = questions[-1].strip()
|
83 |
+
answer = answers[-1].strip()
|
84 |
+
results.append({"question": question, "answer": answer})
|
85 |
+
except:
|
86 |
+
results.append({"question": "", "answer": ""})
|
87 |
+
else:
|
88 |
+
results.append({"question": "", "answer": ""})
|
89 |
+
|
90 |
+
final_results = generate_results(results)
|
91 |
+
scores = [{"overall": min(item["score"],1-item["score"]) if item['question'] else -1,"format": 1 if item['question'] else 0,"accuracy": 1 if item['answer'] else 0} for item in final_results]
|
92 |
+
return scores
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
caller_penalty.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import regex as re
|
16 |
+
from typing import Dict, List
|
17 |
+
import json
|
18 |
+
from mathruler.grader import extract_boxed_content, grade_answer
|
19 |
+
import os
|
20 |
+
import time
|
21 |
+
import random
|
22 |
+
import requests
|
23 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
24 |
+
|
25 |
+
from collections import Counter
|
26 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
27 |
+
from sklearn.cluster import AgglomerativeClustering
|
28 |
+
import numpy as np
|
29 |
+
STORAGE_PATH = os.getenv("STORAGE_PATH","/apdcephfs_sh2/share_300000800/user/chengchuang")
|
30 |
+
def _bleu_distance_matrix(sentences):
|
31 |
+
n = len(sentences)
|
32 |
+
dist = np.zeros((n, n))
|
33 |
+
smoother = SmoothingFunction().method1
|
34 |
+
for i in range(n):
|
35 |
+
for j in range(i, n):
|
36 |
+
if i == j:
|
37 |
+
score = 1.0
|
38 |
+
else:
|
39 |
+
ref = [sentences[j].split()]
|
40 |
+
hyp = sentences[i].split()
|
41 |
+
score = sentence_bleu(ref, hyp, smoothing_function=smoother)
|
42 |
+
dist[i, j] = dist[j, i] = 1 - score
|
43 |
+
return dist
|
44 |
+
|
45 |
+
def cluster_share_per_problem(
|
46 |
+
problems,
|
47 |
+
distance_threshold: float = 0.5,
|
48 |
+
linkage: str = "average"):
|
49 |
+
if not problems:
|
50 |
+
return []
|
51 |
+
print('start clustering')
|
52 |
+
start_time = time.time()
|
53 |
+
dist_mat = _bleu_distance_matrix(problems)
|
54 |
+
|
55 |
+
clustering = AgglomerativeClustering(
|
56 |
+
n_clusters=None,
|
57 |
+
distance_threshold=distance_threshold,
|
58 |
+
metric="precomputed",
|
59 |
+
linkage=linkage
|
60 |
+
)
|
61 |
+
labels = clustering.fit_predict(dist_mat)
|
62 |
+
print(f'end clustering, time: {time.time() - start_time}')
|
63 |
+
total = len(problems)
|
64 |
+
cluster_size = Counter(labels)
|
65 |
+
cluster_ratio = {lab: sz / total for lab, sz in cluster_size.items()}
|
66 |
+
|
67 |
+
proportions = [cluster_ratio[lab] for lab in labels]
|
68 |
+
return proportions
|
69 |
+
|
70 |
+
def generate_temp_filename(prefix="temp", suffix=".json"):
|
71 |
+
timestamp = int(time.time() * 1000)
|
72 |
+
rand_part = random.randint(0, 99999)
|
73 |
+
return f"{STORAGE_PATH}/temp_results/{prefix}_{timestamp}_{rand_part}{suffix}"
|
74 |
+
def split_list(lst, n=4):
|
75 |
+
k, m = divmod(len(lst), n)
|
76 |
+
return [lst[i*k + min(i, m):(i+1)*k + min(i+1, m)] for i in range(n)]
|
77 |
+
|
78 |
+
os.environ["NO_PROXY"] = "0.0.0.0,127.0.0.1"
|
79 |
+
|
80 |
+
def fetch(index,i):
|
81 |
+
response = requests.get(f"http://0.0.0.0:{5000+index}/hello?name={i}")
|
82 |
+
print(response)
|
83 |
+
return True
|
84 |
+
|
85 |
+
def generate_results(data):
|
86 |
+
datas = split_list(data,4)
|
87 |
+
random_names = [generate_temp_filename(prefix=f"temp_{i}", suffix=".json") for i in range(4)]
|
88 |
+
for i in range(4):
|
89 |
+
with open(random_names[i],'w') as f:
|
90 |
+
json.dump(datas[i],f,indent=4)
|
91 |
+
|
92 |
+
final_results = []
|
93 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
94 |
+
futures = [executor.submit(fetch, i,random_names[i]) for i in range(4)]
|
95 |
+
|
96 |
+
for future in as_completed(futures):
|
97 |
+
print(future.result())
|
98 |
+
|
99 |
+
for i in range(4):
|
100 |
+
with open(random_names[i].replace('.json','_results.json'),'r') as f:
|
101 |
+
final_results.extend(json.load(f))
|
102 |
+
# os.remove(random_names[i].replace('.json','_results.json'))
|
103 |
+
for i in range(4):
|
104 |
+
os.remove(random_names[i].replace('.json','_results.json'))
|
105 |
+
return final_results
|
106 |
+
|
107 |
+
def format_reward(predict: str) -> float:
|
108 |
+
pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL)
|
109 |
+
format_match = re.fullmatch(pattern, predict)
|
110 |
+
return 1.0 if format_match else 0.0
|
111 |
+
|
112 |
+
|
113 |
+
def accuracy_reward(predict: str, ground_truth: str) -> float:
|
114 |
+
answer = extract_boxed_content(predict)
|
115 |
+
return 1.0 if grade_answer(answer, ground_truth) else 0.0
|
116 |
+
|
117 |
+
|
118 |
+
def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1, file_path: str = "") -> List[Dict[str, float]]:
|
119 |
+
results = []
|
120 |
+
with open('test.json','w') as f:
|
121 |
+
json.dump(predicts,f,indent=4)
|
122 |
+
for i in range(len(predicts)):
|
123 |
+
questions = re.findall(r"<question>(.*?)</question>", predicts[i], re.DOTALL)
|
124 |
+
answers = extract_boxed_content(predicts[i])
|
125 |
+
if questions and answers:
|
126 |
+
try:
|
127 |
+
question = questions[-1].strip()
|
128 |
+
answer = answers[-1].strip()
|
129 |
+
results.append({"question": question, "answer": answer})
|
130 |
+
except:
|
131 |
+
results.append({"question": "", "answer": ""})
|
132 |
+
else:
|
133 |
+
results.append({"question": "", "answer": ""})
|
134 |
+
|
135 |
+
final_results = generate_results(results)
|
136 |
+
penalty = cluster_share_per_problem([result['question'] for result in final_results], distance_threshold=0.5)
|
137 |
+
# print(penalty)
|
138 |
+
assert len(penalty) == len(final_results)
|
139 |
+
scores = []
|
140 |
+
for i in range(len(final_results)):
|
141 |
+
final_score = (min(final_results[i]["score"],1-final_results[i]["score"]) if final_results[i]['question'] else -1)-penalty[i]
|
142 |
+
scores.append({"overall": final_score,"format": 1 if final_results[i]['question'] else 0,"accuracy": penalty[i]})
|
143 |
+
return scores
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
|
config.yaml
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
train_files: hiyouga/math12k@train
|
3 |
+
val_files: hiyouga/math12k@test
|
4 |
+
prompt_key: problem
|
5 |
+
answer_key: answer
|
6 |
+
image_key: images
|
7 |
+
max_prompt_length: 2048
|
8 |
+
max_response_length: 2048
|
9 |
+
rollout_batch_size: 512
|
10 |
+
val_batch_size: 1024
|
11 |
+
format_prompt: ./examples/format_prompt/math_format.jinja
|
12 |
+
override_chat_template: null
|
13 |
+
shuffle: true
|
14 |
+
seed: 1
|
15 |
+
max_pixels: 4194304
|
16 |
+
min_pixels: 262144
|
17 |
+
filter_overlong_prompts: true
|
18 |
+
|
19 |
+
algorithm:
|
20 |
+
adv_estimator: grpo
|
21 |
+
disable_kl: false
|
22 |
+
use_kl_loss: true
|
23 |
+
kl_penalty: low_var_kl
|
24 |
+
kl_coef: 1.0e-2
|
25 |
+
mock_data: test
|
26 |
+
|
27 |
+
worker:
|
28 |
+
actor:
|
29 |
+
global_batch_size: 128
|
30 |
+
micro_batch_size_per_device_for_update: 2
|
31 |
+
micro_batch_size_per_device_for_experience: 8
|
32 |
+
max_grad_norm: 1.0
|
33 |
+
padding_free: true
|
34 |
+
ulysses_sequence_parallel_size: 1
|
35 |
+
model:
|
36 |
+
model_path: Qwen/Qwen2.5-7B-Instruct
|
37 |
+
enable_gradient_checkpointing: true
|
38 |
+
trust_remote_code: false
|
39 |
+
freeze_vision_tower: false
|
40 |
+
optim:
|
41 |
+
lr: 1.0e-6
|
42 |
+
weight_decay: 1.0e-2
|
43 |
+
strategy: adamw # {adamw, adamw_bf16}
|
44 |
+
lr_warmup_ratio: 0.0
|
45 |
+
fsdp:
|
46 |
+
enable_full_shard: true
|
47 |
+
enable_cpu_offload: false
|
48 |
+
enable_rank0_init: true
|
49 |
+
offload:
|
50 |
+
offload_params: true # true: more CPU memory; false: more GPU memory
|
51 |
+
offload_optimizer: true # true: more CPU memory; false: more GPU memory
|
52 |
+
|
53 |
+
rollout:
|
54 |
+
n: 5
|
55 |
+
temperature: 1.0
|
56 |
+
top_p: 0.99
|
57 |
+
gpu_memory_utilization: 0.7
|
58 |
+
enforce_eager: false
|
59 |
+
enable_chunked_prefill: false
|
60 |
+
tensor_parallel_size: 2
|
61 |
+
limit_images: 0
|
62 |
+
val_override_config:
|
63 |
+
temperature: 1.0
|
64 |
+
n: 1
|
65 |
+
|
66 |
+
ref:
|
67 |
+
fsdp:
|
68 |
+
enable_full_shard: true
|
69 |
+
enable_cpu_offload: true # true: more CPU memory; false: more GPU memory
|
70 |
+
enable_rank0_init: true
|
71 |
+
offload:
|
72 |
+
offload_params: true
|
73 |
+
|
74 |
+
reward:
|
75 |
+
reward_type: batch
|
76 |
+
reward_function: ./examples/reward_function/math.py:compute_score
|
77 |
+
|
78 |
+
trainer:
|
79 |
+
total_epochs: 2
|
80 |
+
max_steps: null
|
81 |
+
project_name: easy_r1
|
82 |
+
experiment_name: qwen2_5_7b_math_grpo
|
83 |
+
logger: ["console", "wandb"]
|
84 |
+
nnodes: 1
|
85 |
+
n_gpus_per_node: 8
|
86 |
+
val_freq: 3 # -1 to disable
|
87 |
+
val_before_train: true
|
88 |
+
val_only: false
|
89 |
+
val_generations_to_log: 3
|
90 |
+
save_freq: 5 # -1 to disable
|
91 |
+
save_limit: 3 # -1 to disable
|
92 |
+
save_checkpoint_path: your_checkpoint_path
|
93 |
+
load_checkpoint_path: null
|
math.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import re
|
16 |
+
from typing import Dict, List
|
17 |
+
|
18 |
+
from mathruler.grader import extract_boxed_content, grade_answer
|
19 |
+
|
20 |
+
|
21 |
+
def format_reward(predict: str) -> float:
|
22 |
+
pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL)
|
23 |
+
format_match = re.fullmatch(pattern, predict)
|
24 |
+
return 1.0 if format_match else 0.0
|
25 |
+
|
26 |
+
|
27 |
+
def accuracy_reward(predict: str, ground_truth: str) -> float:
|
28 |
+
answer = extract_boxed_content(predict)
|
29 |
+
try:
|
30 |
+
return 1.0 if grade_answer(answer, ground_truth) else 0.0
|
31 |
+
except:
|
32 |
+
return 0.0
|
33 |
+
|
34 |
+
|
35 |
+
def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1) -> List[Dict[str, float]]:
|
36 |
+
scores = []
|
37 |
+
for predict, ground_truth in zip(predicts, ground_truths):
|
38 |
+
predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict) # handle qwen2.5vl-32b format
|
39 |
+
format_score = format_reward(predict)
|
40 |
+
accuracy_score = accuracy_reward(predict, ground_truth)
|
41 |
+
scores.append(
|
42 |
+
{
|
43 |
+
"overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
|
44 |
+
"format": format_score,
|
45 |
+
"accuracy": accuracy_score,
|
46 |
+
}
|
47 |
+
)
|
48 |
+
|
49 |
+
return scores
|
math_format.jinja
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{{ content | trim }} You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE put in \boxed{}.
|
persona.jinja
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
questioner_format_with_persona
|
questioner.jinja
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
questioner_format
|
r1v.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import re
|
16 |
+
from typing import Dict
|
17 |
+
|
18 |
+
from mathruler.grader import grade_answer
|
19 |
+
|
20 |
+
|
21 |
+
def format_reward(predict: str) -> float:
|
22 |
+
pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
|
23 |
+
format_match = re.fullmatch(pattern, predict)
|
24 |
+
return 1.0 if format_match else 0.0
|
25 |
+
|
26 |
+
|
27 |
+
def accuracy_reward(predict: str, ground_truth: str) -> float:
|
28 |
+
try:
|
29 |
+
content_match = re.search(r"<answer>(.*?)</answer>", predict)
|
30 |
+
given_answer = content_match.group(1).strip() if content_match else predict.strip()
|
31 |
+
if grade_answer(given_answer, ground_truth.strip()):
|
32 |
+
return 1.0
|
33 |
+
|
34 |
+
except Exception:
|
35 |
+
pass
|
36 |
+
|
37 |
+
return 0.0
|
38 |
+
|
39 |
+
|
40 |
+
def compute_score(predict: str, ground_truth: str, format_weight: float = 0.5) -> Dict[str, float]:
|
41 |
+
format_score = format_reward(predict)
|
42 |
+
accuracy_score = accuracy_reward(predict, ground_truth)
|
43 |
+
return {
|
44 |
+
"overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
|
45 |
+
"format": format_score,
|
46 |
+
"accuracy": accuracy_score,
|
47 |
+
}
|
r1v_format.jinja
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{{ content | trim }} A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>
|
runtime_env.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
working_dir: ./
|
2 |
+
excludes: ["/.git/"]
|
3 |
+
env_vars:
|
4 |
+
TOKENIZERS_PARALLELISM: "true"
|
5 |
+
NCCL_DEBUG: "WARN"
|
6 |
+
VLLM_LOGGING_LEVEL: "WARN"
|
7 |
+
TORCH_NCCL_AVOID_RECORD_STREAMS: "1"
|
8 |
+
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:False"
|
9 |
+
PYTHONUNBUFFERED: "1"
|
solver.jinja
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
solver_format
|