MLX_GPT_OSS_120B / mlx-gpt-oss-120b /analyze_safetensors.py
TroglodyteDerivations's picture
Upload 48 files
c28358e verified
raw
history blame
3.67 kB
#!/usr/bin/env python3
"""
Understand and convert the MXFP4-Q4 format - FIXED
"""
import torch
from safetensors import safe_open
import numpy as np
import json
import logging
from pathlib import Path
from transformers import AutoConfig # Added missing import
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def analyze_safetensors(model_path):
"""Analyze the safetensors files to understand the format"""
logger.info("πŸ” Analyzing safetensors files...")
model_files = list(Path(model_path).glob("model-*.safetensors"))
logger.info(f"πŸ“¦ Found {len(model_files)} model files")
# Check the first file to understand structure
if model_files:
first_file = model_files[0]
logger.info(f"πŸ“„ Analyzing: {first_file.name}")
try:
with safe_open(first_file, framework="pt") as f:
# Get all keys in this file
keys = f.keys()
logger.info(f" Contains {len(keys)} tensors")
# Show first few keys
for i, key in enumerate(list(keys)[:5]):
tensor = f.get_tensor(key)
logger.info(f" {key}: shape {tuple(tensor.shape)}, dtype {tensor.dtype}")
if i >= 4:
logger.info(" ... (more tensors in file)")
break
except Exception as e:
logger.error(f"❌ Error analyzing {first_file}: {e}")
# Check index file
index_file = Path(model_path) / "model.safetensors.index.json"
if index_file.exists():
with open(index_file, 'r') as f:
index_data = json.load(f)
logger.info(f"πŸ“‹ Total weights in index: {len(index_data['weight_map'])}")
# Count weights by type
weight_types = {}
for weight_name in index_data['weight_map'].keys():
weight_type = weight_name.split('.')[-1] if '.' in weight_name else weight_name
weight_types[weight_type] = weight_types.get(weight_type, 0) + 1
logger.info("πŸ“Š Weight types distribution:")
for wt, count in sorted(weight_types.items()):
logger.info(f" {wt}: {count}")
def check_quantization_method(model_path):
"""Check what quantization method is used"""
logger.info("πŸ” Checking quantization method...")
# Check config for quantization info
config = AutoConfig.from_pretrained(model_path)
if hasattr(config, 'quantization_config'):
logger.info(f"πŸ“Š Quantization config: {config.quantization_config}")
else:
logger.info("πŸ“Š No quantization config found in model config")
# Check for any quantization metadata
config_file = Path(model_path) / "config.json"
with open(config_file, 'r') as f:
config_data = json.load(f)
if 'quantization_config' in config_data:
logger.info(f"🎯 Quantization method: {config_data['quantization_config']}")
else:
logger.info("ℹ️ Model uses custom MXFP4-Q4 quantization (Apple MLX optimized)")
if __name__ == "__main__":
model_path = "./my_model"
analyze_safetensors(model_path)
check_quantization_method(model_path)
logger.info("\nπŸ’‘ This model uses MXFP4-Q4 quantization optimized for Apple MLX")
logger.info(" It requires custom loading rather than standard Transformers")
logger.info(" Consider using the original MLX implementation from the model authors")