LivLoRA / train.py
lyricolivia20's picture
Update train.py
cedeb44 verified
import os
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
import torch
# 1. Verify CUDA and setup
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.current_device()}")
# 2. Load dataset (using tiny subset for testing)
dataset = load_dataset("imdb", split="train[:1%]") # Replace with your dataset
# 3. Configure 4-bit quantization (reduces VRAM usage)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
# 4. Load model (using TinyLlama for quick testing)
model = AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
quantization_config=bnb_config,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
tokenizer.pad_token = tokenizer.eos_token
# 5. Configure LoRA
lora_config = LoraConfig(
r=8, # Rank
lora_alpha=32, # Scaling factor
target_modules=["q_proj", "v_proj"], # Modules to target
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
# 6. Convert model to LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters() # Should show ~0.1% of parameters
# 7. Training arguments (optimized for Spaces' A10G GPU)
training_args = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=2, # Reduce if OOM
gradient_accumulation_steps=4, # Effective batch size = 8
learning_rate=2e-4,
logging_steps=10,
max_steps=100, # Short demo run
save_steps=50,
fp16=True, # Use mixed precision
push_to_hub=True,
hub_model_id="lyricolivia20/LivLoRA", # Your HF repo
)
# 8. Create trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
dataset_text_field="text", # Update with your dataset's text column
args=training_args,
tokenizer=tokenizer,
max_seq_length=512,
)
# 9. Train and push
print("Starting training...")
trainer.train()
trainer.push_to_hub()
print("βœ… Training complete! Adapter pushed to Hub.")