from fastapi import FastAPI from typing import Dict, List, Any, Tuple import pickle import math import re import gc from utils import split import torch from build_vocab import WordVocab from pretrain_trfm import TrfmSeq2seq from transformers import T5EncoderModel, T5Tokenizer import numpy as np import pydantic app = FastAPI() tokenizer = T5Tokenizer.from_pretrained( "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False, torch_dtype=torch.float16) model = T5EncoderModel.from_pretrained( "Rostlab/prot_t5_xl_half_uniref50-enc") class Item(pydantic.BaseModel): sequence: str smiles: str @app.post("/predict") def predict(item: Item): endpointHandler = EndpointHandler() result = endpointHandler.predict({ "inputs": { "sequence": item.sequence, "smiles": item.smiles } }) return result class EndpointHandler(): def __init__(self, path=""): self.tokenizer = tokenizer self.model = model # path to the vocab_content and trfm model vocab_content_path = "vocab_content.txt" trfm_path = "trfm_12_23000.pkl" # load the vocab_content instead of the pickle file with open(vocab_content_path, "r", encoding="utf-8") as f: vocab_content = f.read().strip().split("\n") # load the vocab and trfm model self.vocab = WordVocab(vocab_content) self.trfm = TrfmSeq2seq(len(self.vocab), 256, len(self.vocab), 4) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.trfm.load_state_dict(torch.load(trfm_path, map_location=device)) self.trfm.eval() # path to the pretrained models self.Km_model_path = "Km.pkl" self.Kcat_model_path = "Kcat.pkl" self.Kcat_over_Km_model_path = "Kcat_over_Km.pkl" # vocab indices self.pad_index = 0 self.unk_index = 1 self.eos_index = 2 self.sos_index = 3 def predict(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Function where the endpoint logic is implemented. Args: data (Dict[str, Any]): The input data for the endpoint. It only contain a single key "inputs" which is a list of dictionaries. The dictionary contains the following keys: - sequence (str): Amino acid sequence. - smiles (str): SMILES representation of the molecule. Returns: Dict[str, Any]: The output data for the endpoint. The dictionary contains the following keys: - Km (float): float of predicted Km value. - Kcat (float): float of predicted Kcat value. - Vmax (float): float of predicted Vmax value. """ sequence = data["inputs"]["sequence"] smiles = data["inputs"]["smiles"] seq_vec = self.Seq_to_vec(sequence) smiles_vec = self.smiles_to_vec(smiles) fused_vector = np.concatenate((smiles_vec, seq_vec), axis=1) pred_Km = self.predict_feature_using_model( fused_vector, self.Km_model_path) pred_Kcat = self.predict_feature_using_model( fused_vector, self.Kcat_model_path) pred_Vmax = self.predict_feature_using_model( fused_vector, self.Kcat_over_Km_model_path) result = { "Km": pred_Km, "Kcat": pred_Kcat, "Vmax": pred_Vmax, } return result def predict_feature_using_model(self, X: np.array, model_path: str) -> float: """ Function to predict the feature using the pretrained model. """ with open(model_path, "rb") as f: model = pickle.load(f) pred_feature = model.predict(X) pred_feature_pow = math.pow(10, pred_feature) return pred_feature_pow def smiles_to_vec(self, Smiles: str) -> np.array: """ Function to convert the smiles to a vector using the pretrained model. """ Smiles = [Smiles] x_split = [split(sm) for sm in Smiles] xid, xseg = self.get_array(x_split, self.vocab) X = self.trfm.encode(torch.t(xid)) return X def get_inputs(self, sm: str, vocab: WordVocab) -> Tuple[List[int], List[int]]: """ Convert smiles to tensor """ seq_len = len(sm) sm = sm.split() ids = [vocab.stoi.get(token, self.unk_index) for token in sm] ids = [self.sos_index] + ids + [self.eos_index] seg = [1]*len(ids) padding = [self.pad_index]*(seq_len - len(ids)) ids.extend(padding), seg.extend(padding) return ids, seg def get_array(self, smiles: list[str], vocab: WordVocab) -> Tuple[torch.tensor, torch.tensor]: """ Convert smiles to tensor """ x_id, x_seg = [], [] for sm in smiles: a,b = self.get_inputs(sm, vocab) x_id.append(a) x_seg.append(b) return torch.tensor(x_id), torch.tensor(x_seg) def Seq_to_vec(self, Sequence: str) -> np.array: """ Function to convert the sequence to a vector using the pretrained model. """ Sequence = [Sequence] sequences_Example = [] for i in range(len(Sequence)): zj = '' for j in range(len(Sequence[i]) - 1): zj += Sequence[i][j] + ' ' zj += Sequence[i][-1] sequences_Example.append(zj) gc.collect() print(torch.cuda.is_available()) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.model = self.model.to(device) self.model = self.model.eval() features = [] for i in range(len(sequences_Example)): sequences_Example_i = sequences_Example[i] sequences_Example_i = [re.sub(r"[UZOB]", "X", sequences_Example_i)] ids = self.tokenizer.batch_encode_plus(sequences_Example_i, add_special_tokens=True, padding=True) input_ids = torch.tensor(ids['input_ids']).to(device) attention_mask = torch.tensor(ids['attention_mask']).to(device) with torch.no_grad(): embedding = self.model(input_ids=input_ids, attention_mask=attention_mask) embedding = embedding.last_hidden_state.cpu().numpy() for seq_num in range(len(embedding)): seq_len = (attention_mask[seq_num] == 1).sum() seq_emd = embedding[seq_num][:seq_len - 1] features.append(seq_emd) features_normalize = np.zeros([len(features), len(features[0][0])], dtype=float) for i in range(len(features)): for k in range(len(features[0][0])): for j in range(len(features[i])): features_normalize[i][k] += features[i][j][k] features_normalize[i][k] /= len(features[i]) return features_normalize