gufett0 commited on
Commit
c5ffbe6
·
1 Parent(s): 9c4af10

vectostoreindex

Browse files
Files changed (6) hide show
  1. .gitignore +1 -0
  2. app.py +4 -0
  3. appLlama.py +0 -29
  4. appcompleta.py +0 -262
  5. interface.py +0 -105
  6. requirements.txt +1 -0
.gitignore CHANGED
@@ -2,3 +2,4 @@
2
  __pycache__/
3
  appcompleta.py
4
  appLlama.py
 
 
2
  __pycache__/
3
  appcompleta.py
4
  appLlama.py
5
+ interface.py
app.py CHANGED
@@ -9,6 +9,7 @@ from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorS
9
  from huggingface_hub import login
10
  from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, PromptTemplate, load_index_from_storage, StorageContext
11
  from llama_index.core.node_parser import SentenceSplitter
 
12
 
13
 
14
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
@@ -41,6 +42,9 @@ model.config.sliding_window = 4096
41
  #model = model.to(device)
42
  model.eval()
43
 
 
 
 
44
  ###------####
45
  # rag
46
  documents_paths = {
 
9
  from huggingface_hub import login
10
  from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, PromptTemplate, load_index_from_storage, StorageContext
11
  from llama_index.core.node_parser import SentenceSplitter
12
+ from llama_index.embeddings.instructor import InstructorEmbedding
13
 
14
 
15
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
 
42
  #model = model.to(device)
43
  model.eval()
44
 
45
+ Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
46
+
47
+
48
  ###------####
49
  # rag
50
  documents_paths = {
appLlama.py DELETED
@@ -1,29 +0,0 @@
1
- from backend import handle_query
2
- import gradio as gr
3
-
4
-
5
- DESCRIPTION = """\
6
- # <div style="text-align: center;">Odi, l'assistente ricercatore degli Osservatori</div>
7
-
8
-
9
- 👉 Retrieval-Augmented Generation - Ask me anything about the research carried out at the Osservatori.
10
- """
11
-
12
-
13
- chat_interface =gr.ChatInterface(
14
- fn=handle_query,
15
- chatbot=gr.Chatbot(height=500),
16
- textbox=gr.Textbox(placeholder="Chiedimi qualasiasi cosa relativa agli Osservatori", container=False, scale=7),
17
- #examples=[["Ciao, in cosa puoi aiutarmi?"],["Dimmi i risultati e le modalità di conduzione del censimento per favore"]]
18
- )
19
-
20
-
21
- with gr.Blocks(css=".gradio-container {background-color: #B9D9EB}") as demo:
22
- gr.Markdown(DESCRIPTION)
23
- #gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
24
- chat_interface.render()
25
-
26
- if __name__ == "__main__":
27
- #progress = gr.Progress(track_tqdm=True)
28
- demo.launch()
29
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
appcompleta.py DELETED
@@ -1,262 +0,0 @@
1
- import os
2
- import spaces
3
- from threading import Thread
4
- from typing import Iterator
5
- from backend2 import load_documents, prepare_documents, get_context_sources
6
- import gradio as gr
7
- import torch
8
- from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
9
- from huggingface_hub import login
10
- import threading
11
-
12
-
13
- huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
14
- login(huggingface_token)
15
-
16
- DESCRIPTION = """\
17
- # La Chatbot degli Osservatori
18
- """
19
- MAX_MAX_NEW_TOKENS = 2048
20
- DEFAULT_MAX_NEW_TOKENS = 1024
21
- os.environ["MAX_INPUT_TOKEN_LENGTH"] = "4096" #"8192"
22
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH"))
23
-
24
-
25
- # Force usage of CPU
26
- #device = torch.device("cpu")
27
-
28
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29
-
30
- model_id = "google/gemma-2-2b-it"
31
- model = AutoModelForCausalLM.from_pretrained(
32
- model_id,
33
- device_map="auto",
34
- torch_dtype= torch.float16 if torch.cuda.is_available() else torch.float32,
35
- )
36
- tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
37
- #tokenizer = AutoTokenizer.from_pretrained(model_id)
38
- tokenizer.use_default_system_prompt = False
39
- model.config.sliding_window = 4096
40
- #model = model.to(device)
41
- model.eval()
42
-
43
- ###------####
44
- # rag
45
- documents_paths = {
46
- 'blockchain': 'documents/blockchain',
47
- 'metaverse': 'documents/metaverso',
48
- 'payment': 'documents/payment'
49
- }
50
-
51
- """session_state = {"documents_loaded": False,
52
- "document_db": None,
53
- "original_message": None,
54
- "clarification": False}"""
55
-
56
- INSTRUCTION_1 = 'In italiano, chiedi molto brevemente se la domanda si riferisce agli "Osservatori Blockchain", "Osservatori Payment" oppure "Osservatori Metaverse".'
57
- INSTRUCTION_2 = 'Sei un assistente che risponde in italiano alle domande basandosi solo sulle informazioni fornite nel contesto. Se non trovi informazioni, rispondi "Puoi chiedere maggiori informazioni all\'ufficio di riferimento.". Se invece la domanda è completamente fuori contesto, non rispondere e rammenta il topic del contesto'
58
-
59
- default_error_response = (
60
- 'Non sono sicuro che tu voglia indirizzare la tua ricerca su una di queste opzioni: '
61
- '"Blockchain", "Metaverse", "Payment". '
62
- 'Per favore utilizza il nome corretto.'
63
- )
64
-
65
-
66
- thread_local = threading.local()
67
-
68
- def get_session_state():
69
- if not hasattr(thread_local, "session_state"):
70
- thread_local.session_state = {
71
- "documents_loaded": False,
72
- "document_db": None,
73
- "original_message": None,
74
- "clarification": False
75
- }
76
- return thread_local.session_state
77
-
78
- @spaces.GPU(duration=30)
79
- def generate(
80
- message: str,
81
- chat_history: list[tuple[str, str]],
82
- max_new_tokens: int = 1024,
83
- temperature: float = 0.6,
84
- top_p: float = 0.9,
85
- top_k: int = 50,
86
- repetition_penalty: float = 1.2,
87
- ) -> Iterator[str]:
88
-
89
- session_state = get_session_state()
90
-
91
- global context, sources, conversation
92
-
93
- if (not (session_state["documents_loaded"]) and not (session_state["clarification"])):
94
-
95
- conversation = []
96
- for user, assistant in chat_history:
97
- conversation.extend(
98
- [
99
- {"role": "user", "content": user},
100
- {"role": "assistant", "content": assistant},
101
- ]
102
- )
103
- conversation.append({"role": "user", "content": f"Domanda: {message} . Comando: {INSTRUCTION_1}" })
104
- conversation.append({"role": "assistant", "content": "Ok."})
105
- print("debug - CONV1", conversation)
106
-
107
- session_state["original_message"] = message
108
- session_state["clarification"] = True
109
-
110
-
111
- elif session_state["clarification"]:
112
-
113
- message = message.lower()
114
- matched_path = None
115
-
116
- for key, path in documents_paths.items():
117
- if key in message:
118
- matched_path = path
119
- break
120
-
121
- if matched_path:
122
- yield "Fammi cercare tra i miei documenti..."
123
- documents = load_documents(matched_path)
124
- session_state["document_db"] = prepare_documents(documents)
125
- session_state["documents_loaded"] = True
126
- yield f"Ecco, ho raccolto informazioni dagli Osservatori {key.capitalize()}. Ora sto elaborando una risposta per te..."
127
- sources = []
128
- context, sources = get_context_sources(session_state["original_message"], session_state["document_db"])
129
- print("sources ", sources)
130
- print("context ", context)
131
-
132
- conversation = []
133
- conversation.append({"role": "user", "content": f"{INSTRUCTION_2}"})
134
- conversation.append({"role": "assistant", "content": "Ok."})
135
- for user, assistant in chat_history:
136
- conversation.extend(
137
- [
138
- {"role": "user", "content": user },
139
- {"role": "assistant", "content": assistant},
140
- ]
141
- )
142
- conversation.append({"role": "user", "content": f'Contesto: {context}\n\n Domanda iniziale: {session_state["original_message"]} . Rispondi solo in italiano.'})
143
- session_state["clarification"] = False
144
- print("debug - CONV2", conversation)
145
-
146
- else:
147
- print(default_error_response)
148
- gr.Info("NO MATCH")
149
-
150
- else:
151
- conversation = []
152
- conversation.append({"role": "user", "content": f"Comandi: {INSTRUCTION_2}"})
153
- conversation.append({"role": "assistant", "content": "Va bene."})
154
- for user, assistant in chat_history:
155
- conversation.extend(
156
- [
157
- {"role": "user", "content": user},
158
- {"role": "assistant", "content": assistant},
159
- ]
160
- )
161
- conversation.append({"role": "user", "content": f"Contesto: {context}\n\n Nuova domanda: {message} . Rispondi in italiano e seguendo i comandi che ti ho dato prima"})
162
- print("debug - CONV3", conversation)
163
-
164
- """ retriever = db.as_retriever()
165
- qa = RetrievalQA.from_chain_type(llm=model, chain_type="refine", retriever=retriever, return_source_documents=False)
166
- question = "Cosa sono i RWA?"
167
- result = qa.run({"query": question})
168
- print(result["result"]) """
169
-
170
- # Iterate model output
171
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
172
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
173
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
174
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
175
- input_ids = input_ids.to(model.device)
176
-
177
- streamer = TextIteratorStreamer(tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True)
178
- generate_kwargs = dict(
179
- {"input_ids": input_ids},
180
- streamer=streamer,
181
- max_new_tokens=max_new_tokens,
182
- do_sample=True,
183
- top_p=top_p,
184
- top_k=top_k,
185
- temperature=temperature,
186
- num_beams=1,
187
- repetition_penalty=repetition_penalty,
188
- )
189
- t = Thread(target=model.generate, kwargs=generate_kwargs)
190
- t.start()
191
-
192
-
193
- outputs = []
194
- for text in streamer:
195
- outputs.append(text)
196
- yield "".join(outputs)
197
-
198
- if session_state["documents_loaded"]:
199
- outputs.append(f"Fonti utilizzate: {sources}")
200
- yield "".join(outputs)
201
-
202
- #sources = []
203
- print("debug - CHATHISTORY", chat_history)
204
-
205
- chat_interface = gr.ChatInterface(
206
- fn=generate,
207
- additional_inputs=[
208
- gr.Slider(
209
- label="Max new tokens",
210
- minimum=1,
211
- maximum=MAX_MAX_NEW_TOKENS,
212
- step=1,
213
- value=DEFAULT_MAX_NEW_TOKENS,
214
- ),
215
- gr.Slider(
216
- label="Temperature",
217
- minimum=0.1,
218
- maximum=4.0,
219
- step=0.1,
220
- value=0.6,
221
- ),
222
- gr.Slider(
223
- label="Top-p (nucleus sampling)",
224
- minimum=0.05,
225
- maximum=1.0,
226
- step=0.05,
227
- value=0.9,
228
- ),
229
- gr.Slider(
230
- label="Top-k",
231
- minimum=1,
232
- maximum=1000,
233
- step=1,
234
- value=50,
235
- ),
236
- gr.Slider(
237
- label="Repetition penalty",
238
- minimum=1.0,
239
- maximum=2.0,
240
- step=0.05,
241
- value=1.2,
242
- ),
243
- ],
244
- stop_btn=None,
245
- examples=[
246
- ["Ciao, in cosa puoi aiutarmi?"],
247
- ["Ciao, in cosa consiste un piatto di spaghetti?"],
248
- ["Ciao, quali sono le aziende che hanno iniziato ad integrare le stablecoins? Fammi un breve sommario."],
249
- ["Spiegami la differenza tra mondi virtuali pubblici o privati"],
250
- ["Trovami un esempio di progetto B2B"],
251
- ["Quali sono le regole europee sui bonifici istantanei?"],
252
- ],
253
- cache_examples=False,
254
- )
255
-
256
- with gr.Blocks(css=".gradio-container {background-color: #B9D9EB}", fill_height=True) as demo:
257
- gr.Markdown(DESCRIPTION, elem_classes="centered")
258
- chat_interface.render()
259
-
260
- if __name__ == "__main__":
261
- #demo.queue(max_size=20).launch()
262
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
interface.py DELETED
@@ -1,105 +0,0 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM
2
- from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen
3
- from llama_index.core.llms.callbacks import llm_completion_callback
4
- from typing import Any, Iterator
5
- import torch
6
- from transformers import TextIteratorStreamer
7
- from threading import Thread
8
- from pydantic import Field, field_validator
9
- import keras
10
- import keras_nlp
11
-
12
- # for transformers 2 (__setattr__ is used to bypass Pydantic check )
13
- """class GemmaLLMInterface(CustomLLM):
14
- def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
15
- super().__init__(**kwargs)
16
- object.__setattr__(self, "model_id", model_id)
17
- model = AutoModelForCausalLM.from_pretrained(
18
- model_id,
19
- device_map="auto",
20
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
21
- )
22
- tokenizer = AutoTokenizer.from_pretrained(model_id)
23
- object.__setattr__(self, "model", model)
24
- object.__setattr__(self, "tokenizer", tokenizer)
25
- object.__setattr__(self, "context_window", 8192)
26
- object.__setattr__(self, "num_output", 2048)
27
-
28
- def _format_prompt(self, message: str) -> str:
29
- return (
30
- f"<start_of_turn>user\n{message}<end_of_turn>\n"
31
- f"<start_of_turn>model\n"
32
- )
33
-
34
- @property
35
- def metadata(self) -> LLMMetadata:
36
- return LLMMetadata(
37
- context_window=self.context_window,
38
- num_output=self.num_output,
39
- model_name=self.model_id,
40
- )
41
-
42
-
43
- @llm_completion_callback()
44
- def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
45
- prompt = self._format_prompt(prompt)
46
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
47
- outputs = self.model.generate(**inputs, max_new_tokens=self.num_output)
48
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
49
- response = response[len(prompt):].strip()
50
- return CompletionResponse(text=response if response else "No response generated.")
51
-
52
- @llm_completion_callback()
53
- def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
54
- #prompt = self._format_prompt(prompt)
55
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
56
-
57
- streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
58
- generation_kwargs = dict(inputs, max_new_tokens=self.num_output, streamer=streamer)
59
-
60
- thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
61
- thread.start()
62
-
63
- streamed_response = ""
64
- for new_text in streamer:
65
- if new_text:
66
- streamed_response += new_text
67
- yield CompletionResponse(text=streamed_response, delta=new_text)
68
-
69
- if not streamed_response:
70
- yield CompletionResponse(text="No response generated.", delta="No response generated.")"""
71
-
72
- # for Keras
73
- class GemmaLLMInterface(CustomLLM):
74
- model: keras_nlp.models.GemmaCausalLM = None
75
- context_window: int = 8192
76
- num_output: int = 2048
77
- model_name: str = "gemma_2"
78
-
79
- def _format_prompt(self, message: str) -> str:
80
- return (
81
- f"<start_of_turn>user\n{message}<end_of_turn>\n" f"<start_of_turn>model\n"
82
- )
83
-
84
- @property
85
- def metadata(self) -> LLMMetadata:
86
- """Get LLM metadata."""
87
- return LLMMetadata(
88
- context_window=self.context_window,
89
- num_output=self.num_output,
90
- model_name=self.model_name,
91
- )
92
-
93
- @llm_completion_callback()
94
- def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
95
- prompt = self._format_prompt(prompt)
96
- raw_response = self.model.generate(prompt, max_length=self.num_output)
97
- response = raw_response[len(prompt) :]
98
- return CompletionResponse(text=response)
99
-
100
- @llm_completion_callback()
101
- def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
102
- response = self.complete(prompt).text
103
- for token in response:
104
- response += token
105
- yield CompletionResponse(text=response, delta=token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -14,6 +14,7 @@ torch==2.2.0
14
  transformers==4.43.3
15
  llama-cpp-agent>=0.2.25
16
  setuptools
 
17
 
18
  pydantic
19
  ipython
 
14
  transformers==4.43.3
15
  llama-cpp-agent>=0.2.25
16
  setuptools
17
+ faiss-cpu
18
 
19
  pydantic
20
  ipython