nicolasmery commited on
Commit
6298526
·
verified ·
1 Parent(s): 1526fbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -28
app.py CHANGED
@@ -1,7 +1,46 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
 
 
 
3
 
 
 
 
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def respond(
6
  message,
7
  history: list[dict[str, str]],
@@ -11,19 +50,21 @@ def respond(
11
  top_p,
12
  hf_token: gr.OAuthToken,
13
  ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
 
20
 
 
 
21
  messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
 
 
24
 
25
  response = ""
26
-
27
  for message in client.chat_completion(
28
  messages,
29
  max_tokens=max_tokens,
@@ -31,32 +72,20 @@ def respond(
31
  temperature=temperature,
32
  top_p=top_p,
33
  ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
  chatbot = gr.ChatInterface(
47
  respond,
48
  type="messages",
49
  additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
  ],
61
  )
62
 
@@ -65,6 +94,5 @@ with gr.Blocks() as demo:
65
  gr.LoginButton()
66
  chatbot.render()
67
 
68
-
69
  if __name__ == "__main__":
70
  demo.launch()
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import torch
5
+ import numpy as np
6
+ import os
7
 
8
+ # --- 1. Récupérer le token Hugging Face depuis variable d'environnement ---
9
+ hf_token = os.environ.get("HF_TOKEN")
10
+ if hf_token is None:
11
+ raise ValueError("Tu dois définir la variable d'environnement HF_TOKEN avec ton token Hugging Face.")
12
 
13
+ # --- 2. Charger SteelBERT pour embeddings ---
14
+ steelbert_tokenizer = AutoTokenizer.from_pretrained(
15
+ "MGE-LLMs/SteelBERT", use_auth_token=hf_token
16
+ )
17
+ steelbert_model = AutoModel.from_pretrained(
18
+ "MGE-LLMs/SteelBERT", use_auth_token=hf_token
19
+ ).eval()
20
+
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ steelbert_model.to(device)
23
+
24
+ def embed(text):
25
+ inputs = steelbert_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
26
+ with torch.no_grad():
27
+ outputs = steelbert_model(**inputs, output_hidden_states=True)
28
+ return outputs.hidden_states[-1][:,0,:].cpu().numpy()[0]
29
+
30
+ # --- 3. Base documentaire (exemple, à remplacer par tes documents techniques) ---
31
+ docs = {
32
+ "doc1": "L’acier X42 a une résistance à la traction de 415 MPa.",
33
+ "doc2": "L’acier inoxydable 304 est résistant à la corrosion et à l’oxydation."
34
+ }
35
+ doc_embeddings = {k: embed(v) for k,v in docs.items()}
36
+
37
+ def search_best_doc(question):
38
+ q_emb = embed(question)
39
+ def cosine(a,b): return np.dot(a,b)/(np.linalg.norm(a)*np.linalg.norm(b))
40
+ best_doc = max(docs, key=lambda k: cosine(q_emb, doc_embeddings[k]))
41
+ return docs[best_doc]
42
+
43
+ # --- 4. Fonction de réponse avec Mistral 7B Instruct ---
44
  def respond(
45
  message,
46
  history: list[dict[str, str]],
 
50
  top_p,
51
  hf_token: gr.OAuthToken,
52
  ):
53
+ client = InferenceClient(token=hf_token.token, model="mistralai/Mistral-7B-Instruct-v0.2")
 
 
 
54
 
55
+ # Récupérer le contexte pertinent avec SteelBERT
56
+ best_doc = search_best_doc(message)
57
+ context = docs[best_doc]
58
 
59
+ # Construire le prompt
60
+ messages = [{"role": "system", "content": system_message}]
61
  messages.extend(history)
62
+ messages.append({
63
+ "role": "user",
64
+ "content": f"Question: {message}\nContexte: {context}\nRéponds clairement en français :"
65
+ })
66
 
67
  response = ""
 
68
  for message in client.chat_completion(
69
  messages,
70
  max_tokens=max_tokens,
 
72
  temperature=temperature,
73
  top_p=top_p,
74
  ):
75
+ if len(message.choices) and message.choices[0].delta.content:
76
+ token = message.choices[0].delta.content
77
+ response += token
78
+ yield response
79
 
80
+ # --- 5. Interface Gradio ---
 
 
 
 
 
 
81
  chatbot = gr.ChatInterface(
82
  respond,
83
  type="messages",
84
  additional_inputs=[
85
+ gr.Textbox(value="Tu es un assistant spécialisé en métallurgie et en acier.", label="System message"),
86
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
87
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
88
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
 
 
 
 
 
89
  ],
90
  )
91
 
 
94
  gr.LoginButton()
95
  chatbot.render()
96
 
 
97
  if __name__ == "__main__":
98
  demo.launch()