meta-llama_Llama-3.2-3B-Instruct-TEQ-int4-gs128-sym / meta-llama_Llama-3.2-3B-Instruct-TEQ-int4-gs128-sym.py
fbaldassarri's picture
Initial Upload
85091c5 verified
import os
import sys
import time
import random
import torch
from collections import UserDict
from packaging.version import Version
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from neural_compressor.common import logger
from neural_compressor.torch.utils import is_hpex_available, get_torch_version
# ====== utils.py content inlined and fixed ======
class DataloaderPreprocessor:
def __init__(self, dataloader_original, use_max_length=False, max_seq_length=2048, nsamples=128) -> None:
self.dataloader_original = dataloader_original
self.use_max_length = use_max_length
self.max_seq_length = max_seq_length
self.nsamples = nsamples
self.dataloader = []
self.is_ready = False
def get_prepared_dataloader(self):
if not self.is_ready:
self.prepare_dataloader()
return self.dataloader
def prepare_dataloader(self):
if self.use_max_length:
self.obtain_first_n_samples_fulllength()
else:
self.obtain_first_n_samples()
self.is_ready = True
def obtain_first_n_samples(self, seed=0):
"""Get first nsample data as the real calibration dataset."""
self.dataloader.clear()
random.seed(seed)
for batch in self.dataloader_original:
if len(self.dataloader) == self.nsamples:
logger.info(f"Successfully collect {self.nsamples} calibration samples.")
break
# list, tuple
if isinstance(batch, list) or isinstance(batch, tuple):
if batch[0].shape[-1] > self.max_seq_length:
i = random.randint(0, batch[0].shape[-1] - self.max_seq_length - 1)
j = i + self.max_seq_length
batch_final = []
for item in batch:
if isinstance(item, torch.Tensor) and item.ndim == 2:
batch_final.append(item[:, i:j])
else:
batch_final.append(item)
else:
batch_final = batch[:]
# dict
elif isinstance(batch, dict):
try:
length = batch["input_ids"].shape[-1]
except Exception:
logger.warning("Please make sure your dict'like data contains key of 'input_ids'.")
continue
batch_final = {}
if length > self.max_seq_length:
i = random.randint(0, length - self.max_seq_length - 1)
j = i + self.max_seq_length
for key in batch.keys():
if isinstance(batch[key], torch.Tensor):
batch_final[key] = batch[key][:, i:j]
else:
batch_final[key] = batch[key]
else:
batch_final = batch
# tensor
else:
if batch.shape[-1] > self.max_seq_length:
i = random.randint(0, batch.shape[-1] - self.max_seq_length - 1)
j = i + self.max_seq_length
batch_final = batch[:, i:j]
else:
batch_final = batch
self.dataloader.append(batch_final)
if len(self.dataloader) < self.nsamples:
logger.warning(f"Try to use {self.nsamples} data, but entire dataset size is {len(self.dataloader)}.")
def obtain_first_n_samples_fulllength(self, seed=0):
self.dataloader.clear()
random.seed(seed)
unified_length = self.max_seq_length
for batch in self.dataloader_original:
if len(self.dataloader) == self.nsamples:
logger.info(f"Successfully collect {self.nsamples} calibration samples.")
break
# list & tuple
if isinstance(batch, list) or isinstance(batch, tuple):
if batch[0].shape[-1] == unified_length:
batch_final = batch[:]
elif batch[0].shape[-1] > unified_length:
i = random.randint(0, batch[0].shape[-1] - unified_length - 1)
j = i + unified_length
batch_final = []
for item in batch:
if isinstance(item, torch.Tensor) and item.ndim == 2:
batch_final.append(item[:, i:j])
else:
batch_final.append(item)
else:
continue
# dict
elif isinstance(batch, dict):
try:
length = batch["input_ids"].shape[-1]
except Exception:
logger.warning("Please make sure your dict'like data contains key of 'input_ids'.")
continue
batch_final = {}
if length == self.max_seq_length:
batch_final = batch
elif length > self.max_seq_length:
i = random.randint(0, length - self.max_seq_length - 1)
j = i + self.max_seq_length
for key in batch.keys():
if isinstance(batch[key], torch.Tensor):
batch_final[key] = batch[key][:, i:j]
else:
batch_final[key] = batch[key]
else:
continue
# tensor
else:
if batch.shape[-1] == unified_length:
batch_final = batch
elif batch.shape[-1] > unified_length:
i = random.randint(0, batch.shape[-1] - unified_length - 1)
j = i + unified_length
batch_final = batch[:, i:j]
else:
continue
self.dataloader.append(batch_final)
if len(self.dataloader) < self.nsamples:
logger.warning(
f"Trying to allocate {self.nsamples} data with fixed length {unified_length}, "
f"but only {len(self.dataloader)} samples are found. Please use smaller 'self.max_seq_length' value."
)
def get_example_inputs(model, dataloader):
version = get_torch_version()
from neural_compressor.torch.algorithms.smooth_quant import move_input_to_device
if dataloader is None:
return None
device = next(model.parameters()).device
try:
for idx, (input, label) in enumerate(dataloader):
input = move_input_to_device(input, device)
if isinstance(input, (dict, UserDict)):
assert version.release >= Version("1.12.0").release, "INC support IPEX version >= 1.12.0"
if "label" in input.keys():
input.pop("label")
if version.release <= Version("2.0.1").release:
return tuple(input.values())
else:
return dict(input)
if isinstance(input, (list, tuple)):
return tuple(input)
if isinstance(input, torch.Tensor):
return input
break
except Exception as e:
for idx, input in enumerate(dataloader):
input = move_input_to_device(input, device)
if isinstance(input, (dict, UserDict)):
assert version.release >= Version("1.12.0").release, "INC support IPEX version >= 1.12.0"
if "label" in input.keys():
input.pop("label")
if version.release <= Version("2.0.1").release:
return tuple(input.values())
else:
return dict(input)
if isinstance(input, list) or isinstance(input, tuple):
return tuple(input)
if isinstance(input, torch.Tensor):
return input
break
if idx == 0:
assert False, "Please checkout the example_inputs format."
# ====== End of utils.py content ======
# ====== Hardcoded arguments ======
class Args:
model = "meta-llama/Llama-3.2-3B-Instruct"
trust_remote_code = True
revision = None
dataset = "neuralmagic/LLM_compression_calibration"
output_dir = "meta-llama_Llama-3.2-3B-Instruct-TEQ-int4-gs128-sym"
quantize = True
seed = 42
load = False
accuracy = False
performance = False
iters = 100
batch_size = 1
pad_max_length = 512
calib_iters = 512
tasks = "lambada_openai,hellaswag,winogrande,piqa"
peft_model_id = None
# Weight-only quantization configs
woq_algo = "TEQ"
woq_bits = 4
woq_dtype = "int"
woq_group_size = 128
woq_group_dim = 1
woq_scheme = "sym"
woq_use_mse_search = False
woq_use_full_range = False
quant_lm_head = True
use_hf_format = False
# TEQ/AWQ configs
use_auto_scale = False
use_auto_clip = False
folding = False
absorb_layer_dict = {}
# DoubleQuant configs
double_quant_type = None
double_quant_dtype = "fp32"
double_quant_bits = 8
double_quant_use_sym = True
double_quant_group_size = 256
args = Args()
calib_size = 1
if is_hpex_available():
import habana_frameworks.torch.core as htcore
htcore.hpu_set_inference_env()
device = "hpu"
else:
device = "cpu"
# ====== Helper functions ======
def get_user_model():
torchscript = False
if args.woq_algo in ["AWQ", "TEQ"]:
torchscript = True
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
torchscript=torchscript,
trust_remote_code=args.trust_remote_code,
revision=args.revision,
)
tokenizer = AutoTokenizer.from_pretrained(args.model)
user_model = user_model.float()
user_model = user_model.to(memory_format=torch.channels_last)
user_model.eval()
return user_model, tokenizer
def calib_func(prepared_model):
for i, calib_input in enumerate(calib_dataloader):
if i > args.calib_iters:
break
prepared_model(calib_input[0])
# ====== Main quantization logic ======
if args.quantize:
user_model, tokenizer = get_user_model()
calib_dataset = load_dataset(args.dataset, split="train")
calib_dataset = calib_dataset.shuffle(seed=args.seed)
class Evaluator:
def __init__(self, dataset, tokenizer, batch_size=8, pad_val=1, pad_max=196, is_calib=False):
self.dataset = dataset
self.tokenizer = tokenizer
self.batch_size = batch_size
self.pad_val = pad_val
self.pad_max = pad_max
self.is_calib = is_calib
self.dataset = self.dataset.map(self.tokenize_function, batched=True)
self.dataset.set_format(type="torch", columns=["input_ids"])
@torch.no_grad()
def tokenize_function(self, examples):
if args.woq_algo in ['TEQ']:
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
example = self.tokenizer(examples["text"], padding="max_length", max_length=self.pad_max)
else:
example = self.tokenizer(examples["text"])
return example
@torch.no_grad()
def collate_batch(self, batch):
input_ids_padded = []
last_ind = []
for text in batch:
input_ids = text["input_ids"]
pad_len = self.pad_max - input_ids.shape[0]
last_ind.append(input_ids.shape[0] - 1)
input_ids = input_ids[:self.pad_max] if len(input_ids) > self.pad_max else input_ids
input_ids = torch.nn.functional.pad(input_ids, (0, pad_len), value=self.pad_val)
input_ids_padded.append(input_ids)
return (torch.vstack(input_ids_padded), torch.tensor(last_ind))
calib_evaluator = Evaluator(calib_dataset, tokenizer, args.batch_size, pad_max=args.pad_max_length, is_calib=True)
calib_dataloader = DataLoader(
calib_evaluator.dataset,
batch_size=calib_size,
shuffle=False,
collate_fn=calib_evaluator.collate_batch,
)
# === TEQ quantization ===
from neural_compressor.torch.quantization import TEQConfig, prepare, convert
weight_sym = True if args.woq_scheme == "sym" else False
quant_config = TEQConfig(
dtype=args.woq_dtype,
bits=args.woq_bits,
use_sym=weight_sym,
group_size=args.woq_group_size,
group_dim=args.woq_group_dim,
folding=args.folding,
quant_lm_head=args.quant_lm_head,
)
example_inputs = torch.ones([1, args.pad_max_length], dtype=torch.long)
run_fn = calib_func
user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(user_model)
user_model = convert(user_model)
# === Save quantized model ===
os.makedirs(args.output_dir, exist_ok=True)
print("Saving weight-only quantized model to", args.output_dir)
if args.use_hf_format:
user_model.save(args.output_dir, format="huggingface")
tokenizer.save_pretrained(args.output_dir)
else:
user_model.save(args.output_dir)
print("Saved weight-only quantized model.")
else:
print("Quantization not enabled. Exiting.")