File size: 7,608 Bytes
4a1fd53
 
 
721f5f5
a57357b
 
 
20852a7
 
 
decfb95
20852a7
decfb95
a57357b
4a1fd53
20852a7
 
 
 
 
 
 
4a1fd53
 
 
 
 
 
 
 
 
 
 
 
 
 
20852a7
 
 
 
 
decfb95
20852a7
decfb95
 
20852a7
4a1fd53
decfb95
a57357b
20852a7
 
4a1fd53
20852a7
4a1fd53
 
20852a7
4a1fd53
 
 
 
 
 
20852a7
4a1fd53
 
 
 
 
 
20852a7
4a1fd53
 
 
 
20852a7
4a1fd53
 
 
 
 
 
 
 
 
 
a7d1f2a
 
 
 
4a1fd53
 
 
20852a7
4a1fd53
 
a7d1f2a
4a1fd53
 
 
a57357b
4a1fd53
 
 
 
 
20852a7
a57357b
4a1fd53
a57357b
20852a7
 
a57357b
ae57ea2
 
 
4a1fd53
 
 
a57357b
4a1fd53
20852a7
a57357b
4a1fd53
 
a57357b
4a1fd53
20852a7
4a1fd53
a57357b
 
4a1fd53
 
 
a57357b
4a1fd53
 
 
a57357b
4a1fd53
 
 
 
 
 
20852a7
4a1fd53
 
 
 
 
 
 
 
 
 
 
 
a7d1f2a
4a1fd53
 
a7d1f2a
4a1fd53
 
 
 
 
 
 
 
 
 
a57357b
4a1fd53
a57357b
 
decfb95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a1fd53
 
fbbcd99
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#!/usr/bin/env python
# coding=utf-8

import unsloth
import os
import sys
import json
import logging
import subprocess
import time
import traceback
from datetime import datetime
from pathlib import Path

# Configure logging to match HF Space logs
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

# Set other loggers to WARNING to reduce noise and ensure our logs are visible
logging.getLogger("transformers").setLevel(logging.WARNING)
logging.getLogger("datasets").setLevel(logging.WARNING)
logging.getLogger("accelerate").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("bitsandbytes").setLevel(logging.WARNING)

# Define a clean logging function for HF Space compatibility
def log_info(message):
    """Log information in a format compatible with Hugging Face Spaces"""
    logger.info(message)
    # Ensure output is flushed immediately for streaming
    sys.stdout.flush()

# Configuration paths
CONFIG_DIR = "."
TRANSFORMERS_CONFIG = os.path.join(CONFIG_DIR, "transformers_config.json")

def load_config(config_path):
    """Load configuration from a JSON file."""
    try:
        with open(config_path, 'r') as f:
            return json.load(f)
    except Exception as e:
        log_info(f"Error loading config: {str(e)}")
        return {}

