xj
commited on
Commit
·
f4d598d
1
Parent(s):
a2c1aea
fix: fix bugs.
Browse files- gradio_app.py +5 -12
- modeling_tio.py +1 -1
gradio_app.py
CHANGED
@@ -333,22 +333,15 @@ def get_gradio_demo(model, tokenizer, image_processor) -> gr.Interface:
|
|
333 |
|
334 |
|
335 |
def main(model_id: str = 'jxu124/TiO', host: str = "0.0.0.0", port: int = None):
|
336 |
-
|
337 |
-
|
338 |
-
model_id,
|
339 |
-
trust_remote_code=True,
|
340 |
-
torch_dtype=torch.float16,
|
341 |
-
device_map='cuda',
|
342 |
-
# load_in_4bit=True,
|
343 |
-
# bnb_4bit_compute_dtype=torch.float16,
|
344 |
-
)
|
345 |
-
else:
|
346 |
-
model = None
|
347 |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
348 |
image_processor = AutoImageProcessor.from_pretrained(model_id)
|
|
|
|
|
349 |
model.get_gradio_demo(tokenizer, image_processor).queue(max_size=20).launch(server_name=host, server_port=port)
|
350 |
|
351 |
|
352 |
if __name__ == "__main__":
|
353 |
import fire
|
354 |
-
fire.Fire(main)
|
|
|
333 |
|
334 |
|
335 |
def main(model_id: str = 'jxu124/TiO', host: str = "0.0.0.0", port: int = None):
|
336 |
+
assert torch.cuda.is_available()
|
337 |
+
model = AutoModel.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float16).cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
339 |
image_processor = AutoImageProcessor.from_pretrained(model_id)
|
340 |
+
|
341 |
+
# ---- gradio demo ----
|
342 |
model.get_gradio_demo(tokenizer, image_processor).queue(max_size=20).launch(server_name=host, server_port=port)
|
343 |
|
344 |
|
345 |
if __name__ == "__main__":
|
346 |
import fire
|
347 |
+
fire.Fire(main)
|
modeling_tio.py
CHANGED
@@ -103,7 +103,7 @@ def make_image_bucket_position(bucket_size, num_relative_distance):
|
|
103 |
"""
|
104 |
coords_h = torch.arange(bucket_size)
|
105 |
coords_w = torch.arange(bucket_size)
|
106 |
-
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
107 |
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
108 |
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
109 |
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
|
|
103 |
"""
|
104 |
coords_h = torch.arange(bucket_size)
|
105 |
coords_w = torch.arange(bucket_size)
|
106 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
|
107 |
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
108 |
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
109 |
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|