Cédric KACZMAREK commited on
Commit
861919a
·
1 Parent(s): bb49fd9

après hackathon

Browse files
Files changed (3) hide show
  1. app.py +48 -118
  2. requirements.txt +3 -6
  3. src/utils_fct.py +117 -0
app.py CHANGED
@@ -1,99 +1,53 @@
1
  import os
2
  import json
 
3
  import gradio as gr
4
  from llama_index.core import (
5
  VectorStoreIndex,
6
  download_loader,
7
  StorageContext
8
  )
9
- from dotenv import load_dotenv, find_dotenv
10
-
11
- import chromadb
12
-
13
- from llama_index.llms.mistralai import MistralAI
14
- from llama_index.embeddings.mistralai import MistralAIEmbedding
15
- from llama_index.vector_stores.chroma import ChromaVectorStore
16
- from llama_index.core.indices.service_context import ServiceContext
17
 
 
 
18
  from pathlib import Path
19
 
 
 
 
 
 
 
20
  TITLE = "RIZOA-AUCHAN Chatbot Demo"
21
- DESCRIPTION = "Example of an assistant with Gradio, coupling with function calling and Mistral AI via its API"
22
  PLACEHOLDER = (
23
- "Vous pouvez me posez une question sur ce contexte, appuyer sur Entrée pour valider"
24
  )
25
- PLACEHOLDER_URL = "Extract text from this url"
26
- llm_model = "mistral-medium"
 
 
 
27
 
28
  load_dotenv()
29
- env_api_key = os.environ.get("MISTRAL_API_KEY")
30
- query_engine = None
 
 
31
 
32
  # Define LLMs
33
- llm = MistralAI(api_key=env_api_key, model=llm_model)
34
- embed_model = MistralAIEmbedding(model_name="mistral-embed", api_key=env_api_key)
35
-
36
- # create client and a new collection
37
- db = chromadb.PersistentClient(path="./chroma_db")
38
- chroma_collection = db.get_or_create_collection("quickstart")
39
-
40
- # set up ChromaVectorStore and load in data
41
- vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
42
- storage_context = StorageContext.from_defaults(vector_store=vector_store)
43
- service_context = ServiceContext.from_defaults(
44
- chunk_size=1024, llm=llm, embed_model=embed_model
45
- )
46
-
47
- #PDFReader = download_loader("PDFReader")
48
- #loader = PDFReader()
49
-
50
- index = VectorStoreIndex(
51
- [], service_context=service_context, storage_context=storage_context
52
- )
53
- query_engine = index.as_query_engine(similarity_top_k=5)
54
-
55
- FILE = Path(__file__).resolve()
56
- BASE_PATH = FILE.parents[0]
57
-
58
- '''
59
- image = os.path.join(BASE_PATH,"img","logo_rizoa_auchan.jpg")
60
- print(f"Chemin de l'image : {image}")
61
- image = os.path.join("img","logo_rizoa_auchan.jpg")
62
- print(f"chemin 2 : {image}")
63
- image = os.path.abspath(os.path.join("img", "logo_rizoa_auchan.jpg"))
64
- print(f"Image 3 : {image}")
65
- image = os.path.join("https://huggingface.co/spaces/rizoa-auchan-hack/hack/blob/main/img/logo_rizoa_auchan.jpg")
66
- print(f"Image 4 : {image}")
67
- '''
68
- image = os.path.join("logo_rizoa_auchan.jpg")
69
- print(f"Chemin:{image}")
70
-
71
- if os.path.exists(image):
72
- print("Image existe")
73
- else:
74
- print("Image n'existe pas")
75
-
76
-
77
- PLACEHOLDER = (image)
78
 
79
  with gr.Blocks() as demo:
80
  with gr.Row():
81
-
82
  with gr.Column(scale=1):
83
- '''
84
- gr.Image(
85
- #value=os.path.join(BASE_PATH,"img","logo_rizoa_auchan.jpg"),
86
- #value=os.path.join("img","logo_rizoa_auchan.jpg"),
87
- value="logo_rizoa_auchan.jpg",
88
  height=250,
89
  width=250,
90
- container=False,
91
  show_download_button=False
92
  )
93
- '''
94
- gr.HTML(
95
- value = '<img src="https://huggingface.co/spaces/rizoa-auchan-hack/hack/resolve/main/LOGO_RIZOA_CARRE.jpg">'
96
- )
97
  with gr.Column(scale=4):
98
  gr.Markdown(
99
  """
@@ -103,59 +57,35 @@ with gr.Blocks() as demo:
103
  """
104
  )
105
 
