File size: 749 Bytes
11020de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# train_model.py
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
from datasets import load_from_disk

tokenized_dataset = load_from_disk("tokenized_dataset")

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

training_args = TrainingArguments(
    output_dir="./checkpoints",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    evaluation_strategy="no",
    save_strategy="epoch",
    fp16=True,  # if using GPU
    logging_steps=50,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
)

trainer.train()
model.save_pretrained("./my_ai_assistant", safe_serialization=True)  # saves .safetensors