#!/usr/bin/env python import gradio as gr import PIL.Image import spaces import torch from transformers import AutoModel, AutoProcessor, GenerationConfig DESCRIPTION = "# MIL-UT/Asagi-14B" model_id = "MIL-UT/Asagi-14B" processor = AutoProcessor.from_pretrained(model_id) model = AutoModel.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True) TEMPLATE = ( "以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n\n" "### 指示:\n\n{prompt}\n\n### 応答:\n" ) @spaces.GPU def run( image: PIL.Image.Image, prompt: str, max_new_tokens: int = 256, temperature: float = 0.7, ) -> str: prompt = TEMPLATE.format(prompt=prompt) inputs = processor(text=prompt, images=image, return_tensors="pt") inputs_text = processor.tokenizer(prompt, return_tensors="pt") inputs["input_ids"] = inputs_text["input_ids"] inputs["attention_mask"] = inputs_text["attention_mask"] for k, v in inputs.items(): if v.dtype == torch.float32: inputs[k] = v.to(model.dtype) inputs = {k: inputs[k].to(model.device) for k in inputs if k != "token_type_ids"} generation_config = GenerationConfig( max_new_tokens=max_new_tokens, temperature=temperature, do_sample=temperature > 0, num_beams=5, ) output = model.generate(**inputs, generation_config=generation_config) generated_text = processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] # do not print the prompt if "" in prompt: prompt = prompt.replace("", " ") return generated_text.replace(prompt, "") examples = [ [ "https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/shibuya.jpg", "この画像を見て、次の質問に詳細かつ具体的に答えてください。この写真はどこで撮影されたものか教えてください。また、画像の内容についても詳しく説明してください。", ], [ "https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/bridge.jpg", "この画像を見て、次の指示に詳細かつ具体的に答えてください。この写真の内容について詳しく教えてください。", ], [ "https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/tower.jpg", "この画像を見て、次の質問に詳細かつ具体的に答えてください。この写真について評価してください。", ], [ "https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/shrine.jpg", "この画像を見て、次の質問に詳細かつ具体的に答えてください。この写真の神社について、細かいところまで詳しく説明してください。", ], [ "https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/garden.jpg", "この画像を見て、次の指示に詳細かつ具体的に答えてください。これは日本庭園の中でも、どのような形式に分類される庭園ですか?また、その理由は何ですか?", # noqa: RUF001 ], [ "https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/slope.jpg", "この画像を見て、次の質問に詳細に答えてください。この画像の場所を舞台とした小説のあらすじを書いてください。", ], ] with gr.Blocks(css_paths="style.css") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): image = gr.Image(label="Input Image") prompt = gr.Textbox(label="Prompt") run_button = gr.Button() with gr.Accordion("Advanced options", open=False): max_new_tokens = gr.Slider( label="Max new tokens", minimum=1, maximum=1024, step=1, value=256, ) temperature = gr.Slider( label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7, ) with gr.Column(): output = gr.Textbox(label="Output") gr.Examples(examples=examples, inputs=[image, prompt]) run_button.click( fn=run, inputs=[image, prompt, max_new_tokens, temperature], outputs=output, ) if __name__ == "__main__": demo.launch()