bhavinjawade commited on
Commit
9564ed2
·
verified ·
1 Parent(s): 906f62f

Model save

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +103 -0
  2. =0.41.0 +0 -0
  3. =0.6.0 +0 -0
  4. DSPy_Optimization.ipynb +415 -0
  5. InstructionFinetuning.ipynb +1277 -0
  6. README.md +58 -0
  7. SFT_Expert.py +229 -0
  8. TQ_template.py +37 -0
  9. TextGrad_Optimization.ipynb +544 -0
  10. adapter_config.json +45 -0
  11. adapter_model.safetensors +3 -0
  12. added_tokens.json +3 -0
  13. data_prep.py +380 -0
  14. gemma-12b-tq-model/README.md +58 -0
  15. gemma-12b-tq-model/adapter_config.json +45 -0
  16. gemma-12b-tq-model/adapter_model.safetensors +3 -0
  17. gemma-12b-tq-model/added_tokens.json +3 -0
  18. gemma-12b-tq-model/checkpoint-2/README.md +202 -0
  19. gemma-12b-tq-model/checkpoint-2/adapter_config.json +45 -0
  20. gemma-12b-tq-model/checkpoint-2/adapter_model.safetensors +3 -0
  21. gemma-12b-tq-model/checkpoint-2/added_tokens.json +3 -0
  22. gemma-12b-tq-model/checkpoint-2/optimizer.pt +3 -0
  23. gemma-12b-tq-model/checkpoint-2/rng_state.pth +3 -0
  24. gemma-12b-tq-model/checkpoint-2/scheduler.pt +3 -0
  25. gemma-12b-tq-model/checkpoint-2/special_tokens_map.json +33 -0
  26. gemma-12b-tq-model/checkpoint-2/tokenizer.json +3 -0
  27. gemma-12b-tq-model/checkpoint-2/tokenizer.model +3 -0
  28. gemma-12b-tq-model/checkpoint-2/tokenizer_config.json +0 -0
  29. gemma-12b-tq-model/checkpoint-2/trainer_state.json +51 -0
  30. gemma-12b-tq-model/checkpoint-2/training_args.bin +3 -0
  31. gemma-12b-tq-model/runs/Apr25_08-39-59_9945b53f-579e-4565-94fc-5fbe73c83cc2/events.out.tfevents.1745570448.9945b53f-579e-4565-94fc-5fbe73c83cc2 +3 -0
  32. gemma-12b-tq-model/runs/Apr25_08-42-29_9945b53f-579e-4565-94fc-5fbe73c83cc2/events.out.tfevents.1745570563.9945b53f-579e-4565-94fc-5fbe73c83cc2 +3 -0
  33. gemma-12b-tq-model/runs/Apr25_09-19-39_9945b53f-579e-4565-94fc-5fbe73c83cc2/events.out.tfevents.1745572788.9945b53f-579e-4565-94fc-5fbe73c83cc2 +3 -0
  34. gemma-12b-tq-model/special_tokens_map.json +33 -0
  35. gemma-12b-tq-model/tokenizer.json +3 -0
  36. gemma-12b-tq-model/tokenizer.model +3 -0
  37. gemma-12b-tq-model/tokenizer_config.json +0 -0
  38. gemma-12b-tq-model/training_args.bin +3 -0
  39. gemma-1b-tq-model/README.md +58 -0
  40. gemma-1b-tq-model/adapter_config.json +42 -0
  41. gemma-1b-tq-model/adapter_model.safetensors +3 -0
  42. gemma-1b-tq-model/added_tokens.json +3 -0
  43. gemma-1b-tq-model/checkpoint-10/README.md +202 -0
  44. gemma-1b-tq-model/checkpoint-10/adapter_config.json +42 -0
  45. gemma-1b-tq-model/checkpoint-10/adapter_model.safetensors +3 -0
  46. gemma-1b-tq-model/checkpoint-10/added_tokens.json +3 -0
  47. gemma-1b-tq-model/checkpoint-10/optimizer.pt +3 -0
  48. gemma-1b-tq-model/checkpoint-10/rng_state.pth +3 -0
  49. gemma-1b-tq-model/checkpoint-10/scheduler.pt +3 -0
  50. gemma-1b-tq-model/checkpoint-10/special_tokens_map.json +33 -0
