def format_phi_chat(messages, dataset_config):
    """Format messages according to phi-4's chat template and dataset config."""
    formatted_chat = ""
    
    # Get role templates from config
    roles = dataset_config.get("data_formatting", {}).get("roles", {
        "system": "System: {content}\n\n",
        "human": "Human: {content}\n\n",
        "user": "Human: {content}\n\n",
        "assistant": "Assistant: {content}\n\n"
    })
    
    # Handle research introduction metadata first
    metadata = next((msg for msg in messages if isinstance(msg, dict) and 
                    "[RESEARCH INTRODUCTION]" in msg.get("content", "")), None)
    if metadata:
        system_template = roles.get("system", "System: {content}\n\n")
        formatted_chat = system_template.format(content=metadata['content'])
        messages = [msg for msg in messages if msg != metadata]
    
    # Process remaining messages
    for message in messages:
        if not isinstance(message, dict) or "content" not in message:
            logger.warning(f"Skipping invalid message format: {message}")
            continue
            
        role = message.get("role", "").lower()
        content = message.get("content", "")
    
        # Format based on role
        if role == "human" or role == "user":
            template = roles.get("user", roles.get("human", "Human: {content}\n\n"))
            formatted_chat += template.format(content=content)
        elif role == "assistant" or role == "bot":
            template = roles.get("assistant", "Assistant: {content}\n\n")
            formatted_chat += template.format(content=content)
        elif role == "system":
            # For system messages, prepend them
            template = roles.get("system", "System: {content}\n\n")
            formatted_chat = template.format(content=content) + formatted_chat
        else:
            # Default to system for unknown roles
            logger.warning(f"Unknown role '{role}' - treating as system message")
            template = roles.get("system", "System: {content}\n\n")
            formatted_chat += template.format(content=content)
    
    return formatted_chat.strip()

class SimpleDataCollator:
    def __init__(self, tokenizer, dataset_config):
        self.tokenizer = tokenizer
        self.dataset_config = dataset_config
        self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
        self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
        self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
        logger.info(f"SimpleDataCollator initialized - using pre-audited dataset with max_seq_length={self.max_seq_length}")
        logger.info("Using exact dataset structure without reformatting")
        
        # Check if we're on GPU
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"SimpleDataCollator using device: {self.device}")
    
    def __call__(self, features):
        """Process examples preserving exact JSONL structure"""
        batch = {"input_ids": [], "attention_mask": [], "labels": []}
        
        for example in features:
            try:
                # Get ID
                paper_id = example.get("id", "")
                
                # Get conversations - these should already contain role and content
                conversations = example.get("conversations", [])
                if not conversations:
                    self.stats["skipped"] += 1
                    continue
                
                # Directly use the conversations array as input to the model's chat template
                # This preserves the exact structure with roles and content as they are
                try:
                    # Let tokenizer handle the content with the model's chat template
                    inputs = self.tokenizer.apply_chat_template(
                        conversations,
                        return_tensors=None,
                        add_generation_prompt=False
                    )
                except Exception as chat_error:
                    # Fallback if apply_chat_template fails
                    logger.warning(f"Chat template application failed for example {paper_id}: {str(chat_error)[:100]}")
                    
                    # Create a basic representation of the conversation
                    conversation_text = ""
                    for msg in conversations:
                        if isinstance(msg, dict) and 'content' in msg:
                            conversation_text += msg.get('content', '') + "\n\n"
                    
                    # Basic tokenization
                    inputs = self.tokenizer(
                        conversation_text,
                        add_special_tokens=True,
                        return_tensors=None
                    )
                
                # Apply length cap if needed (shouldn't be necessary for pre-audited data)
                if self.max_seq_length > 0 and len(inputs) > self.max_seq_length:
                    logger.warning(f"Example {paper_id} exceeds max_seq_length ({len(inputs)} > {self.max_seq_length})")
                    inputs = inputs[:self.max_seq_length]
                    
                # Create attention mask (1 for all tokens)
                attention_mask = [1] * len(inputs)
                
                if len(inputs) > 0:
                    # For causal language modeling, labels are the same as inputs
                    labels = inputs.copy()
                    
                    batch["input_ids"].append(inputs)
                    batch["attention_mask"].append(attention_mask)
                    batch["labels"].append(labels)
                    
                    self.stats["processed"] += 1
                    self.stats["total_tokens"] += len(inputs)
                    
                    # Debug logging for first few examples
                    log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3)
                    if self.stats["processed"] <= log_samples:
                        logger.info(f"Example {self.stats['processed']}:")
                        logger.info(f"Paper ID: {paper_id}")
                        logger.info(f"Token count: {len(inputs)}")
                        logger.info(f"Conversation entries: {len(conversations)}")
                else:
                    self.stats["skipped"] += 1
            except Exception as e:
                logger.warning(f"Error processing example: {str(e)[:100]}...")
                logger.warning(f"Problematic example ID: {example.get('id', 'unknown')}")
                self.stats["skipped"] += 1
                continue
        
        if not batch["input_ids"]:
            logger.warning("Empty batch, returning dummy tensors")
            return {
                "input_ids": torch.zeros((1, 1), dtype=torch.long),
                "attention_mask": torch.zeros((1, 1), dtype=torch.long),
                "labels": torch.zeros((1, 1), dtype=torch.long)
            }
        
        # Pad the batch
        max_length = max(len(ids) for ids in batch["input_ids"])
        
        for i in range(len(batch["input_ids"])):
            padding_length = max_length - len(batch["input_ids"][i])
            if padding_length > 0:
                batch["input_ids"][i].extend([self.pad_token_id] * padding_length)
                batch["attention_mask"][i].extend([0] * padding_length)
                batch["labels"][i].extend([-100] * padding_length)
        
        # Convert to tensors
        batch = {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()}
        
        # Log stats periodically
        log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100)
        if self.stats["processed"] % log_interval == 0 and self.stats["processed"] > 0:
            logger.info(f"Data collator stats: processed={self.stats['processed']}, "
                       f"skipped={self.stats['skipped']}, "
                       f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}")
        
        return batch

