Spaces:
Build error
Build error
import os | |
from datasets import load_dataset | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TrainingArguments, | |
BitsAndBytesConfig | |
) | |
from peft import LoraConfig, get_peft_model | |
from trl import SFTTrainer | |
import torch | |
# 1. Verify CUDA and setup | |
print(f"CUDA available: {torch.cuda.is_available()}") | |
print(f"CUDA device: {torch.cuda.current_device()}") | |
# 2. Load dataset (using tiny subset for testing) | |
dataset = load_dataset("imdb", split="train[:1%]") # Replace with your dataset | |
# 3. Configure 4-bit quantization (reduces VRAM usage) | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16, | |
) | |
# 4. Load model (using TinyLlama for quick testing) | |
model = AutoModelForCausalLM.from_pretrained( | |
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
quantization_config=bnb_config, | |
device_map="auto" | |
) | |
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
tokenizer.pad_token = tokenizer.eos_token | |
# 5. Configure LoRA | |
lora_config = LoraConfig( | |
r=8, # Rank | |
lora_alpha=32, # Scaling factor | |
target_modules=["q_proj", "v_proj"], # Modules to target | |
lora_dropout=0.05, | |
bias="none", | |
task_type="CAUSAL_LM", | |
) | |
# 6. Convert model to LoRA | |
model = get_peft_model(model, lora_config) | |
model.print_trainable_parameters() # Should show ~0.1% of parameters | |
# 7. Training arguments (optimized for Spaces' A10G GPU) | |
training_args = TrainingArguments( | |
output_dir="./output", | |
per_device_train_batch_size=2, # Reduce if OOM | |
gradient_accumulation_steps=4, # Effective batch size = 8 | |
learning_rate=2e-4, | |
logging_steps=10, | |
max_steps=100, # Short demo run | |
save_steps=50, | |
fp16=True, # Use mixed precision | |
push_to_hub=True, | |
hub_model_id="lyricolivia20/LivLoRA", # Your HF repo | |
) | |
# 8. Create trainer | |
trainer = SFTTrainer( | |
model=model, | |
train_dataset=dataset, | |
dataset_text_field="text", # Update with your dataset's text column | |
args=training_args, | |
tokenizer=tokenizer, | |
max_seq_length=512, | |
) | |
# 9. Train and push | |
print("Starting training...") | |
trainer.train() | |
trainer.push_to_hub() | |
print("β Training complete! Adapter pushed to Hub.") |