sanketchaudhary10 commited on
Commit
2a24040
·
1 Parent(s): 01d8c2b

Adding files

Browse files
Files changed (5) hide show
  1. __init__.py +0 -0
  2. chat_logic.py +33 -0
  3. main.py +23 -0
  4. model.py +25 -0
  5. utils.py +6 -0
__init__.py ADDED
File without changes
chat_logic.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import tokenizer, model
2
+
3
+ HISTORY = [] # Initialize an empty history list
4
+
5
+ def chat(message, history):
6
+ """
7
+ Handles user input, generates a response using the model, and updates the chat history.
8
+ """
9
+ # Combine history with the current message
10
+ conversation = "\n".join(history) + f"\nUser: {message}\nAssistant:"
11
+
12
+ # Tokenize and generate response
13
+ inputs = tokenizer(conversation, return_tensors="pt", truncation=True, max_length=1024).to("cuda")
14
+ outputs = model.generate(inputs.input_ids, max_length=1024, temperature=0.7, do_sample=True)
15
+ reply = tokenizer.decode(outputs[:, inputs.input_ids.shape[-1]:][0], skip_special_tokens=True)
16
+
17
+ # Update history
18
+ update_history(message, reply)
19
+ return reply
20
+
21
+ def update_history(message, reply):
22
+ """
23
+ Update the global history with the latest message and reply.
24
+ """
25
+ global HISTORY
26
+ HISTORY.append(f"User: {message}")
27
+ HISTORY.append(f"Assistant: {reply}")
28
+
29
+ def get_history():
30
+ """
31
+ Retrieve the chat history as a string.
32
+ """
33
+ return "\n".join(HISTORY)
main.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from chat_logic import chat, get_history
3
+
4
+ def launch_gradio_ui():
5
+ with gr.Blocks() as gui:
6
+ gr.Markdown("## Chat With Llama 3.1-8B")
7
+
8
+ with gr.Row():
9
+ with gr.Column(scale=3):
10
+ chat_interface = gr.ChatInterface(fn=chat)
11
+
12
+ with gr.Column(scale=1):
13
+ gr.Markdown("### Message History")
14
+ history_display = gr.Textbox(label="Chat History", lines=27, interactive=False)
15
+ refresh_button = gr.Button("Refresh History")
16
+
17
+ # Update history display when the button is clicked
18
+ refresh_button.click(get_history, [], history_display)
19
+
20
+ gui.launch(share=True)
21
+
22
+ if __name__ == "__main__":
23
+ launch_gradio_ui()
model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from transformers import AutoTokenizer, AutoModelForCausalLM
2
+
3
+ # def load_model(model_name="meta-llama/Llama-3.1-8B"):
4
+ # """
5
+ # Load the Hugging Face Llama model and tokenizer.
6
+ # """
7
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ # model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto")
9
+ # return tokenizer, model
10
+
11
+ # # Initialize model and tokenizer
12
+ # tokenizer, model = load_model()
13
+
14
+ from transformers import pipeline
15
+
16
+ # Replace with your Hugging Face API token
17
+ api_token = "your_huggingface_api_token"
18
+
19
+ # Load the model using the API
20
+ generator = pipeline(
21
+ "text-generation",
22
+ model="meta-llama/Llama-3.1-8B",
23
+ use_auth_token=api_token
24
+ )
25
+
utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def truncate_conversation(history, max_tokens=1024):
2
+ """
3
+ Truncate the conversation history to fit within the token limit.
4
+ """
5
+ truncated_history = history[-max_tokens:]
6
+ return truncated_history