amanchahar commited on
Commit
f80ec8d
·
verified ·
1 Parent(s): 1fe1f4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -35
app.py CHANGED
@@ -1,49 +1,30 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, TextGenerationPipeline
3
- from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
4
- import logging
5
 
6
- # Load tokenizer and model
7
- device = "cuda" if st.checkbox("Use GPU") else "cpu"
 
8
 
9
- model_name = "amanchahar/llama2_finetune_Restaurants"
10
- model = AutoGPTQForCausalLM.from_quantized(model_name, device="cuda:0", use_safetensors=True, use_triton=False)
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
 
13
- # Function to generate model response
14
- def generate_response(messages):
15
- encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
16
- model_inputs = encodeds.to(device)
17
- generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
18
- decoded_response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
19
- return decoded_response[0]
20
 
21
- # Streamlit app
22
  def main():
23
- st.title("Interactive Conversational AI")
24
 
25
- messages = []
 
26
 
27
- while True:
28
- # Display chat interface
29
- user_input = st.text_input("You:", key="user_input")
30
- if st.button("Send"):
31
- if user_input.strip():
32
- # Add user message to chat history
33
- messages.append({"role": "user", "content": user_input})
34
-
35
- # Generate response from the model
36
- response = generate_response(messages)
37
-
38
- # Display assistant's response
39
- st.text_area("Assistant:", value=response, height=150)
40
-
41
- # Add assistant's response to chat history
42
- messages.append({"role": "assistant", "content": response})
43
 
44
  if __name__ == "__main__":
45
  main()
46
 
47
-
48
-
49
 
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, TextGenerationPipeline
3
+ from auto_gptq import AutoGPTQForCausalLM
 
4
 
5
+ # Load the tokenizer and model
6
+ pretrained_model_dir = "TheBloke/Llama-2-7b-Chat-GPTQ"
7
+ quantized_model_dir = "amanchahar/llama2_finetune_Restaurants"
8
 
9
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
10
+ model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0")
 
11
 
12
+ # Create a text generation pipeline
13
+ pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer)
 
 
 
 
 
14
 
15
+ # Define the Streamlit app
16
  def main():
17
+ st.title("Restaurants Auto-GPTQ Text Generation")
18
 
19
+ # User input text box
20
+ user_input = st.text_input("Enter your query:", "auto-gptq is")
21
 
22
+ if st.button("Generate"):
23
+ # Generate response based on user input
24
+ generated_text = pipeline(user_input)[0]["generated_text"]
25
+ st.markdown(f"**Generated Response:** {generated_text}")
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  if __name__ == "__main__":
28
  main()
29
 
 
 
30