|
from typing import Dict, List, Any |
|
import os |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from transformers import PreTrainedTokenizerFast |
|
from transformers import GenerationConfig |
|
import transformers |
|
import pandas as pd |
|
import time |
|
from precious3_gpt_multi_model import Custom_MPTForCausalLM |
|
|
|
|
|
emb_gpt_genes = pd.read_pickle('./multi-modal-data/emb_gpt_genes.pickle') |
|
emb_hgt_genes = pd.read_pickle('./multi-modal-data/emb_hgt_genes.pickle') |
|
|
|
|
|
def create_prompt(prompt_config): |
|
|
|
prompt = "[BOS]" |
|
|
|
multi_modal_prefix = '<modality0><modality1><modality2><modality3>'*3 |
|
|
|
for k, v in prompt_config.items(): |
|
if k=='instruction': |
|
prompt+=f"<{v}>" |
|
elif k=='up': |
|
prompt+=f'{multi_modal_prefix}<{k}>{v}</{k}>' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} </{k}>' |
|
elif k=='down': |
|
prompt+=f'{multi_modal_prefix}<{k}>{v}</{k}>' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} </{k}>' |
|
else: |
|
prompt+=f'<{k}>{v}</{k}>' if isinstance(v, str) else f'<{k}>{" ".join(v)} </{k}>' |
|
return prompt |
|
|
|
def custom_generate(input_ids, |
|
acc_embs_up_kg_mean, |
|
acc_embs_down_kg_mean, |
|
acc_embs_up_txt_mean, |
|
acc_embs_down_txt_mean, |
|
device, |
|
max_new_tokens, |
|
num_return_sequences, |
|
temperature=0.8, |
|
top_p=0.2, top_k=3550, n_next_tokens=50, |
|
unique_compounds): |
|
torch.manual_seed(137) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modality0_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_kg_mean), 0).to(device) |
|
modality1_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_kg_mean), 0).to(device) |
|
modality2_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_txt_mean), 0).to(device) |
|
modality3_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_txt_mean), 0).to(device) |
|
|
|
|
|
|
|
outputs = [] |
|
next_token_compounds = [] |
|
|
|
for _ in range(num_return_sequences): |
|
start_time = time.time() |
|
generated_sequence = [] |
|
current_token = input_ids.clone() |
|
|
|
for _ in range(max_new_tokens): |
|
|
|
logits = model.forward(input_ids=current_token, |
|
modality0_emb=modality0_emb, |
|
modality0_token_id=62191, |
|
modality1_emb=modality1_emb, |
|
modality1_token_id=62192, |
|
modality2_emb=modality2_emb, |
|
modality2_token_id=62193, |
|
modality3_emb=modality3_emb, |
|
modality3_token_id=62194)[0] |
|
|
|
|
|
if temperature != 1.0: |
|
logits = logits / temperature |
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
if top_k > 0: |
|
sorted_indices_to_remove[..., top_k:] = 1 |
|
|
|
|
|
inf_tensor = torch.tensor(float("-inf")).type(torch.bfloat16).to(logits.device) |
|
|
|
logits = logits.where(sorted_indices_to_remove, inf_tensor) |
|
|
|
|
|
|
|
if current_token[0][-1] == tokenizer.encode('<drug>')[0]: |
|
next_token_compounds.append(torch.topk(torch.softmax(logits, dim=-1)[0][len(current_token[0])-1, :].flatten(), 50).indices) |
|
|
|
next_token = torch.multinomial(torch.softmax(logits, dim=-1)[0], num_samples=1)[len(current_token[0])-1, :].unsqueeze(0) |
|
|
|
|
|
|
|
generated_sequence.append(next_token.item()) |
|
|
|
Stop generation if an end token is generated |
|
if next_token == tokenizer.eos_token_id: |
|
break |
|
|
|
|
|
current_token = torch.cat((current_token, next_token), dim=-1) |
|
print(time.time()-start_time) |
|
outputs.append(generated_sequence) |
|
return outputs, next_token_compounds |
|
|
|
|
|
def get_predicted_compounds(input_ids, generation_output, tokenizer, p3_compounds): |
|
id_4_drug_token = list(generation_output.sequences[0][len(input_ids[0]):]).index(tokenizer.convert_tokens_to_ids(['<drug>'])[0]) |
|
id_4_drug_token += 1 |
|
print('This is token index where drug should be predicted: ', id_4_drug_token) |
|
|
|
values, indices = torch.topk(generation_output["scores"][id_4_drug_token].view(-1), k=50) |
|
indices_decoded = tokenizer.decode(indices, skip_special_tokens=True) |
|
|
|
predicted_compound = indices_decoded.split(' ') |
|
predicted_compound = [i.strip() for i in predicted_compound] |
|
|
|
valid_compounds = sorted(set(predicted_compound) & set(p3_compounds), key = predicted_compound.index) |
|
print(f"Model predicted {len(predicted_compound)} tokens. Valid compounds {len(valid_compounds)}") |
|
return valid_compounds |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
self.model = Custom_MPTForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to('cuda') |
|
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file = os.path.join(path, "tokenizer.json"), unk_token="[UNK]", |
|
pad_token="[PAD]", |
|
eos_token="[EOS]", |
|
bos_token="[BOS]") |
|
self.model.config.pad_token_id = self.tokenizer.pad_token_id |
|
self.model.config.bos_token_id = self.tokenizer.bos_token_id |
|
self.model.config.eos_token_id = self.tokenizer.eos_token_id |
|
unique_entities_p3 = pd.read_csv(os.path.join(path, 'all_entities_with_type.csv')) |
|
self.unique_compounds_p3 = [i.strip() for i in unique_entities_p3[unique_entities_p3.type=='compound'].entity.to_list()] |
|
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: |
|
""" |
|
Args: |
|
data (:dict:): |
|
The payload with the text prompt and generation parameters. |
|
""" |
|
|
|
inputs = data.pop("inputs", data) |
|
parameters = data.pop("parameters", None) |
|
mode = data.pop('mode', 'diff2compound') |
|
|
|
if mode == 'diff2compound': |
|
with open('./generation-configs/diff2compound.json', 'r') as f: |
|
config_data = json.load(f) |
|
else: |
|
with open('./generation-configs/diff2compound.json', 'r') as f: |
|
config_data = json.load(f) |
|
|
|
prompt = create_prompt(config_data) |
|
|
|
inputs = self.tokenizer(inputs, return_tensors="pt") |
|
input_ids = inputs["input_ids"].to('cuda') |
|
|
|
|
|
generation_config = GenerationConfig(**parameters, |
|
pad_token_id=self.tokenizer.pad_token_id, num_return_sequences=1) |
|
|
|
max_new_tokens = self.model.config.max_seq_len - len(input_ids[0]) |
|
|
|
torch.manual_seed(137) |
|
|
|
with torch.no_grad(): |
|
generation_output = self.model.generate( |
|
input_ids=input_ids, |
|
generation_config=generation_config, |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
max_new_tokens=max_new_tokens |
|
) |
|
if mode =='diff2compound': |
|
predicted_compounds = get_predicted_compounds(input_ids=input_ids, generation_output=generation_output, tokenizer=self.tokenizer, p3_compounds=self.unique_compounds_p3) |
|
output = {'output': predicted_compounds, "mode": mode, 'message': "Done!"} |
|
else: |
|
output = {'output': [None], "mode": mode, 'message': "Set mode"} |
|
return output |