File size: 4,245 Bytes
e3d5fb5
 
 
 
 
 
 
 
 
 
 
 
 
ec1229f
e3d5fb5
 
 
 
31cc30e
ec1229f
 
14b014c
ec1229f
 
 
 
 
 
e3d5fb5
 
 
 
 
 
 
 
 
 
cf93ccc
 
 
 
 
31cc30e
 
 
 
 
 
 
 
 
 
 
 
 
cf93ccc
 
 
 
 
e3d5fb5
 
 
 
 
 
 
 
31cc30e
 
 
 
 
 
e3d5fb5
 
 
 
 
 
 
ec1229f
e3d5fb5
 
 
 
27e355f
e3d5fb5
 
27e355f
e3d5fb5
 
 
 
 
 
 
27e355f
 
e3d5fb5
 
 
 
 
 
 
 
 
 
27e355f
e3d5fb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31cc30e
e3d5fb5
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# /// script
# dependencies = [
#     "trl>=0.12.0",
#     "peft>=0.7.0",
#     "transformers>=4.36.0",
#     "accelerate>=0.24.0",
#     "datasets>=2.14.0",
#     "trackio",
#     "torch",
#     "bitsandbytes",
# ]
# ///

import os
import trackio
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer
from huggingface_hub import login

# Login with HF token
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
    login(token=hf_token)
    print("Logged in to Hugging Face Hub")
else:
    print("Warning: HF_TOKEN not found in environment")

# Load dataset - using the solutions configuration with messages format
print("Loading open-r1/codeforces-cots dataset...")
dataset = load_dataset("open-r1/codeforces-cots", "solutions", split="train")
print(f"Full dataset loaded: {len(dataset)} examples")

# Take 1000 examples for demo
dataset = dataset.select(range(min(1000, len(dataset))))
print(f"Using {len(dataset)} examples for demo training")

# The dataset has both 'prompt' (string) and 'messages' (chat format) columns
# TRL gets confused with both present. Keep only 'messages' for chat-based SFT.
print("Preparing dataset for chat-based SFT...")

# Filter for valid messages and keep only the messages column
def filter_valid_messages(example):
    """Filter out samples with empty or invalid messages"""
    messages = example.get("messages", [])
    if not messages or len(messages) < 2:
        return False
    for msg in messages:
        if not msg.get("content"):
            return False
    return True

dataset = dataset.filter(filter_valid_messages)
print(f"After filtering: {len(dataset)} examples")

# Remove all columns except 'messages' to avoid confusion
columns_to_remove = [col for col in dataset.column_names if col != "messages"]
dataset = dataset.remove_columns(columns_to_remove)
print(f"Dataset columns: {dataset.column_names}")

# Create train/eval split
print("Creating train/eval split...")
dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]
print(f"   Train: {len(train_dataset)} examples")
print(f"   Eval: {len(eval_dataset)} examples")

# Load tokenizer for chat template
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Training configuration
config = SFTConfig(
    # CRITICAL: Hub settings
    output_dir="qwen3-0.6b-codeforces-sft",
    push_to_hub=True,
    hub_model_id="Godsonntungi2/qwen3-0.6b-codeforces-sft",
    hub_strategy="every_save",
    hub_token=hf_token,  # Explicitly pass token

    # Training parameters
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=1,  # Smaller eval batch to prevent OOM
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    max_length=1024,  # Reduced from 2048 to save memory

    # Logging & checkpointing
    logging_steps=10,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,

    # Evaluation - disable to save memory and time
    eval_strategy="no",

    # Optimization
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    gradient_checkpointing=True,
    bf16=True,

    # Monitoring
    report_to="trackio",
    project="qwen3-codeforces-sft",
    run_name="demo-1k-v2",
)

# LoRA configuration for efficient training
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)

# Initialize and train
print("Initializing trainer with Qwen/Qwen3-0.6B...")
trainer = SFTTrainer(
    model="Qwen/Qwen3-0.6B",
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    processing_class=tokenizer,
    args=config,
    peft_config=peft_config,
)

print("Starting training...")
trainer.train()

print("Pushing to Hub...")
trainer.push_to_hub()

print("Complete! Model at: https://huggingface.co/Godsonntungi2/qwen3-0.6b-codeforces-sft")
print("View metrics at: https://huggingface.co/spaces/Godsonntungi2/trackio")