GurgenGulay commited on
Commit
8cda88c
·
verified ·
1 Parent(s): 06a6d14

Update fine_tuning.py

Browse files
Files changed (1) hide show
  1. fine_tuning.py +47 -59
fine_tuning.py CHANGED
@@ -41,62 +41,50 @@ def prepare_data(input_texts, target_texts):
41
  targets = tokenizer(target_texts, max_length=512, truncation=True, padding="max_length")
42
  return {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "labels": targets["input_ids"]}
43
 
44
- def paraphrase_with_model(text, model, tokenizer):
45
- prompt = "Create a detailed, structured teaching transcript from the following text: " + text
46
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
47
- output_ids = model.generate(
48
- inputs["input_ids"],
49
- do_sample=True,
50
- top_k=40,
51
- top_p=0.9,
52
- temperature=0.8,
53
- max_length=300,
54
- no_repeat_ngram_size=3,
55
- early_stopping=True
56
- )
57
- return tokenizer.decode(output_ids[0], skip_special_tokens=True)
58
-
59
-
60
- model_name = "t5-base"
61
- tokenizer = T5Tokenizer.from_pretrained(model_name)
62
- model = T5ForConditionalGeneration.from_pretrained(model_name)
63
-
64
- try:
65
- logger.info("Reading and cleaning prompts.")
66
- input_texts, target_texts = read_prompts("prompts.txt")
67
- input_texts_cleaned = [clean_text(text) for text in input_texts]
68
- target_texts_cleaned = [clean_text(text) for text in target_texts]
69
-
70
- logger.info("Splitting dataset into training and validation sets.")
71
- train_texts, val_texts, train_labels, val_labels = train_test_split(input_texts_cleaned, target_texts_cleaned, test_size=0.1)
72
-
73
- logger.info("Preparing datasets for training.")
74
- train_dataset = Dataset.from_dict(prepare_data(train_texts, train_labels))
75
- val_dataset = Dataset.from_dict(prepare_data(val_texts, val_labels))
76
-
77
- training_args = TrainingArguments(
78
- output_dir="./results",
79
- evaluation_strategy="steps",
80
- learning_rate=5e-5,
81
- per_device_train_batch_size=4,
82
- num_train_epochs=3,
83
- save_steps=500,
84
- logging_dir="./logs",
85
- logging_steps=10
86
- )
87
-
88
- logger.info("Starting model training.")
89
- trainer = Trainer(
90
- model=model,
91
- args=training_args,
92
- train_dataset=train_dataset,
93
- eval_dataset=val_dataset
94
- )
95
- trainer.train()
96
-
97
- logger.info("Saving fine-tuned model.")
98
- model.save_pretrained("./fine_tuned_model")
99
- tokenizer.save_pretrained("./fine_tuned_model")
100
-
101
- except Exception as e:
102
- logger.error(f"An error occurred during fine-tuning: {str(e)}")
 
41
  targets = tokenizer(target_texts, max_length=512, truncation=True, padding="max_length")
42
  return {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "labels": targets["input_ids"]}
43
 
44
+ # Fine-tuning
45
+ def fine_tune_model():
46
+ model_name = "t5-base"
47
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
48
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
49
+
50
+ try:
51
+ logger.info("Reading and cleaning prompts.")
52
+ input_texts, target_texts = read_prompts("prompts.txt")
53
+ input_texts_cleaned = [clean_text(text) for text in input_texts]
54
+ target_texts_cleaned = [clean_text(text) for text in target_texts]
55
+
56
+ logger.info("Splitting dataset into training and validation sets.")
57
+ train_texts, val_texts, train_labels, val_labels = train_test_split(input_texts_cleaned, target_texts_cleaned, test_size=0.1)
58
+
59
+ logger.info("Preparing datasets for training.")
60
+ train_dataset = Dataset.from_dict(prepare_data(train_texts, train_labels, tokenizer))
61
+ val_dataset = Dataset.from_dict(prepare_data(val_texts, val_labels, tokenizer))
62
+
63
+ training_args = TrainingArguments(
64
+ output_dir="./results",
65
+ evaluation_strategy="steps",
66
+ learning_rate=5e-5,
67
+ per_device_train_batch_size=4,
68
+ num_train_epochs=3,
69
+ save_steps=500,
70
+ logging_dir="./logs",
71
+ logging_steps=10
72
+ )
73
+
74
+ logger.info("Starting model training.")
75
+ trainer = Trainer(
76
+ model=model,
77
+ args=training_args,
78
+ train_dataset=train_dataset,
79
+ eval_dataset=val_dataset
80
+ )
81
+ trainer.train()
82
+
83
+ logger.info("Saving fine-tuned model.")
84
+ model.save_pretrained("./fine_tuned_model")
85
+ tokenizer.save_pretrained("./fine_tuned_model")
86
+
87
+ except Exception as e:
88
+ logger.error(f"An error occurred during fine-tuning: {str(e)}")
89
+
90
+ fine_tune_model()