class LoggingCallback(TrainerCallback):
    def __init__(self):
        self.last_log_time = time.time()
        self.last_memory_log_time = time.time()
        
    def on_step_end(self, args, state, control, **kwargs):
        # Log every 50 steps or every 5 minutes, whichever comes first
        current_time = time.time()
        
        # Log loss every 50 steps or 5 minutes
        if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
            if state.log_history:
                loss = state.log_history[-1].get('loss', 'N/A')
                # Use simple formatting for better HF Space log compatibility
                log_info(f"Step {state.global_step}: Loss {loss}")
            else:
                log_info(f"Step {state.global_step}: No loss data available")
            self.last_log_time = current_time
        
        # Log memory usage every 15 minutes
        if current_time - self.last_memory_log_time > 900:  # 15 minutes
            if torch.cuda.is_available():
                memory_info = []
                for i in range(torch.cuda.device_count()):
                    allocated = torch.cuda.memory_allocated(i) / 1024**2
                    reserved = torch.cuda.memory_reserved(i) / 1024**2
                    memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB")
                
                # Log in compact format for better visibility
                log_info(f"Memory usage - {', '.join(memory_info)}")
            self.last_memory_log_time = current_time
            
    def on_train_begin(self, args, state, control, **kwargs):
        log_info("=== Training is starting ===")
        
        # Log important training parameters for visibility
        effective_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * max(1, torch.cuda.device_count())
        log_info(f"Per device batch size: {args.per_device_train_batch_size}")
        log_info(f"Gradient accumulation steps: {args.gradient_accumulation_steps}")
        log_info(f"Number of GPUs: {max(1, torch.cuda.device_count())}")
        log_info(f"Total effective batch size: {effective_batch_size}")
        log_info(f"Learning rate: {args.learning_rate}")
        log_info(f"Epochs: {args.num_train_epochs}")
        
        # Log dataset information
        if hasattr(trainer, 'train_dataset') and trainer.train_dataset is not None:
            log_info(f"Dataset size: {len(trainer.train_dataset)} examples")
            if len(trainer.train_dataset) > 0:
                try:
                    # Log first few prompt numbers to verify sequence
                    prompt_numbers = []
                    for i in range(min(5, len(trainer.train_dataset))):
                        if 'prompt_number' in trainer.train_dataset[i]:
                            prompt_numbers.append(trainer.train_dataset[i]['prompt_number'])
                    if prompt_numbers:
                        log_info(f"First few prompt numbers: {prompt_numbers}")
                except Exception as e:
                    log_info(f"Error accessing dataset samples: {e}")
        
        # Log memory information in compact format
        if torch.cuda.is_available():
            memory_info = []
            for i in range(torch.cuda.device_count()):
                allocated = torch.cuda.memory_allocated(i) / 1024**2
                max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
                memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
            
            log_info(f"Initial memory usage - {', '.join(memory_info)}")
            
    def on_train_end(self, args, state, control, **kwargs):
        log_info("=== Training completed ===")
        if torch.cuda.is_available():
            memory_info = []
            for i in range(torch.cuda.device_count()):
                allocated = torch.cuda.memory_allocated(i) / 1024**2
                max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
                memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
            
            log_info(f"Final memory usage - {', '.join(memory_info)}")
        
        log_info(f"Total steps: {state.global_step}")
        log_info(f"Final loss: {state.log_history[-1].get('loss', 'N/A') if state.log_history else 'N/A'}")

