File size: 4,245 Bytes
e3d5fb5 ec1229f e3d5fb5 31cc30e ec1229f 14b014c ec1229f e3d5fb5 cf93ccc 31cc30e cf93ccc e3d5fb5 31cc30e e3d5fb5 ec1229f e3d5fb5 27e355f e3d5fb5 27e355f e3d5fb5 27e355f e3d5fb5 27e355f e3d5fb5 31cc30e e3d5fb5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | # /// script
# dependencies = [
# "trl>=0.12.0",
# "peft>=0.7.0",
# "transformers>=4.36.0",
# "accelerate>=0.24.0",
# "datasets>=2.14.0",
# "trackio",
# "torch",
# "bitsandbytes",
# ]
# ///
import os
import trackio
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer
from huggingface_hub import login
# Login with HF token
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
print("Logged in to Hugging Face Hub")
else:
print("Warning: HF_TOKEN not found in environment")
# Load dataset - using the solutions configuration with messages format
print("Loading open-r1/codeforces-cots dataset...")
dataset = load_dataset("open-r1/codeforces-cots", "solutions", split="train")
print(f"Full dataset loaded: {len(dataset)} examples")
# Take 1000 examples for demo
dataset = dataset.select(range(min(1000, len(dataset))))
print(f"Using {len(dataset)} examples for demo training")
# The dataset has both 'prompt' (string) and 'messages' (chat format) columns
# TRL gets confused with both present. Keep only 'messages' for chat-based SFT.
print("Preparing dataset for chat-based SFT...")
# Filter for valid messages and keep only the messages column
def filter_valid_messages(example):
"""Filter out samples with empty or invalid messages"""
messages = example.get("messages", [])
if not messages or len(messages) < 2:
return False
for msg in messages:
if not msg.get("content"):
return False
return True
dataset = dataset.filter(filter_valid_messages)
print(f"After filtering: {len(dataset)} examples")
# Remove all columns except 'messages' to avoid confusion
columns_to_remove = [col for col in dataset.column_names if col != "messages"]
dataset = dataset.remove_columns(columns_to_remove)
print(f"Dataset columns: {dataset.column_names}")
# Create train/eval split
print("Creating train/eval split...")
dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]
print(f" Train: {len(train_dataset)} examples")
print(f" Eval: {len(eval_dataset)} examples")
# Load tokenizer for chat template
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Training configuration
config = SFTConfig(
# CRITICAL: Hub settings
output_dir="qwen3-0.6b-codeforces-sft",
push_to_hub=True,
hub_model_id="Godsonntungi2/qwen3-0.6b-codeforces-sft",
hub_strategy="every_save",
hub_token=hf_token, # Explicitly pass token
# Training parameters
num_train_epochs=3,
per_device_train_batch_size=2,
per_device_eval_batch_size=1, # Smaller eval batch to prevent OOM
gradient_accumulation_steps=8,
learning_rate=2e-5,
max_length=1024, # Reduced from 2048 to save memory
# Logging & checkpointing
logging_steps=10,
save_strategy="steps",
save_steps=100,
save_total_limit=2,
# Evaluation - disable to save memory and time
eval_strategy="no",
# Optimization
warmup_ratio=0.1,
lr_scheduler_type="cosine",
gradient_checkpointing=True,
bf16=True,
# Monitoring
report_to="trackio",
project="qwen3-codeforces-sft",
run_name="demo-1k-v2",
)
# LoRA configuration for efficient training
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
# Initialize and train
print("Initializing trainer with Qwen/Qwen3-0.6B...")
trainer = SFTTrainer(
model="Qwen/Qwen3-0.6B",
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
args=config,
peft_config=peft_config,
)
print("Starting training...")
trainer.train()
print("Pushing to Hub...")
trainer.push_to_hub()
print("Complete! Model at: https://huggingface.co/Godsonntungi2/qwen3-0.6b-codeforces-sft")
print("View metrics at: https://huggingface.co/spaces/Godsonntungi2/trackio")
|