|
|
|
""" |
|
Understand and convert the MXFP4-Q4 format - FIXED |
|
""" |
|
|
|
import torch |
|
from safetensors import safe_open |
|
import numpy as np |
|
import json |
|
import logging |
|
from pathlib import Path |
|
from transformers import AutoConfig |
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
def analyze_safetensors(model_path): |
|
"""Analyze the safetensors files to understand the format""" |
|
logger.info("π Analyzing safetensors files...") |
|
|
|
model_files = list(Path(model_path).glob("model-*.safetensors")) |
|
logger.info(f"π¦ Found {len(model_files)} model files") |
|
|
|
|
|
if model_files: |
|
first_file = model_files[0] |
|
logger.info(f"π Analyzing: {first_file.name}") |
|
|
|
try: |
|
with safe_open(first_file, framework="pt") as f: |
|
|
|
keys = f.keys() |
|
logger.info(f" Contains {len(keys)} tensors") |
|
|
|
|
|
for i, key in enumerate(list(keys)[:5]): |
|
tensor = f.get_tensor(key) |
|
logger.info(f" {key}: shape {tuple(tensor.shape)}, dtype {tensor.dtype}") |
|
if i >= 4: |
|
logger.info(" ... (more tensors in file)") |
|
break |
|
|
|
except Exception as e: |
|
logger.error(f"β Error analyzing {first_file}: {e}") |
|
|
|
|
|
index_file = Path(model_path) / "model.safetensors.index.json" |
|
if index_file.exists(): |
|
with open(index_file, 'r') as f: |
|
index_data = json.load(f) |
|
logger.info(f"π Total weights in index: {len(index_data['weight_map'])}") |
|
|
|
|
|
weight_types = {} |
|
for weight_name in index_data['weight_map'].keys(): |
|
weight_type = weight_name.split('.')[-1] if '.' in weight_name else weight_name |
|
weight_types[weight_type] = weight_types.get(weight_type, 0) + 1 |
|
|
|
logger.info("π Weight types distribution:") |
|
for wt, count in sorted(weight_types.items()): |
|
logger.info(f" {wt}: {count}") |
|
|
|
def check_quantization_method(model_path): |
|
"""Check what quantization method is used""" |
|
logger.info("π Checking quantization method...") |
|
|
|
|
|
config = AutoConfig.from_pretrained(model_path) |
|
|
|
if hasattr(config, 'quantization_config'): |
|
logger.info(f"π Quantization config: {config.quantization_config}") |
|
else: |
|
logger.info("π No quantization config found in model config") |
|
|
|
|
|
config_file = Path(model_path) / "config.json" |
|
with open(config_file, 'r') as f: |
|
config_data = json.load(f) |
|
if 'quantization_config' in config_data: |
|
logger.info(f"π― Quantization method: {config_data['quantization_config']}") |
|
else: |
|
logger.info("βΉοΈ Model uses custom MXFP4-Q4 quantization (Apple MLX optimized)") |
|
|
|
if __name__ == "__main__": |
|
model_path = "./my_model" |
|
|
|
analyze_safetensors(model_path) |
|
check_quantization_method(model_path) |
|
|
|
logger.info("\nπ‘ This model uses MXFP4-Q4 quantization optimized for Apple MLX") |
|
logger.info(" It requires custom loading rather than standard Transformers") |
|
logger.info(" Consider using the original MLX implementation from the model authors") |