106
- # gr.Markdown(""" ### 1 / Extract data from PDF """)
107
-
108
- # with gr.Row():
109
- # with gr.Column():
110
- # input_file = gr.File(
111
- # label="Load a pdf",
112
- # file_types=[".pdf"],
113
- # file_count="single",
114
- # type="filepath",
115
- # interactive=True,
116
- # )
117
- # file_msg = gr.Textbox(
118
- # label="Loaded documents:", container=False, visible=False
119
- # )
120
-
121
- # input_file.upload(
122
- # fn=load_document,
123
- # inputs=[
124
- # input_file,
125
- # ],
126
- # outputs=[file_msg],
127
- # concurrency_limit=20,
128
- # )
129
-
130
- # file_btn = gr.Button(value="Encode file ✅", interactive=True)
131
- # btn_msg = gr.Textbox(container=False, visible=False)
132
-
133
- # with gr.Row():
134
- # db_list = gr.Markdown(value=get_documents_in_db)
135
- # delete_btn = gr.Button(value="Empty db 🗑️", interactive=True, scale=0)
136
-
137
- # file_btn.click(
138
- # load_file,
139
- # inputs=[input_file],
140
- # outputs=[file_msg, btn_msg, db_list],
141
- # show_progress="full",
142
- # )
143
- # delete_btn.click(empty_db, outputs=[db_list], show_progress="minimal")
144
-
145
- gr.Markdown(""" ### Ask a question """)
146
 
147
  chatbot = gr.Chatbot()
148
  msg = gr.Textbox(placeholder=PLACEHOLDER)
149
  clear = gr.ClearButton([msg, chatbot])
150
-
151
  def respond(message, chat_history):
152
- response = query_engine.query(message)
153
- chat_history.append((message, str(response)))
154
- return chat_history
155
-
156
- msg.submit(respond, [msg, chatbot], [chatbot])
157
 
158
- demo.title = TITLE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  if __name__ == "__main__":
161
- demo.launch(allowed_paths=['/home/user/app/img/','./img/','.'])
 
1
  import os
2
  import json
3
+ import pandas as pd
4
  import gradio as gr
5
  from llama_index.core import (
6
  VectorStoreIndex,
7
  download_loader,
8
  StorageContext
9
  )
 
 
 
 
 
 
 
 
10
 
11
+ import logging
12
+ from dotenv import load_dotenv, find_dotenv
13
  from pathlib import Path
14
 
15
+ # from llama_index.llms.mistralai import MistralAI
16
+ from mistralai.client import MistralClient
17
+ from mistralai.models.chat_completion import ChatMessage
18
+ # from llama_index.embeddings.mistralai import MistralAIEmbedding
19
+ from src.utils_fct import *
20
+
21
  TITLE = "RIZOA-AUCHAN Chatbot Demo"
22
+ DESCRIPTION = "Example of an assistant with Gradio, coupling with function callings and Mistral AI via its API"
23
  PLACEHOLDER = (
24
+ "Vous pouvez me posez une question, appuyer sur Entrée pour valider"
25
  )
26
+ EXAMPLES = ["Comment fait on pour produire du maïs ?", "Rédige moi une lettre pour faire un stage dans une exploitation agricole", "Comment reprendre une exploitation agricole ?"]
27
+ MODEL = "mistral-large-latest"
28
+
29
+ # FILE = Path(__file__).resolve()
30
+ # BASE_PATH = FILE.parents[0]
31
 
32
  load_dotenv()
33
+ ENV_API_KEY = os.environ.get("MISTRAL_API_KEY")
34
+ # HISTORY = pd.read_csv(os.path.join(BASE_PATH, "data/cereal_price.csv"), encoding="latin-1")
35
+ # HISTORY = HISTORY[[HISTORY["memberStateName"]=="France"]]
36
+ # HISTORY['price'] = HISTORY['price'].str.replace(",", ".").astype('float64')
37
 
38
  # Define LLMs
39
+ CLIENT = MistralClient(api_key=ENV_API_KEY)
40
+ # EMBED_MODEL = MistralAIEmbedding(model_name="mistral-embed", api_key=ENV_API_KEY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  with gr.Blocks() as demo:
43
  with gr.Row():
 
44
  with gr.Column(scale=1):
45
+ gr.Image(value= os.path.join(BASE_PATH, "img/logo_rizoa_auchan.jpg"),#".\img\logo_rizoa_auchan.jpg",
 
 
 
 
46
  height=250,
47
  width=250,
48
+ container=False,
49
  show_download_button=False
50
  )
 
 
 
 
51
  with gr.Column(scale=4):
52
  gr.Markdown(
53
  """
 
57
  """
58
  )
59
 
