Aditi commited on
Commit
ac1fe86
·
1 Parent(s): 3847e49

fine-tune & evaluation

Browse files
Files changed (3) hide show
  1. fine_tune.py +0 -0
  2. fine_tune_and_evaluation.py +129 -0
  3. requirements.txt +25 -57
fine_tune.py DELETED
File without changes
fine_tune_and_evaluation.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BartTokenizer, BartForConditionalGeneration, TrainingArguments, Trainer
2
+ import pandas as pd
3
+ from datasets import Dataset, Features, Value
4
+ import evaluate
5
+ import nltk
6
+ import json
7
+ import os
8
+ import random
9
+
10
+ nltk.download('punkt')
11
+
12
+ # === CONFIGURATION ===
13
+ train_file = r"C:/Users/aditi/OneDrive/Desktop/train_v0.2 QuaC.json"
14
+ model_name = "voidful/bart-eqg-question-generator"
15
+ output_dir = "./bart-eqg-finetuned-500"
16
+
17
+ # === FILE CHECK ===
18
+ if not os.path.exists(train_file):
19
+ raise FileNotFoundError(f"File not found at: {train_file}")
20
+
21
+ # === LOAD DATA ===
22
+ with open(train_file, 'r', encoding='utf-8') as f:
23
+ quac_data = json.load(f)
24
+
25
+ # === EXTRACT 500 Q&A PAIRS ===
26
+ data = []
27
+ for item in quac_data.get("data", []):
28
+ for paragraph in item.get("paragraphs", []):
29
+ context = paragraph.get("context", "")
30
+ for qa in paragraph.get("qas", []):
31
+ question = qa.get("question", "")
32
+ answer = qa.get("answers", [{}])[0].get("text", "") if qa.get("answers") else ""
33
+ if context and question and answer:
34
+ data.append({"context": context, "question": question, "answer": answer})
35
+
36
+ random.seed(42)
37
+ random.shuffle(data)
38
+ data = data[:500]
39
+
40
+ # === CREATE DATASET ===
41
+ df = pd.DataFrame(data)[["context", "question", "answer"]]
42
+ features = Features({
43
+ "context": Value("string"),
44
+ "question": Value("string"),
45
+ "answer": Value("string")
46
+ })
47
+ dataset = Dataset.from_pandas(df, features=features)
48
+ train_test_split = dataset.train_test_split(test_size=0.2, seed=42)
49
+ train_dataset = train_test_split["train"]
50
+ eval_dataset = train_test_split["test"]
51
+
52
+ print(f"Train size: {len(train_dataset)} | Eval size: {len(eval_dataset)}")
53
+
54
+ # === LOAD MODEL AND TOKENIZER ===
55
+ try:
56
+ tokenizer = BartTokenizer.from_pretrained(model_name)
57
+ model = BartForConditionalGeneration.from_pretrained(model_name)
58
+ except Exception as e:
59
+ raise RuntimeError(f"Could not load model or tokenizer: {e}")
60
+
61
+ # === PREPROCESS FUNCTION ===
62
+ def preprocess(example):
63
+ input_text = example['context']
64
+ target_text = example['question']
65
+ model_inputs = tokenizer(input_text, max_length=512, truncation=True, padding="max_length")
66
+ labels = tokenizer(target_text, max_length=64, truncation=True, padding="max_length")["input_ids"]
67
+ model_inputs["labels"] = labels
68
+ return model_inputs
69
+
70
+ tokenized_train_dataset = train_dataset.map(preprocess, remove_columns=train_dataset.column_names, batched=True)
71
+ tokenized_eval_dataset = eval_dataset.map(preprocess, remove_columns=eval_dataset.column_names, batched=True)
72
+
73
+ # === METRIC COMPUTATION ===
74
+ def compute_metrics(eval_pred):
75
+ preds, labels = eval_pred
76
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
77
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
78
+
79
+ bleu = evaluate.load("bleu")
80
+ rouge = evaluate.load("rouge")
81
+
82
+ bleu_score = bleu.compute(predictions=decoded_preds, references=decoded_labels)
83
+ rouge_score = rouge.compute(predictions=decoded_preds, references=decoded_labels)
84
+
85
+ return {
86
+ "bleu": bleu_score["bleu"],
87
+ "rouge1": rouge_score["rouge1"],
88
+ "rougeL": rouge_score["rougeL"]
89
+ }
90
+
91
+ # === TRAINING ARGS === (no evaluation_strategy used)
92
+ training_args = TrainingArguments(
93
+ output_dir=output_dir,
94
+ per_device_train_batch_size=2,
95
+ per_device_eval_batch_size=2,
96
+ num_train_epochs=3,
97
+ save_strategy="epoch",
98
+ save_total_limit=1,
99
+ logging_dir="./logs",
100
+ logging_steps=10,
101
+ fp16=False,
102
+ report_to="none"
103
+ )
104
+
105
+ # === TRAINER ===
106
+ trainer = Trainer(
107
+ model=model,
108
+ args=training_args,
109
+ train_dataset=tokenized_train_dataset,
110
+ eval_dataset=tokenized_eval_dataset,
111
+ compute_metrics=compute_metrics
112
+ )
113
+
114
+ # === TRAIN & EVALUATE ===
115
+ print("Fine-tuning started...")
116
+ #trainer.train()
117
+ trainer.train(resume_from_checkpoint=True)
118
+
119
+ print("Running final evaluation...")
120
+ results = trainer.evaluate()
121
+ print("Final Evaluation Results:")
122
+ for metric, score in results.items():
123
+ print(f" {metric}: {score}")
124
+
125
+ # === SAVE MODEL ===
126
+ model.save_pretrained(os.path.join(output_dir, "final"))
127
+ tokenizer.save_pretrained(os.path.join(output_dir, "final"))
128
+ print("Fine-tuned model and tokenizer saved!")
129
+
requirements.txt CHANGED
@@ -1,58 +1,26 @@
1
- altair==5.5.0
2
- attrs==25.3.0
3
- blinker==1.9.0
4
- cachetools==6.1.0
5
- certifi==2025.6.15
6
- charset-normalizer==3.4.2
7
- click==8.2.1
8
- colorama==0.4.6
9
- filelock==3.18.0
10
- fsspec==2025.5.1
11
- gitdb==4.0.12
12
- GitPython==3.1.44
13
- huggingface-hub==0.33.2
14
- idna==3.10
15
- Jinja2==3.1.6
16
- joblib==1.5.1
17
- jsonschema==4.24.0
18
- jsonschema-specifications==2025.4.1
19
- MarkupSafe==3.0.2
20
- mpmath==1.3.0
21
- narwhals==1.45.0
22
- networkx==3.5
23
  nltk==3.9.1
