Gatsby767 commited on
Commit
74a5f2b
·
verified ·
1 Parent(s): 842120c

Upload 2 files

Browse files
Files changed (2) hide show
  1. question_generate.bash +15 -0
  2. question_generate.py +134 -0
question_generate.bash ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load the model name from the command line
2
+ model_name=$1
3
+ num_samples=$2
4
+ save_name=$3
5
+ export VLLM_DISABLE_COMPILE_CACHE=1
6
+ CUDA_VISIBLE_DEVICES=0 python question_generate/question_generate.py --model $model_name --suffix 0 --num_samples $num_samples --save_name $save_name &
7
+ CUDA_VISIBLE_DEVICES=1 python question_generate/question_generate.py --model $model_name --suffix 1 --num_samples $num_samples --save_name $save_name &
8
+ CUDA_VISIBLE_DEVICES=2 python question_generate/question_generate.py --model $model_name --suffix 2 --num_samples $num_samples --save_name $save_name &
9
+ CUDA_VISIBLE_DEVICES=3 python question_generate/question_generate.py --model $model_name --suffix 3 --num_samples $num_samples --save_name $save_name &
10
+ CUDA_VISIBLE_DEVICES=4 python question_generate/question_generate.py --model $model_name --suffix 4 --num_samples $num_samples --save_name $save_name &
11
+ CUDA_VISIBLE_DEVICES=5 python question_generate/question_generate.py --model $model_name --suffix 5 --num_samples $num_samples --save_name $save_name &
12
+ CUDA_VISIBLE_DEVICES=6 python question_generate/question_generate.py --model $model_name --suffix 6 --num_samples $num_samples --save_name $save_name &
13
+ CUDA_VISIBLE_DEVICES=7 python question_generate/question_generate.py --model $model_name --suffix 7 --num_samples $num_samples --save_name $save_name &
14
+
15
+ wait
question_generate.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import vllm
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+ import argparse
5
+ from typing import List
6
+ from vllm.outputs import RequestOutput
7
+ from evaluation.datasets_loader import get_dataset_handler
8
+ import json
9
+ import regex as re
10
+ import os
11
+ STORAGE_PATH = os.getenv("STORAGE_PATH")
12
+
13
+ def extract_boxed(text):
14
+ results, i = [], 0
15
+ prefix = r'\boxed{'
16
+ plen = len(prefix)
17
+
18
+ while True:
19
+ start = text.find(prefix, i)
20
+ if start == -1:
21
+ break # no more \boxed{…}
22
+
23
+ j = start + plen
24
+ depth = 1
25
+ while j < len(text) and depth:
26
+ if text[j] == '{':
27
+ depth += 1
28
+ elif text[j] == '}':
29
+ depth -= 1
30
+ j += 1
31
+
32
+ results.append(text[start + plen : j - 1])
33
+ i = j
34
+
35
+ return results
36
+
37
+ def get_response_mask(response_ids, eos_token_id, dtype):
38
+ batch_size, seq_len = response_ids.shape
39
+ mask = torch.ones((batch_size, seq_len), dtype=dtype)
40
+ for i in range(batch_size):
41
+ for j in range(seq_len):
42
+ if response_ids[i][j] == eos_token_id:
43
+ mask[i][j:] = 0
44
+ break
45
+ return mask
46
+
47
+ def main(args):
48
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
49
+ if tokenizer.pad_token is None:
50
+ tokenizer.pad_token = tokenizer.eos_token
51
+ if tokenizer.pad_token_id is None:
52
+ tokenizer.pad_token_id = tokenizer.eos_token_id
53
+ model = vllm.LLM(
54
+ model=args.model,
55
+ tokenizer=args.model,
56
+ # gpu_memory_utilization=0.8,
57
+ seed=int(args.suffix),
58
+ )
59
+ dataset_handler = get_dataset_handler("math")
60
+ questions, answers = dataset_handler.load_data()
61
+ question = questions[0]
62
+ answer = answers[0]
63
+ chat = [
64
+ {
65
+ "role": "system",
66
+ "content": (
67
+ "You are an expert competition-math problem setter.\n"
68
+ "FIRST, in your private scratch-pad, think step-by-step to design a brand-new, non-trivial problem. "
69
+ "The problem could come from any field of mathematics, including but not limited to algebra, geometry, number theory, combinatorics, prealgebra, probability, statistics, and calculus. "
70
+ "Aim for a difficulty such that fewer than 30 % of advanced high-school students could solve it. "
71
+ "Avoid re-using textbook clichés or famous contest problems.\n"
72
+ "THEN, without revealing any of your private thoughts, output **exactly** the following two blocks:\n\n"
73
+ "<question>\n"
74
+ "{The full problem statement on one or more lines}\n"
75
+ "</question>\n\n"
76
+ r"\boxed{final_answer}"
77
+ "\n\n"
78
+ "Do NOT output anything else—no explanations, no extra markup."
79
+ )
80
+ },
81
+ {
82
+ "role": "user",
83
+ "content": (
84
+ "Generate one new, challenging reasoning question now. "
85
+ "Remember to format the output exactly as instructed."
86
+ )
87
+ }
88
+ ]
89
+
90
+ if tokenizer.chat_template:
91
+ prompt = tokenizer.apply_chat_template(
92
+ chat,
93
+ tokenize=False,
94
+ add_generation_prompt=True,
95
+ add_special_tokens=True
96
+ )
97
+ else:
98
+ prompt = "system: " + chat[0]["content"] + '\n' + "user: " + chat[1]["content"]
99
+ sample_params = vllm.SamplingParams(
100
+ max_tokens=4096,
101
+ temperature=1.0,
102
+ top_p=0.95,
103
+ n=1,
104
+ stop_token_ids=[tokenizer.eos_token_id],
105
+ )
106
+
107
+ completions: List[RequestOutput] = model.generate([prompt]*args.num_samples, sampling_params=sample_params)
108
+ results=[]
109
+ for completion in completions:
110
+ response = completion.outputs[0].text
111
+ try:
112
+ questions = re.findall(r"<question>(.*?)</question>", response, re.DOTALL)
113
+ answers = extract_boxed(response)
114
+
115
+ if questions and answers:
116
+ question = questions[-1].strip()
117
+ answer = answers[-1].strip()
118
+ results.append({"question": question, "answer": answer, "score": 0})
119
+ else:
120
+ results.append({"question": response, "answer": "", "score": -1})
121
+ except:
122
+ results.append({"question": response, "answer": "", "score": -1})
123
+ with open(f"{STORAGE_PATH}/generated_question/{args.save_name}_{args.suffix}.json", "w") as f:
124
+ json.dump(results, f, indent=4)
125
+
126
+ if __name__ == "__main__":
127
+ parser = argparse.ArgumentParser()
128
+ parser.add_argument("--model", type=str, default="Qwen/Qwen3-4B")
129
+ parser.add_argument("--num_samples", type=int, default=1250, help="Number of samples to generate")
130
+ parser.add_argument("--suffix", type=str, default="", help="Suffix to add to the output file")
131
+ parser.add_argument("--save_name", type=str, default="", help="")
132
+ args = parser.parse_args()
133
+
134
+ main(args)