File size: 4,775 Bytes
b3d9f11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python

import gradio as gr
import PIL.Image
import spaces
import torch
from transformers import AutoModel, AutoProcessor, GenerationConfig

DESCRIPTION = "# MIL-UT/Asagi-14B"

model_id = "MIL-UT/Asagi-14B"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)


TEMPLATE = (
    "以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n\n"
    "### 指示:\n<image>\n{prompt}\n\n### 応答:\n"
)


@spaces.GPU
def run(
    image: PIL.Image.Image,
    prompt: str,
    max_new_tokens: int = 256,
    temperature: float = 0.7,
) -> str:
    prompt = TEMPLATE.format(prompt=prompt)

    inputs = processor(text=prompt, images=image, return_tensors="pt")
    inputs_text = processor.tokenizer(prompt, return_tensors="pt")
    inputs["input_ids"] = inputs_text["input_ids"]
    inputs["attention_mask"] = inputs_text["attention_mask"]
    for k, v in inputs.items():
        if v.dtype == torch.float32:
            inputs[k] = v.to(model.dtype)
    inputs = {k: inputs[k].to(model.device) for k in inputs if k != "token_type_ids"}

    generation_config = GenerationConfig(
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=temperature > 0,
        num_beams=5,
    )

    output = model.generate(**inputs, generation_config=generation_config)
    generated_text = processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    # do not print the prompt
    if "<image>" in prompt:
        prompt = prompt.replace("<image>", " ")
    return generated_text.replace(prompt, "")


examples = [
    [
        "https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/shibuya.jpg",
        "この画像を見て、次の質問に詳細かつ具体的に答えてください。この写真はどこで撮影されたものか教えてください。また、画像の内容についても詳しく説明してください。",
    ],
    [
        "https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/bridge.jpg",
        "この画像を見て、次の指示に詳細かつ具体的に答えてください。この写真の内容について詳しく教えてください。",
    ],
    [
        "https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/tower.jpg",
        "この画像を見て、次の質問に詳細かつ具体的に答えてください。この写真について評価してください。",
    ],
    [
        "https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/shrine.jpg",
        "この画像を見て、次の質問に詳細かつ具体的に答えてください。この写真の神社について、細かいところまで詳しく説明してください。",
    ],
    [
        "https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/garden.jpg",
        "この画像を見て、次の指示に詳細かつ具体的に答えてください。これは日本庭園の中でも、どのような形式に分類される庭園ですか?また、その理由は何ですか?",  # noqa: RUF001
    ],
    [
        "https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/slope.jpg",
        "この画像を見て、次の質問に詳細に答えてください。この画像の場所を舞台とした小説のあらすじを書いてください。",
    ],
]

with gr.Blocks(css_paths="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column():
            image = gr.Image(label="Input Image")
            prompt = gr.Textbox(label="Prompt")
            run_button = gr.Button()
            with gr.Accordion("Advanced options", open=False):
                max_new_tokens = gr.Slider(
                    label="Max new tokens",
                    minimum=1,
                    maximum=1024,
                    step=1,
                    value=256,
                )
                temperature = gr.Slider(
                    label="Temperature",
                    minimum=0.1,
                    maximum=2.0,
                    step=0.1,
                    value=0.7,
                )
        with gr.Column():
            output = gr.Textbox(label="Output")

    gr.Examples(examples=examples, inputs=[image, prompt])

    run_button.click(
        fn=run,
        inputs=[image, prompt, max_new_tokens, temperature],
        outputs=output,
    )

if __name__ == "__main__":
    demo.launch()