import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import efficientnet_v2_l, EfficientNet_V2_L_Weights
from PIL import Image
from typing import Optional
import torchvision.transforms as transforms
import os
import json

class InitialOnlyImageTagger(nn.Module):
    """
    A lightweight version of ImageTagger that only includes the backbone and initial classifier.
    This model uses significantly less VRAM than the full model.
    """
    def __init__(self, total_tags, dataset, model_name='efficientnet_v2_l',
                 dropout=0.1, pretrained=True):
        super().__init__()
        # Debug and stats flags
        self._flags = {
            'debug': False,
            'model_stats': False
        }
        
        # Core model config
        self.dataset = dataset
        self.embedding_dim = 1280  # Fixed to EfficientNetV2-L output dimension
        
        # Initialize backbone
        if model_name == 'efficientnet_v2_l':
            weights = EfficientNet_V2_L_Weights.DEFAULT if pretrained else None
            self.backbone = efficientnet_v2_l(weights=weights)
            self.backbone.classifier = nn.Identity()
        
        # Spatial pooling only - no projection
        self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Initial tag prediction with bottleneck
        self.initial_classifier = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim * 2),
            nn.LayerNorm(self.embedding_dim * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(self.embedding_dim * 2, self.embedding_dim),
            nn.LayerNorm(self.embedding_dim),
            nn.GELU(),
            nn.Linear(self.embedding_dim, total_tags)
        )
        
        # Temperature scaling
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)
        
    @property
    def debug(self):
        return self._flags['debug']
    
    @debug.setter
    def debug(self, value):
        self._flags['debug'] = value
        
    @property
    def model_stats(self):
        return self._flags['model_stats']
    
    @model_stats.setter
    def model_stats(self, value):
        self._flags['model_stats'] = value
        
    def preprocess_image(self, image_path, image_size=512):
        """Process an image for inference using same preprocessing as training"""
        if not os.path.exists(image_path):
            raise ValueError(f"Image not found at path: {image_path}")
       
        # Initialize the same transform used during training
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])
       
        try:
            with Image.open(image_path) as img:
                # Convert RGBA or Palette images to RGB
                if img.mode in ('RGBA', 'P'):
                    img = img.convert('RGB')
               
                # Get original dimensions
                width, height = img.size
                aspect_ratio = width / height
               
                # Calculate new dimensions to maintain aspect ratio
                if aspect_ratio > 1:
                    new_width = image_size
                    new_height = int(new_width / aspect_ratio)
                else:
                    new_height = image_size
                    new_width = int(new_height * aspect_ratio)
               
                # Resize with LANCZOS filter
                img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
               
                # Create new image with padding
                new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0))
                paste_x = (image_size - new_width) // 2
                paste_y = (image_size - new_height) // 2
                new_image.paste(img, (paste_x, paste_y))
               
                # Apply transforms (without normalization)
                img_tensor = transform(new_image)
                return img_tensor
        except Exception as e:
            raise Exception(f"Error processing {image_path}: {str(e)}")
    
    def forward(self, x):
        """Forward pass with only the initial predictions"""
        # Image Feature Extraction
        features = self.backbone.features(x)
        features = self.spatial_pool(features).squeeze(-1).squeeze(-1)
        
        # Initial Tag Predictions
        initial_logits = self.initial_classifier(features)
        initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
        
        # For API compatibility with the full model, return the same predictions twice
        return initial_preds, initial_preds
    
    def predict(self, image_path, threshold=0.325, category_thresholds=None):
        """
        Run inference on an image with support for category-specific thresholds.
        """
        # Preprocess the image
        img_tensor = self.preprocess_image(image_path).unsqueeze(0)
        
        # Move to the same device as model and convert to half precision
        device = next(self.parameters()).device
        dtype = next(self.parameters()).dtype  # Match model's precision
        img_tensor = img_tensor.to(device, dtype=dtype)
        
        # Run inference
        with torch.no_grad():
            initial_preds, _ = self.forward(img_tensor)
            
            # Apply sigmoid to get probabilities
            initial_probs = torch.sigmoid(initial_preds)
            
            # Apply thresholds
            if category_thresholds:
                # Create binary prediction tensors
                initial_binary = torch.zeros_like(initial_probs)
                
                # Apply thresholds by category
                for category, cat_threshold in category_thresholds.items():
                    # Create a mask for tags in this category
                    category_mask = torch.zeros_like(initial_probs, dtype=torch.bool)
                    
                    # Find indices for this category
                    for tag_idx in range(initial_probs.size(-1)):
                        try:
                            _, tag_category = self.dataset.get_tag_info(tag_idx)
                            if tag_category == category:
                                category_mask[:, tag_idx] = True
                        except:
                            continue
                    
                    # Apply threshold only to tags in this category
                    cat_threshold_tensor = torch.tensor(cat_threshold, device=device, dtype=dtype)
                    initial_binary[category_mask] = (initial_probs[category_mask] >= cat_threshold_tensor).to(dtype)
                
                predictions = initial_binary
            else:
                # Use the same threshold for all tags
                threshold_tensor = torch.tensor(threshold, device=device, dtype=dtype)
                predictions = (initial_probs >= threshold_tensor).to(dtype)
            
            # Return the same probabilities for both initial and refined for API compatibility
            return {
                'initial_probabilities': initial_probs,
                'refined_probabilities': initial_probs,  # Same as initial for compatibility
                'predictions': predictions
            }
            
    def get_tags_from_predictions(self, predictions, include_probabilities=True):
        """
        Convert model predictions to human-readable tags grouped by category.
        """
        # Get non-zero predictions
        if predictions.dim() > 1:
            predictions = predictions[0]  # Remove batch dimension
            
        # Get indices of positive predictions
        indices = torch.where(predictions > 0)[0].cpu().tolist()
        
        # Group by category
        result = {}
        for idx in indices:
            tag_name, category = self.dataset.get_tag_info(idx)
            
            if category not in result:
                result[category] = []
            
            if include_probabilities:
                prob = predictions[idx].item()
                result[category].append((tag_name, prob))
            else:
                result[category].append(tag_name)
        
        # Sort tags by probability within each category
        if include_probabilities:
            for category in result:
                result[category] = sorted(result[category], key=lambda x: x[1], reverse=True)
        
        return result

