#!/usr/bin/env python # coding=utf-8 import os import sys import json import argparse import logging from datetime import datetime import time import warnings import torch from importlib.util import find_spec # Global variables for hardware detection CUDA_AVAILABLE = torch.cuda.is_available() NUM_GPUS = torch.cuda.device_count() if CUDA_AVAILABLE else 0 DEVICE_TYPE = "cuda" if CUDA_AVAILABLE else "cpu" # Import Unsloth first, before other ML imports try: from unsloth import FastLanguageModel from unsloth.chat_templates import get_chat_template unsloth_available = True except ImportError: unsloth_available = False logger = logging.getLogger(__name__) logger.warning("Unsloth not available. Please install with: pip install unsloth") from datasets import load_dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, set_seed, BitsAndBytesConfig ) # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger(__name__) # Set other loggers to WARNING to reduce noise and ensure our logs are visible logging.getLogger("transformers").setLevel(logging.WARNING) logging.getLogger("datasets").setLevel(logging.WARNING) logging.getLogger("accelerate").setLevel(logging.WARNING) logging.getLogger("torch").setLevel(logging.WARNING) logging.getLogger("bitsandbytes").setLevel(logging.WARNING) # Check availability of libraries peft_available = find_spec("peft") is not None # Define a clean logging function for HF Space compatibility def log_info(message): """Log information in a format compatible with Hugging Face Spaces""" # Just use the logger, but ensure consistent formatting logger.info(message) # Also ensure output is flushed immediately for streaming sys.stdout.flush() # Check for BitsAndBytes try: from transformers import BitsAndBytesConfig bitsandbytes_available = True except ImportError: bitsandbytes_available = False logger.warning("BitsAndBytes not available. 4-bit quantization will not be used.") # Check for PEFT try: from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training peft_available = True except ImportError: peft_available = False logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.") def load_env_variables(): """Load environment variables from system, .env file, or Hugging Face Space variables.""" # Check if we're running in a Hugging Face Space if os.environ.get("SPACE_ID"): logging.info("Running in Hugging Face Space") # Log the presence of variables (without revealing values) logging.info(f"HF_TOKEN available: {bool(os.environ.get('HF_TOKEN'))}") logging.info(f"HF_USERNAME available: {bool(os.environ.get('HF_USERNAME'))}") # If username is not set, try to extract from SPACE_ID if not os.environ.get("HF_USERNAME") and "/" in os.environ.get("SPACE_ID", ""): username = os.environ.get("SPACE_ID").split("/")[0] os.environ["HF_USERNAME"] = username logging.info(f"Set HF_USERNAME from SPACE_ID: {username}") else: # Try to load from .env file if not in a Space try: from dotenv import load_dotenv # Updated path to .env file in the new directory structure env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "shared", ".env") if os.path.exists(env_path): load_dotenv(env_path) logging.info(f"Loaded environment variables from {env_path}") logging.info(f"HF_TOKEN loaded from .env file: {bool(os.environ.get('HF_TOKEN'))}") logging.info(f"HF_USERNAME loaded from .env file: {bool(os.environ.get('HF_USERNAME'))}") logging.info(f"HF_SPACE_NAME loaded from .env file: {bool(os.environ.get('HF_SPACE_NAME'))}") else: logging.warning(f"No .env file found at {env_path}") except ImportError: logging.warning("python-dotenv not installed, not loading from .env file") if not os.environ.get("HF_USERNAME"): logger.warning("HF_USERNAME is not set. Using default username.") if not os.environ.get("HF_SPACE_NAME"): logger.warning("HF_SPACE_NAME is not set. Using default space name.") # Set HF_TOKEN for huggingface_hub if os.environ.get("HF_TOKEN"): os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN") def load_configs(base_path): """Load configuration from transformers_config.json file.""" # Using a single consolidated config file config_file = base_path try: with open(config_file, "r") as f: config = json.load(f) logger.info(f"Loaded configuration from {config_file}") return config except Exception as e: logger.error(f"Error loading {config_file}: {e}") raise def parse_args(): parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset") parser.add_argument("--config", type=str, default="transformers_config.json", help="Path to configuration file") return parser.parse_args() def load_model_and_tokenizer(config): """Load model and tokenizer with proper error handling and optimizations.""" try: if not unsloth_available: logger.error("Unsloth is required for training with pre-quantized model") logger.error("Please ensure unsloth is in requirements.txt") raise ImportError("Unsloth is required for this training setup") # Get model name correctly from config model_name = config.get("model_name") or config.get("model", {}).get("name") logger.info(f"Loading model: {model_name}") if not model_name: raise ValueError("Model name not found in configuration. Please check your transformers_config.json file.") logger.info("Using Unsloth optimizations with pre-quantized model") # Check for flash attention use_flash_attention = config.get("use_flash_attention", True) if use_flash_attention and not find_spec("flash_attn"): logger.warning("flash-attn not found. Will continue without flash attention.") logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation") use_flash_attention = False # First detect if we have a GPU if torch.cuda.is_available(): gpu_count = torch.cuda.device_count() logger.info(f"CUDA available, found {gpu_count} GPU(s)") # Log GPU info for i in range(gpu_count): logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}") logger.info(f"Memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB") # Create an optimized device map for better balance if gpu_count > 1: logger.info(f"Creating balanced device map for {gpu_count} GPUs") # Use auto mapping but with memory tracking device_map = "auto" # Set max memory for better balancing max_memory = {i: f"{int(torch.cuda.get_device_properties(i).total_memory * 0.85 / 1024**3)}GiB" for i in range(gpu_count)} logger.info(f"Max memory settings: {max_memory}") else: device_map = "auto" max_memory = None else: logger.warning("No CUDA available, falling back to CPU") device_map = {"": "cpu"} # Force CPU placement max_memory = None # Set default dtype for better numerics if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: # Use bfloat16 for Ampere or newer dtype = torch.bfloat16 logger.info("Using bfloat16 precision (Ampere+ GPU)") elif torch.cuda.is_available(): # Use float16 for older GPUs dtype = torch.float16 logger.info("Using float16 precision (pre-Ampere GPU)") else: # CPU, use default dtype dtype = None logger.info("Using default precision (CPU)") # Load model with proper error handling for out-of-memory try: # Improved memory settings for multi-GPU setup os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_name, max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048), dtype=dtype, device_map=device_map, max_memory=max_memory, # Don't explicitly use flash attention config here, let Unsloth handle it ) except RuntimeError as e: if "CUDA out of memory" in str(e): logger.error("Out of GPU memory. Consider using a smaller batch size or gradient accumulation steps.") raise else: # Try again with CPU placement to see if it's a memory issue logger.warning(f"Error loading model on default device: {str(e)}") logger.warning("Attempting to load with device_map='cpu' and no specific dtype") model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_name, max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048), dtype=None, device_map={"": "cpu"}, ) logger.warning("Model loaded on CPU. Training will be very slow.") # Ensure model and optimizer init is on the same device logger.info(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'Not available'}") # Apply Unsloth's training optimizations with config parameters unsloth_config = config.get("unsloth", {}) model = FastLanguageModel.get_peft_model( model, r=unsloth_config.get("r", 32), target_modules=unsloth_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]), lora_alpha=unsloth_config.get("alpha", 16), lora_dropout=unsloth_config.get("dropout", 0.05), bias="none", use_gradient_checkpointing=config.get("gradient_checkpointing", True) or config.get("training", {}).get("gradient_checkpointing", True), random_state=config.get("seed", 42), ) logger.info("Unsloth optimizations applied successfully") # Set up tokenizer settings chat_template = config.get("chat_template") or config.get("tokenizer", {}).get("chat_template") if chat_template: try: template = get_chat_template("phi") tokenizer.chat_template = template logger.info("Set phi chat template") except Exception as e: logger.warning(f"Failed to set chat template: {str(e)}") # Ensure proper token settings if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id logger.info(f"Set pad_token_id to eos_token_id: {tokenizer.pad_token_id}") return model, tokenizer except Exception as e: logger.error(f"Error in model/tokenizer loading: {str(e)}") logger.error("If missing dependencies, check the requirements.txt file") raise def load_dataset_with_mapping(dataset_config): """Load dataset and apply appropriate column mappings.""" try: # Load dataset dataset_name = dataset_config.get("dataset", {}).get("name", "") dataset_split = dataset_config.get("dataset", {}).get("split", "train") if not dataset_name: raise ValueError("Dataset name not provided in configuration") logger.info(f"Loading dataset {dataset_name}, split {dataset_split}") dataset = load_dataset(dataset_name, split=dataset_split) # Map columns if specified - with checks to avoid conflicts column_mapping = dataset_config.get("dataset", {}).get("column_mapping", {}) if column_mapping: logger.info(f"Checking column mapping: {column_mapping}") # Only apply mappings for columns that need renaming and don't already exist safe_mappings = {} for target, source in column_mapping.items(): if source in dataset.column_names: # Skip if target already exists and is not the same as source if target in dataset.column_names and target != source: logger.warning(f"Cannot rename '{source}' to '{target}' - target column already exists") else: safe_mappings[source] = target # Apply safe renames if safe_mappings: logger.info(f"Applying safe column mapping: {safe_mappings}") for source, target in safe_mappings.items(): if source != target: # Only rename if names are different dataset = dataset.rename_column(source, target) # Add prompt_number field that increments based on original order - simple approach logger.info("Adding prompt_number based on original dataset order (starting at 1)") # Simple approach 1: Add index as a column during dataset creation # Create a list of dicts with indices examples_with_idx = [] for i, example in enumerate(dataset): example = dict(example) # Make a copy to avoid modifying the original example['prompt_number'] = i + 1 # 1-indexed examples_with_idx.append(example) # Recreate dataset with prompt_number included from datasets import Dataset dataset = Dataset.from_list(examples_with_idx) logger.info("Successfully added prompt_number to dataset") # If conversations is missing but text exists, attempt conversion if "conversations" not in dataset.column_names and "text" in dataset.column_names: logger.info("Converting 'text' field to 'conversations' format") def convert_text_to_conversations(example): # Check if text is already a list of conversation turns if isinstance(example.get("text"), list): example["conversations"] = example["text"] # Otherwise, create a simple conversation with the text as user message else: example["conversations"] = [ {"role": "user", "content": str(example.get("text", ""))} ] return example dataset = dataset.map(convert_text_to_conversations) logger.info("Successfully converted 'text' to 'conversations'") # Verify we have the required columns if "conversations" not in dataset.column_names: logger.error("Required 'conversations' column not found in dataset!") raise ValueError("Required 'conversations' column missing from dataset") # Log column names and a sample logger.info(f"Dataset loaded successfully with {len(dataset)} examples") logger.info(f"Dataset columns: {dataset.column_names}") # Log a sample for inspection if len(dataset) > 0: sample = dataset[0] prompt_num = sample.get("prompt_number", "N/A") article_id = sample.get("article_id", sample.get("id", "N/A")) logger.info(f"First sample - Prompt number: {prompt_num}, ID: {article_id}") return dataset except Exception as e: logger.error(f"Error loading dataset: {str(e)}") raise def format_phi_chat(messages, dataset_config): """Format messages according to phi-4's chat template and dataset config.""" formatted_chat = "" # Get role templates from config roles = dataset_config.get("data_formatting", {}).get("roles", { "system": "System: {content}\n\n", "human": "Human: {content}\n\n", "user": "Human: {content}\n\n", "assistant": "Assistant: {content}\n\n" }) # Handle research introduction metadata first metadata = next((msg for msg in messages if isinstance(msg, dict) and "[RESEARCH INTRODUCTION]" in msg.get("content", "")), None) if metadata: system_template = roles.get("system", "System: {content}\n\n") formatted_chat = system_template.format(content=metadata['content']) messages = [msg for msg in messages if msg != metadata] # Process remaining messages for message in messages: if not isinstance(message, dict) or "content" not in message: logger.warning(f"Skipping invalid message format: {message}") continue role = message.get("role", "").lower() content = message.get("content", "") # Format based on role if role == "human" or role == "user": template = roles.get("user", roles.get("human", "Human: {content}\n\n")) formatted_chat += template.format(content=content) elif role == "assistant" or role == "bot": template = roles.get("assistant", "Assistant: {content}\n\n") formatted_chat += template.format(content=content) elif role == "system": # For system messages, prepend them template = roles.get("system", "System: {content}\n\n") formatted_chat = template.format(content=content) + formatted_chat else: # Default to system for unknown roles logger.warning(f"Unknown role '{role}' - treating as system message") template = roles.get("system", "System: {content}\n\n") formatted_chat += template.format(content=content) return formatted_chat.strip() class SimpleDataCollator: def __init__(self, tokenizer, dataset_config): self.tokenizer = tokenizer self.dataset_config = dataset_config self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0} self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048) logger.info(f"SimpleDataCollator initialized - using pre-audited dataset with max_seq_length={self.max_seq_length}") logger.info("Using exact dataset structure without reformatting") # Check if we're on GPU self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"SimpleDataCollator using device: {self.device}") def __call__(self, features): """Process examples preserving exact JSONL structure""" batch = {"input_ids": [], "attention_mask": [], "labels": []} for example in features: try: # Get ID paper_id = example.get("id", "") # Get conversations - these should already contain role and content conversations = example.get("conversations", []) if not conversations: self.stats["skipped"] += 1 continue # Directly use the conversations array as input to the model's chat template # This preserves the exact structure with roles and content as they are try: # Let tokenizer handle the content with the model's chat template inputs = self.tokenizer.apply_chat_template( conversations, return_tensors=None, add_generation_prompt=False ) except Exception as chat_error: # Fallback if apply_chat_template fails logger.warning(f"Chat template application failed for example {paper_id}: {str(chat_error)[:100]}") # Create a basic representation of the conversation conversation_text = "" for msg in conversations: if isinstance(msg, dict) and 'content' in msg: conversation_text += msg.get('content', '') + "\n\n" # Basic tokenization inputs = self.tokenizer( conversation_text, add_special_tokens=True, return_tensors=None ) # Apply length cap if needed (shouldn't be necessary for pre-audited data) if self.max_seq_length > 0 and len(inputs) > self.max_seq_length: logger.warning(f"Example {paper_id} exceeds max_seq_length ({len(inputs)} > {self.max_seq_length})") inputs = inputs[:self.max_seq_length] # Create attention mask (1 for all tokens) attention_mask = [1] * len(inputs) if len(inputs) > 0: # For causal language modeling, labels are the same as inputs labels = inputs.copy() batch["input_ids"].append(inputs) batch["attention_mask"].append(attention_mask) batch["labels"].append(labels) self.stats["processed"] += 1 self.stats["total_tokens"] += len(inputs) # Debug logging for first few examples log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3) if self.stats["processed"] <= log_samples: logger.info(f"Example {self.stats['processed']}:") logger.info(f"Paper ID: {paper_id}") logger.info(f"Token count: {len(inputs)}") logger.info(f"Conversation entries: {len(conversations)}") else: self.stats["skipped"] += 1 except Exception as e: logger.warning(f"Error processing example: {str(e)[:100]}...") logger.warning(f"Problematic example ID: {example.get('id', 'unknown')}") self.stats["skipped"] += 1 continue if not batch["input_ids"]: logger.warning("Empty batch, returning dummy tensors") return { "input_ids": torch.zeros((1, 1), dtype=torch.long), "attention_mask": torch.zeros((1, 1), dtype=torch.long), "labels": torch.zeros((1, 1), dtype=torch.long) } # Pad the batch max_length = max(len(ids) for ids in batch["input_ids"]) for i in range(len(batch["input_ids"])): padding_length = max_length - len(batch["input_ids"][i]) if padding_length > 0: batch["input_ids"][i].extend([self.pad_token_id] * padding_length) batch["attention_mask"][i].extend([0] * padding_length) batch["labels"][i].extend([-100] * padding_length) # Convert to tensors batch = {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()} # Log stats periodically log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100) if self.stats["processed"] % log_interval == 0 and self.stats["processed"] > 0: logger.info(f"Data collator stats: processed={self.stats['processed']}, " f"skipped={self.stats['skipped']}, " f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}") return batch class LoggingCallback(TrainerCallback): def __init__(self): super().__init__() self.training_started = time.time() self.last_log_time = time.time() self.last_step = 0 self.verify_sequence = None self.sequence_samples = None self.sample_indices = None def on_step_end(self, args, state, control, **kwargs): # Log every 50 steps or every 5 minutes, whichever comes first current_time = time.time() # Perform actual sequence integrity verification if enabled if self.verify_sequence is True and state.global_step % 100 == 0 and self.sequence_samples: try: # Get a batch of data without disturbing the training train_dataloader = trainer.get_train_dataloader() if train_dataloader is None: log_info("Warning: Could not get train dataloader for verification") else: batch_iterator = iter(train_dataloader) if batch_iterator is None: log_info("Warning: Could not get batch iterator for verification") else: try: batch = next(batch_iterator) if batch is None: log_info("Warning: Could not get batch for verification") elif 'input_ids' in batch and 'labels' in batch: log_info("Verifying data sequence integrity...") # Check if we can access some of our reference samples if not hasattr(trainer, 'train_dataset') or trainer.train_dataset is None: log_info("Warning: Train dataset is not available") else: # Get current samples defensively current_samples = [] current_indices = list(range(min(3, len(trainer.train_dataset)))) for idx in current_indices: try: if idx < len(trainer.train_dataset): current_samples.append(trainer.train_dataset[idx]) except Exception as e: log_info(f"Warning: Error accessing dataset at index {idx}: {e}") # Only proceed if we have samples to compare if current_samples and self.sequence_samples: # Compare current samples with our reference samples from training start is_sequence_maintained = True for i, (orig_idx, orig_sample) in enumerate(zip(self.sample_indices, self.sequence_samples)): # Check if sample index is valid if i < len(current_samples): current_sample = current_samples[i] # Compare prompt numbers if available if ('prompt_number' in orig_sample and 'prompt_number' in current_sample and orig_sample['prompt_number'] is not None and current_sample['prompt_number'] is not None): if orig_sample['prompt_number'] != current_sample['prompt_number']: log_info(f"WARNING: Sequence integrity compromised! Sample {i} prompt number changed from {orig_sample['prompt_number']} to {current_sample['prompt_number']}") is_sequence_maintained = False # Also compare IDs as a backup check elif ('article_id' in orig_sample and 'article_id' in current_sample and orig_sample['article_id'] is not None and current_sample['article_id'] is not None): if orig_sample['article_id'] != current_sample['article_id']: log_info(f"WARNING: Sequence integrity compromised! Sample {i} article_id changed from {orig_sample['article_id']} to {current_sample['article_id']}") is_sequence_maintained = False # Compare input fingerprints if ('conversations' in orig_sample and 'conversations' in current_sample and orig_sample['conversations'] is not None and current_sample['conversations'] is not None): orig_len = len(orig_sample['conversations']) curr_len = len(current_sample['conversations']) if orig_len != curr_len: log_info(f"WARNING: Sequence integrity compromised! Sample {i} conversation length changed from {orig_len} to {curr_len}") is_sequence_maintained = False if is_sequence_maintained: log_info("Data sequence integrity check: OK") else: log_info("CRITICAL WARNING: Data sequence integrity check FAILED!") else: log_info("Warning: Not enough samples available for sequence verification") except StopIteration: log_info("Warning: No batches available in the dataloader") except Exception as e: log_info(f"Warning: Error iterating through dataloader: {e}") except Exception as e: log_info(f"Warning: Couldn't verify sequence integrity: {e}") time_interval = current_time - self.last_log_time step_interval = state.global_step - self.last_step if step_interval >= 50 or time_interval >= 300: # 5 minutes = 300 seconds # Calculate throughput examples_per_second = step_interval * args.per_device_train_batch_size * args.gradient_accumulation_steps / max(time_interval, 1e-6) elapsed_total = time.strftime("%H:%M:%S", time.gmtime(current_time - self.training_started)) # Log progress log_info(f"Step: {state.global_step}, Loss: {state.log_history[-1]['loss']:.4f}, " f"Rate: {examples_per_second:.2f} examples/sec, Elapsed: {elapsed_total}") # Report memory usage if CUDA is available if CUDA_AVAILABLE: log_info(f"GPU Memory: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB allocated, " f"{torch.cuda.max_memory_reserved() / 1024**3:.2f} GB reserved") # Reset for next interval self.last_log_time = current_time self.last_step = state.global_step def on_train_begin(self, args, state, control, **kwargs): log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===") log_info(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M") # Set up sequence verification with actual sample capturing try: self.verify_sequence = dataset_config.get("validation", {}).get("verify_sequence_integrity", False) if self.verify_sequence: log_info("Sequence integrity verification enabled during training") # Save actual samples for later verification if trainer and hasattr(trainer, 'train_dataset') and trainer.train_dataset is not None: # Get some reference samples from the beginning of the dataset defensively self.sample_indices = [] self.sequence_samples = [] max_samples = min(5, len(trainer.train_dataset)) for i in range(max_samples): try: if i < len(trainer.train_dataset): self.sample_indices.append(i) self.sequence_samples.append(trainer.train_dataset[i]) except Exception as e: log_info(f"Warning: Error capturing reference sample at index {i}: {e}") if self.sequence_samples: log_info(f"Captured {len(self.sequence_samples)} reference samples for sequence integrity verification") # Log sample prompt numbers for debugging sample_prompt_numbers = [] for s in self.sequence_samples: if isinstance(s, dict) and 'prompt_number' in s and s['prompt_number'] is not None: sample_prompt_numbers.append(s.get('prompt_number')) if sample_prompt_numbers: log_info(f"Reference sample prompt numbers: {sample_prompt_numbers}") else: log_info("Warning: No reference samples were captured") else: log_info("Warning: Could not capture reference samples - verification will be limited") except Exception as e: log_info(f"Warning: Could not set up sequence integrity verification: {e}") self.verify_sequence = False log_info("=== Training is starting ===") # Log important training parameters for visibility total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total") log_info(f"Learning rate: {args.learning_rate}") log_info(f"Epochs: {args.num_train_epochs}") # Log memory information in compact format if CUDA_AVAILABLE: memory_info = [] for i in range(NUM_GPUS): allocated = torch.cuda.memory_allocated(i) / 1024**2 max_mem = torch.cuda.max_memory_allocated(i) / 1024**2 memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)") log_info(f"Initial memory usage - {', '.join(memory_info)}") def on_train_end(self, args, state, control, **kwargs): training_time = time.strftime("%H:%M:%S", time.gmtime(time.time() - self.training_started)) log_info(f"=== Training completed in {training_time} ===") # Log final memory usage if CUDA_AVAILABLE: for i in range(NUM_GPUS): max_mem = torch.cuda.max_memory_allocated(i) / 1024**3 # GB log_info(f"GPU {i} max memory: {max_mem:.2f} GB") # Clear GPU memory torch.cuda.empty_cache() log_info("GPU memory cleared") log_info(f"Total steps: {state.global_step}") log_info(f"Final loss: {state.log_history[-1].get('loss', 'N/A') if state.log_history else 'N/A'}") def check_dependencies(): """Check if all required dependencies are installed.""" missing_packages = [] # Critical packages if not unsloth_available: missing_packages.append("unsloth>=2024.3") if not peft_available: missing_packages.append("peft>=0.9.0") # Optional packages - don't add to missing list, just log if find_spec("flash_attn"): logger.info("flash-attn found. Flash attention will be used for faster training.") else: logger.warning("flash-attn not found. Training will work but may be slower.") logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation") # If critical packages are missing, exit with instructions if missing_packages: logger.error("Critical dependencies missing:") for pkg in missing_packages: logger.error(f" - {pkg}") logger.error("Please ensure the space has these packages in requirements.txt") return False return True def main(): # Set up logging logger.info("Starting training process") # Parse arguments args = parse_args() # Load environment variables load_env_variables() # Load configuration try: transformers_config = load_configs(args.config) hardware_config = transformers_config.get("hardware", {}) dataset_config = transformers_config.get("dataset", {}) logger.info("Configuration loaded successfully") except Exception as e: logger.error(f"Error loading configuration: {e}") return 1 # Check dependencies if not check_dependencies(): logger.error("Aborting due to missing critical dependencies") return 1 # Check if we're in distributed mode is_distributed = "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1 if is_distributed: local_rank = int(os.environ.get("LOCAL_RANK", "0")) log_info(f"Running in distributed mode with {os.environ.get('WORLD_SIZE')} processes, local_rank: {local_rank}") else: log_info("Running in non-distributed mode (single process)") # Set random seed for reproducibility seed = transformers_config.get("seed", 42) set_seed(seed) logger.info(f"Set random seed to {seed}") # Load model and tokenizer using the consolidated config model, tokenizer = load_model_and_tokenizer(transformers_config) # Empty CUDA cache to ensure clean state if CUDA_AVAILABLE: torch.cuda.empty_cache() log_info("Cleared CUDA cache") # Setup environment variable for CUDA memory allocation if CUDA_AVAILABLE: system_settings = hardware_config.get("system_settings", {}) cuda_memory_fraction = system_settings.get("cuda_memory_fraction", 0.85) if cuda_memory_fraction < 1.0: os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True" log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128") try: log_info("Loading dataset...") dataset = load_dataset_with_mapping(dataset_config) log_info(f"Dataset loaded with {len(dataset)} examples") # Minimal validation before proceeding if dataset is None or len(dataset) == 0: logger.error("Dataset is empty or None! Cannot proceed with training.") return 1 # Create data collator data_collator = SimpleDataCollator(tokenizer, dataset_config) # Verify precision settings - ensure only one of bf16/fp16 is set, with bf16 taking precedence # First check hardware config, then transformers config use_bf16 = False use_fp16 = False # Check hardware config first hardware_precision = hardware_config.get("training_optimizations", {}).get("mixed_precision", "") if hardware_precision.lower() == "bf16": use_bf16 = True log_info("Using BF16 precision from hardware config") elif hardware_precision.lower() == "fp16": use_fp16 = True log_info("Using FP16 precision from hardware config") else: # Fall back to transformers config use_bf16 = transformers_config.get("bf16", False) or transformers_config.get("torch_dtype", "") == "bfloat16" use_fp16 = transformers_config.get("fp16", False) and not use_bf16 # Only use fp16 if bf16 is not set log_info(f"Using precision: {'bf16' if use_bf16 else 'fp16' if use_fp16 else 'full precision'}") # Get per device batch size - from transformers config, but possibly overridden by hardware config per_device_batch_size = transformers_config.get("training", {}).get("per_device_train_batch_size", 16) gradient_accumulation_steps = transformers_config.get("training", {}).get("gradient_accumulation_steps", 3) # Get multi-GPU strategy from hardware config (default to data_parallel) multi_gpu_strategy = hardware_config.get("training_optimizations", {}).get("multi_gpu_strategy", "data_parallel") logger.info(f"Multi-GPU strategy: {multi_gpu_strategy}") # For multi-GPU setup, adjust for better balance if CUDA_AVAILABLE and NUM_GPUS > 1: log_info(f"Multi-GPU setup: Adjusting for {NUM_GPUS} GPUs") # Set up FSDP for multi-GPU training if specified and in distributed mode fsdp_config = None if multi_gpu_strategy == "fsdp" and is_distributed and NUM_GPUS > 1: try: from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, MixedPrecision, BackwardPrefetch, ShardingStrategy, CPUOffload, ) from torch.distributed.fsdp.wrap import ( transformer_auto_wrap_policy, enable_wrap, wrap, ) log_info("Using FSDP for distributed training") # Configure FSDP fsdp_config = { "fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"], "fsdp_offload_params": False, "fsdp_backward_prefetch": "BACKWARD_PRE", "fsdp_min_num_params": 1e6, "fsdp_sharding_strategy": 1, # FULL_SHARD } if use_bf16 or use_fp16: precision_type = "bf16" if use_bf16 else "fp16" fsdp_config["fsdp_state_dict_type"] = "FULL_STATE_DICT" log_info(f"FSDP using mixed precision: {precision_type}") except ImportError: log_info("FSDP imports failed, falling back to standard DDP") fsdp_config = None elif multi_gpu_strategy == "fsdp" and not is_distributed: log_info("FSDP disabled: requires distributed environment (use torchrun or accelerate)") log_info("Using DataParallel for multi-GPU training instead") else: log_info(f"Using {multi_gpu_strategy} for multi-GPU training") # Get system settings from hardware config dataloader_workers = hardware_config.get("system_settings", {}).get("dataloader_num_workers", 2) pin_memory = hardware_config.get("system_settings", {}).get("dataloader_pin_memory", True) # Set up training arguments log_info("Setting up training arguments") training_args = TrainingArguments( output_dir=transformers_config.get("output_dir", "./results") or transformers_config.get("checkpointing", {}).get("output_dir", "./results"), num_train_epochs=transformers_config.get("training", {}).get("num_train_epochs", 3), per_device_train_batch_size=per_device_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, learning_rate=transformers_config.get("training", {}).get("learning_rate", 2e-5), weight_decay=transformers_config.get("training", {}).get("weight_decay", 0.01), warmup_ratio=transformers_config.get("training", {}).get("warmup_ratio", 0.05), lr_scheduler_type=transformers_config.get("training", {}).get("lr_scheduler_type", "cosine"), logging_steps=transformers_config.get("training", {}).get("logging_steps", 10), save_strategy=transformers_config.get("checkpointing", {}).get("save_strategy", "steps"), save_steps=transformers_config.get("checkpointing", {}).get("save_steps", 100), save_total_limit=transformers_config.get("checkpointing", {}).get("save_total_limit", 3), fp16=use_fp16, bf16=use_bf16, max_grad_norm=transformers_config.get("training", {}).get("max_grad_norm", 1.0), push_to_hub=transformers_config.get("huggingface_hub", {}).get("push_to_hub", False), hub_model_id=transformers_config.get("huggingface_hub", {}).get("hub_model_id", None), hub_token=os.environ.get("HF_TOKEN", None), report_to="tensorboard", remove_unused_columns=False, # Keep all columns gradient_checkpointing=transformers_config.get("training", {}).get("gradient_checkpointing", True), dataloader_pin_memory=pin_memory, optim=transformers_config.get("training", {}).get("optim", "adamw_torch"), ddp_find_unused_parameters=False, # Improve distributed training efficiency dataloader_drop_last=False, # Process all examples dataloader_num_workers=dataloader_workers, no_cuda=False if CUDA_AVAILABLE else True, # Use CUDA if available # Only add FSDP if we're in distributed mode with FSDP strategy fsdp=fsdp_config if is_distributed and multi_gpu_strategy == "fsdp" else None, ) # Create sequential sampler to maintain original dataset order sequential_sampler = torch.utils.data.SequentialSampler(dataset) # Initialize trainer first log_info("Initializing Trainer") trainer = Trainer( model=model, args=training_args, train_dataset=dataset, # We'll override this with our custom dataloader data_collator=data_collator, callbacks=[LoggingCallback()], ) # Then override the get_train_dataloader method def custom_get_train_dataloader(): """Custom dataloader that preserves original dataset order""" log_info("Creating sequential dataloader to maintain original dataset order") # Create a simple sequential sampler sequential_sampler = torch.utils.data.SequentialSampler(dataset) # Verification of sequence preservation flags - simplified data_loading_config = dataset_config.get("data_loading", {}) shuffle_enabled = data_loading_config.get("shuffle", False) if shuffle_enabled: log_info("CRITICAL ERROR: Shuffle is enabled! This will randomize data entry order!") raise ValueError("Dataset shuffling is enabled but sequential processing is required. " + "Please disable shuffling in your configuration.") # Calculate batch size based on device availability if getattr(training_args, "no_cuda", False): batch_size = training_args.per_device_train_batch_size else: batch_size = max(training_args.per_device_train_batch_size * max(1, NUM_GPUS), 1) log_info(f"Using sequential sampler with batch size {batch_size}") # Return DataLoader with sequential sampler return torch.utils.data.DataLoader( dataset, batch_size=batch_size, sampler=sequential_sampler, collate_fn=data_collator, drop_last=training_args.dataloader_drop_last, num_workers=training_args.dataloader_num_workers, pin_memory=training_args.dataloader_pin_memory, ) # Override the get_train_dataloader method trainer.get_train_dataloader = custom_get_train_dataloader # Start training log_info("=== Starting Training ===") try: # Empty cache again right before training if CUDA_AVAILABLE: torch.cuda.empty_cache() log_info("Cleared CUDA cache before training") # Display compact training info total_steps = int(len(dataset) / (per_device_batch_size * NUM_GPUS * gradient_accumulation_steps) * training_args.num_train_epochs) log_info(f"Training plan: {len(dataset)} examples over {training_args.num_train_epochs} epochs ≈ {total_steps} steps") trainer.train() log_info("Training completed successfully!") # Save the final model log_info("Saving final model...") trainer.save_model() log_info(f"Model saved to {training_args.output_dir}") # Push to hub if enabled if transformers_config.get("huggingface_hub", {}).get("push_to_hub", False): hub_id = transformers_config.get("huggingface_hub", {}).get("hub_model_id", "model") log_info(f"Pushing model to Hugging Face Hub as {hub_id}...") trainer.push_to_hub() log_info("Model successfully pushed to Hub") return 0 except Exception as e: logger.error(f"Training failed with error: {str(e)}") # Log CUDA memory info if available in compact format if CUDA_AVAILABLE: memory_info = [] for i in range(NUM_GPUS): allocated = torch.cuda.memory_allocated(i) / 1024**2 reserved = torch.cuda.memory_reserved(i) / 1024**2 max_mem = torch.cuda.max_memory_allocated(i) / 1024**2 memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB (max: {max_mem:.1f}MB)") logger.error(f"GPU memory at failure: {', '.join(memory_info)}") raise except Exception as e: logger.error(f"Error in main training loop: {str(e)}") return 1 if __name__ == "__main__": sys.exit(main())