Update app.py
Browse files
app.py
CHANGED
@@ -1,49 +1,30 @@
|
|
1 |
import streamlit as st
|
2 |
from transformers import AutoTokenizer, TextGenerationPipeline
|
3 |
-
from auto_gptq import AutoGPTQForCausalLM
|
4 |
-
import logging
|
5 |
|
6 |
-
# Load tokenizer and model
|
7 |
-
|
|
|
8 |
|
9 |
-
|
10 |
-
model = AutoGPTQForCausalLM.from_quantized(
|
11 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
12 |
|
13 |
-
#
|
14 |
-
|
15 |
-
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
|
16 |
-
model_inputs = encodeds.to(device)
|
17 |
-
generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
|
18 |
-
decoded_response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
19 |
-
return decoded_response[0]
|
20 |
|
21 |
-
# Streamlit app
|
22 |
def main():
|
23 |
-
st.title("
|
24 |
|
25 |
-
|
|
|
26 |
|
27 |
-
|
28 |
-
#
|
29 |
-
|
30 |
-
|
31 |
-
if user_input.strip():
|
32 |
-
# Add user message to chat history
|
33 |
-
messages.append({"role": "user", "content": user_input})
|
34 |
-
|
35 |
-
# Generate response from the model
|
36 |
-
response = generate_response(messages)
|
37 |
-
|
38 |
-
# Display assistant's response
|
39 |
-
st.text_area("Assistant:", value=response, height=150)
|
40 |
-
|
41 |
-
# Add assistant's response to chat history
|
42 |
-
messages.append({"role": "assistant", "content": response})
|
43 |
|
44 |
if __name__ == "__main__":
|
45 |
main()
|
46 |
|
47 |
-
|
48 |
-
|
49 |
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import AutoTokenizer, TextGenerationPipeline
|
3 |
+
from auto_gptq import AutoGPTQForCausalLM
|
|
|
4 |
|
5 |
+
# Load the tokenizer and model
|
6 |
+
pretrained_model_dir = "TheBloke/Llama-2-7b-Chat-GPTQ"
|
7 |
+
quantized_model_dir = "amanchahar/llama2_finetune_Restaurants"
|
8 |
|
9 |
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
|
10 |
+
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0")
|
|
|
11 |
|
12 |
+
# Create a text generation pipeline
|
13 |
+
pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer)
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
# Define the Streamlit app
|
16 |
def main():
|
17 |
+
st.title("Restaurants Auto-GPTQ Text Generation")
|
18 |
|
19 |
+
# User input text box
|
20 |
+
user_input = st.text_input("Enter your query:", "auto-gptq is")
|
21 |
|
22 |
+
if st.button("Generate"):
|
23 |
+
# Generate response based on user input
|
24 |
+
generated_text = pipeline(user_input)[0]["generated_text"]
|
25 |
+
st.markdown(f"**Generated Response:** {generated_text}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
if __name__ == "__main__":
|
28 |
main()
|
29 |
|
|
|
|
|
30 |
|