acediaaa commited on
Commit
7679301
·
verified ·
1 Parent(s): 661c67a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -21
app.py CHANGED
@@ -1,42 +1,43 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
 
4
 
5
- # Xác định device: sử dụng GPU nếu có, ngược lại sử dụng CPU
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
-
8
- # Tải tokenizer và model từ repository "vinai/bartpho-syllable"
9
  tokenizer = AutoTokenizer.from_pretrained("vinai/bartpho-syllable")
10
  model = AutoModelForSeq2SeqLM.from_pretrained("vinai/bartpho-syllable")
 
 
 
11
  model.to(device)
12
 
13
- def generate_answer(question):
14
  model.eval()
15
- # Tạo đầu vào cho model: thêm tiền tố "hỏi:" để phù hợp với định dạng huấn luyện (nếu có)
16
  input_text = "hỏi: " + question
 
 
17
  inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length")
18
- inputs = {k: v.to(device) for k, v in inputs.items()}
 
19
 
20
  with torch.no_grad():
 
21
  outputs = model.generate(
22
- input_ids=inputs["input_ids"],
23
- attention_mask=inputs["attention_mask"],
24
  max_length=512, # Độ dài tối đa của câu trả lời
25
- num_beams=5, # Sử dụng beam search với 5 beams
26
- repetition_penalty=1.2, # Phạt lặp lại từ (giá trị > 1 để hạn chế trùng lặp)
27
- no_repeat_ngram_size=3, # Tránh lặp lại các cụm từ có độ dài 3 token
28
- early_stopping=True # Dừng sớm nếu hình đã sinh đủ văn bản
29
  )
30
 
31
- # Giải mã kết quả từ token sang chuỗi văn bản
32
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
33
  return answer
34
 
35
- # Tạo giao diện Gradio cho ứng dụng
36
- demo = gr.Interface(
37
- fn=generate_answer,
38
- inputs=gr.Textbox(label="Nhập câu hỏi"),
39
- outputs=gr.Textbox(label="Câu trả lời")
40
- )
41
 
42
- demo.launch()
 
 
1
  import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
 
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
7
  tokenizer = AutoTokenizer.from_pretrained("vinai/bartpho-syllable")
8
  model = AutoModelForSeq2SeqLM.from_pretrained("vinai/bartpho-syllable")
9
+ model_file = hf_hub_download(repo_id="acediaaa/VietTour_0", filename="viettour_model_bartpho.pth")
10
+ state_dict = torch.load(model_file, map_location=torch.device('cpu'))
11
+ model.load_state_dict(state_dict)
12
  model.to(device)
13
 
14
+ def generate_answer(question, model, tokenizer, device):
15
  model.eval()
 
16
  input_text = "hỏi: " + question
17
+
18
+ # Tokenize câu hỏi
19
  inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length")
20
+ input_ids = inputs.input_ids.to(device)
21
+ attention_mask = inputs.attention_mask.to(device)
22
 
23
  with torch.no_grad():
24
+ # Sinh câu trả lời từ mô hình
25
  outputs = model.generate(
26
+ input_ids=input_ids,
27
+ attention_mask=attention_mask,
28
  max_length=512, # Độ dài tối đa của câu trả lời
29
+ num_beams=5, # Beam search với 5 beam
30
+ repetition_penalty=1.2, # Phạt lặp từ (giá trị > 1.0 để tránh lặp lại)
31
+ no_repeat_ngram_size=3, # Tránh lặp lại các cụm từ dài 3 từ
32
+ early_stopping=True # Dừng sớm nếu sinh văn bản đủ tốt
33
  )
34
 
35
+ # Giải mã câu trả lời từ token thành chuỗi văn bản
36
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
37
  return answer
38
 
39
+ def run(ques):
40
+ return generate_answer(ques, model, tokenizer, device)
 
 
 
 
41
 
42
+ demo = gr.Interface(fn=run, inputs=gr.Textbox(label="Nhập câu hỏi"), outputs=gr.Textbox(label="Câu trả lời"))
43
+ demo.launch()