candenizkocak commited on
Commit
f276112
·
verified ·
1 Parent(s): 65134ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -136
app.py CHANGED
@@ -1,139 +1,75 @@
1
- import gradio as gr
2
- import uuid
3
- from typing import Sequence
4
-
5
- from langchain.chains import create_history_aware_retriever, create_retrieval_chain
6
- from langchain.chains.combine_documents import create_stuff_documents_chain
7
- from langchain_community.document_loaders import TextLoader
8
- from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
9
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
10
- from langchain_core.vectorstores import InMemoryVectorStore
11
-
12
- from langchain_groq import ChatGroq
13
- from langchain_huggingface import HuggingFaceEmbeddings
14
-
15
- from langchain_text_splitters import RecursiveCharacterTextSplitter
16
- from langgraph.checkpoint.memory import MemorySaver
17
- from langgraph.graph import START, StateGraph
18
- from langgraph.graph.message import add_messages
19
- from typing_extensions import Annotated, TypedDict
20
-
21
  import os
22
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
23
-
24
- llm = ChatGroq(model="llama-3.2-11b-text-preview", api_key=GROQ_API_KEY, temperature=0)
25
-
26
- ### Construct retriever ###
27
- loader = TextLoader("stj.txt")
28
- docs = loader.load()
29
-
30
- model_name = "sentence-transformers/all-MiniLM-L6-v2"
31
- model_kwargs = {'device': 'cpu'}
32
- encode_kwargs = {'normalize_embeddings': False}
33
- hf = HuggingFaceEmbeddings(
34
- model_name=model_name,
35
- model_kwargs=model_kwargs,
36
- encode_kwargs=encode_kwargs
37
- )
38
-
39
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
40
- splits = text_splitter.split_documents(docs)
41
- vectorstore = InMemoryVectorStore.from_documents(
42
- documents=splits, embedding=hf
43
- )
44
- retriever = vectorstore.as_retriever()
45
-
46
- ### Contextualize question ###
47
- contextualize_q_system_prompt = (
48
- "Sohbet geçmişi ve en son kullanıcı sorusu verilirse, sohbet geçmişine atıfta bulunabilecek en son kullanıcı sorusunu, sohbet geçmişi olmadan anlaşılabilecek bağımsız bir soru haline getirin. Soruyu yanıtlamayın, sadece yeniden düzenleyin ve gerekirse geri döndürün."
49
- )
50
- contextualize_q_prompt = ChatPromptTemplate.from_messages(
51
- [
52
- ("system", contextualize_q_system_prompt),
53
- MessagesPlaceholder("chat_history"),
54
- ("human", "{input}"),
55
- ]
56
- )
57
- history_aware_retriever = create_history_aware_retriever(
58
- llm, retriever, contextualize_q_prompt
59
- )
60
-
61
- ### Answer question ###
62
- system_prompt = (
63
- "Soru-cevap görevleri için bir asistansın. Soruyu yanıtlamak için alınan aşağıdaki bağlam parçalarını kullan. Cevabı bilmiyorsan, bilmiyorum de. Cevabı üç cümleyle sınırla ve kısa tut."
64
- "\n\n"
65
- "{context}"
66
- )
67
- qa_prompt = ChatPromptTemplate.from_messages(
68
- [
69
- ("system", system_prompt),
70
- MessagesPlaceholder("chat_history"),
71
- ("human", "{input}"),
72
- ]
73
- )
74
- question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
75
-
76
- rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
77
-
78
- ### Statefully manage chat history ###
79
- class State(TypedDict):
80
- input: str
81
- chat_history: Annotated[Sequence[BaseMessage], add_messages]
82
- context: str
83
- answer: str
84
-
85
- def call_model(state: State):
86
- response = rag_chain.invoke(state)
87
- return {
88
- "chat_history": [
89
- HumanMessage(state["input"]),
90
- AIMessage(response["answer"]),
91
  ],
92
- "context": response["context"],
93
- "answer": response["answer"],
94
- }
95
-
96
- workflow = StateGraph(state_schema=State)
97
- workflow.add_edge(START, "model")
98
- workflow.add_node("model", call_model)
99
-
100
- memory = MemorySaver()
101
- app = workflow.compile(checkpointer=memory)
102
-
103
- # Session storage
104
- session_storage = {}
105
-
106
- # Function to interact with the RAG model
107
- def rag_response(user_input, chat_history, session_id):
108
- config = {"configurable": {"thread_id": "abc123"}}
109
-
110
- # Prepare the state with input and chat history
111
- state = {
112
- "input": user_input,
113
- "chat_history": session_storage[session_id]["chat_history"] # Get chat history for this session
114
- }
115
-
116
- # Call the RAG model to get the response
117
- result = app.invoke(state, config=config)
118
-
119
- # Update session storage
120
- session_storage[session_id]["chat_history"].append((user_input, result["answer"]))
121
-
122
- return "", session_storage[session_id]["chat_history"]
123
-
124
- # Define the Gradio interface
125
- with gr.Blocks() as demo:
126
-
127
- chatbox = gr.Chatbot(label="Chat History")
128
- user_input = gr.Textbox(placeholder="Enter your question", label="User Input")
129
- submit_button = gr.Button("Submit")
130
-
131
- # Create a unique session ID
132
- session_id = str(uuid.uuid4())
133
- session_storage[session_id] = {"chat_history": []}
134
-
135
- # Connect the button click event to the rag_response function
136
- submit_button.click(rag_response, inputs=[user_input, chatbox, session_id], outputs=[user_input, chatbox])
137
 
