xj
commited on
Commit
·
f7a3ffd
1
Parent(s):
a76d79d
feat: 添加了gradio 的界面.
Browse files- README.md +27 -3
- gradio_app.py +353 -0
- modeling_tio.py +12 -2
- utils_tio.py +56 -0
README.md
CHANGED
@@ -9,6 +9,27 @@ language:
|
|
9 |
|
10 |
TiO is an Interactive Visual Grounding Model for Disambiguation. (WIP)
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
## Mini-Example
|
13 |
```python
|
14 |
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
|
@@ -21,18 +42,21 @@ import requests
|
|
21 |
tokenizer = AutoTokenizer.from_pretrained("jxu124/TiO", use_fast=False)
|
22 |
image_processor = AutoImageProcessor.from_pretrained("jxu124/TiO")
|
23 |
model = AutoModel.from_pretrained("jxu124/TiO", trust_remote_code=True)
|
24 |
-
model = model.to(torch.float16).cuda() # It
|
25 |
|
26 |
# Prepare example
|
27 |
image = Image.open(BytesIO(requests.get("http://images.cocodataset.org/val2014/COCO_val2014_000000429913.jpg").content))
|
28 |
-
text = "
|
|
|
|
|
|
|
29 |
|
30 |
# Inference
|
31 |
with torch.no_grad():
|
32 |
pt_txt = tokenizer([text], return_tensors="pt").input_ids.cuda()
|
33 |
pt_img = image_processor([image], return_tensors="pt").pixel_values.to(torch.float16).cuda()
|
34 |
gen = model.generate(pt_txt, patch_images=pt_img, top_p=0.5, do_sample=True, no_repeat_ngram_size=3, max_length=256)
|
35 |
-
print(tokenizer.batch_decode(gen, skip_special_tokens=True))
|
36 |
# e.g. [' is he the one who just threw the ball?'] # Due to the generator, different results may be output
|
37 |
```
|
38 |
|
|
|
9 |
|
10 |
TiO is an Interactive Visual Grounding Model for Disambiguation. (WIP)
|
11 |
|
12 |
+
## Online / offline Demo
|
13 |
+
|
14 |
+
```python
|
15 |
+
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
|
16 |
+
|
17 |
+
model_id = "jxu124/TiO"
|
18 |
+
model = AutoModel.from_pretrained(
|
19 |
+
model_id,
|
20 |
+
trust_remote_code=True,
|
21 |
+
torch_dtype=torch.float16,
|
22 |
+
device_map='cuda',
|
23 |
+
# load_in_4bit=True,
|
24 |
+
# bnb_4bit_compute_dtype=torch.float16,
|
25 |
+
)
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
27 |
+
image_processor = AutoImageProcessor.from_pretrained(model_id)
|
28 |
+
# setup gradio demo
|
29 |
+
model.get_gradio_demo(tokenizer, image_processor).\
|
30 |
+
queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860)
|
31 |
+
```
|
32 |
+
|
33 |
## Mini-Example
|
34 |
```python
|
35 |
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
|
|
|
42 |
tokenizer = AutoTokenizer.from_pretrained("jxu124/TiO", use_fast=False)
|
43 |
image_processor = AutoImageProcessor.from_pretrained("jxu124/TiO")
|
44 |
model = AutoModel.from_pretrained("jxu124/TiO", trust_remote_code=True)
|
45 |
+
model = model.to(torch.float16).cuda() # It would be faster.
|
46 |
|
47 |
# Prepare example
|
48 |
image = Image.open(BytesIO(requests.get("http://images.cocodataset.org/val2014/COCO_val2014_000000429913.jpg").content))
|
49 |
+
text = """\
|
50 |
+
#instruction: can you specify which region the context describes?
|
51 |
+
#context:
|
52 |
+
human: look that man in white!"""
|
53 |
|
54 |
# Inference
|
55 |
with torch.no_grad():
|
56 |
pt_txt = tokenizer([text], return_tensors="pt").input_ids.cuda()
|
57 |
pt_img = image_processor([image], return_tensors="pt").pixel_values.to(torch.float16).cuda()
|
58 |
gen = model.generate(pt_txt, patch_images=pt_img, top_p=0.5, do_sample=True, no_repeat_ngram_size=3, max_length=256)
|
59 |
+
print(tokenizer.batch_decode(gen, skip_special_tokens=True).replace("not yet.", ""))
|
60 |
# e.g. [' is he the one who just threw the ball?'] # Due to the generator, different results may be output
|
61 |
```
|
62 |
|
gradio_app.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from threading import Thread
|
2 |
+
from typing import Iterator
|
3 |
+
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor, TextIteratorStreamer
|
4 |
+
from PIL import Image as PILImage
|
5 |
+
import tempfile
|
6 |
+
import torch
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
|
10 |
+
def get_gradio_demo(model, tokenizer, image_processor) -> gr.Interface:
|
11 |
+
|
12 |
+
def get_prompt(message: str, chat_history: list[tuple[str, str]],
|
13 |
+
system_prompt: str) -> str:
|
14 |
+
texts = [f'#instruction: {system_prompt}\n', '#context:\n']
|
15 |
+
texts += [f"human: {user_input.strip()}\nagent: {response.strip()}\n" for user_input, response in chat_history if isinstance(user_input, str)]
|
16 |
+
texts += [f'human: {message.strip()}']
|
17 |
+
return ''.join(texts)
|
18 |
+
|
19 |
+
|
20 |
+
def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
|
21 |
+
prompt = get_prompt(message, chat_history, system_prompt)
|
22 |
+
input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
|
23 |
+
return input_ids.shape[-1]
|
24 |
+
|
25 |
+
|
26 |
+
def run(image: PILImage.Image,
|
27 |
+
message: str,
|
28 |
+
chat_history: list[tuple[str, str]],
|
29 |
+
system_prompt: str,
|
30 |
+
max_new_tokens: int = 192,
|
31 |
+
temperature: float = 0.1,
|
32 |
+
top_p: float = 0.9,
|
33 |
+
top_k: int = 50) -> Iterator[str]:
|
34 |
+
prompt = get_prompt(message, chat_history, system_prompt)
|
35 |
+
patch_images = image_processor([image], return_tensors="pt").pixel_values.to(torch.float16).to('cuda')
|
36 |
+
inputs = tokenizer([prompt], return_tensors='pt').to('cuda')
|
37 |
+
|
38 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=10.) #
|
39 |
+
generate_kwargs = dict(
|
40 |
+
inputs,
|
41 |
+
patch_images=patch_images,
|
42 |
+
streamer=streamer,
|
43 |
+
max_length=max_new_tokens,
|
44 |
+
do_sample=True,
|
45 |
+
top_p=top_p,
|
46 |
+
top_k=top_k,
|
47 |
+
temperature=temperature,
|
48 |
+
num_beams=1,
|
49 |
+
)
|
50 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
51 |
+
t.start()
|
52 |
+
|
53 |
+
outputs = []
|
54 |
+
for text in streamer:
|
55 |
+
outputs.append(text)
|
56 |
+
yield ''.join(outputs).replace("not yet.", "").replace("<s>", "").replace("</s>", "").strip()
|
57 |
+
|
58 |
+
# -------
|
59 |
+
|
60 |
+
DEFAULT_SYSTEM_PROMPT = """can you specify which region the context describes?"""
|
61 |
+
MAX_MAX_NEW_TOKENS = 512
|
62 |
+
DEFAULT_MAX_NEW_TOKENS = 128
|
63 |
+
MAX_INPUT_TOKEN_LENGTH = 512
|
64 |
+
|
65 |
+
DESCRIPTION = """<h1 align="center">TiO Demo</h1>
|
66 |
+
<div align="center">https://huggingface.co/jxu124/TiO</div>
|
67 |
+
"""
|
68 |
+
|
69 |
+
LICENSE = """
|
70 |
+
<p/>
|
71 |
+
|
72 |
+
---
|
73 |
+
"""
|
74 |
+
|
75 |
+
if not torch.cuda.is_available():
|
76 |
+
DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
|
77 |
+
|
78 |
+
|
79 |
+
def upload_image(file_obj):
|
80 |
+
chatbot = [[(file_obj.name,), None]]
|
81 |
+
return (gr.update(visible=False), gr.update(interactive=True, placeholder='Type a message...',), chatbot)
|
82 |
+
|
83 |
+
|
84 |
+
def clear_and_save_textbox(message: str) -> tuple[str, str]:
|
85 |
+
return '', message
|
86 |
+
|
87 |
+
|
88 |
+
def display_input(message: str,
|
89 |
+
history: list[tuple[str, str]]) -> list[tuple[str, str]]:
|
90 |
+
if len(history) == 0:
|
91 |
+
raise gr.Error(f'Upload an image first and try again.')
|
92 |
+
history.append((message, ''))
|
93 |
+
return history
|
94 |
+
|
95 |
+
|
96 |
+
def delete_prev_fn(
|
97 |
+
history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
|
98 |
+
try:
|
99 |
+
message, _ = history.pop()
|
100 |
+
except IndexError:
|
101 |
+
message = ''
|
102 |
+
return history, message or ''
|
103 |
+
|
104 |
+
|
105 |
+
def generate(
|
106 |
+
message: str,
|
107 |
+
history_with_input: list[tuple[str, str]],
|
108 |
+
system_prompt: str,
|
109 |
+
max_new_tokens: int,
|
110 |
+
temperature: float,
|
111 |
+
top_p: float,
|
112 |
+
top_k: int,
|
113 |
+
) -> Iterator[list[tuple[str, str]]]:
|
114 |
+
if max_new_tokens > MAX_MAX_NEW_TOKENS:
|
115 |
+
raise ValueError
|
116 |
+
|
117 |
+
image = PILImage.open(history_with_input[0][0][0])
|
118 |
+
history = history_with_input[:-1]
|
119 |
+
generator = run(image, message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
|
120 |
+
try:
|
121 |
+
first_response = next(generator)
|
122 |
+
yield history + [(message, first_response)]
|
123 |
+
except StopIteration:
|
124 |
+
yield history + [(message, '')]
|
125 |
+
for response in generator:
|
126 |
+
chatbot = history + [(message, response)]
|
127 |
+
if "region:" in response:
|
128 |
+
bboxes = model.utils.sbbox_to_bbox(response)
|
129 |
+
if len(bboxes):
|
130 |
+
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
131 |
+
model.utils.show_mask(image, bboxes).save(f)
|
132 |
+
chatbot += [(None, (f.name,))]
|
133 |
+
yield chatbot
|
134 |
+
|
135 |
+
|
136 |
+
def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
|
137 |
+
generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 192, 1, 0.95, 50)
|
138 |
+
for x in generator:
|
139 |
+
pass
|
140 |
+
return '', x
|
141 |
+
|
142 |
+
|
143 |
+
def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
|
144 |
+
input_token_length = get_input_token_length(message, chat_history[:-1], system_prompt)
|
145 |
+
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
|
146 |
+
raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
|
147 |
+
|
148 |
+
|
149 |
+
with gr.Blocks() as demo:
|
150 |
+
gr.Markdown(DESCRIPTION)
|
151 |
+
|
152 |
+
with gr.Group():
|
153 |
+
chatbot = gr.Chatbot(label='Chatbot')
|
154 |
+
imagebox = gr.File(
|
155 |
+
file_types=["image"],
|
156 |
+
show_label=False,
|
157 |
+
)
|
158 |
+
with gr.Row():
|
159 |
+
textbox = gr.Textbox(
|
160 |
+
container=False,
|
161 |
+
show_label=False,
|
162 |
+
interactive=False,
|
163 |
+
placeholder='Upload an image...',
|
164 |
+
scale=10,
|
165 |
+
)
|
166 |
+
submit_button = gr.Button('Submit',
|
167 |
+
variant='primary',
|
168 |
+
scale=1,
|
169 |
+
min_width=0)
|
170 |
+
with gr.Row():
|
171 |
+
retry_button = gr.Button('🔄 Retry', variant='secondary')
|
172 |
+
undo_button = gr.Button('↩️ Undo', variant='secondary')
|
173 |
+
clear_button = gr.Button('🗑️ Clear', variant='secondary')
|
174 |
+
|
175 |
+
saved_input = gr.State()
|
176 |
+
|
177 |
+
with gr.Accordion(label='Advanced options', open=False):
|
178 |
+
system_prompt = gr.Textbox(label='System prompt',
|
179 |
+
value=DEFAULT_SYSTEM_PROMPT,
|
180 |
+
lines=6)
|
181 |
+
max_new_tokens = gr.Slider(
|
182 |
+
label='Max new tokens',
|
183 |
+
minimum=1,
|
184 |
+
maximum=MAX_MAX_NEW_TOKENS,
|
185 |
+
step=1,
|
186 |
+
value=DEFAULT_MAX_NEW_TOKENS,
|
187 |
+
)
|
188 |
+
temperature = gr.Slider(
|
189 |
+
label='Temperature',
|
190 |
+
minimum=0.1,
|
191 |
+
maximum=4.0,
|
192 |
+
step=0.1,
|
193 |
+
value=0.5,
|
194 |
+
)
|
195 |
+
top_p = gr.Slider(
|
196 |
+
label='Top-p (nucleus sampling)',
|
197 |
+
minimum=0.05,
|
198 |
+
maximum=1.0,
|
199 |
+
step=0.05,
|
200 |
+
value=0.9,
|
201 |
+
)
|
202 |
+
top_k = gr.Slider(
|
203 |
+
label='Top-k',
|
204 |
+
minimum=1,
|
205 |
+
maximum=1000,
|
206 |
+
step=1,
|
207 |
+
value=20,
|
208 |
+
)
|
209 |
+
|
210 |
+
gr.Markdown(LICENSE)
|
211 |
+
imagebox.upload(
|
212 |
+
fn=upload_image,
|
213 |
+
inputs=imagebox,
|
214 |
+
outputs=[imagebox, textbox, chatbot],
|
215 |
+
api_name=None,
|
216 |
+
queue=False,
|
217 |
+
)
|
218 |
+
|
219 |
+
textbox.submit(
|
220 |
+
fn=clear_and_save_textbox,
|
221 |
+
inputs=textbox,
|
222 |
+
outputs=[textbox, saved_input],
|
223 |
+
api_name=None,
|
224 |
+
queue=False,
|
225 |
+
).then(
|
226 |
+
fn=display_input,
|
227 |
+
inputs=[saved_input, chatbot],
|
228 |
+
outputs=chatbot,
|
229 |
+
api_name=None,
|
230 |
+
queue=False,
|
231 |
+
).then(
|
232 |
+
fn=check_input_token_length,
|
233 |
+
inputs=[saved_input, chatbot, system_prompt],
|
234 |
+
api_name=None,
|
235 |
+
queue=False,
|
236 |
+
).success(
|
237 |
+
fn=generate,
|
238 |
+
inputs=[
|
239 |
+
saved_input,
|
240 |
+
chatbot,
|
241 |
+
system_prompt,
|
242 |
+
max_new_tokens,
|
243 |
+
temperature,
|
244 |
+
top_p,
|
245 |
+
top_k,
|
246 |
+
],
|
247 |
+
outputs=chatbot,
|
248 |
+
api_name="generate",
|
249 |
+
)
|
250 |
+
|
251 |
+
button_event_preprocess = submit_button.click(
|
252 |
+
fn=clear_and_save_textbox,
|
253 |
+
inputs=textbox,
|
254 |
+
outputs=[textbox, saved_input],
|
255 |
+
api_name=None,
|
256 |
+
queue=False,
|
257 |
+
).then(
|
258 |
+
fn=display_input,
|
259 |
+
inputs=[saved_input, chatbot],
|
260 |
+
outputs=chatbot,
|
261 |
+
api_name=None,
|
262 |
+
queue=False,
|
263 |
+
).then(
|
264 |
+
fn=check_input_token_length,
|
265 |
+
inputs=[saved_input, chatbot, system_prompt],
|
266 |
+
api_name=None,
|
267 |
+
queue=False,
|
268 |
+
).success(
|
269 |
+
fn=generate,
|
270 |
+
inputs=[
|
271 |
+
saved_input,
|
272 |
+
chatbot,
|
273 |
+
system_prompt,
|
274 |
+
max_new_tokens,
|
275 |
+
temperature,
|
276 |
+
top_p,
|
277 |
+
top_k,
|
278 |
+
],
|
279 |
+
outputs=chatbot,
|
280 |
+
api_name=None,
|
281 |
+
)
|
282 |
+
|
283 |
+
retry_button.click(
|
284 |
+
fn=delete_prev_fn,
|
285 |
+
inputs=chatbot,
|
286 |
+
outputs=[chatbot, saved_input],
|
287 |
+
api_name=None,
|
288 |
+
queue=False,
|
289 |
+
).then(
|
290 |
+
fn=display_input,
|
291 |
+
inputs=[saved_input, chatbot],
|
292 |
+
outputs=chatbot,
|
293 |
+
api_name=None,
|
294 |
+
queue=False,
|
295 |
+
).then(
|
296 |
+
fn=generate,
|
297 |
+
inputs=[
|
298 |
+
saved_input,
|
299 |
+
chatbot,
|
300 |
+
system_prompt,
|
301 |
+
max_new_tokens,
|
302 |
+
temperature,
|
303 |
+
top_p,
|
304 |
+
top_k,
|
305 |
+
],
|
306 |
+
outputs=chatbot,
|
307 |
+
api_name=None,
|
308 |
+
)
|
309 |
+
|
310 |
+
undo_button.click(
|
311 |
+
fn=delete_prev_fn,
|
312 |
+
inputs=chatbot,
|
313 |
+
outputs=[chatbot, saved_input],
|
314 |
+
api_name=None,
|
315 |
+
queue=False,
|
316 |
+
).then(
|
317 |
+
fn=lambda x: x,
|
318 |
+
inputs=[saved_input],
|
319 |
+
outputs=textbox,
|
320 |
+
api_name=None,
|
321 |
+
queue=False,
|
322 |
+
)
|
323 |
+
|
324 |
+
clear_button.click(
|
325 |
+
fn=lambda: ([], '', gr.update(value=None, visible=True), gr.update(interactive=False, placeholder='Upload an image...',)),
|
326 |
+
outputs=[chatbot, saved_input, imagebox, textbox],
|
327 |
+
queue=False,
|
328 |
+
api_name=None,
|
329 |
+
)
|
330 |
+
|
331 |
+
return demo
|
332 |
+
|
333 |
+
|
334 |
+
def main(model_id: str = 'jxu124/TiO', host: str = "0.0.0.0", port: int = None):
|
335 |
+
if torch.cuda.is_available():
|
336 |
+
model = AutoModel.from_pretrained(
|
337 |
+
model_id,
|
338 |
+
trust_remote_code=True,
|
339 |
+
torch_dtype=torch.float16,
|
340 |
+
device_map='cuda',
|
341 |
+
# load_in_4bit=True,
|
342 |
+
# bnb_4bit_compute_dtype=torch.float16,
|
343 |
+
)
|
344 |
+
else:
|
345 |
+
model = None
|
346 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
347 |
+
image_processor = AutoImageProcessor.from_pretrained(model_id)
|
348 |
+
model.get_gradio_demo(tokenizer, image_processor).queue(max_size=20).launch(server_name=host, server_port=port)
|
349 |
+
|
350 |
+
|
351 |
+
if __name__ == "__main__":
|
352 |
+
import fire
|
353 |
+
fire.Fire(main)
|
modeling_tio.py
CHANGED
@@ -87,8 +87,12 @@ def make_token_bucket_position(bucket_size, max_position=DEFAULT_MAX_SOURCE_POSI
|
|
87 |
sign = torch.sign(relative_pos)
|
88 |
mid = bucket_size // 2
|
89 |
abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, torch.abs(relative_pos))
|
90 |
-
|
91 |
-
log_pos =
|
|
|
|
|
|
|
|
|
92 |
bucket_pos = torch.where(abs_pos.le(mid), relative_pos, log_pos * sign).long()
|
93 |
return bucket_pos + bucket_size - 1
|
94 |
|
@@ -2013,3 +2017,9 @@ class TiOModel(TiOPreTrainedModel):
|
|
2013 |
)
|
2014 |
model_kwargs["encoder_outputs"] = encoder_outputs
|
2015 |
return input_ids, model_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
sign = torch.sign(relative_pos)
|
88 |
mid = bucket_size // 2
|
89 |
abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, torch.abs(relative_pos))
|
90 |
+
# import pdb; pdb.set_trace()
|
91 |
+
# log_pos = torch.ceil(torch.log(abs_pos / mid) / math.log((max_position - 1) / mid) * (mid - 1)) + mid
|
92 |
+
# log_pos = log_pos.int()
|
93 |
+
import numpy as np
|
94 |
+
log_pos = np.ceil(np.log(abs_pos.numpy() / mid) / math.log((max_position - 1) / mid) * (mid - 1)) + mid
|
95 |
+
log_pos = torch.tensor(log_pos.astype('int64'))
|
96 |
bucket_pos = torch.where(abs_pos.le(mid), relative_pos, log_pos * sign).long()
|
97 |
return bucket_pos + bucket_size - 1
|
98 |
|
|
|
2017 |
)
|
2018 |
model_kwargs["encoder_outputs"] = encoder_outputs
|
2019 |
return input_ids, model_kwargs
|
2020 |
+
|
2021 |
+
|
2022 |
+
from .utils_tio import Utils
|
2023 |
+
from .gradio_app import get_gradio_demo
|
2024 |
+
TiOModel.utils = Utils
|
2025 |
+
TiOModel.get_gradio_demo = get_gradio_demo
|
utils_tio.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image as PILImage
|
2 |
+
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
|
3 |
+
import re
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
class Utils():
|
9 |
+
def xywh2xyxy(b):
|
10 |
+
b[..., 2:] += b[..., :2]
|
11 |
+
return b
|
12 |
+
|
13 |
+
def bbox_to_sbbox(bbox):
|
14 |
+
# xyxy in [0, 1]
|
15 |
+
assert len(bbox) == 4
|
16 |
+
bbox = (np.asarray(bbox) * 1000).astype(np.int16)
|
17 |
+
bbox = np.clip(bbox, 0, 999)
|
18 |
+
bbox = " ".join([f"<bin_{i}>" for i in bbox])
|
19 |
+
return bbox
|
20 |
+
|
21 |
+
def sbbox_to_bbox(sbbox):
|
22 |
+
sbbox = [re.findall(r"<bin_(\d+)>", s)[:4] for s in sbbox.split("region:")]
|
23 |
+
bbox = np.asarray([s for s in sbbox if len(s) >= 4], dtype=int)
|
24 |
+
bbox = np.clip(bbox / 1000, 1e-3, 1 - 1e-3)
|
25 |
+
return bbox.reshape(-1, 4)
|
26 |
+
|
27 |
+
def make_dialog_context(dialog: list, text_human: str = None) -> str:
|
28 |
+
# dialog: [("pass me an apple.", "which apple do you want?"), ...]
|
29 |
+
context = "".join([f"human: {d[0]}\nagent: {d[1]}\n" for d in dialog])
|
30 |
+
if text_human is not None:
|
31 |
+
context += f"human: {text_human}"
|
32 |
+
return context
|
33 |
+
|
34 |
+
def show_mask(image: PILImage.Image, bboxes=None, masks=None, show_id=False, text_size=1) -> PILImage.Image:
|
35 |
+
""" 给图片画上mask: 只更改被mask标记部分的rgb值. """
|
36 |
+
import colorsys
|
37 |
+
colors = [tuple(int(c * 255) for c in colorsys.hsv_to_rgb(i * 1.0 / 36, 1, 1)) for i in range(36)]
|
38 |
+
size = image.size
|
39 |
+
image = np.asarray(image)
|
40 |
+
if bboxes is not None:
|
41 |
+
bboxes = np.array(bboxes).reshape(-1, 4)
|
42 |
+
for k, bbox in enumerate(bboxes):
|
43 |
+
bbox = (np.asarray(bbox) * np.asarray([*size, *size])).astype(int)
|
44 |
+
image = cv2.rectangle(image, tuple(bbox[:2]), tuple(bbox[2:]), tuple(colors[k]), thickness=2)
|
45 |
+
if show_id:
|
46 |
+
for k, bbox in enumerate(bboxes):
|
47 |
+
bbox = (np.asarray(bbox) * np.asarray([*size, *size])).astype(int)
|
48 |
+
image = cv2.putText(image, str(k), tuple(bbox[:2] + np.array([2, 28 * text_size])), cv2.FONT_HERSHEY_SIMPLEX, text_size, (255, 255, 255), 2, cv2.LINE_AA)
|
49 |
+
image = cv2.putText(image, str(k), tuple(bbox[:2] + np.array([2, 28 * text_size])), cv2.FONT_HERSHEY_SIMPLEX, text_size, tuple(colors[k%len(colors)]), 1, cv2.LINE_AA)
|
50 |
+
|
51 |
+
if masks is not None:
|
52 |
+
for k, mask in enumerate(masks):
|
53 |
+
mask_color = (mask[..., None] * colors[k%len(colors)][:3]).astype(np.uint8)
|
54 |
+
image_mask = cv2.addWeighted(mask_color, 0.5, image * mask[..., None], 0.5, 0)
|
55 |
+
image = cv2.add(image * ~mask[..., None], image_mask)
|
56 |
+
return PILImage.fromarray(image)
|