zehui127 commited on
Commit
3606fab
·
1 Parent(s): 26ab4df
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -18,7 +18,10 @@ model = AutoModelForCausalLM.from_pretrained(model_tokenizer_path, trust_remote_
18
 
19
  # List of available tasks
20
  tasks = ['H3', 'H4', 'H3K9ac', 'H3K14ac', 'H4ac', 'H3K4me1', 'H3K4me2', 'H3K4me3', 'H3K36me3', 'H3K79me3']
21
-
 
 
 
22
  def preprocess_response(response, mask_token="[MASK]"):
23
  """Extracts the response after the [MASK] token."""
24
  if mask_token in response:
@@ -48,8 +51,8 @@ def generate(dna_sequence, task_type, sample_num=1):
48
 
49
  response = model.generate(**tokenized_message, max_new_tokens=sample_num, do_sample=False)
50
  reply = tokenizer.batch_decode(response, skip_special_tokens=False)[0].replace(" ", "")
51
-
52
- return extract_label(reply, task_type)
53
 
54
  def extract_label(message, task_type):
55
  """Extracts the prediction label from the model's response."""
@@ -65,7 +68,7 @@ interface = gr.Interface(
65
  gr.Textbox(label="Input DNA Sequence", placeholder="Enter a DNA sequence"),
66
  gr.Dropdown(choices=tasks, label="Select Task Type"),
67
  ],
68
- outputs=gr.Textbox(label="Predicted Function"),
69
  title="Omni-DNA Multitask Prediction",
70
  description="Select a DNA-related task and input a sequence to generate function predictions.",
71
  )
 
18
 
19
  # List of available tasks
20
  tasks = ['H3', 'H4', 'H3K9ac', 'H3K14ac', 'H4ac', 'H3K4me1', 'H3K4me2', 'H3K4me3', 'H3K36me3', 'H3K79me3']
21
+ mapping = {'1':'It is a',
22
+ '0':'It is not a',
23
+ 'No valid prediction':'Cannot be determined whether or not it is a',
24
+ }
25
  def preprocess_response(response, mask_token="[MASK]"):
26
  """Extracts the response after the [MASK] token."""
27
  if mask_token in response:
 
51
 
52
  response = model.generate(**tokenized_message, max_new_tokens=sample_num, do_sample=False)
53
  reply = tokenizer.batch_decode(response, skip_special_tokens=False)[0].replace(" ", "")
54
+ pred = extract_label(reply, task_type)
55
+ return f"{mapping[pred]} {task_type}"
56
 
57
  def extract_label(message, task_type):
58
  """Extracts the prediction label from the model's response."""
 
68
  gr.Textbox(label="Input DNA Sequence", placeholder="Enter a DNA sequence"),
69
  gr.Dropdown(choices=tasks, label="Select Task Type"),
70
  ],
71
+ outputs=gr.Textbox(label="Predicted Type"),
72
  title="Omni-DNA Multitask Prediction",
73
  description="Select a DNA-related task and input a sequence to generate function predictions.",
74
  )