60
+ gr.Markdown(f""" ### {DESCRIPTION} """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  chatbot = gr.Chatbot()
63
  msg = gr.Textbox(placeholder=PLACEHOLDER)
64
  clear = gr.ClearButton([msg, chatbot])
65
+
66
  def respond(message, chat_history):
67
+ messages = [ChatMessage(role="user", content=message)]
68
+ # response = client.chat(
69
+ # model=MODEL,
70
+ # messages=messages)
 
71
 
72
+ response = forecast(messages)
73
+
74
+ # prompt = f"Reformule le résultat suivant {response}"
75
+ # prompt = [ChatMessage(role="user", content=prompt)]
76
+ # chat_history.append((message, str(response)))
77
+ final_response = CLIENT.chat(
78
+ model=MODEL,
79
+ messages=response
80
+ ).choices[0].message.content
81
+ return "", [[None, None],
82
+ [None, str(final_response)]
83
+ ]
84
+
85
+ msg.submit(respond, [msg, chatbot], [msg, chatbot])
86
+
87
+
88
+ # demo.title = TITLE
89
 
90
  if __name__ == "__main__":
91
+ demo.launch()
requirements.txt CHANGED
@@ -1,7 +1,4 @@
1
- pypdf
2
  mistralai>=0.1.2
3
- # llama-index-llms-mistralai
4
- # llama-index-embeddings-mistralai
5
- llama-index
6
- gradio
7
- chromadb
 
 
1
  mistralai>=0.1.2
2
+ gradio
3
+ openai
4
+ load_dotenv
 
 
src/utils_fct.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta
2
+ import functools
3
+ import json
4
+ import os
5
+ import pandas as pd
6
+ from prophet import Prophet
7
+ from pathlib import Path
8
+ from mistralai.client import MistralClient
9
+ from mistralai.models.chat_completion import ChatMessage
10
+
11
+ # MODEL
12
+ MODEL = "mistral-large-latest"
13
+ API_KEY=os.environ["MISTRAL_API_KEY"]
14
+ CLIENT = MistralClient(api_key=API_KEY)
15
+
16
+ # PATH
17
+ FILE = Path(__file__).resolve()
18
+ BASE_PATH = FILE.parents[1]
19
+
20
+ HISTORY = pd.read_csv(os.path.join(BASE_PATH, "data/cereal_price.csv"), encoding="latin-1")
21
+ HISTORY = HISTORY[HISTORY["memberStateName"]=="France"]
22
+ HISTORY['price'] = HISTORY['price'].str.replace(",", ".").astype('float64')
23
+
24
+
25
+ def model_predict(week=26):
26
+ """
27
+ Predict future prices using the Prophet model.
28
+
29
+ Parameters:
30
+ - weeks (int): Number of periods to predict into the future (default is 26).
31
+
32
+ Returns:
33
+ - dict: Dictionary containing predicted values and confidence intervals.
34
+ """
35
+
36
+ # Prepare the historical data for the model
37
+ data = HISTORY[['endDate', 'price']]
38
+ data.columns = ['ds', 'y']
39
+
40
+ # Prophet Model
41
+ # Instantiate a Prophet object
42
+ model = Prophet()
43
+
44
+ # Fit the model with historical data
45
+ model.fit(data)
46
+
47
+ # Calculate the current date
48
+ today_date = datetime.now().date()
49
+
50
+ # Calculate the end date for the future DataFrame (specified number of periods from today)
51
+ end_date = today_date + timedelta(weeks=week)
52
+
53
+ # Create a DataFrame with dates starting from today and ending in the specified number of periods
54
+ future_df = pd.date_range(start=today_date, end=end_date, freq='W').to_frame(name='ds').reset_index(drop=True)
55
+
56
+ # Make predictions on the future DataFrame
57
+ forecast = model.predict(future_df)
58
+
59
+ # Return relevant columns from the forecast DataFrame as a dictionary
60
+ result_dict = {
61
+ 'ds': forecast['ds'].tolist(),
62
+ 'yhat_lower': forecast['yhat_lower'].tolist(),
63
+ 'yhat_upper': forecast['yhat_upper'].tolist(),
64
+ 'yhat': forecast['yhat'].tolist()
65
+ }
66
+
67
+ return result_dict
68
+
69
+ model_predict_tool = [{
70
+ "type": "function",
71
+ "function": {
72
+ "name": "model_predict",
73
+ "description": "Predict future prices using the Prophet model.",
74
+ "parameters": {
75
+ "type": "object",
76
+ "properties": {
77
+ "week": {
78
+ "type": "integer",
79
+ "description": "Number of periods to predict into the future (default is 26).",
80
+ },
81
+ },
82
+ "required": ["week"]
83
+ },
84
+ },
85
+ }]
86
+
87
+ names_to_functions = {
88
+ 'model_predict': functools.partial(model_predict),
89
+ }
90
+
91
+ # messages = [
92
+ # ChatMessage(role="user", content="Predict future prices using the Prophet model for 4 weeks in the future")
93
+ # ]
94
+
95
+ def forecast(messages
96
+ ):
97
+ response = CLIENT.chat(
98
+ model=MODEL,
99
+ messages=messages,
100
+ tools=model_predict_tool,
101
+ tool_choice="auto"
102
+ )
103
+
104
+ tool_call = response.choices[0].message.tool_calls[0]
105
+ function_name = tool_call.function.name
106
+ function_params = json.loads(tool_call.function.arguments)
107
+ function_result = names_to_functions[function_name](**function_params)
108
+ date = function_result["ds"][-1]
109
+ lower = function_result["yhat_lower"][-1]
110
+ upper = function_result["yhat_upper"][-1]
111
+ prediction = function_result["yhat"][-1]
112
+
113
+ messages.append(ChatMessage(role="tool",
114
+ name=function_name,
115
+ content=str({"date" : date, "prix_minimum": lower, "prix_maximum": upper, "prix_estimé": prediction})
116
+ ))
117
+ return messages