24
- numpy==2.3.1
25
- packaging==25.0
26
- pandas==2.3.0
27
- pillow==11.3.0
28
- protobuf==6.31.1
29
- pyarrow==20.0.0
30
- pydeck==0.9.1
31
- python-dateutil==2.9.0.post0
32
- pytz==2025.2
33
- PyYAML==6.0.2
34
- referencing==0.36.2
35
- regex==2024.11.6
36
- requests==2.32.4
37
- rpds-py==0.26.0
38
- safetensors==0.5.3
39
- scikit-learn==1.7.0
40
- scipy==1.16.0
41
- sentence-transformers==3.1.1
42
- setuptools==80.9.0
43
- six==1.17.0
44
- smmap==5.0.2
45
- streamlit==1.46.1
46
- sympy==1.14.0
47
- tenacity==9.1.2
48
- threadpoolctl==3.6.0
49
- tokenizers==0.15.2
50
- toml==0.10.2
51
- torch==2.7.1
52
- tornado==6.5.1
53
- tqdm==4.67.1
54
- transformers==4.39.3
55
- typing_extensions==4.14.0
56
- tzdata==2025.2
57
- urllib3==2.5.0
58
- watchdog==6.0.0
 
1
+ # Core libraries
2
+ transformers==4.53.2
3
+ datasets==4.0.0
4
+ evaluate==0.4.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  nltk==3.9.1
6
+ pandas==2.3.1
7
+ numpy>=1.17
8
+ tqdm>=4.27
9
+ scipy
10
+
11
+ # PyTorch (CPU version)
12
+ torch==2.3.0
13
+ torchaudio==2.3.0
14
+ torchvision==0.18.0
15
+
16
+ # Hugging Face Hub
17
+ huggingface-hub>=0.16.4
18
+ safetensors>=0.4.3
19
+
20
+ # Optional but useful
21
+ pyarrow>=15.0.0
22
+ regex
23
+ filelock
24
+ fsspec
25
+
26
+ accelerate>=0.26.0