class FlashAttention(nn.Module):
    def __init__(self, dim, num_heads=8, dropout=0.1, batch_first=True):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = dim // num_heads
        assert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"
        
        self.q_proj = nn.Linear(dim, dim, bias=False)
        self.k_proj = nn.Linear(dim, dim, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)
        self.out_proj = nn.Linear(dim, dim, bias=False)
        
        for proj in [self.q_proj, self.k_proj, self.v_proj, self.out_proj]:
            nn.init.xavier_uniform_(proj.weight, gain=0.1)
        
        self.scale = self.head_dim ** -0.5
        self.debug = False

    def _debug_print(self, name, tensor):
        """Debug helper"""
        if self.debug:
            print(f"\n{name}:")
            print(f"Shape: {tensor.shape}")
            print(f"Device: {tensor.device}")
            print(f"Dtype: {tensor.dtype}")
            if tensor.is_floating_point():
                with torch.no_grad():
                    print(f"Range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]")
                    print(f"Mean: {tensor.mean().item():.3f}")
                    print(f"Std: {tensor.std().item():.3f}")
    
    def _reshape_for_flash(self, x: torch.Tensor) -> torch.Tensor:
        """Reshape input tensor for flash attention format"""
        batch_size, seq_len, _ = x.size()
        x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
        x = x.transpose(1, 2)  # [B, H, S, D]
        return x.contiguous()

    def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, 
                value: Optional[torch.Tensor] = None, 
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Forward pass with flash attention"""
        if self.debug:
            print("\nFlashAttention Forward Pass")
        
        batch_size = query.size(0)
        
        # Use query as key/value if not provided
        key = query if key is None else key
        value = query if value is None else value
        
        # Project inputs
        q = self.q_proj(query)
        k = self.k_proj(key)
        v = self.v_proj(value)
        
        if self.debug:
            self._debug_print("Query before reshape", q)
        
        # Reshape for attention [B, H, S, D]
        q = self._reshape_for_flash(q)
        k = self._reshape_for_flash(k)
        v = self._reshape_for_flash(v)
        
        if self.debug:
            self._debug_print("Query after reshape", q)
        
        # Handle masking
        if mask is not None:
            # First convert mask to proper shape based on input dimensionality
            if mask.dim() == 2:  # [B, S]
                mask = mask.view(batch_size, 1, -1, 1)
            elif mask.dim() == 3:  # [B, S, S]
                mask = mask.view(batch_size, 1, mask.size(1), mask.size(2))
            elif mask.dim() == 5:  # [B, 1, S, S, S]
                mask = mask.squeeze(1).view(batch_size, 1, mask.size(2), mask.size(3))
                
            # Ensure mask is float16 if we're using float16
            mask = mask.to(q.dtype)
            
            if self.debug:
                self._debug_print("Prepared mask", mask)
                print(f"q shape: {q.shape}, mask shape: {mask.shape}")
            
            # Create attention mask that covers the full sequence length
            seq_len = q.size(2)
            if mask.size(-1) != seq_len:
                # Pad or trim mask to match sequence length
                new_mask = torch.zeros(batch_size, 1, seq_len, seq_len,
                                    device=mask.device, dtype=mask.dtype)
                min_len = min(seq_len, mask.size(-1))
                new_mask[..., :min_len, :min_len] = mask[..., :min_len, :min_len]
                mask = new_mask
            
            # Create key padding mask
            key_padding_mask = mask.squeeze(1).sum(-1) > 0
            key_padding_mask = key_padding_mask.view(batch_size, 1, -1, 1)
            
            # Apply the key padding mask
            k = k * key_padding_mask
            v = v * key_padding_mask
        
        if self.debug:
            self._debug_print("Query before attention", q)
            self._debug_print("Key before attention", k)
            self._debug_print("Value before attention", v)
        
        # Run flash attention
        dropout_p = self.dropout if self.training else 0.0
        output = flash_attn_func(
            q, k, v,
            dropout_p=dropout_p,
            softmax_scale=self.scale,
            causal=False
        )
        
        if self.debug:
            self._debug_print("Output after attention", output)
        
        # Reshape output [B, H, S, D] -> [B, S, H, D] -> [B, S, D]
        output = output.transpose(1, 2).contiguous()
        output = output.view(batch_size, -1, self.dim)
        
        # Final projection
        output = self.out_proj(output)
        
        if self.debug:
            self._debug_print("Final output", output)
        
        return output
    
class OptimizedTagEmbedding(nn.Module):
    def __init__(self, num_tags, embedding_dim, num_heads=8, dropout=0.1):
        super().__init__()
        # Single shared embedding for all tags
        self.embedding = nn.Embedding(num_tags, embedding_dim)
        self.attention = FlashAttention(embedding_dim, num_heads, dropout)
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.norm2 = nn.LayerNorm(embedding_dim)
        
        # Single importance weighting for all tags
        self.tag_importance = nn.Parameter(torch.ones(num_tags) * 0.1)
        
        # Projection layers for unified tag context
        self.context_proj = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim * 2),
            nn.LayerNorm(embedding_dim * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embedding_dim * 2, embedding_dim),
            nn.LayerNorm(embedding_dim)
        )
        
        self.importance_scale = nn.Parameter(torch.tensor(0.1))
        self.context_scale = nn.Parameter(torch.tensor(1.0))
        self.debug = False
        
    def _debug_print(self, name, tensor, extra_info=None):
        """Memory efficient debug printing with type handling"""
        if self.debug:
            print(f"\n{name}:")
            print(f"- Shape: {tensor.shape}")
            if isinstance(tensor, torch.Tensor):
                with torch.no_grad():
                    print(f"- Device: {tensor.device}")
                    print(f"- Dtype: {tensor.dtype}")
                    
                    # Convert to float32 for statistics if needed
                    if tensor.dtype not in [torch.float16, torch.float32, torch.float64]:
                        calc_tensor = tensor.float()
                    else:
                        calc_tensor = tensor
                        
                    try:
                        min_val = calc_tensor.min().item()
                        max_val = calc_tensor.max().item()
                        mean_val = calc_tensor.mean().item()
                        std_val = calc_tensor.std().item()
                        norm_val = torch.norm(calc_tensor).item()
                        
                        print(f"- Value range: [{min_val:.3f}, {max_val:.3f}]")
                        print(f"- Mean: {mean_val:.3f}")
                        print(f"- Std: {std_val:.3f}")
                        print(f"- L2 Norm: {norm_val:.3f}")
                        
                        if extra_info:
                            print(f"- Additional info: {extra_info}")
                    except Exception as e:
                        print(f"- Could not compute statistics: {str(e)}")
    
    def _debug_tensor(self, name, tensor):
        """Debug helper with dtype-specific analysis"""
        if self.debug and isinstance(tensor, torch.Tensor):
            print(f"\n{name}:")
            print(f"- Shape: {tensor.shape}")
            print(f"- Device: {tensor.device}")
            print(f"- Dtype: {tensor.dtype}")
            with torch.no_grad():
                has_nan = torch.isnan(tensor).any().item() if tensor.is_floating_point() else False
                has_inf = torch.isinf(tensor).any().item() if tensor.is_floating_point() else False
                print(f"- Contains NaN: {has_nan}")
                print(f"- Contains Inf: {has_inf}")
                
                # Different stats for different dtypes
                if tensor.is_floating_point():
                    print(f"- Range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]")
                    print(f"- Mean: {tensor.mean().item():.3f}")
                    print(f"- Std: {tensor.std().item():.3f}")
                else:
                    # For integer tensors
                    print(f"- Range: [{tensor.min().item()}, {tensor.max().item()}]")
                    print(f"- Unique values: {tensor.unique().numel()}")
    
    def _process_category(self, indices, masks):
        """Process a single category of tags"""
        # Get embeddings for this category
        embeddings = self.embedding(indices)
        
        if self.debug:
            self._debug_tensor("Category embeddings", embeddings)
        
        # Apply importance weights
        importance = torch.sigmoid(self.tag_importance) * self.importance_scale
        importance = torch.clamp(importance, min=0.01, max=10.0)
        importance_weights = importance[indices].unsqueeze(-1)
        
        # Apply and normalize
        embeddings = embeddings * importance_weights
        embeddings = self.norm1(embeddings)
        
        # Apply attention if we have more than one tag
        if embeddings.size(1) > 1:
            if masks is not None:
                attention_mask = torch.einsum('bi,bj->bij', masks, masks)
                attended = self.attention(embeddings, mask=attention_mask)
            else:
                attended = self.attention(embeddings)
            embeddings = self.norm2(attended)
        
        # Pool embeddings with masking
        if masks is not None:
            masked_embeddings = embeddings * masks.unsqueeze(-1)
            pooled = masked_embeddings.sum(dim=1) / masks.sum(dim=1, keepdim=True).clamp(min=1.0)
        else:
            pooled = embeddings.mean(dim=1)
            
        return pooled, embeddings
    
    def forward(self, tag_indices_dict, tag_masks_dict=None):
        """
        Process all tags in a unified embedding space
        Args:
            tag_indices_dict: dict of {category: tensor of indices}
            tag_masks_dict: dict of {category: tensor of masks}
        """
        if self.debug:
            print("\nOptimizedTagEmbedding Forward Pass")
        
        # Concatenate all indices and masks
        all_indices = []
        all_masks = []
        batch_size = None
        
        for category, indices in tag_indices_dict.items():
            if batch_size is None:
                batch_size = indices.size(0)
            all_indices.append(indices)
            if tag_masks_dict:
                all_masks.append(tag_masks_dict[category])
        
        # Stack along sequence dimension
        combined_indices = torch.cat(all_indices, dim=1)  # [B, total_seq_len]
        if tag_masks_dict:
            combined_masks = torch.cat(all_masks, dim=1)  # [B, total_seq_len]
        
        if self.debug:
            self._debug_tensor("Combined indices", combined_indices)
            if tag_masks_dict:
                self._debug_tensor("Combined masks", combined_masks)
        
        # Get embeddings for all tags using shared embedding
        embeddings = self.embedding(combined_indices)  # [B, total_seq_len, D]
        
        if self.debug:
            self._debug_tensor("Base embeddings", embeddings)
        
        # Apply unified importance weighting
        importance = torch.sigmoid(self.tag_importance) * self.importance_scale
        importance = torch.clamp(importance, min=0.01, max=10.0)
        importance_weights = importance[combined_indices].unsqueeze(-1)
        
        # Apply and normalize importance weights
        embeddings = embeddings * importance_weights
        embeddings = self.norm1(embeddings)
        
        if self.debug:
            self._debug_tensor("Weighted embeddings", embeddings)
        
        # Apply attention across all tags together
        if tag_masks_dict:
            attention_mask = torch.einsum('bi,bj->bij', combined_masks, combined_masks)
            attended = self.attention(embeddings, mask=attention_mask)
        else:
            attended = self.attention(embeddings)
        
        attended = self.norm2(attended)
        
        if self.debug:
            self._debug_tensor("Attended embeddings", attended)
        
        # Global pooling with masking
        if tag_masks_dict:
            masked_embeddings = attended * combined_masks.unsqueeze(-1)
            tag_context = masked_embeddings.sum(dim=1) / combined_masks.sum(dim=1, keepdim=True).clamp(min=1.0)
        else:
            tag_context = attended.mean(dim=1)
        
        # Project and scale context
        tag_context = self.context_proj(tag_context)
        context_scale = torch.clamp(self.context_scale, min=0.1, max=10.0)
        tag_context = tag_context * context_scale
        
        if self.debug:
            self._debug_tensor("Final tag context", tag_context)
        
        return tag_context, attended

class TagDataset:
    """Lightweight dataset wrapper for inference only"""
    def __init__(self, total_tags, idx_to_tag, tag_to_category):
        self.total_tags = total_tags
        self.idx_to_tag = idx_to_tag if isinstance(idx_to_tag, dict) else {int(k): v for k, v in idx_to_tag.items()}
        self.tag_to_category = tag_to_category
    
    def get_tag_info(self, idx):
        """Get tag name and category for a given index"""
        tag_name = self.idx_to_tag.get(idx, f"unknown-{idx}")
        category = self.tag_to_category.get(tag_name, "general")
        return tag_name, category

class ImageTagger(nn.Module):
    def __init__(self, total_tags, dataset, model_name='efficientnet_v2_l',
                 num_heads=16, dropout=0.1, pretrained=True,
                 tag_context_size=256):
        super().__init__()
        # Debug and stats flags
        self._flags = {
            'debug': False,
            'model_stats': False
        }
        
        # Core model config
        self.dataset = dataset
        self.tag_context_size = tag_context_size
        self.embedding_dim = 1280  # Fixed to EfficientNetV2-L output dimension
        
        # Initialize backbone
        if model_name == 'efficientnet_v2_l':
            weights = EfficientNet_V2_L_Weights.DEFAULT if pretrained else None
            self.backbone = efficientnet_v2_l(weights=weights)
            self.backbone.classifier = nn.Identity()
        
        # Spatial pooling only - no projection
        self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Initial tag prediction with bottleneck
        self.initial_classifier = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim * 2),
            nn.LayerNorm(self.embedding_dim * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(self.embedding_dim * 2, self.embedding_dim),
            nn.LayerNorm(self.embedding_dim),
            nn.GELU(),
            nn.Linear(self.embedding_dim, total_tags)
        )
        
        # Tag embeddings at full dimension
        self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim)
        self.tag_attention = FlashAttention(self.embedding_dim, num_heads, dropout)
        self.tag_norm = nn.LayerNorm(self.embedding_dim)

        # Improved cross attention projection
        self.cross_proj = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim * 2),
            nn.LayerNorm(self.embedding_dim * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(self.embedding_dim * 2, self.embedding_dim)
        )
        
        # Cross attention at full dimension
        self.cross_attention = FlashAttention(self.embedding_dim, num_heads, dropout)
        self.cross_norm = nn.LayerNorm(self.embedding_dim)

        # Refined classifier with improved bottleneck
        self.refined_classifier = nn.Sequential(
            nn.Linear(self.embedding_dim * 2, self.embedding_dim * 2),  # Doubled input size for residual
            nn.LayerNorm(self.embedding_dim * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(self.embedding_dim * 2, self.embedding_dim),
            nn.LayerNorm(self.embedding_dim),
            nn.GELU(),
            nn.Linear(self.embedding_dim, total_tags)
        )
        
        # Temperature scaling
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)
        
    def _get_selected_tags(self, logits):
        """Select top-K tags based on prediction confidence"""
        # Apply sigmoid to get probabilities
        probs = torch.sigmoid(logits)
        
        # Get top-K predictions for each image in batch
        batch_size = logits.size(0)
        topk_values, topk_indices = torch.topk(
            probs, k=self.tag_context_size, dim=1, largest=True, sorted=True
        )
        
        return topk_indices, topk_values
    
    @property
    def debug(self):
        return self._flags['debug']
    
    @debug.setter
    def debug(self, value):
        self._flags['debug'] = value
        
    @property
    def model_stats(self):
        return self._flags['model_stats']
    
    @model_stats.setter
    def model_stats(self, value):
        self._flags['model_stats'] = value
        
    def preprocess_image(self, image_path, image_size=512):
        """Process an image for inference using same preprocessing as training"""
        if not os.path.exists(image_path):
            raise ValueError(f"Image not found at path: {image_path}")
       
        # Initialize the same transform used during training
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])
       
        try:
            with Image.open(image_path) as img:
                # Convert RGBA or Palette images to RGB
                if img.mode in ('RGBA', 'P'):
                    img = img.convert('RGB')
               
                # Get original dimensions
                width, height = img.size
                aspect_ratio = width / height
               
                # Calculate new dimensions to maintain aspect ratio
                if aspect_ratio > 1:
                    new_width = image_size
                    new_height = int(new_width / aspect_ratio)
                else:
                    new_height = image_size
                    new_width = int(new_height * aspect_ratio)
               
                # Resize with LANCZOS filter
                img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
               
                # Create new image with padding
                new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0))
                paste_x = (image_size - new_width) // 2
                paste_y = (image_size - new_height) // 2
                new_image.paste(img, (paste_x, paste_y))
               
                # Apply transforms (without normalization)
                img_tensor = transform(new_image)
                return img_tensor
        except Exception as e:
            raise Exception(f"Error processing {image_path}: {str(e)}")
    
    def forward(self, x):
        """Forward pass with simplified feature handling"""
        # Initialize tracking dicts
        model_stats = {} if self.model_stats else {}
        debug_tensors = {} if self.debug else None
        
        # 1. Image Feature Extraction
        features = self.backbone.features(x)
        features = self.spatial_pool(features).squeeze(-1).squeeze(-1)
        
        # 2. Initial Tag Predictions
        initial_logits = self.initial_classifier(features)
        initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
        
        # 3. Tag Selection & Embedding (simplified)
        pred_tag_indices, _ = self._get_selected_tags(initial_preds)
        tag_embeddings = self.tag_embedding(pred_tag_indices)
        
        # 4. Self-Attention on Tags
        attended_tags = self.tag_attention(tag_embeddings)
        attended_tags = self.tag_norm(attended_tags)
        
        # 5. Cross-Attention between Features and Tags
        features_proj = self.cross_proj(features)
        features_expanded = features_proj.unsqueeze(1).expand(-1, self.tag_context_size, -1)
        
        cross_attended = self.cross_attention(features_expanded, attended_tags)
        cross_attended = self.cross_norm(cross_attended)
        
        # 6. Feature Fusion with Residual Connection
        fused_features = cross_attended.mean(dim=1)  # Average across tag dimension
        # Concatenate original and attended features
        combined_features = torch.cat([features, fused_features], dim=-1)
        
        # 7. Refined Predictions
        refined_logits = self.refined_classifier(combined_features)
        refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0)
        
        # Return both prediction sets
        return initial_preds, refined_preds
    
    def predict(self, image_path, threshold=0.325, category_thresholds=None):
        """
        Run inference on an image with support for category-specific thresholds.
        """
        # Preprocess the image
        img_tensor = self.preprocess_image(image_path).unsqueeze(0)
        
        # Move to the same device as model and convert to half precision
        device = next(self.parameters()).device
        dtype = next(self.parameters()).dtype  # Match model's precision
        img_tensor = img_tensor.to(device, dtype=dtype)
        
        # Run inference
        with torch.no_grad():
            initial_preds, refined_preds = self.forward(img_tensor)
            
            # Apply sigmoid to get probabilities
            initial_probs = torch.sigmoid(initial_preds)
            refined_probs = torch.sigmoid(refined_preds)
            
            # Apply thresholds
            if category_thresholds:
                # Create binary prediction tensors
                refined_binary = torch.zeros_like(refined_probs)
                
                # Apply thresholds by category
                for category, cat_threshold in category_thresholds.items():
                    # Create a mask for tags in this category
                    category_mask = torch.zeros_like(refined_probs, dtype=torch.bool)
                    
                    # Find indices for this category
                    for tag_idx in range(refined_probs.size(-1)):
                        try:
                            _, tag_category = self.dataset.get_tag_info(tag_idx)
                            if tag_category == category:
                                category_mask[:, tag_idx] = True
                        except:
                            continue
                    
                    # Apply threshold only to tags in this category - ensure dtype consistency
                    cat_threshold_tensor = torch.tensor(cat_threshold, device=device, dtype=dtype)
                    refined_binary[category_mask] = (refined_probs[category_mask] >= cat_threshold_tensor).to(dtype)
                
                predictions = refined_binary
            else:
                # Use the same threshold for all tags
                threshold_tensor = torch.tensor(threshold, device=device, dtype=dtype)
                predictions = (refined_probs >= threshold_tensor).to(dtype)
            
            # Return both probabilities and thresholded predictions
            return {
                'initial_probabilities': initial_probs,
                'refined_probabilities': refined_probs,
                'predictions': predictions
            }
    
    def get_tags_from_predictions(self, predictions, include_probabilities=True):
        """
        Convert model predictions to human-readable tags grouped by category.
        """
        # Get non-zero predictions
        if predictions.dim() > 1:
            predictions = predictions[0]  # Remove batch dimension
            
        # Get indices of positive predictions
        indices = torch.where(predictions > 0)[0].cpu().tolist()
        
        # Group by category
        result = {}
        for idx in indices:
            tag_name, category = self.dataset.get_tag_info(idx)
            
            if category not in result:
                result[category] = []
            
            if include_probabilities:
                prob = predictions[idx].item()
                result[category].append((tag_name, prob))
            else:
                result[category].append(tag_name)
        
        # Sort tags by probability within each category
        if include_probabilities:
            for category in result:
                result[category] = sorted(result[category], key=lambda x: x[1], reverse=True)
        
        return result

def load_model(model_dir, device='cuda'):
    """Load model with better error handling and warnings"""
    print(f"Loading model from {model_dir}")
    
    try:
        # Load metadata
        metadata_path = os.path.join(model_dir, "metadata.json")
        if not os.path.exists(metadata_path):
            raise FileNotFoundError(f"Metadata file not found at {metadata_path}")
            
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
        
        # Load model info
        model_info_path = os.path.join(model_dir, "model_info_initial_only.json") 
        if os.path.exists(model_info_path):
            with open(model_info_path, 'r') as f:
                model_info = json.load(f)
        else:
            print("WARNING: Model info file not found, using default settings")
            model_info = {
                "tag_context_size": 256,
                "num_heads": 16,
                "precision": "float16"
            }
        
        # Create dataset wrapper
        dataset = TagDataset(
            total_tags=metadata['total_tags'],
            idx_to_tag=metadata['idx_to_tag'],
            tag_to_category=metadata['tag_to_category']
        )
        
        # Initialize model with exact settings from model_info
        model = ImageTagger(
            total_tags=metadata['total_tags'],
            dataset=dataset,
            num_heads=model_info.get('num_heads', 16),
            tag_context_size=model_info.get('tag_context_size', 256),
            pretrained=False
        )
        
        # Load weights
        state_dict_path = os.path.join(model_dir, "model.pt")
        if not os.path.exists(state_dict_path):
            raise FileNotFoundError(f"Model state dict not found at {state_dict_path}")
            
        state_dict = torch.load(state_dict_path, map_location=device)
        
        # First try strict loading
        try:
            model.load_state_dict(state_dict, strict=True)
            print("✓ Model state dict loaded with strict=True successfully")
        except Exception as e:
            print(f"! Strict loading failed: {str(e)}")
            print("Attempting non-strict loading...")
            
            # Try non-strict loading
            missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
            
            print(f"Non-strict loading completed with:")
            print(f"- {len(missing_keys)} missing keys")
            print(f"- {len(unexpected_keys)} unexpected keys")
            
            if len(missing_keys) > 0:
                print(f"Sample missing keys: {missing_keys[:5]}")
            if len(unexpected_keys) > 0:
                print(f"Sample unexpected keys: {unexpected_keys[:5]}")
        
        # Move model to device
        model = model.to(device)
        
        # Set to half precision if needed
        if model_info.get('precision') == 'float16':
            model = model.half()
            print("✓ Model converted to half precision")
        
        # Set to eval mode
        model.eval()
        print("✓ Model set to evaluation mode")
        
        # Verify parameter dtype
        param_dtype = next(model.parameters()).dtype
        print(f"✓ Model loaded with precision: {param_dtype}")
        
        return model, dataset
    
    except Exception as e:
        print(f"ERROR loading model: {str(e)}")
        import traceback
        traceback.print_exc()
        raise

# Example usage
if __name__ == "__main__":
    import sys
    
    # Get model directory from command line or use default
    model_dir = sys.argv[1] if len(sys.argv) > 1 else "./exported_model"
    
    # Load model
    model, dataset, thresholds = load_model(model_dir)
    
    # Display info
    print(f"\nModel information:")
    print(f"  Total tags: {dataset.total_tags}")
    print(f"  Device: {next(model.parameters()).device}")
    print(f"  Precision: {next(model.parameters()).dtype}")
    
    # Test on an image if provided
    if len(sys.argv) > 2:
        image_path = sys.argv[2]
        print(f"\nRunning inference on {image_path}")
        
        # Use category thresholds if available
        if thresholds and 'categories' in thresholds:
            category_thresholds = {cat: opt['balanced']['threshold'] 
                              for cat, opt in thresholds['categories'].items()}
            results = model.predict(image_path, category_thresholds=category_thresholds)
        else:
            results = model.predict(image_path)
        
        # Get tags
        tags = model.get_tags_from_predictions(results['predictions'])
        
        # Print tags by category
        print("\nPredicted tags:")
        for category, category_tags in tags.items():
            print(f"\n{category.capitalize()}:")
            for tag, prob in category_tags:
                print(f"  {tag}: {prob:.3f}")