update
Browse files
app.py
CHANGED
@@ -18,7 +18,10 @@ model = AutoModelForCausalLM.from_pretrained(model_tokenizer_path, trust_remote_
|
|
18 |
|
19 |
# List of available tasks
|
20 |
tasks = ['H3', 'H4', 'H3K9ac', 'H3K14ac', 'H4ac', 'H3K4me1', 'H3K4me2', 'H3K4me3', 'H3K36me3', 'H3K79me3']
|
21 |
-
|
|
|
|
|
|
|
22 |
def preprocess_response(response, mask_token="[MASK]"):
|
23 |
"""Extracts the response after the [MASK] token."""
|
24 |
if mask_token in response:
|
@@ -48,8 +51,8 @@ def generate(dna_sequence, task_type, sample_num=1):
|
|
48 |
|
49 |
response = model.generate(**tokenized_message, max_new_tokens=sample_num, do_sample=False)
|
50 |
reply = tokenizer.batch_decode(response, skip_special_tokens=False)[0].replace(" ", "")
|
51 |
-
|
52 |
-
return
|
53 |
|
54 |
def extract_label(message, task_type):
|
55 |
"""Extracts the prediction label from the model's response."""
|
@@ -65,7 +68,7 @@ interface = gr.Interface(
|
|
65 |
gr.Textbox(label="Input DNA Sequence", placeholder="Enter a DNA sequence"),
|
66 |
gr.Dropdown(choices=tasks, label="Select Task Type"),
|
67 |
],
|
68 |
-
outputs=gr.Textbox(label="Predicted
|
69 |
title="Omni-DNA Multitask Prediction",
|
70 |
description="Select a DNA-related task and input a sequence to generate function predictions.",
|
71 |
)
|
|
|
18 |
|
19 |
# List of available tasks
|
20 |
tasks = ['H3', 'H4', 'H3K9ac', 'H3K14ac', 'H4ac', 'H3K4me1', 'H3K4me2', 'H3K4me3', 'H3K36me3', 'H3K79me3']
|
21 |
+
mapping = {'1':'It is a',
|
22 |
+
'0':'It is not a',
|
23 |
+
'No valid prediction':'Cannot be determined whether or not it is a',
|
24 |
+
}
|
25 |
def preprocess_response(response, mask_token="[MASK]"):
|
26 |
"""Extracts the response after the [MASK] token."""
|
27 |
if mask_token in response:
|
|
|
51 |
|
52 |
response = model.generate(**tokenized_message, max_new_tokens=sample_num, do_sample=False)
|
53 |
reply = tokenizer.batch_decode(response, skip_special_tokens=False)[0].replace(" ", "")
|
54 |
+
pred = extract_label(reply, task_type)
|
55 |
+
return f"{mapping[pred]} {task_type}"
|
56 |
|
57 |
def extract_label(message, task_type):
|
58 |
"""Extracts the prediction label from the model's response."""
|
|
|
68 |
gr.Textbox(label="Input DNA Sequence", placeholder="Enter a DNA sequence"),
|
69 |
gr.Dropdown(choices=tasks, label="Select Task Type"),
|
70 |
],
|
71 |
+
outputs=gr.Textbox(label="Predicted Type"),
|
72 |
title="Omni-DNA Multitask Prediction",
|
73 |
description="Select a DNA-related task and input a sequence to generate function predictions.",
|
74 |
)
|