Update app.py
Browse files
app.py
CHANGED
@@ -1,75 +1,139 @@
|
|
1 |
-
import os
|
2 |
import gradio as gr
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
gr.Textbox(lines=2, placeholder="Or type text here", label="Text Input"),
|
61 |
-
gr.State([])
|
62 |
-
],
|
63 |
-
outputs=[
|
64 |
-
gr.Textbox(label="Chat History", lines=20),
|
65 |
-
gr.Audio(visible=False),
|
66 |
-
gr.Textbox(visible=False),
|
67 |
-
gr.State()
|
68 |
-
],
|
69 |
-
title="Chat with Llama 3.2-11B With Text or Voice (Whisper Large-v3)",
|
70 |
-
description="Upload an audio file or type text to get a chat response based on the transcription.",
|
71 |
-
allow_flagging='never' # Prevent flagging to isolate sessions
|
72 |
)
|
73 |
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import uuid
|
3 |
+
from typing import Sequence
|
4 |
+
|
5 |
+
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
6 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
7 |
+
from langchain_community.document_loaders import TextLoader
|
8 |
+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
9 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
10 |
+
from langchain_core.vectorstores import InMemoryVectorStore
|
11 |
+
|
12 |
+
from langchain_groq import ChatGroq
|
13 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
14 |
+
|
15 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
16 |
+
from langgraph.checkpoint.memory import MemorySaver
|
17 |
+
from langgraph.graph import START, StateGraph
|
18 |
+
from langgraph.graph.message import add_messages
|
19 |
+
from typing_extensions import Annotated, TypedDict
|
20 |
+
|
21 |
+
import os
|
22 |
+
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
23 |
+
|
24 |
+
llm = ChatGroq(model="llama-3.2-11b-text-preview", api_key=GROQ_API_KEY, temperature=0)
|
25 |
+
|
26 |
+
### Construct retriever ###
|
27 |
+
loader = TextLoader("stj.txt")
|
28 |
+
docs = loader.load()
|
29 |
+
|
30 |
+
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
31 |
+
model_kwargs = {'device': 'cpu'}
|
32 |
+
encode_kwargs = {'normalize_embeddings': False}
|
33 |
+
hf = HuggingFaceEmbeddings(
|
34 |
+
model_name=model_name,
|
35 |
+
model_kwargs=model_kwargs,
|
36 |
+
encode_kwargs=encode_kwargs
|
37 |
+
)
|
38 |
+
|
39 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
40 |
+
splits = text_splitter.split_documents(docs)
|
41 |
+
vectorstore = InMemoryVectorStore.from_documents(
|
42 |
+
documents=splits, embedding=hf
|
43 |
+
)
|
44 |
+
retriever = vectorstore.as_retriever()
|
45 |
+
|
46 |
+
### Contextualize question ###
|
47 |
+
contextualize_q_system_prompt = (
|
48 |
+
"Sohbet geçmişi ve en son kullanıcı sorusu verilirse, sohbet geçmişine atıfta bulunabilecek en son kullanıcı sorusunu, sohbet geçmişi olmadan anlaşılabilecek bağımsız bir soru haline getirin. Soruyu yanıtlamayın, sadece yeniden düzenleyin ve gerekirse geri döndürün."
|
49 |
+
)
|
50 |
+
contextualize_q_prompt = ChatPromptTemplate.from_messages(
|
51 |
+
[
|
52 |
+
("system", contextualize_q_system_prompt),
|
53 |
+
MessagesPlaceholder("chat_history"),
|
54 |
+
("human", "{input}"),
|
55 |
+
]
|
56 |
+
)
|
57 |
+
history_aware_retriever = create_history_aware_retriever(
|
58 |
+
llm, retriever, contextualize_q_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
)
|
60 |
|
61 |
+
### Answer question ###
|
62 |
+
system_prompt = (
|
63 |
+
"Soru-cevap görevleri için bir asistansın. Soruyu yanıtlamak için alınan aşağıdaki bağlam parçalarını kullan. Cevabı bilmiyorsan, bilmiyorum de. Cevabı üç cümleyle sınırla ve kısa tut."
|
64 |
+
"\n\n"
|
65 |
+
"{context}"
|
66 |
+
)
|
67 |
+
qa_prompt = ChatPromptTemplate.from_messages(
|
68 |
+
[
|
69 |
+
("system", system_prompt),
|
70 |
+
MessagesPlaceholder("chat_history"),
|
71 |
+
("human", "{input}"),
|
72 |
+
]
|
73 |
+
)
|
74 |
+
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
75 |
+
|
76 |
+
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
|
77 |
+
|
78 |
+
### Statefully manage chat history ###
|
79 |
+
class State(TypedDict):
|
80 |
+
input: str
|
81 |
+
chat_history: Annotated[Sequence[BaseMessage], add_messages]
|
82 |
+
context: str
|
83 |
+
answer: str
|
84 |
+
|
85 |
+
def call_model(state: State):
|
86 |
+
response = rag_chain.invoke(state)
|
87 |
+
return {
|
88 |
+
"chat_history": [
|
89 |
+
HumanMessage(state["input"]),
|
90 |
+
AIMessage(response["answer"]),
|
91 |
+
],
|
92 |
+
"context": response["context"],
|
93 |
+
"answer": response["answer"],
|
94 |
+
}
|
95 |
+
|
96 |
+
workflow = StateGraph(state_schema=State)
|
97 |
+
workflow.add_edge(START, "model")
|
98 |
+
workflow.add_node("model", call_model)
|
99 |
+
|
100 |
+
memory = MemorySaver()
|
101 |
+
app = workflow.compile(checkpointer=memory)
|
102 |
+
|
103 |
+
# Session storage
|
104 |
+
session_storage = {}
|
105 |
+
|
106 |
+
# Function to interact with the RAG model
|
107 |
+
def rag_response(user_input, chat_history, session_id):
|
108 |
+
config = {"configurable": {"thread_id": "abc123"}}
|
109 |
+
|
110 |
+
# Prepare the state with input and chat history
|
111 |
+
state = {
|
112 |
+
"input": user_input,
|
113 |
+
"chat_history": session_storage[session_id]["chat_history"] # Get chat history for this session
|
114 |
+
}
|
115 |
+
|
116 |
+
# Call the RAG model to get the response
|
117 |
+
result = app.invoke(state, config=config)
|
118 |
+
|
119 |
+
# Update session storage
|
120 |
+
session_storage[session_id]["chat_history"].append((user_input, result["answer"]))
|
121 |
+
|
122 |
+
return "", session_storage[session_id]["chat_history"]
|
123 |
+
|
124 |
+
# Define the Gradio interface
|
125 |
+
with gr.Blocks() as demo:
|
126 |
+
|
127 |
+
chatbox = gr.Chatbot(label="Chat History")
|
128 |
+
user_input = gr.Textbox(placeholder="Enter your question", label="User Input")
|
129 |
+
submit_button = gr.Button("Submit")
|
130 |
+
|
131 |
+
# Create a unique session ID
|
132 |
+
session_id = str(uuid.uuid4())
|
133 |
+
session_storage[session_id] = {"chat_history": []}
|
134 |
+
|
135 |
+
# Connect the button click event to the rag_response function
|
136 |
+
submit_button.click(rag_response, inputs=[user_input, chatbox, session_id], outputs=[user_input, chatbox])
|
137 |
+
|
138 |
+
# Launch the Gradio app
|
139 |
+
demo.launch()
|