brunner56's picture
Update app.py
721f5f5 verified
#!/usr/bin/env python
# coding=utf-8
import unsloth
import os
import sys
import json
import logging
import subprocess
import time
import traceback
from datetime import datetime
from pathlib import Path
# Configure logging to match HF Space logs
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
# Set other loggers to WARNING to reduce noise and ensure our logs are visible
logging.getLogger("transformers").setLevel(logging.WARNING)
logging.getLogger("datasets").setLevel(logging.WARNING)
logging.getLogger("accelerate").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("bitsandbytes").setLevel(logging.WARNING)
# Define a clean logging function for HF Space compatibility
def log_info(message):
"""Log information in a format compatible with Hugging Face Spaces"""
logger.info(message)
# Ensure output is flushed immediately for streaming
sys.stdout.flush()
# Configuration paths
CONFIG_DIR = "."
TRANSFORMERS_CONFIG = os.path.join(CONFIG_DIR, "transformers_config.json")
def load_config(config_path):
"""Load configuration from a JSON file."""
try:
with open(config_path, 'r') as f:
return json.load(f)
except Exception as e:
log_info(f"Error loading config: {str(e)}")
return {}
def display_config():
"""Display current training configuration."""
config = load_config(TRANSFORMERS_CONFIG)
if not config:
return "Error loading configuration file."
# Extract sub-configurations
transformers_config = config
hardware_config = config.get("hardware", {})
dataset_config = config.get("dataset", {})
model_name = transformers_config.get("model", {}).get("name") or transformers_config.get("model_name_or_path", "")
# Training parameters
training_config = transformers_config.get("training", {})
batch_size = training_config.get("per_device_train_batch_size", 16)
grad_accum = training_config.get("gradient_accumulation_steps", 3)
epochs = training_config.get("num_train_epochs", 3)
learning_rate = training_config.get("learning_rate", 2e-5)
# Hardware settings
gpu_count = hardware_config.get("specs", {}).get("gpu_count", 4)
gpu_type = hardware_config.get("specs", {}).get("gpu_type", "L4")
vram = hardware_config.get("specs", {}).get("vram_per_gpu", 24)
# Dataset info
dataset_name = dataset_config.get("dataset", {}).get("name", "")
# Format response as HTML for better display
html = f"""
<h2>Training Configuration</h2>
<h3>Model</h3>
<ul>
<li><b>Model:</b> {model_name}</li>
<li><b>Learning Rate:</b> {training_config.get('learning_rate', '2e-5')}</li>
<li><b>Per-Device Batch Size:</b> {batch_size}</li>
<li><b>Gradient Accumulation:</b> {grad_accum}</li>
<li><b>Total Effective Batch Size:</b> {batch_size} × {gpu_count} × {grad_accum} = {batch_size * gpu_count * grad_accum}</li>
<li><b>Epochs:</b> {epochs}</li>
<li><b>Precision:</b> {'BF16' if transformers_config.get('bf16', True) else 'FP16' if transformers_config.get('fp16', False) else 'FP32'}</li>
<li><b>Max Sequence Length:</b> {transformers_config.get('tokenizer', {}).get('max_seq_length', 2048)}</li>
</ul>
<h3>Hardware</h3>
<ul>
<li><b>GPU:</b> {gpu_count}× {gpu_type} ({vram} GB VRAM per GPU, total: {vram * gpu_count} GB)</li>
<li><b>Multi-GPU Strategy:</b> {hardware_config.get('training_optimizations', {}).get('multi_gpu_strategy', 'data_parallel')}</li>
<li><b>Memory Optimizations:</b> {'Gradient Checkpointing' if hardware_config.get('training_optimizations', {}).get('memory_optimizations', {}).get('use_gradient_checkpointing', True) else 'None'}</li>
</ul>
<h3>Dataset</h3>
<ul>
<li><b>Dataset:</b> {dataset_name}</li>
<li><b>Dataset Split:</b> {dataset_config.get('dataset', {}).get('split', 'train')}</li>
</ul>
"""
return html
def start_training():
"""Start the training process."""
try:
# Log configuration check
log_info("Preparing to start training process...")
log_info("Using consolidated configuration from transformers_config.json")
# Start training
log_info("Starting training process...")
# Run in a background process for HF Space
cmd = "python run_transformers_training.py"
# In HF Spaces, we don't need to handle process management ourselves
subprocess.Popen(cmd, shell=True, stdout=sys.stdout, stderr=sys.stderr)
log_info("Training process has been started. You can monitor progress in the logs.")
return "Training started successfully. Monitor progress in the Hugging Face Space logs."
except Exception as e:
error_msg = f"Error starting training: {str(e)}"
log_info(error_msg)
return error_msg
# Interface setup for gradio
def create_interface():
import gradio as gr
with gr.Blocks(title="Phi-4 Training Center") as demo:
gr.Markdown("# Phi-4 Research Assistant Training")
with gr.Row():
with gr.Column():
gr.Markdown("## Control Panel")
# Display current config
config_html = gr.HTML(display_config())
refresh_btn = gr.Button("Refresh Configuration")
# Training controls
train_btn = gr.Button("Start Training", variant="primary")
train_output = gr.Textbox(label="Status", interactive=False)
with gr.Column():
gr.Markdown("## Training Information")
gr.Markdown("""
### Hardware:
- 4× NVIDIA L4 GPUs (24GB VRAM per GPU, 96GB total)
- Training with BF16 precision
- Using Data Parallel for multi-GPU
- Effective batch size: 16 (per device) × 4 (GPUs) × 3 (gradient accumulation) = 192
### Notes:
- Training may take several hours depending on dataset size
- Check the Space logs for real-time progress
- Model checkpoints will be saved to ./results directory
""")
# Connect buttons to functions
refresh_btn.click(lambda: gr.update(value=display_config()), outputs=config_html)
train_btn.click(start_training, outputs=train_output)
return demo
if __name__ == "__main__":
# Print basic system information to help with debugging
try:
import torch
logger.info(f"Python: {sys.version.split()[0]}")
logger.info(f"PyTorch: {torch.__version__}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"CUDA device: {torch.cuda.get_device_name(0)}")
from transformers import __version__ as tf_version
logger.info(f"Transformers: {tf_version}")
from unsloth import __version__ as un_version
logger.info(f"Unsloth: {un_version}")
except Exception as e:
logger.warning(f"Error printing system info: {e}")
# Create and launch the Gradio interface
demo = create_interface()
demo.queue()
demo.launch(share=True)