satyaiyer commited on
Commit
87f4bdc
·
verified ·
1 Parent(s): 642595e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -19
app.py CHANGED
@@ -3,22 +3,30 @@ import gradio as gr
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  import torch
5
  import os
 
6
 
7
- # Load the model (flan-t5-base) and tokenizer
8
  model_name = "google/flan-t5-base"
9
- hf_token = os.environ.get("HF_TOKEN") # Ensure your token is securely set as a secret
10
 
11
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
12
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=hf_token)
13
 
14
  # Move the model to CPU (or GPU if available)
15
- model.to("cpu")
 
16
 
17
- # Function to generate the prompt for MT QE
18
  def generate_prompt(original, translation):
19
- return f"### Task: Machine Translation Quality Estimation\n\nSource: {original}\nTranslation: {translation}\n\nScore (0-1):"
 
 
 
 
 
 
20
 
21
- # Function to predict quality scores from the file
22
  def predict_scores(file):
23
  df = pd.read_csv(file.name, sep="\t")
24
  scores = []
@@ -26,32 +34,35 @@ def predict_scores(file):
26
  for _, row in df.iterrows():
27
  prompt = generate_prompt(row["original"], row["translation"])
28
 
29
- # Tokenize and generate outputs
30
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
31
  outputs = model.generate(**inputs, max_new_tokens=10)
32
-
33
- # Decode and extract the score from the response
34
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
35
 
36
- # Extract float value (simple way to extract score from response)
37
- score = response.split("Score")[-1].strip()
38
- try:
39
- score_val = float(score.split()[0])
40
- except:
41
- score_val = -1 # Fallback in case of error
 
 
 
 
42
 
43
  scores.append(score_val)
44
 
45
  df["predicted_score"] = scores
46
  return df
47
 
48
- # Set up the Gradio interface
49
  iface = gr.Interface(
50
  fn=predict_scores,
51
  inputs=gr.File(label="Upload dev.tsv"),
52
  outputs=gr.Dataframe(label="QE Output with Predicted Score"),
53
- title="MT QE with Google FLAN-T5-Base",
 
54
  )
55
 
56
- # Launch the Gradio interface
57
  iface.launch()
 
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  import torch
5
  import os
6
+ import re
7
 
8
+ # Load the model and tokenizer
9
  model_name = "google/flan-t5-base"
10
+ hf_token = os.environ.get("HF_TOKEN") # Set as a secret in Hugging Face Space settings
11
 
12
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
13
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=hf_token)
14
 
15
  # Move the model to CPU (or GPU if available)
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model.to(device)
18
 
19
+ # Function to generate a clean prompt
20
  def generate_prompt(original, translation):
21
+ return (
22
+ f"Rate the quality of this translation from 0 (poor) to 1 (excellent). "
23
+ f"Only respond with a number.\n\n"
24
+ f"Source: {original}\n"
25
+ f"Translation: {translation}\n"
26
+ f"Score:"
27
+ )
28
 
29
+ # Main prediction function
30
  def predict_scores(file):
31
  df = pd.read_csv(file.name, sep="\t")
32
  scores = []
 
34
  for _, row in df.iterrows():
35
  prompt = generate_prompt(row["original"], row["translation"])
36
 
37
+ # Tokenize and send to model
38
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
39
  outputs = model.generate(**inputs, max_new_tokens=10)
 
 
40
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
 
42
+ # Debug print (optional)
43
+ print("Response:", response)
44
+
45
+ # Extract numeric score using regex
46
+ match = re.search(r"\b([01](?:\.\d+)?)\b", response)
47
+ if match:
48
+ score_val = float(match.group(1))
49
+ score_val = max(0, min(score_val, 1)) # Clamp between 0 and 1
50
+ else:
51
+ score_val = -1 # fallback if model output is invalid
52
 
53
  scores.append(score_val)
54
 
55
  df["predicted_score"] = scores
56
  return df
57
 
58
+ # Gradio UI
59
  iface = gr.Interface(
60
  fn=predict_scores,
61
  inputs=gr.File(label="Upload dev.tsv"),
62
  outputs=gr.Dataframe(label="QE Output with Predicted Score"),
63
+ title="MT QE with FLAN-T5-Base",
64
+ description="Upload a dev.tsv file with columns: 'original' and 'translation'."
65
  )
66
 
67
+ # Launch app
68
  iface.launch()