Model save
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +103 -0
- =0.41.0 +0 -0
- =0.6.0 +0 -0
- DSPy_Optimization.ipynb +415 -0
- InstructionFinetuning.ipynb +1277 -0
- README.md +58 -0
- SFT_Expert.py +229 -0
- TQ_template.py +37 -0
- TextGrad_Optimization.ipynb +544 -0
- adapter_config.json +45 -0
- adapter_model.safetensors +3 -0
- added_tokens.json +3 -0
- data_prep.py +380 -0
- gemma-12b-tq-model/README.md +58 -0
- gemma-12b-tq-model/adapter_config.json +45 -0
- gemma-12b-tq-model/adapter_model.safetensors +3 -0
- gemma-12b-tq-model/added_tokens.json +3 -0
- gemma-12b-tq-model/checkpoint-2/README.md +202 -0
- gemma-12b-tq-model/checkpoint-2/adapter_config.json +45 -0
- gemma-12b-tq-model/checkpoint-2/adapter_model.safetensors +3 -0
- gemma-12b-tq-model/checkpoint-2/added_tokens.json +3 -0
- gemma-12b-tq-model/checkpoint-2/optimizer.pt +3 -0
- gemma-12b-tq-model/checkpoint-2/rng_state.pth +3 -0
- gemma-12b-tq-model/checkpoint-2/scheduler.pt +3 -0
- gemma-12b-tq-model/checkpoint-2/special_tokens_map.json +33 -0
- gemma-12b-tq-model/checkpoint-2/tokenizer.json +3 -0
- gemma-12b-tq-model/checkpoint-2/tokenizer.model +3 -0
- gemma-12b-tq-model/checkpoint-2/tokenizer_config.json +0 -0
- gemma-12b-tq-model/checkpoint-2/trainer_state.json +51 -0
- gemma-12b-tq-model/checkpoint-2/training_args.bin +3 -0
- gemma-12b-tq-model/runs/Apr25_08-39-59_9945b53f-579e-4565-94fc-5fbe73c83cc2/events.out.tfevents.1745570448.9945b53f-579e-4565-94fc-5fbe73c83cc2 +3 -0
- gemma-12b-tq-model/runs/Apr25_08-42-29_9945b53f-579e-4565-94fc-5fbe73c83cc2/events.out.tfevents.1745570563.9945b53f-579e-4565-94fc-5fbe73c83cc2 +3 -0
- gemma-12b-tq-model/runs/Apr25_09-19-39_9945b53f-579e-4565-94fc-5fbe73c83cc2/events.out.tfevents.1745572788.9945b53f-579e-4565-94fc-5fbe73c83cc2 +3 -0
- gemma-12b-tq-model/special_tokens_map.json +33 -0
- gemma-12b-tq-model/tokenizer.json +3 -0
- gemma-12b-tq-model/tokenizer.model +3 -0
- gemma-12b-tq-model/tokenizer_config.json +0 -0
- gemma-12b-tq-model/training_args.bin +3 -0
- gemma-1b-tq-model/README.md +58 -0
- gemma-1b-tq-model/adapter_config.json +42 -0
- gemma-1b-tq-model/adapter_model.safetensors +3 -0
- gemma-1b-tq-model/added_tokens.json +3 -0
- gemma-1b-tq-model/checkpoint-10/README.md +202 -0
- gemma-1b-tq-model/checkpoint-10/adapter_config.json +42 -0
- gemma-1b-tq-model/checkpoint-10/adapter_model.safetensors +3 -0
- gemma-1b-tq-model/checkpoint-10/added_tokens.json +3 -0
- gemma-1b-tq-model/checkpoint-10/optimizer.pt +3 -0
- gemma-1b-tq-model/checkpoint-10/rng_state.pth +3 -0
- gemma-1b-tq-model/checkpoint-10/scheduler.pt +3 -0
- gemma-1b-tq-model/checkpoint-10/special_tokens_map.json +33 -0
.gitattributes
CHANGED
@@ -33,3 +33,106 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
gemma-12b-tq-model/checkpoint-2/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
37 |
+
gemma-12b-tq-model/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
38 |
+
gemma-1b-tq-model/checkpoint-10/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
39 |
+
gemma-1b-tq-model/checkpoint-100/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
40 |
+
gemma-1b-tq-model/checkpoint-11/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
41 |
+
gemma-1b-tq-model/checkpoint-12/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
42 |
+
gemma-1b-tq-model/checkpoint-14/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
43 |
+
gemma-1b-tq-model/checkpoint-16/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
44 |
+
gemma-1b-tq-model/checkpoint-18/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
45 |
+
gemma-1b-tq-model/checkpoint-2/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
46 |
+
gemma-1b-tq-model/checkpoint-20/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
47 |
+
gemma-1b-tq-model/checkpoint-22/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
48 |
+
gemma-1b-tq-model/checkpoint-24/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
49 |
+
gemma-1b-tq-model/checkpoint-26/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
50 |
+
gemma-1b-tq-model/checkpoint-28/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
51 |
+
gemma-1b-tq-model/checkpoint-30/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
52 |
+
gemma-1b-tq-model/checkpoint-32/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
53 |
+
gemma-1b-tq-model/checkpoint-33/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
54 |
+
gemma-1b-tq-model/checkpoint-34/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
55 |
+
gemma-1b-tq-model/checkpoint-36/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
56 |
+
gemma-1b-tq-model/checkpoint-38/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
57 |
+
gemma-1b-tq-model/checkpoint-4/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
58 |
+
gemma-1b-tq-model/checkpoint-40/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
59 |
+
gemma-1b-tq-model/checkpoint-42/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
60 |
+
gemma-1b-tq-model/checkpoint-44/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
61 |
+
gemma-1b-tq-model/checkpoint-46/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
62 |
+
gemma-1b-tq-model/checkpoint-48/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
63 |
+
gemma-1b-tq-model/checkpoint-5/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
64 |
+
gemma-1b-tq-model/checkpoint-50/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
65 |
+
gemma-1b-tq-model/checkpoint-52/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
66 |
+
gemma-1b-tq-model/checkpoint-54/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
67 |
+
gemma-1b-tq-model/checkpoint-56/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
68 |
+
gemma-1b-tq-model/checkpoint-58/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
69 |
+
gemma-1b-tq-model/checkpoint-6/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
70 |
+
gemma-1b-tq-model/checkpoint-60/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
71 |
+
gemma-1b-tq-model/checkpoint-62/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
72 |
+
gemma-1b-tq-model/checkpoint-64/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
73 |
+
gemma-1b-tq-model/checkpoint-66/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
74 |
+
gemma-1b-tq-model/checkpoint-68/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
75 |
+
gemma-1b-tq-model/checkpoint-70/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
76 |
+
gemma-1b-tq-model/checkpoint-72/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
77 |
+
gemma-1b-tq-model/checkpoint-74/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
78 |
+
gemma-1b-tq-model/checkpoint-76/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
79 |
+
gemma-1b-tq-model/checkpoint-78/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
80 |
+
gemma-1b-tq-model/checkpoint-8/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
81 |
+
gemma-1b-tq-model/checkpoint-80/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
82 |
+
gemma-1b-tq-model/checkpoint-82/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
83 |
+
gemma-1b-tq-model/checkpoint-84/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
84 |
+
gemma-1b-tq-model/checkpoint-86/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
85 |
+
gemma-1b-tq-model/checkpoint-88/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
86 |
+
gemma-1b-tq-model/checkpoint-90/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
87 |
+
gemma-1b-tq-model/checkpoint-92/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
88 |
+
gemma-1b-tq-model/checkpoint-94/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
89 |
+
gemma-1b-tq-model/checkpoint-96/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
90 |
+
gemma-1b-tq-model/checkpoint-98/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
91 |
+
gemma-1b-tq-model/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
92 |
+
gemma-27b-tq_sft_finetuned-model/checkpoint-129/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
93 |
+
gemma-27b-tq_sft_finetuned-model/checkpoint-131/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
94 |
+
gemma-27b-tq_sft_finetuned-model/checkpoint-133/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
95 |
+
gemma-27b-tq_sft_finetuned-model/checkpoint-256/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
96 |
+
gemma-27b-tq_sft_finetuned-model/checkpoint-264/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
97 |
+
gemma-27b-tq_sft_finetuned-model/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
98 |
+
gemma-27b-tq_sft_finetuned-model-full/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
99 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-10/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
100 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-1040/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
101 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-116/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
102 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-1170/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
103 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-124/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
104 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-125/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
105 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-130/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
106 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-1300/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
107 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-140/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
108 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-186/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
109 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-248/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
110 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-250/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
111 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-260/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
112 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-29/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
113 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-3/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
114 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-310/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
115 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-372/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
116 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-375/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
117 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-390/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
118 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-434/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
119 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-496/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
120 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-500/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
121 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-520/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
122 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-558/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
123 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-58/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
124 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-6/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
125 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-61/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
126 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-610/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
127 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-62/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
128 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-625/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
129 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-650/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
130 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-780/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
131 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-87/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
132 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-9/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
133 |
+
gemma-4b-tq_sft_finetuned-model/checkpoint-910/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
134 |
+
gemma-4b-tq_sft_finetuned-model/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
135 |
+
may13-gemma-27b-tq_sft_finetuned-model/checkpoint-80/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
136 |
+
may13-gemma-27b-tq_sft_finetuned-model/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
137 |
+
merged_model/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
138 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
=0.41.0
ADDED
File without changes
|
=0.6.0
ADDED
File without changes
|
DSPy_Optimization.ipynb
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 33,
|
6 |
+
"id": "8b3ee6e2-ca9c-40fa-b4c6-a9596f075f79",
|
7 |
+
"metadata": {
|
8 |
+
"execution": {
|
9 |
+
"iopub.execute_input": "2025-04-22T23:03:20.101831Z",
|
10 |
+
"iopub.status.busy": "2025-04-22T23:03:20.101435Z",
|
11 |
+
"iopub.status.idle": "2025-04-22T23:03:20.105088Z",
|
12 |
+
"shell.execute_reply": "2025-04-22T23:03:20.104580Z",
|
13 |
+
"shell.execute_reply.started": "2025-04-22T23:03:20.101804Z"
|
14 |
+
}
|
15 |
+
},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"import dspy\n",
|
19 |
+
"from dspy.teleprompt import MIPROv2\n",
|
20 |
+
"from typing import List, Dict\n",
|
21 |
+
"import json\n",
|
22 |
+
"import numpy as np\n",
|
23 |
+
"import os\n",
|
24 |
+
"import random\n",
|
25 |
+
"from tqdm import tqdm"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": 31,
|
31 |
+
"id": "4ec9a29b-9162-4fe3-b32d-4de4397c6483",
|
32 |
+
"metadata": {
|
33 |
+
"execution": {
|
34 |
+
"iopub.execute_input": "2025-04-22T23:00:21.439753Z",
|
35 |
+
"iopub.status.busy": "2025-04-22T23:00:21.439342Z",
|
36 |
+
"iopub.status.idle": "2025-04-22T23:00:21.526091Z",
|
37 |
+
"shell.execute_reply": "2025-04-22T23:00:21.525575Z",
|
38 |
+
"shell.execute_reply.started": "2025-04-22T23:00:21.439727Z"
|
39 |
+
}
|
40 |
+
},
|
41 |
+
"outputs": [
|
42 |
+
{
|
43 |
+
"name": "stderr",
|
44 |
+
"output_type": "stream",
|
45 |
+
"text": [
|
46 |
+
"100%|██████████| 4/4 [00:00<00:00, 77.75it/s]\n"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"data": {
|
51 |
+
"text/plain": [
|
52 |
+
"{'input': {'src_text': 'Ma io che ne so, comandà? Io stavo a casa di mia madre, lo sapete.\\n\\nLo so.',\n",
|
53 |
+
" 'tgt_text': \"What do I know, Commander? I was at my mom's house, you know it.\\n\\nI knows.\",\n",
|
54 |
+
" 'src_prev': \"Questa è una linea. Qua faccio quello che voglio, è terra mia, la legge è mia. Dall'altro lato c'è un mondo fatto di spazzatura. Questa linea non l'ho mai oltrepassata. Impara chi è tua madre una volta per tutte. Tieni, questo era per te. Mà… Mà! Secondo me non è stata lei. Come al solito ti sei fatto prendere per il culo. Comandà, credo che non è stata lei. Carmine, sei uno stronzo. Robè, portalo via. Andiamocene.\",\n",
|
55 |
+
" 'src_next': 'E allora che altro vi devo dire? Tu non devi dire niente. Devi tenere la bocca chiusa. E non dire a nessuno quello che ti ho detto. Ma a nessuno però. Ho capito. Però devi tenere le orecchie aperte e ascoltare tutto quello che si dice qua dentro. Perché prima o poi, chi fa queste cose parla. Si deve atteggiare, si deve fare grosso. Che si è divertito con la moglie del comandante. Secondo me vi sbagliate, comandà. Non può essere stato nessuno che sta qua dentro. Lo so.',\n",
|
56 |
+
" 'tgt_prev': \"This is a line. Here I do whatever I want, it's my territory, it's my law. On the other side there's a world of trash. I've never crossed that line. Learn who your mother is, once and for all. Here, this was for you. Ma… Ma! I don't think it was her. As usual you let her fuck you around. Commander, I think she didn't do it. Carmine, you're an asshole. Robè, take him away. Let's go.\",\n",
|
57 |
+
" 'tgt_next': \"So what else should I say? You don't have to say anything. You have to keep your mouth shut. And don't tell anybody what I told you. To nobody. Got it. But keep your ears open and listen to what they say in here. Because sooner or later, guys who do such things talk. They need to swagger, act like big guys. Bragging they had fun with the Commander's wife. I think you're wrong, Commander. It can't have been anyone who's in here. I know.\",\n",
|
58 |
+
" 'src_lang': 'it',\n",
|
59 |
+
" 'tgt_lang': 'en'},\n",
|
60 |
+
" 'evaluation': {'Accuracy Issues': [],\n",
|
61 |
+
" 'Readability Issues': [],\n",
|
62 |
+
" 'Accuracy Score': '4',\n",
|
63 |
+
" 'Readability Score': '4',\n",
|
64 |
+
" 'Confidence Level': 'the_translation_is_excellent_without_any_error_spans_and_no_creative_liberties_were_taken',\n",
|
65 |
+
" 'Main Vs Alternate': 'Alternate Translated Text has marginally better quality',\n",
|
66 |
+
" 'Score': 32}}"
|
67 |
+
]
|
68 |
+
},
|
69 |
+
"execution_count": 31,
|
70 |
+
"metadata": {},
|
71 |
+
"output_type": "execute_result"
|
72 |
+
}
|
73 |
+
],
|
74 |
+
"source": [
|
75 |
+
"data_path = \"/root/notebooks/MT_TQ/TQ/DataPrep_Prompting_Experiments/labeled_data/parsed/\"\n",
|
76 |
+
"json_files = [os.path.join(root, file) for root, _, files in os.walk(data_path) for file in files if file.endswith('.json') and 'PLDL' in file]\n",
|
77 |
+
"\n",
|
78 |
+
"training_samples = []\n",
|
79 |
+
"for json_file in tqdm(json_files):\n",
|
80 |
+
" with open(json_file, 'r') as file:\n",
|
81 |
+
" data = json.load(file)\n",
|
82 |
+
" sampled_items = random.sample(data[\"data\"], 20)\n",
|
83 |
+
" training_samples.extend(sampled_items)\n",
|
84 |
+
"\n",
|
85 |
+
"datapoints = []\n",
|
86 |
+
"\n",
|
87 |
+
"for sample in training_samples:\n",
|
88 |
+
" datapoint = {\"input\":{}}\n",
|
89 |
+
" datapoint[\"input\"][\"src_text\"] = sample[\"main_src_text\"]\n",
|
90 |
+
" datapoint[\"input\"][\"tgt_text\"] = sample[\"tgt_text\"]\n",
|
91 |
+
" datapoint[\"input\"][\"src_prev\"] = sample[\"tt_src_prev\"]\n",
|
92 |
+
" datapoint[\"input\"][\"src_next\"] = sample[\"tt_src_next\"]\n",
|
93 |
+
" datapoint[\"input\"][\"tgt_prev\"] = sample[\"tt_tgt_prev\"]\n",
|
94 |
+
" datapoint[\"input\"][\"tgt_next\"] = sample[\"tt_tgt_next\"]\n",
|
95 |
+
" datapoint[\"input\"][\"src_lang\"] = sample[\"src_lang\"]\n",
|
96 |
+
" datapoint[\"input\"][\"tgt_lang\"] = sample[\"tgt_lang\"]\n",
|
97 |
+
" datapoint[\"evaluation\"] = sample[\"labelers\"][0][\"annotation\"]\n",
|
98 |
+
" datapoints.append(datapoint)\n",
|
99 |
+
"\n",
|
100 |
+
"datapoint"
|
101 |
+
]
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"cell_type": "code",
|
105 |
+
"execution_count": 35,
|
106 |
+
"id": "bde34303-2f52-415f-b117-264e266b84f0",
|
107 |
+
"metadata": {
|
108 |
+
"execution": {
|
109 |
+
"iopub.execute_input": "2025-04-22T23:04:16.302953Z",
|
110 |
+
"iopub.status.busy": "2025-04-22T23:04:16.302402Z",
|
111 |
+
"iopub.status.idle": "2025-04-22T23:04:16.334330Z",
|
112 |
+
"shell.execute_reply": "2025-04-22T23:04:16.333644Z",
|
113 |
+
"shell.execute_reply.started": "2025-04-22T23:04:16.302928Z"
|
114 |
+
}
|
115 |
+
},
|
116 |
+
"outputs": [
|
117 |
+
{
|
118 |
+
"ename": "AttributeError",
|
119 |
+
"evalue": "module 'dspy' has no attribute 'Predictor'",
|
120 |
+
"output_type": "error",
|
121 |
+
"traceback": [
|
122 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
123 |
+
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
124 |
+
"Cell \u001b[0;32mIn[35], line 28\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m prediction\u001b[38;5;241m.\u001b[39mevaluation\n\u001b[1;32m 27\u001b[0m \u001b[38;5;66;03m# Create a custom predictor using your Netflix model\u001b[39;00m\n\u001b[0;32m---> 28\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mNetflixPredictor\u001b[39;00m(\u001b[43mdspy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPredictor\u001b[49m):\n\u001b[1;32m 29\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, model):\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m=\u001b[39m model\n",
|
125 |
+
"\u001b[0;31mAttributeError\u001b[0m: module 'dspy' has no attribute 'Predictor'"
|
126 |
+
]
|
127 |
+
}
|
128 |
+
],
|
129 |
+
"source": [
|
130 |
+
"class TranslationQualityChecker(dspy.Signature):\n",
|
131 |
+
" \"\"\"Evaluate the quality of translation.\"\"\"\n",
|
132 |
+
" \n",
|
133 |
+
" context = dspy.InputField(desc=\"Source and target text with context\")\n",
|
134 |
+
" evaluation = dspy.OutputField(desc=\"Detailed evaluation of the translation quality\")\n",
|
135 |
+
"\n",
|
136 |
+
"class TranslationQualityModule(dspy.Module):\n",
|
137 |
+
" def __init__(self):\n",
|
138 |
+
" super().__init__()\n",
|
139 |
+
" self.checker = dspy.Predict(TranslationQualityChecker)\n",
|
140 |
+
" \n",
|
141 |
+
" def forward(self, src_text, tgt_text, src_prev, tgt_prev, src_next, tgt_next, src_lang, tgt_lang):\n",
|
142 |
+
" context = {\n",
|
143 |
+
" \"source_text\": src_text,\n",
|
144 |
+
" \"target_text\": tgt_text,\n",
|
145 |
+
" \"source_previous\": src_prev,\n",
|
146 |
+
" \"target_previous\": tgt_prev,\n",
|
147 |
+
" \"source_next\": src_next,\n",
|
148 |
+
" \"target_next\": tgt_next,\n",
|
149 |
+
" \"source_language\": src_lang,\n",
|
150 |
+
" \"target_language\": tgt_lang\n",
|
151 |
+
" }\n",
|
152 |
+
" \n",
|
153 |
+
" prediction = self.checker(context=context)\n",
|
154 |
+
" return prediction.evaluation\n",
|
155 |
+
"\n",
|
156 |
+
"# Create a custom backend using your Netflix model\n",
|
157 |
+
"class NetflixBackend(dspy.BackendBase):\n",
|
158 |
+
" def __init__(self, model):\n",
|
159 |
+
" super().__init__()\n",
|
160 |
+
" self.model = model\n",
|
161 |
+
" \n",
|
162 |
+
" def complete(self, prompt, **kwargs):\n",
|
163 |
+
" messages = [{\"role\": \"user\", \"content\": prompt}]\n",
|
164 |
+
" response = self.model.generate(messages)\n",
|
165 |
+
" return response\n",
|
166 |
+
"\n",
|
167 |
+
" def completions(self, prompts, **kwargs):\n",
|
168 |
+
" return [self.complete(prompt, **kwargs) for prompt in prompts]\n",
|
169 |
+
"\n",
|
170 |
+
"# Prepare training data\n",
|
171 |
+
"def prepare_training_data(data_points):\n",
|
172 |
+
" compiled_data = []\n",
|
173 |
+
" for dp in data_points:\n",
|
174 |
+
" input_data = dp['input']\n",
|
175 |
+
" train_example = dspy.Example(\n",
|
176 |
+
" context={\n",
|
177 |
+
" \"source_text\": input_data['src_text'],\n",
|
178 |
+
" \"target_text\": input_data['tgt_text'],\n",
|
179 |
+
" \"source_previous\": input_data['src_prev'],\n",
|
180 |
+
" \"target_previous\": input_data['tgt_prev'],\n",
|
181 |
+
" \"source_next\": input_data['src_next'],\n",
|
182 |
+
" \"target_next\": input_data['tgt_next'],\n",
|
183 |
+
" \"source_language\": input_data['src_lang'],\n",
|
184 |
+
" \"target_language\": input_data['tgt_lang']\n",
|
185 |
+
" },\n",
|
186 |
+
" evaluation=dp['evaluation']\n",
|
187 |
+
" )\n",
|
188 |
+
" compiled_data.append(train_example)\n",
|
189 |
+
" return compiled_data\n",
|
190 |
+
"\n",
|
191 |
+
"def optimize_prompt(model, training_data, validation_data):\n",
|
192 |
+
" # Initialize DSPy with your custom backend\n",
|
193 |
+
" backend = NetflixBackend(model)\n",
|
194 |
+
" dspy.settings.configure(lm=backend)\n",
|
195 |
+
" \n",
|
196 |
+
" # Create the optimizer\n",
|
197 |
+
" optimizer = MIPROv2(\n",
|
198 |
+
" metric=\"exact_match\", # or another appropriate metric\n",
|
199 |
+
" max_rounds=5,\n",
|
200 |
+
" max_prompts=3,\n",
|
201 |
+
" temp=0.7\n",
|
202 |
+
" )\n",
|
203 |
+
" \n",
|
204 |
+
" # Compile the module\n",
|
205 |
+
" translation_module = TranslationQualityModule()\n",
|
206 |
+
" \n",
|
207 |
+
" # Optimize the prompt\n",
|
208 |
+
" optimized_module = optimizer.optimize(\n",
|
209 |
+
" module=translation_module,\n",
|
210 |
+
" trainset=training_data,\n",
|
211 |
+
" valset=validation_data,\n",
|
212 |
+
" metric=dspy.evaluate.answer_exact_match\n",
|
213 |
+
" )\n",
|
214 |
+
" \n",
|
215 |
+
" return optimized_module"
|
216 |
+
]
|
217 |
+
},
|
218 |
+
{
|
219 |
+
"cell_type": "code",
|
220 |
+
"execution_count": null,
|
221 |
+
"id": "67a4583f-162c-4e2d-b061-798f6c676a28",
|
222 |
+
"metadata": {},
|
223 |
+
"outputs": [],
|
224 |
+
"source": [
|
225 |
+
"class TranslationQualityAssessor(dspy.Module):\n",
|
226 |
+
" def __init__(self):\n",
|
227 |
+
" super().__init__()\n",
|
228 |
+
" self.assess = dspy.ChainOfThought(TranslationQualitySignature)\n",
|
229 |
+
"\n",
|
230 |
+
" def forward(self, src_lang, tgt_lang, src_text, translation, src_prev=\"\", tgt_prev=\"\", src_next=\"\", tgt_next=\"\"):\n",
|
231 |
+
" context = f\"\"\"Previous Context:\n",
|
232 |
+
" Source: {src_prev}\n",
|
233 |
+
" Translation: {tgt_prev}\n",
|
234 |
+
" \n",
|
235 |
+
" Next Context:\n",
|
236 |
+
" Source: {src_next}\n",
|
237 |
+
" Translation: {tgt_next}\"\"\"\n",
|
238 |
+
"\n",
|
239 |
+
" result = self.assess(\n",
|
240 |
+
" context=context,\n",
|
241 |
+
" source=f\"Source ({src_lang}): {src_text}\",\n",
|
242 |
+
" translation=f\"Translation ({tgt_lang}): {translation}\"\n",
|
243 |
+
" )\n",
|
244 |
+
" \n",
|
245 |
+
" return result.evaluation\n",
|
246 |
+
"\n",
|
247 |
+
"class TranslationMetrics:\n",
|
248 |
+
" @staticmethod\n",
|
249 |
+
" def exact_match_score(pred, gold):\n",
|
250 |
+
" try:\n",
|
251 |
+
" pred_json = json.loads(pred)\n",
|
252 |
+
" gold_json = gold\n",
|
253 |
+
" \n",
|
254 |
+
" accuracy_match = (str(pred_json.get('Accuracy Score')) == str(gold_json.get('Accuracy Score')))\n",
|
255 |
+
" readability_match = (str(pred_json.get('Readability Score')) == str(gold_json.get('Readability Score')))\n",
|
256 |
+
" \n",
|
257 |
+
" return (accuracy_match and readability_match)\n",
|
258 |
+
" except:\n",
|
259 |
+
" return False\n",
|
260 |
+
" \n",
|
261 |
+
" @staticmethod\n",
|
262 |
+
" def partial_match_score(pred, gold):\n",
|
263 |
+
" try:\n",
|
264 |
+
" pred_json = json.loads(pred)\n",
|
265 |
+
" gold_json = gold\n",
|
266 |
+
" \n",
|
267 |
+
" # Score comparison\n",
|
268 |
+
" accuracy_diff = abs(float(pred_json.get('Accuracy Score', 0)) - float(gold_json.get('Accuracy Score', 0)))\n",
|
269 |
+
" readability_diff = abs(float(pred_json.get('Readability Score', 0)) - float(gold_json.get('Readability Score', 0)))\n",
|
270 |
+
" \n",
|
271 |
+
" # Issues comparison\n",
|
272 |
+
" pred_accuracy_issues = set(str(issue) for issue in pred_json.get('Accuracy Issues', []))\n",
|
273 |
+
" gold_accuracy_issues = set(str(issue) for issue in gold_json.get('Accuracy Issues', []))\n",
|
274 |
+
" pred_readability_issues = set(str(issue) for issue in pred_json.get('Readability Issues', []))\n",
|
275 |
+
" gold_readability_issues = set(str(issue) for issue in gold_json.get('Readability Issues', []))\n",
|
276 |
+
" \n",
|
277 |
+
" # Calculate Jaccard similarity for issues\n",
|
278 |
+
" accuracy_issues_sim = len(pred_accuracy_issues & gold_accuracy_issues) / max(1, len(pred_accuracy_issues | gold_accuracy_issues))\n",
|
279 |
+
" readability_issues_sim = len(pred_readability_issues & gold_readability_issues) / max(1, len(pred_readability_issues | gold_readability_issues))\n",
|
280 |
+
" \n",
|
281 |
+
" # Combine scores (0.6 weight to scores, 0.4 to issues similarity)\n",
|
282 |
+
" score_component = 1 - ((accuracy_diff + readability_diff) / 8)\n",
|
283 |
+
" issues_component = (accuracy_issues_sim + readability_issues_sim) / 2\n",
|
284 |
+
" \n",
|
285 |
+
" final_score = 0.6 * score_component + 0.4 * issues_component\n",
|
286 |
+
" return max(0, final_score)\n",
|
287 |
+
" except:\n",
|
288 |
+
" return 0\n",
|
289 |
+
"\n",
|
290 |
+
"def prepare_dataset(file_path):\n",
|
291 |
+
" with open(file_path, 'r') as f:\n",
|
292 |
+
" data = json.load(f)\n",
|
293 |
+
" \n",
|
294 |
+
" prepared_data = []\n",
|
295 |
+
" \n",
|
296 |
+
" for item in data:\n",
|
297 |
+
" example = dspy.Example(\n",
|
298 |
+
" context=f\"\"\"Previous Context:\n",
|
299 |
+
" Source: {item['src_prev']}\n",
|
300 |
+
" Translation: {item['tgt_prev']}\n",
|
301 |
+
" \n",
|
302 |
+
" Next Context:\n",
|
303 |
+
" Source: {item['src_next']}\n",
|
304 |
+
" Translation: {item['tgt_next']}\"\"\",\n",
|
305 |
+
" source=f\"Source ({item['src_lang']}): {item['src_text']}\",\n",
|
306 |
+
" translation=f\"Translation ({item['tgt_lang']}): {item['main_text']}\",\n",
|
307 |
+
" evaluation=json.dumps(item['evaluation'], ensure_ascii=False)\n",
|
308 |
+
" ).with_inputs(\"context\", \"source\", \"translation\")\n",
|
309 |
+
" \n",
|
310 |
+
" prepared_data.append(example)\n",
|
311 |
+
" \n",
|
312 |
+
" # Split data: 70% train, 15% dev, 15% test\n",
|
313 |
+
" train_size = int(0.7 * len(prepared_data))\n",
|
314 |
+
" dev_size = int(0.15 * len(prepared_data))\n",
|
315 |
+
" \n",
|
316 |
+
" train_data = prepared_data[:train_size]\n",
|
317 |
+
" dev_data = prepared_data[train_size:train_size + dev_size]\n",
|
318 |
+
" test_data = prepared_data[train_size + dev_size:]\n",
|
319 |
+
" \n",
|
320 |
+
" return train_data, dev_data, test_data\n",
|
321 |
+
"\n",
|
322 |
+
"def optimize_translation_quality_assessment():\n",
|
323 |
+
" # Initialize DSPy\n",
|
324 |
+
" lm = TranslationQualityLM()\n",
|
325 |
+
" dspy.settings.configure(lm=lm)\n",
|
326 |
+
" \n",
|
327 |
+
" # Load and prepare dataset\n",
|
328 |
+
" train_data, dev_data, test_data = prepare_dataset('translation_quality_dataset.json')\n",
|
329 |
+
" \n",
|
330 |
+
" # Create evaluator\n",
|
331 |
+
" evaluator = Evaluate(\n",
|
332 |
+
" metrics={\n",
|
333 |
+
" 'exact_match': TranslationMetrics.exact_match_score,\n",
|
334 |
+
" 'partial_match': TranslationMetrics.partial_match_score\n",
|
335 |
+
" }\n",
|
336 |
+
" )\n",
|
337 |
+
" \n",
|
338 |
+
" # Initialize module\n",
|
339 |
+
" assessor = TranslationQualityAssessor()\n",
|
340 |
+
" \n",
|
341 |
+
" # Initialize MIPROv2 optimizer\n",
|
342 |
+
" optimizer = dspy.MIPROv2(\n",
|
343 |
+
" metric=lambda x: x['partial_match'],\n",
|
344 |
+
" max_rounds=5, # Number of optimization rounds\n",
|
345 |
+
" max_traces=10, # Number of traces per round\n",
|
346 |
+
" max_depth=3, # Maximum depth of reasoning chains\n",
|
347 |
+
" num_candidate_prompts=5, # Number of candidate prompts to generate\n",
|
348 |
+
" num_rounds_per_prompt=3, # Number of rounds per candidate prompt\n",
|
349 |
+
" temperature=0.7,\n",
|
350 |
+
" verbose=True\n",
|
351 |
+
" )\n",
|
352 |
+
" \n",
|
353 |
+
" # Compile the module with optimization\n",
|
354 |
+
" compiled_assessor = optimizer.compile(\n",
|
355 |
+
" assessor,\n",
|
356 |
+
" trainset=train_data,\n",
|
357 |
+
" devset=dev_data,\n",
|
358 |
+
" eval_kwargs={\n",
|
359 |
+
" 'metric': 'partial_match',\n",
|
360 |
+
" 'num_threads': 4,\n",
|
361 |
+
" 'batch_size': 8\n",
|
362 |
+
" }\n",
|
363 |
+
" )\n",
|
364 |
+
" \n",
|
365 |
+
" # Evaluate on test set\n",
|
366 |
+
" results = []\n",
|
367 |
+
" for example in test_data:\n",
|
368 |
+
" pred = compiled_assessor(\n",
|
369 |
+
" context=example.context,\n",
|
370 |
+
" source=example.source,\n",
|
371 |
+
" translation=example.translation\n",
|
372 |
+
" )\n",
|
373 |
+
" \n",
|
374 |
+
" result = evaluator.evaluate(\n",
|
375 |
+
" predictions=[pred],\n",
|
376 |
+
" ground_truth=[example.evaluation]\n",
|
377 |
+
" )\n",
|
378 |
+
" results.append(result)\n",
|
379 |
+
" \n",
|
380 |
+
" # Calculate and print final metrics\n",
|
381 |
+
" avg_exact_match = np.mean([r['exact_match'] for r in results])\n",
|
382 |
+
" avg_partial_match = np.mean([r['partial_match'] for r in results])\n",
|
383 |
+
" \n",
|
384 |
+
" print(f\"Average Exact Match Score: {avg_exact_match:.3f}\")\n",
|
385 |
+
" print(f\"Average Partial Match Score: {avg_partial_match:.3f}\")\n",
|
386 |
+
" \n",
|
387 |
+
" return compiled_assessor\n",
|
388 |
+
"\n",
|
389 |
+
"if __name__ == \"__main__\":\n",
|
390 |
+
" optimized_assessor = optimize_translation_quality_assessment()"
|
391 |
+
]
|
392 |
+
}
|
393 |
+
],
|
394 |
+
"metadata": {
|
395 |
+
"kernelspec": {
|
396 |
+
"display_name": "timedlibs",
|
397 |
+
"language": "python",
|
398 |
+
"name": "timedlibs"
|
399 |
+
},
|
400 |
+
"language_info": {
|
401 |
+
"codemirror_mode": {
|
402 |
+
"name": "ipython",
|
403 |
+
"version": 3
|
404 |
+
},
|
405 |
+
"file_extension": ".py",
|
406 |
+
"mimetype": "text/x-python",
|
407 |
+
"name": "python",
|
408 |
+
"nbconvert_exporter": "python",
|
409 |
+
"pygments_lexer": "ipython3",
|
410 |
+
"version": "3.10.16"
|
411 |
+
}
|
412 |
+
},
|
413 |
+
"nbformat": 4,
|
414 |
+
"nbformat_minor": 5
|
415 |
+
}
|
InstructionFinetuning.ipynb
ADDED
@@ -0,0 +1,1277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "e6d20008-a91c-4618-baa0-5991e031f1bd",
|
7 |
+
"metadata": {
|
8 |
+
"execution": {
|
9 |
+
"iopub.execute_input": "2025-05-13T21:48:57.985184Z",
|
10 |
+
"iopub.status.busy": "2025-05-13T21:48:57.984795Z",
|
11 |
+
"iopub.status.idle": "2025-05-13T21:51:48.369715Z",
|
12 |
+
"shell.execute_reply": "2025-05-13T21:51:48.368907Z",
|
13 |
+
"shell.execute_reply.started": "2025-05-13T21:48:57.985144Z"
|
14 |
+
}
|
15 |
+
},
|
16 |
+
"outputs": [
|
17 |
+
{
|
18 |
+
"name": "stderr",
|
19 |
+
"output_type": "stream",
|
20 |
+
"text": [
|
21 |
+
"/root/notebooks/MT_TQ/Libraries/timedlibs/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
22 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
23 |
+
]
|
24 |
+
}
|
25 |
+
],
|
26 |
+
"source": [
|
27 |
+
"from transformers import AutoProcessor, Gemma3ForConditionalGeneration, Trainer, TrainingArguments, DataCollatorForSeq2Seq\n",
|
28 |
+
"import torch\n",
|
29 |
+
"from peft import LoraConfig, get_peft_model\n",
|
30 |
+
"\n",
|
31 |
+
"import os\n",
|
32 |
+
"from tqdm import tqdm\n",
|
33 |
+
"import json\n",
|
34 |
+
"\n",
|
35 |
+
"import random\n",
|
36 |
+
"from datasets import load_dataset\n",
|
37 |
+
"from datasets import Dataset, DatasetDict"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "code",
|
42 |
+
"execution_count": 3,
|
43 |
+
"id": "67f95fc8-a9d8-48cf-a551-7d30781cdb55",
|
44 |
+
"metadata": {
|
45 |
+
"execution": {
|
46 |
+
"iopub.execute_input": "2025-05-13T21:53:58.075473Z",
|
47 |
+
"iopub.status.busy": "2025-05-13T21:53:58.074767Z",
|
48 |
+
"iopub.status.idle": "2025-05-13T21:53:58.767860Z",
|
49 |
+
"shell.execute_reply": "2025-05-13T21:53:58.767319Z",
|
50 |
+
"shell.execute_reply.started": "2025-05-13T21:53:58.075446Z"
|
51 |
+
}
|
52 |
+
},
|
53 |
+
"outputs": [
|
54 |
+
{
|
55 |
+
"name": "stderr",
|
56 |
+
"output_type": "stream",
|
57 |
+
"text": [
|
58 |
+
"100%|██████████| 8/8 [00:00<00:00, 22.76it/s]\n"
|
59 |
+
]
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"name": "stdout",
|
63 |
+
"output_type": "stream",
|
64 |
+
"text": [
|
65 |
+
"DatasetDict({\n",
|
66 |
+
" train: Dataset({\n",
|
67 |
+
" features: ['messages'],\n",
|
68 |
+
" num_rows: 309\n",
|
69 |
+
" })\n",
|
70 |
+
" test: Dataset({\n",
|
71 |
+
" features: ['messages'],\n",
|
72 |
+
" num_rows: 343\n",
|
73 |
+
" })\n",
|
74 |
+
"})\n"
|
75 |
+
]
|
76 |
+
}
|
77 |
+
],
|
78 |
+
"source": [
|
79 |
+
"data_path = (\n",
|
80 |
+
" \"/root/notebooks/MT_TQ/Caches/May2025/tquality.annotated.data/parsed/pldl/\"\n",
|
81 |
+
")\n",
|
82 |
+
"\n",
|
83 |
+
"json_files = [\n",
|
84 |
+
" os.path.join(root, file)\n",
|
85 |
+
" for root, _, files in os.walk(data_path)\n",
|
86 |
+
" for file in files\n",
|
87 |
+
" if file.endswith(\".json\")\n",
|
88 |
+
"]\n",
|
89 |
+
"\n",
|
90 |
+
"training_samples = []\n",
|
91 |
+
"testing_samples = []\n",
|
92 |
+
"\n",
|
93 |
+
"for json_file in tqdm(json_files):\n",
|
94 |
+
" with open(json_file, \"r\") as file:\n",
|
95 |
+
" data = json.load(file)\n",
|
96 |
+
" sampled_items = data[\"data\"]\n",
|
97 |
+
" if \"test\" in json_file:\n",
|
98 |
+
" testing_samples.extend(sampled_items)\n",
|
99 |
+
" if \"train\" in json_file:\n",
|
100 |
+
" training_samples.extend(sampled_items)\n",
|
101 |
+
"\n",
|
102 |
+
"training_datapoints = []\n",
|
103 |
+
"testing_datapoints = []\n",
|
104 |
+
"\n",
|
105 |
+
"for idx, sample in enumerate(training_samples):\n",
|
106 |
+
" datapoint = {\"input\": {}}\n",
|
107 |
+
" datapoint[\"input\"][\"src_text\"] = sample[\"src_text\"]\n",
|
108 |
+
" datapoint[\"input\"][\"tgt_text\"] = sample[\"main_tgt_text\"]\n",
|
109 |
+
" datapoint[\"input\"][\"src_prev\"] = sample[\"tt_src_prev\"]\n",
|
110 |
+
" datapoint[\"input\"][\"src_next\"] = sample[\"tt_src_next\"]\n",
|
111 |
+
" datapoint[\"input\"][\"tgt_prev\"] = sample[\"tt_tgt_prev\"]\n",
|
112 |
+
" datapoint[\"input\"][\"tgt_next\"] = sample[\"tt_tgt_next\"]\n",
|
113 |
+
" datapoint[\"input\"][\"src_lang\"] = sample[\"src_lang\"]\n",
|
114 |
+
" datapoint[\"input\"][\"tgt_lang\"] = sample[\"tgt_lang\"]\n",
|
115 |
+
" datapoint[\"input\"][\"start_frame\"] = sample[\"start_frame\"]\n",
|
116 |
+
" datapoint[\"input\"][\"end_frame\"] = sample[\"end_frame\"]\n",
|
117 |
+
" datapoint[\"input\"][\"title_id\"] = sample[\"title_id\"]\n",
|
118 |
+
" datapoint[\"input\"][\"alt_tgt_text\"]= sample[\"alt_tgt_text\"]\n",
|
119 |
+
" datapoint[\"input\"][\"id\"] = idx\n",
|
120 |
+
" datapoint[\"evaluation\"] = sample[\"labelers\"][0][\"annotation\"]\n",
|
121 |
+
" training_datapoints.append(datapoint)\n",
|
122 |
+
"\n",
|
123 |
+
"for idx, sample in enumerate(testing_samples):\n",
|
124 |
+
" datapoint = {\"input\": {}}\n",
|
125 |
+
" datapoint[\"input\"][\"src_text\"] = sample[\"src_text\"]\n",
|
126 |
+
" datapoint[\"input\"][\"tgt_text\"] = sample[\"main_tgt_text\"]\n",
|
127 |
+
" datapoint[\"input\"][\"src_prev\"] = sample[\"tt_src_prev\"]\n",
|
128 |
+
" datapoint[\"input\"][\"src_next\"] = sample[\"tt_src_next\"]\n",
|
129 |
+
" datapoint[\"input\"][\"tgt_prev\"] = sample[\"tt_tgt_prev\"]\n",
|
130 |
+
" datapoint[\"input\"][\"tgt_next\"] = sample[\"tt_tgt_next\"]\n",
|
131 |
+
" datapoint[\"input\"][\"src_lang\"] = sample[\"src_lang\"]\n",
|
132 |
+
" datapoint[\"input\"][\"tgt_lang\"] = sample[\"tgt_lang\"]\n",
|
133 |
+
" datapoint[\"input\"][\"start_frame\"] = sample[\"start_frame\"]\n",
|
134 |
+
" datapoint[\"input\"][\"end_frame\"] = sample[\"end_frame\"]\n",
|
135 |
+
" datapoint[\"input\"][\"title_id\"] = sample[\"title_id\"]\n",
|
136 |
+
" datapoint[\"input\"][\"alt_tgt_text\"]= sample[\"alt_tgt_text\"]\n",
|
137 |
+
" datapoint[\"input\"][\"id\"] = idx\n",
|
138 |
+
" datapoint[\"evaluation\"] = sample[\"labelers\"][0][\"annotation\"]\n",
|
139 |
+
" testing_datapoints.append(datapoint)\n",
|
140 |
+
"\n",
|
141 |
+
"system_message = \"You are a helpful assistant who is an expert in estimating quality of translations.\"\n",
|
142 |
+
"\n",
|
143 |
+
"output_template = '''\n",
|
144 |
+
"{\n",
|
145 |
+
" \"Accuracy Issues\": [\n",
|
146 |
+
" {\n",
|
147 |
+
" \"Error Span\": \"\",\n",
|
148 |
+
" \"Error Explanation\": \"\",\n",
|
149 |
+
" \"Error Quality Category\": \"\",\n",
|
150 |
+
" \"Error Quality Tags\": [],\n",
|
151 |
+
" \"Error Severity\": \"\"\n",
|
152 |
+
" }\n",
|
153 |
+
" ],\n",
|
154 |
+
" \"Accuracy Score\": \"\",\n",
|
155 |
+
" \"Readability Issues\": [\n",
|
156 |
+
" {\n",
|
157 |
+
" \"Error Span\": \"\",\n",
|
158 |
+
" \"Error Explanation\": \"\",\n",
|
159 |
+
" \"Error Quality Category\": \"\",\n",
|
160 |
+
" \"Error Quality Tags\": [],\n",
|
161 |
+
" \"Error Severity\": \"\"\n",
|
162 |
+
" }\n",
|
163 |
+
" ],\n",
|
164 |
+
" \"Readability Score\": \"\"\n",
|
165 |
+
"}'''\n",
|
166 |
+
"\n",
|
167 |
+
"def create_conversation(input_sample, output_sample):\n",
|
168 |
+
" return {\n",
|
169 |
+
" \"messages\": [\n",
|
170 |
+
" # {\"role\": \"system\", \"content\": system_message},\n",
|
171 |
+
" {\"role\": \"user\", \"content\": input_sample},\n",
|
172 |
+
" {\"role\": \"assistant\", \"content\": output_sample}\n",
|
173 |
+
" ]\n",
|
174 |
+
" }\n",
|
175 |
+
"\n",
|
176 |
+
"def create_dataset(datapoints, template_string):\n",
|
177 |
+
" dataset = []\n",
|
178 |
+
" meta = []\n",
|
179 |
+
" for datapoint in datapoints:\n",
|
180 |
+
" src_text = datapoint['input']['src_text']\n",
|
181 |
+
" tgt_text = datapoint['input']['tgt_text']\n",
|
182 |
+
" src_prev = datapoint['input']['src_prev']\n",
|
183 |
+
" src_next = datapoint['input']['src_next'] \n",
|
184 |
+
" tgt_prev = datapoint['input']['tgt_prev']\n",
|
185 |
+
" tgt_next = datapoint['input']['tgt_next']\n",
|
186 |
+
" src_lang = datapoint['input']['src_lang']\n",
|
187 |
+
" tgt_lang = datapoint['input']['tgt_lang']\n",
|
188 |
+
" \n",
|
189 |
+
" start_frame = datapoint['input']['start_frame']\n",
|
190 |
+
" end_frame = datapoint['input']['end_frame']\n",
|
191 |
+
" title_id = datapoint['input']['title_id']\n",
|
192 |
+
" output = datapoint['evaluation']\n",
|
193 |
+
" idx = datapoint['input']['id']\n",
|
194 |
+
" if len(output['Accuracy Issues']) != 0 or len(output['Readability Issues']) != 0:\n",
|
195 |
+
" item = template_string.format(src_text=src_text, tgt_text=tgt_text, \n",
|
196 |
+
" src_prev=src_prev, src_next=src_next, \n",
|
197 |
+
" tgt_prev=tgt_prev, tgt_next=tgt_next, \n",
|
198 |
+
" src_lang=src_lang, tgt_lang=tgt_lang,\n",
|
199 |
+
" template=output_template)\n",
|
200 |
+
" \n",
|
201 |
+
" dataset.append(create_conversation(item, json.dumps(output)))\n",
|
202 |
+
" meta.append({\"id\": idx, \"start_frame\": start_frame, \"end_frame\": end_frame, \"title_id\": title_id})\n",
|
203 |
+
" \n",
|
204 |
+
" return dataset, meta\n",
|
205 |
+
" \n",
|
206 |
+
"def dataset_prep(datapoints):\n",
|
207 |
+
" with open(\"prompts.txt\") as file:\n",
|
208 |
+
" template_string = file.read()\n",
|
209 |
+
" dataset, meta = create_dataset(datapoints, template_string)\n",
|
210 |
+
" return dataset, meta\n",
|
211 |
+
"\n",
|
212 |
+
"train_dataset, train_meta = dataset_prep(training_datapoints)\n",
|
213 |
+
"test_dataset, test_meta = dataset_prep(testing_datapoints)\n",
|
214 |
+
"\n",
|
215 |
+
"dataset = {\"train\": train_dataset, \"test\": test_dataset}\n",
|
216 |
+
"\n",
|
217 |
+
"def convert_to_hf_dataset(dataset):\n",
|
218 |
+
" train_dataset = Dataset.from_list(dataset['train'])\n",
|
219 |
+
" test_dataset = Dataset.from_list(dataset['test'])\n",
|
220 |
+
" \n",
|
221 |
+
" hf_dataset = DatasetDict({\n",
|
222 |
+
" 'train': train_dataset,\n",
|
223 |
+
" 'test': test_dataset\n",
|
224 |
+
" })\n",
|
225 |
+
" \n",
|
226 |
+
" return hf_dataset\n",
|
227 |
+
"\n",
|
228 |
+
"hf_dataset = convert_to_hf_dataset(dataset)\n",
|
229 |
+
"print(hf_dataset)"
|
230 |
+
]
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"cell_type": "code",
|
234 |
+
"execution_count": 4,
|
235 |
+
"id": "8b52f143-1077-4da6-ac92-b1dce5cdc17c",
|
236 |
+
"metadata": {
|
237 |
+
"execution": {
|
238 |
+
"iopub.execute_input": "2025-05-13T21:54:12.568533Z",
|
239 |
+
"iopub.status.busy": "2025-05-13T21:54:12.568078Z",
|
240 |
+
"iopub.status.idle": "2025-05-13T21:54:49.724121Z",
|
241 |
+
"shell.execute_reply": "2025-05-13T21:54:49.723481Z",
|
242 |
+
"shell.execute_reply.started": "2025-05-13T21:54:12.568507Z"
|
243 |
+
}
|
244 |
+
},
|
245 |
+
"outputs": [
|
246 |
+
{
|
247 |
+
"name": "stderr",
|
248 |
+
"output_type": "stream",
|
249 |
+
"text": [
|
250 |
+
"Loading checkpoint shards: 100%|██████████| 12/12 [00:18<00:00, 1.58s/it]\n"
|
251 |
+
]
|
252 |
+
}
|
253 |
+
],
|
254 |
+
"source": [
|
255 |
+
"import torch\n",
|
256 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig\n",
|
257 |
+
"from transformers import AutoProcessor, Gemma3ForConditionalGeneration\n",
|
258 |
+
"device = torch.device(\"cuda:0\")\n",
|
259 |
+
"\n",
|
260 |
+
"# Hugging Face model id\n",
|
261 |
+
"model_id = \"google/gemma-3-27b-it\" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`\n",
|
262 |
+
"\n",
|
263 |
+
"# Select model class based on id\n",
|
264 |
+
"if model_id == \"google/gemma-3-27b-it\":\n",
|
265 |
+
" model_class = Gemma3ForConditionalGeneration\n",
|
266 |
+
"else:\n",
|
267 |
+
" model_class = AutoModelForImageTextToText\n",
|
268 |
+
"\n",
|
269 |
+
"torch_dtype = torch.bfloat16\n",
|
270 |
+
"\n",
|
271 |
+
"model_kwargs = dict(\n",
|
272 |
+
" attn_implementation=\"eager\",\n",
|
273 |
+
" torch_dtype=torch_dtype,\n",
|
274 |
+
" device_map=\"auto\",\n",
|
275 |
+
")\n",
|
276 |
+
"\n",
|
277 |
+
"model = model_class.from_pretrained(model_id, **model_kwargs)\n",
|
278 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-3-27b-it\") # Load the Instruction Tokenizer to use the official Gemma template"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "code",
|
283 |
+
"execution_count": 5,
|
284 |
+
"id": "8443dfd8-6193-480c-9937-f6e0c43a9f56",
|
285 |
+
"metadata": {
|
286 |
+
"execution": {
|
287 |
+
"iopub.execute_input": "2025-05-13T21:55:12.713958Z",
|
288 |
+
"iopub.status.busy": "2025-05-13T21:55:12.713495Z",
|
289 |
+
"iopub.status.idle": "2025-05-13T21:55:12.717707Z",
|
290 |
+
"shell.execute_reply": "2025-05-13T21:55:12.717199Z",
|
291 |
+
"shell.execute_reply.started": "2025-05-13T21:55:12.713930Z"
|
292 |
+
}
|
293 |
+
},
|
294 |
+
"outputs": [],
|
295 |
+
"source": [
|
296 |
+
"from peft import LoraConfig\n",
|
297 |
+
"\n",
|
298 |
+
"peft_config = LoraConfig(\n",
|
299 |
+
" lora_alpha=128,\n",
|
300 |
+
" lora_dropout=0.05,\n",
|
301 |
+
" r=16,\n",
|
302 |
+
" bias=\"none\",\n",
|
303 |
+
" target_modules=\"all-linear\",\n",
|
304 |
+
" task_type=\"CAUSAL_LM\",\n",
|
305 |
+
" modules_to_save=[\"lm_head\", \"embed_tokens\"] # make sure to save the lm_head and embed_tokens as you train the special tokens\n",
|
306 |
+
")"
|
307 |
+
]
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"cell_type": "code",
|
311 |
+
"execution_count": 6,
|
312 |
+
"id": "8f2b8371-ba1b-44ff-9462-d0c90335f82a",
|
313 |
+
"metadata": {
|
314 |
+
"execution": {
|
315 |
+
"iopub.execute_input": "2025-05-13T21:55:22.076515Z",
|
316 |
+
"iopub.status.busy": "2025-05-13T21:55:22.076029Z",
|
317 |
+
"iopub.status.idle": "2025-05-13T21:55:22.783524Z",
|
318 |
+
"shell.execute_reply": "2025-05-13T21:55:22.782937Z",
|
319 |
+
"shell.execute_reply.started": "2025-05-13T21:55:22.076489Z"
|
320 |
+
}
|
321 |
+
},
|
322 |
+
"outputs": [],
|
323 |
+
"source": [
|
324 |
+
"from trl import SFTConfig\n",
|
325 |
+
"\n",
|
326 |
+
"args = SFTConfig(\n",
|
327 |
+
" output_dir=\"may13-gemma-27b-tq_sft_finetuned-model\",\n",
|
328 |
+
" max_seq_length=2048,\n",
|
329 |
+
" packing=True,\n",
|
330 |
+
" num_train_epochs=1,\n",
|
331 |
+
" per_device_train_batch_size=1,\n",
|
332 |
+
" gradient_accumulation_steps=4,\n",
|
333 |
+
" gradient_checkpointing=True,\n",
|
334 |
+
" optim=\"adamw_torch_fused\",\n",
|
335 |
+
" logging_steps=1,\n",
|
336 |
+
" save_strategy=\"epoch\",\n",
|
337 |
+
" learning_rate=1e-4,\n",
|
338 |
+
" fp16=True if torch_dtype == torch.float16 else False,\n",
|
339 |
+
" bf16=True if torch_dtype == torch.bfloat16 else False,\n",
|
340 |
+
" max_grad_norm=0.3,\n",
|
341 |
+
" warmup_ratio=0.03,\n",
|
342 |
+
" lr_scheduler_type=\"constant\",\n",
|
343 |
+
" push_to_hub=True,\n",
|
344 |
+
" report_to=\"tensorboard\",\n",
|
345 |
+
" dataset_kwargs={\n",
|
346 |
+
" \"add_special_tokens\": False,\n",
|
347 |
+
" \"append_concat_token\": True,\n",
|
348 |
+
" },\n",
|
349 |
+
" no_cuda=False,\n",
|
350 |
+
")"
|
351 |
+
]
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"cell_type": "code",
|
355 |
+
"execution_count": 7,
|
356 |
+
"id": "2be55b87-70c9-4973-b0db-33154c272e47",
|
357 |
+
"metadata": {
|
358 |
+
"execution": {
|
359 |
+
"iopub.execute_input": "2025-05-13T21:55:25.765385Z",
|
360 |
+
"iopub.status.busy": "2025-05-13T21:55:25.764949Z",
|
361 |
+
"iopub.status.idle": "2025-05-13T21:55:36.592163Z",
|
362 |
+
"shell.execute_reply": "2025-05-13T21:55:36.591614Z",
|
363 |
+
"shell.execute_reply.started": "2025-05-13T21:55:25.765360Z"
|
364 |
+
}
|
365 |
+
},
|
366 |
+
"outputs": [
|
367 |
+
{
|
368 |
+
"name": "stderr",
|
369 |
+
"output_type": "stream",
|
370 |
+
"text": [
|
371 |
+
"Converting train dataset to ChatML: 100%|██████████| 309/309 [00:00<00:00, 9533.70 examples/s]\n",
|
372 |
+
"Applying chat template to train dataset: 100%|██████████| 309/309 [00:00<00:00, 4443.06 examples/s]\n",
|
373 |
+
"Tokenizing train dataset: 100%|██████████| 309/309 [00:01<00:00, 226.22 examples/s]\n",
|
374 |
+
"Packing train dataset: 100%|██████████| 309/309 [00:00<00:00, 102364.74 examples/s]\n",
|
375 |
+
"No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\n"
|
376 |
+
]
|
377 |
+
}
|
378 |
+
],
|
379 |
+
"source": [
|
380 |
+
"from trl import SFTTrainer\n",
|
381 |
+
"\n",
|
382 |
+
"# Create Trainer object\n",
|
383 |
+
"trainer = SFTTrainer(\n",
|
384 |
+
" model=model,\n",
|
385 |
+
" args=args,\n",
|
386 |
+
" train_dataset=hf_dataset[\"train\"],\n",
|
387 |
+
" peft_config=peft_config,\n",
|
388 |
+
" processing_class=tokenizer\n",
|
389 |
+
")"
|
390 |
+
]
|
391 |
+
},
|
392 |
+
{
|
393 |
+
"cell_type": "code",
|
394 |
+
"execution_count": 8,
|
395 |
+
"id": "d8d82767-27ed-48ed-ad22-3f3cf2dff15e",
|
396 |
+
"metadata": {
|
397 |
+
"execution": {
|
398 |
+
"iopub.execute_input": "2025-05-13T22:00:25.107226Z",
|
399 |
+
"iopub.status.busy": "2025-05-13T22:00:25.106569Z",
|
400 |
+
"iopub.status.idle": "2025-05-13T22:27:35.945604Z",
|
401 |
+
"shell.execute_reply": "2025-05-13T22:27:35.944775Z",
|
402 |
+
"shell.execute_reply.started": "2025-05-13T22:00:25.107196Z"
|
403 |
+
},
|
404 |
+
"scrolled": true
|
405 |
+
},
|
406 |
+
"outputs": [
|
407 |
+
{
|
408 |
+
"name": "stderr",
|
409 |
+
"output_type": "stream",
|
410 |
+
"text": [
|
411 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\n"
|
412 |
+
]
|
413 |
+
},
|
414 |
+
{
|
415 |
+
"data": {
|
416 |
+
"text/html": [
|
417 |
+
"\n",
|
418 |
+
" <div>\n",
|
419 |
+
" \n",
|
420 |
+
" <progress value='80' max='80' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
421 |
+
" [80/80 22:44, Epoch 0/1]\n",
|
422 |
+
" </div>\n",
|
423 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
424 |
+
" <thead>\n",
|
425 |
+
" <tr style=\"text-align: left;\">\n",
|
426 |
+
" <th>Step</th>\n",
|
427 |
+
" <th>Training Loss</th>\n",
|
428 |
+
" </tr>\n",
|
429 |
+
" </thead>\n",
|
430 |
+
" <tbody>\n",
|
431 |
+
" <tr>\n",
|
432 |
+
" <td>1</td>\n",
|
433 |
+
" <td>10.801900</td>\n",
|
434 |
+
" </tr>\n",
|
435 |
+
" <tr>\n",
|
436 |
+
" <td>2</td>\n",
|
437 |
+
" <td>8.381400</td>\n",
|
438 |
+
" </tr>\n",
|
439 |
+
" <tr>\n",
|
440 |
+
" <td>3</td>\n",
|
441 |
+
" <td>6.970200</td>\n",
|
442 |
+
" </tr>\n",
|
443 |
+
" <tr>\n",
|
444 |
+
" <td>4</td>\n",
|
445 |
+
" <td>5.784300</td>\n",
|
446 |
+
" </tr>\n",
|
447 |
+
" <tr>\n",
|
448 |
+
" <td>5</td>\n",
|
449 |
+
" <td>4.970800</td>\n",
|
450 |
+
" </tr>\n",
|
451 |
+
" <tr>\n",
|
452 |
+
" <td>6</td>\n",
|
453 |
+
" <td>4.389700</td>\n",
|
454 |
+
" </tr>\n",
|
455 |
+
" <tr>\n",
|
456 |
+
" <td>7</td>\n",
|
457 |
+
" <td>4.325000</td>\n",
|
458 |
+
" </tr>\n",
|
459 |
+
" <tr>\n",
|
460 |
+
" <td>8</td>\n",
|
461 |
+
" <td>3.557000</td>\n",
|
462 |
+
" </tr>\n",
|
463 |
+
" <tr>\n",
|
464 |
+
" <td>9</td>\n",
|
465 |
+
" <td>3.357700</td>\n",
|
466 |
+
" </tr>\n",
|
467 |
+
" <tr>\n",
|
468 |
+
" <td>10</td>\n",
|
469 |
+
" <td>3.092500</td>\n",
|
470 |
+
" </tr>\n",
|
471 |
+
" <tr>\n",
|
472 |
+
" <td>11</td>\n",
|
473 |
+
" <td>3.170300</td>\n",
|
474 |
+
" </tr>\n",
|
475 |
+
" <tr>\n",
|
476 |
+
" <td>12</td>\n",
|
477 |
+
" <td>2.648500</td>\n",
|
478 |
+
" </tr>\n",
|
479 |
+
" <tr>\n",
|
480 |
+
" <td>13</td>\n",
|
481 |
+
" <td>3.067800</td>\n",
|
482 |
+
" </tr>\n",
|
483 |
+
" <tr>\n",
|
484 |
+
" <td>14</td>\n",
|
485 |
+
" <td>2.377100</td>\n",
|
486 |
+
" </tr>\n",
|
487 |
+
" <tr>\n",
|
488 |
+
" <td>15</td>\n",
|
489 |
+
" <td>2.847700</td>\n",
|
490 |
+
" </tr>\n",
|
491 |
+
" <tr>\n",
|
492 |
+
" <td>16</td>\n",
|
493 |
+
" <td>2.628800</td>\n",
|
494 |
+
" </tr>\n",
|
495 |
+
" <tr>\n",
|
496 |
+
" <td>17</td>\n",
|
497 |
+
" <td>2.630800</td>\n",
|
498 |
+
" </tr>\n",
|
499 |
+
" <tr>\n",
|
500 |
+
" <td>18</td>\n",
|
501 |
+
" <td>2.820900</td>\n",
|
502 |
+
" </tr>\n",
|
503 |
+
" <tr>\n",
|
504 |
+
" <td>19</td>\n",
|
505 |
+
" <td>2.596700</td>\n",
|
506 |
+
" </tr>\n",
|
507 |
+
" <tr>\n",
|
508 |
+
" <td>20</td>\n",
|
509 |
+
" <td>2.675300</td>\n",
|
510 |
+
" </tr>\n",
|
511 |
+
" <tr>\n",
|
512 |
+
" <td>21</td>\n",
|
513 |
+
" <td>2.846300</td>\n",
|
514 |
+
" </tr>\n",
|
515 |
+
" <tr>\n",
|
516 |
+
" <td>22</td>\n",
|
517 |
+
" <td>2.706700</td>\n",
|
518 |
+
" </tr>\n",
|
519 |
+
" <tr>\n",
|
520 |
+
" <td>23</td>\n",
|
521 |
+
" <td>2.645100</td>\n",
|
522 |
+
" </tr>\n",
|
523 |
+
" <tr>\n",
|
524 |
+
" <td>24</td>\n",
|
525 |
+
" <td>2.214600</td>\n",
|
526 |
+
" </tr>\n",
|
527 |
+
" <tr>\n",
|
528 |
+
" <td>25</td>\n",
|
529 |
+
" <td>2.790700</td>\n",
|
530 |
+
" </tr>\n",
|
531 |
+
" <tr>\n",
|
532 |
+
" <td>26</td>\n",
|
533 |
+
" <td>2.640700</td>\n",
|
534 |
+
" </tr>\n",
|
535 |
+
" <tr>\n",
|
536 |
+
" <td>27</td>\n",
|
537 |
+
" <td>2.908900</td>\n",
|
538 |
+
" </tr>\n",
|
539 |
+
" <tr>\n",
|
540 |
+
" <td>28</td>\n",
|
541 |
+
" <td>2.690400</td>\n",
|
542 |
+
" </tr>\n",
|
543 |
+
" <tr>\n",
|
544 |
+
" <td>29</td>\n",
|
545 |
+
" <td>2.807200</td>\n",
|
546 |
+
" </tr>\n",
|
547 |
+
" <tr>\n",
|
548 |
+
" <td>30</td>\n",
|
549 |
+
" <td>2.713600</td>\n",
|
550 |
+
" </tr>\n",
|
551 |
+
" <tr>\n",
|
552 |
+
" <td>31</td>\n",
|
553 |
+
" <td>2.563200</td>\n",
|
554 |
+
" </tr>\n",
|
555 |
+
" <tr>\n",
|
556 |
+
" <td>32</td>\n",
|
557 |
+
" <td>2.412700</td>\n",
|
558 |
+
" </tr>\n",
|
559 |
+
" <tr>\n",
|
560 |
+
" <td>33</td>\n",
|
561 |
+
" <td>2.627700</td>\n",
|
562 |
+
" </tr>\n",
|
563 |
+
" <tr>\n",
|
564 |
+
" <td>34</td>\n",
|
565 |
+
" <td>2.431800</td>\n",
|
566 |
+
" </tr>\n",
|
567 |
+
" <tr>\n",
|
568 |
+
" <td>35</td>\n",
|
569 |
+
" <td>2.240600</td>\n",
|
570 |
+
" </tr>\n",
|
571 |
+
" <tr>\n",
|
572 |
+
" <td>36</td>\n",
|
573 |
+
" <td>2.650300</td>\n",
|
574 |
+
" </tr>\n",
|
575 |
+
" <tr>\n",
|
576 |
+
" <td>37</td>\n",
|
577 |
+
" <td>2.014900</td>\n",
|
578 |
+
" </tr>\n",
|
579 |
+
" <tr>\n",
|
580 |
+
" <td>38</td>\n",
|
581 |
+
" <td>2.463100</td>\n",
|
582 |
+
" </tr>\n",
|
583 |
+
" <tr>\n",
|
584 |
+
" <td>39</td>\n",
|
585 |
+
" <td>2.283300</td>\n",
|
586 |
+
" </tr>\n",
|
587 |
+
" <tr>\n",
|
588 |
+
" <td>40</td>\n",
|
589 |
+
" <td>2.450500</td>\n",
|
590 |
+
" </tr>\n",
|
591 |
+
" <tr>\n",
|
592 |
+
" <td>41</td>\n",
|
593 |
+
" <td>2.570400</td>\n",
|
594 |
+
" </tr>\n",
|
595 |
+
" <tr>\n",
|
596 |
+
" <td>42</td>\n",
|
597 |
+
" <td>2.550500</td>\n",
|
598 |
+
" </tr>\n",
|
599 |
+
" <tr>\n",
|
600 |
+
" <td>43</td>\n",
|
601 |
+
" <td>2.530600</td>\n",
|
602 |
+
" </tr>\n",
|
603 |
+
" <tr>\n",
|
604 |
+
" <td>44</td>\n",
|
605 |
+
" <td>2.551400</td>\n",
|
606 |
+
" </tr>\n",
|
607 |
+
" <tr>\n",
|
608 |
+
" <td>45</td>\n",
|
609 |
+
" <td>2.383000</td>\n",
|
610 |
+
" </tr>\n",
|
611 |
+
" <tr>\n",
|
612 |
+
" <td>46</td>\n",
|
613 |
+
" <td>2.550500</td>\n",
|
614 |
+
" </tr>\n",
|
615 |
+
" <tr>\n",
|
616 |
+
" <td>47</td>\n",
|
617 |
+
" <td>2.575900</td>\n",
|
618 |
+
" </tr>\n",
|
619 |
+
" <tr>\n",
|
620 |
+
" <td>48</td>\n",
|
621 |
+
" <td>2.494300</td>\n",
|
622 |
+
" </tr>\n",
|
623 |
+
" <tr>\n",
|
624 |
+
" <td>49</td>\n",
|
625 |
+
" <td>2.387200</td>\n",
|
626 |
+
" </tr>\n",
|
627 |
+
" <tr>\n",
|
628 |
+
" <td>50</td>\n",
|
629 |
+
" <td>2.318800</td>\n",
|
630 |
+
" </tr>\n",
|
631 |
+
" <tr>\n",
|
632 |
+
" <td>51</td>\n",
|
633 |
+
" <td>2.365200</td>\n",
|
634 |
+
" </tr>\n",
|
635 |
+
" <tr>\n",
|
636 |
+
" <td>52</td>\n",
|
637 |
+
" <td>2.190100</td>\n",
|
638 |
+
" </tr>\n",
|
639 |
+
" <tr>\n",
|
640 |
+
" <td>53</td>\n",
|
641 |
+
" <td>2.419100</td>\n",
|
642 |
+
" </tr>\n",
|
643 |
+
" <tr>\n",
|
644 |
+
" <td>54</td>\n",
|
645 |
+
" <td>2.290900</td>\n",
|
646 |
+
" </tr>\n",
|
647 |
+
" <tr>\n",
|
648 |
+
" <td>55</td>\n",
|
649 |
+
" <td>2.152500</td>\n",
|
650 |
+
" </tr>\n",
|
651 |
+
" <tr>\n",
|
652 |
+
" <td>56</td>\n",
|
653 |
+
" <td>2.398700</td>\n",
|
654 |
+
" </tr>\n",
|
655 |
+
" <tr>\n",
|
656 |
+
" <td>57</td>\n",
|
657 |
+
" <td>2.982500</td>\n",
|
658 |
+
" </tr>\n",
|
659 |
+
" <tr>\n",
|
660 |
+
" <td>58</td>\n",
|
661 |
+
" <td>2.380200</td>\n",
|
662 |
+
" </tr>\n",
|
663 |
+
" <tr>\n",
|
664 |
+
" <td>59</td>\n",
|
665 |
+
" <td>2.357500</td>\n",
|
666 |
+
" </tr>\n",
|
667 |
+
" <tr>\n",
|
668 |
+
" <td>60</td>\n",
|
669 |
+
" <td>2.386300</td>\n",
|
670 |
+
" </tr>\n",
|
671 |
+
" <tr>\n",
|
672 |
+
" <td>61</td>\n",
|
673 |
+
" <td>2.741300</td>\n",
|
674 |
+
" </tr>\n",
|
675 |
+
" <tr>\n",
|
676 |
+
" <td>62</td>\n",
|
677 |
+
" <td>2.850300</td>\n",
|
678 |
+
" </tr>\n",
|
679 |
+
" <tr>\n",
|
680 |
+
" <td>63</td>\n",
|
681 |
+
" <td>2.682100</td>\n",
|
682 |
+
" </tr>\n",
|
683 |
+
" <tr>\n",
|
684 |
+
" <td>64</td>\n",
|
685 |
+
" <td>2.972100</td>\n",
|
686 |
+
" </tr>\n",
|
687 |
+
" <tr>\n",
|
688 |
+
" <td>65</td>\n",
|
689 |
+
" <td>2.237800</td>\n",
|
690 |
+
" </tr>\n",
|
691 |
+
" <tr>\n",
|
692 |
+
" <td>66</td>\n",
|
693 |
+
" <td>2.518300</td>\n",
|
694 |
+
" </tr>\n",
|
695 |
+
" <tr>\n",
|
696 |
+
" <td>67</td>\n",
|
697 |
+
" <td>2.520700</td>\n",
|
698 |
+
" </tr>\n",
|
699 |
+
" <tr>\n",
|
700 |
+
" <td>68</td>\n",
|
701 |
+
" <td>2.122700</td>\n",
|
702 |
+
" </tr>\n",
|
703 |
+
" <tr>\n",
|
704 |
+
" <td>69</td>\n",
|
705 |
+
" <td>2.210200</td>\n",
|
706 |
+
" </tr>\n",
|
707 |
+
" <tr>\n",
|
708 |
+
" <td>70</td>\n",
|
709 |
+
" <td>2.414000</td>\n",
|
710 |
+
" </tr>\n",
|
711 |
+
" <tr>\n",
|
712 |
+
" <td>71</td>\n",
|
713 |
+
" <td>2.348200</td>\n",
|
714 |
+
" </tr>\n",
|
715 |
+
" <tr>\n",
|
716 |
+
" <td>72</td>\n",
|
717 |
+
" <td>2.470800</td>\n",
|
718 |
+
" </tr>\n",
|
719 |
+
" <tr>\n",
|
720 |
+
" <td>73</td>\n",
|
721 |
+
" <td>2.417400</td>\n",
|
722 |
+
" </tr>\n",
|
723 |
+
" <tr>\n",
|
724 |
+
" <td>74</td>\n",
|
725 |
+
" <td>2.562900</td>\n",
|
726 |
+
" </tr>\n",
|
727 |
+
" <tr>\n",
|
728 |
+
" <td>75</td>\n",
|
729 |
+
" <td>2.286800</td>\n",
|
730 |
+
" </tr>\n",
|
731 |
+
" <tr>\n",
|
732 |
+
" <td>76</td>\n",
|
733 |
+
" <td>2.671400</td>\n",
|
734 |
+
" </tr>\n",
|
735 |
+
" <tr>\n",
|
736 |
+
" <td>77</td>\n",
|
737 |
+
" <td>2.176200</td>\n",
|
738 |
+
" </tr>\n",
|
739 |
+
" <tr>\n",
|
740 |
+
" <td>78</td>\n",
|
741 |
+
" <td>2.284200</td>\n",
|
742 |
+
" </tr>\n",
|
743 |
+
" <tr>\n",
|
744 |
+
" <td>79</td>\n",
|
745 |
+
" <td>2.354700</td>\n",
|
746 |
+
" </tr>\n",
|
747 |
+
" <tr>\n",
|
748 |
+
" <td>80</td>\n",
|
749 |
+
" <td>2.363400</td>\n",
|
750 |
+
" </tr>\n",
|
751 |
+
" </tbody>\n",
|
752 |
+
"</table><p>"
|
753 |
+
],
|
754 |
+
"text/plain": [
|
755 |
+
"<IPython.core.display.HTML object>"
|
756 |
+
]
|
757 |
+
},
|
758 |
+
"metadata": {},
|
759 |
+
"output_type": "display_data"
|
760 |
+
}
|
761 |
+
],
|
762 |
+
"source": [
|
763 |
+
"trainer.train()\n",
|
764 |
+
"trainer.save_model()"
|
765 |
+
]
|
766 |
+
},
|
767 |
+
{
|
768 |
+
"cell_type": "code",
|
769 |
+
"execution_count": 10,
|
770 |
+
"id": "2398696f-eeb8-45d1-8dee-ed88a7ac140b",
|
771 |
+
"metadata": {
|
772 |
+
"execution": {
|
773 |
+
"iopub.execute_input": "2025-05-13T22:34:47.172016Z",
|
774 |
+
"iopub.status.busy": "2025-05-13T22:34:47.171574Z",
|
775 |
+
"iopub.status.idle": "2025-05-13T22:39:06.055171Z",
|
776 |
+
"shell.execute_reply": "2025-05-13T22:39:06.054429Z",
|
777 |
+
"shell.execute_reply.started": "2025-05-13T22:34:47.171989Z"
|
778 |
+
}
|
779 |
+
},
|
780 |
+
"outputs": [
|
781 |
+
{
|
782 |
+
"name": "stderr",
|
783 |
+
"output_type": "stream",
|
784 |
+
"text": [
|
785 |
+
"Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n"
|
786 |
+
]
|
787 |
+
},
|
788 |
+
{
|
789 |
+
"data": {
|
790 |
+
"text/plain": [
|
791 |
+
"('/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full/tokenizer_config.json',\n",
|
792 |
+
" '/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full/special_tokens_map.json',\n",
|
793 |
+
" '/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full/tokenizer.model',\n",
|
794 |
+
" '/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full/added_tokens.json',\n",
|
795 |
+
" '/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full/tokenizer.json')"
|
796 |
+
]
|
797 |
+
},
|
798 |
+
"execution_count": 10,
|
799 |
+
"metadata": {},
|
800 |
+
"output_type": "execute_result"
|
801 |
+
}
|
802 |
+
],
|
803 |
+
"source": [
|
804 |
+
"lora_model = trainer.model\n",
|
805 |
+
"merged_model = lora_model.merge_and_unload()\n",
|
806 |
+
"# Save the model with fused weights\n",
|
807 |
+
"merged_model.save_pretrained('/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full')\n",
|
808 |
+
"trainer.tokenizer.save_pretrained('/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full')"
|
809 |
+
]
|
810 |
+
},
|
811 |
+
{
|
812 |
+
"cell_type": "code",
|
813 |
+
"execution_count": 1,
|
814 |
+
"id": "8b811a84-0cdb-4b40-bb96-d6e6f27d41d3",
|
815 |
+
"metadata": {
|
816 |
+
"execution": {
|
817 |
+
"iopub.execute_input": "2025-05-08T21:17:00.794785Z",
|
818 |
+
"iopub.status.busy": "2025-05-08T21:17:00.794339Z",
|
819 |
+
"iopub.status.idle": "2025-05-08T21:17:18.309148Z",
|
820 |
+
"shell.execute_reply": "2025-05-08T21:17:18.308319Z",
|
821 |
+
"shell.execute_reply.started": "2025-05-08T21:17:00.794761Z"
|
822 |
+
}
|
823 |
+
},
|
824 |
+
"outputs": [
|
825 |
+
{
|
826 |
+
"ename": "NameError",
|
827 |
+
"evalue": "name 'model' is not defined",
|
828 |
+
"output_type": "error",
|
829 |
+
"traceback": [
|
830 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
831 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
832 |
+
"Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Merge LoRA weights into the base model\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, param \u001b[38;5;129;01min\u001b[39;00m \u001b[43mmodel\u001b[49m\u001b[38;5;241m.\u001b[39mnamed_parameters():\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mpeft_model\u001b[38;5;241m.\u001b[39mlora_weights:\n\u001b[1;32m 4\u001b[0m param\u001b[38;5;241m.\u001b[39mdata \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m trainer\u001b[38;5;241m.\u001b[39mpeft_model\u001b[38;5;241m.\u001b[39mlora_weights[name]\n",
|
833 |
+
"\u001b[0;31mNameError\u001b[0m: name 'model' is not defined"
|
834 |
+
]
|
835 |
+
}
|
836 |
+
],
|
837 |
+
"source": [
|
838 |
+
"# Merge LoRA weights into the base model\n",
|
839 |
+
"for name, param in model.named_parameters():\n",
|
840 |
+
" if name in trainer.peft_model.lora_weights:\n",
|
841 |
+
" param.data += trainer.peft_model.lora_weights[name]\n",
|
842 |
+
"\n",
|
843 |
+
"# Save the model with fused weights\n",
|
844 |
+
"model.save_pretrained('/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full')\n",
|
845 |
+
"tokenizer.save_pretrained('/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full')"
|
846 |
+
]
|
847 |
+
},
|
848 |
+
{
|
849 |
+
"cell_type": "code",
|
850 |
+
"execution_count": 9,
|
851 |
+
"id": "e5b4930d-92c5-46e8-9163-6e7f722e0c99",
|
852 |
+
"metadata": {
|
853 |
+
"execution": {
|
854 |
+
"iopub.execute_input": "2025-05-08T19:13:24.762234Z",
|
855 |
+
"iopub.status.busy": "2025-05-08T19:13:24.761972Z",
|
856 |
+
"iopub.status.idle": "2025-05-08T19:13:50.993002Z",
|
857 |
+
"shell.execute_reply": "2025-05-08T19:13:50.992329Z",
|
858 |
+
"shell.execute_reply.started": "2025-05-08T19:13:24.762215Z"
|
859 |
+
}
|
860 |
+
},
|
861 |
+
"outputs": [
|
862 |
+
{
|
863 |
+
"name": "stderr",
|
864 |
+
"output_type": "stream",
|
865 |
+
"text": [
|
866 |
+
"Loading checkpoint shards: 100%|██████████| 12/12 [00:19<00:00, 1.60s/it]\n"
|
867 |
+
]
|
868 |
+
}
|
869 |
+
],
|
870 |
+
"source": [
|
871 |
+
"import torch\n",
|
872 |
+
"from transformers import pipeline\n",
|
873 |
+
"from random import randint\n",
|
874 |
+
"import re\n",
|
875 |
+
"\n",
|
876 |
+
"model_id = \"google/gemma-3-27b-it\"\n",
|
877 |
+
"model = model_class.from_pretrained(\n",
|
878 |
+
" model_id,\n",
|
879 |
+
" device_map=\"auto\",\n",
|
880 |
+
" torch_dtype=torch_dtype,\n",
|
881 |
+
" attn_implementation=\"eager\",\n",
|
882 |
+
")\n",
|
883 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n"
|
884 |
+
]
|
885 |
+
},
|
886 |
+
{
|
887 |
+
"cell_type": "code",
|
888 |
+
"execution_count": 10,
|
889 |
+
"id": "5a428dea-261a-4c74-89a8-1b62d7ade5ab",
|
890 |
+
"metadata": {
|
891 |
+
"execution": {
|
892 |
+
"iopub.execute_input": "2025-05-08T19:13:50.999539Z",
|
893 |
+
"iopub.status.busy": "2025-05-08T19:13:50.999160Z",
|
894 |
+
"iopub.status.idle": "2025-05-08T19:15:04.024652Z",
|
895 |
+
"shell.execute_reply": "2025-05-08T19:15:04.022626Z",
|
896 |
+
"shell.execute_reply.started": "2025-05-08T19:13:50.999517Z"
|
897 |
+
}
|
898 |
+
},
|
899 |
+
"outputs": [
|
900 |
+
{
|
901 |
+
"ename": "NameError",
|
902 |
+
"evalue": "name 'trainer' is not defined",
|
903 |
+
"output_type": "error",
|
904 |
+
"traceback": [
|
905 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
906 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
907 |
+
"Cell \u001b[0;32mIn[10], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Merge LoRA weights into the base model\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, param \u001b[38;5;129;01min\u001b[39;00m model\u001b[38;5;241m.\u001b[39mnamed_parameters():\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[43mtrainer\u001b[49m\u001b[38;5;241m.\u001b[39mpeft_model\u001b[38;5;241m.\u001b[39mlora_weights:\n\u001b[1;32m 4\u001b[0m param\u001b[38;5;241m.\u001b[39mdata \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m trainer\u001b[38;5;241m.\u001b[39mpeft_model\u001b[38;5;241m.\u001b[39mlora_weights[name]\n\u001b[1;32m 6\u001b[0m \u001b[38;5;66;03m# Save the model with fused weights\u001b[39;00m\n",
|
908 |
+
"\u001b[0;31mNameError\u001b[0m: name 'trainer' is not defined"
|
909 |
+
]
|
910 |
+
}
|
911 |
+
],
|
912 |
+
"source": []
|
913 |
+
},
|
914 |
+
{
|
915 |
+
"cell_type": "code",
|
916 |
+
"execution_count": null,
|
917 |
+
"id": "c7e7172e-db49-40f3-a0d6-9a87a3b2cf80",
|
918 |
+
"metadata": {
|
919 |
+
"execution": {
|
920 |
+
"iopub.status.busy": "2025-05-08T19:15:04.026597Z",
|
921 |
+
"iopub.status.idle": "2025-05-08T19:15:04.026984Z",
|
922 |
+
"shell.execute_reply": "2025-05-08T19:15:04.026875Z",
|
923 |
+
"shell.execute_reply.started": "2025-05-08T19:15:04.026863Z"
|
924 |
+
},
|
925 |
+
"scrolled": true
|
926 |
+
},
|
927 |
+
"outputs": [],
|
928 |
+
"source": [
|
929 |
+
"pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer)\n",
|
930 |
+
"rand_idx = randint(0, len(dataset[\"test\"]))\n",
|
931 |
+
"test_sample = hf_dataset[\"test\"][rand_idx]\n",
|
932 |
+
"stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<end_of_turn>\")]\n",
|
933 |
+
"prompt = pipe.tokenizer.apply_chat_template(test_sample[\"messages\"][:1], tokenize=False, add_generation_prompt=True)\n",
|
934 |
+
"\n",
|
935 |
+
"outputs = pipe(prompt, max_new_tokens=1024, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)"
|
936 |
+
]
|
937 |
+
},
|
938 |
+
{
|
939 |
+
"cell_type": "code",
|
940 |
+
"execution_count": null,
|
941 |
+
"id": "de05b438-ca77-4b95-b2b1-32ea7ae033a5",
|
942 |
+
"metadata": {
|
943 |
+
"execution": {
|
944 |
+
"iopub.status.busy": "2025-05-08T19:15:04.028819Z",
|
945 |
+
"iopub.status.idle": "2025-05-08T19:15:04.029072Z",
|
946 |
+
"shell.execute_reply": "2025-05-08T19:15:04.028971Z",
|
947 |
+
"shell.execute_reply.started": "2025-05-08T19:15:04.028960Z"
|
948 |
+
}
|
949 |
+
},
|
950 |
+
"outputs": [],
|
951 |
+
"source": [
|
952 |
+
"start = outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip().find(\"{\")\n",
|
953 |
+
"end = outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip().rfind(\"}\")\n",
|
954 |
+
"print(start, end)\n",
|
955 |
+
"print(outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip()[start:end + 1])\n",
|
956 |
+
"json.loads(outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip()[start:end + 1])\n",
|
957 |
+
"rand_idx"
|
958 |
+
]
|
959 |
+
},
|
960 |
+
{
|
961 |
+
"cell_type": "code",
|
962 |
+
"execution_count": null,
|
963 |
+
"id": "60b3da99-0edc-4ef6-b0e0-be7d046eaa02",
|
964 |
+
"metadata": {
|
965 |
+
"execution": {
|
966 |
+
"iopub.status.busy": "2025-05-08T19:15:04.030913Z",
|
967 |
+
"iopub.status.idle": "2025-05-08T19:15:04.031227Z",
|
968 |
+
"shell.execute_reply": "2025-05-08T19:15:04.031122Z",
|
969 |
+
"shell.execute_reply.started": "2025-05-08T19:15:04.031111Z"
|
970 |
+
}
|
971 |
+
},
|
972 |
+
"outputs": [],
|
973 |
+
"source": [
|
974 |
+
"json.loads(hf_dataset[\"test\"][81][\"messages\"][1]['content'])"
|
975 |
+
]
|
976 |
+
},
|
977 |
+
{
|
978 |
+
"cell_type": "code",
|
979 |
+
"execution_count": null,
|
980 |
+
"id": "cdc44250-e3b9-4870-bce5-23f475023962",
|
981 |
+
"metadata": {
|
982 |
+
"execution": {
|
983 |
+
"iopub.status.busy": "2025-05-08T19:15:04.032999Z",
|
984 |
+
"iopub.status.idle": "2025-05-08T19:15:04.033327Z",
|
985 |
+
"shell.execute_reply": "2025-05-08T19:15:04.033207Z",
|
986 |
+
"shell.execute_reply.started": "2025-05-08T19:15:04.033196Z"
|
987 |
+
},
|
988 |
+
"scrolled": true
|
989 |
+
},
|
990 |
+
"outputs": [],
|
991 |
+
"source": [
|
992 |
+
"import torch\n",
|
993 |
+
"from transformers import pipeline\n",
|
994 |
+
"from random import randint\n",
|
995 |
+
"import re\n",
|
996 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig\n",
|
997 |
+
"from transformers import AutoProcessor, Gemma3ForConditionalGeneration\n",
|
998 |
+
"device = torch.device(\"cuda:0\")\n",
|
999 |
+
"\n",
|
1000 |
+
"model_class = Gemma3ForConditionalGeneration\n",
|
1001 |
+
"torch_dtype = torch.bfloat16\n",
|
1002 |
+
"\n",
|
1003 |
+
"model_id = \"gemma-27b-tq_sft_finetuned-model\"\n",
|
1004 |
+
"model = model_class.from_pretrained(\n",
|
1005 |
+
" model_id,\n",
|
1006 |
+
" device_map=\"auto\",\n",
|
1007 |
+
" torch_dtype=torch_dtype,\n",
|
1008 |
+
" attn_implementation=\"eager\",\n",
|
1009 |
+
")\n",
|
1010 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
|
1011 |
+
"pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer)"
|
1012 |
+
]
|
1013 |
+
},
|
1014 |
+
{
|
1015 |
+
"cell_type": "code",
|
1016 |
+
"execution_count": null,
|
1017 |
+
"id": "ce4070a9-5291-477a-bb3f-867b7971e391",
|
1018 |
+
"metadata": {
|
1019 |
+
"execution": {
|
1020 |
+
"iopub.status.busy": "2025-05-08T19:15:04.035085Z",
|
1021 |
+
"iopub.status.idle": "2025-05-08T19:15:04.035416Z",
|
1022 |
+
"shell.execute_reply": "2025-05-08T19:15:04.035307Z",
|
1023 |
+
"shell.execute_reply.started": "2025-05-08T19:15:04.035295Z"
|
1024 |
+
}
|
1025 |
+
},
|
1026 |
+
"outputs": [],
|
1027 |
+
"source": [
|
1028 |
+
"def extract_json_data(json_string):\n",
|
1029 |
+
" key_pattern = r'\"(.*?)\"\\s*:\\s*'\n",
|
1030 |
+
" value_pattern = r'(?:\"(.*?)\"|(\\d+)|$$(.*?)$$|\\{(.*?)\\})'\n",
|
1031 |
+
" matches = re.finditer(key_pattern + value_pattern, json_string, re.DOTALL) \n",
|
1032 |
+
" data = {}\n",
|
1033 |
+
" for match in matches:\n",
|
1034 |
+
" key = match.group(1)\n",
|
1035 |
+
" value = match.group(2) or match.group(3) or match.group(4) or match.group(5) \n",
|
1036 |
+
" if value:\n",
|
1037 |
+
" try:\n",
|
1038 |
+
" value = json.loads(value)\n",
|
1039 |
+
" except (json.JSONDecodeError, TypeError):\n",
|
1040 |
+
" pass\n",
|
1041 |
+
" data[key] = value\n",
|
1042 |
+
" return data"
|
1043 |
+
]
|
1044 |
+
},
|
1045 |
+
{
|
1046 |
+
"cell_type": "code",
|
1047 |
+
"execution_count": null,
|
1048 |
+
"id": "4940ab0c-ff5a-4c1e-a543-b0e8be91a4cb",
|
1049 |
+
"metadata": {
|
1050 |
+
"execution": {
|
1051 |
+
"iopub.status.busy": "2025-05-08T19:15:04.037234Z",
|
1052 |
+
"iopub.status.idle": "2025-05-08T19:15:04.037745Z",
|
1053 |
+
"shell.execute_reply": "2025-05-08T19:15:04.037637Z",
|
1054 |
+
"shell.execute_reply.started": "2025-05-08T19:15:04.037626Z"
|
1055 |
+
}
|
1056 |
+
},
|
1057 |
+
"outputs": [],
|
1058 |
+
"source": [
|
1059 |
+
"rand_idx = randint(0, len(dataset[\"test\"]))\n",
|
1060 |
+
"test_predictions = []\n",
|
1061 |
+
"\n",
|
1062 |
+
"index = 9\n",
|
1063 |
+
"\n",
|
1064 |
+
"meta_data = test_meta[index]\n",
|
1065 |
+
"stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<end_of_turn>\")]\n",
|
1066 |
+
"prompt = pipe.tokenizer.apply_chat_template(hf_dataset[\"test\"][index][\"messages\"][:1], tokenize=False, add_generation_prompt=True)\n",
|
1067 |
+
"outputs = pipe(prompt, max_new_tokens=2048, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)\n",
|
1068 |
+
"start = outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip().find(\"{\")\n",
|
1069 |
+
"end = outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip().rfind(\"}\")\n",
|
1070 |
+
"try:\n",
|
1071 |
+
" pred_dict = json.loads(outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip()[start:end + 1])\n",
|
1072 |
+
"except:\n",
|
1073 |
+
" start = outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip().find(\"{\")\n",
|
1074 |
+
" end = outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip().rfind(\"}\")\n",
|
1075 |
+
" pred_dict = outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip()[start:end + 1]"
|
1076 |
+
]
|
1077 |
+
},
|
1078 |
+
{
|
1079 |
+
"cell_type": "code",
|
1080 |
+
"execution_count": null,
|
1081 |
+
"id": "fdf03584-7cd0-40cc-af95-87279a2dc05e",
|
1082 |
+
"metadata": {
|
1083 |
+
"execution": {
|
1084 |
+
"iopub.status.busy": "2025-05-08T19:15:04.039492Z",
|
1085 |
+
"iopub.status.idle": "2025-05-08T19:15:04.039810Z",
|
1086 |
+
"shell.execute_reply": "2025-05-08T19:15:04.039704Z",
|
1087 |
+
"shell.execute_reply.started": "2025-05-08T19:15:04.039693Z"
|
1088 |
+
}
|
1089 |
+
},
|
1090 |
+
"outputs": [],
|
1091 |
+
"source": [
|
1092 |
+
"pred_dict"
|
1093 |
+
]
|
1094 |
+
},
|
1095 |
+
{
|
1096 |
+
"cell_type": "code",
|
1097 |
+
"execution_count": null,
|
1098 |
+
"id": "80603718-a168-4e4c-aa55-842dfb20f265",
|
1099 |
+
"metadata": {
|
1100 |
+
"execution": {
|
1101 |
+
"iopub.status.busy": "2025-05-08T19:15:04.041594Z",
|
1102 |
+
"iopub.status.idle": "2025-05-08T19:15:04.041970Z",
|
1103 |
+
"shell.execute_reply": "2025-05-08T19:15:04.041865Z",
|
1104 |
+
"shell.execute_reply.started": "2025-05-08T19:15:04.041854Z"
|
1105 |
+
}
|
1106 |
+
},
|
1107 |
+
"outputs": [],
|
1108 |
+
"source": [
|
1109 |
+
"hf_dataset[\"test\"][index][\"messages\"][1]"
|
1110 |
+
]
|
1111 |
+
},
|
1112 |
+
{
|
1113 |
+
"cell_type": "code",
|
1114 |
+
"execution_count": null,
|
1115 |
+
"id": "6d3731c9-4686-453f-8c91-e9477fe5541c",
|
1116 |
+
"metadata": {
|
1117 |
+
"execution": {
|
1118 |
+
"iopub.status.busy": "2025-05-08T19:15:04.043675Z",
|
1119 |
+
"iopub.status.idle": "2025-05-08T19:15:04.043977Z",
|
1120 |
+
"shell.execute_reply": "2025-05-08T19:15:04.043872Z",
|
1121 |
+
"shell.execute_reply.started": "2025-05-08T19:15:04.043861Z"
|
1122 |
+
}
|
1123 |
+
},
|
1124 |
+
"outputs": [],
|
1125 |
+
"source": [
|
1126 |
+
"batch_size = 8\n",
|
1127 |
+
"test_predictions = []\n",
|
1128 |
+
"\n",
|
1129 |
+
"for i in tqdm(range(0, len(hf_dataset[\"test\"]), batch_size)):\n",
|
1130 |
+
" batch_samples = hf_dataset[\"test\"][i:i + batch_size][\"messages\"]\n",
|
1131 |
+
" batch_meta = test_meta[i:i + batch_size]\n",
|
1132 |
+
" prompts = [\n",
|
1133 |
+
" pipe.tokenizer.apply_chat_template(sample[:1], tokenize=False, add_generation_prompt=True)\n",
|
1134 |
+
" for sample in batch_samples\n",
|
1135 |
+
" ]\n",
|
1136 |
+
" outputs = pipe(prompts, max_new_tokens=2048, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<end_of_turn>\")], disable_compile=True)\n",
|
1137 |
+
"\n",
|
1138 |
+
" for index, output in tqdm(enumerate(tqdm(outputs))):\n",
|
1139 |
+
" output_dict = {}\n",
|
1140 |
+
" start = output[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip().find(\"{\")\n",
|
1141 |
+
" end = output[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip().rfind(\"}\")\n",
|
1142 |
+
" try:\n",
|
1143 |
+
" pred_dict = json.loads(output[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip()[start:end + 1])\n",
|
1144 |
+
" except:\n",
|
1145 |
+
" pred_dict = output[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip()[start:end + 1]\n",
|
1146 |
+
" \n",
|
1147 |
+
" output_dict.update(batch_meta[index])\n",
|
1148 |
+
" output_dict[\"predictions\"] = pred_dict\n",
|
1149 |
+
" output_dict[\"human-annotation\"] = batch_samples[index][1]['content']\n",
|
1150 |
+
" output_dict[\"prompt\"] = batch_samples[index][0]['content']\n",
|
1151 |
+
" test_predictions.append(output_dict)"
|
1152 |
+
]
|
1153 |
+
},
|
1154 |
+
{
|
1155 |
+
"cell_type": "code",
|
1156 |
+
"execution_count": null,
|
1157 |
+
"id": "616eb30a-eac2-4229-b86c-24eca7534cc6",
|
1158 |
+
"metadata": {
|
1159 |
+
"execution": {
|
1160 |
+
"iopub.status.busy": "2025-05-08T19:15:04.045755Z",
|
1161 |
+
"iopub.status.idle": "2025-05-08T19:15:04.046057Z",
|
1162 |
+
"shell.execute_reply": "2025-05-08T19:15:04.045954Z",
|
1163 |
+
"shell.execute_reply.started": "2025-05-08T19:15:04.045943Z"
|
1164 |
+
}
|
1165 |
+
},
|
1166 |
+
"outputs": [],
|
1167 |
+
"source": [
|
1168 |
+
"with open(\"/root/notebooks/trashspace/gemma_finetuned_expertdata/test_pred.json\", 'w') as json_file:\n",
|
1169 |
+
" json.dump(test_predictions, json_file)"
|
1170 |
+
]
|
1171 |
+
},
|
1172 |
+
{
|
1173 |
+
"cell_type": "code",
|
1174 |
+
"execution_count": null,
|
1175 |
+
"id": "71480057-d6b9-4499-a8d7-26bf3f3f9342",
|
1176 |
+
"metadata": {},
|
1177 |
+
"outputs": [],
|
1178 |
+
"source": []
|
1179 |
+
},
|
1180 |
+
{
|
1181 |
+
"cell_type": "code",
|
1182 |
+
"execution_count": null,
|
1183 |
+
"id": "ce73604c-4bbb-46c7-8433-d957b0e10405",
|
1184 |
+
"metadata": {},
|
1185 |
+
"outputs": [],
|
1186 |
+
"source": []
|
1187 |
+
},
|
1188 |
+
{
|
1189 |
+
"cell_type": "code",
|
1190 |
+
"execution_count": null,
|
1191 |
+
"id": "4741a722-a772-44bf-949e-e77671a4ef03",
|
1192 |
+
"metadata": {},
|
1193 |
+
"outputs": [],
|
1194 |
+
"source": []
|
1195 |
+
},
|
1196 |
+
{
|
1197 |
+
"cell_type": "code",
|
1198 |
+
"execution_count": null,
|
1199 |
+
"id": "07c2adef-c2f6-4ba9-98f7-277cce2701d0",
|
1200 |
+
"metadata": {},
|
1201 |
+
"outputs": [],
|
1202 |
+
"source": []
|
1203 |
+
},
|
1204 |
+
{
|
1205 |
+
"cell_type": "code",
|
1206 |
+
"execution_count": null,
|
1207 |
+
"id": "1adfa27b-9bfa-4479-be3f-5149a2237c1f",
|
1208 |
+
"metadata": {
|
1209 |
+
"execution": {
|
1210 |
+
"iopub.status.busy": "2025-05-08T19:15:04.047823Z",
|
1211 |
+
"iopub.status.idle": "2025-05-08T19:15:04.048130Z",
|
1212 |
+
"shell.execute_reply": "2025-05-08T19:15:04.048026Z",
|
1213 |
+
"shell.execute_reply.started": "2025-05-08T19:15:04.048015Z"
|
1214 |
+
}
|
1215 |
+
},
|
1216 |
+
"outputs": [],
|
1217 |
+
"source": [
|
1218 |
+
"data = json.loads(test_sample['messages'][1]['content'])\n",
|
1219 |
+
"data"
|
1220 |
+
]
|
1221 |
+
},
|
1222 |
+
{
|
1223 |
+
"cell_type": "code",
|
1224 |
+
"execution_count": null,
|
1225 |
+
"id": "750d3454-6300-469b-bdc3-77cce45a00ce",
|
1226 |
+
"metadata": {
|
1227 |
+
"execution": {
|
1228 |
+
"iopub.status.busy": "2025-05-08T19:15:04.049897Z",
|
1229 |
+
"iopub.status.idle": "2025-05-08T19:15:04.050203Z",
|
1230 |
+
"shell.execute_reply": "2025-05-08T19:15:04.050099Z",
|
1231 |
+
"shell.execute_reply.started": "2025-05-08T19:15:04.050088Z"
|
1232 |
+
}
|
1233 |
+
},
|
1234 |
+
"outputs": [],
|
1235 |
+
"source": [
|
1236 |
+
"print(len(hf_dataset[\"test\"]))"
|
1237 |
+
]
|
1238 |
+
},
|
1239 |
+
{
|
1240 |
+
"cell_type": "code",
|
1241 |
+
"execution_count": null,
|
1242 |
+
"id": "3be45a2d-336f-4899-a8e9-e000437fab8c",
|
1243 |
+
"metadata": {},
|
1244 |
+
"outputs": [],
|
1245 |
+
"source": []
|
1246 |
+
},
|
1247 |
+
{
|
1248 |
+
"cell_type": "code",
|
1249 |
+
"execution_count": null,
|
1250 |
+
"id": "248182ff-bec8-46ff-bc34-14b523d877bf",
|
1251 |
+
"metadata": {},
|
1252 |
+
"outputs": [],
|
1253 |
+
"source": []
|
1254 |
+
}
|
1255 |
+
],
|
1256 |
+
"metadata": {
|
1257 |
+
"kernelspec": {
|
1258 |
+
"display_name": "timedlibs",
|
1259 |
+
"language": "python",
|
1260 |
+
"name": "timedlibs"
|
1261 |
+
},
|
1262 |
+
"language_info": {
|
1263 |
+
"codemirror_mode": {
|
1264 |
+
"name": "ipython",
|
1265 |
+
"version": 3
|
1266 |
+
},
|
1267 |
+
"file_extension": ".py",
|
1268 |
+
"mimetype": "text/x-python",
|
1269 |
+
"name": "python",
|
1270 |
+
"nbconvert_exporter": "python",
|
1271 |
+
"pygments_lexer": "ipython3",
|
1272 |
+
"version": "3.10.16"
|
1273 |
+
}
|
1274 |
+
},
|
1275 |
+
"nbformat": 4,
|
1276 |
+
"nbformat_minor": 5
|
1277 |
+
}
|
README.md
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
base_model: google/gemma-3-4b-it
|
3 |
+
library_name: transformers
|
4 |
+
model_name: TQTune
|
5 |
+
tags:
|
6 |
+
- generated_from_trainer
|
7 |
+
- trl
|
8 |
+
- sft
|
9 |
+
licence: license
|
10 |
+
---
|
11 |
+
|
12 |
+
# Model Card for TQTune
|
13 |
+
|
14 |
+
This model is a fine-tuned version of [google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it).
|
15 |
+
It has been trained using [TRL](https://github.com/huggingface/trl).
|
16 |
+
|
17 |
+
## Quick start
|
18 |
+
|
19 |
+
```python
|
20 |
+
from transformers import pipeline
|
21 |
+
|
22 |
+
question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?"
|
23 |
+
generator = pipeline("text-generation", model="bhavinjawade/TQTune", device="cuda")
|
24 |
+
output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0]
|
25 |
+
print(output["generated_text"])
|
26 |
+
```
|
27 |
+
|
28 |
+
## Training procedure
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
This model was trained with SFT.
|
34 |
+
|
35 |
+
### Framework versions
|
36 |
+
|
37 |
+
- TRL: 0.16.1
|
38 |
+
- Transformers: 4.50.0.dev0
|
39 |
+
- Pytorch: 2.6.0+cu124
|
40 |
+
- Datasets: 3.3.2
|
41 |
+
- Tokenizers: 0.21.0
|
42 |
+
|
43 |
+
## Citations
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
Cite TRL as:
|
48 |
+
|
49 |
+
```bibtex
|
50 |
+
@misc{vonwerra2022trl,
|
51 |
+
title = {{TRL: Transformer Reinforcement Learning}},
|
52 |
+
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
|
53 |
+
year = 2020,
|
54 |
+
journal = {GitHub repository},
|
55 |
+
publisher = {GitHub},
|
56 |
+
howpublished = {\url{https://github.com/huggingface/trl}}
|
57 |
+
}
|
58 |
+
```
|
SFT_Expert.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, Trainer, TrainingArguments, DataCollatorForSeq2Seq
|
2 |
+
import torch
|
3 |
+
from peft import LoraConfig, get_peft_model
|
4 |
+
import os
|
5 |
+
from tqdm import tqdm
|
6 |
+
import json
|
7 |
+
import random
|
8 |
+
from datasets import load_dataset
|
9 |
+
from datasets import Dataset, DatasetDict
|
10 |
+
|
11 |
+
system_message = "You are a helpful assistant who is an expert in estimating quality of translations."
|
12 |
+
|
13 |
+
output_template = '''
|
14 |
+
{
|
15 |
+
"Accuracy Issues": [
|
16 |
+
{
|
17 |
+
"Error Span": "",
|
18 |
+
"Error Explanation": "",
|
19 |
+
"Error Quality Category": "",
|
20 |
+
"Error Quality Tags": [],
|
21 |
+
"Error Severity": ""
|
22 |
+
}
|
23 |
+
],
|
24 |
+
"Accuracy Score": "",
|
25 |
+
"Readability Issues": [
|
26 |
+
{
|
27 |
+
"Error Span": "",
|
28 |
+
"Error Explanation": "",
|
29 |
+
"Error Quality Category": "",
|
30 |
+
"Error Quality Tags": [],
|
31 |
+
"Error Severity": ""
|
32 |
+
}
|
33 |
+
],
|
34 |
+
"Readability Score": ""
|
35 |
+
}'''
|
36 |
+
|
37 |
+
def create_conversation(input_sample, output_sample):
|
38 |
+
return {
|
39 |
+
"messages": [
|
40 |
+
# {"role": "system", "content": system_message},
|
41 |
+
{"role": "user", "content": input_sample},
|
42 |
+
{"role": "assistant", "content": output_sample}
|
43 |
+
]
|
44 |
+
}
|
45 |
+
|
46 |
+
data_path = (
|
47 |
+
"/root/notebooks/MT_TQ/TQ/TQTune/labeled_data/parsed/"
|
48 |
+
)
|
49 |
+
|
50 |
+
json_files = [
|
51 |
+
os.path.join(root, file)
|
52 |
+
for root, _, files in os.walk(data_path)
|
53 |
+
for file in files
|
54 |
+
if file.endswith(".json") and "PLDL" in file
|
55 |
+
]
|
56 |
+
|
57 |
+
training_samples = []
|
58 |
+
for json_file in tqdm(json_files):
|
59 |
+
with open(json_file, "r") as file:
|
60 |
+
data = json.load(file)
|
61 |
+
sampled_items = random.sample(data["data"], 20)
|
62 |
+
training_samples.extend(sampled_items)
|
63 |
+
|
64 |
+
datapoints = []
|
65 |
+
|
66 |
+
for sample in training_samples:
|
67 |
+
datapoint = {"input": {}}
|
68 |
+
datapoint["input"]["src_text"] = sample["main_src_text"]
|
69 |
+
datapoint["input"]["tgt_text"] = sample["tgt_text"]
|
70 |
+
datapoint["input"]["src_prev"] = sample["tt_src_prev"]
|
71 |
+
datapoint["input"]["src_next"] = sample["tt_src_next"]
|
72 |
+
datapoint["input"]["tgt_prev"] = sample["tt_tgt_prev"]
|
73 |
+
datapoint["input"]["tgt_next"] = sample["tt_tgt_next"]
|
74 |
+
datapoint["input"]["src_lang"] = sample["src_lang"]
|
75 |
+
datapoint["input"]["tgt_lang"] = sample["tgt_lang"]
|
76 |
+
datapoint["evaluation"] = sample["labelers"][0]["annotation"]
|
77 |
+
datapoints.append(datapoint)
|
78 |
+
|
79 |
+
def dataset_prep(datapoints, test_size=0.2):
|
80 |
+
with open("prompts.txt") as file:
|
81 |
+
template_string = file.read()
|
82 |
+
|
83 |
+
random.shuffle(datapoints)
|
84 |
+
|
85 |
+
split_index = int(len(datapoints) * (1 - test_size))
|
86 |
+
train_datapoints = datapoints[:split_index]
|
87 |
+
test_datapoints = datapoints[split_index:]
|
88 |
+
|
89 |
+
def create_dataset(datapoints):
|
90 |
+
dataset = []
|
91 |
+
for datapoint in datapoints:
|
92 |
+
src_text = datapoint['input']['src_text']
|
93 |
+
tgt_text = datapoint['input']['tgt_text']
|
94 |
+
src_prev = datapoint['input']['src_prev']
|
95 |
+
src_next = datapoint['input']['src_next']
|
96 |
+
tgt_prev = datapoint['input']['tgt_prev']
|
97 |
+
tgt_next = datapoint['input']['tgt_next']
|
98 |
+
src_lang = datapoint['input']['src_lang']
|
99 |
+
tgt_lang = datapoint['input']['tgt_lang']
|
100 |
+
output = datapoint['evaluation']
|
101 |
+
del output["Confidence Level"]
|
102 |
+
del output["Main Vs Alternate"]
|
103 |
+
del output["Score"]
|
104 |
+
|
105 |
+
if len(output['Accuracy Issues']) != 0 and len(output['Readability Issues']) != 0:
|
106 |
+
item = template_string.format(src_text=src_text, tgt_text=tgt_text,
|
107 |
+
src_prev=src_prev, src_next=src_next,
|
108 |
+
tgt_prev=tgt_prev, tgt_next=tgt_next,
|
109 |
+
src_lang=src_lang, tgt_lang=tgt_lang,
|
110 |
+
template=output_template)
|
111 |
+
|
112 |
+
dataset.append(create_conversation(item, json.dumps(output)))
|
113 |
+
|
114 |
+
return dataset
|
115 |
+
|
116 |
+
train_set = create_dataset(train_datapoints)
|
117 |
+
test_set = create_dataset(test_datapoints)
|
118 |
+
|
119 |
+
return train_set, test_set
|
120 |
+
|
121 |
+
train_dataset, test_dataset = dataset_prep(datapoints)
|
122 |
+
dataset = {"train": train_dataset, "test": test_dataset}
|
123 |
+
|
124 |
+
def convert_to_hf_dataset(dataset):
|
125 |
+
# Convert the train and test datasets into Hugging Face Dataset objects
|
126 |
+
train_dataset = Dataset.from_list(dataset['train'])
|
127 |
+
test_dataset = Dataset.from_list(dataset['test'])
|
128 |
+
|
129 |
+
# Combine them into a DatasetDict
|
130 |
+
hf_dataset = DatasetDict({
|
131 |
+
'train': train_dataset,
|
132 |
+
'test': test_dataset
|
133 |
+
})
|
134 |
+
|
135 |
+
return hf_dataset
|
136 |
+
|
137 |
+
# Convert your dataset into a Hugging Face Dataset object
|
138 |
+
hf_dataset = convert_to_hf_dataset(dataset)
|
139 |
+
|
140 |
+
# Now you can use hf_dataset for your machine learning tasks
|
141 |
+
print(hf_dataset)
|
142 |
+
|
143 |
+
import torch
|
144 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig
|
145 |
+
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
146 |
+
device = torch.device("cuda:0")
|
147 |
+
|
148 |
+
# Hugging Face model id
|
149 |
+
model_id = "google/gemma-3-12b-it" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`
|
150 |
+
|
151 |
+
# Select model class based on id
|
152 |
+
if model_id == "google/gemma-3-12b-it":
|
153 |
+
model_class = Gemma3ForConditionalGeneration
|
154 |
+
else:
|
155 |
+
model_class = AutoModelForImageTextToText
|
156 |
+
|
157 |
+
torch_dtype = torch.bfloat16
|
158 |
+
|
159 |
+
model_kwargs = dict(
|
160 |
+
attn_implementation="eager",
|
161 |
+
torch_dtype=torch_dtype,
|
162 |
+
device_map="auto", # Change from {'': 0} to "auto"
|
163 |
+
)
|
164 |
+
|
165 |
+
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
166 |
+
load_in_8bit=True,
|
167 |
+
bnb_8bit_use_double_quant=True,
|
168 |
+
bnb_8bit_quant_type='nf8',
|
169 |
+
bnb_8bit_compute_dtype=model_kwargs['torch_dtype'],
|
170 |
+
bnb_8bit_quant_storage=model_kwargs['torch_dtype'],
|
171 |
+
)
|
172 |
+
|
173 |
+
model = model_class.from_pretrained(model_id, **model_kwargs)
|
174 |
+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-12b-it") # Load the Instruction Tokenizer to use the official Gemma template
|
175 |
+
|
176 |
+
from peft import LoraConfig
|
177 |
+
|
178 |
+
peft_config = LoraConfig(
|
179 |
+
lora_alpha=128,
|
180 |
+
lora_dropout=0.05,
|
181 |
+
r=16,
|
182 |
+
bias="none",
|
183 |
+
target_modules="all-linear",
|
184 |
+
task_type="CAUSAL_LM",
|
185 |
+
modules_to_save=["lm_head", "embed_tokens"] # make sure to save the lm_head and embed_tokens as you train the special tokens
|
186 |
+
)
|
187 |
+
|
188 |
+
from trl import SFTConfig
|
189 |
+
|
190 |
+
args = SFTConfig(
|
191 |
+
output_dir="gemma-12b-tq-model",
|
192 |
+
max_seq_length=512,
|
193 |
+
packing=True,
|
194 |
+
num_train_epochs=1,
|
195 |
+
per_device_train_batch_size=1,
|
196 |
+
gradient_accumulation_steps=4,
|
197 |
+
gradient_checkpointing=True,
|
198 |
+
optim="adamw_torch_fused",
|
199 |
+
logging_steps=1,
|
200 |
+
save_strategy="epoch",
|
201 |
+
learning_rate=2e-4,
|
202 |
+
fp16=True if torch_dtype == torch.float16 else False,
|
203 |
+
bf16=True if torch_dtype == torch.bfloat16 else False,
|
204 |
+
max_grad_norm=0.3,
|
205 |
+
warmup_ratio=0.03,
|
206 |
+
lr_scheduler_type="constant",
|
207 |
+
push_to_hub=True,
|
208 |
+
report_to="tensorboard",
|
209 |
+
dataset_kwargs={
|
210 |
+
"add_special_tokens": False,
|
211 |
+
"append_concat_token": True,
|
212 |
+
},
|
213 |
+
ddp_find_unused_parameters=False,
|
214 |
+
no_cuda=False,
|
215 |
+
)
|
216 |
+
|
217 |
+
from trl import SFTTrainer
|
218 |
+
|
219 |
+
# Create Trainer object
|
220 |
+
trainer = SFTTrainer(
|
221 |
+
model=model,
|
222 |
+
args=args,
|
223 |
+
train_dataset=hf_dataset["train"],
|
224 |
+
peft_config=peft_config,
|
225 |
+
processing_class=tokenizer
|
226 |
+
)
|
227 |
+
|
228 |
+
trainer.train()
|
229 |
+
trainer.save_model()
|
TQ_template.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"src_lang": "en",
|
3 |
+
"tgt_lang": "de",
|
4 |
+
"src_prev": "Alice: Hi Bob, how are you?\nBob: I'm good, thanks!",
|
5 |
+
"tgt_prev": "Alice: Hallo Bob, wie geht's?\nBob: Mir geht's gut, danke!",
|
6 |
+
"src_next": "Alice: Want to grab coffee later?\nBob: Sure, sounds good.",
|
7 |
+
"tgt_next": "Alice: Möchtest du später einen Kaffee trinken?\nBob: Klar, klingt gut.",
|
8 |
+
"src_text": "Bob: I just got back from Paris.",
|
9 |
+
"main_text": "This is the main text",
|
10 |
+
"alternate_text": "This is the alternate text",
|
11 |
+
"evaluation": {
|
12 |
+
"Accuracy Issues": [
|
13 |
+
{
|
14 |
+
"Error Span": [5,8],
|
15 |
+
"Error Explanation": "Incorrect translation of 'just got back' as 'gerade aus' instead of 'gerade zurück'.",
|
16 |
+
"Error Quality Category": "Fidelity",
|
17 |
+
"Error Quality Tags": ["terminology", "accuracy"],
|
18 |
+
"Error Severity": "Major"
|
19 |
+
}
|
20 |
+
],
|
21 |
+
"Accuracy Score": "4", # ["projects"]["What ever key is there"]["labels"][index here of the labels]["annotations"]["classifications"][1]["radio_answer"]["name"]
|
22 |
+
"Readability Issues": [
|
23 |
+
{
|
24 |
+
"Error Location": "Src", # ["projects"]["What ever key is there"]["labels"][index here of the labels]["annotations"]["objects"][index of it]["conversational_location"]["message_id"]
|
25 |
+
"Error Span": [0,2], # ["projects"]["What ever key is there"]["labels"][index here of the labels]["annotations"]["objects"][index of it]["conversational_location"]["location"]["start" and "end" use them to make this list (start, end)]
|
26 |
+
"Error Explanation": "Sentence structure is awkward in German translation.", # ["projects"]["What ever key is there"]["labels"][index here of the labels]["annotations"]["objects"][index of it]["classification"][2]["text_answer"]["content"]
|
27 |
+
"Error Quality Category": "Style", # ["projects"]["What ever key is there"]["labels"][index here of the labels]["annotations"]["objects"][index of it]["name"] - here if the name is "Style" - put it under Readability Issues else put it under Accuracy Issues.
|
28 |
+
"Error Quality Tags": ["awkward", "structure"], # ["projects"]["What ever key is there"]["labels"][index here of the labels]["annotations"]["objects"][index of it]["classification"][1]["checklist_answers"][list of dicts, take name keys for all and make list of it]
|
29 |
+
"Error Severity": "Minor" # ["projects"]["What ever key is there"]["labels"][index here of the labels]["annotations"]["objects"][index of it]["classification"][0]["radio_answer"]["name"]
|
30 |
+
}
|
31 |
+
],
|
32 |
+
"Readability Score": "3", # ["projects"]["What ever key is there"]["labels"][index here of the labels]["classifications"][2]["radio_answer"]["name"]
|
33 |
+
"Confidence Level": "the_translation_is_excellent_without_any_error_spans_and_no_creative_liberties_were_taken", # # ["projects"]["What ever key is there"]["labels"][index here of the labels]["classifications"][3]["radio_answer"]["name"]
|
34 |
+
"Main Vs Alternate": "Both of them have roughly the same quality" # ["projects"]["What ever key is there"]["labels"][index here of the labels]["classifications"][0]["radio_answer"]["name"]
|
35 |
+
},
|
36 |
+
"Score": "26"
|
37 |
+
}
|
TextGrad_Optimization.ipynb
ADDED
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 10,
|
6 |
+
"id": "8b3ee6e2-ca9c-40fa-b4c6-a9596f075f79",
|
7 |
+
"metadata": {
|
8 |
+
"execution": {
|
9 |
+
"iopub.execute_input": "2025-05-09T17:36:47.763713Z",
|
10 |
+
"iopub.status.busy": "2025-05-09T17:36:47.763339Z",
|
11 |
+
"iopub.status.idle": "2025-05-09T17:36:47.768648Z",
|
12 |
+
"shell.execute_reply": "2025-05-09T17:36:47.768166Z",
|
13 |
+
"shell.execute_reply.started": "2025-05-09T17:36:47.763676Z"
|
14 |
+
}
|
15 |
+
},
|
16 |
+
"outputs": [
|
17 |
+
{
|
18 |
+
"name": "stdout",
|
19 |
+
"output_type": "stream",
|
20 |
+
"text": [
|
21 |
+
"env: OPENAI_API_KEY=\"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\"\n"
|
22 |
+
]
|
23 |
+
}
|
24 |
+
],
|
25 |
+
"source": [
|
26 |
+
"import argparse\n",
|
27 |
+
"import concurrent\n",
|
28 |
+
"from dotenv import load_dotenv\n",
|
29 |
+
"from tqdm import tqdm\n",
|
30 |
+
"import textgrad as tg\n",
|
31 |
+
"from textgrad.tasks import load_task\n",
|
32 |
+
"import numpy as np\n",
|
33 |
+
"import random\n",
|
34 |
+
"load_dotenv(override=True)\n",
|
35 |
+
"import os\n",
|
36 |
+
"import json\n",
|
37 |
+
"\n",
|
38 |
+
"%env OPENAI_API_KEY=\"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\""
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"execution_count": 4,
|
44 |
+
"id": "4ec9a29b-9162-4fe3-b32d-4de4397c6483",
|
45 |
+
"metadata": {
|
46 |
+
"execution": {
|
47 |
+
"iopub.execute_input": "2025-05-09T17:33:04.417822Z",
|
48 |
+
"iopub.status.busy": "2025-05-09T17:33:04.417437Z",
|
49 |
+
"iopub.status.idle": "2025-05-09T17:33:04.429505Z",
|
50 |
+
"shell.execute_reply": "2025-05-09T17:33:04.429029Z",
|
51 |
+
"shell.execute_reply.started": "2025-05-09T17:33:04.417795Z"
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"outputs": [
|
55 |
+
{
|
56 |
+
"name": "stderr",
|
57 |
+
"output_type": "stream",
|
58 |
+
"text": [
|
59 |
+
"0it [00:00, ?it/s]\n"
|
60 |
+
]
|
61 |
+
}
|
62 |
+
],
|
63 |
+
"source": [
|
64 |
+
"data_path = \"/root/notebooks/MT_TQ/TQ/DataPrep_Prompting_Experiments/labeled_data/parsed/\"\n",
|
65 |
+
"json_files = [os.path.join(root, file) for root, _, files in os.walk(data_path) for file in files if file.endswith('.json') and 'PLDL' in file]\n",
|
66 |
+
"\n",
|
67 |
+
"training_samples = []\n",
|
68 |
+
"for json_file in tqdm(json_files):\n",
|
69 |
+
" with open(json_file, 'r') as file:\n",
|
70 |
+
" data = json.load(file)\n",
|
71 |
+
" sampled_items = random.sample(data[\"data\"], 20)\n",
|
72 |
+
" training_samples.extend(sampled_items)\n",
|
73 |
+
"\n",
|
74 |
+
"datapoints = []\n",
|
75 |
+
"\n",
|
76 |
+
"for sample in training_samples:\n",
|
77 |
+
" datapoint = {\"input\":{}}\n",
|
78 |
+
" datapoint[\"input\"][\"src_text\"] = sample[\"main_src_text\"]\n",
|
79 |
+
" datapoint[\"input\"][\"tgt_text\"] = sample[\"tgt_text\"]\n",
|
80 |
+
" datapoint[\"input\"][\"src_prev\"] = sample[\"tt_src_prev\"]\n",
|
81 |
+
" datapoint[\"input\"][\"src_next\"] = sample[\"tt_src_next\"]\n",
|
82 |
+
" datapoint[\"input\"][\"tgt_prev\"] = sample[\"tt_tgt_prev\"]\n",
|
83 |
+
" datapoint[\"input\"][\"tgt_next\"] = sample[\"tt_tgt_next\"]\n",
|
84 |
+
" datapoint[\"input\"][\"src_lang\"] = sample[\"src_lang\"]\n",
|
85 |
+
" datapoint[\"input\"][\"tgt_lang\"] = sample[\"tgt_lang\"]\n",
|
86 |
+
" datapoint[\"evaluation\"] = sample[\"labelers\"][0][\"annotation\"]\n",
|
87 |
+
" datapoints.append(datapoint)"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": 5,
|
93 |
+
"id": "a894ce72-d451-44fa-aaa5-85bf8e6dc9da",
|
94 |
+
"metadata": {
|
95 |
+
"execution": {
|
96 |
+
"iopub.execute_input": "2025-05-09T17:33:40.240759Z",
|
97 |
+
"iopub.status.busy": "2025-05-09T17:33:40.240243Z",
|
98 |
+
"iopub.status.idle": "2025-05-09T17:33:40.244435Z",
|
99 |
+
"shell.execute_reply": "2025-05-09T17:33:40.243818Z",
|
100 |
+
"shell.execute_reply.started": "2025-05-09T17:33:40.240720Z"
|
101 |
+
}
|
102 |
+
},
|
103 |
+
"outputs": [],
|
104 |
+
"source": [
|
105 |
+
"def set_seed(seed):\n",
|
106 |
+
" np.random.seed(seed)\n",
|
107 |
+
" random.seed(seed)"
|
108 |
+
]
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"cell_type": "code",
|
112 |
+
"execution_count": 6,
|
113 |
+
"id": "4eeaa266-3ca2-4360-b80b-b38aa3bbdb70",
|
114 |
+
"metadata": {
|
115 |
+
"execution": {
|
116 |
+
"iopub.execute_input": "2025-05-09T17:33:55.982807Z",
|
117 |
+
"iopub.status.busy": "2025-05-09T17:33:55.982080Z",
|
118 |
+
"iopub.status.idle": "2025-05-09T17:33:55.988522Z",
|
119 |
+
"shell.execute_reply": "2025-05-09T17:33:55.987924Z",
|
120 |
+
"shell.execute_reply.started": "2025-05-09T17:33:55.982770Z"
|
121 |
+
}
|
122 |
+
},
|
123 |
+
"outputs": [],
|
124 |
+
"source": [
|
125 |
+
"def eval_sample(item, eval_fn, model):\n",
|
126 |
+
" \"\"\"\n",
|
127 |
+
" This function allows us to evaluate if an answer to a question in the prompt is a good answer.\n",
|
128 |
+
"\n",
|
129 |
+
" \"\"\"\n",
|
130 |
+
" x, y = item\n",
|
131 |
+
" x = tg.Variable(x, requires_grad=False, role_description=\"query to the language model\")\n",
|
132 |
+
" y = tg.Variable(y, requires_grad=False, role_description=\"correct answer for the query\")\n",
|
133 |
+
" response = model(x)\n",
|
134 |
+
" try:\n",
|
135 |
+
" eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))\n",
|
136 |
+
" return int(eval_output_variable.value)\n",
|
137 |
+
" except:\n",
|
138 |
+
" eval_output_variable = eval_fn([x, y, response])\n",
|
139 |
+
" eval_output_parsed = eval_fn.parse_output(eval_output_variable)\n",
|
140 |
+
" return int(eval_output_parsed)"
|
141 |
+
]
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"cell_type": "code",
|
145 |
+
"execution_count": 7,
|
146 |
+
"id": "c7e57f9d-c0ff-4139-9e61-b93510599353",
|
147 |
+
"metadata": {
|
148 |
+
"execution": {
|
149 |
+
"iopub.execute_input": "2025-05-09T17:34:08.606301Z",
|
150 |
+
"iopub.status.busy": "2025-05-09T17:34:08.605538Z",
|
151 |
+
"iopub.status.idle": "2025-05-09T17:34:08.612515Z",
|
152 |
+
"shell.execute_reply": "2025-05-09T17:34:08.611911Z",
|
153 |
+
"shell.execute_reply.started": "2025-05-09T17:34:08.606262Z"
|
154 |
+
}
|
155 |
+
},
|
156 |
+
"outputs": [],
|
157 |
+
"source": [
|
158 |
+
"def eval_dataset(test_set, eval_fn, model, max_samples: int=None):\n",
|
159 |
+
" if max_samples is None:\n",
|
160 |
+
" max_samples = len(test_set)\n",
|
161 |
+
" accuracy_list = []\n",
|
162 |
+
" with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:\n",
|
163 |
+
" futures = []\n",
|
164 |
+
" for _, sample in enumerate(test_set):\n",
|
165 |
+
" \n",
|
166 |
+
" future = executor.submit(eval_sample, sample, eval_fn, model)\n",
|
167 |
+
" futures.append(future)\n",
|
168 |
+
" if len(futures) >= max_samples:\n",
|
169 |
+
" break\n",
|
170 |
+
" tqdm_loader = tqdm(concurrent.futures.as_completed(futures), total=len(futures), position=0)\n",
|
171 |
+
" for future in tqdm_loader:\n",
|
172 |
+
" acc_item = future.result()\n",
|
173 |
+
" accuracy_list.append(acc_item)\n",
|
174 |
+
" tqdm_loader.set_description(f\"Accuracy: {np.mean(accuracy_list)}\")\n",
|
175 |
+
" return accuracy_list "
|
176 |
+
]
|
177 |
+
},
|
178 |
+
{
|
179 |
+
"cell_type": "code",
|
180 |
+
"execution_count": 8,
|
181 |
+
"id": "039af9f3-a124-4a50-98a7-e728a913c069",
|
182 |
+
"metadata": {
|
183 |
+
"execution": {
|
184 |
+
"iopub.execute_input": "2025-05-09T17:34:22.703336Z",
|
185 |
+
"iopub.status.busy": "2025-05-09T17:34:22.702980Z",
|
186 |
+
"iopub.status.idle": "2025-05-09T17:34:22.707253Z",
|
187 |
+
"shell.execute_reply": "2025-05-09T17:34:22.706781Z",
|
188 |
+
"shell.execute_reply.started": "2025-05-09T17:34:22.703313Z"
|
189 |
+
}
|
190 |
+
},
|
191 |
+
"outputs": [],
|
192 |
+
"source": [
|
193 |
+
"def run_validation_revert(system_prompt: tg.Variable, results, model, eval_fn, val_set):\n",
|
194 |
+
" val_performance = np.mean(eval_dataset(val_set, eval_fn, model))\n",
|
195 |
+
" previous_performance = np.mean(results[\"validation_acc\"][-1])\n",
|
196 |
+
" print(\"val_performance: \", val_performance)\n",
|
197 |
+
" print(\"previous_performance: \", previous_performance)\n",
|
198 |
+
" previous_prompt = results[\"prompt\"][-1]\n",
|
199 |
+
" \n",
|
200 |
+
" if val_performance < previous_performance:\n",
|
201 |
+
" print(f\"rejected prompt: {system_prompt.value}\")\n",
|
202 |
+
" system_prompt.set_value(previous_prompt)\n",
|
203 |
+
" val_performance = previous_performance\n",
|
204 |
+
"\n",
|
205 |
+
" results[\"validation_acc\"].append(val_performance)"
|
206 |
+
]
|
207 |
+
},
|
208 |
+
{
|
209 |
+
"cell_type": "code",
|
210 |
+
"execution_count": 14,
|
211 |
+
"id": "031ebb6e-f5ff-45b0-a810-d1bd81ef6d2a",
|
212 |
+
"metadata": {
|
213 |
+
"execution": {
|
214 |
+
"iopub.execute_input": "2025-05-09T17:40:38.476352Z",
|
215 |
+
"iopub.status.busy": "2025-05-09T17:40:38.475979Z",
|
216 |
+
"iopub.status.idle": "2025-05-09T17:40:38.701947Z",
|
217 |
+
"shell.execute_reply": "2025-05-09T17:40:38.701394Z",
|
218 |
+
"shell.execute_reply.started": "2025-05-09T17:40:38.476327Z"
|
219 |
+
}
|
220 |
+
},
|
221 |
+
"outputs": [
|
222 |
+
{
|
223 |
+
"name": "stdout",
|
224 |
+
"output_type": "stream",
|
225 |
+
"text": [
|
226 |
+
"Train/Val/Test Set Lengths: 50 100 100\n"
|
227 |
+
]
|
228 |
+
}
|
229 |
+
],
|
230 |
+
"source": [
|
231 |
+
"set_seed(12)\n",
|
232 |
+
"llm_api_eval = tg.get_engine(engine_name=\"gpt-4o\")\n",
|
233 |
+
"llm_api_test = tg.get_engine(engine_name=\"gpt-3.5-turbo-0125\")\n",
|
234 |
+
"tg.set_backward_engine(llm_api_eval, override=True)\n",
|
235 |
+
"\n",
|
236 |
+
"# Load the data and the evaluation function\n",
|
237 |
+
"train_set, val_set, test_set, eval_fn = load_task(\"BBH_object_counting\", evaluation_api=llm_api_eval)\n",
|
238 |
+
"print(\"Train/Val/Test Set Lengths: \", len(train_set), len(val_set), len(test_set))\n",
|
239 |
+
"STARTING_SYSTEM_PROMPT = train_set.get_task_description()"
|
240 |
+
]
|
241 |
+
},
|
242 |
+
{
|
243 |
+
"cell_type": "code",
|
244 |
+
"execution_count": 15,
|
245 |
+
"id": "bde34303-2f52-415f-b117-264e266b84f0",
|
246 |
+
"metadata": {
|
247 |
+
"execution": {
|
248 |
+
"iopub.execute_input": "2025-05-09T17:40:39.330651Z",
|
249 |
+
"iopub.status.busy": "2025-05-09T17:40:39.330285Z",
|
250 |
+
"iopub.status.idle": "2025-05-09T17:40:39.398820Z",
|
251 |
+
"shell.execute_reply": "2025-05-09T17:40:39.398116Z",
|
252 |
+
"shell.execute_reply.started": "2025-05-09T17:40:39.330626Z"
|
253 |
+
}
|
254 |
+
},
|
255 |
+
"outputs": [
|
256 |
+
{
|
257 |
+
"name": "stderr",
|
258 |
+
"output_type": "stream",
|
259 |
+
"text": [
|
260 |
+
" 0%| | 0/100 [00:00<?, ?it/s]\n"
|
261 |
+
]
|
262 |
+
},
|
263 |
+
{
|
264 |
+
"ename": "AssertionError",
|
265 |
+
"evalue": "Value must be a string, int, or image (bytes). Got: <class 'numpy.int64'>",
|
266 |
+
"output_type": "error",
|
267 |
+
"traceback": [
|
268 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
269 |
+
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
|
270 |
+
"Cell \u001b[0;32mIn[15], line 18\u001b[0m\n\u001b[1;32m 15\u001b[0m optimizer \u001b[38;5;241m=\u001b[39m tg\u001b[38;5;241m.\u001b[39mTextualGradientDescent(engine\u001b[38;5;241m=\u001b[39mllm_api_eval, parameters\u001b[38;5;241m=\u001b[39m[system_prompt])\n\u001b[1;32m 17\u001b[0m results \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_acc\u001b[39m\u001b[38;5;124m\"\u001b[39m: [], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mprompt\u001b[39m\u001b[38;5;124m\"\u001b[39m: [], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_acc\u001b[39m\u001b[38;5;124m\"\u001b[39m: []}\n\u001b[0;32m---> 18\u001b[0m results[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_acc\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mappend(\u001b[43meval_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_set\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meval_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 19\u001b[0m results[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_acc\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mappend(eval_dataset(val_set, eval_fn, model))\n\u001b[1;32m 20\u001b[0m results[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mprompt\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mappend(system_prompt\u001b[38;5;241m.\u001b[39mget_value())\n",
|
271 |
+
"Cell \u001b[0;32mIn[7], line 15\u001b[0m, in \u001b[0;36meval_dataset\u001b[0;34m(test_set, eval_fn, model, max_samples)\u001b[0m\n\u001b[1;32m 13\u001b[0m tqdm_loader \u001b[38;5;241m=\u001b[39m tqdm(concurrent\u001b[38;5;241m.\u001b[39mfutures\u001b[38;5;241m.\u001b[39mas_completed(futures), total\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(futures), position\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m future \u001b[38;5;129;01min\u001b[39;00m tqdm_loader:\n\u001b[0;32m---> 15\u001b[0m acc_item \u001b[38;5;241m=\u001b[39m \u001b[43mfuture\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mresult\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m accuracy_list\u001b[38;5;241m.\u001b[39mappend(acc_item)\n\u001b[1;32m 17\u001b[0m tqdm_loader\u001b[38;5;241m.\u001b[39mset_description(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAccuracy: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnp\u001b[38;5;241m.\u001b[39mmean(accuracy_list)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
|
272 |
+
"File \u001b[0;32m/apps/python3.10/lib/python3.10/concurrent/futures/_base.py:451\u001b[0m, in \u001b[0;36mFuture.result\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 449\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m CancelledError()\n\u001b[1;32m 450\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_state \u001b[38;5;241m==\u001b[39m FINISHED:\n\u001b[0;32m--> 451\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__get_result\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 453\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_condition\u001b[38;5;241m.\u001b[39mwait(timeout)\n\u001b[1;32m 455\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_state \u001b[38;5;129;01min\u001b[39;00m [CANCELLED, CANCELLED_AND_NOTIFIED]:\n",
|
273 |
+
"File \u001b[0;32m/apps/python3.10/lib/python3.10/concurrent/futures/_base.py:403\u001b[0m, in \u001b[0;36mFuture.__get_result\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 401\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_exception:\n\u001b[1;32m 402\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 403\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_exception\n\u001b[1;32m 404\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 405\u001b[0m \u001b[38;5;66;03m# Break a reference cycle with the exception in self._exception\u001b[39;00m\n\u001b[1;32m 406\u001b[0m \u001b[38;5;28mself\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
274 |
+
"File \u001b[0;32m/apps/python3.10/lib/python3.10/concurrent/futures/thread.py:58\u001b[0m, in \u001b[0;36m_WorkItem.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 58\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m exc:\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfuture\u001b[38;5;241m.\u001b[39mset_exception(exc)\n",
|
275 |
+
"Cell \u001b[0;32mIn[6], line 8\u001b[0m, in \u001b[0;36meval_sample\u001b[0;34m(item, eval_fn, model)\u001b[0m\n\u001b[1;32m 6\u001b[0m x, y \u001b[38;5;241m=\u001b[39m item\n\u001b[1;32m 7\u001b[0m x \u001b[38;5;241m=\u001b[39m tg\u001b[38;5;241m.\u001b[39mVariable(x, requires_grad\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, role_description\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquery to the language model\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 8\u001b[0m y \u001b[38;5;241m=\u001b[39m \u001b[43mtg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mVariable\u001b[49m\u001b[43m(\u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrequires_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrole_description\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcorrect answer for the query\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m response \u001b[38;5;241m=\u001b[39m model(x)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n",
|
276 |
+
"File \u001b[0;32m~/notebooks/MT_TQ/Libraries/timedlibs/lib/python3.10/site-packages/textgrad/variable.py:43\u001b[0m, in \u001b[0;36mVariable.__init__\u001b[0;34m(self, value, image_path, predecessors, requires_grad, role_description)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m requires_grad) \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;28mlen\u001b[39m(_predecessor_requires_grad) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m):\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIf the variable does not require grad, none of its predecessors should require grad.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIn this case, following predecessors require grad: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m_predecessor_requires_grad\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 43\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(value) \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mbytes\u001b[39m, \u001b[38;5;28mint\u001b[39m], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mValue must be a string, int, or image (bytes). Got: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;28mtype\u001b[39m(value))\n\u001b[1;32m 44\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(value, \u001b[38;5;28mint\u001b[39m):\n\u001b[1;32m 45\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mstr\u001b[39m(value)\n",
|
277 |
+
"\u001b[0;31mAssertionError\u001b[0m: Value must be a string, int, or image (bytes). Got: <class 'numpy.int64'>"
|
278 |
+
]
|
279 |
+
}
|
280 |
+
],
|
281 |
+
"source": [
|
282 |
+
"train_loader = tg.tasks.DataLoader(train_set, batch_size=3, shuffle=True)\n",
|
283 |
+
"\n",
|
284 |
+
"\n",
|
285 |
+
"# Testing the 0-shot performance of the evaluation engine\n",
|
286 |
+
"system_prompt = tg.Variable(STARTING_SYSTEM_PROMPT, \n",
|
287 |
+
" requires_grad=True, \n",
|
288 |
+
" role_description=\"system prompt to the language model\")\n",
|
289 |
+
"model_evaluation = tg.BlackboxLLM(llm_api_eval, system_prompt)\n",
|
290 |
+
"\n",
|
291 |
+
"system_prompt = tg.Variable(STARTING_SYSTEM_PROMPT, \n",
|
292 |
+
" requires_grad=True,\n",
|
293 |
+
" role_description=\"structured system prompt to a somewhat capable language model that specifies the behavior and strategies for the QA task\")\n",
|
294 |
+
"model = tg.BlackboxLLM(llm_api_test, system_prompt)\n",
|
295 |
+
"\n",
|
296 |
+
"optimizer = tg.TextualGradientDescent(engine=llm_api_eval, parameters=[system_prompt])\n",
|
297 |
+
"\n",
|
298 |
+
"results = {\"test_acc\": [], \"prompt\": [], \"validation_acc\": []}\n",
|
299 |
+
"results[\"test_acc\"].append(eval_dataset(test_set, eval_fn, model))\n",
|
300 |
+
"results[\"validation_acc\"].append(eval_dataset(val_set, eval_fn, model))\n",
|
301 |
+
"results[\"prompt\"].append(system_prompt.get_value())"
|
302 |
+
]
|
303 |
+
},
|
304 |
+
{
|
305 |
+
"cell_type": "code",
|
306 |
+
"execution_count": null,
|
307 |
+
"id": "47c15231-22ff-459b-b5cc-ca32aaa62332",
|
308 |
+
"metadata": {},
|
309 |
+
"outputs": [],
|
310 |
+
"source": [
|
311 |
+
"for epoch in range(3):\n",
|
312 |
+
" for steps, (batch_x, batch_y) in enumerate((pbar := tqdm(train_loader, position=0))):\n",
|
313 |
+
" pbar.set_description(f\"Training step {steps}. Epoch {epoch}\")\n",
|
314 |
+
" optimizer.zero_grad()\n",
|
315 |
+
" losses = []\n",
|
316 |
+
" for (x, y) in zip(batch_x, batch_y):\n",
|
317 |
+
" x = tg.Variable(x, requires_grad=False, role_description=\"query to the language model\")\n",
|
318 |
+
" y = tg.Variable(y, requires_grad=False, role_description=\"correct answer for the query\")\n",
|
319 |
+
" response = model(x)\n",
|
320 |
+
" try:\n",
|
321 |
+
" eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))\n",
|
322 |
+
" except:\n",
|
323 |
+
" eval_output_variable = eval_fn([x, y, response])\n",
|
324 |
+
" losses.append(eval_output_variable)\n",
|
325 |
+
" total_loss = tg.sum(losses)\n",
|
326 |
+
" total_loss.backward()\n",
|
327 |
+
" optimizer.step()\n",
|
328 |
+
" \n",
|
329 |
+
" run_validation_revert(system_prompt, results, model, eval_fn, val_set)\n",
|
330 |
+
" \n",
|
331 |
+
" print(\"sys prompt: \", system_prompt)\n",
|
332 |
+
" test_acc = eval_dataset(test_set, eval_fn, model)\n",
|
333 |
+
" results[\"test_acc\"].append(test_acc)\n",
|
334 |
+
" results[\"prompt\"].append(system_prompt.get_value())\n",
|
335 |
+
" if steps == 3:\n",
|
336 |
+
" break"
|
337 |
+
]
|
338 |
+
},
|
339 |
+
{
|
340 |
+
"cell_type": "code",
|
341 |
+
"execution_count": null,
|
342 |
+
"id": "3c5e93f5-8d1c-4b87-a6d1-811714982d47",
|
343 |
+
"metadata": {},
|
344 |
+
"outputs": [],
|
345 |
+
"source": []
|
346 |
+
},
|
347 |
+
{
|
348 |
+
"cell_type": "code",
|
349 |
+
"execution_count": null,
|
350 |
+
"id": "67a4583f-162c-4e2d-b061-798f6c676a28",
|
351 |
+
"metadata": {},
|
352 |
+
"outputs": [],
|
353 |
+
"source": [
|
354 |
+
"class TranslationQualityAssessor(dspy.Module):\n",
|
355 |
+
" def __init__(self):\n",
|
356 |
+
" super().__init__()\n",
|
357 |
+
" self.assess = dspy.ChainOfThought(TranslationQualitySignature)\n",
|
358 |
+
"\n",
|
359 |
+
" def forward(self, src_lang, tgt_lang, src_text, translation, src_prev=\"\", tgt_prev=\"\", src_next=\"\", tgt_next=\"\"):\n",
|
360 |
+
" context = f\"\"\"Previous Context:\n",
|
361 |
+
" Source: {src_prev}\n",
|
362 |
+
" Translation: {tgt_prev}\n",
|
363 |
+
" \n",
|
364 |
+
" Next Context:\n",
|
365 |
+
" Source: {src_next}\n",
|
366 |
+
" Translation: {tgt_next}\"\"\"\n",
|
367 |
+
"\n",
|
368 |
+
" result = self.assess(\n",
|
369 |
+
" context=context,\n",
|
370 |
+
" source=f\"Source ({src_lang}): {src_text}\",\n",
|
371 |
+
" translation=f\"Translation ({tgt_lang}): {translation}\"\n",
|
372 |
+
" )\n",
|
373 |
+
" \n",
|
374 |
+
" return result.evaluation\n",
|
375 |
+
"\n",
|
376 |
+
"class TranslationMetrics:\n",
|
377 |
+
" @staticmethod\n",
|
378 |
+
" def exact_match_score(pred, gold):\n",
|
379 |
+
" try:\n",
|
380 |
+
" pred_json = json.loads(pred)\n",
|
381 |
+
" gold_json = gold\n",
|
382 |
+
" \n",
|
383 |
+
" accuracy_match = (str(pred_json.get('Accuracy Score')) == str(gold_json.get('Accuracy Score')))\n",
|
384 |
+
" readability_match = (str(pred_json.get('Readability Score')) == str(gold_json.get('Readability Score')))\n",
|
385 |
+
" \n",
|
386 |
+
" return (accuracy_match and readability_match)\n",
|
387 |
+
" except:\n",
|
388 |
+
" return False\n",
|
389 |
+
" \n",
|
390 |
+
" @staticmethod\n",
|
391 |
+
" def partial_match_score(pred, gold):\n",
|
392 |
+
" try:\n",
|
393 |
+
" pred_json = json.loads(pred)\n",
|
394 |
+
" gold_json = gold\n",
|
395 |
+
" \n",
|
396 |
+
" # Score comparison\n",
|
397 |
+
" accuracy_diff = abs(float(pred_json.get('Accuracy Score', 0)) - float(gold_json.get('Accuracy Score', 0)))\n",
|
398 |
+
" readability_diff = abs(float(pred_json.get('Readability Score', 0)) - float(gold_json.get('Readability Score', 0)))\n",
|
399 |
+
" \n",
|
400 |
+
" # Issues comparison\n",
|
401 |
+
" pred_accuracy_issues = set(str(issue) for issue in pred_json.get('Accuracy Issues', []))\n",
|
402 |
+
" gold_accuracy_issues = set(str(issue) for issue in gold_json.get('Accuracy Issues', []))\n",
|
403 |
+
" pred_readability_issues = set(str(issue) for issue in pred_json.get('Readability Issues', []))\n",
|
404 |
+
" gold_readability_issues = set(str(issue) for issue in gold_json.get('Readability Issues', []))\n",
|
405 |
+
" \n",
|
406 |
+
" # Calculate Jaccard similarity for issues\n",
|
407 |
+
" accuracy_issues_sim = len(pred_accuracy_issues & gold_accuracy_issues) / max(1, len(pred_accuracy_issues | gold_accuracy_issues))\n",
|
408 |
+
" readability_issues_sim = len(pred_readability_issues & gold_readability_issues) / max(1, len(pred_readability_issues | gold_readability_issues))\n",
|
409 |
+
" \n",
|
410 |
+
" # Combine scores (0.6 weight to scores, 0.4 to issues similarity)\n",
|
411 |
+
" score_component = 1 - ((accuracy_diff + readability_diff) / 8)\n",
|
412 |
+
" issues_component = (accuracy_issues_sim + readability_issues_sim) / 2\n",
|
413 |
+
" \n",
|
414 |
+
" final_score = 0.6 * score_component + 0.4 * issues_component\n",
|
415 |
+
" return max(0, final_score)\n",
|
416 |
+
" except:\n",
|
417 |
+
" return 0\n",
|
418 |
+
"\n",
|
419 |
+
"def prepare_dataset(file_path):\n",
|
420 |
+
" with open(file_path, 'r') as f:\n",
|
421 |
+
" data = json.load(f)\n",
|
422 |
+
" \n",
|
423 |
+
" prepared_data = []\n",
|
424 |
+
" \n",
|
425 |
+
" for item in data:\n",
|
426 |
+
" example = dspy.Example(\n",
|
427 |
+
" context=f\"\"\"Previous Context:\n",
|
428 |
+
" Source: {item['src_prev']}\n",
|
429 |
+
" Translation: {item['tgt_prev']}\n",
|
430 |
+
" \n",
|
431 |
+
" Next Context:\n",
|
432 |
+
" Source: {item['src_next']}\n",
|
433 |
+
" Translation: {item['tgt_next']}\"\"\",\n",
|
434 |
+
" source=f\"Source ({item['src_lang']}): {item['src_text']}\",\n",
|
435 |
+
" translation=f\"Translation ({item['tgt_lang']}): {item['main_text']}\",\n",
|
436 |
+
" evaluation=json.dumps(item['evaluation'], ensure_ascii=False)\n",
|
437 |
+
" ).with_inputs(\"context\", \"source\", \"translation\")\n",
|
438 |
+
" \n",
|
439 |
+
" prepared_data.append(example)\n",
|
440 |
+
" \n",
|
441 |
+
" # Split data: 70% train, 15% dev, 15% test\n",
|
442 |
+
" train_size = int(0.7 * len(prepared_data))\n",
|
443 |
+
" dev_size = int(0.15 * len(prepared_data))\n",
|
444 |
+
" \n",
|
445 |
+
" train_data = prepared_data[:train_size]\n",
|
446 |
+
" dev_data = prepared_data[train_size:train_size + dev_size]\n",
|
447 |
+
" test_data = prepared_data[train_size + dev_size:]\n",
|
448 |
+
" \n",
|
449 |
+
" return train_data, dev_data, test_data\n",
|
450 |
+
"\n",
|
451 |
+
"def optimize_translation_quality_assessment():\n",
|
452 |
+
" # Initialize DSPy\n",
|
453 |
+
" lm = TranslationQualityLM()\n",
|
454 |
+
" dspy.settings.configure(lm=lm)\n",
|
455 |
+
" \n",
|
456 |
+
" # Load and prepare dataset\n",
|
457 |
+
" train_data, dev_data, test_data = prepare_dataset('translation_quality_dataset.json')\n",
|
458 |
+
" \n",
|
459 |
+
" # Create evaluator\n",
|
460 |
+
" evaluator = Evaluate(\n",
|
461 |
+
" metrics={\n",
|
462 |
+
" 'exact_match': TranslationMetrics.exact_match_score,\n",
|
463 |
+
" 'partial_match': TranslationMetrics.partial_match_score\n",
|
464 |
+
" }\n",
|
465 |
+
" )\n",
|
466 |
+
" \n",
|
467 |
+
" # Initialize module\n",
|
468 |
+
" assessor = TranslationQualityAssessor()\n",
|
469 |
+
" \n",
|
470 |
+
" # Initialize MIPROv2 optimizer\n",
|
471 |
+
" optimizer = dspy.MIPROv2(\n",
|
472 |
+
" metric=lambda x: x['partial_match'],\n",
|
473 |
+
" max_rounds=5, # Number of optimization rounds\n",
|
474 |
+
" max_traces=10, # Number of traces per round\n",
|
475 |
+
" max_depth=3, # Maximum depth of reasoning chains\n",
|
476 |
+
" num_candidate_prompts=5, # Number of candidate prompts to generate\n",
|
477 |
+
" num_rounds_per_prompt=3, # Number of rounds per candidate prompt\n",
|
478 |
+
" temperature=0.7,\n",
|
479 |
+
" verbose=True\n",
|
480 |
+
" )\n",
|
481 |
+
" \n",
|
482 |
+
" # Compile the module with optimization\n",
|
483 |
+
" compiled_assessor = optimizer.compile(\n",
|
484 |
+
" assessor,\n",
|
485 |
+
" trainset=train_data,\n",
|
486 |
+
" devset=dev_data,\n",
|
487 |
+
" eval_kwargs={\n",
|
488 |
+
" 'metric': 'partial_match',\n",
|
489 |
+
" 'num_threads': 4,\n",
|
490 |
+
" 'batch_size': 8\n",
|
491 |
+
" }\n",
|
492 |
+
" )\n",
|
493 |
+
" \n",
|
494 |
+
" # Evaluate on test set\n",
|
495 |
+
" results = []\n",
|
496 |
+
" for example in test_data:\n",
|
497 |
+
" pred = compiled_assessor(\n",
|
498 |
+
" context=example.context,\n",
|
499 |
+
" source=example.source,\n",
|
500 |
+
" translation=example.translation\n",
|
501 |
+
" )\n",
|
502 |
+
" \n",
|
503 |
+
" result = evaluator.evaluate(\n",
|
504 |
+
" predictions=[pred],\n",
|
505 |
+
" ground_truth=[example.evaluation]\n",
|
506 |
+
" )\n",
|
507 |
+
" results.append(result)\n",
|
508 |
+
" \n",
|
509 |
+
" # Calculate and print final metrics\n",
|
510 |
+
" avg_exact_match = np.mean([r['exact_match'] for r in results])\n",
|
511 |
+
" avg_partial_match = np.mean([r['partial_match'] for r in results])\n",
|
512 |
+
" \n",
|
513 |
+
" print(f\"Average Exact Match Score: {avg_exact_match:.3f}\")\n",
|
514 |
+
" print(f\"Average Partial Match Score: {avg_partial_match:.3f}\")\n",
|
515 |
+
" \n",
|
516 |
+
" return compiled_assessor\n",
|
517 |
+
"\n",
|
518 |
+
"if __name__ == \"__main__\":\n",
|
519 |
+
" optimized_assessor = optimize_translation_quality_assessment()"
|
520 |
+
]
|
521 |
+
}
|
522 |
+
],
|
523 |
+
"metadata": {
|
524 |
+
"kernelspec": {
|
525 |
+
"display_name": "timedlibs",
|
526 |
+
"language": "python",
|
527 |
+
"name": "timedlibs"
|
528 |
+
},
|
529 |
+
"language_info": {
|
530 |
+
"codemirror_mode": {
|
531 |
+
"name": "ipython",
|
532 |
+
"version": 3
|
533 |
+
},
|
534 |
+
"file_extension": ".py",
|
535 |
+
"mimetype": "text/x-python",
|
536 |
+
"name": "python",
|
537 |
+
"nbconvert_exporter": "python",
|
538 |
+
"pygments_lexer": "ipython3",
|
539 |
+
"version": "3.10.16"
|
540 |
+
}
|
541 |
+
},
|
542 |
+
"nbformat": 4,
|
543 |
+
"nbformat_minor": 5
|
544 |
+
}
|
adapter_config.json
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha_pattern": {},
|
3 |
+
"auto_mapping": null,
|
4 |
+
"base_model_name_or_path": "google/gemma-3-4b-it",
|
5 |
+
"bias": "none",
|
6 |
+
"corda_config": null,
|
7 |
+
"eva_config": null,
|
8 |
+
"exclude_modules": null,
|
9 |
+
"fan_in_fan_out": false,
|
10 |
+
"inference_mode": true,
|
11 |
+
"init_lora_weights": true,
|
12 |
+
"layer_replication": null,
|
13 |
+
"layers_pattern": null,
|
14 |
+
"layers_to_transform": null,
|
15 |
+
"loftq_config": {},
|
16 |
+
"lora_alpha": 128,
|
17 |
+
"lora_bias": false,
|
18 |
+
"lora_dropout": 0.05,
|
19 |
+
"megatron_config": null,
|
20 |
+
"megatron_core": "megatron.core",
|
21 |
+
"modules_to_save": [
|
22 |
+
"lm_head",
|
23 |
+
"embed_tokens"
|
24 |
+
],
|
25 |
+
"peft_type": "LORA",
|
26 |
+
"r": 16,
|
27 |
+
"rank_pattern": {},
|
28 |
+
"revision": null,
|
29 |
+
"target_modules": [
|
30 |
+
"out_proj",
|
31 |
+
"v_proj",
|
32 |
+
"k_proj",
|
33 |
+
"fc1",
|
34 |
+
"down_proj",
|
35 |
+
"up_proj",
|
36 |
+
"fc2",
|
37 |
+
"o_proj",
|
38 |
+
"q_proj",
|
39 |
+
"gate_proj"
|
40 |
+
],
|
41 |
+
"task_type": "CAUSAL_LM",
|
42 |
+
"trainable_token_indices": null,
|
43 |
+
"use_dora": false,
|
44 |
+
"use_rslora": false
|
45 |
+
}
|
adapter_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:913b7aae1196ba3282bd45e04a65bf705d67f933d758540a06902f63054ad6e7
|
3 |
+
size 2839124552
|
added_tokens.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"<image_soft_token>": 262144
|
3 |
+
}
|
data_prep.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing as T
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import argparse
|
5 |
+
import json
|
6 |
+
import nflx_copilot as ncp
|
7 |
+
import pandas as pd
|
8 |
+
import re
|
9 |
+
|
10 |
+
sys.path.append("/root/workspace")
|
11 |
+
|
12 |
+
from timedtext.adapters.translation.generation.pldl import TimedTextAdapter, ConverterDialogContext
|
13 |
+
from timedtext.manager import TimedTextManager
|
14 |
+
from timedtext.handlers import OriginalLanguagePivotLanguageHandler, EnglishTemplateSubtitleHandler
|
15 |
+
from timedprompts.evaluation.pldl_prompt_one.prompt import (
|
16 |
+
ReferenceFreeFeedbackTransform,
|
17 |
+
ContextFreeFeedbackTransform,
|
18 |
+
ReferenceFreeDirectTransform,
|
19 |
+
ReferenceBasedFeedbackTransform,
|
20 |
+
ReferenceFreeExampleTransform,
|
21 |
+
)
|
22 |
+
from tqdm import tqdm
|
23 |
+
from timedtune.convert.tq_for_pldl.pldl_train_one import PldlTrainOneReferenceFreeTransform
|
24 |
+
from timedtext.adapters.translation.evaluation import compute_score_delta
|
25 |
+
|
26 |
+
def compute_32_point_score(response, generation):
|
27 |
+
parsed, score = {}, -1
|
28 |
+
try:
|
29 |
+
score = (
|
30 |
+
int(response["Accuracy Score"])
|
31 |
+
+ int(response["Readability Score"])
|
32 |
+
+ compute_score_delta(response, "Accuracy Issues", generation)
|
33 |
+
+ compute_score_delta(response, "Readability Issues", generation)
|
34 |
+
)
|
35 |
+
score = score * 4
|
36 |
+
except:
|
37 |
+
score = -1
|
38 |
+
return parsed, score
|
39 |
+
|
40 |
+
# Your existing TimedTextAdapter and helper classes
|
41 |
+
class TimedTextAdapterFromCache_PLDL(TimedTextAdapter):
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
data_dir: str,
|
45 |
+
cache_size: int = 0,
|
46 |
+
ol_dialog_list_version: str = "",
|
47 |
+
pl_dialog_list_version: str = "",
|
48 |
+
ol_dialog_list_pl_dialog_list_version: str = "",
|
49 |
+
num_prev_events: int = 16,
|
50 |
+
num_next_events: int = 16,
|
51 |
+
) -> None:
|
52 |
+
super().__init__(num_prev_events, num_next_events)
|
53 |
+
self.timed_text_manager = TimedTextManager(
|
54 |
+
data_dir,
|
55 |
+
cache_size=cache_size,
|
56 |
+
ol_dialog_list_version=ol_dialog_list_version,
|
57 |
+
pl_dialog_list_version=pl_dialog_list_version,
|
58 |
+
ol_dialog_list_pl_dialog_list_version=ol_dialog_list_pl_dialog_list_version,
|
59 |
+
)
|
60 |
+
|
61 |
+
def _get_timed_text(
|
62 |
+
self, movie_id: int, start_frame: int, end_frame: int, src_lang: str, tgt_lang: str
|
63 |
+
) -> T.Dict[str, T.Union[T.Dict, T.List[T.Dict]]]:
|
64 |
+
results = self.timed_text_manager.match_and_get_timed_text(
|
65 |
+
handler_class=OriginalLanguagePivotLanguageHandler,
|
66 |
+
movie_id=movie_id,
|
67 |
+
start_frame=start_frame,
|
68 |
+
end_frame=end_frame,
|
69 |
+
src_lang=src_lang,
|
70 |
+
tgt_lang=tgt_lang,
|
71 |
+
mid_lang="",
|
72 |
+
**self.timed_text_kwargs,
|
73 |
+
)
|
74 |
+
|
75 |
+
curr_srcs = [result["curr"]["src"]["txt"] for result in results]
|
76 |
+
curr_tgts = [result["curr"]["tgt"]["txt"] for result in results]
|
77 |
+
|
78 |
+
return {
|
79 |
+
"curr": {"src": {"txt": "\n\n".join(curr_srcs)}, "tgt": {"txt": "\n\n".join(curr_tgts)}},
|
80 |
+
"prev": results[0]["prev"],
|
81 |
+
"next": results[-1]["next"],
|
82 |
+
}
|
83 |
+
|
84 |
+
class TimedTextAdapterFromCache_SUBS(TimedTextAdapter):
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
data_dir: str,
|
88 |
+
cache_size: int = 0,
|
89 |
+
num_prev_events: int = 16,
|
90 |
+
num_next_events: int = 16,
|
91 |
+
) -> None:
|
92 |
+
super().__init__(num_prev_events, num_next_events)
|
93 |
+
self.timed_text_manager = TimedTextManager(
|
94 |
+
data_dir,
|
95 |
+
cache_size=cache_size,
|
96 |
+
)
|
97 |
+
|
98 |
+
def _get_timed_text(
|
99 |
+
self, movie_id: int, start_frame: int, end_frame: int, src_lang: str, tgt_lang: str
|
100 |
+
) -> T.Dict[str, T.Union[T.Dict, T.List[T.Dict]]]:
|
101 |
+
results = self.timed_text_manager.match_and_get_timed_text(
|
102 |
+
handler_class=EnglishTemplateSubtitleHandler,
|
103 |
+
movie_id=movie_id,
|
104 |
+
start_frame=start_frame,
|
105 |
+
end_frame=end_frame,
|
106 |
+
src_lang=src_lang,
|
107 |
+
tgt_lang=tgt_lang,
|
108 |
+
mid_lang="",
|
109 |
+
**self.timed_text_kwargs,
|
110 |
+
)
|
111 |
+
|
112 |
+
curr_srcs = [result["curr"]["src"]["txt"] for result in results]
|
113 |
+
curr_tgts = [result["curr"]["tgt"]["txt"] for result in results]
|
114 |
+
|
115 |
+
return {
|
116 |
+
"curr": {"src": {"txt": "\n\n".join(curr_srcs)}, "tgt": {"txt": "\n\n".join(curr_tgts)}},
|
117 |
+
"prev": results[0]["prev"],
|
118 |
+
"next": results[-1]["next"],
|
119 |
+
}
|
120 |
+
|
121 |
+
|
122 |
+
# Function to fetch contextual information using TimedTextAdapter
|
123 |
+
def fetch_contextual_information(timed_text_adapter, row):
|
124 |
+
"""
|
125 |
+
Fetches the required context information for each sample using timed_text_adapter.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
timed_text_adapter (TimedTextAdapterFromCache): Adapter to fetch data from.
|
129 |
+
row (dict): Row containing the necessary information to fetch the context.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
dict: Contextual information containing src_text, tgt_text, prev_context, next_context, src_prev, src_next, tgt_prev, tgt_next.
|
133 |
+
"""
|
134 |
+
# Fetching the actual translation context
|
135 |
+
src_text, tgt_text, prev_context, next_context = timed_text_adapter.get_timed_text(
|
136 |
+
movie_id=row["movie_id"],
|
137 |
+
start_frame=row["start_frame"],
|
138 |
+
end_frame=row["end_frame"],
|
139 |
+
src_lang=row["src_lang"],
|
140 |
+
tgt_lang=row["tgt_lang"],
|
141 |
+
)
|
142 |
+
|
143 |
+
timed_text_converter = ConverterDialogContext(timed_text_adapter)
|
144 |
+
|
145 |
+
# Converting context to the format expected by the prompt
|
146 |
+
src_prev, src_next, tgt_prev, tgt_next, _ = timed_text_converter.__context__(
|
147 |
+
row["src_lang"], row["tgt_lang"], prev_context, next_context, None
|
148 |
+
)
|
149 |
+
|
150 |
+
return {
|
151 |
+
"tt_src_text": src_text,
|
152 |
+
"tt_tgt_text": tgt_text,
|
153 |
+
"tt_src_prev": src_prev,
|
154 |
+
"tt_src_next": src_next,
|
155 |
+
"tt_tgt_prev": tgt_prev,
|
156 |
+
"tt_tgt_next": tgt_next,
|
157 |
+
}
|
158 |
+
|
159 |
+
def transform_json(input_json):
|
160 |
+
# Get the first project key
|
161 |
+
project_key = list(input_json['projects'].keys())[0]
|
162 |
+
project = input_json['projects'][project_key]
|
163 |
+
|
164 |
+
final_output = {"labelers": []}
|
165 |
+
# Process each label
|
166 |
+
for index, label in enumerate(project['labels']):
|
167 |
+
# Initialize output structure
|
168 |
+
output = {
|
169 |
+
"annotation": {
|
170 |
+
"Accuracy Issues": [],
|
171 |
+
"Readability Issues": [],
|
172 |
+
"Accuracy Score": "",
|
173 |
+
"Readability Score": "",
|
174 |
+
"Confidence Level": "",
|
175 |
+
"Main Vs Alternate": "",
|
176 |
+
"Score": "-1" # initalized -1, will be updated in next steps
|
177 |
+
},
|
178 |
+
}
|
179 |
+
# Process annotations/objects (issues)
|
180 |
+
if 'objects' in label['annotations']:
|
181 |
+
for obj in label['annotations']['objects']:
|
182 |
+
issue = {
|
183 |
+
"Error Location": obj['conversational_location']['message_id'],
|
184 |
+
"Error Span": [
|
185 |
+
obj['conversational_location']['location']['start'],
|
186 |
+
obj['conversational_location']['location']['end']
|
187 |
+
],
|
188 |
+
"Error Explanation": "",
|
189 |
+
"Error Quality Category": obj['name'],
|
190 |
+
"Error Quality Tags": [],
|
191 |
+
"Error Severity": ""
|
192 |
+
}
|
193 |
+
|
194 |
+
# Process classifications within object
|
195 |
+
for classification in obj['classifications']:
|
196 |
+
if classification['name'] == 'Explanation':
|
197 |
+
issue["Error Explanation"] = classification['text_answer']['content']
|
198 |
+
elif classification['name'] == 'Quality Tag':
|
199 |
+
issue["Error Quality Tags"] = [ans['name'].lower() for ans in classification['checklist_answers']]
|
200 |
+
elif classification['name'] == 'Quality SubCategory':
|
201 |
+
severity = classification['radio_answer']['name']
|
202 |
+
if 'Major' in severity:
|
203 |
+
issue["Error Severity"] = "Major"
|
204 |
+
else:
|
205 |
+
issue["Error Severity"] = "Minor"
|
206 |
+
|
207 |
+
# Add to appropriate issues list
|
208 |
+
if obj['name'] == 'Style':
|
209 |
+
output['annotation']['Readability Issues'].append(issue)
|
210 |
+
else:
|
211 |
+
output['annotation']['Accuracy Issues'].append(issue)
|
212 |
+
|
213 |
+
# Process classifications
|
214 |
+
for classification in label['annotations']['classifications']:
|
215 |
+
if classification['name'] == 'Accuracy Score':
|
216 |
+
output['annotation']['Accuracy Score'] = classification['radio_answer']['name'].split(' - ')[0]
|
217 |
+
elif classification['name'] == 'Readability Score':
|
218 |
+
output['annotation']['Readability Score'] = classification['radio_answer']['name'].split(' - ')[0]
|
219 |
+
elif classification['name'] == 'Confidence Level':
|
220 |
+
output['annotation']['Confidence Level'] = classification['radio_answer']['value']
|
221 |
+
elif classification['name'] == 'Main vs Alternate':
|
222 |
+
output['annotation']['Main Vs Alternate'] = classification['radio_answer']['name']
|
223 |
+
final_output["labelers"].append(output)
|
224 |
+
return final_output
|
225 |
+
|
226 |
+
# Function to load the relevant meta json for a given key
|
227 |
+
def load_meta_json(priority_key, data_row_key, meta_path):
|
228 |
+
"""
|
229 |
+
Loads and validates metadata json from the specified path based on the priority key and data row key.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
priority_key (str): Priority key from the label metadata.
|
233 |
+
data_row_key (str): Data row key to find the relevant file.
|
234 |
+
meta_path (str): Path to the metadata folder.
|
235 |
+
|
236 |
+
Returns:
|
237 |
+
dict: Loaded metadata.
|
238 |
+
"""
|
239 |
+
with open(os.path.join(meta_path, f'{priority_key}.json')) as fread:
|
240 |
+
meta_dict = json.load(fread)
|
241 |
+
|
242 |
+
_, movie_id, start_end_frame, _, _, _, _ = data_row_key.split('.')
|
243 |
+
start_frame, end_frame = start_end_frame.split('_')
|
244 |
+
|
245 |
+
if int(meta_dict['movie_id']) != int(movie_id):
|
246 |
+
print("Movie Ids didn't match:", int(meta_dict['movie_id']), int(movie_id), os.path.join(meta_path, f'{priority_key}.json'), data_row_key)
|
247 |
+
exit(0)
|
248 |
+
assert int(meta_dict['start_frame']) == int(start_frame)
|
249 |
+
assert int(meta_dict['end_frame']) == int(end_frame)
|
250 |
+
|
251 |
+
return meta_dict
|
252 |
+
|
253 |
+
# Main function that processes the data
|
254 |
+
def process_json(timed_text_adapter, example_row, meta_path, conv_path):
|
255 |
+
"""
|
256 |
+
Takes the full input json, converts it to the required format, and adds context using metadata.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
timed_text_adapter (TimedTextAdapterFromCache): Adapter to fetch context.
|
260 |
+
example_row (dict): The full input JSON (like the example_row you provided).
|
261 |
+
meta_path (str): Path to the metadata folder to fetch meta json.
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
dict: The enriched annotation format with context and annotation data.
|
265 |
+
"""
|
266 |
+
# Step 1: Convert the full input JSON to the required annotation format
|
267 |
+
annotation_result = transform_json(example_row)
|
268 |
+
|
269 |
+
# Extracting the necessary data_row_key and priority_key
|
270 |
+
data_row_key = example_row['data_row']['global_key']
|
271 |
+
priority_key = example_row['projects'][list(example_row["projects"].keys())[0]]['project_details']['priority']
|
272 |
+
|
273 |
+
annotation_result["Data_Row_Key"] = data_row_key
|
274 |
+
key = ".".join(data_row_key.split(".")[:3])
|
275 |
+
with open(conv_path + "/" + key + ".json") as file:
|
276 |
+
data = json.load(file)
|
277 |
+
annotation_result["main_tgt_text"] = data["messages"][0]["content"]
|
278 |
+
annotation_result["src_text"] = data["messages"][1]["content"]
|
279 |
+
annotation_result["alt_tgt_text"] = data["messages"][2]["content"]
|
280 |
+
|
281 |
+
# Load the metadata using the keys from the json
|
282 |
+
meta_dict = load_meta_json(priority_key, data_row_key, meta_path)
|
283 |
+
|
284 |
+
# Step 2: Add the metadata fields (e.g., title_id, start_frame, end_frame, src_lang, tgt_lang)
|
285 |
+
annotation_result.update({
|
286 |
+
"title_id": meta_dict['movie_id'],
|
287 |
+
"start_frame": meta_dict['start_frame'],
|
288 |
+
"end_frame": meta_dict['end_frame'],
|
289 |
+
"src_lang": meta_dict['src_lang'],
|
290 |
+
"tgt_lang": meta_dict['tgt_lang'],
|
291 |
+
})
|
292 |
+
|
293 |
+
# Step 3: Fetch contextual information using the given timed_text_adapter
|
294 |
+
context_info = fetch_contextual_information(timed_text_adapter, meta_dict)
|
295 |
+
|
296 |
+
annotation_result.update(context_info)
|
297 |
+
|
298 |
+
# Update error spans with actual text for each labeler
|
299 |
+
for labeler in annotation_result["labelers"]:
|
300 |
+
# Process Accuracy Issues
|
301 |
+
for issue in labeler["annotation"]["Accuracy Issues"]:
|
302 |
+
error_location = issue["Error Location"]
|
303 |
+
start, end = issue["Error Span"][0], issue["Error Span"][1]
|
304 |
+
|
305 |
+
# Get the actual text based on error location
|
306 |
+
if error_location == "src":
|
307 |
+
actual_text = annotation_result["src_text"][start:end]
|
308 |
+
else: # tgt
|
309 |
+
actual_text = annotation_result["main_tgt_text"][start:end]
|
310 |
+
|
311 |
+
# Update the error span with actual text
|
312 |
+
issue["Error Span"] = actual_text
|
313 |
+
|
314 |
+
# Process Readability Issues
|
315 |
+
for issue in labeler["annotation"]["Readability Issues"]:
|
316 |
+
error_location = issue["Error Location"]
|
317 |
+
start, end = issue["Error Span"]
|
318 |
+
|
319 |
+
# Get the actual text based on error location
|
320 |
+
if error_location == "src":
|
321 |
+
actual_text = annotation_result["src_text"][start:end]
|
322 |
+
else: # tgt
|
323 |
+
actual_text = annotation_result["main_tgt_text"][start:end]
|
324 |
+
|
325 |
+
# Update the error span with actual text
|
326 |
+
issue["Error Span"] = actual_text
|
327 |
+
|
328 |
+
return annotation_result
|
329 |
+
|
330 |
+
# Example usage
|
331 |
+
def main():
|
332 |
+
base_path = "MT_TQ/Caches/May2025/tquality.annotated.data/"
|
333 |
+
json_files = [base_path + "raw/" + f for f in os.listdir(base_path + "raw/") if f.endswith('.json')]
|
334 |
+
|
335 |
+
for json_file in tqdm(json_files):
|
336 |
+
if "calibration" in json_file:
|
337 |
+
print("Warning: Skipping Calibration Data, Remove this if you want to use Calibration data")
|
338 |
+
continue
|
339 |
+
|
340 |
+
if "PLDL" in json_file:
|
341 |
+
folder = "pldl"
|
342 |
+
timed_text_adapter = TimedTextAdapterFromCache_PLDL(
|
343 |
+
data_dir="/fsx_l10n/l10n_dse_timedtext/cache", num_prev_events=32, num_next_events=32
|
344 |
+
)
|
345 |
+
elif "SUBS" in json_file:
|
346 |
+
folder = "subs"
|
347 |
+
timed_text_adapter = TimedTextAdapterFromCache_SUBS(
|
348 |
+
data_dir="/fsx_l10n/l10n_dse_timedtext/cache", num_prev_events=32, num_next_events=32
|
349 |
+
)
|
350 |
+
else:
|
351 |
+
folder = ""
|
352 |
+
assert "invalid json file"
|
353 |
+
|
354 |
+
langs_type = json_file.split("/")[-1].split("-")[1].replace("_",".")
|
355 |
+
phase = json_file.split("/")[-1].split("-")[3]
|
356 |
+
phase_number = int(''.join(re.findall(r'\d+', phase))) if re.findall(r'\d+', phase) else None
|
357 |
+
phase_date = json_file.split("/")[-1].split("-")[4].replace(".json", "")
|
358 |
+
|
359 |
+
zzmetapath = f"/root/notebooks/MT_TQ/Caches/labelspace/tquality.zzmeta.data/{folder}/{langs_type}/phase {phase_number} - {phase_date}"
|
360 |
+
|
361 |
+
meta_path = zzmetapath + "/meta"
|
362 |
+
conv_path = zzmetapath + "/conv"
|
363 |
+
|
364 |
+
with open(json_file) as file:
|
365 |
+
data = json.load(file)
|
366 |
+
|
367 |
+
output_data = []
|
368 |
+
for data_point in tqdm(data):
|
369 |
+
annotation_result = process_json(timed_text_adapter, data_point, meta_path, conv_path)
|
370 |
+
for labeler in annotation_result["labelers"]:
|
371 |
+
_, score = compute_32_point_score(labeler["annotation"], annotation_result["main_tgt_text"])
|
372 |
+
labeler["annotation"]["Score"] = score
|
373 |
+
|
374 |
+
output_data.append(annotation_result)
|
375 |
+
|
376 |
+
with open(base_path + "parsed/" + json_file.split("/")[-1], 'w') as json_file:
|
377 |
+
json.dump({"data": output_data}, json_file, indent=4)
|
378 |
+
|
379 |
+
if __name__ == "__main__":
|
380 |
+
main()
|
gemma-12b-tq-model/README.md
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
base_model: google/gemma-3-4b-it
|
3 |
+
library_name: transformers
|
4 |
+
model_name: gemma-12b-tq-model
|
5 |
+
tags:
|
6 |
+
- generated_from_trainer
|
7 |
+
- trl
|
8 |
+
- sft
|
9 |
+
licence: license
|
10 |
+
---
|
11 |
+
|
12 |
+
# Model Card for gemma-12b-tq-model
|
13 |
+
|
14 |
+
This model is a fine-tuned version of [google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it).
|
15 |
+
It has been trained using [TRL](https://github.com/huggingface/trl).
|
16 |
+
|
17 |
+
## Quick start
|
18 |
+
|
19 |
+
```python
|
20 |
+
from transformers import pipeline
|
21 |
+
|
22 |
+
question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?"
|
23 |
+
generator = pipeline("text-generation", model="bhavinjawade/gemma-12b-tq-model", device="cuda")
|
24 |
+
output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0]
|
25 |
+
print(output["generated_text"])
|
26 |
+
```
|
27 |
+
|
28 |
+
## Training procedure
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
This model was trained with SFT.
|
34 |
+
|
35 |
+
### Framework versions
|
36 |
+
|
37 |
+
- TRL: 0.16.1
|
38 |
+
- Transformers: 4.50.0.dev0
|
39 |
+
- Pytorch: 2.7.0
|
40 |
+
- Datasets: 3.3.2
|
41 |
+
- Tokenizers: 0.21.0
|
42 |
+
|
43 |
+
## Citations
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
Cite TRL as:
|
48 |
+
|
49 |
+
```bibtex
|
50 |
+
@misc{vonwerra2022trl,
|
51 |
+
title = {{TRL: Transformer Reinforcement Learning}},
|
52 |
+
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
|
53 |
+
year = 2020,
|
54 |
+
journal = {GitHub repository},
|
55 |
+
publisher = {GitHub},
|
56 |
+
howpublished = {\url{https://github.com/huggingface/trl}}
|
57 |
+
}
|
58 |
+
```
|
gemma-12b-tq-model/adapter_config.json
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha_pattern": {},
|
3 |
+
"auto_mapping": null,
|
4 |
+
"base_model_name_or_path": "google/gemma-3-4b-it",
|
5 |
+
"bias": "none",
|
6 |
+
"corda_config": null,
|
7 |
+
"eva_config": null,
|
8 |
+
"exclude_modules": null,
|
9 |
+
"fan_in_fan_out": false,
|
10 |
+
"inference_mode": true,
|
11 |
+
"init_lora_weights": true,
|
12 |
+
"layer_replication": null,
|
13 |
+
"layers_pattern": null,
|
14 |
+
"layers_to_transform": null,
|
15 |
+
"loftq_config": {},
|
16 |
+
"lora_alpha": 128,
|
17 |
+
"lora_bias": false,
|
18 |
+
"lora_dropout": 0.05,
|
19 |
+
"megatron_config": null,
|
20 |
+
"megatron_core": "megatron.core",
|
21 |
+
"modules_to_save": [
|
22 |
+
"lm_head",
|
23 |
+
"embed_tokens"
|
24 |
+
],
|
25 |
+
"peft_type": "LORA",
|
26 |
+
"r": 16,
|
27 |
+
"rank_pattern": {},
|
28 |
+
"revision": null,
|
29 |
+
"target_modules": [
|
30 |
+
"up_proj",
|
31 |
+
"gate_proj",
|
32 |
+
"fc2",
|
33 |
+
"out_proj",
|
34 |
+
"fc1",
|
35 |
+
"down_proj",
|
36 |
+
"o_proj",
|
37 |
+
"k_proj",
|
38 |
+
"q_proj",
|
39 |
+
"v_proj"
|
40 |
+
],
|
41 |
+
"task_type": "CAUSAL_LM",
|
42 |
+
"trainable_token_indices": null,
|
43 |
+
"use_dora": false,
|
44 |
+
"use_rslora": false
|
45 |
+
}
|
gemma-12b-tq-model/adapter_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:44481365420dd4297edab0bc7d76dc43d4e6d7f38e393cce87c2fabdbea96661
|
3 |
+
size 2839124552
|
gemma-12b-tq-model/added_tokens.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"<image_soft_token>": 262144
|
3 |
+
}
|
gemma-12b-tq-model/checkpoint-2/README.md
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
base_model: google/gemma-3-4b-it
|
3 |
+
library_name: peft
|
4 |
+
---
|
5 |
+
|
6 |
+
# Model Card for Model ID
|
7 |
+
|
8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
## Model Details
|
13 |
+
|
14 |
+
### Model Description
|
15 |
+
|
16 |
+
<!-- Provide a longer summary of what this model is. -->
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
- **Developed by:** [More Information Needed]
|
21 |
+
- **Funded by [optional]:** [More Information Needed]
|
22 |
+
- **Shared by [optional]:** [More Information Needed]
|
23 |
+
- **Model type:** [More Information Needed]
|
24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
25 |
+
- **License:** [More Information Needed]
|
26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
27 |
+
|
28 |
+
### Model Sources [optional]
|
29 |
+
|
30 |
+
<!-- Provide the basic links for the model. -->
|
31 |
+
|
32 |
+
- **Repository:** [More Information Needed]
|
33 |
+
- **Paper [optional]:** [More Information Needed]
|
34 |
+
- **Demo [optional]:** [More Information Needed]
|
35 |
+
|
36 |
+
## Uses
|
37 |
+
|
38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
39 |
+
|
40 |
+
### Direct Use
|
41 |
+
|
42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
43 |
+
|
44 |
+
[More Information Needed]
|
45 |
+
|
46 |
+
### Downstream Use [optional]
|
47 |
+
|
48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
49 |
+
|
50 |
+
[More Information Needed]
|
51 |
+
|
52 |
+
### Out-of-Scope Use
|
53 |
+
|
54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
55 |
+
|
56 |
+
[More Information Needed]
|
57 |
+
|
58 |
+
## Bias, Risks, and Limitations
|
59 |
+
|
60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
61 |
+
|
62 |
+
[More Information Needed]
|
63 |
+
|
64 |
+
### Recommendations
|
65 |
+
|
66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
67 |
+
|
68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
69 |
+
|
70 |
+
## How to Get Started with the Model
|
71 |
+
|
72 |
+
Use the code below to get started with the model.
|
73 |
+
|
74 |
+
[More Information Needed]
|
75 |
+
|
76 |
+
## Training Details
|
77 |
+
|
78 |
+
### Training Data
|
79 |
+
|
80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
81 |
+
|
82 |
+
[More Information Needed]
|
83 |
+
|
84 |
+
### Training Procedure
|
85 |
+
|
86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
87 |
+
|
88 |
+
#### Preprocessing [optional]
|
89 |
+
|
90 |
+
[More Information Needed]
|
91 |
+
|
92 |
+
|
93 |
+
#### Training Hyperparameters
|
94 |
+
|
95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
96 |
+
|
97 |
+
#### Speeds, Sizes, Times [optional]
|
98 |
+
|
99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
100 |
+
|
101 |
+
[More Information Needed]
|
102 |
+
|
103 |
+
## Evaluation
|
104 |
+
|
105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
106 |
+
|
107 |
+
### Testing Data, Factors & Metrics
|
108 |
+
|
109 |
+
#### Testing Data
|
110 |
+
|
111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
112 |
+
|
113 |
+
[More Information Needed]
|
114 |
+
|
115 |
+
#### Factors
|
116 |
+
|
117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
118 |
+
|
119 |
+
[More Information Needed]
|
120 |
+
|
121 |
+
#### Metrics
|
122 |
+
|
123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
124 |
+
|
125 |
+
[More Information Needed]
|
126 |
+
|
127 |
+
### Results
|
128 |
+
|
129 |
+
[More Information Needed]
|
130 |
+
|
131 |
+
#### Summary
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
## Model Examination [optional]
|
136 |
+
|
137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
138 |
+
|
139 |
+
[More Information Needed]
|
140 |
+
|
141 |
+
## Environmental Impact
|
142 |
+
|
143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
144 |
+
|
145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
146 |
+
|
147 |
+
- **Hardware Type:** [More Information Needed]
|
148 |
+
- **Hours used:** [More Information Needed]
|
149 |
+
- **Cloud Provider:** [More Information Needed]
|
150 |
+
- **Compute Region:** [More Information Needed]
|
151 |
+
- **Carbon Emitted:** [More Information Needed]
|
152 |
+
|
153 |
+
## Technical Specifications [optional]
|
154 |
+
|
155 |
+
### Model Architecture and Objective
|
156 |
+
|
157 |
+
[More Information Needed]
|
158 |
+
|
159 |
+
### Compute Infrastructure
|
160 |
+
|
161 |
+
[More Information Needed]
|
162 |
+
|
163 |
+
#### Hardware
|
164 |
+
|
165 |
+
[More Information Needed]
|
166 |
+
|
167 |
+
#### Software
|
168 |
+
|
169 |
+
[More Information Needed]
|
170 |
+
|
171 |
+
## Citation [optional]
|
172 |
+
|
173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
174 |
+
|
175 |
+
**BibTeX:**
|
176 |
+
|
177 |
+
[More Information Needed]
|
178 |
+
|
179 |
+
**APA:**
|
180 |
+
|
181 |
+
[More Information Needed]
|
182 |
+
|
183 |
+
## Glossary [optional]
|
184 |
+
|
185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
186 |
+
|
187 |
+
[More Information Needed]
|
188 |
+
|
189 |
+
## More Information [optional]
|
190 |
+
|
191 |
+
[More Information Needed]
|
192 |
+
|
193 |
+
## Model Card Authors [optional]
|
194 |
+
|
195 |
+
[More Information Needed]
|
196 |
+
|
197 |
+
## Model Card Contact
|
198 |
+
|
199 |
+
[More Information Needed]
|
200 |
+
### Framework versions
|
201 |
+
|
202 |
+
- PEFT 0.15.2
|
gemma-12b-tq-model/checkpoint-2/adapter_config.json
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha_pattern": {},
|
3 |
+
"auto_mapping": null,
|
4 |
+
"base_model_name_or_path": "google/gemma-3-4b-it",
|
5 |
+
"bias": "none",
|
6 |
+
"corda_config": null,
|
7 |
+
"eva_config": null,
|
8 |
+
"exclude_modules": null,
|
9 |
+
"fan_in_fan_out": false,
|
10 |
+
"inference_mode": true,
|
11 |
+
"init_lora_weights": true,
|
12 |
+
"layer_replication": null,
|
13 |
+
"layers_pattern": null,
|
14 |
+
"layers_to_transform": null,
|
15 |
+
"loftq_config": {},
|
16 |
+
"lora_alpha": 128,
|
17 |
+
"lora_bias": false,
|
18 |
+
"lora_dropout": 0.05,
|
19 |
+
"megatron_config": null,
|
20 |
+
"megatron_core": "megatron.core",
|
21 |
+
"modules_to_save": [
|
22 |
+
"lm_head",
|
23 |
+
"embed_tokens"
|
24 |
+
],
|
25 |
+
"peft_type": "LORA",
|
26 |
+
"r": 16,
|
27 |
+
"rank_pattern": {},
|
28 |
+
"revision": null,
|
29 |
+
"target_modules": [
|
30 |
+
"up_proj",
|
31 |
+
"gate_proj",
|
32 |
+
"fc2",
|
33 |
+
"out_proj",
|
34 |
+
"fc1",
|
35 |
+
"down_proj",
|
36 |
+
"o_proj",
|
37 |
+
"k_proj",
|
38 |
+
"q_proj",
|
39 |
+
"v_proj"
|
40 |
+
],
|
41 |
+
"task_type": "CAUSAL_LM",
|
42 |
+
"trainable_token_indices": null,
|
43 |
+
"use_dora": false,
|
44 |
+
"use_rslora": false
|
45 |
+
}
|
gemma-12b-tq-model/checkpoint-2/adapter_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:44481365420dd4297edab0bc7d76dc43d4e6d7f38e393cce87c2fabdbea96661
|
3 |
+
size 2839124552
|
gemma-12b-tq-model/checkpoint-2/added_tokens.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"<image_soft_token>": 262144
|
3 |
+
}
|
gemma-12b-tq-model/checkpoint-2/optimizer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:898abdc006b804547b999529ece7d0ca106ab09c0a2352337e1809b5041573ee
|
3 |
+
size 5608850589
|
gemma-12b-tq-model/checkpoint-2/rng_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:250560ab3d528161ab3659b120def6e4a9ab4b457e3399603bbcfa40db3efc90
|
3 |
+
size 14645
|
gemma-12b-tq-model/checkpoint-2/scheduler.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:29847f084360f67920e16e2780978c5b4908b1a69433f50a755d4db1e0c11563
|
3 |
+
size 1401
|
gemma-12b-tq-model/checkpoint-2/special_tokens_map.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"boi_token": "<start_of_image>",
|
3 |
+
"bos_token": {
|
4 |
+
"content": "<bos>",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false
|
9 |
+
},
|
10 |
+
"eoi_token": "<end_of_image>",
|
11 |
+
"eos_token": {
|
12 |
+
"content": "<eos>",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false
|
17 |
+
},
|
18 |
+
"image_token": "<image_soft_token>",
|
19 |
+
"pad_token": {
|
20 |
+
"content": "<pad>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false
|
25 |
+
},
|
26 |
+
"unk_token": {
|
27 |
+
"content": "<unk>",
|
28 |
+
"lstrip": false,
|
29 |
+
"normalized": false,
|
30 |
+
"rstrip": false,
|
31 |
+
"single_word": false
|
32 |
+
}
|
33 |
+
}
|
gemma-12b-tq-model/checkpoint-2/tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4667f2089529e8e7657cfb6d1c19910ae71ff5f28aa7ab2ff2763330affad795
|
3 |
+
size 33384568
|
gemma-12b-tq-model/checkpoint-2/tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
|
3 |
+
size 4689074
|
gemma-12b-tq-model/checkpoint-2/tokenizer_config.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
gemma-12b-tq-model/checkpoint-2/trainer_state.json
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"best_metric": null,
|
3 |
+
"best_model_checkpoint": null,
|
4 |
+
"epoch": 0.7272727272727273,
|
5 |
+
"eval_steps": 500,
|
6 |
+
"global_step": 2,
|
7 |
+
"is_hyper_param_search": false,
|
8 |
+
"is_local_process_zero": true,
|
9 |
+
"is_world_process_zero": true,
|
10 |
+
"log_history": [
|
11 |
+
{
|
12 |
+
"epoch": 0.36363636363636365,
|
13 |
+
"grad_norm": 269.2862243652344,
|
14 |
+
"learning_rate": 0.0002,
|
15 |
+
"loss": 12.5019,
|
16 |
+
"mean_token_accuracy": 0.4838709682226181,
|
17 |
+
"num_tokens": 4096.0,
|
18 |
+
"step": 1
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"epoch": 0.7272727272727273,
|
22 |
+
"grad_norm": 232.57679748535156,
|
23 |
+
"learning_rate": 0.0002,
|
24 |
+
"loss": 9.4112,
|
25 |
+
"mean_token_accuracy": 0.5447482168674469,
|
26 |
+
"num_tokens": 7561.0,
|
27 |
+
"step": 2
|
28 |
+
}
|
29 |
+
],
|
30 |
+
"logging_steps": 1,
|
31 |
+
"max_steps": 2,
|
32 |
+
"num_input_tokens_seen": 0,
|
33 |
+
"num_train_epochs": 1,
|
34 |
+
"save_steps": 500,
|
35 |
+
"stateful_callbacks": {
|
36 |
+
"TrainerControl": {
|
37 |
+
"args": {
|
38 |
+
"should_epoch_stop": false,
|
39 |
+
"should_evaluate": false,
|
40 |
+
"should_log": false,
|
41 |
+
"should_save": true,
|
42 |
+
"should_training_stop": true
|
43 |
+
},
|
44 |
+
"attributes": {}
|
45 |
+
}
|
46 |
+
},
|
47 |
+
"total_flos": 196609832513952.0,
|
48 |
+
"train_batch_size": 1,
|
49 |
+
"trial_name": null,
|
50 |
+
"trial_params": null
|
51 |
+
}
|
gemma-12b-tq-model/checkpoint-2/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6a6ae5bb0171c4fcfd72d4d46496b122da48e4ebf65ff31d5812d7d8dba26a8e
|
3 |
+
size 6161
|
gemma-12b-tq-model/runs/Apr25_08-39-59_9945b53f-579e-4565-94fc-5fbe73c83cc2/events.out.tfevents.1745570448.9945b53f-579e-4565-94fc-5fbe73c83cc2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1316da37ce84d8a16f9542d45cf4ab5617be2069029895d1e700955c546f6c26
|
3 |
+
size 6460
|
gemma-12b-tq-model/runs/Apr25_08-42-29_9945b53f-579e-4565-94fc-5fbe73c83cc2/events.out.tfevents.1745570563.9945b53f-579e-4565-94fc-5fbe73c83cc2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8a216b92a09ed293f852daa5eac2575a729347e101cd1e58fb41bce94b4a0c3d
|
3 |
+
size 6460
|
gemma-12b-tq-model/runs/Apr25_09-19-39_9945b53f-579e-4565-94fc-5fbe73c83cc2/events.out.tfevents.1745572788.9945b53f-579e-4565-94fc-5fbe73c83cc2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac3b5c4eb7654e90a80565f05650922f5b8fe9c0c0e8f02134b18287d5ef32db
|
3 |
+
size 7455
|
gemma-12b-tq-model/special_tokens_map.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"boi_token": "<start_of_image>",
|
3 |
+
"bos_token": {
|
4 |
+
"content": "<bos>",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false
|
9 |
+
},
|
10 |
+
"eoi_token": "<end_of_image>",
|
11 |
+
"eos_token": {
|
12 |
+
"content": "<eos>",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false
|
17 |
+
},
|
18 |
+
"image_token": "<image_soft_token>",
|
19 |
+
"pad_token": {
|
20 |
+
"content": "<pad>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false
|
25 |
+
},
|
26 |
+
"unk_token": {
|
27 |
+
"content": "<unk>",
|
28 |
+
"lstrip": false,
|
29 |
+
"normalized": false,
|
30 |
+
"rstrip": false,
|
31 |
+
"single_word": false
|
32 |
+
}
|
33 |
+
}
|
gemma-12b-tq-model/tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4667f2089529e8e7657cfb6d1c19910ae71ff5f28aa7ab2ff2763330affad795
|
3 |
+
size 33384568
|
gemma-12b-tq-model/tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
|
3 |
+
size 4689074
|
gemma-12b-tq-model/tokenizer_config.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
gemma-12b-tq-model/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6a6ae5bb0171c4fcfd72d4d46496b122da48e4ebf65ff31d5812d7d8dba26a8e
|
3 |
+
size 6161
|
gemma-1b-tq-model/README.md
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
base_model: google/gemma-3-1b-pt
|
3 |
+
library_name: transformers
|
4 |
+
model_name: gemma-1b-tq-model
|
5 |
+
tags:
|
6 |
+
- generated_from_trainer
|
7 |
+
- trl
|
8 |
+
- sft
|
9 |
+
licence: license
|
10 |
+
---
|
11 |
+
|
12 |
+
# Model Card for gemma-1b-tq-model
|
13 |
+
|
14 |
+
This model is a fine-tuned version of [google/gemma-3-1b-pt](https://huggingface.co/google/gemma-3-1b-pt).
|
15 |
+
It has been trained using [TRL](https://github.com/huggingface/trl).
|
16 |
+
|
17 |
+
## Quick start
|
18 |
+
|
19 |
+
```python
|
20 |
+
from transformers import pipeline
|
21 |
+
|
22 |
+
question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?"
|
23 |
+
generator = pipeline("text-generation", model="bhavinjawade/gemma-1b-tq-model", device="cuda")
|
24 |
+
output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0]
|
25 |
+
print(output["generated_text"])
|
26 |
+
```
|
27 |
+
|
28 |
+
## Training procedure
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
This model was trained with SFT.
|
34 |
+
|
35 |
+
### Framework versions
|
36 |
+
|
37 |
+
- TRL: 0.16.1
|
38 |
+
- Transformers: 4.50.0.dev0
|
39 |
+
- Pytorch: 2.7.0
|
40 |
+
- Datasets: 3.3.2
|
41 |
+
- Tokenizers: 0.21.0
|
42 |
+
|
43 |
+
## Citations
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
Cite TRL as:
|
48 |
+
|
49 |
+
```bibtex
|
50 |
+
@misc{vonwerra2022trl,
|
51 |
+
title = {{TRL: Transformer Reinforcement Learning}},
|
52 |
+
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
|
53 |
+
year = 2020,
|
54 |
+
journal = {GitHub repository},
|
55 |
+
publisher = {GitHub},
|
56 |
+
howpublished = {\url{https://github.com/huggingface/trl}}
|
57 |
+
}
|
58 |
+
```
|
gemma-1b-tq-model/adapter_config.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha_pattern": {},
|
3 |
+
"auto_mapping": null,
|
4 |
+
"base_model_name_or_path": "google/gemma-3-1b-pt",
|
5 |
+
"bias": "none",
|
6 |
+
"corda_config": null,
|
7 |
+
"eva_config": null,
|
8 |
+
"exclude_modules": null,
|
9 |
+
"fan_in_fan_out": false,
|
10 |
+
"inference_mode": true,
|
11 |
+
"init_lora_weights": true,
|
12 |
+
"layer_replication": null,
|
13 |
+
"layers_pattern": null,
|
14 |
+
"layers_to_transform": null,
|
15 |
+
"loftq_config": {},
|
16 |
+
"lora_alpha": 16,
|
17 |
+
"lora_bias": false,
|
18 |
+
"lora_dropout": 0.05,
|
19 |
+
"megatron_config": null,
|
20 |
+
"megatron_core": "megatron.core",
|
21 |
+
"modules_to_save": [
|
22 |
+
"lm_head",
|
23 |
+
"embed_tokens"
|
24 |
+
],
|
25 |
+
"peft_type": "LORA",
|
26 |
+
"r": 16,
|
27 |
+
"rank_pattern": {},
|
28 |
+
"revision": null,
|
29 |
+
"target_modules": [
|
30 |
+
"o_proj",
|
31 |
+
"k_proj",
|
32 |
+
"v_proj",
|
33 |
+
"q_proj",
|
34 |
+
"up_proj",
|
35 |
+
"down_proj",
|
36 |
+
"gate_proj"
|
37 |
+
],
|
38 |
+
"task_type": "CAUSAL_LM",
|
39 |
+
"trainable_token_indices": null,
|
40 |
+
"use_dora": false,
|
41 |
+
"use_rslora": false
|
42 |
+
}
|
gemma-1b-tq-model/adapter_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b5ccfa8bc91f4f5e80a03ca2a73036523a47df7550df49b4cd8296c486ed37de
|
3 |
+
size 1260191096
|
gemma-1b-tq-model/added_tokens.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"<image_soft_token>": 262144
|
3 |
+
}
|
gemma-1b-tq-model/checkpoint-10/README.md
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
base_model: google/gemma-3-1b-pt
|
3 |
+
library_name: peft
|
4 |
+
---
|
5 |
+
|
6 |
+
# Model Card for Model ID
|
7 |
+
|
8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
## Model Details
|
13 |
+
|
14 |
+
### Model Description
|
15 |
+
|
16 |
+
<!-- Provide a longer summary of what this model is. -->
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
- **Developed by:** [More Information Needed]
|
21 |
+
- **Funded by [optional]:** [More Information Needed]
|
22 |
+
- **Shared by [optional]:** [More Information Needed]
|
23 |
+
- **Model type:** [More Information Needed]
|
24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
25 |
+
- **License:** [More Information Needed]
|
26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
27 |
+
|
28 |
+
### Model Sources [optional]
|
29 |
+
|
30 |
+
<!-- Provide the basic links for the model. -->
|
31 |
+
|
32 |
+
- **Repository:** [More Information Needed]
|
33 |
+
- **Paper [optional]:** [More Information Needed]
|
34 |
+
- **Demo [optional]:** [More Information Needed]
|
35 |
+
|
36 |
+
## Uses
|
37 |
+
|
38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
39 |
+
|
40 |
+
### Direct Use
|
41 |
+
|
42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
43 |
+
|
44 |
+
[More Information Needed]
|
45 |
+
|
46 |
+
### Downstream Use [optional]
|
47 |
+
|
48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
49 |
+
|
50 |
+
[More Information Needed]
|
51 |
+
|
52 |
+
### Out-of-Scope Use
|
53 |
+
|
54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
55 |
+
|
56 |
+
[More Information Needed]
|
57 |
+
|
58 |
+
## Bias, Risks, and Limitations
|
59 |
+
|
60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
61 |
+
|
62 |
+
[More Information Needed]
|
63 |
+
|
64 |
+
### Recommendations
|
65 |
+
|
66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
67 |
+
|
68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
69 |
+
|
70 |
+
## How to Get Started with the Model
|
71 |
+
|
72 |
+
Use the code below to get started with the model.
|
73 |
+
|
74 |
+
[More Information Needed]
|
75 |
+
|
76 |
+
## Training Details
|
77 |
+
|
78 |
+
### Training Data
|
79 |
+
|
80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
81 |
+
|
82 |
+
[More Information Needed]
|
83 |
+
|
84 |
+
### Training Procedure
|
85 |
+
|
86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
87 |
+
|
88 |
+
#### Preprocessing [optional]
|
89 |
+
|
90 |
+
[More Information Needed]
|
91 |
+
|
92 |
+
|
93 |
+
#### Training Hyperparameters
|
94 |
+
|
95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
96 |
+
|
97 |
+
#### Speeds, Sizes, Times [optional]
|
98 |
+
|
99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
100 |
+
|
101 |
+
[More Information Needed]
|
102 |
+
|
103 |
+
## Evaluation
|
104 |
+
|
105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
106 |
+
|
107 |
+
### Testing Data, Factors & Metrics
|
108 |
+
|
109 |
+
#### Testing Data
|
110 |
+
|
111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
112 |
+
|
113 |
+
[More Information Needed]
|
114 |
+
|
115 |
+
#### Factors
|
116 |
+
|
117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
118 |
+
|
119 |
+
[More Information Needed]
|
120 |
+
|
121 |
+
#### Metrics
|
122 |
+
|
123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
124 |
+
|
125 |
+
[More Information Needed]
|
126 |
+
|
127 |
+
### Results
|
128 |
+
|
129 |
+
[More Information Needed]
|
130 |
+
|
131 |
+
#### Summary
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
## Model Examination [optional]
|
136 |
+
|
137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
138 |
+
|
139 |
+
[More Information Needed]
|
140 |
+
|
141 |
+
## Environmental Impact
|
142 |
+
|
143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
144 |
+
|
145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
146 |
+
|
147 |
+
- **Hardware Type:** [More Information Needed]
|
148 |
+
- **Hours used:** [More Information Needed]
|
149 |
+
- **Cloud Provider:** [More Information Needed]
|
150 |
+
- **Compute Region:** [More Information Needed]
|
151 |
+
- **Carbon Emitted:** [More Information Needed]
|
152 |
+
|
153 |
+
## Technical Specifications [optional]
|
154 |
+
|
155 |
+
### Model Architecture and Objective
|
156 |
+
|
157 |
+
[More Information Needed]
|
158 |
+
|
159 |
+
### Compute Infrastructure
|
160 |
+
|
161 |
+
[More Information Needed]
|
162 |
+
|
163 |
+
#### Hardware
|
164 |
+
|
165 |
+
[More Information Needed]
|
166 |
+
|
167 |
+
#### Software
|
168 |
+
|
169 |
+
[More Information Needed]
|
170 |
+
|
171 |
+
## Citation [optional]
|
172 |
+
|
173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
174 |
+
|
175 |
+
**BibTeX:**
|
176 |
+
|
177 |
+
[More Information Needed]
|
178 |
+
|
179 |
+
**APA:**
|
180 |
+
|
181 |
+
[More Information Needed]
|
182 |
+
|
183 |
+
## Glossary [optional]
|
184 |
+
|
185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
186 |
+
|
187 |
+
[More Information Needed]
|
188 |
+
|
189 |
+
## More Information [optional]
|
190 |
+
|
191 |
+
[More Information Needed]
|
192 |
+
|
193 |
+
## Model Card Authors [optional]
|
194 |
+
|
195 |
+
[More Information Needed]
|
196 |
+
|
197 |
+
## Model Card Contact
|
198 |
+
|
199 |
+
[More Information Needed]
|
200 |
+
### Framework versions
|
201 |
+
|
202 |
+
- PEFT 0.15.2
|
gemma-1b-tq-model/checkpoint-10/adapter_config.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha_pattern": {},
|
3 |
+
"auto_mapping": null,
|
4 |
+
"base_model_name_or_path": "google/gemma-3-1b-pt",
|
5 |
+
"bias": "none",
|
6 |
+
"corda_config": null,
|
7 |
+
"eva_config": null,
|
8 |
+
"exclude_modules": null,
|
9 |
+
"fan_in_fan_out": false,
|
10 |
+
"inference_mode": true,
|
11 |
+
"init_lora_weights": true,
|
12 |
+
"layer_replication": null,
|
13 |
+
"layers_pattern": null,
|
14 |
+
"layers_to_transform": null,
|
15 |
+
"loftq_config": {},
|
16 |
+
"lora_alpha": 16,
|
17 |
+
"lora_bias": false,
|
18 |
+
"lora_dropout": 0.05,
|
19 |
+
"megatron_config": null,
|
20 |
+
"megatron_core": "megatron.core",
|
21 |
+
"modules_to_save": [
|
22 |
+
"lm_head",
|
23 |
+
"embed_tokens"
|
24 |
+
],
|
25 |
+
"peft_type": "LORA",
|
26 |
+
"r": 16,
|
27 |
+
"rank_pattern": {},
|
28 |
+
"revision": null,
|
29 |
+
"target_modules": [
|
30 |
+
"o_proj",
|
31 |
+
"k_proj",
|
32 |
+
"v_proj",
|
33 |
+
"q_proj",
|
34 |
+
"up_proj",
|
35 |
+
"down_proj",
|
36 |
+
"gate_proj"
|
37 |
+
],
|
38 |
+
"task_type": "CAUSAL_LM",
|
39 |
+
"trainable_token_indices": null,
|
40 |
+
"use_dora": false,
|
41 |
+
"use_rslora": false
|
42 |
+
}
|
gemma-1b-tq-model/checkpoint-10/adapter_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8b5ac34a75c3a7a06c3c9fda32b95d9870267691dacb15af9c0c2c08dc4e7934
|
3 |
+
size 1260191096
|
gemma-1b-tq-model/checkpoint-10/added_tokens.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"<image_soft_token>": 262144
|
3 |
+
}
|
gemma-1b-tq-model/checkpoint-10/optimizer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8f990dd49302e243f6510cfc5976f8dc28049c45c9af22f8424f89bc8b3d89b2
|
3 |
+
size 2520598381
|
gemma-1b-tq-model/checkpoint-10/rng_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:67a697233a108d806598e97819e22cb699651bb7e046c04cc47db386d7540306
|
3 |
+
size 14645
|
gemma-1b-tq-model/checkpoint-10/scheduler.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:591c3072a024fe1a8043b72e8e5366699aec4a9d0c3da5bde546eb445034a199
|
3 |
+
size 1401
|
gemma-1b-tq-model/checkpoint-10/special_tokens_map.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"boi_token": "<start_of_image>",
|
3 |
+
"bos_token": {
|
4 |
+
"content": "<bos>",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false
|
9 |
+
},
|
10 |
+
"eoi_token": "<end_of_image>",
|
11 |
+
"eos_token": {
|
12 |
+
"content": "<eos>",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false
|
17 |
+
},
|
18 |
+
"image_token": "<image_soft_token>",
|
19 |
+
"pad_token": {
|
20 |
+
"content": "<pad>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false
|
25 |
+
},
|
26 |
+
"unk_token": {
|
27 |
+
"content": "<unk>",
|
28 |
+
"lstrip": false,
|
29 |
+
"normalized": false,
|
30 |
+
"rstrip": false,
|
31 |
+
"single_word": false
|
32 |
+
}
|
33 |
+
}
|