138
- # Launch the Gradio app
139
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import gradio as gr
3
+ from groq import Groq
4
+
5
+ api_key = os.getenv("GROQ_API_KEY")
6
+ client = Groq(api_key=api_key)
7
+
8
+ if not api_key:
9
+ raise ValueError("API key not found. Please set the GROQ_API_KEY environment variable.")
10
+
11
+ def transcribe_audio(file_path):
12
+ with open(file_path, "rb") as file:
13
+ transcription = client.audio.transcriptions.create(
14
+ file=(file_path, file.read()),
15
+ model="whisper-large-v3",
16
+ response_format="verbose_json",
17
+ )
18
+ return transcription.text
19
+
20
+ def get_chat_completion(prompt):
21
+ completion = client.chat.completions.create(
22
+ model="llama-3.2-11b-text-preview",
23
+ messages=[
24
+ {
25
+ "role": "user",
26
+ "content": prompt
27
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  ],
29
+ temperature=1,
30
+ max_tokens=1024,
31
+ top_p=1,
32
+ stream=True,
33
+ stop=None,
34
+ )
35
+
36
+ response = ""
37
+ for chunk in completion:
38
+ response += chunk.choices[0].delta.content or ""
39
+ return response
40
+
41
+ def process_input(audio_file, text_input, chat_history):
42
+ if audio_file is not None:
43
+ transcription_text = transcribe_audio(audio_file)
44
+ else:
45
+ transcription_text = text_input
46
+
47
+ chat_response = get_chat_completion(transcription_text)
48
+ chat_history.append(("👤", transcription_text))
49
+ chat_history.append(("🤖", chat_response))
50
+
51
+ formatted_history = "\n".join([f"{role}: {content}\n" for role, content in chat_history])
52
+
53
+ return formatted_history, gr.update(value=None), gr.update(value=''), chat_history
54
+
55
+ # Create Gradio interface
56
+ interface = gr.Interface(
57
+ fn=process_input,
58
+ inputs=[
59
+ gr.Audio(type="filepath", label="Upload Audio or Record"),
60
+ gr.Textbox(lines=2, placeholder="Or type text here", label="Text Input"),
61
+ gr.State([])
62
+ ],
63
+ outputs=[
64
+ gr.Textbox(label="Chat History", lines=20),
65
+ gr.Audio(visible=False),
66
+ gr.Textbox(visible=False),
67
+ gr.State()
68
+ ],
69
+ title="Chat with Llama 3.2-11B With Text or Voice (Whisper Large-v3)",
70
+ description="Upload an audio file or type text to get a chat response based on the transcription.",
71
+ allow_flagging='never' # Prevent flagging to isolate sessions
72
+ )
 
73
 
74
+ if __name__ == "__main__":
75
+ interface.launch()