viral-esm2-3b-hqq

This is a 4-bit HQQ quantized version of mahdi-b/viral-esm2-3b.

Quantization Details

  • Method: HQQ (Half-Quadratic Quantization)
  • Bits: 4
  • Group Size: 16
  • Compute dtype: float16

Usage

Easy loading (recommended):

from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

# Load with trust_remote_code=True
model = AutoModelForMaskedLM.from_pretrained("mahdi-b/viral-esm2-3b-hqq", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("mahdi-b/viral-esm2-3b-hqq")

# Use the model
inputs = tokenizer("MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTL", return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}

with torch.no_grad():
    outputs = model(**inputs)

Manual loading:

from transformers import AutoModelForMaskedLM, AutoTokenizer
from hqq.models.hf.base import AutoHQQHFModel
from hqq.core.quantize import BaseQuantizeConfig
from huggingface_hub import hf_hub_download
import torch

# Download weights
weights_path = hf_hub_download("mahdi-b/viral-esm2-3b-hqq", "pytorch_model.bin")

# Create model with same quantization
model = AutoModelForMaskedLM.from_pretrained("mahdi-b/viral-esm2-3b", torch_dtype=torch.float16)
quant_cfg = BaseQuantizeConfig(nbits=4, group_size=16)
AutoHQQHFModel.quantize_model(model, quant_config=quant_cfg, compute_dtype=torch.float16, device={"": torch.device("cuda:0")})

# Load weights and move to GPU
model.load_state_dict(torch.load(weights_path, map_location="cuda:0"))
model = model.to("cuda:0")

# Ready to use
tokenizer = AutoTokenizer.from_pretrained("mahdi-b/viral-esm2-3b-hqq")

Model Size

  • Original model: ~11GB (float16)
  • Quantized model: ~2.8GB (4-bit)

Requirements

  • transformers
  • torch
  • hqq
Downloads last month
2
Safetensors
Model size
1.78B params
Tensor type
I64
F16
U8
Inference Providers NEW
This model isn't deployed by any Inference Provider. 馃檵 Ask for provider support

Model tree for mahdi-b/viral-esm2-3b-hqq

Finetuned
(1)
this model