File size: 749 Bytes
11020de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
# train_model.py
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
from datasets import load_from_disk
tokenized_dataset = load_from_disk("tokenized_dataset")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
training_args = TrainingArguments(
output_dir="./checkpoints",
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
evaluation_strategy="no",
save_strategy="epoch",
fp16=True, # if using GPU
logging_steps=50,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
)
trainer.train()
model.save_pretrained("./my_ai_assistant", safe_serialization=True) # saves .safetensors
|