def custom_get_train_dataloader():
    """Custom dataloader that preserves original dataset order"""
    log_info("Creating sequential dataloader to maintain original dataset order")
    
    # Create a simple sequential sampler
    sequential_sampler = torch.utils.data.SequentialSampler(dataset)
    
    # Verify shuffle is disabled
    data_loading_config = dataset_config.get("data_loading", {})
    shuffle_enabled = data_loading_config.get("shuffle", False)
    
    if shuffle_enabled:
        log_info("CRITICAL ERROR: Shuffle is enabled! This will randomize data entry order!")
        raise ValueError("Dataset shuffling is enabled but sequential processing is required. " +
                      "Please disable shuffling in your configuration.")
    
    # Log our sequential processing approach
    log_info("Using SequentialSampler to guarantee original dataset order is preserved")
    log_info("Data order preservation is critical for proper training sequence")
    
    # Calculate batch size based on device availability
    if getattr(training_args, "no_cuda", False):
        batch_size = training_args.per_device_train_batch_size
    else:
        batch_size = max(training_args.per_device_train_batch_size * max(1, NUM_GPUS), 1)
        
    log_info(f"Using sequential sampler with batch size {batch_size}")
    
    # Return DataLoader with sequential sampler
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sequential_sampler,
        collate_fn=data_collator,
        drop_last=training_args.dataloader_drop_last,
        num_workers=training_args.dataloader_num_workers,
        pin_memory=training_args.dataloader_pin_memory,
    ) 

def check_dependencies():
    """Check for critical dependencies and provide useful warnings."""
    # Check for flash attention without attempting import
    flash_attention_available = False
    try:
        import importlib.util
        if importlib.util.find_spec("flash_attn") is not None:
            flash_attention_available = True
            log_info("flash-attn found! Using Flash Attention for faster training.")
        else:
            log_info("flash-attn not found. Training will continue but may be slower.")
            log_info("To use flash attention, install: pip install flash-attn==2.5.2 --no-build-isolation")
            # Still continue as this is optional
    except Exception as e:
        log_info(f"Error checking for flash-attn: {e}")

    # Check for torch CUDA
    if not torch.cuda.is_available():
        log_info("WARNING: CUDA not available. Training will be extremely slow on CPU!")
    else:
        log_info(f"Found {torch.cuda.device_count()} CUDA devices")
    
    # Check for unsloth
    unsloth_available = False
    try:
        import importlib.util
        if importlib.util.find_spec("unsloth") is not None:
            unsloth_available = True
            log_info("Unsloth found! Using Unsloth for optimized training.")
        else:
            log_info("CRITICAL: Unsloth not found. This pipeline requires Unsloth.")
            log_info("Install with: pip install unsloth>=2024.3")
            return False
    except Exception as e:
        log_info(f"Error checking for unsloth: {e}")
        return False
    
    return True 

def main():
    """Main training function with error handling."""
    try:
        # Initialize logging
        log_info("Starting Phi-4 training process")
        
        # Parse arguments
        args = parse_args()
        
        # Load environment variables
        load_env_variables()
        
        # Load config from file
        config = load_configs(args.config)
        
        # Extract specific configurations
        hardware_config = config.get("hardware", {})
        dataset_config = config.get("dataset", {})
        
        # Define multi_gpu_strategy early to prevent undefined errors
        multi_gpu_strategy = hardware_config.get("training_optimizations", {}).get("multi_gpu_strategy", "data_parallel")
        log_info(f"Multi-GPU strategy: {multi_gpu_strategy}")
        
        # Check dependencies
        if not check_dependencies():
            log_info("Aborting due to missing critical dependencies")
            return 1
            
        # Log hardware info
        cuda_available = torch.cuda.is_available()
        num_gpus = torch.cuda.device_count() if cuda_available else 0
        log_info(f"Hardware: {num_gpus} GPUs detected" if cuda_available else "Hardware: CPU only")
        
        # Rest of training code would go here
        # ...
        
        return 0
    except Exception as e:
        log_info(f"Error in main training loop: {str(e)}")
        # Log CUDA memory if available
        if torch.cuda.is_available():
            try:
                memory_info = []
                for i in range(torch.cuda.device_count()):
                    allocated = torch.cuda.memory_allocated(i) / 1024**2
                    reserved = torch.cuda.memory_reserved(i) / 1024**2
                    memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB")
                log_info(f"GPU memory at failure: {', '.join(memory_info)}")
            except:
                pass
        return 1

if __name__ == "__main__":
    import sys
    sys.exit(main())