amurienne commited on
Commit
6a3db72
·
verified ·
1 Parent(s): 760d624

Initial chatbot

Browse files
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForSeq2SeqLM,
6
+ AutoModelForCausalLM,
7
+ BitsAndBytesConfig,
8
+ pipeline
9
+ )
10
+
11
+ import torch
12
+
13
+ # CHAT MODEL
14
+
15
+ chat_modelcard = 'meta-llama/Llama-3.2-3B-Instruct'
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained(chat_modelcard)
18
+ model = AutoModelForCausalLM.from_pretrained(chat_modelcard, quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16), device_map='cpu')
19
+
20
+ chat_pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer, do_sample=True, temperature=0.5, truncation=True, max_length=512, return_full_text=False)
21
+
22
+ # TRANSLATION MODELS
23
+
24
+ fw_modelcard = "amurienne/gallek-m2m100"
25
+ bw_modelcard = "amurienne/kellag-m2m100"
26
+
27
+ fw_model = AutoModelForSeq2SeqLM.from_pretrained(fw_modelcard)
28
+ fw_tokenizer = AutoTokenizer.from_pretrained(fw_modelcard)
29
+
30
+ fw_translation_pipeline = pipeline("translation", model=fw_model, tokenizer=fw_tokenizer, src_lang='fr', tgt_lang='br', max_length=400, device="cpu")
31
+
32
+ bw_model = AutoModelForSeq2SeqLM.from_pretrained(bw_modelcard)
33
+ bw_tokenizer = AutoTokenizer.from_pretrained(bw_modelcard)
34
+
35
+ bw_translation_pipeline = pipeline("translation", model=bw_model, tokenizer=bw_tokenizer, src_lang='br', tgt_lang='fr', max_length=400, device="cpu")
36
+
37
+ # translation function
38
+ def translate(text, forward: bool):
39
+ if forward:
40
+ return fw_translation_pipeline("traduis de français en breton: " + text)[0]['translation_text']
41
+ else:
42
+ return bw_translation_pipeline("treiñ eus ar galleg d'ar brezhoneg: " + text)[0]['translation_text']
43
+
44
+ # answer function
45
+ def answer(text):
46
+ return chat_pipeline(text, chat_template=None)[0]['generated_text']
47
+
48
+ def format_prompt_with_history(message, native_chat_history):
49
+ # format the conversation history
50
+ prompt = ""
51
+ for interaction in native_chat_history:
52
+ prompt += f"<|start_header_id|>{interaction['role']}<|end_header_id|>\n{interaction['content']}<|eot_id|>\n"
53
+
54
+ # add the current user message
55
+ prompt += f"<|start_header_id|>user<|end_header_id|>\ntu es un assistant francophone. Répond en une seule phrase sans formattage.<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n"
56
+
57
+ return prompt
58
+
59
+ # maximum number of interactions to keep in history
60
+ max_history_length = 3
61
+
62
+ # keep a hidden model "native" language chat history
63
+ native_chat_history = []
64
+
65
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
66
+ chatbot = gr.Chatbot(label="Breton Chatbot (Translation based)", type="messages")
67
+ msg = gr.Textbox(label='User Input')
68
+
69
+ def clear(chat_history):
70
+ """
71
+ Handles clearing chat
72
+ """
73
+ chat_history.clear()
74
+ native_chat_history.clear()
75
+
76
+ chatbot.clear(clear, inputs=[chatbot])
77
+
78
+ def respond(message, chat_history):
79
+ """
80
+ Handles bot response generation
81
+ """
82
+
83
+ global native_chat_history
84
+
85
+ fr_message = translate(message, forward=False)
86
+ print(f"user fr -> {fr_message}")
87
+
88
+ prompt = format_prompt_with_history(fr_message, native_chat_history)
89
+
90
+ bot_fr_message = answer(prompt)
91
+ print(f"bot fr -> {bot_fr_message}")
92
+ bot_br_message = translate( bot_fr_message, forward=True)
93
+ print(f"bot br -> {bot_br_message}")
94
+
95
+ chat_history.append({"role": "user", "content": message})
96
+ chat_history.append({"role": "assistant", "content": bot_br_message})
97
+
98
+ native_chat_history.append({"role": "user", "content": fr_message})
99
+ native_chat_history.append({"role": "assistant", "content": bot_fr_message})
100
+
101
+ # limit the history length
102
+ if len(chat_history) > max_history_length * 2:
103
+ chat_history = chat_history[-max_history_length * 2:]
104
+ native_chat_history = native_chat_history[-max_history_length * 2:]
105
+
106
+ return "", chat_history
107
+
108
+ msg.submit(respond, [msg, chatbot], [msg, chatbot])
109
+
110
+ if __name__ == "__main__":
111
+ demo.launch()