|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq
|
|
from datasets import load_dataset
|
|
|
|
|
|
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
MODEL_NAME,
|
|
torch_dtype=torch.float16,
|
|
device_map="auto"
|
|
)
|
|
|
|
|
|
dataset = load_dataset("json", data_files="processed_dataset.json")
|
|
|
|
|
|
def tokenize_function(examples):
|
|
return tokenizer(examples["prompt"], examples["response"], padding="max_length", truncation=True)
|
|
|
|
|
|
dataset = dataset.map(tokenize_function, batched=True)
|
|
dataset = dataset.remove_columns(["prompt", "response"])
|
|
|
|
|
|
data_collator = DataCollatorForSeq2Seq(
|
|
tokenizer=tokenizer,
|
|
model=model,
|
|
padding=True,
|
|
return_tensors="pt"
|
|
)
|
|
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir="./results",
|
|
num_train_epochs=3,
|
|
per_device_train_batch_size=4,
|
|
per_device_eval_batch_size=4,
|
|
save_steps=10_000,
|
|
save_total_limit=2,
|
|
logging_dir="./logs",
|
|
logging_steps=200,
|
|
remove_unused_columns=False,
|
|
fp16=True,
|
|
)
|
|
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=dataset["train"],
|
|
data_collator=data_collator,
|
|
)
|
|
|
|
|
|
trainer.train()
|
|
|