import gradio as gr import torch import transformers def reduce_sum(value, mask, axis=None): if axis is None: return torch.sum(value * mask) return torch.sum(value * mask, axis) def reduce_mean(value, mask, axis=None): if axis is None: return torch.sum(value * mask) / torch.sum(mask) return reduce_sum(value, mask, axis) / torch.sum(mask, axis) device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') class InteractiveRainier: def __init__(self): self.tokenizer = transformers.AutoTokenizer.from_pretrained('allenai/unifiedqa-t5-large') self.rainier_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('liujch1998/rainier-large').to(device) self.qa_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('allenai/unifiedqa-t5-large').to(device) self.loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100,reduction='none') def parse_choices(self, s): ''' s: serialized_choices '(A) ... (B) ... (C) ...' ''' choices = [] key = 'A' if s.find('(A)') != -1 else 'a' while True: pos = s.find(f'({chr(ord(key) + 1)})') if pos == -1: break choice = s[3:pos] s = s[pos:] choice = choice.strip(' ') choices.append(choice) key = chr(ord(key) + 1) choice = s[3:] choice = choice.strip(' ') choices.append(choice) return choices def run(self, question, max_input_len, max_output_len, m, top_p): tokenized = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L) knowledges_ids = self.rainier_model.generate( input_ids=tokenized.input_ids, max_length=max_output_len + 1, min_length=3, do_sample=True, num_return_sequences=m, top_p=top_p, ) # (K, L); begins with 0 ([BOS]); ends with 1 ([EOS]) knowledges_ids = knowledges_ids[:, 1:].contiguous() # no beginning; ends with 1 ([EOS]) knowledges = self.tokenizer.batch_decode(knowledges_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) knowledges = list(set(knowledges)) knowledges = [''] + knowledges prompts = [question + (f' \\n {knowledge}' if knowledge != '' else '') for knowledge in knowledges] choices = self.parse_choices(question.split('\\n')[1].strip(' ')) prompts = [prompt.lower() for prompt in prompts] choices = [choice.lower() for choice in choices] answer_logitss = [] for choice in choices: tokenized_prompts = self.tokenizer(prompts, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1+K, L) tokenized_choices = self.tokenizer([choice], return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L) pad_mask = (tokenized_choices.input_ids == self.tokenizer.pad_token_id) tokenized_choices.input_ids[pad_mask] = -100 tokenized_choices.input_ids = tokenized_choices.input_ids.repeat(len(knowledges), 1) # (1+K, L) with torch.no_grad(): logits = self.qa_model( input_ids=tokenized_prompts.input_ids, attention_mask=tokenized_prompts.attention_mask, labels=tokenized_choices.input_ids, ).logits # (1+K, L, V) losses = self.loss_fct(logits.view(-1, logits.size(-1)), tokenized_choices.input_ids.view(-1)) losses = losses.view(tokenized_choices.input_ids.shape) # (1+K, L) losses = reduce_mean(losses, ~pad_mask, axis=-1) # (1+K) answer_logitss.append(-losses) answer_logitss = torch.stack(answer_logitss, dim=1) # (1+K, C) answer_probss = answer_logitss.softmax(dim=1) # (1+K, C) # Ensemble knowless_pred = answer_probss[0, :].argmax(dim=0).item() knowless_pred = choices[knowless_pred] answer_probs = answer_probss.max(dim=0).values # (C) knowful_pred = answer_probs.argmax(dim=0).item() knowful_pred = choices[knowful_pred] selected_knowledge_ix = answer_probss.max(dim=1).values.argmax(dim=0).item() selected_knowledge = knowledges[selected_knowledge_ix] return { 'question': question, 'knowledges': knowledges, 'knowless_pred': knowless_pred, 'knowful_pred': knowful_pred, 'selected_knowledge': selected_knowledge, } rainier = InteractiveRainier() def predict(question, kg_model, qa_model, max_input_len, max_output_len, m, top_p): result = rainier.run(question, max_input_len, max_output_len, m, top_p) output = '' output += f'QA model answer without knowledge: {result["knowless_pred"]}\n' output += f'QA model answer with knowledge: {result["knowful_pred"]}\n' output += '\n' output += f'All generated knowledges:\n' for knowledge in result['knowledges']: output += f' {knowledge}\n' output += '\n' output += f'Knowledge selected to make the prediction: {result["selected_knowledge"]}\n' return output examples = [ 'If the mass of an object gets bigger what will happen to the amount of matter contained within it? \\n (A) gets bigger (B) gets smaller', 'What would vinyl be an odd thing to replace? \\n (A) pants (B) record albums (C) record store (D) cheese (E) wallpaper', 'Some pelycosaurs gave rise to reptile ancestral to \\n (A) lamphreys (B) angiosperm (C) mammals (D) paramecium (E) animals (F) protozoa (G) arachnids (H) backbones', 'Sydney rubbed Addison’s head because she had a horrible headache. What will happen to Sydney? \\n (A) drift to sleep (B) receive thanks (C) be reprimanded', 'Adam always spent all of the free time watching Tv unlike Hunter who volunteered, due to _ being lazy. \\n (A) Adam (B) Hunter', 'Causes bad breath and frightens blood-suckers \\n (A) tuna (B) iron (C) trash (D) garlic (E) pubs', ] input_question = gr.inputs.Dropdown( choices=examples, label='Question:', # info='A multiple-choice commonsense question. Please follow the UnifiedQA input format: "{question} \\n (A) ... (B) ... (C) ..."', ) input_kg_model = gr.inputs.Textbox(label='Knowledge generation model:', value='liujch1998/rainier-large', interactive=False) input_qa_model = gr.inputs.Textbox(label='QA model:', value='allenai/unifiedqa-t5-large', interactive=False) input_max_input_len = gr.inputs.Number(label='Max question length:', value=256, precision=0) input_max_output_len = gr.inputs.Number(label='Max knowledge length:', value=32, precision=0) input_m = gr.inputs.Slider(label='Number of generated knowledges:', value=10, mininum=1, maximum=20, step=1) input_top_p = gr.inputs.Slider(label='Top_p for knowledge generation:', value=0.5, mininum=0.0, maximum=1.0, step=0.05) output_text = gr.outputs.Textbox(label='Output') gr.Interface( fn=predict, inputs=[input_question, input_kg_model, input_qa_model, input_max_input_len, input_max_output_len, input_m, input_top_p], outputs=output_text, title="Rainier", ).launch()