.gitattributes CHANGED
@@ -33,3 +33,106 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ gemma-12b-tq-model/checkpoint-2/tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ gemma-12b-tq-model/tokenizer.json filter=lfs diff=lfs merge=lfs -text
38
+ gemma-1b-tq-model/checkpoint-10/tokenizer.json filter=lfs diff=lfs merge=lfs -text
39
+ gemma-1b-tq-model/checkpoint-100/tokenizer.json filter=lfs diff=lfs merge=lfs -text
40
+ gemma-1b-tq-model/checkpoint-11/tokenizer.json filter=lfs diff=lfs merge=lfs -text
41
+ gemma-1b-tq-model/checkpoint-12/tokenizer.json filter=lfs diff=lfs merge=lfs -text
42
+ gemma-1b-tq-model/checkpoint-14/tokenizer.json filter=lfs diff=lfs merge=lfs -text
43
+ gemma-1b-tq-model/checkpoint-16/tokenizer.json filter=lfs diff=lfs merge=lfs -text
44
+ gemma-1b-tq-model/checkpoint-18/tokenizer.json filter=lfs diff=lfs merge=lfs -text
45
+ gemma-1b-tq-model/checkpoint-2/tokenizer.json filter=lfs diff=lfs merge=lfs -text
46
+ gemma-1b-tq-model/checkpoint-20/tokenizer.json filter=lfs diff=lfs merge=lfs -text
47
+ gemma-1b-tq-model/checkpoint-22/tokenizer.json filter=lfs diff=lfs merge=lfs -text
48
+ gemma-1b-tq-model/checkpoint-24/tokenizer.json filter=lfs diff=lfs merge=lfs -text
49
+ gemma-1b-tq-model/checkpoint-26/tokenizer.json filter=lfs diff=lfs merge=lfs -text
50
+ gemma-1b-tq-model/checkpoint-28/tokenizer.json filter=lfs diff=lfs merge=lfs -text
51
+ gemma-1b-tq-model/checkpoint-30/tokenizer.json filter=lfs diff=lfs merge=lfs -text
52
+ gemma-1b-tq-model/checkpoint-32/tokenizer.json filter=lfs diff=lfs merge=lfs -text
53
+ gemma-1b-tq-model/checkpoint-33/tokenizer.json filter=lfs diff=lfs merge=lfs -text
54
+ gemma-1b-tq-model/checkpoint-34/tokenizer.json filter=lfs diff=lfs merge=lfs -text
55
+ gemma-1b-tq-model/checkpoint-36/tokenizer.json filter=lfs diff=lfs merge=lfs -text
56
+ gemma-1b-tq-model/checkpoint-38/tokenizer.json filter=lfs diff=lfs merge=lfs -text
57
+ gemma-1b-tq-model/checkpoint-4/tokenizer.json filter=lfs diff=lfs merge=lfs -text
58
+ gemma-1b-tq-model/checkpoint-40/tokenizer.json filter=lfs diff=lfs merge=lfs -text
59
+ gemma-1b-tq-model/checkpoint-42/tokenizer.json filter=lfs diff=lfs merge=lfs -text
60
+ gemma-1b-tq-model/checkpoint-44/tokenizer.json filter=lfs diff=lfs merge=lfs -text
61
+ gemma-1b-tq-model/checkpoint-46/tokenizer.json filter=lfs diff=lfs merge=lfs -text
62
+ gemma-1b-tq-model/checkpoint-48/tokenizer.json filter=lfs diff=lfs merge=lfs -text
63
+ gemma-1b-tq-model/checkpoint-5/tokenizer.json filter=lfs diff=lfs merge=lfs -text
64
+ gemma-1b-tq-model/checkpoint-50/tokenizer.json filter=lfs diff=lfs merge=lfs -text
65
+ gemma-1b-tq-model/checkpoint-52/tokenizer.json filter=lfs diff=lfs merge=lfs -text
66
+ gemma-1b-tq-model/checkpoint-54/tokenizer.json filter=lfs diff=lfs merge=lfs -text
67
+ gemma-1b-tq-model/checkpoint-56/tokenizer.json filter=lfs diff=lfs merge=lfs -text
68
+ gemma-1b-tq-model/checkpoint-58/tokenizer.json filter=lfs diff=lfs merge=lfs -text
69
+ gemma-1b-tq-model/checkpoint-6/tokenizer.json filter=lfs diff=lfs merge=lfs -text
70
+ gemma-1b-tq-model/checkpoint-60/tokenizer.json filter=lfs diff=lfs merge=lfs -text
71
+ gemma-1b-tq-model/checkpoint-62/tokenizer.json filter=lfs diff=lfs merge=lfs -text
72
+ gemma-1b-tq-model/checkpoint-64/tokenizer.json filter=lfs diff=lfs merge=lfs -text
73
+ gemma-1b-tq-model/checkpoint-66/tokenizer.json filter=lfs diff=lfs merge=lfs -text
74
+ gemma-1b-tq-model/checkpoint-68/tokenizer.json filter=lfs diff=lfs merge=lfs -text
75
+ gemma-1b-tq-model/checkpoint-70/tokenizer.json filter=lfs diff=lfs merge=lfs -text
76
+ gemma-1b-tq-model/checkpoint-72/tokenizer.json filter=lfs diff=lfs merge=lfs -text
77
+ gemma-1b-tq-model/checkpoint-74/tokenizer.json filter=lfs diff=lfs merge=lfs -text
78
+ gemma-1b-tq-model/checkpoint-76/tokenizer.json filter=lfs diff=lfs merge=lfs -text
79
+ gemma-1b-tq-model/checkpoint-78/tokenizer.json filter=lfs diff=lfs merge=lfs -text
80
+ gemma-1b-tq-model/checkpoint-8/tokenizer.json filter=lfs diff=lfs merge=lfs -text
81
+ gemma-1b-tq-model/checkpoint-80/tokenizer.json filter=lfs diff=lfs merge=lfs -text
82
+ gemma-1b-tq-model/checkpoint-82/tokenizer.json filter=lfs diff=lfs merge=lfs -text
83
+ gemma-1b-tq-model/checkpoint-84/tokenizer.json filter=lfs diff=lfs merge=lfs -text
84
+ gemma-1b-tq-model/checkpoint-86/tokenizer.json filter=lfs diff=lfs merge=lfs -text
85
+ gemma-1b-tq-model/checkpoint-88/tokenizer.json filter=lfs diff=lfs merge=lfs -text
86
+ gemma-1b-tq-model/checkpoint-90/tokenizer.json filter=lfs diff=lfs merge=lfs -text
87
+ gemma-1b-tq-model/checkpoint-92/tokenizer.json filter=lfs diff=lfs merge=lfs -text
88
+ gemma-1b-tq-model/checkpoint-94/tokenizer.json filter=lfs diff=lfs merge=lfs -text
89
+ gemma-1b-tq-model/checkpoint-96/tokenizer.json filter=lfs diff=lfs merge=lfs -text
90
+ gemma-1b-tq-model/checkpoint-98/tokenizer.json filter=lfs diff=lfs merge=lfs -text
91
+ gemma-1b-tq-model/tokenizer.json filter=lfs diff=lfs merge=lfs -text
92
+ gemma-27b-tq_sft_finetuned-model/checkpoint-129/tokenizer.json filter=lfs diff=lfs merge=lfs -text
93
+ gemma-27b-tq_sft_finetuned-model/checkpoint-131/tokenizer.json filter=lfs diff=lfs merge=lfs -text
94
+ gemma-27b-tq_sft_finetuned-model/checkpoint-133/tokenizer.json filter=lfs diff=lfs merge=lfs -text
95
+ gemma-27b-tq_sft_finetuned-model/checkpoint-256/tokenizer.json filter=lfs diff=lfs merge=lfs -text
96
+ gemma-27b-tq_sft_finetuned-model/checkpoint-264/tokenizer.json filter=lfs diff=lfs merge=lfs -text
97
+ gemma-27b-tq_sft_finetuned-model/tokenizer.json filter=lfs diff=lfs merge=lfs -text
98
+ gemma-27b-tq_sft_finetuned-model-full/tokenizer.json filter=lfs diff=lfs merge=lfs -text
99
+ gemma-4b-tq_sft_finetuned-model/checkpoint-10/tokenizer.json filter=lfs diff=lfs merge=lfs -text
100
+ gemma-4b-tq_sft_finetuned-model/checkpoint-1040/tokenizer.json filter=lfs diff=lfs merge=lfs -text
101
+ gemma-4b-tq_sft_finetuned-model/checkpoint-116/tokenizer.json filter=lfs diff=lfs merge=lfs -text
102
+ gemma-4b-tq_sft_finetuned-model/checkpoint-1170/tokenizer.json filter=lfs diff=lfs merge=lfs -text
103
+ gemma-4b-tq_sft_finetuned-model/checkpoint-124/tokenizer.json filter=lfs diff=lfs merge=lfs -text
104
+ gemma-4b-tq_sft_finetuned-model/checkpoint-125/tokenizer.json filter=lfs diff=lfs merge=lfs -text
105
+ gemma-4b-tq_sft_finetuned-model/checkpoint-130/tokenizer.json filter=lfs diff=lfs merge=lfs -text
106
+ gemma-4b-tq_sft_finetuned-model/checkpoint-1300/tokenizer.json filter=lfs diff=lfs merge=lfs -text
107
+ gemma-4b-tq_sft_finetuned-model/checkpoint-140/tokenizer.json filter=lfs diff=lfs merge=lfs -text
108
+ gemma-4b-tq_sft_finetuned-model/checkpoint-186/tokenizer.json filter=lfs diff=lfs merge=lfs -text
109
+ gemma-4b-tq_sft_finetuned-model/checkpoint-248/tokenizer.json filter=lfs diff=lfs merge=lfs -text
110
+ gemma-4b-tq_sft_finetuned-model/checkpoint-250/tokenizer.json filter=lfs diff=lfs merge=lfs -text
111
+ gemma-4b-tq_sft_finetuned-model/checkpoint-260/tokenizer.json filter=lfs diff=lfs merge=lfs -text
112
+ gemma-4b-tq_sft_finetuned-model/checkpoint-29/tokenizer.json filter=lfs diff=lfs merge=lfs -text
113
+ gemma-4b-tq_sft_finetuned-model/checkpoint-3/tokenizer.json filter=lfs diff=lfs merge=lfs -text
114
+ gemma-4b-tq_sft_finetuned-model/checkpoint-310/tokenizer.json filter=lfs diff=lfs merge=lfs -text
115
+ gemma-4b-tq_sft_finetuned-model/checkpoint-372/tokenizer.json filter=lfs diff=lfs merge=lfs -text
116
+ gemma-4b-tq_sft_finetuned-model/checkpoint-375/tokenizer.json filter=lfs diff=lfs merge=lfs -text
117
+ gemma-4b-tq_sft_finetuned-model/checkpoint-390/tokenizer.json filter=lfs diff=lfs merge=lfs -text
118
+ gemma-4b-tq_sft_finetuned-model/checkpoint-434/tokenizer.json filter=lfs diff=lfs merge=lfs -text
119
+ gemma-4b-tq_sft_finetuned-model/checkpoint-496/tokenizer.json filter=lfs diff=lfs merge=lfs -text
120
+ gemma-4b-tq_sft_finetuned-model/checkpoint-500/tokenizer.json filter=lfs diff=lfs merge=lfs -text
121
+ gemma-4b-tq_sft_finetuned-model/checkpoint-520/tokenizer.json filter=lfs diff=lfs merge=lfs -text
122
+ gemma-4b-tq_sft_finetuned-model/checkpoint-558/tokenizer.json filter=lfs diff=lfs merge=lfs -text
123
+ gemma-4b-tq_sft_finetuned-model/checkpoint-58/tokenizer.json filter=lfs diff=lfs merge=lfs -text
124
+ gemma-4b-tq_sft_finetuned-model/checkpoint-6/tokenizer.json filter=lfs diff=lfs merge=lfs -text
125
+ gemma-4b-tq_sft_finetuned-model/checkpoint-61/tokenizer.json filter=lfs diff=lfs merge=lfs -text
126
+ gemma-4b-tq_sft_finetuned-model/checkpoint-610/tokenizer.json filter=lfs diff=lfs merge=lfs -text
127
+ gemma-4b-tq_sft_finetuned-model/checkpoint-62/tokenizer.json filter=lfs diff=lfs merge=lfs -text
128
+ gemma-4b-tq_sft_finetuned-model/checkpoint-625/tokenizer.json filter=lfs diff=lfs merge=lfs -text
129
+ gemma-4b-tq_sft_finetuned-model/checkpoint-650/tokenizer.json filter=lfs diff=lfs merge=lfs -text
130
+ gemma-4b-tq_sft_finetuned-model/checkpoint-780/tokenizer.json filter=lfs diff=lfs merge=lfs -text
131
+ gemma-4b-tq_sft_finetuned-model/checkpoint-87/tokenizer.json filter=lfs diff=lfs merge=lfs -text
132
+ gemma-4b-tq_sft_finetuned-model/checkpoint-9/tokenizer.json filter=lfs diff=lfs merge=lfs -text
133
+ gemma-4b-tq_sft_finetuned-model/checkpoint-910/tokenizer.json filter=lfs diff=lfs merge=lfs -text
134
+ gemma-4b-tq_sft_finetuned-model/tokenizer.json filter=lfs diff=lfs merge=lfs -text
135
+ may13-gemma-27b-tq_sft_finetuned-model/checkpoint-80/tokenizer.json filter=lfs diff=lfs merge=lfs -text
136
+ may13-gemma-27b-tq_sft_finetuned-model/tokenizer.json filter=lfs diff=lfs merge=lfs -text
137
+ merged_model/tokenizer.json filter=lfs diff=lfs merge=lfs -text
138
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
=0.41.0 ADDED
File without changes
=0.6.0 ADDED
File without changes
DSPy_Optimization.ipynb ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 33,
6
+ "id": "8b3ee6e2-ca9c-40fa-b4c6-a9596f075f79",
7
+ "metadata": {
8
+ "execution": {
9
+ "iopub.execute_input": "2025-04-22T23:03:20.101831Z",
10
+ "iopub.status.busy": "2025-04-22T23:03:20.101435Z",
11
+ "iopub.status.idle": "2025-04-22T23:03:20.105088Z",
12
+ "shell.execute_reply": "2025-04-22T23:03:20.104580Z",
13
+ "shell.execute_reply.started": "2025-04-22T23:03:20.101804Z"
14
+ }
15
+ },
16
+ "outputs": [],
17
+ "source": [
18
+ "import dspy\n",
19
+ "from dspy.teleprompt import MIPROv2\n",
20
+ "from typing import List, Dict\n",
21
+ "import json\n",
22
+ "import numpy as np\n",
23
+ "import os\n",
24
+ "import random\n",
25
+ "from tqdm import tqdm"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": 31,
31
+ "id": "4ec9a29b-9162-4fe3-b32d-4de4397c6483",
32
+ "metadata": {
33
+ "execution": {
34
+ "iopub.execute_input": "2025-04-22T23:00:21.439753Z",
35
+ "iopub.status.busy": "2025-04-22T23:00:21.439342Z",
36
+ "iopub.status.idle": "2025-04-22T23:00:21.526091Z",
37
+ "shell.execute_reply": "2025-04-22T23:00:21.525575Z",
38
+ "shell.execute_reply.started": "2025-04-22T23:00:21.439727Z"
39
+ }
40
+ },
41
+ "outputs": [
42
+ {
43
+ "name": "stderr",
44
+ "output_type": "stream",
45
+ "text": [
46
+ "100%|██████████| 4/4 [00:00<00:00, 77.75it/s]\n"
47
+ ]
48
+ },
49
+ {
50
+ "data": {
51
+ "text/plain": [
52
+ "{'input': {'src_text': 'Ma io che ne so, comandà? Io stavo a casa di mia madre, lo sapete.\\n\\nLo so.',\n",
53
+ " 'tgt_text': \"What do I know, Commander? I was at my mom's house, you know it.\\n\\nI knows.\",\n",
54
+ " 'src_prev': \"Questa è una linea. Qua faccio quello che voglio, è terra mia, la legge è mia. Dall'altro lato c'è un mondo fatto di spazzatura. Questa linea non l'ho mai oltrepassata. Impara chi è tua madre una volta per tutte. Tieni, questo era per te. Mà… Mà! Secondo me non è stata lei. Come al solito ti sei fatto prendere per il culo. Comandà, credo che non è stata lei. Carmine, sei uno stronzo. Robè, portalo via. Andiamocene.\",\n",
55
+ " 'src_next': 'E allora che altro vi devo dire? Tu non devi dire niente. Devi tenere la bocca chiusa. E non dire a nessuno quello che ti ho detto. Ma a nessuno però. Ho capito. Però devi tenere le orecchie aperte e ascoltare tutto quello che si dice qua dentro. Perché prima o poi, chi fa queste cose parla. Si deve atteggiare, si deve fare grosso. Che si è divertito con la moglie del comandante. Secondo me vi sbagliate, comandà. Non può essere stato nessuno che sta qua dentro. Lo so.',\n",
56
+ " 'tgt_prev': \"This is a line. Here I do whatever I want, it's my territory, it's my law. On the other side there's a world of trash. I've never crossed that line. Learn who your mother is, once and for all. Here, this was for you. Ma… Ma! I don't think it was her. As usual you let her fuck you around. Commander, I think she didn't do it. Carmine, you're an asshole. Robè, take him away. Let's go.\",\n",
57
+ " 'tgt_next': \"So what else should I say? You don't have to say anything. You have to keep your mouth shut. And don't tell anybody what I told you. To nobody. Got it. But keep your ears open and listen to what they say in here. Because sooner or later, guys who do such things talk. They need to swagger, act like big guys. Bragging they had fun with the Commander's wife. I think you're wrong, Commander. It can't have been anyone who's in here. I know.\",\n",
58
+ " 'src_lang': 'it',\n",
59
+ " 'tgt_lang': 'en'},\n",
60
+ " 'evaluation': {'Accuracy Issues': [],\n",
61
+ " 'Readability Issues': [],\n",
62
+ " 'Accuracy Score': '4',\n",
63
+ " 'Readability Score': '4',\n",
64
+ " 'Confidence Level': 'the_translation_is_excellent_without_any_error_spans_and_no_creative_liberties_were_taken',\n",
65
+ " 'Main Vs Alternate': 'Alternate Translated Text has marginally better quality',\n",
66
+ " 'Score': 32}}"
67
+ ]
68
+ },
69
+ "execution_count": 31,
70
+ "metadata": {},
71
+ "output_type": "execute_result"
72
+ }
73
+ ],
74
+ "source": [
75
+ "data_path = \"/root/notebooks/MT_TQ/TQ/DataPrep_Prompting_Experiments/labeled_data/parsed/\"\n",
76
+ "json_files = [os.path.join(root, file) for root, _, files in os.walk(data_path) for file in files if file.endswith('.json') and 'PLDL' in file]\n",
77
+ "\n",
78
+ "training_samples = []\n",
79
+ "for json_file in tqdm(json_files):\n",
80
+ " with open(json_file, 'r') as file:\n",
81
+ " data = json.load(file)\n",
82
+ " sampled_items = random.sample(data[\"data\"], 20)\n",
83
+ " training_samples.extend(sampled_items)\n",
84
+ "\n",
85
+ "datapoints = []\n",
86
+ "\n",
87
+ "for sample in training_samples:\n",
88
+ " datapoint = {\"input\":{}}\n",
89
+ " datapoint[\"input\"][\"src_text\"] = sample[\"main_src_text\"]\n",
90
+ " datapoint[\"input\"][\"tgt_text\"] = sample[\"tgt_text\"]\n",
91
+ " datapoint[\"input\"][\"src_prev\"] = sample[\"tt_src_prev\"]\n",
92
+ " datapoint[\"input\"][\"src_next\"] = sample[\"tt_src_next\"]\n",
93
+ " datapoint[\"input\"][\"tgt_prev\"] = sample[\"tt_tgt_prev\"]\n",
94
+ " datapoint[\"input\"][\"tgt_next\"] = sample[\"tt_tgt_next\"]\n",
95
+ " datapoint[\"input\"][\"src_lang\"] = sample[\"src_lang\"]\n",
96
+ " datapoint[\"input\"][\"tgt_lang\"] = sample[\"tgt_lang\"]\n",
97
+ " datapoint[\"evaluation\"] = sample[\"labelers\"][0][\"annotation\"]\n",
98
+ " datapoints.append(datapoint)\n",
99
+ "\n",
100
+ "datapoint"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": 35,
106
+ "id": "bde34303-2f52-415f-b117-264e266b84f0",
107
+ "metadata": {
108
+ "execution": {
109
+ "iopub.execute_input": "2025-04-22T23:04:16.302953Z",
110
+ "iopub.status.busy": "2025-04-22T23:04:16.302402Z",
111
+ "iopub.status.idle": "2025-04-22T23:04:16.334330Z",
112
+ "shell.execute_reply": "2025-04-22T23:04:16.333644Z",
113
+ "shell.execute_reply.started": "2025-04-22T23:04:16.302928Z"
114
+ }
115
+ },
116
+ "outputs": [
117
+ {
118
+ "ename": "AttributeError",
119
+ "evalue": "module 'dspy' has no attribute 'Predictor'",
120
+ "output_type": "error",
121
+ "traceback": [
122
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
123
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
124
+ "Cell \u001b[0;32mIn[35], line 28\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m prediction\u001b[38;5;241m.\u001b[39mevaluation\n\u001b[1;32m 27\u001b[0m \u001b[38;5;66;03m# Create a custom predictor using your Netflix model\u001b[39;00m\n\u001b[0;32m---> 28\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mNetflixPredictor\u001b[39;00m(\u001b[43mdspy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPredictor\u001b[49m):\n\u001b[1;32m 29\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, model):\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m=\u001b[39m model\n",
125
+ "\u001b[0;31mAttributeError\u001b[0m: module 'dspy' has no attribute 'Predictor'"
126
+ ]
127
+ }
128
+ ],
129
+ "source": [
130
+ "class TranslationQualityChecker(dspy.Signature):\n",
131
+ " \"\"\"Evaluate the quality of translation.\"\"\"\n",
132
+ " \n",
133
+ " context = dspy.InputField(desc=\"Source and target text with context\")\n",
134
+ " evaluation = dspy.OutputField(desc=\"Detailed evaluation of the translation quality\")\n",
135
+ "\n",
136
+ "class TranslationQualityModule(dspy.Module):\n",
137
+ " def __init__(self):\n",
138
+ " super().__init__()\n",
139
+ " self.checker = dspy.Predict(TranslationQualityChecker)\n",
140
+ " \n",
141
+ " def forward(self, src_text, tgt_text, src_prev, tgt_prev, src_next, tgt_next, src_lang, tgt_lang):\n",
142
+ " context = {\n",
143
+ " \"source_text\": src_text,\n",
144
+ " \"target_text\": tgt_text,\n",
145
+ " \"source_previous\": src_prev,\n",
146
+ " \"target_previous\": tgt_prev,\n",
147
+ " \"source_next\": src_next,\n",
148
+ " \"target_next\": tgt_next,\n",
149
+ " \"source_language\": src_lang,\n",
150
+ " \"target_language\": tgt_lang\n",
151
+ " }\n",
152
+ " \n",
153
+ " prediction = self.checker(context=context)\n",
154
+ " return prediction.evaluation\n",
155
+ "\n",
156
+ "# Create a custom backend using your Netflix model\n",
157
+ "class NetflixBackend(dspy.BackendBase):\n",
158
+ " def __init__(self, model):\n",
159
+ " super().__init__()\n",
160
+ " self.model = model\n",
161
+ " \n",
162
+ " def complete(self, prompt, **kwargs):\n",
163
+ " messages = [{\"role\": \"user\", \"content\": prompt}]\n",
164
+ " response = self.model.generate(messages)\n",
165
+ " return response\n",
166
+ "\n",
167
+ " def completions(self, prompts, **kwargs):\n",
168
+ " return [self.complete(prompt, **kwargs) for prompt in prompts]\n",
169
+ "\n",
170
+ "# Prepare training data\n",
171
+ "def prepare_training_data(data_points):\n",
172
+ " compiled_data = []\n",
173
+ " for dp in data_points:\n",
174
+ " input_data = dp['input']\n",
175
+ " train_example = dspy.Example(\n",
176
+ " context={\n",
177
+ " \"source_text\": input_data['src_text'],\n",
178
+ " \"target_text\": input_data['tgt_text'],\n",
179
+ " \"source_previous\": input_data['src_prev'],\n",
180
+ " \"target_previous\": input_data['tgt_prev'],\n",
181
+ " \"source_next\": input_data['src_next'],\n",
182
+ " \"target_next\": input_data['tgt_next'],\n",
183
+ " \"source_language\": input_data['src_lang'],\n",
184
+ " \"target_language\": input_data['tgt_lang']\n",
185
+ " },\n",
186
+ " evaluation=dp['evaluation']\n",
187
+ " )\n",
188
+ " compiled_data.append(train_example)\n",
189
+ " return compiled_data\n",
190
+ "\n",
191
+ "def optimize_prompt(model, training_data, validation_data):\n",
192
+ " # Initialize DSPy with your custom backend\n",
193
+ " backend = NetflixBackend(model)\n",
194
+ " dspy.settings.configure(lm=backend)\n",
195
+ " \n",
196
+ " # Create the optimizer\n",
197
+ " optimizer = MIPROv2(\n",
198
+ " metric=\"exact_match\", # or another appropriate metric\n",
199
+ " max_rounds=5,\n",
200
+ " max_prompts=3,\n",
201
+ " temp=0.7\n",
202
+ " )\n",
203
+ " \n",
204
+ " # Compile the module\n",
205
+ " translation_module = TranslationQualityModule()\n",
206
+ " \n",
207
+ " # Optimize the prompt\n",
208
+ " optimized_module = optimizer.optimize(\n",
209
+ " module=translation_module,\n",
210
+ " trainset=training_data,\n",
211
+ " valset=validation_data,\n",
212
+ " metric=dspy.evaluate.answer_exact_match\n",
213
+ " )\n",
214
+ " \n",
215
+ " return optimized_module"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": null,
221
+ "id": "67a4583f-162c-4e2d-b061-798f6c676a28",
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": [
225
+ "class TranslationQualityAssessor(dspy.Module):\n",
226
+ " def __init__(self):\n",
227
+ " super().__init__()\n",
228
+ " self.assess = dspy.ChainOfThought(TranslationQualitySignature)\n",
229
+ "\n",
230
+ " def forward(self, src_lang, tgt_lang, src_text, translation, src_prev=\"\", tgt_prev=\"\", src_next=\"\", tgt_next=\"\"):\n",
231
+ " context = f\"\"\"Previous Context:\n",
232
+ " Source: {src_prev}\n",
233
+ " Translation: {tgt_prev}\n",
234
+ " \n",
235
+ " Next Context:\n",
236
+ " Source: {src_next}\n",
237
+ " Translation: {tgt_next}\"\"\"\n",
238
+ "\n",
239
+ " result = self.assess(\n",
240
+ " context=context,\n",
241
+ " source=f\"Source ({src_lang}): {src_text}\",\n",
242
+ " translation=f\"Translation ({tgt_lang}): {translation}\"\n",
243
+ " )\n",
244
+ " \n",
245
+ " return result.evaluation\n",
246
+ "\n",
247
+ "class TranslationMetrics:\n",
248
+ " @staticmethod\n",
249
+ " def exact_match_score(pred, gold):\n",
250
+ " try:\n",
251
+ " pred_json = json.loads(pred)\n",
252
+ " gold_json = gold\n",
253
+ " \n",
254
+ " accuracy_match = (str(pred_json.get('Accuracy Score')) == str(gold_json.get('Accuracy Score')))\n",
255
+ " readability_match = (str(pred_json.get('Readability Score')) == str(gold_json.get('Readability Score')))\n",
256
+ " \n",
257
+ " return (accuracy_match and readability_match)\n",
258
+ " except:\n",
259
+ " return False\n",
260
+ " \n",
261
+ " @staticmethod\n",
262
+ " def partial_match_score(pred, gold):\n",
263
+ " try:\n",
264
+ " pred_json = json.loads(pred)\n",
265
+ " gold_json = gold\n",
266
+ " \n",
267
+ " # Score comparison\n",
268
+ " accuracy_diff = abs(float(pred_json.get('Accuracy Score', 0)) - float(gold_json.get('Accuracy Score', 0)))\n",
269
+ " readability_diff = abs(float(pred_json.get('Readability Score', 0)) - float(gold_json.get('Readability Score', 0)))\n",
270
+ " \n",
271
+ " # Issues comparison\n",
272
+ " pred_accuracy_issues = set(str(issue) for issue in pred_json.get('Accuracy Issues', []))\n",
273
+ " gold_accuracy_issues = set(str(issue) for issue in gold_json.get('Accuracy Issues', []))\n",
274
+ " pred_readability_issues = set(str(issue) for issue in pred_json.get('Readability Issues', []))\n",
275
+ " gold_readability_issues = set(str(issue) for issue in gold_json.get('Readability Issues', []))\n",
276
+ " \n",
277
+ " # Calculate Jaccard similarity for issues\n",
278
+ " accuracy_issues_sim = len(pred_accuracy_issues & gold_accuracy_issues) / max(1, len(pred_accuracy_issues | gold_accuracy_issues))\n",
279
+ " readability_issues_sim = len(pred_readability_issues & gold_readability_issues) / max(1, len(pred_readability_issues | gold_readability_issues))\n",
280
+ " \n",
281
+ " # Combine scores (0.6 weight to scores, 0.4 to issues similarity)\n",
282
+ " score_component = 1 - ((accuracy_diff + readability_diff) / 8)\n",
283
+ " issues_component = (accuracy_issues_sim + readability_issues_sim) / 2\n",
284
+ " \n",
285
+ " final_score = 0.6 * score_component + 0.4 * issues_component\n",
286
+ " return max(0, final_score)\n",
287
+ " except:\n",
288
+ " return 0\n",
289
+ "\n",
290
+ "def prepare_dataset(file_path):\n",
291
+ " with open(file_path, 'r') as f:\n",
292
+ " data = json.load(f)\n",
293
+ " \n",
294
+ " prepared_data = []\n",
295
+ " \n",
296
+ " for item in data:\n",
297
+ " example = dspy.Example(\n",
298
+ " context=f\"\"\"Previous Context:\n",
299
+ " Source: {item['src_prev']}\n",
300
+ " Translation: {item['tgt_prev']}\n",
301
+ " \n",
302
+ " Next Context:\n",
303
+ " Source: {item['src_next']}\n",
304
+ " Translation: {item['tgt_next']}\"\"\",\n",
305
+ " source=f\"Source ({item['src_lang']}): {item['src_text']}\",\n",
306
+ " translation=f\"Translation ({item['tgt_lang']}): {item['main_text']}\",\n",
307
+ " evaluation=json.dumps(item['evaluation'], ensure_ascii=False)\n",
308
+ " ).with_inputs(\"context\", \"source\", \"translation\")\n",
309
+ " \n",
310
+ " prepared_data.append(example)\n",
311
+ " \n",
312
+ " # Split data: 70% train, 15% dev, 15% test\n",
313
+ " train_size = int(0.7 * len(prepared_data))\n",
314
+ " dev_size = int(0.15 * len(prepared_data))\n",
315
+ " \n",
316
+ " train_data = prepared_data[:train_size]\n",
317
+ " dev_data = prepared_data[train_size:train_size + dev_size]\n",
318
+ " test_data = prepared_data[train_size + dev_size:]\n",
319
+ " \n",
320
+ " return train_data, dev_data, test_data\n",
321
+ "\n",
322
+ "def optimize_translation_quality_assessment():\n",
323
+ " # Initialize DSPy\n",
324
+ " lm = TranslationQualityLM()\n",
325
+ " dspy.settings.configure(lm=lm)\n",
326
+ " \n",
327
+ " # Load and prepare dataset\n",
328
+ " train_data, dev_data, test_data = prepare_dataset('translation_quality_dataset.json')\n",
329
+ " \n",
330
+ " # Create evaluator\n",
331
+ " evaluator = Evaluate(\n",
332
+ " metrics={\n",
333
+ " 'exact_match': TranslationMetrics.exact_match_score,\n",
334
+ " 'partial_match': TranslationMetrics.partial_match_score\n",
335
+ " }\n",
336
+ " )\n",
337
+ " \n",
338
+ " # Initialize module\n",
339
+ " assessor = TranslationQualityAssessor()\n",
340
+ " \n",
341
+ " # Initialize MIPROv2 optimizer\n",
342
+ " optimizer = dspy.MIPROv2(\n",
343
+ " metric=lambda x: x['partial_match'],\n",
344
+ " max_rounds=5, # Number of optimization rounds\n",
345
+ " max_traces=10, # Number of traces per round\n",
346
+ " max_depth=3, # Maximum depth of reasoning chains\n",
347
+ " num_candidate_prompts=5, # Number of candidate prompts to generate\n",
348
+ " num_rounds_per_prompt=3, # Number of rounds per candidate prompt\n",
349
+ " temperature=0.7,\n",
350
+ " verbose=True\n",
351
+ " )\n",
352
+ " \n",
353
+ " # Compile the module with optimization\n",
354
+ " compiled_assessor = optimizer.compile(\n",
355
+ " assessor,\n",
356
+ " trainset=train_data,\n",
357
+ " devset=dev_data,\n",
358
+ " eval_kwargs={\n",
359
+ " 'metric': 'partial_match',\n",
360
+ " 'num_threads': 4,\n",
361
+ " 'batch_size': 8\n",
362
+ " }\n",
363
+ " )\n",
364
+ " \n",
365
+ " # Evaluate on test set\n",
366
+ " results = []\n",
367
+ " for example in test_data:\n",
368
+ " pred = compiled_assessor(\n",
369
+ " context=example.context,\n",
370
+ " source=example.source,\n",
371
+ " translation=example.translation\n",
372
+ " )\n",
373
+ " \n",
374
+ " result = evaluator.evaluate(\n",
375
+ " predictions=[pred],\n",
376
+ " ground_truth=[example.evaluation]\n",
377
+ " )\n",
378
+ " results.append(result)\n",
379
+ " \n",
380
+ " # Calculate and print final metrics\n",
381
+ " avg_exact_match = np.mean([r['exact_match'] for r in results])\n",
382
+ " avg_partial_match = np.mean([r['partial_match'] for r in results])\n",
383
+ " \n",
384
+ " print(f\"Average Exact Match Score: {avg_exact_match:.3f}\")\n",
385
+ " print(f\"Average Partial Match Score: {avg_partial_match:.3f}\")\n",
386
+ " \n",
387
+ " return compiled_assessor\n",
388
+ "\n",
389
+ "if __name__ == \"__main__\":\n",
390
+ " optimized_assessor = optimize_translation_quality_assessment()"
391
+ ]
392
+ }
393
+ ],
394
+ "metadata": {
395
+ "kernelspec": {
396
+ "display_name": "timedlibs",
397
+ "language": "python",
398
+ "name": "timedlibs"
399
+ },
400
+ "language_info": {
401
+ "codemirror_mode": {
402
+ "name": "ipython",
403
+ "version": 3
404
+ },
405
+ "file_extension": ".py",
406
+ "mimetype": "text/x-python",
407
+ "name": "python",
408
+ "nbconvert_exporter": "python",
409
+ "pygments_lexer": "ipython3",
410
+ "version": "3.10.16"
411
+ }
412
+ },
413
+ "nbformat": 4,
414
+ "nbformat_minor": 5
415
+ }
InstructionFinetuning.ipynb ADDED
@@ -0,0 +1,1277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "e6d20008-a91c-4618-baa0-5991e031f1bd",
7
+ "metadata": {
8
+ "execution": {
9
+ "iopub.execute_input": "2025-05-13T21:48:57.985184Z",
10
+ "iopub.status.busy": "2025-05-13T21:48:57.984795Z",
11
+ "iopub.status.idle": "2025-05-13T21:51:48.369715Z",
12
+ "shell.execute_reply": "2025-05-13T21:51:48.368907Z",
13
+ "shell.execute_reply.started": "2025-05-13T21:48:57.985144Z"
14
+ }
15
+ },
16
+ "outputs": [
17
+ {
18
+ "name": "stderr",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "/root/notebooks/MT_TQ/Libraries/timedlibs/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
22
+ " from .autonotebook import tqdm as notebook_tqdm\n"
23
+ ]
24
+ }
25
+ ],
26
+ "source": [
27
+ "from transformers import AutoProcessor, Gemma3ForConditionalGeneration, Trainer, TrainingArguments, DataCollatorForSeq2Seq\n",
28
+ "import torch\n",
29
+ "from peft import LoraConfig, get_peft_model\n",
30
+ "\n",
31
+ "import os\n",
32
+ "from tqdm import tqdm\n",
33
+ "import json\n",
34
+ "\n",
35
+ "import random\n",
36
+ "from datasets import load_dataset\n",
37
+ "from datasets import Dataset, DatasetDict"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": 3,
43
+ "id": "67f95fc8-a9d8-48cf-a551-7d30781cdb55",
44
+ "metadata": {
45
+ "execution": {
46
+ "iopub.execute_input": "2025-05-13T21:53:58.075473Z",
47
+ "iopub.status.busy": "2025-05-13T21:53:58.074767Z",
48
+ "iopub.status.idle": "2025-05-13T21:53:58.767860Z",
49
+ "shell.execute_reply": "2025-05-13T21:53:58.767319Z",
50
+ "shell.execute_reply.started": "2025-05-13T21:53:58.075446Z"
51
+ }
52
+ },
53
+ "outputs": [
54
+ {
55
+ "name": "stderr",
56
+ "output_type": "stream",
57
+ "text": [
58
+ "100%|██████████| 8/8 [00:00<00:00, 22.76it/s]\n"
59
+ ]
60
+ },
61
+ {
62
+ "name": "stdout",
63
+ "output_type": "stream",
64
+ "text": [
65
+ "DatasetDict({\n",
66
+ " train: Dataset({\n",
67
+ " features: ['messages'],\n",
68
+ " num_rows: 309\n",
69
+ " })\n",
70
+ " test: Dataset({\n",
71
+ " features: ['messages'],\n",
72
+ " num_rows: 343\n",
73
+ " })\n",
74
+ "})\n"
75
+ ]
76
+ }
77
+ ],
78
+ "source": [
79
+ "data_path = (\n",
80
+ " \"/root/notebooks/MT_TQ/Caches/May2025/tquality.annotated.data/parsed/pldl/\"\n",
81
+ ")\n",
82
+ "\n",
83
+ "json_files = [\n",
84
+ " os.path.join(root, file)\n",
85
+ " for root, _, files in os.walk(data_path)\n",
86
+ " for file in files\n",
87
+ " if file.endswith(\".json\")\n",
88
+ "]\n",
89
+ "\n",
90
+ "training_samples = []\n",
91
+ "testing_samples = []\n",
92
+ "\n",
93
+ "for json_file in tqdm(json_files):\n",
94
+ " with open(json_file, \"r\") as file:\n",
95
+ " data = json.load(file)\n",
96
+ " sampled_items = data[\"data\"]\n",
97
+ " if \"test\" in json_file:\n",
98
+ " testing_samples.extend(sampled_items)\n",
99
+ " if \"train\" in json_file:\n",
100
+ " training_samples.extend(sampled_items)\n",
101
+ "\n",
102
+ "training_datapoints = []\n",
103
+ "testing_datapoints = []\n",
104
+ "\n",
105
+ "for idx, sample in enumerate(training_samples):\n",
106
+ " datapoint = {\"input\": {}}\n",
107
+ " datapoint[\"input\"][\"src_text\"] = sample[\"src_text\"]\n",
108
+ " datapoint[\"input\"][\"tgt_text\"] = sample[\"main_tgt_text\"]\n",
109
+ " datapoint[\"input\"][\"src_prev\"] = sample[\"tt_src_prev\"]\n",
110
+ " datapoint[\"input\"][\"src_next\"] = sample[\"tt_src_next\"]\n",
111
+ " datapoint[\"input\"][\"tgt_prev\"] = sample[\"tt_tgt_prev\"]\n",
112
+ " datapoint[\"input\"][\"tgt_next\"] = sample[\"tt_tgt_next\"]\n",
113
+ " datapoint[\"input\"][\"src_lang\"] = sample[\"src_lang\"]\n",
114
+ " datapoint[\"input\"][\"tgt_lang\"] = sample[\"tgt_lang\"]\n",
115
+ " datapoint[\"input\"][\"start_frame\"] = sample[\"start_frame\"]\n",
116
+ " datapoint[\"input\"][\"end_frame\"] = sample[\"end_frame\"]\n",
117
+ " datapoint[\"input\"][\"title_id\"] = sample[\"title_id\"]\n",
118
+ " datapoint[\"input\"][\"alt_tgt_text\"]= sample[\"alt_tgt_text\"]\n",
119
+ " datapoint[\"input\"][\"id\"] = idx\n",
120
+ " datapoint[\"evaluation\"] = sample[\"labelers\"][0][\"annotation\"]\n",
121
+ " training_datapoints.append(datapoint)\n",
122
+ "\n",
123
+ "for idx, sample in enumerate(testing_samples):\n",
124
+ " datapoint = {\"input\": {}}\n",
125
+ " datapoint[\"input\"][\"src_text\"] = sample[\"src_text\"]\n",
126
+ " datapoint[\"input\"][\"tgt_text\"] = sample[\"main_tgt_text\"]\n",
127
+ " datapoint[\"input\"][\"src_prev\"] = sample[\"tt_src_prev\"]\n",
128
+ " datapoint[\"input\"][\"src_next\"] = sample[\"tt_src_next\"]\n",
129
+ " datapoint[\"input\"][\"tgt_prev\"] = sample[\"tt_tgt_prev\"]\n",
130
+ " datapoint[\"input\"][\"tgt_next\"] = sample[\"tt_tgt_next\"]\n",
131
+ " datapoint[\"input\"][\"src_lang\"] = sample[\"src_lang\"]\n",
132
+ " datapoint[\"input\"][\"tgt_lang\"] = sample[\"tgt_lang\"]\n",
133
+ " datapoint[\"input\"][\"start_frame\"] = sample[\"start_frame\"]\n",
134
+ " datapoint[\"input\"][\"end_frame\"] = sample[\"end_frame\"]\n",
135
+ " datapoint[\"input\"][\"title_id\"] = sample[\"title_id\"]\n",
136
+ " datapoint[\"input\"][\"alt_tgt_text\"]= sample[\"alt_tgt_text\"]\n",
137
+ " datapoint[\"input\"][\"id\"] = idx\n",
138
+ " datapoint[\"evaluation\"] = sample[\"labelers\"][0][\"annotation\"]\n",
139
+ " testing_datapoints.append(datapoint)\n",
140
+ "\n",
141
+ "system_message = \"You are a helpful assistant who is an expert in estimating quality of translations.\"\n",
142
+ "\n",
143
+ "output_template = '''\n",
144
+ "{\n",
145
+ " \"Accuracy Issues\": [\n",
146
+ " {\n",
147
+ " \"Error Span\": \"\",\n",
148
+ " \"Error Explanation\": \"\",\n",
149
+ " \"Error Quality Category\": \"\",\n",
150
+ " \"Error Quality Tags\": [],\n",
151
+ " \"Error Severity\": \"\"\n",
152
+ " }\n",
153
+ " ],\n",
154
+ " \"Accuracy Score\": \"\",\n",
155
+ " \"Readability Issues\": [\n",
156
+ " {\n",
157
+ " \"Error Span\": \"\",\n",
158
+ " \"Error Explanation\": \"\",\n",
159
+ " \"Error Quality Category\": \"\",\n",
160
+ " \"Error Quality Tags\": [],\n",
161
+ " \"Error Severity\": \"\"\n",
162
+ " }\n",
163
+ " ],\n",
164
+ " \"Readability Score\": \"\"\n",
165
+ "}'''\n",
166
+ "\n",
167
+ "def create_conversation(input_sample, output_sample):\n",
168
+ " return {\n",
169
+ " \"messages\": [\n",
170
+ " # {\"role\": \"system\", \"content\": system_message},\n",
171
+ " {\"role\": \"user\", \"content\": input_sample},\n",
172
+ " {\"role\": \"assistant\", \"content\": output_sample}\n",
173
+ " ]\n",
174
+ " }\n",
175
+ "\n",
176
+ "def create_dataset(datapoints, template_string):\n",
177
+ " dataset = []\n",
178
+ " meta = []\n",
179
+ " for datapoint in datapoints:\n",
180
+ " src_text = datapoint['input']['src_text']\n",
181
+ " tgt_text = datapoint['input']['tgt_text']\n",
182
+ " src_prev = datapoint['input']['src_prev']\n",
183
+ " src_next = datapoint['input']['src_next'] \n",
184
+ " tgt_prev = datapoint['input']['tgt_prev']\n",
185
+ " tgt_next = datapoint['input']['tgt_next']\n",
186
+ " src_lang = datapoint['input']['src_lang']\n",
187
+ " tgt_lang = datapoint['input']['tgt_lang']\n",
188
+ " \n",
189
+ " start_frame = datapoint['input']['start_frame']\n",
190
+ " end_frame = datapoint['input']['end_frame']\n",
191
+ " title_id = datapoint['input']['title_id']\n",
192
+ " output = datapoint['evaluation']\n",
193
+ " idx = datapoint['input']['id']\n",
194
+ " if len(output['Accuracy Issues']) != 0 or len(output['Readability Issues']) != 0:\n",
195
+ " item = template_string.format(src_text=src_text, tgt_text=tgt_text, \n",
196
+ " src_prev=src_prev, src_next=src_next, \n",
197
+ " tgt_prev=tgt_prev, tgt_next=tgt_next, \n",
198
+ " src_lang=src_lang, tgt_lang=tgt_lang,\n",
199
+ " template=output_template)\n",
200
+ " \n",
201
+ " dataset.append(create_conversation(item, json.dumps(output)))\n",
202
+ " meta.append({\"id\": idx, \"start_frame\": start_frame, \"end_frame\": end_frame, \"title_id\": title_id})\n",
203
+ " \n",
204
+ " return dataset, meta\n",
205
+ " \n",
206
+ "def dataset_prep(datapoints):\n",
207
+ " with open(\"prompts.txt\") as file:\n",
208
+ " template_string = file.read()\n",
209
+ " dataset, meta = create_dataset(datapoints, template_string)\n",
210
+ " return dataset, meta\n",
211
+ "\n",
212
+ "train_dataset, train_meta = dataset_prep(training_datapoints)\n",
213
+ "test_dataset, test_meta = dataset_prep(testing_datapoints)\n",
214
+ "\n",
215
+ "dataset = {\"train\": train_dataset, \"test\": test_dataset}\n",
216
+ "\n",
217
+ "def convert_to_hf_dataset(dataset):\n",
218
+ " train_dataset = Dataset.from_list(dataset['train'])\n",
219
+ " test_dataset = Dataset.from_list(dataset['test'])\n",
220
+ " \n",
221
+ " hf_dataset = DatasetDict({\n",
222
+ " 'train': train_dataset,\n",
223
+ " 'test': test_dataset\n",
224
+ " })\n",
225
+ " \n",
226
+ " return hf_dataset\n",
227
+ "\n",
228
+ "hf_dataset = convert_to_hf_dataset(dataset)\n",
229
+ "print(hf_dataset)"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": 4,
235
+ "id": "8b52f143-1077-4da6-ac92-b1dce5cdc17c",
236
+ "metadata": {
237
+ "execution": {
238
+ "iopub.execute_input": "2025-05-13T21:54:12.568533Z",
239
+ "iopub.status.busy": "2025-05-13T21:54:12.568078Z",
240
+ "iopub.status.idle": "2025-05-13T21:54:49.724121Z",
241
+ "shell.execute_reply": "2025-05-13T21:54:49.723481Z",
242
+ "shell.execute_reply.started": "2025-05-13T21:54:12.568507Z"
243
+ }
244
+ },
245
+ "outputs": [
246
+ {
247
+ "name": "stderr",
248
+ "output_type": "stream",
249
+ "text": [
250
+ "Loading checkpoint shards: 100%|██████████| 12/12 [00:18<00:00, 1.58s/it]\n"
251
+ ]
252
+ }
253
+ ],
254
+ "source": [
255
+ "import torch\n",
256
+ "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig\n",
257
+ "from transformers import AutoProcessor, Gemma3ForConditionalGeneration\n",
258
+ "device = torch.device(\"cuda:0\")\n",
259
+ "\n",
260
+ "# Hugging Face model id\n",
261
+ "model_id = \"google/gemma-3-27b-it\" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`\n",
262
+ "\n",
263
+ "# Select model class based on id\n",
264
+ "if model_id == \"google/gemma-3-27b-it\":\n",
265
+ " model_class = Gemma3ForConditionalGeneration\n",
266
+ "else:\n",
267
+ " model_class = AutoModelForImageTextToText\n",
268
+ "\n",
269
+ "torch_dtype = torch.bfloat16\n",
270
+ "\n",
271
+ "model_kwargs = dict(\n",
272
+ " attn_implementation=\"eager\",\n",
273
+ " torch_dtype=torch_dtype,\n",
274
+ " device_map=\"auto\",\n",
275
+ ")\n",
276
+ "\n",
277
+ "model = model_class.from_pretrained(model_id, **model_kwargs)\n",
278
+ "tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-3-27b-it\") # Load the Instruction Tokenizer to use the official Gemma template"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 5,
284
+ "id": "8443dfd8-6193-480c-9937-f6e0c43a9f56",
285
+ "metadata": {
286
+ "execution": {
287
+ "iopub.execute_input": "2025-05-13T21:55:12.713958Z",
288
+ "iopub.status.busy": "2025-05-13T21:55:12.713495Z",
289
+ "iopub.status.idle": "2025-05-13T21:55:12.717707Z",
290
+ "shell.execute_reply": "2025-05-13T21:55:12.717199Z",
291
+ "shell.execute_reply.started": "2025-05-13T21:55:12.713930Z"
292
+ }
293
+ },
294
+ "outputs": [],
295
+ "source": [
296
+ "from peft import LoraConfig\n",
297
+ "\n",
298
+ "peft_config = LoraConfig(\n",
299
+ " lora_alpha=128,\n",
300
+ " lora_dropout=0.05,\n",
301
+ " r=16,\n",
302
+ " bias=\"none\",\n",
303
+ " target_modules=\"all-linear\",\n",
304
+ " task_type=\"CAUSAL_LM\",\n",
305
+ " modules_to_save=[\"lm_head\", \"embed_tokens\"] # make sure to save the lm_head and embed_tokens as you train the special tokens\n",
306
+ ")"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": 6,
312
+ "id": "8f2b8371-ba1b-44ff-9462-d0c90335f82a",
313
+ "metadata": {
314
+ "execution": {
315
+ "iopub.execute_input": "2025-05-13T21:55:22.076515Z",
316
+ "iopub.status.busy": "2025-05-13T21:55:22.076029Z",
317
+ "iopub.status.idle": "2025-05-13T21:55:22.783524Z",
318
+ "shell.execute_reply": "2025-05-13T21:55:22.782937Z",
319
+ "shell.execute_reply.started": "2025-05-13T21:55:22.076489Z"
320
+ }
321
+ },
322
+ "outputs": [],
323
+ "source": [
324
+ "from trl import SFTConfig\n",
325
+ "\n",
326
+ "args = SFTConfig(\n",
327
+ " output_dir=\"may13-gemma-27b-tq_sft_finetuned-model\",\n",
328
+ " max_seq_length=2048,\n",
329
+ " packing=True,\n",
330
+ " num_train_epochs=1,\n",
331
+ " per_device_train_batch_size=1,\n",
332
+ " gradient_accumulation_steps=4,\n",
333
+ " gradient_checkpointing=True,\n",
334
+ " optim=\"adamw_torch_fused\",\n",
335
+ " logging_steps=1,\n",
336
+ " save_strategy=\"epoch\",\n",
337
+ " learning_rate=1e-4,\n",
338
+ " fp16=True if torch_dtype == torch.float16 else False,\n",
339
+ " bf16=True if torch_dtype == torch.bfloat16 else False,\n",
340
+ " max_grad_norm=0.3,\n",
341
+ " warmup_ratio=0.03,\n",
342
+ " lr_scheduler_type=\"constant\",\n",
343
+ " push_to_hub=True,\n",
344
+ " report_to=\"tensorboard\",\n",
345
+ " dataset_kwargs={\n",
346
+ " \"add_special_tokens\": False,\n",
347
+ " \"append_concat_token\": True,\n",
348
+ " },\n",
349
+ " no_cuda=False,\n",
350
+ ")"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": 7,
356
+ "id": "2be55b87-70c9-4973-b0db-33154c272e47",
357
+ "metadata": {
358
+ "execution": {
359
+ "iopub.execute_input": "2025-05-13T21:55:25.765385Z",
360
+ "iopub.status.busy": "2025-05-13T21:55:25.764949Z",
361
+ "iopub.status.idle": "2025-05-13T21:55:36.592163Z",
362
+ "shell.execute_reply": "2025-05-13T21:55:36.591614Z",
363
+ "shell.execute_reply.started": "2025-05-13T21:55:25.765360Z"
364
+ }
365
+ },
366
+ "outputs": [
367
+ {
368
+ "name": "stderr",
369
+ "output_type": "stream",
370
+ "text": [
371
+ "Converting train dataset to ChatML: 100%|██████████| 309/309 [00:00<00:00, 9533.70 examples/s]\n",
372
+ "Applying chat template to train dataset: 100%|██████████| 309/309 [00:00<00:00, 4443.06 examples/s]\n",
373
+ "Tokenizing train dataset: 100%|██████████| 309/309 [00:01<00:00, 226.22 examples/s]\n",
374
+ "Packing train dataset: 100%|██████████| 309/309 [00:00<00:00, 102364.74 examples/s]\n",
375
+ "No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\n"
376
+ ]
377
+ }
378
+ ],
379
+ "source": [
380
+ "from trl import SFTTrainer\n",
381
+ "\n",
382
+ "# Create Trainer object\n",
383
+ "trainer = SFTTrainer(\n",
384
+ " model=model,\n",
385
+ " args=args,\n",
386
+ " train_dataset=hf_dataset[\"train\"],\n",
387
+ " peft_config=peft_config,\n",
388
+ " processing_class=tokenizer\n",
389
+ ")"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": 8,
395
+ "id": "d8d82767-27ed-48ed-ad22-3f3cf2dff15e",
396
+ "metadata": {
397
+ "execution": {
398
+ "iopub.execute_input": "2025-05-13T22:00:25.107226Z",
399
+ "iopub.status.busy": "2025-05-13T22:00:25.106569Z",
400
+ "iopub.status.idle": "2025-05-13T22:27:35.945604Z",
401
+ "shell.execute_reply": "2025-05-13T22:27:35.944775Z",
402
+ "shell.execute_reply.started": "2025-05-13T22:00:25.107196Z"
403
+ },
404
+ "scrolled": true
405
+ },
406
+ "outputs": [
407
+ {
408
+ "name": "stderr",
409
+ "output_type": "stream",
410
+ "text": [
411
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\n"
412
+ ]
413
+ },
414
+ {
415
+ "data": {
416
+ "text/html": [
417
+ "\n",
418
+ " <div>\n",
419
+ " \n",
420
+ " <progress value='80' max='80' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
421
+ " [80/80 22:44, Epoch 0/1]\n",
422
+ " </div>\n",
423
+ " <table border=\"1\" class=\"dataframe\">\n",
424
+ " <thead>\n",
425
+ " <tr style=\"text-align: left;\">\n",
426
+ " <th>Step</th>\n",
427
+ " <th>Training Loss</th>\n",
428
+ " </tr>\n",
429
+ " </thead>\n",
430
+ " <tbody>\n",
431
+ " <tr>\n",
432
+ " <td>1</td>\n",
433
+ " <td>10.801900</td>\n",
434
+ " </tr>\n",
435
+ " <tr>\n",
436
+ " <td>2</td>\n",
437
+ " <td>8.381400</td>\n",
438
+ " </tr>\n",
439
+ " <tr>\n",
440
+ " <td>3</td>\n",
441
+ " <td>6.970200</td>\n",
442
+ " </tr>\n",
443
+ " <tr>\n",
444
+ " <td>4</td>\n",
445
+ " <td>5.784300</td>\n",
446
+ " </tr>\n",
447
+ " <tr>\n",
448
+ " <td>5</td>\n",
449
+ " <td>4.970800</td>\n",
450
+ " </tr>\n",
451
+ " <tr>\n",
452
+ " <td>6</td>\n",
453
+ " <td>4.389700</td>\n",
454
+ " </tr>\n",
455
+ " <tr>\n",
456
+ " <td>7</td>\n",
457
+ " <td>4.325000</td>\n",
458
+ " </tr>\n",
459
+ " <tr>\n",
460
+ " <td>8</td>\n",
461
+ " <td>3.557000</td>\n",
462
+ " </tr>\n",
463
+ " <tr>\n",
464
+ " <td>9</td>\n",
465
+ " <td>3.357700</td>\n",
466
+ " </tr>\n",
467
+ " <tr>\n",
468
+ " <td>10</td>\n",
469
+ " <td>3.092500</td>\n",
470
+ " </tr>\n",
471
+ " <tr>\n",
472
+ " <td>11</td>\n",
473
+ " <td>3.170300</td>\n",
474
+ " </tr>\n",
475
+ " <tr>\n",
476
+ " <td>12</td>\n",
477
+ " <td>2.648500</td>\n",
478
+ " </tr>\n",
479
+ " <tr>\n",
480
+ " <td>13</td>\n",
481
+ " <td>3.067800</td>\n",
482
+ " </tr>\n",
483
+ " <tr>\n",
484
+ " <td>14</td>\n",
485
+ " <td>2.377100</td>\n",
486
+ " </tr>\n",
487
+ " <tr>\n",
488
+ " <td>15</td>\n",
489
+ " <td>2.847700</td>\n",
490
+ " </tr>\n",
491
+ " <tr>\n",
492
+ " <td>16</td>\n",
493
+ " <td>2.628800</td>\n",
494
+ " </tr>\n",
495
+ " <tr>\n",
496
+ " <td>17</td>\n",
497
+ " <td>2.630800</td>\n",
498
+ " </tr>\n",
499
+ " <tr>\n",
500
+ " <td>18</td>\n",
501
+ " <td>2.820900</td>\n",
502
+ " </tr>\n",
503
+ " <tr>\n",
504
+ " <td>19</td>\n",
505
+ " <td>2.596700</td>\n",
506
+ " </tr>\n",
507
+ " <tr>\n",
508
+ " <td>20</td>\n",
509
+ " <td>2.675300</td>\n",
510
+ " </tr>\n",
511
+ " <tr>\n",
512
+ " <td>21</td>\n",
513
+ " <td>2.846300</td>\n",
514
+ " </tr>\n",
515
+ " <tr>\n",
516
+ " <td>22</td>\n",
517
+ " <td>2.706700</td>\n",
518
+ " </tr>\n",
519
+ " <tr>\n",
520
+ " <td>23</td>\n",
521
+ " <td>2.645100</td>\n",
522
+ " </tr>\n",
523
+ " <tr>\n",
524
+ " <td>24</td>\n",
525
+ " <td>2.214600</td>\n",
526
+ " </tr>\n",
527
+ " <tr>\n",
528
+ " <td>25</td>\n",
529
+ " <td>2.790700</td>\n",
530
+ " </tr>\n",
531
+ " <tr>\n",
532
+ " <td>26</td>\n",
533
+ " <td>2.640700</td>\n",
534
+ " </tr>\n",
535
+ " <tr>\n",
536
+ " <td>27</td>\n",
537
+ " <td>2.908900</td>\n",
538
+ " </tr>\n",
539
+ " <tr>\n",
540
+ " <td>28</td>\n",
541
+ " <td>2.690400</td>\n",
542
+ " </tr>\n",
543
+ " <tr>\n",
544
+ " <td>29</td>\n",
545
+ " <td>2.807200</td>\n",
546
+ " </tr>\n",
547
+ " <tr>\n",
548
+ " <td>30</td>\n",
549
+ " <td>2.713600</td>\n",
550
+ " </tr>\n",
551
+ " <tr>\n",
552
+ " <td>31</td>\n",
553
+ " <td>2.563200</td>\n",
554
+ " </tr>\n",
555
+ " <tr>\n",
556
+ " <td>32</td>\n",
557
+ " <td>2.412700</td>\n",
558
+ " </tr>\n",
559
+ " <tr>\n",
560
+ " <td>33</td>\n",
561
+ " <td>2.627700</td>\n",
562
+ " </tr>\n",
563
+ " <tr>\n",
564
+ " <td>34</td>\n",
565
+ " <td>2.431800</td>\n",
566
+ " </tr>\n",
567
+ " <tr>\n",
568
+ " <td>35</td>\n",
569
+ " <td>2.240600</td>\n",
570
+ " </tr>\n",
571
+ " <tr>\n",
572
+ " <td>36</td>\n",
573
+ " <td>2.650300</td>\n",
574
+ " </tr>\n",
575
+ " <tr>\n",
576
+ " <td>37</td>\n",
577
+ " <td>2.014900</td>\n",
578
+ " </tr>\n",
579
+ " <tr>\n",
580
+ " <td>38</td>\n",
581
+ " <td>2.463100</td>\n",
582
+ " </tr>\n",
583
+ " <tr>\n",
584
+ " <td>39</td>\n",
585
+ " <td>2.283300</td>\n",
586
+ " </tr>\n",
587
+ " <tr>\n",
588
+ " <td>40</td>\n",
589
+ " <td>2.450500</td>\n",
590
+ " </tr>\n",
591
+ " <tr>\n",
592
+ " <td>41</td>\n",
593
+ " <td>2.570400</td>\n",
594
+ " </tr>\n",
595
+ " <tr>\n",
596
+ " <td>42</td>\n",
597
+ " <td>2.550500</td>\n",
598
+ " </tr>\n",
599
+ " <tr>\n",
600
+ " <td>43</td>\n",
601
+ " <td>2.530600</td>\n",
602
+ " </tr>\n",
603
+ " <tr>\n",
604
+ " <td>44</td>\n",
605
+ " <td>2.551400</td>\n",
606
+ " </tr>\n",
607
+ " <tr>\n",
608
+ " <td>45</td>\n",
609
+ " <td>2.383000</td>\n",
610
+ " </tr>\n",
611
+ " <tr>\n",
612
+ " <td>46</td>\n",
613
+ " <td>2.550500</td>\n",
614
+ " </tr>\n",
615
+ " <tr>\n",
616
+ " <td>47</td>\n",
617
+ " <td>2.575900</td>\n",
618
+ " </tr>\n",
619
+ " <tr>\n",
620
+ " <td>48</td>\n",
621
+ " <td>2.494300</td>\n",
622
+ " </tr>\n",
623
+ " <tr>\n",
624
+ " <td>49</td>\n",
625
+ " <td>2.387200</td>\n",
626
+ " </tr>\n",
627
+ " <tr>\n",
628
+ " <td>50</td>\n",
629
+ " <td>2.318800</td>\n",
630
+ " </tr>\n",
631
+ " <tr>\n",
632
+ " <td>51</td>\n",
633
+ " <td>2.365200</td>\n",
634
+ " </tr>\n",
635
+ " <tr>\n",
636
+ " <td>52</td>\n",
637
+ " <td>2.190100</td>\n",
638
+ " </tr>\n",
639
+ " <tr>\n",
640
+ " <td>53</td>\n",
641
+ " <td>2.419100</td>\n",
642
+ " </tr>\n",
643
+ " <tr>\n",
644
+ " <td>54</td>\n",
645
+ " <td>2.290900</td>\n",
646
+ " </tr>\n",
647
+ " <tr>\n",
648
+ " <td>55</td>\n",
649
+ " <td>2.152500</td>\n",
650
+ " </tr>\n",
651
+ " <tr>\n",
652
+ " <td>56</td>\n",
653
+ " <td>2.398700</td>\n",
654
+ " </tr>\n",
655
+ " <tr>\n",
656
+ " <td>57</td>\n",
657
+ " <td>2.982500</td>\n",
658
+ " </tr>\n",
659
+ " <tr>\n",
660
+ " <td>58</td>\n",
661
+ " <td>2.380200</td>\n",
662
+ " </tr>\n",
663
+ " <tr>\n",
664
+ " <td>59</td>\n",
665
+ " <td>2.357500</td>\n",
666
+ " </tr>\n",
667
+ " <tr>\n",
668
+ " <td>60</td>\n",
669
+ " <td>2.386300</td>\n",
670
+ " </tr>\n",
671
+ " <tr>\n",
672
+ " <td>61</td>\n",
673
+ " <td>2.741300</td>\n",
674
+ " </tr>\n",
675
+ " <tr>\n",
676
+ " <td>62</td>\n",
677
+ " <td>2.850300</td>\n",
678
+ " </tr>\n",
679
+ " <tr>\n",
680
+ " <td>63</td>\n",
681
+ " <td>2.682100</td>\n",
682
+ " </tr>\n",
683
+ " <tr>\n",
684
+ " <td>64</td>\n",
685
+ " <td>2.972100</td>\n",
686
+ " </tr>\n",
687
+ " <tr>\n",
688
+ " <td>65</td>\n",
689
+ " <td>2.237800</td>\n",
690
+ " </tr>\n",
691
+ " <tr>\n",
692
+ " <td>66</td>\n",
693
+ " <td>2.518300</td>\n",
694
+ " </tr>\n",
695
+ " <tr>\n",
696
+ " <td>67</td>\n",
697
+ " <td>2.520700</td>\n",
698
+ " </tr>\n",
699
+ " <tr>\n",
700
+ " <td>68</td>\n",
701
+ " <td>2.122700</td>\n",
702
+ " </tr>\n",
703
+ " <tr>\n",
704
+ " <td>69</td>\n",
705
+ " <td>2.210200</td>\n",
706
+ " </tr>\n",
707
+ " <tr>\n",
708
+ " <td>70</td>\n",
709
+ " <td>2.414000</td>\n",
710
+ " </tr>\n",
711
+ " <tr>\n",
712
+ " <td>71</td>\n",
713
+ " <td>2.348200</td>\n",
714
+ " </tr>\n",
715
+ " <tr>\n",
716
+ " <td>72</td>\n",
717
+ " <td>2.470800</td>\n",
718
+ " </tr>\n",
719
+ " <tr>\n",
720
+ " <td>73</td>\n",
721
+ " <td>2.417400</td>\n",
722
+ " </tr>\n",
723
+ " <tr>\n",
724
+ " <td>74</td>\n",
725
+ " <td>2.562900</td>\n",
726
+ " </tr>\n",
727
+ " <tr>\n",
728
+ " <td>75</td>\n",
729
+ " <td>2.286800</td>\n",
730
+ " </tr>\n",
731
+ " <tr>\n",
732
+ " <td>76</td>\n",
733
+ " <td>2.671400</td>\n",
734
+ " </tr>\n",
735
+ " <tr>\n",
736
+ " <td>77</td>\n",
737
+ " <td>2.176200</td>\n",
738
+ " </tr>\n",
739
+ " <tr>\n",
740
+ " <td>78</td>\n",
741
+ " <td>2.284200</td>\n",
742
+ " </tr>\n",
743
+ " <tr>\n",
744
+ " <td>79</td>\n",
745
+ " <td>2.354700</td>\n",
746
+ " </tr>\n",
747
+ " <tr>\n",
748
+ " <td>80</td>\n",
749
+ " <td>2.363400</td>\n",
750
+ " </tr>\n",
751
+ " </tbody>\n",
752
+ "</table><p>"
753
+ ],
754
+ "text/plain": [
755
+ "<IPython.core.display.HTML object>"
756
+ ]
757
+ },
758
+ "metadata": {},
759
+ "output_type": "display_data"
760
+ }
761
+ ],
762
+ "source": [
763
+ "trainer.train()\n",
764
+ "trainer.save_model()"
765
+ ]
766
+ },
767
+ {
768
+ "cell_type": "code",
769
+ "execution_count": 10,
770
+ "id": "2398696f-eeb8-45d1-8dee-ed88a7ac140b",
771
+ "metadata": {
772
+ "execution": {
773
+ "iopub.execute_input": "2025-05-13T22:34:47.172016Z",
774
+ "iopub.status.busy": "2025-05-13T22:34:47.171574Z",
775
+ "iopub.status.idle": "2025-05-13T22:39:06.055171Z",
776
+ "shell.execute_reply": "2025-05-13T22:39:06.054429Z",
777
+ "shell.execute_reply.started": "2025-05-13T22:34:47.171989Z"
778
+ }
779
+ },
780
+ "outputs": [
781
+ {
782
+ "name": "stderr",
783
+ "output_type": "stream",
784
+ "text": [
785
+ "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n"
786
+ ]
787
+ },
788
+ {
789
+ "data": {
790
+ "text/plain": [
791
+ "('/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full/tokenizer_config.json',\n",
792
+ " '/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full/special_tokens_map.json',\n",
793
+ " '/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full/tokenizer.model',\n",
794
+ " '/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full/added_tokens.json',\n",
795
+ " '/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full/tokenizer.json')"
796
+ ]
797
+ },
798
+ "execution_count": 10,
799
+ "metadata": {},
800
+ "output_type": "execute_result"
801
+ }
802
+ ],
803
+ "source": [
804
+ "lora_model = trainer.model\n",
805
+ "merged_model = lora_model.merge_and_unload()\n",
806
+ "# Save the model with fused weights\n",
807
+ "merged_model.save_pretrained('/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full')\n",
808
+ "trainer.tokenizer.save_pretrained('/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full')"
809
+ ]
810
+ },
811
+ {
812
+ "cell_type": "code",
813
+ "execution_count": 1,
814
+ "id": "8b811a84-0cdb-4b40-bb96-d6e6f27d41d3",
815
+ "metadata": {
816
+ "execution": {
817
+ "iopub.execute_input": "2025-05-08T21:17:00.794785Z",
818
+ "iopub.status.busy": "2025-05-08T21:17:00.794339Z",
819
+ "iopub.status.idle": "2025-05-08T21:17:18.309148Z",
820
+ "shell.execute_reply": "2025-05-08T21:17:18.308319Z",
821
+ "shell.execute_reply.started": "2025-05-08T21:17:00.794761Z"
822
+ }
823
+ },
824
+ "outputs": [
825
+ {
826
+ "ename": "NameError",
827
+ "evalue": "name 'model' is not defined",
828
+ "output_type": "error",
829
+ "traceback": [
830
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
831
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
832
+ "Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Merge LoRA weights into the base model\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, param \u001b[38;5;129;01min\u001b[39;00m \u001b[43mmodel\u001b[49m\u001b[38;5;241m.\u001b[39mnamed_parameters():\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mpeft_model\u001b[38;5;241m.\u001b[39mlora_weights:\n\u001b[1;32m 4\u001b[0m param\u001b[38;5;241m.\u001b[39mdata \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m trainer\u001b[38;5;241m.\u001b[39mpeft_model\u001b[38;5;241m.\u001b[39mlora_weights[name]\n",
833
+ "\u001b[0;31mNameError\u001b[0m: name 'model' is not defined"
834
+ ]
835
+ }
836
+ ],
837
+ "source": [
838
+ "# Merge LoRA weights into the base model\n",
839
+ "for name, param in model.named_parameters():\n",
840
+ " if name in trainer.peft_model.lora_weights:\n",
841
+ " param.data += trainer.peft_model.lora_weights[name]\n",
842
+ "\n",
843
+ "# Save the model with fused weights\n",
844
+ "model.save_pretrained('/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full')\n",
845
+ "tokenizer.save_pretrained('/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full')"
846
+ ]
847
+ },
848
+ {
849
+ "cell_type": "code",
850
+ "execution_count": 9,
851
+ "id": "e5b4930d-92c5-46e8-9163-6e7f722e0c99",
852
+ "metadata": {
853
+ "execution": {
854
+ "iopub.execute_input": "2025-05-08T19:13:24.762234Z",
855
+ "iopub.status.busy": "2025-05-08T19:13:24.761972Z",
856
+ "iopub.status.idle": "2025-05-08T19:13:50.993002Z",
857
+ "shell.execute_reply": "2025-05-08T19:13:50.992329Z",
858
+ "shell.execute_reply.started": "2025-05-08T19:13:24.762215Z"
859
+ }
860
+ },
861
+ "outputs": [
862
+ {
863
+ "name": "stderr",
864
+ "output_type": "stream",
865
+ "text": [
866
+ "Loading checkpoint shards: 100%|██████████| 12/12 [00:19<00:00, 1.60s/it]\n"
867
+ ]
868
+ }
869
+ ],
870
+ "source": [
871
+ "import torch\n",
872
+ "from transformers import pipeline\n",
873
+ "from random import randint\n",
874
+ "import re\n",
875
+ "\n",
876
+ "model_id = \"google/gemma-3-27b-it\"\n",
877
+ "model = model_class.from_pretrained(\n",
878
+ " model_id,\n",
879
+ " device_map=\"auto\",\n",
880
+ " torch_dtype=torch_dtype,\n",
881
+ " attn_implementation=\"eager\",\n",
882
+ ")\n",
883
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n"
884
+ ]
885
+ },
886
+ {
887
+ "cell_type": "code",
888
+ "execution_count": 10,
889
+ "id": "5a428dea-261a-4c74-89a8-1b62d7ade5ab",
890
+ "metadata": {
891
+ "execution": {
892
+ "iopub.execute_input": "2025-05-08T19:13:50.999539Z",
893
+ "iopub.status.busy": "2025-05-08T19:13:50.999160Z",
894
+ "iopub.status.idle": "2025-05-08T19:15:04.024652Z",
895
+ "shell.execute_reply": "2025-05-08T19:15:04.022626Z",
896
+ "shell.execute_reply.started": "2025-05-08T19:13:50.999517Z"
897
+ }
898
+ },
899
+ "outputs": [
900
+ {
901
+ "ename": "NameError",
902
+ "evalue": "name 'trainer' is not defined",
903
+ "output_type": "error",
904
+ "traceback": [
905
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
906
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
907
+ "Cell \u001b[0;32mIn[10], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Merge LoRA weights into the base model\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, param \u001b[38;5;129;01min\u001b[39;00m model\u001b[38;5;241m.\u001b[39mnamed_parameters():\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[43mtrainer\u001b[49m\u001b[38;5;241m.\u001b[39mpeft_model\u001b[38;5;241m.\u001b[39mlora_weights:\n\u001b[1;32m 4\u001b[0m param\u001b[38;5;241m.\u001b[39mdata \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m trainer\u001b[38;5;241m.\u001b[39mpeft_model\u001b[38;5;241m.\u001b[39mlora_weights[name]\n\u001b[1;32m 6\u001b[0m \u001b[38;5;66;03m# Save the model with fused weights\u001b[39;00m\n",
908
+ "\u001b[0;31mNameError\u001b[0m: name 'trainer' is not defined"
909
+ ]
910
+ }
911
+ ],
912
+ "source": []
913
+ },
914
+ {
915
+ "cell_type": "code",
916
+ "execution_count": null,
917
+ "id": "c7e7172e-db49-40f3-a0d6-9a87a3b2cf80",
918
+ "metadata": {
919
+ "execution": {
920
+ "iopub.status.busy": "2025-05-08T19:15:04.026597Z",
921
+ "iopub.status.idle": "2025-05-08T19:15:04.026984Z",
922
+ "shell.execute_reply": "2025-05-08T19:15:04.026875Z",
923
+ "shell.execute_reply.started": "2025-05-08T19:15:04.026863Z"
924
+ },
925
+ "scrolled": true
926
+ },
927
+ "outputs": [],
928
+ "source": [
929
+ "pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer)\n",
930
+ "rand_idx = randint(0, len(dataset[\"test\"]))\n",
931
+ "test_sample = hf_dataset[\"test\"][rand_idx]\n",
932
+ "stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<end_of_turn>\")]\n",
933
+ "prompt = pipe.tokenizer.apply_chat_template(test_sample[\"messages\"][:1], tokenize=False, add_generation_prompt=True)\n",
934
+ "\n",
935
+ "outputs = pipe(prompt, max_new_tokens=1024, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)"
936
+ ]
937
+ },
938
+ {
939
+ "cell_type": "code",
940
+ "execution_count": null,
941
+ "id": "de05b438-ca77-4b95-b2b1-32ea7ae033a5",
942
+ "metadata": {
943
+ "execution": {
944
+ "iopub.status.busy": "2025-05-08T19:15:04.028819Z",
945
+ "iopub.status.idle": "2025-05-08T19:15:04.029072Z",
946
+ "shell.execute_reply": "2025-05-08T19:15:04.028971Z",
947
+ "shell.execute_reply.started": "2025-05-08T19:15:04.028960Z"
948
+ }
949
+ },
950
+ "outputs": [],
951
+ "source": [
952
+ "start = outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip().find(\"{\")\n",
953
+ "end = outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip().rfind(\"}\")\n",
954
+ "print(start, end)\n",
955
+ "print(outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip()[start:end + 1])\n",
956
+ "json.loads(outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip()[start:end + 1])\n",
957
+ "rand_idx"
958
+ ]
959
+ },
960
+ {
961
+ "cell_type": "code",
962
+ "execution_count": null,
963
+ "id": "60b3da99-0edc-4ef6-b0e0-be7d046eaa02",
964
+ "metadata": {
965
+ "execution": {
966
+ "iopub.status.busy": "2025-05-08T19:15:04.030913Z",
967
+ "iopub.status.idle": "2025-05-08T19:15:04.031227Z",
968
+ "shell.execute_reply": "2025-05-08T19:15:04.031122Z",
969
+ "shell.execute_reply.started": "2025-05-08T19:15:04.031111Z"
970
+ }
971
+ },
972
+ "outputs": [],
973
+ "source": [
974
+ "json.loads(hf_dataset[\"test\"][81][\"messages\"][1]['content'])"
975
+ ]
976
+ },
977
+ {
978
+ "cell_type": "code",
979
+ "execution_count": null,
980
+ "id": "cdc44250-e3b9-4870-bce5-23f475023962",
981
+ "metadata": {
982
+ "execution": {
983
+ "iopub.status.busy": "2025-05-08T19:15:04.032999Z",
984
+ "iopub.status.idle": "2025-05-08T19:15:04.033327Z",
985
+ "shell.execute_reply": "2025-05-08T19:15:04.033207Z",
986
+ "shell.execute_reply.started": "2025-05-08T19:15:04.033196Z"
987
+ },
988
+ "scrolled": true
989
+ },
990
+ "outputs": [],
991
+ "source": [
992
+ "import torch\n",
993
+ "from transformers import pipeline\n",
994
+ "from random import randint\n",
995
+ "import re\n",
996
+ "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig\n",
997
+ "from transformers import AutoProcessor, Gemma3ForConditionalGeneration\n",
998
+ "device = torch.device(\"cuda:0\")\n",
999
+ "\n",
1000
+ "model_class = Gemma3ForConditionalGeneration\n",
1001
+ "torch_dtype = torch.bfloat16\n",
1002
+ "\n",
1003
+ "model_id = \"gemma-27b-tq_sft_finetuned-model\"\n",
1004
+ "model = model_class.from_pretrained(\n",
1005
+ " model_id,\n",
1006
+ " device_map=\"auto\",\n",
1007
+ " torch_dtype=torch_dtype,\n",
1008
+ " attn_implementation=\"eager\",\n",
1009
+ ")\n",
1010
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
1011
+ "pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer)"
1012
+ ]
1013
+ },
1014
+ {
1015
+ "cell_type": "code",
1016
+ "execution_count": null,
1017
+ "id": "ce4070a9-5291-477a-bb3f-867b7971e391",
1018
+ "metadata": {
1019
+ "execution": {
1020
+ "iopub.status.busy": "2025-05-08T19:15:04.035085Z",
1021
+ "iopub.status.idle": "2025-05-08T19:15:04.035416Z",
1022
+ "shell.execute_reply": "2025-05-08T19:15:04.035307Z",
1023
+ "shell.execute_reply.started": "2025-05-08T19:15:04.035295Z"
1024
+ }
1025
+ },
1026
+ "outputs": [],
1027
+ "source": [
1028
+ "def extract_json_data(json_string):\n",
1029
+ " key_pattern = r'\"(.*?)\"\\s*:\\s*'\n",
1030
+ " value_pattern = r'(?:\"(.*?)\"|(\\d+)|$$(.*?)$$|\\{(.*?)\\})'\n",
1031
+ " matches = re.finditer(key_pattern + value_pattern, json_string, re.DOTALL) \n",
1032
+ " data = {}\n",
1033
+ " for match in matches:\n",
1034
+ " key = match.group(1)\n",
1035
+ " value = match.group(2) or match.group(3) or match.group(4) or match.group(5) \n",
1036
+ " if value:\n",
1037
+ " try:\n",
1038
+ " value = json.loads(value)\n",
1039
+ " except (json.JSONDecodeError, TypeError):\n",
1040
+ " pass\n",
1041
+ " data[key] = value\n",
1042
+ " return data"
1043
+ ]
1044
+ },
1045
+ {
1046
+ "cell_type": "code",
1047
+ "execution_count": null,
1048
+ "id": "4940ab0c-ff5a-4c1e-a543-b0e8be91a4cb",
1049
+ "metadata": {
1050
+ "execution": {
1051
+ "iopub.status.busy": "2025-05-08T19:15:04.037234Z",
1052
+ "iopub.status.idle": "2025-05-08T19:15:04.037745Z",
1053
+ "shell.execute_reply": "2025-05-08T19:15:04.037637Z",
1054
+ "shell.execute_reply.started": "2025-05-08T19:15:04.037626Z"
1055
+ }
1056
+ },
1057
+ "outputs": [],
1058
+ "source": [
1059
+ "rand_idx = randint(0, len(dataset[\"test\"]))\n",
1060
+ "test_predictions = []\n",
1061
+ "\n",
1062
+ "index = 9\n",
1063
+ "\n",
1064
+ "meta_data = test_meta[index]\n",
1065
+ "stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<end_of_turn>\")]\n",
1066
+ "prompt = pipe.tokenizer.apply_chat_template(hf_dataset[\"test\"][index][\"messages\"][:1], tokenize=False, add_generation_prompt=True)\n",
1067
+ "outputs = pipe(prompt, max_new_tokens=2048, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)\n",
1068
+ "start = outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip().find(\"{\")\n",
1069
+ "end = outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip().rfind(\"}\")\n",
1070
+ "try:\n",
1071
+ " pred_dict = json.loads(outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip()[start:end + 1])\n",
1072
+ "except:\n",
1073
+ " start = outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip().find(\"{\")\n",
1074
+ " end = outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip().rfind(\"}\")\n",
1075
+ " pred_dict = outputs[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip()[start:end + 1]"
1076
+ ]
1077
+ },
1078
+ {
1079
+ "cell_type": "code",
1080
+ "execution_count": null,
1081
+ "id": "fdf03584-7cd0-40cc-af95-87279a2dc05e",
1082
+ "metadata": {
1083
+ "execution": {
1084
+ "iopub.status.busy": "2025-05-08T19:15:04.039492Z",
1085
+ "iopub.status.idle": "2025-05-08T19:15:04.039810Z",
1086
+ "shell.execute_reply": "2025-05-08T19:15:04.039704Z",
1087
+ "shell.execute_reply.started": "2025-05-08T19:15:04.039693Z"
1088
+ }
1089
+ },
1090
+ "outputs": [],
1091
+ "source": [
1092
+ "pred_dict"
1093
+ ]
1094
+ },
1095
+ {
1096
+ "cell_type": "code",
1097
+ "execution_count": null,
1098
+ "id": "80603718-a168-4e4c-aa55-842dfb20f265",
1099
+ "metadata": {
1100
+ "execution": {
1101
+ "iopub.status.busy": "2025-05-08T19:15:04.041594Z",
1102
+ "iopub.status.idle": "2025-05-08T19:15:04.041970Z",
1103
+ "shell.execute_reply": "2025-05-08T19:15:04.041865Z",
1104
+ "shell.execute_reply.started": "2025-05-08T19:15:04.041854Z"
1105
+ }
1106
+ },
1107
+ "outputs": [],
1108
+ "source": [
1109
+ "hf_dataset[\"test\"][index][\"messages\"][1]"
1110
+ ]
1111
+ },
1112
+ {
1113
+ "cell_type": "code",
1114
+ "execution_count": null,
1115
+ "id": "6d3731c9-4686-453f-8c91-e9477fe5541c",
1116
+ "metadata": {
1117
+ "execution": {
1118
+ "iopub.status.busy": "2025-05-08T19:15:04.043675Z",
1119
+ "iopub.status.idle": "2025-05-08T19:15:04.043977Z",
1120
+ "shell.execute_reply": "2025-05-08T19:15:04.043872Z",
1121
+ "shell.execute_reply.started": "2025-05-08T19:15:04.043861Z"
1122
+ }
1123
+ },
1124
+ "outputs": [],
1125
+ "source": [
1126
+ "batch_size = 8\n",
1127
+ "test_predictions = []\n",
1128
+ "\n",
1129
+ "for i in tqdm(range(0, len(hf_dataset[\"test\"]), batch_size)):\n",
1130
+ " batch_samples = hf_dataset[\"test\"][i:i + batch_size][\"messages\"]\n",
1131
+ " batch_meta = test_meta[i:i + batch_size]\n",
1132
+ " prompts = [\n",
1133
+ " pipe.tokenizer.apply_chat_template(sample[:1], tokenize=False, add_generation_prompt=True)\n",
1134
+ " for sample in batch_samples\n",
1135
+ " ]\n",
1136
+ " outputs = pipe(prompts, max_new_tokens=2048, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<end_of_turn>\")], disable_compile=True)\n",
1137
+ "\n",
1138
+ " for index, output in tqdm(enumerate(tqdm(outputs))):\n",
1139
+ " output_dict = {}\n",
1140
+ " start = output[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip().find(\"{\")\n",
1141
+ " end = output[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip().rfind(\"}\")\n",
1142
+ " try:\n",
1143
+ " pred_dict = json.loads(output[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip()[start:end + 1])\n",
1144
+ " except:\n",
1145
+ " pred_dict = output[0]['generated_text'].split(r\"<start_of_turn>model\")[1].strip()[start:end + 1]\n",
1146
+ " \n",
1147
+ " output_dict.update(batch_meta[index])\n",
1148
+ " output_dict[\"predictions\"] = pred_dict\n",
1149
+ " output_dict[\"human-annotation\"] = batch_samples[index][1]['content']\n",
1150
+ " output_dict[\"prompt\"] = batch_samples[index][0]['content']\n",
1151
+ " test_predictions.append(output_dict)"
1152
+ ]
1153
+ },
1154
+ {
1155
+ "cell_type": "code",
1156
+ "execution_count": null,
1157
+ "id": "616eb30a-eac2-4229-b86c-24eca7534cc6",
1158
+ "metadata": {
1159
+ "execution": {
1160
+ "iopub.status.busy": "2025-05-08T19:15:04.045755Z",
1161
+ "iopub.status.idle": "2025-05-08T19:15:04.046057Z",
1162
+ "shell.execute_reply": "2025-05-08T19:15:04.045954Z",
1163
+ "shell.execute_reply.started": "2025-05-08T19:15:04.045943Z"
1164
+ }
1165
+ },
1166
+ "outputs": [],
1167
+ "source": [
1168
+ "with open(\"/root/notebooks/trashspace/gemma_finetuned_expertdata/test_pred.json\", 'w') as json_file:\n",
1169
+ " json.dump(test_predictions, json_file)"
1170
+ ]
1171
+ },
1172
+ {
1173
+ "cell_type": "code",
1174
+ "execution_count": null,
1175
+ "id": "71480057-d6b9-4499-a8d7-26bf3f3f9342",
1176
+ "metadata": {},
1177
+ "outputs": [],
1178
+ "source": []
1179
+ },
1180
+ {
1181
+ "cell_type": "code",
1182
+ "execution_count": null,
1183
+ "id": "ce73604c-4bbb-46c7-8433-d957b0e10405",
1184
+ "metadata": {},
1185
+ "outputs": [],
1186
+ "source": []
1187
+ },
1188
+ {
1189
+ "cell_type": "code",
1190
+ "execution_count": null,
1191
+ "id": "4741a722-a772-44bf-949e-e77671a4ef03",
1192
+ "metadata": {},
1193
+ "outputs": [],
1194
+ "source": []
1195
+ },
1196
+ {
1197
+ "cell_type": "code",
1198
+ "execution_count": null,
1199
+ "id": "07c2adef-c2f6-4ba9-98f7-277cce2701d0",
1200
+ "metadata": {},
1201
+ "outputs": [],
1202
+ "source": []
1203
+ },
1204
+ {
1205
+ "cell_type": "code",
1206
+ "execution_count": null,
1207
+ "id": "1adfa27b-9bfa-4479-be3f-5149a2237c1f",
1208
+ "metadata": {
1209
+ "execution": {
1210
+ "iopub.status.busy": "2025-05-08T19:15:04.047823Z",
1211
+ "iopub.status.idle": "2025-05-08T19:15:04.048130Z",
1212
+ "shell.execute_reply": "2025-05-08T19:15:04.048026Z",
1213
+ "shell.execute_reply.started": "2025-05-08T19:15:04.048015Z"
1214
+ }
1215
+ },
1216
+ "outputs": [],
1217
+ "source": [
1218
+ "data = json.loads(test_sample['messages'][1]['content'])\n",
1219
+ "data"
1220
+ ]
1221
+ },
1222
+ {
1223
+ "cell_type": "code",
1224
+ "execution_count": null,
1225
+ "id": "750d3454-6300-469b-bdc3-77cce45a00ce",
1226
+ "metadata": {
1227
+ "execution": {
1228
+ "iopub.status.busy": "2025-05-08T19:15:04.049897Z",
1229
+ "iopub.status.idle": "2025-05-08T19:15:04.050203Z",
1230
+ "shell.execute_reply": "2025-05-08T19:15:04.050099Z",
1231
+ "shell.execute_reply.started": "2025-05-08T19:15:04.050088Z"
1232
+ }
1233
+ },
1234
+ "outputs": [],
1235
+ "source": [
1236
+ "print(len(hf_dataset[\"test\"]))"
1237
+ ]
1238
+ },
1239
+ {
1240
+ "cell_type": "code",
1241
+ "execution_count": null,
1242
+ "id": "3be45a2d-336f-4899-a8e9-e000437fab8c",
1243
+ "metadata": {},
1244
+ "outputs": [],
1245
+ "source": []
1246
+ },
1247
+ {
1248
+ "cell_type": "code",
1249
+ "execution_count": null,
1250
+ "id": "248182ff-bec8-46ff-bc34-14b523d877bf",
1251
+ "metadata": {},
1252
+ "outputs": [],
1253
+ "source": []
1254
+ }
1255
+ ],
1256
+ "metadata": {
1257
+ "kernelspec": {
1258
+ "display_name": "timedlibs",
1259
+ "language": "python",
1260
+ "name": "timedlibs"
1261
+ },
1262
+ "language_info": {
1263
+ "codemirror_mode": {
1264
+ "name": "ipython",
1265
+ "version": 3
1266
+ },
1267
+ "file_extension": ".py",
1268
+ "mimetype": "text/x-python",
1269
+ "name": "python",
1270
+ "nbconvert_exporter": "python",
1271
+ "pygments_lexer": "ipython3",
1272
+ "version": "3.10.16"
1273
+ }
1274
+ },
1275
+ "nbformat": 4,
1276
+ "nbformat_minor": 5
1277
+ }
README.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: google/gemma-3-4b-it
3
+ library_name: transformers
4
+ model_name: TQTune
5
+ tags:
6
+ - generated_from_trainer
7
+ - trl
8
+ - sft
9
+ licence: license
10
+ ---
11
+
12
+ # Model Card for TQTune
13
+
14
+ This model is a fine-tuned version of [google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it).
15
+ It has been trained using [TRL](https://github.com/huggingface/trl).
16
+
17
+ ## Quick start
18
+
19
+ ```python
20
+ from transformers import pipeline
21
+
22
+ question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?"
23
+ generator = pipeline("text-generation", model="bhavinjawade/TQTune", device="cuda")
24
+ output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0]
25
+ print(output["generated_text"])
26
+ ```
27
+
28
+ ## Training procedure
29
+
30
+
31
+
32
+
33
+ This model was trained with SFT.
34
+
35
+ ### Framework versions
36
+
37
+ - TRL: 0.16.1
38
+ - Transformers: 4.50.0.dev0
39
+ - Pytorch: 2.6.0+cu124
40
+ - Datasets: 3.3.2
41
+ - Tokenizers: 0.21.0
42
+
43
+ ## Citations
44
+
45
+
46
+
47
+ Cite TRL as:
48
+
49
+ ```bibtex
50
+ @misc{vonwerra2022trl,
51
+ title = {{TRL: Transformer Reinforcement Learning}},
52
+ author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
53
+ year = 2020,
54
+ journal = {GitHub repository},
55
+ publisher = {GitHub},
56
+ howpublished = {\url{https://github.com/huggingface/trl}}
57
+ }
58
+ ```
SFT_Expert.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, Trainer, TrainingArguments, DataCollatorForSeq2Seq
2
+ import torch
3
+ from peft import LoraConfig, get_peft_model
4
+ import os
5
+ from tqdm import tqdm
6
+ import json
7
+ import random
8
+ from datasets import load_dataset
9
+ from datasets import Dataset, DatasetDict
10
+
11
+ system_message = "You are a helpful assistant who is an expert in estimating quality of translations."
12
+
13
+ output_template = '''
14
+ {
15
+ "Accuracy Issues": [
16
+ {
17
+ "Error Span": "",
18
+ "Error Explanation": "",
19
+ "Error Quality Category": "",
20
+ "Error Quality Tags": [],
21
+ "Error Severity": ""
22
+ }
23
+ ],
24
+ "Accuracy Score": "",
25
+ "Readability Issues": [
26
+ {
27
+ "Error Span": "",
28
+ "Error Explanation": "",
29
+ "Error Quality Category": "",
30
+ "Error Quality Tags": [],
31
+ "Error Severity": ""
32
+ }
33
+ ],
34
+ "Readability Score": ""
35
+ }'''
36
+
37
+ def create_conversation(input_sample, output_sample):
38
+ return {
39
+ "messages": [
40
+ # {"role": "system", "content": system_message},
41
+ {"role": "user", "content": input_sample},
42
+ {"role": "assistant", "content": output_sample}
43
+ ]
44
+ }
45
+
46
+ data_path = (
47
+ "/root/notebooks/MT_TQ/TQ/TQTune/labeled_data/parsed/"
48
+ )
49
+
50
+ json_files = [
51
+ os.path.join(root, file)
52
+ for root, _, files in os.walk(data_path)
53
+ for file in files
54
+ if file.endswith(".json") and "PLDL" in file
55
+ ]
56
+
57
+ training_samples = []
58
+ for json_file in tqdm(json_files):
59
+ with open(json_file, "r") as file:
60
+ data = json.load(file)
61
+ sampled_items = random.sample(data["data"], 20)
62
+ training_samples.extend(sampled_items)
63
+
64
+ datapoints = []
65
+
66
+ for sample in training_samples:
67
+ datapoint = {"input": {}}
68
+ datapoint["input"]["src_text"] = sample["main_src_text"]
69
+ datapoint["input"]["tgt_text"] = sample["tgt_text"]
70
+ datapoint["input"]["src_prev"] = sample["tt_src_prev"]
71
+ datapoint["input"]["src_next"] = sample["tt_src_next"]
72
+ datapoint["input"]["tgt_prev"] = sample["tt_tgt_prev"]
73
+ datapoint["input"]["tgt_next"] = sample["tt_tgt_next"]
74
+ datapoint["input"]["src_lang"] = sample["src_lang"]
75
+ datapoint["input"]["tgt_lang"] = sample["tgt_lang"]
76
+ datapoint["evaluation"] = sample["labelers"][0]["annotation"]
77
+ datapoints.append(datapoint)
78
+
79
+ def dataset_prep(datapoints, test_size=0.2):
80
+ with open("prompts.txt") as file:
81
+ template_string = file.read()
82
+
83
+ random.shuffle(datapoints)
84
+
85
+ split_index = int(len(datapoints) * (1 - test_size))
86
+ train_datapoints = datapoints[:split_index]
87
+ test_datapoints = datapoints[split_index:]
88
+
89
+ def create_dataset(datapoints):
90
+ dataset = []
91
+ for datapoint in datapoints:
92
+ src_text = datapoint['input']['src_text']
93
+ tgt_text = datapoint['input']['tgt_text']
94
+ src_prev = datapoint['input']['src_prev']
95
+ src_next = datapoint['input']['src_next']
96
+ tgt_prev = datapoint['input']['tgt_prev']
97
+ tgt_next = datapoint['input']['tgt_next']
98
+ src_lang = datapoint['input']['src_lang']
99
+ tgt_lang = datapoint['input']['tgt_lang']
100
+ output = datapoint['evaluation']
101
+ del output["Confidence Level"]
102
+ del output["Main Vs Alternate"]
103
+ del output["Score"]
104
+
105
+ if len(output['Accuracy Issues']) != 0 and len(output['Readability Issues']) != 0:
106
+ item = template_string.format(src_text=src_text, tgt_text=tgt_text,
107
+ src_prev=src_prev, src_next=src_next,
108
+ tgt_prev=tgt_prev, tgt_next=tgt_next,
109
+ src_lang=src_lang, tgt_lang=tgt_lang,
110
+ template=output_template)
111
+
112
+ dataset.append(create_conversation(item, json.dumps(output)))
113
+
114
+ return dataset
115
+
116
+ train_set = create_dataset(train_datapoints)
117
+ test_set = create_dataset(test_datapoints)
118
+
119
+ return train_set, test_set
120
+
121
+ train_dataset, test_dataset = dataset_prep(datapoints)
122
+ dataset = {"train": train_dataset, "test": test_dataset}
123
+
124
+ def convert_to_hf_dataset(dataset):
125
+ # Convert the train and test datasets into Hugging Face Dataset objects
126
+ train_dataset = Dataset.from_list(dataset['train'])
127
+ test_dataset = Dataset.from_list(dataset['test'])
128
+
129
+ # Combine them into a DatasetDict
130
+ hf_dataset = DatasetDict({
131
+ 'train': train_dataset,
132
+ 'test': test_dataset
133
+ })
134
+
135
+ return hf_dataset
136
+
137
+ # Convert your dataset into a Hugging Face Dataset object
138
+ hf_dataset = convert_to_hf_dataset(dataset)
139
+
140
+ # Now you can use hf_dataset for your machine learning tasks
141
+ print(hf_dataset)
142
+
143
+ import torch
144
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig
145
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration
146
+ device = torch.device("cuda:0")
147
+
148
+ # Hugging Face model id
149
+ model_id = "google/gemma-3-12b-it" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`
150
+
151
+ # Select model class based on id
152
+ if model_id == "google/gemma-3-12b-it":
153
+ model_class = Gemma3ForConditionalGeneration
154
+ else:
155
+ model_class = AutoModelForImageTextToText
156
+
157
+ torch_dtype = torch.bfloat16
158
+
159
+ model_kwargs = dict(
160
+ attn_implementation="eager",
161
+ torch_dtype=torch_dtype,
162
+ device_map="auto", # Change from {'': 0} to "auto"
163
+ )
164
+
165
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
166
+ load_in_8bit=True,
167
+ bnb_8bit_use_double_quant=True,
168
+ bnb_8bit_quant_type='nf8',
169
+ bnb_8bit_compute_dtype=model_kwargs['torch_dtype'],
170
+ bnb_8bit_quant_storage=model_kwargs['torch_dtype'],
171
+ )
172
+
173
+ model = model_class.from_pretrained(model_id, **model_kwargs)
174
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-12b-it") # Load the Instruction Tokenizer to use the official Gemma template
175
+
176
+ from peft import LoraConfig
177
+
178
+ peft_config = LoraConfig(
179
+ lora_alpha=128,
180
+ lora_dropout=0.05,
181
+ r=16,
182
+ bias="none",
183
+ target_modules="all-linear",
184
+ task_type="CAUSAL_LM",
185
+ modules_to_save=["lm_head", "embed_tokens"] # make sure to save the lm_head and embed_tokens as you train the special tokens
186
+ )
187
+
188
+ from trl import SFTConfig
189
+
190
+ args = SFTConfig(
191
+ output_dir="gemma-12b-tq-model",
192
+ max_seq_length=512,
193
+ packing=True,
194
+ num_train_epochs=1,
195
+ per_device_train_batch_size=1,
196
+ gradient_accumulation_steps=4,
197
+ gradient_checkpointing=True,
198
+ optim="adamw_torch_fused",
199
+ logging_steps=1,
200
+ save_strategy="epoch",
201
+ learning_rate=2e-4,
202
+ fp16=True if torch_dtype == torch.float16 else False,
203
+ bf16=True if torch_dtype == torch.bfloat16 else False,
204
+ max_grad_norm=0.3,
205
+ warmup_ratio=0.03,
206
+ lr_scheduler_type="constant",
207
+ push_to_hub=True,
208
+ report_to="tensorboard",
209
+ dataset_kwargs={
210
+ "add_special_tokens": False,
211
+ "append_concat_token": True,
212
+ },
213
+ ddp_find_unused_parameters=False,
214
+ no_cuda=False,
215
+ )
216
+
217
+ from trl import SFTTrainer
218
+
219
+ # Create Trainer object
220
+ trainer = SFTTrainer(
221
+ model=model,
222
+ args=args,
223
+ train_dataset=hf_dataset["train"],
224
+ peft_config=peft_config,
225
+ processing_class=tokenizer
226
+ )
227
+
228
+ trainer.train()
229
+ trainer.save_model()
TQ_template.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "src_lang": "en",
3
+ "tgt_lang": "de",
4
+ "src_prev": "Alice: Hi Bob, how are you?\nBob: I'm good, thanks!",
5
+ "tgt_prev": "Alice: Hallo Bob, wie geht's?\nBob: Mir geht's gut, danke!",
6
+ "src_next": "Alice: Want to grab coffee later?\nBob: Sure, sounds good.",
7
+ "tgt_next": "Alice: Möchtest du später einen Kaffee trinken?\nBob: Klar, klingt gut.",
8
+ "src_text": "Bob: I just got back from Paris.",
9
+ "main_text": "This is the main text",
10
+ "alternate_text": "This is the alternate text",
11
+ "evaluation": {
12
+ "Accuracy Issues": [
13
+ {
14
+ "Error Span": [5,8],
15
+ "Error Explanation": "Incorrect translation of 'just got back' as 'gerade aus' instead of 'gerade zurück'.",
16
+ "Error Quality Category": "Fidelity",
17
+ "Error Quality Tags": ["terminology", "accuracy"],
18
+ "Error Severity": "Major"
19
+ }
20
+ ],
21
+ "Accuracy Score": "4", # ["projects"]["What ever key is there"]["labels"][index here of the labels]["annotations"]["classifications"][1]["radio_answer"]["name"]
22
+ "Readability Issues": [
23
+ {
24
+ "Error Location": "Src", # ["projects"]["What ever key is there"]["labels"][index here of the labels]["annotations"]["objects"][index of it]["conversational_location"]["message_id"]
25
+ "Error Span": [0,2], # ["projects"]["What ever key is there"]["labels"][index here of the labels]["annotations"]["objects"][index of it]["conversational_location"]["location"]["start" and "end" use them to make this list (start, end)]
26
+ "Error Explanation": "Sentence structure is awkward in German translation.", # ["projects"]["What ever key is there"]["labels"][index here of the labels]["annotations"]["objects"][index of it]["classification"][2]["text_answer"]["content"]
27
+ "Error Quality Category": "Style", # ["projects"]["What ever key is there"]["labels"][index here of the labels]["annotations"]["objects"][index of it]["name"] - here if the name is "Style" - put it under Readability Issues else put it under Accuracy Issues.
28
+ "Error Quality Tags": ["awkward", "structure"], # ["projects"]["What ever key is there"]["labels"][index here of the labels]["annotations"]["objects"][index of it]["classification"][1]["checklist_answers"][list of dicts, take name keys for all and make list of it]
29
+ "Error Severity": "Minor" # ["projects"]["What ever key is there"]["labels"][index here of the labels]["annotations"]["objects"][index of it]["classification"][0]["radio_answer"]["name"]
30
+ }
31
+ ],
32
+ "Readability Score": "3", # ["projects"]["What ever key is there"]["labels"][index here of the labels]["classifications"][2]["radio_answer"]["name"]
33
+ "Confidence Level": "the_translation_is_excellent_without_any_error_spans_and_no_creative_liberties_were_taken", # # ["projects"]["What ever key is there"]["labels"][index here of the labels]["classifications"][3]["radio_answer"]["name"]
34
+ "Main Vs Alternate": "Both of them have roughly the same quality" # ["projects"]["What ever key is there"]["labels"][index here of the labels]["classifications"][0]["radio_answer"]["name"]
35
+ },
36
+ "Score": "26"
37
+ }
TextGrad_Optimization.ipynb ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 10,
6
+ "id": "8b3ee6e2-ca9c-40fa-b4c6-a9596f075f79",
7
+ "metadata": {
8
+ "execution": {
9
+ "iopub.execute_input": "2025-05-09T17:36:47.763713Z",
10
+ "iopub.status.busy": "2025-05-09T17:36:47.763339Z",
11
+ "iopub.status.idle": "2025-05-09T17:36:47.768648Z",
12
+ "shell.execute_reply": "2025-05-09T17:36:47.768166Z",
13
+ "shell.execute_reply.started": "2025-05-09T17:36:47.763676Z"
14
+ }
15
+ },
16
+ "outputs": [
17
+ {
18
+ "name": "stdout",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "env: OPENAI_API_KEY=\"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\"\n"
22
+ ]
23
+ }
24
+ ],
25
+ "source": [
26
+ "import argparse\n",
27
+ "import concurrent\n",
28
+ "from dotenv import load_dotenv\n",
29
+ "from tqdm import tqdm\n",
30
+ "import textgrad as tg\n",
31
+ "from textgrad.tasks import load_task\n",
32
+ "import numpy as np\n",
33
+ "import random\n",
34
+ "load_dotenv(override=True)\n",
35
+ "import os\n",
36
+ "import json\n",
37
+ "\n",
38
+ "%env OPENAI_API_KEY=\"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\""
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": 4,
44
+ "id": "4ec9a29b-9162-4fe3-b32d-4de4397c6483",
45
+ "metadata": {
46
+ "execution": {
47
+ "iopub.execute_input": "2025-05-09T17:33:04.417822Z",
48
+ "iopub.status.busy": "2025-05-09T17:33:04.417437Z",
49
+ "iopub.status.idle": "2025-05-09T17:33:04.429505Z",
50
+ "shell.execute_reply": "2025-05-09T17:33:04.429029Z",
51
+ "shell.execute_reply.started": "2025-05-09T17:33:04.417795Z"
52
+ }
53
+ },
54
+ "outputs": [
55
+ {
56
+ "name": "stderr",
57
+ "output_type": "stream",
58
+ "text": [
59
+ "0it [00:00, ?it/s]\n"
60
+ ]
61
+ }
62
+ ],
63
+ "source": [
64
+ "data_path = \"/root/notebooks/MT_TQ/TQ/DataPrep_Prompting_Experiments/labeled_data/parsed/\"\n",
65
+ "json_files = [os.path.join(root, file) for root, _, files in os.walk(data_path) for file in files if file.endswith('.json') and 'PLDL' in file]\n",
66
+ "\n",
67
+ "training_samples = []\n",
68
+ "for json_file in tqdm(json_files):\n",
69
+ " with open(json_file, 'r') as file:\n",
70
+ " data = json.load(file)\n",
71
+ " sampled_items = random.sample(data[\"data\"], 20)\n",
72
+ " training_samples.extend(sampled_items)\n",
73
+ "\n",
74
+ "datapoints = []\n",
75
+ "\n",
76
+ "for sample in training_samples:\n",
77
+ " datapoint = {\"input\":{}}\n",
78
+ " datapoint[\"input\"][\"src_text\"] = sample[\"main_src_text\"]\n",
79
+ " datapoint[\"input\"][\"tgt_text\"] = sample[\"tgt_text\"]\n",
80
+ " datapoint[\"input\"][\"src_prev\"] = sample[\"tt_src_prev\"]\n",
81
+ " datapoint[\"input\"][\"src_next\"] = sample[\"tt_src_next\"]\n",
82
+ " datapoint[\"input\"][\"tgt_prev\"] = sample[\"tt_tgt_prev\"]\n",
83
+ " datapoint[\"input\"][\"tgt_next\"] = sample[\"tt_tgt_next\"]\n",
84
+ " datapoint[\"input\"][\"src_lang\"] = sample[\"src_lang\"]\n",
85
+ " datapoint[\"input\"][\"tgt_lang\"] = sample[\"tgt_lang\"]\n",
86
+ " datapoint[\"evaluation\"] = sample[\"labelers\"][0][\"annotation\"]\n",
87
+ " datapoints.append(datapoint)"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": 5,
93
+ "id": "a894ce72-d451-44fa-aaa5-85bf8e6dc9da",
94
+ "metadata": {
95
+ "execution": {
96
+ "iopub.execute_input": "2025-05-09T17:33:40.240759Z",
97
+ "iopub.status.busy": "2025-05-09T17:33:40.240243Z",
98
+ "iopub.status.idle": "2025-05-09T17:33:40.244435Z",
99
+ "shell.execute_reply": "2025-05-09T17:33:40.243818Z",
100
+ "shell.execute_reply.started": "2025-05-09T17:33:40.240720Z"
101
+ }
102
+ },
103
+ "outputs": [],
104
+ "source": [
105
+ "def set_seed(seed):\n",
106
+ " np.random.seed(seed)\n",
107
+ " random.seed(seed)"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": 6,
113
+ "id": "4eeaa266-3ca2-4360-b80b-b38aa3bbdb70",
114
+ "metadata": {
115
+ "execution": {
116
+ "iopub.execute_input": "2025-05-09T17:33:55.982807Z",
117
+ "iopub.status.busy": "2025-05-09T17:33:55.982080Z",
118
+ "iopub.status.idle": "2025-05-09T17:33:55.988522Z",
119
+ "shell.execute_reply": "2025-05-09T17:33:55.987924Z",
120
+ "shell.execute_reply.started": "2025-05-09T17:33:55.982770Z"
121
+ }
122
+ },
123
+ "outputs": [],
124
+ "source": [
125
+ "def eval_sample(item, eval_fn, model):\n",
126
+ " \"\"\"\n",
127
+ " This function allows us to evaluate if an answer to a question in the prompt is a good answer.\n",
128
+ "\n",
129
+ " \"\"\"\n",
130
+ " x, y = item\n",
131
+ " x = tg.Variable(x, requires_grad=False, role_description=\"query to the language model\")\n",
132
+ " y = tg.Variable(y, requires_grad=False, role_description=\"correct answer for the query\")\n",
133
+ " response = model(x)\n",
134
+ " try:\n",
135
+ " eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))\n",
136
+ " return int(eval_output_variable.value)\n",
137
+ " except:\n",
138
+ " eval_output_variable = eval_fn([x, y, response])\n",
139
+ " eval_output_parsed = eval_fn.parse_output(eval_output_variable)\n",
140
+ " return int(eval_output_parsed)"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": 7,
146
+ "id": "c7e57f9d-c0ff-4139-9e61-b93510599353",
147
+ "metadata": {
148
+ "execution": {
149
+ "iopub.execute_input": "2025-05-09T17:34:08.606301Z",
150
+ "iopub.status.busy": "2025-05-09T17:34:08.605538Z",
151
+ "iopub.status.idle": "2025-05-09T17:34:08.612515Z",
152
+ "shell.execute_reply": "2025-05-09T17:34:08.611911Z",
153
+ "shell.execute_reply.started": "2025-05-09T17:34:08.606262Z"
154
+ }
155
+ },
156
+ "outputs": [],
157
+ "source": [
158
+ "def eval_dataset(test_set, eval_fn, model, max_samples: int=None):\n",
159
+ " if max_samples is None:\n",
160
+ " max_samples = len(test_set)\n",
161
+ " accuracy_list = []\n",
162
+ " with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:\n",
163
+ " futures = []\n",
164
+ " for _, sample in enumerate(test_set):\n",
165
+ " \n",
166
+ " future = executor.submit(eval_sample, sample, eval_fn, model)\n",
167
+ " futures.append(future)\n",
168
+ " if len(futures) >= max_samples:\n",
169
+ " break\n",
170
+ " tqdm_loader = tqdm(concurrent.futures.as_completed(futures), total=len(futures), position=0)\n",
171
+ " for future in tqdm_loader:\n",
172
+ " acc_item = future.result()\n",
173
+ " accuracy_list.append(acc_item)\n",
174
+ " tqdm_loader.set_description(f\"Accuracy: {np.mean(accuracy_list)}\")\n",
175
+ " return accuracy_list "
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": 8,
181
+ "id": "039af9f3-a124-4a50-98a7-e728a913c069",
182
+ "metadata": {
183
+ "execution": {
184
+ "iopub.execute_input": "2025-05-09T17:34:22.703336Z",
185
+ "iopub.status.busy": "2025-05-09T17:34:22.702980Z",
186
+ "iopub.status.idle": "2025-05-09T17:34:22.707253Z",
187
+ "shell.execute_reply": "2025-05-09T17:34:22.706781Z",
188
+ "shell.execute_reply.started": "2025-05-09T17:34:22.703313Z"
189
+ }
190
+ },
191
+ "outputs": [],
192
+ "source": [
193
+ "def run_validation_revert(system_prompt: tg.Variable, results, model, eval_fn, val_set):\n",
194
+ " val_performance = np.mean(eval_dataset(val_set, eval_fn, model))\n",
195
+ " previous_performance = np.mean(results[\"validation_acc\"][-1])\n",
196
+ " print(\"val_performance: \", val_performance)\n",
197
+ " print(\"previous_performance: \", previous_performance)\n",
198
+ " previous_prompt = results[\"prompt\"][-1]\n",
199
+ " \n",
200
+ " if val_performance < previous_performance:\n",
201
+ " print(f\"rejected prompt: {system_prompt.value}\")\n",
202
+ " system_prompt.set_value(previous_prompt)\n",
203
+ " val_performance = previous_performance\n",
204
+ "\n",
205
+ " results[\"validation_acc\"].append(val_performance)"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": 14,
211
+ "id": "031ebb6e-f5ff-45b0-a810-d1bd81ef6d2a",
212
+ "metadata": {
213
+ "execution": {
214
+ "iopub.execute_input": "2025-05-09T17:40:38.476352Z",
215
+ "iopub.status.busy": "2025-05-09T17:40:38.475979Z",
216
+ "iopub.status.idle": "2025-05-09T17:40:38.701947Z",
217
+ "shell.execute_reply": "2025-05-09T17:40:38.701394Z",
218
+ "shell.execute_reply.started": "2025-05-09T17:40:38.476327Z"
219
+ }
220
+ },
221
+ "outputs": [
222
+ {
223
+ "name": "stdout",
224
+ "output_type": "stream",
225
+ "text": [
226
+ "Train/Val/Test Set Lengths: 50 100 100\n"
227
+ ]
228
+ }
229
+ ],
230
+ "source": [
231
+ "set_seed(12)\n",
232
+ "llm_api_eval = tg.get_engine(engine_name=\"gpt-4o\")\n",
233
+ "llm_api_test = tg.get_engine(engine_name=\"gpt-3.5-turbo-0125\")\n",
234
+ "tg.set_backward_engine(llm_api_eval, override=True)\n",
235
+ "\n",
236
+ "# Load the data and the evaluation function\n",
237
+ "train_set, val_set, test_set, eval_fn = load_task(\"BBH_object_counting\", evaluation_api=llm_api_eval)\n",
238
+ "print(\"Train/Val/Test Set Lengths: \", len(train_set), len(val_set), len(test_set))\n",
239
+ "STARTING_SYSTEM_PROMPT = train_set.get_task_description()"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": 15,
245
+ "id": "bde34303-2f52-415f-b117-264e266b84f0",
246
+ "metadata": {
247
+ "execution": {
248
+ "iopub.execute_input": "2025-05-09T17:40:39.330651Z",
249
+ "iopub.status.busy": "2025-05-09T17:40:39.330285Z",
250
+ "iopub.status.idle": "2025-05-09T17:40:39.398820Z",
251
+ "shell.execute_reply": "2025-05-09T17:40:39.398116Z",
252
+ "shell.execute_reply.started": "2025-05-09T17:40:39.330626Z"
253
+ }
254
+ },
255
+ "outputs": [
256
+ {
257
+ "name": "stderr",
258
+ "output_type": "stream",
259
+ "text": [
260
+ " 0%| | 0/100 [00:00<?, ?it/s]\n"
261
+ ]
262
+ },
263
+ {
264
+ "ename": "AssertionError",
265
+ "evalue": "Value must be a string, int, or image (bytes). Got: <class 'numpy.int64'>",
266
+ "output_type": "error",
267
+ "traceback": [
268
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
269
+ "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
270
+ "Cell \u001b[0;32mIn[15], line 18\u001b[0m\n\u001b[1;32m 15\u001b[0m optimizer \u001b[38;5;241m=\u001b[39m tg\u001b[38;5;241m.\u001b[39mTextualGradientDescent(engine\u001b[38;5;241m=\u001b[39mllm_api_eval, parameters\u001b[38;5;241m=\u001b[39m[system_prompt])\n\u001b[1;32m 17\u001b[0m results \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_acc\u001b[39m\u001b[38;5;124m\"\u001b[39m: [], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mprompt\u001b[39m\u001b[38;5;124m\"\u001b[39m: [], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_acc\u001b[39m\u001b[38;5;124m\"\u001b[39m: []}\n\u001b[0;32m---> 18\u001b[0m results[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_acc\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mappend(\u001b[43meval_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_set\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meval_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 19\u001b[0m results[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_acc\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mappend(eval_dataset(val_set, eval_fn, model))\n\u001b[1;32m 20\u001b[0m results[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mprompt\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mappend(system_prompt\u001b[38;5;241m.\u001b[39mget_value())\n",
271
+ "Cell \u001b[0;32mIn[7], line 15\u001b[0m, in \u001b[0;36meval_dataset\u001b[0;34m(test_set, eval_fn, model, max_samples)\u001b[0m\n\u001b[1;32m 13\u001b[0m tqdm_loader \u001b[38;5;241m=\u001b[39m tqdm(concurrent\u001b[38;5;241m.\u001b[39mfutures\u001b[38;5;241m.\u001b[39mas_completed(futures), total\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(futures), position\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m future \u001b[38;5;129;01min\u001b[39;00m tqdm_loader:\n\u001b[0;32m---> 15\u001b[0m acc_item \u001b[38;5;241m=\u001b[39m \u001b[43mfuture\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mresult\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m accuracy_list\u001b[38;5;241m.\u001b[39mappend(acc_item)\n\u001b[1;32m 17\u001b[0m tqdm_loader\u001b[38;5;241m.\u001b[39mset_description(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAccuracy: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnp\u001b[38;5;241m.\u001b[39mmean(accuracy_list)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
272
+ "File \u001b[0;32m/apps/python3.10/lib/python3.10/concurrent/futures/_base.py:451\u001b[0m, in \u001b[0;36mFuture.result\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 449\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m CancelledError()\n\u001b[1;32m 450\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_state \u001b[38;5;241m==\u001b[39m FINISHED:\n\u001b[0;32m--> 451\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__get_result\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 453\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_condition\u001b[38;5;241m.\u001b[39mwait(timeout)\n\u001b[1;32m 455\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_state \u001b[38;5;129;01min\u001b[39;00m [CANCELLED, CANCELLED_AND_NOTIFIED]:\n",
273
+ "File \u001b[0;32m/apps/python3.10/lib/python3.10/concurrent/futures/_base.py:403\u001b[0m, in \u001b[0;36mFuture.__get_result\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 401\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_exception:\n\u001b[1;32m 402\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 403\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_exception\n\u001b[1;32m 404\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 405\u001b[0m \u001b[38;5;66;03m# Break a reference cycle with the exception in self._exception\u001b[39;00m\n\u001b[1;32m 406\u001b[0m \u001b[38;5;28mself\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
274
+ "File \u001b[0;32m/apps/python3.10/lib/python3.10/concurrent/futures/thread.py:58\u001b[0m, in \u001b[0;36m_WorkItem.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 58\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m exc:\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfuture\u001b[38;5;241m.\u001b[39mset_exception(exc)\n",
275
+ "Cell \u001b[0;32mIn[6], line 8\u001b[0m, in \u001b[0;36meval_sample\u001b[0;34m(item, eval_fn, model)\u001b[0m\n\u001b[1;32m 6\u001b[0m x, y \u001b[38;5;241m=\u001b[39m item\n\u001b[1;32m 7\u001b[0m x \u001b[38;5;241m=\u001b[39m tg\u001b[38;5;241m.\u001b[39mVariable(x, requires_grad\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, role_description\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquery to the language model\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 8\u001b[0m y \u001b[38;5;241m=\u001b[39m \u001b[43mtg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mVariable\u001b[49m\u001b[43m(\u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrequires_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrole_description\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcorrect answer for the query\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m response \u001b[38;5;241m=\u001b[39m model(x)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n",
276
+ "File \u001b[0;32m~/notebooks/MT_TQ/Libraries/timedlibs/lib/python3.10/site-packages/textgrad/variable.py:43\u001b[0m, in \u001b[0;36mVariable.__init__\u001b[0;34m(self, value, image_path, predecessors, requires_grad, role_description)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m requires_grad) \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;28mlen\u001b[39m(_predecessor_requires_grad) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m):\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIf the variable does not require grad, none of its predecessors should require grad.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIn this case, following predecessors require grad: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m_predecessor_requires_grad\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 43\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(value) \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mbytes\u001b[39m, \u001b[38;5;28mint\u001b[39m], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mValue must be a string, int, or image (bytes). Got: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;28mtype\u001b[39m(value))\n\u001b[1;32m 44\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(value, \u001b[38;5;28mint\u001b[39m):\n\u001b[1;32m 45\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mstr\u001b[39m(value)\n",
277
+ "\u001b[0;31mAssertionError\u001b[0m: Value must be a string, int, or image (bytes). Got: <class 'numpy.int64'>"
278
+ ]
279
+ }
280
+ ],
281
+ "source": [
282
+ "train_loader = tg.tasks.DataLoader(train_set, batch_size=3, shuffle=True)\n",
283
+ "\n",
284
+ "\n",
285
+ "# Testing the 0-shot performance of the evaluation engine\n",
286
+ "system_prompt = tg.Variable(STARTING_SYSTEM_PROMPT, \n",
287
+ " requires_grad=True, \n",
288
+ " role_description=\"system prompt to the language model\")\n",
289
+ "model_evaluation = tg.BlackboxLLM(llm_api_eval, system_prompt)\n",
290
+ "\n",
291
+ "system_prompt = tg.Variable(STARTING_SYSTEM_PROMPT, \n",
292
+ " requires_grad=True,\n",
293
+ " role_description=\"structured system prompt to a somewhat capable language model that specifies the behavior and strategies for the QA task\")\n",
294
+ "model = tg.BlackboxLLM(llm_api_test, system_prompt)\n",
295
+ "\n",
296
+ "optimizer = tg.TextualGradientDescent(engine=llm_api_eval, parameters=[system_prompt])\n",
297
+ "\n",
298
+ "results = {\"test_acc\": [], \"prompt\": [], \"validation_acc\": []}\n",
299
+ "results[\"test_acc\"].append(eval_dataset(test_set, eval_fn, model))\n",
300
+ "results[\"validation_acc\"].append(eval_dataset(val_set, eval_fn, model))\n",
301
+ "results[\"prompt\"].append(system_prompt.get_value())"
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "code",
306
+ "execution_count": null,
307
+ "id": "47c15231-22ff-459b-b5cc-ca32aaa62332",
308
+ "metadata": {},
309
+ "outputs": [],
310
+ "source": [
311
+ "for epoch in range(3):\n",
312
+ " for steps, (batch_x, batch_y) in enumerate((pbar := tqdm(train_loader, position=0))):\n",
313
+ " pbar.set_description(f\"Training step {steps}. Epoch {epoch}\")\n",
314
+ " optimizer.zero_grad()\n",
315
+ " losses = []\n",
316
+ " for (x, y) in zip(batch_x, batch_y):\n",
317
+ " x = tg.Variable(x, requires_grad=False, role_description=\"query to the language model\")\n",
318
+ " y = tg.Variable(y, requires_grad=False, role_description=\"correct answer for the query\")\n",
319
+ " response = model(x)\n",
320
+ " try:\n",
321
+ " eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))\n",
322
+ " except:\n",
323
+ " eval_output_variable = eval_fn([x, y, response])\n",
324
+ " losses.append(eval_output_variable)\n",
325
+ " total_loss = tg.sum(losses)\n",
326
+ " total_loss.backward()\n",
327
+ " optimizer.step()\n",
328
+ " \n",
329
+ " run_validation_revert(system_prompt, results, model, eval_fn, val_set)\n",
330
+ " \n",
331
+ " print(\"sys prompt: \", system_prompt)\n",
332
+ " test_acc = eval_dataset(test_set, eval_fn, model)\n",
333
+ " results[\"test_acc\"].append(test_acc)\n",
334
+ " results[\"prompt\"].append(system_prompt.get_value())\n",
335
+ " if steps == 3:\n",
336
+ " break"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "code",
341
+ "execution_count": null,
342
+ "id": "3c5e93f5-8d1c-4b87-a6d1-811714982d47",
343
+ "metadata": {},
344
+ "outputs": [],
345
+ "source": []
346
+ },
347
+ {
348
+ "cell_type": "code",
349
+ "execution_count": null,
350
+ "id": "67a4583f-162c-4e2d-b061-798f6c676a28",
351
+ "metadata": {},
352
+ "outputs": [],
353
+ "source": [
354
+ "class TranslationQualityAssessor(dspy.Module):\n",
355
+ " def __init__(self):\n",
356
+ " super().__init__()\n",
357
+ " self.assess = dspy.ChainOfThought(TranslationQualitySignature)\n",
358
+ "\n",
359
+ " def forward(self, src_lang, tgt_lang, src_text, translation, src_prev=\"\", tgt_prev=\"\", src_next=\"\", tgt_next=\"\"):\n",
360
+ " context = f\"\"\"Previous Context:\n",
361
+ " Source: {src_prev}\n",
362
+ " Translation: {tgt_prev}\n",
363
+ " \n",
364
+ " Next Context:\n",
365
+ " Source: {src_next}\n",
366
+ " Translation: {tgt_next}\"\"\"\n",
367
+ "\n",
368
+ " result = self.assess(\n",
369
+ " context=context,\n",
370
+ " source=f\"Source ({src_lang}): {src_text}\",\n",
371
+ " translation=f\"Translation ({tgt_lang}): {translation}\"\n",
372
+ " )\n",
373
+ " \n",
374
+ " return result.evaluation\n",
375
+ "\n",
376
+ "class TranslationMetrics:\n",
377
+ " @staticmethod\n",
378
+ " def exact_match_score(pred, gold):\n",
379
+ " try:\n",
380
+ " pred_json = json.loads(pred)\n",
381
+ " gold_json = gold\n",
382
+ " \n",
383
+ " accuracy_match = (str(pred_json.get('Accuracy Score')) == str(gold_json.get('Accuracy Score')))\n",
384
+ " readability_match = (str(pred_json.get('Readability Score')) == str(gold_json.get('Readability Score')))\n",
385
+ " \n",
386
+ " return (accuracy_match and readability_match)\n",
387
+ " except:\n",
388
+ " return False\n",
389
+ " \n",
390
+ " @staticmethod\n",
391
+ " def partial_match_score(pred, gold):\n",
392
+ " try:\n",
393
+ " pred_json = json.loads(pred)\n",
394
+ " gold_json = gold\n",
395
+ " \n",
396
+ " # Score comparison\n",
397
+ " accuracy_diff = abs(float(pred_json.get('Accuracy Score', 0)) - float(gold_json.get('Accuracy Score', 0)))\n",
398
+ " readability_diff = abs(float(pred_json.get('Readability Score', 0)) - float(gold_json.get('Readability Score', 0)))\n",
399
+ " \n",
400
+ " # Issues comparison\n",
401
+ " pred_accuracy_issues = set(str(issue) for issue in pred_json.get('Accuracy Issues', []))\n",
402
+ " gold_accuracy_issues = set(str(issue) for issue in gold_json.get('Accuracy Issues', []))\n",
403
+ " pred_readability_issues = set(str(issue) for issue in pred_json.get('Readability Issues', []))\n",
404
+ " gold_readability_issues = set(str(issue) for issue in gold_json.get('Readability Issues', []))\n",
405
+ " \n",
406
+ " # Calculate Jaccard similarity for issues\n",
407
+ " accuracy_issues_sim = len(pred_accuracy_issues & gold_accuracy_issues) / max(1, len(pred_accuracy_issues | gold_accuracy_issues))\n",
408
+ " readability_issues_sim = len(pred_readability_issues & gold_readability_issues) / max(1, len(pred_readability_issues | gold_readability_issues))\n",
409
+ " \n",
410
+ " # Combine scores (0.6 weight to scores, 0.4 to issues similarity)\n",
411
+ " score_component = 1 - ((accuracy_diff + readability_diff) / 8)\n",
412
+ " issues_component = (accuracy_issues_sim + readability_issues_sim) / 2\n",
413
+ " \n",
414
+ " final_score = 0.6 * score_component + 0.4 * issues_component\n",
415
+ " return max(0, final_score)\n",
416
+ " except:\n",
417
+ " return 0\n",
418
+ "\n",
419
+ "def prepare_dataset(file_path):\n",
420
+ " with open(file_path, 'r') as f:\n",
421
+ " data = json.load(f)\n",
422
+ " \n",
423
+ " prepared_data = []\n",
424
+ " \n",
425
+ " for item in data:\n",
426
+ " example = dspy.Example(\n",
427
+ " context=f\"\"\"Previous Context:\n",
428
+ " Source: {item['src_prev']}\n",
429
+ " Translation: {item['tgt_prev']}\n",
430
+ " \n",
431
+ " Next Context:\n",
432
+ " Source: {item['src_next']}\n",
433
+ " Translation: {item['tgt_next']}\"\"\",\n",
434
+ " source=f\"Source ({item['src_lang']}): {item['src_text']}\",\n",
435
+ " translation=f\"Translation ({item['tgt_lang']}): {item['main_text']}\",\n",
436
+ " evaluation=json.dumps(item['evaluation'], ensure_ascii=False)\n",
437
+ " ).with_inputs(\"context\", \"source\", \"translation\")\n",
438
+ " \n",
439
+ " prepared_data.append(example)\n",
440
+ " \n",
441
+ " # Split data: 70% train, 15% dev, 15% test\n",
442
+ " train_size = int(0.7 * len(prepared_data))\n",
443
+ " dev_size = int(0.15 * len(prepared_data))\n",
444
+ " \n",
445
+ " train_data = prepared_data[:train_size]\n",
446
+ " dev_data = prepared_data[train_size:train_size + dev_size]\n",
447
+ " test_data = prepared_data[train_size + dev_size:]\n",
448
+ " \n",
449
+ " return train_data, dev_data, test_data\n",
450
+ "\n",
451
+ "def optimize_translation_quality_assessment():\n",
452
+ " # Initialize DSPy\n",
453
+ " lm = TranslationQualityLM()\n",
454
+ " dspy.settings.configure(lm=lm)\n",
455
+ " \n",
456
+ " # Load and prepare dataset\n",
457
+ " train_data, dev_data, test_data = prepare_dataset('translation_quality_dataset.json')\n",
458
+ " \n",
459
+ " # Create evaluator\n",
460
+ " evaluator = Evaluate(\n",
461
+ " metrics={\n",
462
+ " 'exact_match': TranslationMetrics.exact_match_score,\n",
463
+ " 'partial_match': TranslationMetrics.partial_match_score\n",
464
+ " }\n",
465
+ " )\n",
466
+ " \n",
467
+ " # Initialize module\n",
468
+ " assessor = TranslationQualityAssessor()\n",
469
+ " \n",
470
+ " # Initialize MIPROv2 optimizer\n",
471
+ " optimizer = dspy.MIPROv2(\n",
472
+ " metric=lambda x: x['partial_match'],\n",
473
+ " max_rounds=5, # Number of optimization rounds\n",
474
+ " max_traces=10, # Number of traces per round\n",
475
+ " max_depth=3, # Maximum depth of reasoning chains\n",
476
+ " num_candidate_prompts=5, # Number of candidate prompts to generate\n",
477
+ " num_rounds_per_prompt=3, # Number of rounds per candidate prompt\n",
478
+ " temperature=0.7,\n",
479
+ " verbose=True\n",
480
+ " )\n",
481
+ " \n",
482
+ " # Compile the module with optimization\n",
483
+ " compiled_assessor = optimizer.compile(\n",
484
+ " assessor,\n",
485
+ " trainset=train_data,\n",
486
+ " devset=dev_data,\n",
487
+ " eval_kwargs={\n",
488
+ " 'metric': 'partial_match',\n",
489
+ " 'num_threads': 4,\n",
490
+ " 'batch_size': 8\n",
491
+ " }\n",
492
+ " )\n",
493
+ " \n",
494
+ " # Evaluate on test set\n",
495
+ " results = []\n",
496
+ " for example in test_data:\n",
497
+ " pred = compiled_assessor(\n",
498
+ " context=example.context,\n",
499
+ " source=example.source,\n",
500
+ " translation=example.translation\n",
501
+ " )\n",
502
+ " \n",
503
+ " result = evaluator.evaluate(\n",
504
+ " predictions=[pred],\n",
505
+ " ground_truth=[example.evaluation]\n",
506
+ " )\n",
507
+ " results.append(result)\n",
508
+ " \n",
509
+ " # Calculate and print final metrics\n",
510
+ " avg_exact_match = np.mean([r['exact_match'] for r in results])\n",
511
+ " avg_partial_match = np.mean([r['partial_match'] for r in results])\n",
512
+ " \n",
513
+ " print(f\"Average Exact Match Score: {avg_exact_match:.3f}\")\n",
514
+ " print(f\"Average Partial Match Score: {avg_partial_match:.3f}\")\n",
515
+ " \n",
516
+ " return compiled_assessor\n",
517
+ "\n",
518
+ "if __name__ == \"__main__\":\n",
519
+ " optimized_assessor = optimize_translation_quality_assessment()"
520
+ ]
521
+ }
522
+ ],
523
+ "metadata": {
524
+ "kernelspec": {
525
+ "display_name": "timedlibs",
526
+ "language": "python",
527
+ "name": "timedlibs"
528
+ },
529
+ "language_info": {
530
+ "codemirror_mode": {
531
+ "name": "ipython",
532
+ "version": 3
533
+ },
534
+ "file_extension": ".py",
535
+ "mimetype": "text/x-python",
536
+ "name": "python",
537
+ "nbconvert_exporter": "python",
538
+ "pygments_lexer": "ipython3",
539
+ "version": "3.10.16"
540
+ }
541
+ },
542
+ "nbformat": 4,
543
+ "nbformat_minor": 5
544
+ }
adapter_config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "google/gemma-3-4b-it",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 128,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.05,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": [
22
+ "lm_head",
23
+ "embed_tokens"
24
+ ],
25
+ "peft_type": "LORA",
26
+ "r": 16,
27
+ "rank_pattern": {},
28
+ "revision": null,
29
+ "target_modules": [
30
+ "out_proj",
31
+ "v_proj",
32
+ "k_proj",
33
+ "fc1",
34
+ "down_proj",
35
+ "up_proj",
36
+ "fc2",
37
+ "o_proj",
38
+ "q_proj",
39
+ "gate_proj"
40
+ ],
41
+ "task_type": "CAUSAL_LM",
42
+ "trainable_token_indices": null,
43
+ "use_dora": false,
44
+ "use_rslora": false
45
+ }
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:913b7aae1196ba3282bd45e04a65bf705d67f933d758540a06902f63054ad6e7
3
+ size 2839124552
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<image_soft_token>": 262144
3
+ }
data_prep.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as T
2
+ import os
3
+ import sys
4
+ import argparse
5
+ import json
6
+ import nflx_copilot as ncp
7
+ import pandas as pd
8
+ import re
9
+
10
+ sys.path.append("/root/workspace")
11
+
12
+ from timedtext.adapters.translation.generation.pldl import TimedTextAdapter, ConverterDialogContext
13
+ from timedtext.manager import TimedTextManager
14
+ from timedtext.handlers import OriginalLanguagePivotLanguageHandler, EnglishTemplateSubtitleHandler
15
+ from timedprompts.evaluation.pldl_prompt_one.prompt import (
16
+ ReferenceFreeFeedbackTransform,
17
+ ContextFreeFeedbackTransform,
18
+ ReferenceFreeDirectTransform,
19
+ ReferenceBasedFeedbackTransform,
20
+ ReferenceFreeExampleTransform,
21
+ )
22
+ from tqdm import tqdm
23
+ from timedtune.convert.tq_for_pldl.pldl_train_one import PldlTrainOneReferenceFreeTransform
24
+ from timedtext.adapters.translation.evaluation import compute_score_delta
25
+
26
+ def compute_32_point_score(response, generation):
27
+ parsed, score = {}, -1
28
+ try:
29
+ score = (
30
+ int(response["Accuracy Score"])
31
+ + int(response["Readability Score"])
32
+ + compute_score_delta(response, "Accuracy Issues", generation)
33
+ + compute_score_delta(response, "Readability Issues", generation)
34
+ )
35
+ score = score * 4
36
+ except:
37
+ score = -1
38
+ return parsed, score
39
+
40
+ # Your existing TimedTextAdapter and helper classes
41
+ class TimedTextAdapterFromCache_PLDL(TimedTextAdapter):
42
+ def __init__(
43
+ self,
44
+ data_dir: str,
45
+ cache_size: int = 0,
46
+ ol_dialog_list_version: str = "",
47
+ pl_dialog_list_version: str = "",
48
+ ol_dialog_list_pl_dialog_list_version: str = "",
49
+ num_prev_events: int = 16,
50
+ num_next_events: int = 16,
51
+ ) -> None:
52
+ super().__init__(num_prev_events, num_next_events)
53
+ self.timed_text_manager = TimedTextManager(
54
+ data_dir,
55
+ cache_size=cache_size,
56
+ ol_dialog_list_version=ol_dialog_list_version,
57
+ pl_dialog_list_version=pl_dialog_list_version,
58
+ ol_dialog_list_pl_dialog_list_version=ol_dialog_list_pl_dialog_list_version,
59
+ )
60
+
61
+ def _get_timed_text(
62
+ self, movie_id: int, start_frame: int, end_frame: int, src_lang: str, tgt_lang: str
63
+ ) -> T.Dict[str, T.Union[T.Dict, T.List[T.Dict]]]:
64
+ results = self.timed_text_manager.match_and_get_timed_text(
65
+ handler_class=OriginalLanguagePivotLanguageHandler,
66
+ movie_id=movie_id,
67
+ start_frame=start_frame,
68
+ end_frame=end_frame,
69
+ src_lang=src_lang,
70
+ tgt_lang=tgt_lang,
71
+ mid_lang="",
72
+ **self.timed_text_kwargs,
73
+ )
74
+
75
+ curr_srcs = [result["curr"]["src"]["txt"] for result in results]
76
+ curr_tgts = [result["curr"]["tgt"]["txt"] for result in results]
77
+
78
+ return {
79
+ "curr": {"src": {"txt": "\n\n".join(curr_srcs)}, "tgt": {"txt": "\n\n".join(curr_tgts)}},
80
+ "prev": results[0]["prev"],
81
+ "next": results[-1]["next"],
82
+ }
83
+
84
+ class TimedTextAdapterFromCache_SUBS(TimedTextAdapter):
85
+ def __init__(
86
+ self,
87
+ data_dir: str,
88
+ cache_size: int = 0,
89
+ num_prev_events: int = 16,
90
+ num_next_events: int = 16,
91
+ ) -> None:
92
+ super().__init__(num_prev_events, num_next_events)
93
+ self.timed_text_manager = TimedTextManager(
94
+ data_dir,
95
+ cache_size=cache_size,
96
+ )
97
+
98
+ def _get_timed_text(
99
+ self, movie_id: int, start_frame: int, end_frame: int, src_lang: str, tgt_lang: str
100
+ ) -> T.Dict[str, T.Union[T.Dict, T.List[T.Dict]]]:
101
+ results = self.timed_text_manager.match_and_get_timed_text(
102
+ handler_class=EnglishTemplateSubtitleHandler,
103
+ movie_id=movie_id,
104
+ start_frame=start_frame,
105
+ end_frame=end_frame,
106
+ src_lang=src_lang,
107
+ tgt_lang=tgt_lang,
108
+ mid_lang="",
109
+ **self.timed_text_kwargs,
110
+ )
111
+
112
+ curr_srcs = [result["curr"]["src"]["txt"] for result in results]
113
+ curr_tgts = [result["curr"]["tgt"]["txt"] for result in results]
114
+
115
+ return {
116
+ "curr": {"src": {"txt": "\n\n".join(curr_srcs)}, "tgt": {"txt": "\n\n".join(curr_tgts)}},
117
+ "prev": results[0]["prev"],
118
+ "next": results[-1]["next"],
119
+ }
120
+
121
+
122
+ # Function to fetch contextual information using TimedTextAdapter
123
+ def fetch_contextual_information(timed_text_adapter, row):
124
+ """
125
+ Fetches the required context information for each sample using timed_text_adapter.
126
+
127
+ Args:
128
+ timed_text_adapter (TimedTextAdapterFromCache): Adapter to fetch data from.
129
+ row (dict): Row containing the necessary information to fetch the context.
130
+
131
+ Returns:
132
+ dict: Contextual information containing src_text, tgt_text, prev_context, next_context, src_prev, src_next, tgt_prev, tgt_next.
133
+ """
134
+ # Fetching the actual translation context
135
+ src_text, tgt_text, prev_context, next_context = timed_text_adapter.get_timed_text(
136
+ movie_id=row["movie_id"],
137
+ start_frame=row["start_frame"],
138
+ end_frame=row["end_frame"],
139
+ src_lang=row["src_lang"],
140
+ tgt_lang=row["tgt_lang"],
141
+ )
142
+
143
+ timed_text_converter = ConverterDialogContext(timed_text_adapter)
144
+
145
+ # Converting context to the format expected by the prompt
146
+ src_prev, src_next, tgt_prev, tgt_next, _ = timed_text_converter.__context__(
147
+ row["src_lang"], row["tgt_lang"], prev_context, next_context, None
148
+ )
149
+
150
+ return {
151
+ "tt_src_text": src_text,
152
+ "tt_tgt_text": tgt_text,
153
+ "tt_src_prev": src_prev,
154
+ "tt_src_next": src_next,
155
+ "tt_tgt_prev": tgt_prev,
156
+ "tt_tgt_next": tgt_next,
157
+ }
158
+
159
+ def transform_json(input_json):
160
+ # Get the first project key
161
+ project_key = list(input_json['projects'].keys())[0]
162
+ project = input_json['projects'][project_key]
163
+
164
+ final_output = {"labelers": []}
165
+ # Process each label
166
+ for index, label in enumerate(project['labels']):
167
+ # Initialize output structure
168
+ output = {
169
+ "annotation": {
170
+ "Accuracy Issues": [],
171
+ "Readability Issues": [],
172
+ "Accuracy Score": "",
173
+ "Readability Score": "",
174
+ "Confidence Level": "",
175
+ "Main Vs Alternate": "",
176
+ "Score": "-1" # initalized -1, will be updated in next steps
177
+ },
178
+ }
179
+ # Process annotations/objects (issues)
180
+ if 'objects' in label['annotations']:
181
+ for obj in label['annotations']['objects']:
182
+ issue = {
183
+ "Error Location": obj['conversational_location']['message_id'],
184
+ "Error Span": [
185
+ obj['conversational_location']['location']['start'],
186
+ obj['conversational_location']['location']['end']
187
+ ],
188
+ "Error Explanation": "",
189
+ "Error Quality Category": obj['name'],
190
+ "Error Quality Tags": [],
191
+ "Error Severity": ""
192
+ }
193
+
194
+ # Process classifications within object
195
+ for classification in obj['classifications']:
196
+ if classification['name'] == 'Explanation':
197
+ issue["Error Explanation"] = classification['text_answer']['content']
198
+ elif classification['name'] == 'Quality Tag':
199
+ issue["Error Quality Tags"] = [ans['name'].lower() for ans in classification['checklist_answers']]
200
+ elif classification['name'] == 'Quality SubCategory':
201
+ severity = classification['radio_answer']['name']
202
+ if 'Major' in severity:
203
+ issue["Error Severity"] = "Major"
204
+ else:
205
+ issue["Error Severity"] = "Minor"
206
+
207
+ # Add to appropriate issues list
208
+ if obj['name'] == 'Style':
209
+ output['annotation']['Readability Issues'].append(issue)
210
+ else:
211
+ output['annotation']['Accuracy Issues'].append(issue)
212
+
213
+ # Process classifications
214
+ for classification in label['annotations']['classifications']:
215
+ if classification['name'] == 'Accuracy Score':
216
+ output['annotation']['Accuracy Score'] = classification['radio_answer']['name'].split(' - ')[0]
217
+ elif classification['name'] == 'Readability Score':
218
+ output['annotation']['Readability Score'] = classification['radio_answer']['name'].split(' - ')[0]
219
+ elif classification['name'] == 'Confidence Level':
220
+ output['annotation']['Confidence Level'] = classification['radio_answer']['value']
221
+ elif classification['name'] == 'Main vs Alternate':
222
+ output['annotation']['Main Vs Alternate'] = classification['radio_answer']['name']
223
+ final_output["labelers"].append(output)
224
+ return final_output
225
+
226
+ # Function to load the relevant meta json for a given key
227
+ def load_meta_json(priority_key, data_row_key, meta_path):
228
+ """
229
+ Loads and validates metadata json from the specified path based on the priority key and data row key.
230
+
231
+ Args:
232
+ priority_key (str): Priority key from the label metadata.
233
+ data_row_key (str): Data row key to find the relevant file.
234
+ meta_path (str): Path to the metadata folder.
235
+
236
+ Returns:
237
+ dict: Loaded metadata.
238
+ """
239
+ with open(os.path.join(meta_path, f'{priority_key}.json')) as fread:
240
+ meta_dict = json.load(fread)
241
+
242
+ _, movie_id, start_end_frame, _, _, _, _ = data_row_key.split('.')
243
+ start_frame, end_frame = start_end_frame.split('_')
244
+
245
+ if int(meta_dict['movie_id']) != int(movie_id):
246
+ print("Movie Ids didn't match:", int(meta_dict['movie_id']), int(movie_id), os.path.join(meta_path, f'{priority_key}.json'), data_row_key)
247
+ exit(0)
248
+ assert int(meta_dict['start_frame']) == int(start_frame)
249
+ assert int(meta_dict['end_frame']) == int(end_frame)
250
+
251
+ return meta_dict
252
+
253
+ # Main function that processes the data
254
+ def process_json(timed_text_adapter, example_row, meta_path, conv_path):
255
+ """
256
+ Takes the full input json, converts it to the required format, and adds context using metadata.
257
+
258
+ Args:
259
+ timed_text_adapter (TimedTextAdapterFromCache): Adapter to fetch context.
260
+ example_row (dict): The full input JSON (like the example_row you provided).
261
+ meta_path (str): Path to the metadata folder to fetch meta json.
262
+
263
+ Returns:
264
+ dict: The enriched annotation format with context and annotation data.
265
+ """
266
+ # Step 1: Convert the full input JSON to the required annotation format
267
+ annotation_result = transform_json(example_row)
268
+
269
+ # Extracting the necessary data_row_key and priority_key
270
+ data_row_key = example_row['data_row']['global_key']
271
+ priority_key = example_row['projects'][list(example_row["projects"].keys())[0]]['project_details']['priority']
272
+
273
+ annotation_result["Data_Row_Key"] = data_row_key
274
+ key = ".".join(data_row_key.split(".")[:3])
275
+ with open(conv_path + "/" + key + ".json") as file:
276
+ data = json.load(file)
277
+ annotation_result["main_tgt_text"] = data["messages"][0]["content"]
278
+ annotation_result["src_text"] = data["messages"][1]["content"]
279
+ annotation_result["alt_tgt_text"] = data["messages"][2]["content"]
280
+
281
+ # Load the metadata using the keys from the json
282
+ meta_dict = load_meta_json(priority_key, data_row_key, meta_path)
283
+
284
+ # Step 2: Add the metadata fields (e.g., title_id, start_frame, end_frame, src_lang, tgt_lang)
285
+ annotation_result.update({
286
+ "title_id": meta_dict['movie_id'],
287
+ "start_frame": meta_dict['start_frame'],
288
+ "end_frame": meta_dict['end_frame'],
289
+ "src_lang": meta_dict['src_lang'],
290
+ "tgt_lang": meta_dict['tgt_lang'],
291
+ })
292
+
293
+ # Step 3: Fetch contextual information using the given timed_text_adapter
294
+ context_info = fetch_contextual_information(timed_text_adapter, meta_dict)
295
+
296
+ annotation_result.update(context_info)
297
+
298
+ # Update error spans with actual text for each labeler
299
+ for labeler in annotation_result["labelers"]:
300
+ # Process Accuracy Issues
301
+ for issue in labeler["annotation"]["Accuracy Issues"]:
302
+ error_location = issue["Error Location"]
303
+ start, end = issue["Error Span"][0], issue["Error Span"][1]
304
+
305
+ # Get the actual text based on error location
306
+ if error_location == "src":
307
+ actual_text = annotation_result["src_text"][start:end]
308
+ else: # tgt
309
+ actual_text = annotation_result["main_tgt_text"][start:end]
310
+
311
+ # Update the error span with actual text
312
+ issue["Error Span"] = actual_text
313
+
314
+ # Process Readability Issues
315
+ for issue in labeler["annotation"]["Readability Issues"]:
316
+ error_location = issue["Error Location"]
317
+ start, end = issue["Error Span"]
318
+
319
+ # Get the actual text based on error location
320
+ if error_location == "src":
321
+ actual_text = annotation_result["src_text"][start:end]
322
+ else: # tgt
323
+ actual_text = annotation_result["main_tgt_text"][start:end]
324
+
325
+ # Update the error span with actual text
326
+ issue["Error Span"] = actual_text
327
+
328
+ return annotation_result
329
+
330
+ # Example usage
331
+ def main():
332
+ base_path = "MT_TQ/Caches/May2025/tquality.annotated.data/"
333
+ json_files = [base_path + "raw/" + f for f in os.listdir(base_path + "raw/") if f.endswith('.json')]
334
+
335
+ for json_file in tqdm(json_files):
336
+ if "calibration" in json_file:
337
+ print("Warning: Skipping Calibration Data, Remove this if you want to use Calibration data")
338
+ continue
339
+
340
+ if "PLDL" in json_file:
341
+ folder = "pldl"
342
+ timed_text_adapter = TimedTextAdapterFromCache_PLDL(
343
+ data_dir="/fsx_l10n/l10n_dse_timedtext/cache", num_prev_events=32, num_next_events=32
344
+ )
345
+ elif "SUBS" in json_file:
346
+ folder = "subs"
347
+ timed_text_adapter = TimedTextAdapterFromCache_SUBS(
348
+ data_dir="/fsx_l10n/l10n_dse_timedtext/cache", num_prev_events=32, num_next_events=32
349
+ )
350
+ else:
351
+ folder = ""
352
+ assert "invalid json file"
353
+
354
+ langs_type = json_file.split("/")[-1].split("-")[1].replace("_",".")
355
+ phase = json_file.split("/")[-1].split("-")[3]
356
+ phase_number = int(''.join(re.findall(r'\d+', phase))) if re.findall(r'\d+', phase) else None
357
+ phase_date = json_file.split("/")[-1].split("-")[4].replace(".json", "")
358
+
359
+ zzmetapath = f"/root/notebooks/MT_TQ/Caches/labelspace/tquality.zzmeta.data/{folder}/{langs_type}/phase {phase_number} - {phase_date}"
360
+
361
+ meta_path = zzmetapath + "/meta"
362
+ conv_path = zzmetapath + "/conv"
363
+
364
+ with open(json_file) as file:
365
+ data = json.load(file)
366
+
367
+ output_data = []
368
+ for data_point in tqdm(data):
369
+ annotation_result = process_json(timed_text_adapter, data_point, meta_path, conv_path)
370
+ for labeler in annotation_result["labelers"]:
371
+ _, score = compute_32_point_score(labeler["annotation"], annotation_result["main_tgt_text"])
372
+ labeler["annotation"]["Score"] = score
373
+
374
+ output_data.append(annotation_result)
375
+
376
+ with open(base_path + "parsed/" + json_file.split("/")[-1], 'w') as json_file:
377
+ json.dump({"data": output_data}, json_file, indent=4)
378
+
379
+ if __name__ == "__main__":
380
+ main()
gemma-12b-tq-model/README.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: google/gemma-3-4b-it
3
+ library_name: transformers
4
+ model_name: gemma-12b-tq-model
5
+ tags:
6
+ - generated_from_trainer
7
+ - trl
8
+ - sft
9
+ licence: license
10
+ ---
11
+
12
+ # Model Card for gemma-12b-tq-model
13
+
14
+ This model is a fine-tuned version of [google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it).
15
+ It has been trained using [TRL](https://github.com/huggingface/trl).
16
+
17
+ ## Quick start
18
+
19
+ ```python
20
+ from transformers import pipeline
21
+
22
+ question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?"
23
+ generator = pipeline("text-generation", model="bhavinjawade/gemma-12b-tq-model", device="cuda")
24
+ output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0]
25
+ print(output["generated_text"])
26
+ ```
27
+
28
+ ## Training procedure
29
+
30
+
31
+
32
+
33
+ This model was trained with SFT.
34
+
35
+ ### Framework versions
36
+
37
+ - TRL: 0.16.1
38
+ - Transformers: 4.50.0.dev0
39
+ - Pytorch: 2.7.0
40
+ - Datasets: 3.3.2
41
+ - Tokenizers: 0.21.0
42
+
43
+ ## Citations
44
+
45
+
46
+
47
+ Cite TRL as:
48
+
49
+ ```bibtex
50
+ @misc{vonwerra2022trl,
51
+ title = {{TRL: Transformer Reinforcement Learning}},
52
+ author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
53
+ year = 2020,
54
+ journal = {GitHub repository},
55
+ publisher = {GitHub},
56
+ howpublished = {\url{https://github.com/huggingface/trl}}
57
+ }
58
+ ```
gemma-12b-tq-model/adapter_config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "google/gemma-3-4b-it",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 128,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.05,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": [
22
+ "lm_head",
23
+ "embed_tokens"
24
+ ],
25
+ "peft_type": "LORA",
26
+ "r": 16,
27
+ "rank_pattern": {},
28
+ "revision": null,
29
+ "target_modules": [
30
+ "up_proj",
31
+ "gate_proj",
32
+ "fc2",
33
+ "out_proj",
34
+ "fc1",
35
+ "down_proj",
36
+ "o_proj",
37
+ "k_proj",
38
+ "q_proj",
39
+ "v_proj"
40
+ ],
41
+ "task_type": "CAUSAL_LM",
42
+ "trainable_token_indices": null,
43
+ "use_dora": false,
44
+ "use_rslora": false
45
+ }
gemma-12b-tq-model/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44481365420dd4297edab0bc7d76dc43d4e6d7f38e393cce87c2fabdbea96661
3
+ size 2839124552
gemma-12b-tq-model/added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<image_soft_token>": 262144
3
+ }
gemma-12b-tq-model/checkpoint-2/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: google/gemma-3-4b-it
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.15.2
gemma-12b-tq-model/checkpoint-2/adapter_config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "google/gemma-3-4b-it",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 128,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.05,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": [
22
+ "lm_head",
23
+ "embed_tokens"
24
+ ],
25
+ "peft_type": "LORA",
26
+ "r": 16,
27
+ "rank_pattern": {},
28
+ "revision": null,
29
+ "target_modules": [
30
+ "up_proj",
31
+ "gate_proj",
32
+ "fc2",
33
+ "out_proj",
34
+ "fc1",
35
+ "down_proj",
36
+ "o_proj",
37
+ "k_proj",
38
+ "q_proj",
39
+ "v_proj"
40
+ ],
41
+ "task_type": "CAUSAL_LM",
42
+ "trainable_token_indices": null,
43
+ "use_dora": false,
44
+ "use_rslora": false
45
+ }
gemma-12b-tq-model/checkpoint-2/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44481365420dd4297edab0bc7d76dc43d4e6d7f38e393cce87c2fabdbea96661
3
+ size 2839124552
gemma-12b-tq-model/checkpoint-2/added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<image_soft_token>": 262144
3
+ }
gemma-12b-tq-model/checkpoint-2/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:898abdc006b804547b999529ece7d0ca106ab09c0a2352337e1809b5041573ee
3
+ size 5608850589
gemma-12b-tq-model/checkpoint-2/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:250560ab3d528161ab3659b120def6e4a9ab4b457e3399603bbcfa40db3efc90
3
+ size 14645
gemma-12b-tq-model/checkpoint-2/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29847f084360f67920e16e2780978c5b4908b1a69433f50a755d4db1e0c11563
3
+ size 1401
gemma-12b-tq-model/checkpoint-2/special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "boi_token": "<start_of_image>",
3
+ "bos_token": {
4
+ "content": "<bos>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "eoi_token": "<end_of_image>",
11
+ "eos_token": {
12
+ "content": "<eos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "image_token": "<image_soft_token>",
19
+ "pad_token": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
gemma-12b-tq-model/checkpoint-2/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4667f2089529e8e7657cfb6d1c19910ae71ff5f28aa7ab2ff2763330affad795
3
+ size 33384568
gemma-12b-tq-model/checkpoint-2/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
3
+ size 4689074
gemma-12b-tq-model/checkpoint-2/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
gemma-12b-tq-model/checkpoint-2/trainer_state.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 0.7272727272727273,
5
+ "eval_steps": 500,
6
+ "global_step": 2,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 0.36363636363636365,
13
+ "grad_norm": 269.2862243652344,
14
+ "learning_rate": 0.0002,
15
+ "loss": 12.5019,
16
+ "mean_token_accuracy": 0.4838709682226181,
17
+ "num_tokens": 4096.0,
18
+ "step": 1
19
+ },
20
+ {
21
+ "epoch": 0.7272727272727273,
22
+ "grad_norm": 232.57679748535156,
23
+ "learning_rate": 0.0002,
24
+ "loss": 9.4112,
25
+ "mean_token_accuracy": 0.5447482168674469,
26
+ "num_tokens": 7561.0,
27
+ "step": 2
28
+ }
29
+ ],
30
+ "logging_steps": 1,
31
+ "max_steps": 2,
32
+ "num_input_tokens_seen": 0,
33
+ "num_train_epochs": 1,
34
+ "save_steps": 500,
35
+ "stateful_callbacks": {
36
+ "TrainerControl": {
37
+ "args": {
38
+ "should_epoch_stop": false,
39
+ "should_evaluate": false,
40
+ "should_log": false,
41
+ "should_save": true,
42
+ "should_training_stop": true
43
+ },
44
+ "attributes": {}
45
+ }
46
+ },
47
+ "total_flos": 196609832513952.0,
48
+ "train_batch_size": 1,
49
+ "trial_name": null,
50
+ "trial_params": null
51
+ }
gemma-12b-tq-model/checkpoint-2/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a6ae5bb0171c4fcfd72d4d46496b122da48e4ebf65ff31d5812d7d8dba26a8e
3
+ size 6161
gemma-12b-tq-model/runs/Apr25_08-39-59_9945b53f-579e-4565-94fc-5fbe73c83cc2/events.out.tfevents.1745570448.9945b53f-579e-4565-94fc-5fbe73c83cc2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1316da37ce84d8a16f9542d45cf4ab5617be2069029895d1e700955c546f6c26
3
+ size 6460
gemma-12b-tq-model/runs/Apr25_08-42-29_9945b53f-579e-4565-94fc-5fbe73c83cc2/events.out.tfevents.1745570563.9945b53f-579e-4565-94fc-5fbe73c83cc2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a216b92a09ed293f852daa5eac2575a729347e101cd1e58fb41bce94b4a0c3d
3
+ size 6460
gemma-12b-tq-model/runs/Apr25_09-19-39_9945b53f-579e-4565-94fc-5fbe73c83cc2/events.out.tfevents.1745572788.9945b53f-579e-4565-94fc-5fbe73c83cc2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac3b5c4eb7654e90a80565f05650922f5b8fe9c0c0e8f02134b18287d5ef32db
3
+ size 7455
gemma-12b-tq-model/special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "boi_token": "<start_of_image>",
3
+ "bos_token": {
4
+ "content": "<bos>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "eoi_token": "<end_of_image>",
11
+ "eos_token": {
12
+ "content": "<eos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "image_token": "<image_soft_token>",
19
+ "pad_token": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
gemma-12b-tq-model/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4667f2089529e8e7657cfb6d1c19910ae71ff5f28aa7ab2ff2763330affad795
3
+ size 33384568
gemma-12b-tq-model/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
3
+ size 4689074
gemma-12b-tq-model/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
gemma-12b-tq-model/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a6ae5bb0171c4fcfd72d4d46496b122da48e4ebf65ff31d5812d7d8dba26a8e
3
+ size 6161
gemma-1b-tq-model/README.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: google/gemma-3-1b-pt
3
+ library_name: transformers
4
+ model_name: gemma-1b-tq-model
5
+ tags:
6
+ - generated_from_trainer
7
+ - trl
8
+ - sft
9
+ licence: license
10
+ ---
11
+
12
+ # Model Card for gemma-1b-tq-model
13
+
14
+ This model is a fine-tuned version of [google/gemma-3-1b-pt](https://huggingface.co/google/gemma-3-1b-pt).
15
+ It has been trained using [TRL](https://github.com/huggingface/trl).
16
+
17
+ ## Quick start
18
+
19
+ ```python
20
+ from transformers import pipeline
21
+
22
+ question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?"
23
+ generator = pipeline("text-generation", model="bhavinjawade/gemma-1b-tq-model", device="cuda")
24
+ output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0]
25
+ print(output["generated_text"])
26
+ ```
27
+
28
+ ## Training procedure
29
+
30
+
31
+
32
+
33
+ This model was trained with SFT.
34
+
35
+ ### Framework versions
36
+
37
+ - TRL: 0.16.1
38
+ - Transformers: 4.50.0.dev0
39
+ - Pytorch: 2.7.0
40
+ - Datasets: 3.3.2
41
+ - Tokenizers: 0.21.0
42
+
43
+ ## Citations
44
+
45
+
46
+
47
+ Cite TRL as:
48
+
49
+ ```bibtex
50
+ @misc{vonwerra2022trl,
51
+ title = {{TRL: Transformer Reinforcement Learning}},
52
+ author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
53
+ year = 2020,
54
+ journal = {GitHub repository},
55
+ publisher = {GitHub},
56
+ howpublished = {\url{https://github.com/huggingface/trl}}
57
+ }
58
+ ```
gemma-1b-tq-model/adapter_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "google/gemma-3-1b-pt",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 16,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.05,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": [
22
+ "lm_head",
23
+ "embed_tokens"
24
+ ],
25
+ "peft_type": "LORA",
26
+ "r": 16,
27
+ "rank_pattern": {},
28
+ "revision": null,
29
+ "target_modules": [
30
+ "o_proj",
31
+ "k_proj",
32
+ "v_proj",
33
+ "q_proj",
34
+ "up_proj",
35
+ "down_proj",
36
+ "gate_proj"
37
+ ],
38
+ "task_type": "CAUSAL_LM",
39
+ "trainable_token_indices": null,
40
+ "use_dora": false,
41
+ "use_rslora": false
42
+ }
gemma-1b-tq-model/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5ccfa8bc91f4f5e80a03ca2a73036523a47df7550df49b4cd8296c486ed37de
3
+ size 1260191096
gemma-1b-tq-model/added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<image_soft_token>": 262144
3
+ }
gemma-1b-tq-model/checkpoint-10/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: google/gemma-3-1b-pt
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.15.2
gemma-1b-tq-model/checkpoint-10/adapter_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "google/gemma-3-1b-pt",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 16,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.05,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": [
22
+ "lm_head",
23
+ "embed_tokens"
24
+ ],
25
+ "peft_type": "LORA",
26
+ "r": 16,
27
+ "rank_pattern": {},
28
+ "revision": null,
29
+ "target_modules": [
30
+ "o_proj",
31
+ "k_proj",
32
+ "v_proj",
33
+ "q_proj",
34
+ "up_proj",
35
+ "down_proj",
36
+ "gate_proj"
37
+ ],
38
+ "task_type": "CAUSAL_LM",
39
+ "trainable_token_indices": null,
40
+ "use_dora": false,
41
+ "use_rslora": false
42
+ }
gemma-1b-tq-model/checkpoint-10/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b5ac34a75c3a7a06c3c9fda32b95d9870267691dacb15af9c0c2c08dc4e7934
3
+ size 1260191096
gemma-1b-tq-model/checkpoint-10/added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<image_soft_token>": 262144
3
+ }
gemma-1b-tq-model/checkpoint-10/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f990dd49302e243f6510cfc5976f8dc28049c45c9af22f8424f89bc8b3d89b2
3
+ size 2520598381
gemma-1b-tq-model/checkpoint-10/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67a697233a108d806598e97819e22cb699651bb7e046c04cc47db386d7540306
3
+ size 14645
gemma-1b-tq-model/checkpoint-10/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:591c3072a024fe1a8043b72e8e5366699aec4a9d0c3da5bde546eb445034a199
3
+ size 1401
gemma-1b-tq-model/checkpoint-10/special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "boi_token": "<start_of_image>",
3
+ "bos_token": {
4
+ "content": "<bos>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "eoi_token": "<end_of_image>",
11
+ "eos_token": {
12
+ "content": "<eos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "image_token": "<image_soft_token>",
19
+ "pad_token": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }