Model Overview

  • Model Architecture: Meta-Llama-3.1
    • Input: Text
    • Output: Text
  • Supported Hardware Microarchitecture: AMD MI350/MI355
  • Preferred Operating System(s): Linux
  • Inference Engine: vLLM
  • Model Optimizer: AMD-Quark
    • Weight quantization: OCP MXFP4
    • Activation quantization: OCP MXFP4
    • KV cache quantization: OCP FP8
  • Calibration Dataset: Pile

The model is the quantized version of the Meta-Llama-3.1-405B-Instruct model, which is an auto-regressive language model that uses an optimized transformer architecture. For more information, please check here. The MXFP4 model is quantized with AMD-Quark.

Model Quantization

This model was obtained by quantizing Meta-Llama-3.1-405B-Instruct's weights and activations to MXFP4 and KV caches to FP8, using AutoSmoothQuant algorithm in AMD-Quark.

Quantization scripts:

cd Quark/examples/torch/language_modeling/llm_ptq/
python3 quantize_quark.py --model_dir "meta-llama/Meta-Llama-3.1-405B-Instruct" \
                          --model_attn_implementation "sdpa" \
                          --quant_scheme w_mxfp4_a_mxfp4 \
                          --kv_cache_dtype fp8 \
                          --quant_algo autosmoothquant \
                          --min_kv_scale 1.0 \
                          --model_export hf_format \
                          --output_dir $output_path \
                          --multi_gpu

Deployment

Use with vLLM

This model can be deployed efficiently using the vLLM backend.

Evaluation

The model was evaluated on MMLU and GSM8K_COT. Evaluation was conducted using the framework lm-evaluation-harness and the vLLM engine.

Accuracy

Benchmark Meta-Llama-3.1-405B-Instruct Meta-Llama-3.1-405B-Instruct-MXFP4(this model) Recovery
MMLU (5-shot) 87.63 86.62 98.85%
GSM8K_COT (8-shot, strict-match) 96.51 96.06 99.53%

Reproduction

The results were obtained using the following commands:

MMLU

lm_eval \
    --model vllm \
    --model_args pretrained="amd/Llama-3.1-405B-Instruct-MXFP4-Preview",gpu_memory_utilization=0.85,tensor_parallel_size=8,kv_cache_dtype='fp8' \
    --tasks mmlu_llama \
    --fewshot_as_multiturn \
    --apply_chat_template \
    --num_fewshot 5 \
    --batch_size auto

GSM8K_COT

lm_eval \
    --model vllm \
    --model_args pretrained="amd/Llama-3.1-405B-Instruct-MXFP4-Preview",gpu_memory_utilization=0.85,tensor_parallel_size=8,kv_cache_dtype='fp8' \
    --tasks gsm8k_llama \
    --fewshot_as_multiturn \
    --apply_chat_template \
    --num_fewshot 8 \
    --batch_size auto

License

Modifications copyright(c) 2024 Advanced Micro Devices,Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Downloads last month
106
Safetensors
Model size
218B params
Tensor type
BF16
·
U8
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for amd/Llama-3.1-405B-Instruct-MXFP4-Preview

Quantized
(30)
this model

Collection including amd/Llama-3.1-405B-Instruct-MXFP4-Preview