|  | --- | 
					
						
						|  | license: apache-2.0 | 
					
						
						|  | inference: false | 
					
						
						|  | tags: | 
					
						
						|  | - auto-gptq | 
					
						
						|  | pipeline_tag: text-generation | 
					
						
						|  | --- | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | # redpajama gptq: RedPajama-INCITE-Chat-3B-v1 | 
					
						
						|  |  | 
					
						
						|  | <a href="https://colab.research.google.com/gist/pszemraj/86d2e8485df182302646ed2c5a637059/inference-with-redpajama-incite-chat-3b-v1-gptq-4bit-128g.ipynb"> | 
					
						
						|  | <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> | 
					
						
						|  | </a> | 
					
						
						|  |  | 
					
						
						|  | A GPTQ quantization of the [RedPajama-INCITE-Chat-3B-v1](https://huggingface.co/togethercomputer/RedPajama-INCITE-Chat-3B-v1) via auto-gptq. Model file is only 2GB. | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ## Usage | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | > Note that you cannot load directly from the hub with `auto_gptq` yet - if needed you can use [this function](https://gist.github.com/pszemraj/8368cba3400bda6879e521a55d2346d0) to download using the repo name. | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | first install auto-GPTQ | 
					
						
						|  |  | 
					
						
						|  | ```bash | 
					
						
						|  | pip install ninja auto-gptq[triton] | 
					
						
						|  | ``` | 
					
						
						|  |  | 
					
						
						|  | load: | 
					
						
						|  |  | 
					
						
						|  | ```python | 
					
						
						|  | import torch | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | from auto_gptq import AutoGPTQForCausalLM | 
					
						
						|  | from transformers import AutoTokenizer | 
					
						
						|  |  | 
					
						
						|  | model_repo = Path.cwd() / "RedPajama-INCITE-Chat-3B-v1-GPTQ-4bit-128g" | 
					
						
						|  | device = "cuda:0" if torch.cuda.is_available() else "cpu" | 
					
						
						|  | tokenizer = AutoTokenizer.from_pretrained(model_repo) | 
					
						
						|  | model = AutoGPTQForCausalLM.from_quantized( | 
					
						
						|  | model_repo, | 
					
						
						|  | device=device, | 
					
						
						|  | use_safetensors=True, | 
					
						
						|  | use_triton=device != "cpu", # comment/remove if not on Linux | 
					
						
						|  | ).to(device) | 
					
						
						|  | ``` | 
					
						
						|  |  | 
					
						
						|  | Inference: | 
					
						
						|  |  | 
					
						
						|  | ```python | 
					
						
						|  | import re | 
					
						
						|  | import pprint as pp | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | prompt = "How can I further strive to increase shareholder value even further?" | 
					
						
						|  | prompt = f"<human>: {prompt}\n<bot>:" | 
					
						
						|  | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | 
					
						
						|  |  | 
					
						
						|  | outputs = model.generate( | 
					
						
						|  | **inputs, | 
					
						
						|  | penalty_alpha=0.6, | 
					
						
						|  | top_k=4, | 
					
						
						|  | temperature=0.7, | 
					
						
						|  | do_sample=True, | 
					
						
						|  | max_new_tokens=192, | 
					
						
						|  | length_penalty=0.9, | 
					
						
						|  | pad_token_id=model.config.eos_token_id | 
					
						
						|  | ) | 
					
						
						|  | result = tokenizer.batch_decode( | 
					
						
						|  | outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | bot_responses = re.findall(r'<bot>:(.*?)(<human>|$)', result[0], re.DOTALL) | 
					
						
						|  | bot_responses = [response[0].strip() for response in bot_responses] | 
					
						
						|  |  | 
					
						
						|  | print(bot_responses[0]) | 
					
						
						|  | ``` |