taxagent / finetune_tinyllama.py
fragger246's picture
Upload 6 files
2cb9c2c verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq
from datasets import load_dataset
# Model name
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16, # Use float16 for better efficiency
device_map="auto" # Use GPU if available
)
# Load dataset from JSON file
dataset = load_dataset("json", data_files="processed_dataset.json")
# Tokenization function
def tokenize_function(examples):
return tokenizer(examples["prompt"], examples["response"], padding="max_length", truncation=True)
# Apply tokenization
dataset = dataset.map(tokenize_function, batched=True)
dataset = dataset.remove_columns(["prompt", "response"]) # Keep only tokenized data
# Data collator (for batching and padding)
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
model=model,
padding=True,
return_tensors="pt"
)
# Training arguments
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
save_steps=10_000,
save_total_limit=2,
logging_dir="./logs",
logging_steps=200,
remove_unused_columns=False, # Ensure tokenized data isn't removed
fp16=True, # Enable mixed precision if using GPU
)
# Trainer setup
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
data_collator=data_collator,
)
# Start training
trainer.train()