|
import streamlit as st |
|
from transformers import AutoTokenizer, TextGenerationPipeline |
|
from auto_gptq import AutoGPTQForCausalLM |
|
|
|
|
|
pretrained_model_dir = "TheBloke/Llama-2-7b-Chat-GPTQ" |
|
quantized_model_dir = "amanchahar/llama2_finetune_Restaurants" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) |
|
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0") |
|
|
|
|
|
pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer) |
|
|
|
|
|
def main(): |
|
st.title("Restaurants Auto-GPTQ Text Generation") |
|
|
|
|
|
user_input = st.text_input("Enter your query:", "auto-gptq is") |
|
|
|
if st.button("Generate"): |
|
|
|
generated_text = pipeline(user_input)[0]["generated_text"] |
|
st.markdown(f"**Generated Response:** {generated_text}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|