import gradio as gr import torch import json from transformers import AutoTokenizer, AutoModelForCausalLM if torch.cuda.is_available(): use_cuda = True else: use_cuda = False tokenizer = AutoTokenizer.from_pretrained("keminglu/pivoine-7b", use_auth_token="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", padding_side="left") model = AutoModelForCausalLM.from_pretrained("keminglu/pivoine-7b", use_auth_token="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", torch_dtype=torch.float16) model.requires_grad_(False) model.eval() if use_cuda: model = model.to("cuda") examples = json.load(open("examples.json")) description = open("description.txt").read() def inference(context, instruction, num_beams:int=4): input_str = f"\"{context}\"\n\n{instruction}" if not input_str.endswith("."): input_str += "." input_tokens = tokenizer(input_str, return_tensors="pt", padding=True) if use_cuda: for t in input_tokens: if torch.is_tensor(input_tokens[t]): input_tokens[t] = input_tokens[t].to("cuda") output = model.generate( input_tokens['input_ids'], num_beams=num_beams, do_sample=False, max_new_tokens=2048, num_return_sequences=1, return_dict_in_generate=True, ) num_input_tokens = input_tokens["input_ids"].shape[1] output_tokens = output.sequences generated_tokens = output_tokens[:, num_input_tokens:] num_generated_tokens = (generated_tokens != tokenizer.pad_token_id).sum(dim=-1).tolist()[0] prefix_to_add = torch.tensor([[tokenizer("A")["input_ids"][0]]]).to("cuda") generated_tokens = torch.cat([prefix_to_add, generated_tokens], dim=1) generated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) string_output = [i[1:].strip() for i in generated_text][0] json_output = None try: json_output = json.loads(string_output) except json.JSONDecodeError: json_output = {"error": "Unfortunately, there is a JSON decode error on your output, which is really rare in our experiment :("} except Exception as e: raise gr.Error(e) return num_generated_tokens, string_output, json_output demo = gr.Interface( fn=inference, inputs=["text", "text", gr.Slider(1,5,value=4,step=1)], outputs=[ gr.Number(label="Number of Generated Tokens"), gr.Textbox(label="Raw String Output"), gr.JSON(label="Json Output")], examples=examples, examples_per_page=3, title="Instruction-following Open-world Information Extraction", description=description, ) demo.launch( show_error=True)