def display_config():
    """Display current training configuration."""
    config = load_config(TRANSFORMERS_CONFIG)
    
    if not config:
        return "Error loading configuration file."
    
    # Extract sub-configurations
    transformers_config = config
    hardware_config = config.get("hardware", {})
    dataset_config = config.get("dataset", {})
    
    model_name = transformers_config.get("model", {}).get("name") or transformers_config.get("model_name_or_path", "")
    
    # Training parameters
    training_config = transformers_config.get("training", {})
    batch_size = training_config.get("per_device_train_batch_size", 16)
    grad_accum = training_config.get("gradient_accumulation_steps", 3)
    epochs = training_config.get("num_train_epochs", 3)
    learning_rate = training_config.get("learning_rate", 2e-5)
    
    # Hardware settings
    gpu_count = hardware_config.get("specs", {}).get("gpu_count", 4)
    gpu_type = hardware_config.get("specs", {}).get("gpu_type", "L4")
    vram = hardware_config.get("specs", {}).get("vram_per_gpu", 24)
    
    # Dataset info
    dataset_name = dataset_config.get("dataset", {}).get("name", "")
    
    # Format response as HTML for better display
    html = f"""
    <h2>Training Configuration</h2>
    <h3>Model</h3>
    <ul>
        <li><b>Model:</b> {model_name}</li>
        <li><b>Learning Rate:</b> {training_config.get('learning_rate', '2e-5')}</li>
        <li><b>Per-Device Batch Size:</b> {batch_size}</li>
        <li><b>Gradient Accumulation:</b> {grad_accum}</li>
        <li><b>Total Effective Batch Size:</b> {batch_size} × {gpu_count} × {grad_accum} = {batch_size * gpu_count * grad_accum}</li>
        <li><b>Epochs:</b> {epochs}</li>
        <li><b>Precision:</b> {'BF16' if transformers_config.get('bf16', True) else 'FP16' if transformers_config.get('fp16', False) else 'FP32'}</li>
        <li><b>Max Sequence Length:</b> {transformers_config.get('tokenizer', {}).get('max_seq_length', 2048)}</li>
    </ul>
    
    <h3>Hardware</h3>
    <ul>
        <li><b>GPU:</b> {gpu_count}× {gpu_type} ({vram} GB VRAM per GPU, total: {vram * gpu_count} GB)</li>
        <li><b>Multi-GPU Strategy:</b> {hardware_config.get('training_optimizations', {}).get('multi_gpu_strategy', 'data_parallel')}</li>
        <li><b>Memory Optimizations:</b> {'Gradient Checkpointing' if hardware_config.get('training_optimizations', {}).get('memory_optimizations', {}).get('use_gradient_checkpointing', True) else 'None'}</li>
    </ul>
    
    <h3>Dataset</h3>
    <ul>
        <li><b>Dataset:</b> {dataset_name}</li>
        <li><b>Dataset Split:</b> {dataset_config.get('dataset', {}).get('split', 'train')}</li>
    </ul>
    """
    
    return html

def start_training():
    """Start the training process."""
    try:
        # Log configuration check
        log_info("Preparing to start training process...")
        log_info("Using consolidated configuration from transformers_config.json")
            
        # Start training
        log_info("Starting training process...")
        
        # Run in a background process for HF Space
        cmd = "python run_transformers_training.py"
        
        # In HF Spaces, we don't need to handle process management ourselves
        subprocess.Popen(cmd, shell=True, stdout=sys.stdout, stderr=sys.stderr)
        
        log_info("Training process has been started. You can monitor progress in the logs.")
        
        return "Training started successfully. Monitor progress in the Hugging Face Space logs."
        
    except Exception as e:
        error_msg = f"Error starting training: {str(e)}"
        log_info(error_msg)
        return error_msg

# Interface setup for gradio
def create_interface():
    import gradio as gr
    
    with gr.Blocks(title="Phi-4 Training Center") as demo:
        gr.Markdown("# Phi-4 Research Assistant Training")
        
        with gr.Row():
            with gr.Column():
                gr.Markdown("## Control Panel")
                
                # Display current config
                config_html = gr.HTML(display_config())
                refresh_btn = gr.Button("Refresh Configuration")
                
                # Training controls
                train_btn = gr.Button("Start Training", variant="primary")
                train_output = gr.Textbox(label="Status", interactive=False)
                
            with gr.Column():
                gr.Markdown("## Training Information")
                gr.Markdown("""
                ### Hardware:
                - 4× NVIDIA L4 GPUs (24GB VRAM per GPU, 96GB total)
                - Training with BF16 precision
                - Using Data Parallel for multi-GPU
                - Effective batch size: 16 (per device) × 4 (GPUs) × 3 (gradient accumulation) = 192
                
                ### Notes:
                - Training may take several hours depending on dataset size
                - Check the Space logs for real-time progress
                - Model checkpoints will be saved to ./results directory
                """)
        
        # Connect buttons to functions
        refresh_btn.click(lambda: gr.update(value=display_config()), outputs=config_html)
        train_btn.click(start_training, outputs=train_output)
    
    return demo

if __name__ == "__main__":
    # Print basic system information to help with debugging
    try:
        import torch
        logger.info(f"Python: {sys.version.split()[0]}")
        logger.info(f"PyTorch: {torch.__version__}")
        logger.info(f"CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            logger.info(f"CUDA device: {torch.cuda.get_device_name(0)}")
        
        from transformers import __version__ as tf_version
        logger.info(f"Transformers: {tf_version}")
        
        from unsloth import __version__ as un_version
        logger.info(f"Unsloth: {un_version}")
    except Exception as e:
        logger.warning(f"Error printing system info: {e}")
    
    # Create and launch the Gradio interface
    demo = create_interface()
    demo.queue()
    demo.launch(share=True)