import gradio as gr import cv2 import torch from transformers import AutoTokenizer, AutoModelForMaskedLM from collections import defaultdict tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") model = AutoModelForMaskedLM.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") def mlm(image, text): print(text) questions_dict = { #'location': f'[CLS] Only [MASK] cells have a {cls_name}. [SEP]', #num of mask? # 'location': f'[CLS] The {cls_name} normally appears at or near the [MASK] of a cell. [SEP]', # 'color': f'[CLS] When a cell is histologically stained, the {cls_name} are in [MASK] color. [SEP]', # 'shape': f'[CLS] Mostly the shape of {cls_name} is [MASK]. [SEP]', 'location': f'This {text} is at [MASK] place', 'color': f'This {text} is in [MASK] color', 'shape': f'This {text} is in [MASK] shape', #'def': f'{cls_name} is a . [SEP]', } ans = list() res = defaultdict() device = 'cpu' for k, v in questions_dict.items(): predicted_tokens = [] print(v) tokenized_text = tokenizer.tokenize(v) indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) # Create the segments tensors. segments_ids = [0] * len(tokenized_text) # Convert inputs to PyTorch tensors tokens_tensor = torch.tensor([indexed_tokens]).to(device) segments_tensors = torch.tensor([segments_ids]).to(device) masked_index = tokenized_text.index('[MASK]') with torch.no_grad(): predictions = model(tokens_tensor, segments_tensors) _, predicted_index = torch.topk(predictions[0][0][masked_index], 1)#.item() predicted_index = predicted_index.detach().cpu().numpy() print(predicted_index) for idx in predicted_index: predicted_tokens.append(tokenizer.convert_ids_to_tokens([idx])[0]) # for i in range(1): # res[text][k].append(predicted_tokens) print(predicted_tokens) res[k] = predicted_tokens[0] color, shape, loc = res['color'], res['shape'], res['location'] ans = f'{color} color, {shape} shape, {text} at {loc}' print(ans) return image, ans def to_black(image, text): output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) outputs = [output, text] return outputs interface = gr.Interface(fn=mlm, inputs=["image", "text"], outputs=["image", "text"]) interface.launch()