|
from threading import Thread |
|
from typing import Iterator |
|
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor, TextIteratorStreamer |
|
from PIL import Image as PILImage |
|
import tempfile |
|
import torch |
|
import gradio as gr |
|
|
|
|
|
def get_gradio_demo(model, tokenizer, image_processor) -> gr.Interface: |
|
|
|
def get_prompt(message: str, chat_history: list[tuple[str, str]], |
|
system_prompt: str) -> str: |
|
texts = [f'#instruction: {system_prompt}\n', '#context:\n'] |
|
texts += [f"human: {user_input.strip()}\nagent: {response.strip()}\n" for user_input, response in chat_history if isinstance(user_input, str)] |
|
texts += [f'human: {message.strip()}'] |
|
return ''.join(texts) |
|
|
|
|
|
def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int: |
|
prompt = get_prompt(message, chat_history, system_prompt) |
|
input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids'] |
|
return input_ids.shape[-1] |
|
|
|
|
|
def run(image: PILImage.Image, |
|
message: str, |
|
chat_history: list[tuple[str, str]], |
|
system_prompt: str, |
|
max_new_tokens: int = 192, |
|
temperature: float = 0.1, |
|
top_p: float = 0.9, |
|
top_k: int = 50) -> Iterator[str]: |
|
prompt = get_prompt(message, chat_history, system_prompt) |
|
patch_images = image_processor([image], return_tensors="pt").pixel_values.to(torch.float16).to('cuda') |
|
inputs = tokenizer([prompt], return_tensors='pt').to('cuda') |
|
|
|
streamer = TextIteratorStreamer(tokenizer, timeout=10.) |
|
generate_kwargs = dict( |
|
inputs, |
|
patch_images=patch_images, |
|
streamer=streamer, |
|
max_length=max_new_tokens, |
|
do_sample=True, |
|
top_p=top_p, |
|
top_k=top_k, |
|
temperature=temperature, |
|
num_beams=1, |
|
) |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
outputs = [] |
|
for text in streamer: |
|
outputs.append(text) |
|
yield ''.join(outputs).replace("not yet.", "").replace("<s>", "").replace("</s>", "").strip() |
|
|
|
|
|
|
|
DEFAULT_SYSTEM_PROMPT = """can you specify which region the context describes?""" |
|
MAX_MAX_NEW_TOKENS = 512 |
|
DEFAULT_MAX_NEW_TOKENS = 128 |
|
MAX_INPUT_TOKEN_LENGTH = 512 |
|
|
|
DESCRIPTION = """<h1 align="center">TiO Demo</h1> |
|
<div align="center">https://huggingface.co/jxu124/TiO</div> |
|
""" |
|
|
|
LICENSE = """ |
|
<p/> |
|
|
|
--- |
|
""" |
|
|
|
if not torch.cuda.is_available(): |
|
DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>' |
|
|
|
|
|
def upload_image(file_obj): |
|
chatbot = [[(file_obj.name,), None]] |
|
return (gr.update(visible=False), gr.update(interactive=True, placeholder='Type a message...',), chatbot) |
|
|
|
|
|
def clear_and_save_textbox(message: str) -> tuple[str, str]: |
|
return '', message |
|
|
|
|
|
def display_input(message: str, |
|
history: list[tuple[str, str]]) -> list[tuple[str, str]]: |
|
if len(history) == 0: |
|
raise gr.Error(f'Upload an image first and try again.') |
|
history.append((message, '')) |
|
return history |
|
|
|
|
|
def delete_prev_fn( |
|
history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]: |
|
try: |
|
message, _ = history.pop() |
|
except IndexError: |
|
message = '' |
|
return history, message or '' |
|
|
|
|
|
def generate( |
|
message: str, |
|
history_with_input: list[tuple[str, str]], |
|
system_prompt: str, |
|
max_new_tokens: int, |
|
temperature: float, |
|
top_p: float, |
|
top_k: int, |
|
) -> Iterator[list[tuple[str, str]]]: |
|
if max_new_tokens > MAX_MAX_NEW_TOKENS: |
|
raise ValueError |
|
|
|
image = PILImage.open(history_with_input[0][0][0]) |
|
history = history_with_input[:-1] |
|
generator = run(image, message, history, system_prompt, max_new_tokens, temperature, top_p, top_k) |
|
try: |
|
first_response = next(generator) |
|
yield history + [(message, first_response)] |
|
except StopIteration: |
|
yield history + [(message, '')] |
|
for response in generator: |
|
if "region:" in response: |
|
bboxes = model.utils.sbbox_to_bbox(response) |
|
if len(bboxes): |
|
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: |
|
model.utils.show_mask(image, bboxes).save(f) |
|
chatbot = history + [(message, "OK, I see."), (None, (f.name,))] |
|
else: |
|
chatbot = history + [(message, response)] |
|
yield chatbot |
|
|
|
|
|
def process_example(message: str) -> tuple[str, list[tuple[str, str]]]: |
|
generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 192, 1, 0.95, 50) |
|
for x in generator: |
|
pass |
|
return '', x |
|
|
|
|
|
def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None: |
|
input_token_length = get_input_token_length(message, chat_history[:-1], system_prompt) |
|
if input_token_length > MAX_INPUT_TOKEN_LENGTH: |
|
raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.') |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(DESCRIPTION) |
|
|
|
with gr.Group(): |
|
chatbot = gr.Chatbot(label='Chatbot') |
|
imagebox = gr.File( |
|
file_types=["image"], |
|
show_label=False, |
|
) |
|
with gr.Row(): |
|
textbox = gr.Textbox( |
|
container=False, |
|
show_label=False, |
|
interactive=False, |
|
placeholder='Upload an image...', |
|
scale=10, |
|
) |
|
submit_button = gr.Button('Submit', |
|
variant='primary', |
|
scale=1, |
|
min_width=0) |
|
with gr.Row(): |
|
retry_button = gr.Button('🔄 Retry', variant='secondary') |
|
undo_button = gr.Button('↩️ Undo', variant='secondary') |
|
clear_button = gr.Button('🗑️ Clear', variant='secondary') |
|
|
|
saved_input = gr.State() |
|
|
|
with gr.Accordion(label='Advanced options', open=False): |
|
system_prompt = gr.Textbox(label='System prompt', |
|
value=DEFAULT_SYSTEM_PROMPT, |
|
lines=6) |
|
max_new_tokens = gr.Slider( |
|
label='Max new tokens', |
|
minimum=1, |
|
maximum=MAX_MAX_NEW_TOKENS, |
|
step=1, |
|
value=DEFAULT_MAX_NEW_TOKENS, |
|
) |
|
temperature = gr.Slider( |
|
label='Temperature', |
|
minimum=0.1, |
|
maximum=4.0, |
|
step=0.1, |
|
value=0.5, |
|
) |
|
top_p = gr.Slider( |
|
label='Top-p (nucleus sampling)', |
|
minimum=0.05, |
|
maximum=1.0, |
|
step=0.05, |
|
value=0.9, |
|
) |
|
top_k = gr.Slider( |
|
label='Top-k', |
|
minimum=1, |
|
maximum=1000, |
|
step=1, |
|
value=20, |
|
) |
|
|
|
gr.Markdown(LICENSE) |
|
imagebox.upload( |
|
fn=upload_image, |
|
inputs=imagebox, |
|
outputs=[imagebox, textbox, chatbot], |
|
api_name=None, |
|
queue=False, |
|
) |
|
|
|
textbox.submit( |
|
fn=clear_and_save_textbox, |
|
inputs=textbox, |
|
outputs=[textbox, saved_input], |
|
api_name=None, |
|
queue=False, |
|
).then( |
|
fn=display_input, |
|
inputs=[saved_input, chatbot], |
|
outputs=chatbot, |
|
api_name=None, |
|
queue=False, |
|
).then( |
|
fn=check_input_token_length, |
|
inputs=[saved_input, chatbot, system_prompt], |
|
api_name=None, |
|
queue=False, |
|
).success( |
|
fn=generate, |
|
inputs=[ |
|
saved_input, |
|
chatbot, |
|
system_prompt, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
top_k, |
|
], |
|
outputs=chatbot, |
|
api_name="generate", |
|
) |
|
|
|
button_event_preprocess = submit_button.click( |
|
fn=clear_and_save_textbox, |
|
inputs=textbox, |
|
outputs=[textbox, saved_input], |
|
api_name=None, |
|
queue=False, |
|
).then( |
|
fn=display_input, |
|
inputs=[saved_input, chatbot], |
|
outputs=chatbot, |
|
api_name=None, |
|
queue=False, |
|
).then( |
|
fn=check_input_token_length, |
|
inputs=[saved_input, chatbot, system_prompt], |
|
api_name=None, |
|
queue=False, |
|
).success( |
|
fn=generate, |
|
inputs=[ |
|
saved_input, |
|
chatbot, |
|
system_prompt, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
top_k, |
|
], |
|
outputs=chatbot, |
|
api_name=None, |
|
) |
|
|
|
retry_button.click( |
|
fn=delete_prev_fn, |
|
inputs=chatbot, |
|
outputs=[chatbot, saved_input], |
|
api_name=None, |
|
queue=False, |
|
).then( |
|
fn=display_input, |
|
inputs=[saved_input, chatbot], |
|
outputs=chatbot, |
|
api_name=None, |
|
queue=False, |
|
).then( |
|
fn=generate, |
|
inputs=[ |
|
saved_input, |
|
chatbot, |
|
system_prompt, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
top_k, |
|
], |
|
outputs=chatbot, |
|
api_name=None, |
|
) |
|
|
|
undo_button.click( |
|
fn=delete_prev_fn, |
|
inputs=chatbot, |
|
outputs=[chatbot, saved_input], |
|
api_name=None, |
|
queue=False, |
|
).then( |
|
fn=lambda x: x, |
|
inputs=[saved_input], |
|
outputs=textbox, |
|
api_name=None, |
|
queue=False, |
|
) |
|
|
|
clear_button.click( |
|
fn=lambda: ([], '', gr.update(value=None, visible=True), gr.update(interactive=False, placeholder='Upload an image...',)), |
|
outputs=[chatbot, saved_input, imagebox, textbox], |
|
queue=False, |
|
api_name=None, |
|
) |
|
|
|
return demo |
|
|
|
|
|
def main(model_id: str = 'jxu124/TiO', host: str = "0.0.0.0", port: int = None): |
|
assert torch.cuda.is_available() |
|
model = AutoModel.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float16).cuda() |
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) |
|
image_processor = AutoImageProcessor.from_pretrained(model_id) |
|
|
|
|
|
model.get_gradio_demo(tokenizer, image_processor).queue(max_size=20).launch(server_name=host, server_port=port) |
|
|
|
|
|
if __name__ == "__main__": |
|
import fire |
|
fire.Fire(main) |
|
|