Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +21 -0
- .ipynb_checkpoints/geochat_demo-checkpoint.py +707 -0
- .ipynb_checkpoints/pyproject-checkpoint.toml +39 -0
- README.md +227 -8
- demo_images/04133.png +3 -0
- demo_images/04444.png +3 -0
- demo_images/7292.JPG +3 -0
- demo_images/MicrosoftTeams-image.png +3 -0
- demo_images/church_183.png +3 -0
- demo_images/train_2956_0001.png +3 -0
- docs/Customize_Component.md +20 -0
- docs/Data.md +24 -0
- docs/Evaluation.md +54 -0
- docs/LoRA.md +24 -0
- docs/MODEL_ZOO.md +18 -0
- docs/geochat_supp.pdf +3 -0
- geochat.egg-info/PKG-INFO +260 -0
- geochat.egg-info/SOURCES.txt +51 -0
- geochat.egg-info/dependency_links.txt +1 -0
- geochat.egg-info/requires.txt +24 -0
- geochat.egg-info/top_level.txt +3 -0
- geochat/__init__.py +1 -0
- geochat/__pycache__/__init__.cpython-310.pyc +0 -0
- geochat/__pycache__/constants.cpython-310.pyc +0 -0
- geochat/__pycache__/conversation.cpython-310.pyc +0 -0
- geochat/__pycache__/mm_utils.cpython-310.pyc +0 -0
- geochat/__pycache__/utils.cpython-310.pyc +0 -0
- geochat/constants.py +12 -0
- geochat/conversation.py +520 -0
- geochat/eval/batch_geochat_grounding.py +138 -0
- geochat/eval/batch_geochat_referring.py +132 -0
- geochat/eval/batch_geochat_scene.py +139 -0
- geochat/eval/batch_geochat_vqa.py +125 -0
- geochat/mm_utils.py +121 -0
- geochat/model/.ipynb_checkpoints/__init__-checkpoint.py +2 -0
- geochat/model/.ipynb_checkpoints/builder-checkpoint.py +149 -0
- geochat/model/__init__.py +2 -0
- geochat/model/__pycache__/__init__.cpython-310.pyc +0 -0
- geochat/model/__pycache__/builder.cpython-310.pyc +0 -0
- geochat/model/__pycache__/geochat_arch.cpython-310.pyc +0 -0
- geochat/model/apply_delta.py +48 -0
- geochat/model/builder.py +149 -0
- geochat/model/consolidate.py +29 -0
- geochat/model/geochat_arch.py +262 -0
- geochat/model/language_model/.ipynb_checkpoints/geochat_llama-checkpoint.py +140 -0
- geochat/model/language_model/__pycache__/geochat_llama.cpython-310.pyc +0 -0
- geochat/model/language_model/__pycache__/geochat_mpt.cpython-310.pyc +0 -0
- geochat/model/language_model/geochat_llama.py +140 -0
- geochat/model/language_model/geochat_mpt.py +113 -0
- geochat/model/language_model/mpt/__pycache__/adapt_tokenizer.cpython-310.pyc +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,24 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
demo_images/04133.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
demo_images/04444.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
demo_images/7292.JPG filter=lfs diff=lfs merge=lfs -text
|
39 |
+
demo_images/MicrosoftTeams-image.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
demo_images/church_183.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
demo_images/train_2956_0001.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
docs/geochat_supp.pdf filter=lfs diff=lfs merge=lfs -text
|
43 |
+
geochat/serve/examples/11760.jpg filter=lfs diff=lfs merge=lfs -text
|
44 |
+
geochat/serve/examples/11765.jpg filter=lfs diff=lfs merge=lfs -text
|
45 |
+
images/architecture.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
images/dataset.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
images/examples.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
images/grounded.jpg filter=lfs diff=lfs merge=lfs -text
|
49 |
+
images/iden.jpg filter=lfs diff=lfs merge=lfs -text
|
50 |
+
images/logo_geochat.png filter=lfs diff=lfs merge=lfs -text
|
51 |
+
images/overview2.png filter=lfs diff=lfs merge=lfs -text
|
52 |
+
images/ref1.jpg filter=lfs diff=lfs merge=lfs -text
|
53 |
+
images/ref_2.jpg filter=lfs diff=lfs merge=lfs -text
|
54 |
+
images/scene.jpg filter=lfs diff=lfs merge=lfs -text
|
55 |
+
images/teaser.png filter=lfs diff=lfs merge=lfs -text
|
56 |
+
images/vqa.jpg filter=lfs diff=lfs merge=lfs -text
|
.ipynb_checkpoints/geochat_demo-checkpoint.py
ADDED
@@ -0,0 +1,707 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import re
|
8 |
+
import math
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
import torch
|
12 |
+
import html
|
13 |
+
import gradio as gr
|
14 |
+
|
15 |
+
import torchvision.transforms as T
|
16 |
+
import torch.backends.cudnn as cudnn
|
17 |
+
|
18 |
+
from geochat.conversation import conv_templates, Chat
|
19 |
+
from geochat.model.builder import load_pretrained_model
|
20 |
+
from geochat.mm_utils import get_model_name_from_path
|
21 |
+
|
22 |
+
|
23 |
+
def parse_args():
|
24 |
+
parser = argparse.ArgumentParser(description="Demo")
|
25 |
+
# parser = argparse.ArgumentParser()
|
26 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
27 |
+
parser.add_argument("--model-base", type=str, default=None)
|
28 |
+
parser.add_argument("--gpu-id", type=str,default=0)
|
29 |
+
parser.add_argument("--device", type=str, default="cuda")
|
30 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
31 |
+
parser.add_argument("--max-new-tokens", type=int, default=300)
|
32 |
+
parser.add_argument("--load-8bit", action="store_true")
|
33 |
+
parser.add_argument("--load-4bit", action="store_true")
|
34 |
+
parser.add_argument("--debug", action="store_true")
|
35 |
+
parser.add_argument("--image-aspect-ratio", type=str, default='pad')
|
36 |
+
# args = parser.parse_args()
|
37 |
+
args = parser.parse_args()
|
38 |
+
return args
|
39 |
+
|
40 |
+
|
41 |
+
random.seed(42)
|
42 |
+
np.random.seed(42)
|
43 |
+
torch.manual_seed(42)
|
44 |
+
|
45 |
+
cudnn.benchmark = False
|
46 |
+
cudnn.deterministic = True
|
47 |
+
|
48 |
+
print('Initializing Chat')
|
49 |
+
args = parse_args()
|
50 |
+
# cfg = Config(args)
|
51 |
+
|
52 |
+
model_name = get_model_name_from_path(args.model_path)
|
53 |
+
print(model_name)
|
54 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
|
55 |
+
|
56 |
+
device = 'cuda:{}'.format(args.gpu_id)
|
57 |
+
|
58 |
+
# model_config = cfg.model_cfg
|
59 |
+
# model_config.device_8bit = args.gpu_id
|
60 |
+
# model_cls = registry.get_model_class(model_config.arch)
|
61 |
+
# model = model_cls.from_config(model_config).to(device)
|
62 |
+
bounding_box_size = 100
|
63 |
+
|
64 |
+
# vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
65 |
+
# vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
66 |
+
|
67 |
+
model = model.eval()
|
68 |
+
|
69 |
+
CONV_VISION = conv_templates['llava_v1'].copy()
|
70 |
+
|
71 |
+
def bbox_and_angle_to_polygon(x1, y1, x2, y2, a):
|
72 |
+
# Calculate center coordinates
|
73 |
+
x_ctr = (x1 + x2) / 2
|
74 |
+
y_ctr = (y1 + y2) / 2
|
75 |
+
|
76 |
+
# Calculate width and height
|
77 |
+
w = abs(x2 - x1)
|
78 |
+
h = abs(y2 - y1)
|
79 |
+
|
80 |
+
# Calculate the angle in radians
|
81 |
+
angle_rad = math.radians(a)
|
82 |
+
|
83 |
+
# Calculate coordinates of the four corners of the rotated bounding box
|
84 |
+
cos_a = math.cos(angle_rad)
|
85 |
+
sin_a = math.sin(angle_rad)
|
86 |
+
|
87 |
+
x1_rot = cos_a * (-w / 2) - sin_a * (-h / 2) + x_ctr
|
88 |
+
y1_rot = sin_a * (-w / 2) + cos_a * (-h / 2) + y_ctr
|
89 |
+
|
90 |
+
x2_rot = cos_a * (w / 2) - sin_a * (-h / 2) + x_ctr
|
91 |
+
y2_rot = sin_a * (w / 2) + cos_a * (-h / 2) + y_ctr
|
92 |
+
|
93 |
+
x3_rot = cos_a * (w / 2) - sin_a * (h / 2) + x_ctr
|
94 |
+
y3_rot = sin_a * (w / 2) + cos_a * (h / 2) + y_ctr
|
95 |
+
|
96 |
+
x4_rot = cos_a * (-w / 2) - sin_a * (h / 2) + x_ctr
|
97 |
+
y4_rot = sin_a * (-w / 2) + cos_a * (h / 2) + y_ctr
|
98 |
+
|
99 |
+
# Return the polygon coordinates
|
100 |
+
polygon_coords = np.array((x1_rot, y1_rot, x2_rot, y2_rot, x3_rot, y3_rot, x4_rot, y4_rot))
|
101 |
+
|
102 |
+
return polygon_coords
|
103 |
+
|
104 |
+
def rotate_bbox(top_right, bottom_left, angle_degrees):
|
105 |
+
# Convert angle to radians
|
106 |
+
angle_radians = np.radians(angle_degrees)
|
107 |
+
|
108 |
+
# Calculate the center of the rectangle
|
109 |
+
center = ((top_right[0] + bottom_left[0]) / 2, (top_right[1] + bottom_left[1]) / 2)
|
110 |
+
|
111 |
+
# Calculate the width and height of the rectangle
|
112 |
+
width = top_right[0] - bottom_left[0]
|
113 |
+
height = top_right[1] - bottom_left[1]
|
114 |
+
|
115 |
+
# Create a rotation matrix
|
116 |
+
rotation_matrix = cv2.getRotationMatrix2D(center, angle_degrees, 1)
|
117 |
+
|
118 |
+
# Create an array of the rectangle corners
|
119 |
+
rectangle_points = np.array([[bottom_left[0], bottom_left[1]],
|
120 |
+
[top_right[0], bottom_left[1]],
|
121 |
+
[top_right[0], top_right[1]],
|
122 |
+
[bottom_left[0], top_right[1]]], dtype=np.float32)
|
123 |
+
|
124 |
+
# Rotate the rectangle points
|
125 |
+
rotated_rectangle = cv2.transform(np.array([rectangle_points]), rotation_matrix)[0]
|
126 |
+
|
127 |
+
return rotated_rectangle
|
128 |
+
def extract_substrings(string):
|
129 |
+
# first check if there is no-finished bracket
|
130 |
+
index = string.rfind('}')
|
131 |
+
if index != -1:
|
132 |
+
string = string[:index + 1]
|
133 |
+
|
134 |
+
pattern = r'<p>(.*?)\}(?!<)'
|
135 |
+
matches = re.findall(pattern, string)
|
136 |
+
substrings = [match for match in matches]
|
137 |
+
|
138 |
+
return substrings
|
139 |
+
|
140 |
+
|
141 |
+
def is_overlapping(rect1, rect2):
|
142 |
+
x1, y1, x2, y2 = rect1
|
143 |
+
x3, y3, x4, y4 = rect2
|
144 |
+
return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
|
145 |
+
|
146 |
+
|
147 |
+
def computeIoU(bbox1, bbox2):
|
148 |
+
x1, y1, x2, y2 = bbox1
|
149 |
+
x3, y3, x4, y4 = bbox2
|
150 |
+
intersection_x1 = max(x1, x3)
|
151 |
+
intersection_y1 = max(y1, y3)
|
152 |
+
intersection_x2 = min(x2, x4)
|
153 |
+
intersection_y2 = min(y2, y4)
|
154 |
+
intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
|
155 |
+
bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
156 |
+
bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
|
157 |
+
union_area = bbox1_area + bbox2_area - intersection_area
|
158 |
+
iou = intersection_area / union_area
|
159 |
+
return iou
|
160 |
+
|
161 |
+
|
162 |
+
def save_tmp_img(visual_img):
|
163 |
+
file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
|
164 |
+
file_path = "/tmp/gradio" + file_name
|
165 |
+
visual_img.save(file_path)
|
166 |
+
return file_path
|
167 |
+
|
168 |
+
|
169 |
+
def mask2bbox(mask):
|
170 |
+
if mask is None:
|
171 |
+
return ''
|
172 |
+
mask = mask.resize([100, 100], resample=Image.NEAREST)
|
173 |
+
mask = np.array(mask)[:, :, 0]
|
174 |
+
|
175 |
+
rows = np.any(mask, axis=1)
|
176 |
+
cols = np.any(mask, axis=0)
|
177 |
+
|
178 |
+
if rows.sum():
|
179 |
+
# Get the top, bottom, left, and right boundaries
|
180 |
+
rmin, rmax = np.where(rows)[0][[0, -1]]
|
181 |
+
cmin, cmax = np.where(cols)[0][[0, -1]]
|
182 |
+
bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
|
183 |
+
else:
|
184 |
+
bbox = ''
|
185 |
+
|
186 |
+
return bbox
|
187 |
+
|
188 |
+
|
189 |
+
def escape_markdown(text):
|
190 |
+
# List of Markdown special characters that need to be escaped
|
191 |
+
md_chars = ['<', '>']
|
192 |
+
|
193 |
+
# Escape each special character
|
194 |
+
for char in md_chars:
|
195 |
+
text = text.replace(char, '\\' + char)
|
196 |
+
|
197 |
+
return text
|
198 |
+
|
199 |
+
|
200 |
+
def reverse_escape(text):
|
201 |
+
md_chars = ['\\<', '\\>']
|
202 |
+
|
203 |
+
for char in md_chars:
|
204 |
+
text = text.replace(char, char[1:])
|
205 |
+
|
206 |
+
return text
|
207 |
+
|
208 |
+
|
209 |
+
colors = [
|
210 |
+
(255, 0, 0),
|
211 |
+
(0, 255, 0),
|
212 |
+
(0, 0, 255),
|
213 |
+
(210, 210, 0),
|
214 |
+
(255, 0, 255),
|
215 |
+
(0, 255, 255),
|
216 |
+
(114, 128, 250),
|
217 |
+
(0, 165, 255),
|
218 |
+
(0, 128, 0),
|
219 |
+
(144, 238, 144),
|
220 |
+
(238, 238, 175),
|
221 |
+
(255, 191, 0),
|
222 |
+
(0, 128, 0),
|
223 |
+
(226, 43, 138),
|
224 |
+
(255, 0, 255),
|
225 |
+
(0, 215, 255),
|
226 |
+
]
|
227 |
+
|
228 |
+
color_map = {
|
229 |
+
f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
|
230 |
+
color_id, color in enumerate(colors)
|
231 |
+
}
|
232 |
+
|
233 |
+
used_colors = colors
|
234 |
+
|
235 |
+
|
236 |
+
def visualize_all_bbox_together(image, generation):
|
237 |
+
if image is None:
|
238 |
+
return None, ''
|
239 |
+
|
240 |
+
generation = html.unescape(generation)
|
241 |
+
|
242 |
+
image_width, image_height = image.size
|
243 |
+
image = image.resize([500, int(500 / image_width * image_height)])
|
244 |
+
image_width, image_height = image.size
|
245 |
+
|
246 |
+
string_list = extract_substrings(generation)
|
247 |
+
if string_list: # it is grounding or detection
|
248 |
+
mode = 'all'
|
249 |
+
entities = defaultdict(list)
|
250 |
+
i = 0
|
251 |
+
j = 0
|
252 |
+
for string in string_list:
|
253 |
+
try:
|
254 |
+
obj, string = string.split('</p>')
|
255 |
+
except ValueError:
|
256 |
+
print('wrong string: ', string)
|
257 |
+
continue
|
258 |
+
if "}{" in string:
|
259 |
+
string=string.replace("}{","}<delim>{")
|
260 |
+
bbox_list = string.split('<delim>')
|
261 |
+
flag = False
|
262 |
+
for bbox_string in bbox_list:
|
263 |
+
integers = re.findall(r'-?\d+', bbox_string)
|
264 |
+
if len(integers)==4:
|
265 |
+
angle=0
|
266 |
+
else:
|
267 |
+
angle=integers[4]
|
268 |
+
integers=integers[:-1]
|
269 |
+
|
270 |
+
if len(integers) == 4:
|
271 |
+
x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
|
272 |
+
left = x0 / bounding_box_size * image_width
|
273 |
+
bottom = y0 / bounding_box_size * image_height
|
274 |
+
right = x1 / bounding_box_size * image_width
|
275 |
+
top = y1 / bounding_box_size * image_height
|
276 |
+
|
277 |
+
entities[obj].append([left, bottom, right, top,angle])
|
278 |
+
|
279 |
+
j += 1
|
280 |
+
flag = True
|
281 |
+
if flag:
|
282 |
+
i += 1
|
283 |
+
else:
|
284 |
+
integers = re.findall(r'-?\d+', generation)
|
285 |
+
# if len(integers)==4:
|
286 |
+
angle=0
|
287 |
+
# else:
|
288 |
+
# angle=integers[4]
|
289 |
+
integers=integers[:-1]
|
290 |
+
if len(integers) == 4: # it is refer
|
291 |
+
mode = 'single'
|
292 |
+
|
293 |
+
entities = list()
|
294 |
+
x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
|
295 |
+
left = x0 / bounding_box_size * image_width
|
296 |
+
bottom = y0 / bounding_box_size * image_height
|
297 |
+
right = x1 / bounding_box_size * image_width
|
298 |
+
top = y1 / bounding_box_size * image_height
|
299 |
+
entities.append([left, bottom, right, top,angle])
|
300 |
+
else:
|
301 |
+
# don't detect any valid bbox to visualize
|
302 |
+
return None, ''
|
303 |
+
|
304 |
+
if len(entities) == 0:
|
305 |
+
return None, ''
|
306 |
+
|
307 |
+
if isinstance(image, Image.Image):
|
308 |
+
image_h = image.height
|
309 |
+
image_w = image.width
|
310 |
+
image = np.array(image)
|
311 |
+
|
312 |
+
elif isinstance(image, str):
|
313 |
+
if os.path.exists(image):
|
314 |
+
pil_img = Image.open(image).convert("RGB")
|
315 |
+
image = np.array(pil_img)[:, :, [2, 1, 0]]
|
316 |
+
image_h = pil_img.height
|
317 |
+
image_w = pil_img.width
|
318 |
+
else:
|
319 |
+
raise ValueError(f"invaild image path, {image}")
|
320 |
+
elif isinstance(image, torch.Tensor):
|
321 |
+
|
322 |
+
image_tensor = image.cpu()
|
323 |
+
reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
|
324 |
+
reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
|
325 |
+
image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
|
326 |
+
pil_img = T.ToPILImage()(image_tensor)
|
327 |
+
image_h = pil_img.height
|
328 |
+
image_w = pil_img.width
|
329 |
+
image = np.array(pil_img)[:, :, [2, 1, 0]]
|
330 |
+
else:
|
331 |
+
raise ValueError(f"invalid image format, {type(image)} for {image}")
|
332 |
+
|
333 |
+
indices = list(range(len(entities)))
|
334 |
+
|
335 |
+
new_image = image.copy()
|
336 |
+
|
337 |
+
previous_bboxes = []
|
338 |
+
# size of text
|
339 |
+
text_size = 0.4
|
340 |
+
# thickness of text
|
341 |
+
text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
|
342 |
+
box_line = 2
|
343 |
+
(c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
|
344 |
+
base_height = int(text_height * 0.675)
|
345 |
+
text_offset_original = text_height - base_height
|
346 |
+
text_spaces = 2
|
347 |
+
|
348 |
+
# num_bboxes = sum(len(x[-1]) for x in entities)
|
349 |
+
used_colors = colors # random.sample(colors, k=num_bboxes)
|
350 |
+
|
351 |
+
color_id = -1
|
352 |
+
for entity_idx, entity_name in enumerate(entities):
|
353 |
+
if mode == 'single' or mode == 'identify':
|
354 |
+
bboxes = entity_name
|
355 |
+
bboxes = [bboxes]
|
356 |
+
else:
|
357 |
+
bboxes = entities[entity_name]
|
358 |
+
color_id += 1
|
359 |
+
for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm,angle) in enumerate(bboxes):
|
360 |
+
skip_flag = False
|
361 |
+
orig_x1, orig_y1, orig_x2, orig_y2,angle = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm), int(angle)
|
362 |
+
|
363 |
+
color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
|
364 |
+
top_right=(orig_x1,orig_y1)
|
365 |
+
bottom_left=(orig_x2,orig_y2)
|
366 |
+
angle=angle
|
367 |
+
rotated_bbox = rotate_bbox(top_right, bottom_left, angle)
|
368 |
+
new_image=cv2.polylines(new_image, [rotated_bbox.astype(np.int32)], isClosed=True,thickness=2, color=color)
|
369 |
+
|
370 |
+
# new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
|
371 |
+
|
372 |
+
if mode == 'all':
|
373 |
+
l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
|
374 |
+
|
375 |
+
x1 = orig_x1 - l_o
|
376 |
+
y1 = orig_y1 - l_o
|
377 |
+
|
378 |
+
if y1 < text_height + text_offset_original + 2 * text_spaces:
|
379 |
+
y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
|
380 |
+
x1 = orig_x1 + r_o
|
381 |
+
|
382 |
+
# add text background
|
383 |
+
(text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
|
384 |
+
text_line)
|
385 |
+
text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
|
386 |
+
text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
|
387 |
+
|
388 |
+
for prev_bbox in previous_bboxes:
|
389 |
+
if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
|
390 |
+
prev_bbox['phrase'] == entity_name:
|
391 |
+
skip_flag = True
|
392 |
+
break
|
393 |
+
while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
|
394 |
+
text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
|
395 |
+
text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
|
396 |
+
y1 += (text_height + text_offset_original + 2 * text_spaces)
|
397 |
+
|
398 |
+
if text_bg_y2 >= image_h:
|
399 |
+
text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
|
400 |
+
text_bg_y2 = image_h
|
401 |
+
y1 = image_h
|
402 |
+
break
|
403 |
+
if not skip_flag:
|
404 |
+
alpha = 0.5
|
405 |
+
for i in range(text_bg_y1, text_bg_y2):
|
406 |
+
for j in range(text_bg_x1, text_bg_x2):
|
407 |
+
if i < image_h and j < image_w:
|
408 |
+
if j < text_bg_x1 + 1.35 * c_width:
|
409 |
+
# original color
|
410 |
+
bg_color = color
|
411 |
+
else:
|
412 |
+
# white
|
413 |
+
bg_color = [255, 255, 255]
|
414 |
+
new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
|
415 |
+
np.uint8)
|
416 |
+
|
417 |
+
cv2.putText(
|
418 |
+
new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
|
419 |
+
cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
|
420 |
+
)
|
421 |
+
|
422 |
+
previous_bboxes.append(
|
423 |
+
{'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
|
424 |
+
|
425 |
+
if mode == 'all':
|
426 |
+
def color_iterator(colors):
|
427 |
+
while True:
|
428 |
+
for color in colors:
|
429 |
+
yield color
|
430 |
+
|
431 |
+
color_gen = color_iterator(colors)
|
432 |
+
|
433 |
+
# Add colors to phrases and remove <p></p>
|
434 |
+
def colored_phrases(match):
|
435 |
+
phrase = match.group(1)
|
436 |
+
color = next(color_gen)
|
437 |
+
return f'<span style="color:rgb{color}">{phrase}</span>'
|
438 |
+
|
439 |
+
generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
|
440 |
+
generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
|
441 |
+
else:
|
442 |
+
generation_colored = ''
|
443 |
+
|
444 |
+
pil_image = Image.fromarray(new_image)
|
445 |
+
return pil_image, generation_colored
|
446 |
+
|
447 |
+
|
448 |
+
def gradio_reset(chat_state, img_list):
|
449 |
+
if chat_state is not None:
|
450 |
+
chat_state.messages = []
|
451 |
+
if img_list is not None:
|
452 |
+
img_list = []
|
453 |
+
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
|
454 |
+
interactive=True), chat_state, img_list
|
455 |
+
|
456 |
+
|
457 |
+
def image_upload_trigger(upload_flag, replace_flag, img_list):
|
458 |
+
# set the upload flag to true when receive a new image.
|
459 |
+
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
|
460 |
+
upload_flag = 1
|
461 |
+
if img_list:
|
462 |
+
replace_flag = 1
|
463 |
+
return upload_flag, replace_flag
|
464 |
+
|
465 |
+
|
466 |
+
def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
|
467 |
+
# set the upload flag to true when receive a new image.
|
468 |
+
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
|
469 |
+
upload_flag = 1
|
470 |
+
if img_list or replace_flag == 1:
|
471 |
+
replace_flag = 1
|
472 |
+
|
473 |
+
return upload_flag, replace_flag
|
474 |
+
|
475 |
+
|
476 |
+
def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
|
477 |
+
if len(user_message) == 0:
|
478 |
+
text_box_show = 'Input should not be empty!'
|
479 |
+
else:
|
480 |
+
text_box_show = ''
|
481 |
+
|
482 |
+
if isinstance(gr_img, dict):
|
483 |
+
gr_img, mask = gr_img['image'], gr_img['mask']
|
484 |
+
else:
|
485 |
+
mask = None
|
486 |
+
|
487 |
+
if '[identify]' in user_message:
|
488 |
+
# check if user provide bbox in the text input
|
489 |
+
integers = re.findall(r'-?\d+', user_message)
|
490 |
+
if len(integers) != 4: # no bbox in text
|
491 |
+
bbox = mask2bbox(mask)
|
492 |
+
user_message = user_message + bbox
|
493 |
+
|
494 |
+
if chat_state is None:
|
495 |
+
chat_state = CONV_VISION.copy()
|
496 |
+
|
497 |
+
if upload_flag:
|
498 |
+
if replace_flag:
|
499 |
+
chat_state = CONV_VISION.copy() # new image, reset everything
|
500 |
+
replace_flag = 0
|
501 |
+
chatbot = []
|
502 |
+
img_list = []
|
503 |
+
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
504 |
+
upload_flag = 0
|
505 |
+
|
506 |
+
chat.ask(user_message, chat_state)
|
507 |
+
|
508 |
+
chatbot = chatbot + [[user_message, None]]
|
509 |
+
|
510 |
+
if '[identify]' in user_message:
|
511 |
+
visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
|
512 |
+
if visual_img is not None:
|
513 |
+
file_path = save_tmp_img(visual_img)
|
514 |
+
chatbot = chatbot + [[(file_path,), None]]
|
515 |
+
|
516 |
+
return text_box_show, chatbot, chat_state, img_list, upload_flag, replace_flag
|
517 |
+
|
518 |
+
|
519 |
+
# def gradio_answer(chatbot, chat_state, img_list, temperature):
|
520 |
+
# llm_message = chat.answer(conv=chat_state,
|
521 |
+
# img_list=img_list,
|
522 |
+
# temperature=temperature,
|
523 |
+
# max_new_tokens=500,
|
524 |
+
# max_length=2000)[0]
|
525 |
+
# chatbot[-1][1] = llm_message
|
526 |
+
# return chatbot, chat_state
|
527 |
+
|
528 |
+
|
529 |
+
def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
|
530 |
+
if len(img_list) > 0:
|
531 |
+
if not isinstance(img_list[0], torch.Tensor):
|
532 |
+
chat.encode_img(img_list)
|
533 |
+
streamer = chat.stream_answer(conv=chat_state,
|
534 |
+
img_list=img_list,
|
535 |
+
temperature=temperature,
|
536 |
+
max_new_tokens=500,
|
537 |
+
max_length=2000)
|
538 |
+
# chatbot[-1][1] = output
|
539 |
+
# chat_state.messages[-1][1] = '</s>'
|
540 |
+
|
541 |
+
output = ''
|
542 |
+
for new_output in streamer:
|
543 |
+
# print(new_output)
|
544 |
+
output=output+new_output
|
545 |
+
print(output)
|
546 |
+
# if "{" in output:
|
547 |
+
# chatbot[-1][1]="Grounding and referring expression is still under work."
|
548 |
+
# else:
|
549 |
+
output = escape_markdown(output)
|
550 |
+
# output += escapped
|
551 |
+
chatbot[-1][1] = output
|
552 |
+
yield chatbot, chat_state
|
553 |
+
chat_state.messages[-1][1] = '</s>'
|
554 |
+
return chatbot, chat_state
|
555 |
+
|
556 |
+
|
557 |
+
def gradio_visualize(chatbot, gr_img):
|
558 |
+
if isinstance(gr_img, dict):
|
559 |
+
gr_img, mask = gr_img['image'], gr_img['mask']
|
560 |
+
|
561 |
+
unescaped = reverse_escape(chatbot[-1][1])
|
562 |
+
visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
|
563 |
+
if visual_img is not None:
|
564 |
+
if len(generation_color):
|
565 |
+
chatbot[-1][1] = generation_color
|
566 |
+
file_path = save_tmp_img(visual_img)
|
567 |
+
chatbot = chatbot + [[None, (file_path,)]]
|
568 |
+
|
569 |
+
return chatbot
|
570 |
+
|
571 |
+
|
572 |
+
def gradio_taskselect(idx):
|
573 |
+
prompt_list = [
|
574 |
+
'',
|
575 |
+
'Classify the image in the following classes: ',
|
576 |
+
'[identify] what is this ',
|
577 |
+
]
|
578 |
+
instruct_list = [
|
579 |
+
'**Hint:** Type in whatever you want',
|
580 |
+
'**Hint:** Type in the classes you want the model to classify in',
|
581 |
+
'**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw',
|
582 |
+
]
|
583 |
+
return prompt_list[idx], instruct_list[idx]
|
584 |
+
|
585 |
+
|
586 |
+
|
587 |
+
|
588 |
+
chat = Chat(model, image_processor,tokenizer, device=device)
|
589 |
+
|
590 |
+
|
591 |
+
title = """<h1 align="center">GeoChat Demo</h1>"""
|
592 |
+
description = 'Welcome to Our GeoChat Chatbot Demo!'
|
593 |
+
article = """<div style="display: flex;"><p style="display: inline-block;"><a href='https://mbzuai-oryx.github.io/GeoChat'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p style="display: inline-block;"><a href='https://arxiv.org/abs/2311.15826'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p><p style="display: inline-block;"><a href='https://github.com/mbzuai-oryx/GeoChat/tree/main'><img src='https://img.shields.io/badge/GitHub-Repo-blue'></a></p><p style="display: inline-block;"><a href='https://youtu.be/KOKtkkKpNDk?feature=shared'><img src='https://img.shields.io/badge/YouTube-Video-red'></a></p></div>"""
|
594 |
+
# article = """<p><a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p>"""
|
595 |
+
|
596 |
+
introduction = '''
|
597 |
+
1. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time).
|
598 |
+
2. No Tag: Input whatever you want and CLICK **Send** without any tagging
|
599 |
+
|
600 |
+
You can also simply chat in free form!
|
601 |
+
'''
|
602 |
+
|
603 |
+
|
604 |
+
text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False,
|
605 |
+
scale=12)
|
606 |
+
with gr.Blocks() as demo:
|
607 |
+
gr.Markdown(title)
|
608 |
+
# gr.Markdown(description)
|
609 |
+
gr.Markdown(article)
|
610 |
+
|
611 |
+
with gr.Row():
|
612 |
+
with gr.Column(scale=0.5):
|
613 |
+
image = gr.Image(type="pil", tool='sketch', brush_radius=20)
|
614 |
+
|
615 |
+
temperature = gr.Slider(
|
616 |
+
minimum=0.1,
|
617 |
+
maximum=1.5,
|
618 |
+
value=0.6,
|
619 |
+
step=0.1,
|
620 |
+
interactive=True,
|
621 |
+
label="Temperature",
|
622 |
+
)
|
623 |
+
|
624 |
+
clear = gr.Button("Restart")
|
625 |
+
|
626 |
+
gr.Markdown(introduction)
|
627 |
+
|
628 |
+
with gr.Column():
|
629 |
+
chat_state = gr.State(value=None)
|
630 |
+
img_list = gr.State(value=[])
|
631 |
+
chatbot = gr.Chatbot(label='GeoChat')
|
632 |
+
|
633 |
+
dataset = gr.Dataset(
|
634 |
+
components=[gr.Textbox(visible=False)],
|
635 |
+
samples=[['No Tag'], ['Scene Classification'],['Identify']],
|
636 |
+
type="index",
|
637 |
+
label='Task Shortcuts',
|
638 |
+
)
|
639 |
+
task_inst = gr.Markdown('**Hint:** Upload your image and chat')
|
640 |
+
with gr.Row():
|
641 |
+
text_input.render()
|
642 |
+
send = gr.Button("Send", variant='primary', size='sm', scale=1)
|
643 |
+
|
644 |
+
upload_flag = gr.State(value=0)
|
645 |
+
replace_flag = gr.State(value=0)
|
646 |
+
image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
|
647 |
+
|
648 |
+
with gr.Row():
|
649 |
+
with gr.Column():
|
650 |
+
gr.Examples(examples=[
|
651 |
+
["demo_images/train_2956_0001.png", "Where are the airplanes located and what is their type?", upload_flag, replace_flag,
|
652 |
+
img_list],
|
653 |
+
["demo_images/7292.JPG", "How many buildings are flooded?", upload_flag,
|
654 |
+
replace_flag, img_list],
|
655 |
+
], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
|
656 |
+
outputs=[upload_flag, replace_flag])
|
657 |
+
with gr.Column():
|
658 |
+
gr.Examples(examples=[
|
659 |
+
["demo_images/church_183.png", "Classify the image in the following classes: Church, Beach, Dense Residential, Storage Tanks.",
|
660 |
+
upload_flag, replace_flag, img_list],
|
661 |
+
["demo_images/04444.png", "[identify] what is this {<8><26><22><37>}", upload_flag,
|
662 |
+
replace_flag, img_list],
|
663 |
+
], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
|
664 |
+
outputs=[upload_flag, replace_flag])
|
665 |
+
|
666 |
+
dataset.click(
|
667 |
+
gradio_taskselect,
|
668 |
+
inputs=[dataset],
|
669 |
+
outputs=[text_input, task_inst],
|
670 |
+
show_progress="hidden",
|
671 |
+
postprocess=False,
|
672 |
+
queue=False,
|
673 |
+
)
|
674 |
+
|
675 |
+
text_input.submit(
|
676 |
+
gradio_ask,
|
677 |
+
[text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
|
678 |
+
[text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
|
679 |
+
).success(
|
680 |
+
gradio_stream_answer,
|
681 |
+
[chatbot, chat_state, img_list, temperature],
|
682 |
+
[chatbot, chat_state]
|
683 |
+
).success(
|
684 |
+
gradio_visualize,
|
685 |
+
[chatbot, image],
|
686 |
+
[chatbot],
|
687 |
+
queue=False,
|
688 |
+
)
|
689 |
+
|
690 |
+
send.click(
|
691 |
+
gradio_ask,
|
692 |
+
[text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
|
693 |
+
[text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
|
694 |
+
).success(
|
695 |
+
gradio_stream_answer,
|
696 |
+
[chatbot, chat_state, img_list, temperature],
|
697 |
+
[chatbot, chat_state]
|
698 |
+
).success(
|
699 |
+
gradio_visualize,
|
700 |
+
[chatbot, image],
|
701 |
+
[chatbot],
|
702 |
+
queue=False,
|
703 |
+
)
|
704 |
+
|
705 |
+
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
|
706 |
+
|
707 |
+
demo.launch(share=True, enable_queue=True,server_name='0.0.0.0')
|
.ipynb_checkpoints/pyproject-checkpoint.toml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools>=61.0"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "geochat"
|
7 |
+
version = "1.1.1"
|
8 |
+
description = "Grounded VLM for Remote Sensing"
|
9 |
+
readme = "README.md"
|
10 |
+
requires-python = ">=3.8"
|
11 |
+
classifiers = [
|
12 |
+
"Programming Language :: Python :: 3",
|
13 |
+
"License :: OSI Approved :: Apache Software License",
|
14 |
+
]
|
15 |
+
dependencies = [
|
16 |
+
"einops", "fastapi", "gradio==3.35.2", "markdown2[all]", "numpy",
|
17 |
+
"requests", "sentencepiece", "tokenizers>=0.12.1",
|
18 |
+
"torch==2.0.1", "torchvision==0.15.2", "uvicorn", "wandb",
|
19 |
+
"shortuuid", "httpx==0.24.0",
|
20 |
+
#"deepspeed==0.9.5",
|
21 |
+
"peft==0.4.0",
|
22 |
+
"transformers==4.31.0",
|
23 |
+
"accelerate==0.21.0",
|
24 |
+
"bitsandbytes==0.41.0",
|
25 |
+
"scikit-learn==1.2.2",
|
26 |
+
"sentencepiece==0.1.99",
|
27 |
+
"einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
|
28 |
+
"gradio_client==0.2.9"
|
29 |
+
]
|
30 |
+
|
31 |
+
[project.urls]
|
32 |
+
"Homepage" = "https://github.com/mbzuai-oryx/GeoChat"
|
33 |
+
"Bug Tracker" = "https://github.com/mbzuai-oryx/GeoChat/issues"
|
34 |
+
|
35 |
+
[tool.setuptools.packages.find]
|
36 |
+
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
|
37 |
+
|
38 |
+
[tool.wheel]
|
39 |
+
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
|
README.md
CHANGED
@@ -1,12 +1,231 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: csu
|
3 |
+
app_file: geochat_demo.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 3.35.2
|
|
|
|
|
6 |
---
|
7 |
+
# GeoChat <img src="images/logo_geochat.png" height="40">: Grounded Large Vision-Language Model for Remote Sensing [CVPR-2024]
|
8 |
+
<p align="center">
|
9 |
+
<img src="https://i.imgur.com/waxVImv.png" alt="Oryx Video-ChatGPT">
|
10 |
+
</p>
|
11 |
|
12 |
+
#### [Kartik Kuckreja](https://www.linkedin.com/in/kartik-kuckreja-930531221/)\*, [Muhammad Sohail Danish](https://www.linkedin.com/in/muhammad-sohail-danish/)\*, [Muzammal Naseer](https://muzammal-naseer.com/), [Abhijit Das](https://sites.google.com/site/dasabhijit2048/home), [Salman Khan](https://salman-h-khan.github.io/) and [Fahad Khan](https://sites.google.com/view/fahadkhans/home)
|
13 |
+
\* Equally contributing first authors
|
14 |
+
|
15 |
+
#### **Mohamed bin Zayed University of AI, Birla Institute of Technology & Science, Australian National University, Linkoping University**
|
16 |
+
|
17 |
+
[](https://mbzuai-oryx.github.io/GeoChat)
|
18 |
+
[](https://arxiv.org/abs/2311.15826)
|
19 |
+
[](https://youtu.be/KOKtkkKpNDk)
|
20 |
+
|
21 |
+
---
|
22 |
+
|
23 |
+
## 📢 Latest Updates
|
24 |
+
- Supplementary material for the accepted paper is available here: [Supplementary](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/geochat_supp.pdf).
|
25 |
+
- **Feb-28-24**: We open source the code, model, dataset, and evaluation scripts.
|
26 |
+
- **Feb-27-24**: GeoChat has been accepted to **CVPR-24** 🎉.
|
27 |
+
- **Nov-28-23**: GeoChat paper is released [arxiv link](https://arxiv.org/abs/2311.15826). 🔥🔥
|
28 |
+
---
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
## <img src="images/logo_geochat.png" height="40">Overview
|
33 |
+
|
34 |
+
GeoChat is the first grounded Large Vision Language Model, specifically tailored to Remote Sensing(RS) scenarios. Unlike general-domain models, GeoChat excels in handling high-resolution RS imagery, employing region-level reasoning for comprehensive scene interpretation. Leveraging a newly created RS multimodal dataset, GeoChat is fine-tuned using the LLaVA-1.5 architecture. This results in robust zero-shot performance across various RS tasks, including image and region captioning, visual question answering, scene classification, visually grounded conversations, and referring object detection.
|
35 |
+
|
36 |
+
---
|
37 |
+
## Contents
|
38 |
+
- [Install](#install)
|
39 |
+
- [Model Zoo](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/MODEL_ZOO.md)
|
40 |
+
- [Dataset](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/GeoChat_Instruct.json)
|
41 |
+
- [Train](#train)
|
42 |
+
- [Evaluation](#evaluation)
|
43 |
+
|
44 |
+
## Install
|
45 |
+
|
46 |
+
1. Clone this repository and navigate to GeoChat folder
|
47 |
+
```bash
|
48 |
+
git clone https://github.com/mbzuai-oryx/GeoChat.git
|
49 |
+
cd GeoChat
|
50 |
+
```
|
51 |
+
|
52 |
+
2. Install Package
|
53 |
+
```Shell
|
54 |
+
conda create -n geochat python=3.10 -y
|
55 |
+
conda activate geochat
|
56 |
+
pip install --upgrade pip # enable PEP 660 support
|
57 |
+
pip install -e .
|
58 |
+
```
|
59 |
+
|
60 |
+
3. Install additional packages for training cases
|
61 |
+
```
|
62 |
+
pip install ninja
|
63 |
+
pip install flash-attn --no-build-isolation
|
64 |
+
```
|
65 |
+
|
66 |
+
### Upgrade to latest code base
|
67 |
+
|
68 |
+
```Shell
|
69 |
+
git pull
|
70 |
+
pip uninstall transformers
|
71 |
+
pip install -e .
|
72 |
+
```
|
73 |
+
|
74 |
+
## GeoChat Weights and Demo
|
75 |
+
Please check out our [Model Zoo](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/MODEL_ZOO.md) for all public GeoChat checkpoints, and check [LoRA.md](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/LoRA.md) for instructions on how to run the demo and training.
|
76 |
+
|
77 |
+
## Train
|
78 |
+
|
79 |
+
GeoChat training consists of visual instruction tuning using GeoChat_Instruct Dataset: 318k Vicuna-generated multimodal instruction-following data, finetuned over the pretrained weights of LlaVA-v1.5.
|
80 |
+
|
81 |
+
We train GeoChat on 3 A100 GPUs with 40GB memory. To train on fewer GPUs, you can reduce the `per_device_train_batch_size` and increase the `gradient_accumulation_steps` accordingly. Always keep the global batch size the same: `per_device_train_batch_size` x `gradient_accumulation_steps` x `num_gpus`.
|
82 |
+
|
83 |
+
### Hyperparameters
|
84 |
+
We use a similar set of hyperparameters as Vicuna in finetuning. Both hyperparameters used in pretraining and finetuning are provided below.
|
85 |
+
|
86 |
+
| Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
|
87 |
+
| --- | ---: | ---: | ---: | ---: | ---: |
|
88 |
+
| GeoChat-7B | 144 | 2e-5 | 1 | 2048 | 0 |
|
89 |
+
|
90 |
+
### Pretrain (feature alignment)
|
91 |
+
|
92 |
+
We use the pretrained projector from LLaVAv1.5, which is trained on 558K subset of the LAION-CC-SBU dataset with BLIP captions. It takes around 3.5 hours for LLaVA-v1.5-7B.
|
93 |
+
|
94 |
+
- `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
|
95 |
+
- `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
|
96 |
+
|
97 |
+
### Visual Instruction Tuning
|
98 |
+
|
99 |
+
1. Prepare data
|
100 |
+
|
101 |
+
Please download the annotation of the final mixture of our instruction tuning data [GeoChat_Instruct.json](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/GeoChat_Instruct.json), and download the split image zips from the [hugging face](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct). Save the multiple image zips in a single folder and run the following command to merge them:
|
102 |
+
```Shell
|
103 |
+
cat images_parta* > images.zip
|
104 |
+
```
|
105 |
+
Unzip the images.zip file to a folder and give the folder's path in [finetune_lora.sh](https://github.com/mbzuai-oryx/GeoChat/blob/main/scripts/finetune_lora.sh).
|
106 |
+
|
107 |
+
2. Start training!
|
108 |
+
|
109 |
+
Visual instruction tuning takes more time due to the increased resolution of CLIP to 504X504. It takes around ~25 hours to finetune GeoChat-7B on 3x A100 (40G).
|
110 |
+
|
111 |
+
Training script with DeepSpeed ZeRO-3: [`finetune_lora.sh`](https://github.com/mbzuai-oryx/GeoChat/blob/main/scripts/finetune_lora.sh).
|
112 |
+
|
113 |
+
Options to note:
|
114 |
+
|
115 |
+
- `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
|
116 |
+
- `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
|
117 |
+
- `--image_aspect_ratio pad`: this pads the non-square images to square, instead of cropping them; it slightly reduces hallucination.
|
118 |
+
- `--group_by_modality_length True`: this should only be used when your instruction tuning dataset contains both language (e.g. ShareGPT) and multimodal (e.g. LLaVA-Instruct).
|
119 |
+
-
|
120 |
+
## Evaluation
|
121 |
+
|
122 |
+
We evaluate GeoChat on a diverse set of 7 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs.
|
123 |
+
See [Evaluation.md](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/Evaluation.md).
|
124 |
+
|
125 |
+
## 🏆 Contributions
|
126 |
+
|
127 |
+
- **RS multimodal instruction following dataset.** We present a novel data generation pipeline, to leverage existing object detection dataset to create short descriptions of the images, followed by using Vicuna-v1.5 to create conversations using the generated text alone. Further, we add visual question-answering and scene classification abilities
|
128 |
+
using their corresponding datasets. This results in a total of 318k instruction pairs for RS domain.
|
129 |
+
- **GeoChat.** Leveraging our dataset, we finetune LLaVA-1.5 to create the remote sensing-domain vision-language model - GeoChat. Our LoRA fine-tuning is efficient and avoids forgetting the necessary context embedded in fully-tuned LLaVA model, whose MLP projection is trained to align images into the word embedding space of the LLM (Vicuna-v1.5). This allows GeoChat to retain the conversation and instruction following abilities of LLaVA and extend its domain-knowledge to remote sensing tasks.
|
130 |
+
|
131 |
+
- **Evaluation Benchmark.** We also address the lack of evaluation benchmarks to assess the capability of existing VLMs on remote-sensing conversations. To this end, we setup evaluation protocols for conversation grounding in RS, as well as a setup a suite of tasks to allow comparisons with future efforts in this direction. We show various supervised as well as zero-shot evaluations for different remote sensing tasks, including image captioning, visual question answering and scene classification to demonstrate the generalisability of GeoChat conversational VLM.
|
132 |
+
|
133 |
+
---
|
134 |
+
## 👁️💬 GeoChat : Grounded Large Vision-Language Model for Remote Sensing
|
135 |
+
|
136 |
+
GeoChat can accomplish multiple tasks for remote-sensing (RS) image comprehension in a unified framework. Given suitable task tokens and user queries, the model can generate visually grounded responses (text with corresponding object locations - shown on top), visual question answering on images and regions (top left and bottom right, respectively) as well as scene classification (top right) and normal natural language conversations (bottom). This makes it the first RS VLM with grounding capability.
|
137 |
+
|
138 |
+
<p align="center">
|
139 |
+
<img src="images/overview2.png" alt="GeoChat Overview">
|
140 |
+
</p>
|
141 |
+
|
142 |
+
---
|
143 |
+
|
144 |
+
## 🛰️ GeoChat : Architecture
|
145 |
+
|
146 |
+
An overview of GeoChat - the first grounded large vision-language model for remote sensing. Given an image input together with a user query, a visual backbone is first used to encode patch-level tokens at a higher resolution via interpolating positional encodings. A multi-layer perceptron (MLP) is used to adapt vision-tokens to language space suitable for input to a Large Language Model (Vicuna 1.5). Besides visual inputs, region locations can also be input to the model together with task-specific prompts that specify the desired task required by the user. Given this context, the LLM can generate natural language responses interleaved with corresponding object locations. GeoChat can perform multiple tasks as shown on top e.g., scene classification, image/region captioning, VQA and grounded conversations.
|
147 |
+
|
148 |
+
<p align="center">
|
149 |
+
<img src="images/architecture.png" alt="GeoChat Architectural">
|
150 |
+
</p>
|
151 |
+
|
152 |
+
---
|
153 |
+
|
154 |
+
## 🔍 RS Multimodal Instruction Dataset
|
155 |
+
|
156 |
+
Types of annotations available in the GeoChat instruction-set. For a given RS image, we obtain object attribute and relationship information, referring expressions and region captions along with their corresponding region annotations (shown over the image). This structured information is used to create the rich instruction-set with a total of 318k image-instruction pairs.
|
157 |
+
|
158 |
+
<p align="center">
|
159 |
+
<img src="images/dataset.png" alt="Dataset Annotation Pipeline">
|
160 |
+
</p>
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
## 🤖 Qualitative results of GeoChat
|
165 |
+
|
166 |
+
Qualitative results of GeoChat. (<em>left-right</em>) Results are shown on grounding, referring object detection, and disaster/damage detection. The user can provide task-specific tokens (e.g., <strong>[grounding]</strong>) to shape model responses according to the desired behavior. The model can generate textual responses (<em>right</em>), only visual grounding (<em>center</em>) and both text and object groundings interleaved together (<em>left</em>). The model can also specify object types, object counts, object attributes and object relationships.
|
167 |
+
<p align="center">
|
168 |
+
<img src="images/examples.png" alt="Results_GCG">
|
169 |
+
</p>
|
170 |
+
|
171 |
+
---
|
172 |
+
|
173 |
+
## 🤖 Visual Question Answering
|
174 |
+
Qualitative examples for Visual Question Answering tasks. GeoChat is able to hold multi-turn conversations, based on various types of questions, including presence, count, complex comparisons and so on. It is able to detect objects and hold conversations against low resolution images as well.
|
175 |
+
<p align="center">
|
176 |
+
<img src="images/vqa.jpg" alt="Visual Question Answering">
|
177 |
+
</p>
|
178 |
+
|
179 |
+
---
|
180 |
+
|
181 |
+
## 🤖 Scene Classification
|
182 |
+
Qualitative examples for scene classification. We give the model all the classes from the dataset and ask to choose only one.
|
183 |
+
<p align="center">
|
184 |
+
<img src="images/scene.jpg" alt="Visual Question Answering">
|
185 |
+
</p>
|
186 |
+
|
187 |
+
---
|
188 |
+
|
189 |
+
## 🤖 Grounded Description
|
190 |
+
When asked to describe the image with the special token '[grounding]', GeoChat outputs both the description of the image as well as the bounding boxes for all the objects detected.
|
191 |
+
<p align="center">
|
192 |
+
<img src="images/grounded.jpg" alt="Grounded Description">
|
193 |
+
</p>
|
194 |
+
|
195 |
+
---
|
196 |
+
|
197 |
+
## 🤖 Referring Expression
|
198 |
+
When asked about an object as a referred expression, GeoChat is able to locate it and draw rotated bounding boxes around it correspondingly.
|
199 |
+
<p align="center">
|
200 |
+
<img src="images/ref1.jpg" alt="Referring Expression">
|
201 |
+
</p>
|
202 |
+
<p align="center">
|
203 |
+
<img src="images/ref_2.jpg" alt="Referring Expression">
|
204 |
+
</p>
|
205 |
+
|
206 |
+
---
|
207 |
+
|
208 |
+
## 🤖 Region Caption
|
209 |
+
Qualitative examples for region-based captioning. Given a bounding box, GeoChat is able to provide brief descriptions about the area or the object covered by the bounding box.
|
210 |
+
<p align="center">
|
211 |
+
<img src="images/iden.jpg" alt="Region Caption">
|
212 |
+
</p>
|
213 |
+
|
214 |
+
---
|
215 |
+
|
216 |
+
## 📜 Citation
|
217 |
+
```bibtex
|
218 |
+
@article{kuckreja2023geochat,
|
219 |
+
title={GeoChat: Grounded Large Vision-Language Model for Remote Sensing},
|
220 |
+
author={Kuckreja, Kartik and Danish, Muhammad S. and Naseer, Muzammal and Das, Abhijit and Khan, Salman and Khan, Fahad S.},
|
221 |
+
journal={The IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
222 |
+
year={2024}
|
223 |
+
}
|
224 |
+
```
|
225 |
+
## 🙏 Acknowledgement
|
226 |
+
We are thankful to LLaVA and Vicuna for releasing their models and code as open-source contributions.
|
227 |
+
|
228 |
+
---
|
229 |
+
[<img src="images/IVAL_logo.png" width="200" height="100">](https://www.ival-mbzuai.com)
|
230 |
+
[<img src="images/Oryx_logo.png" width="100" height="100">](https://github.com/mbzuai-oryx)
|
231 |
+
[<img src="images/MBZUAI_logo.png" width="360" height="85">](https://mbzuai.ac.ae)
|
demo_images/04133.png
ADDED
![]() |
Git LFS Details
|
demo_images/04444.png
ADDED
![]() |
Git LFS Details
|
demo_images/7292.JPG
ADDED
|
Git LFS Details
|
demo_images/MicrosoftTeams-image.png
ADDED
![]() |
Git LFS Details
|
demo_images/church_183.png
ADDED
![]() |
Git LFS Details
|
demo_images/train_2956_0001.png
ADDED
![]() |
Git LFS Details
|
docs/Customize_Component.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Customize Components in GeoChat
|
2 |
+
|
3 |
+
This is an initial guide on how to replace the LLMs, visual encoders, etc. with your choice of components.
|
4 |
+
|
5 |
+
## LLM
|
6 |
+
|
7 |
+
It is quite simple to swap out LLaMA to any other LLMs. You can refer to our implementation of [`GeoChat_llama.py`](https://github.com/mbzuai-oryx/GeoChat/blob/main/geochat/model/language_model/geochat_llama.py) for an example of how to replace the LLM.
|
8 |
+
|
9 |
+
Although it may seem that it still needs ~100 lines of code, most of them are copied from the original `llama.py` from HF. The only part that is different is to insert some lines for processing the multimodal inputs.
|
10 |
+
|
11 |
+
In `forward` function, you can see that we call `self.prepare_inputs_labels_for_multimodal` to process the multimodal inputs. This function is defined in `GeoChatMetaForCausalLM` and you just need to insert it into the `forward` function of your LLM.
|
12 |
+
|
13 |
+
In `prepare_inputs_for_generation` function, you can see that we add `images` to the `model_inputs`. This is because we need to pass the images to the LLM during generation.
|
14 |
+
|
15 |
+
These are basically all the changes you need to make to replace the LLM.
|
16 |
+
|
17 |
+
## Visual Encoder
|
18 |
+
|
19 |
+
You can check out [`clip_encoder.py`](https://github.com/haotian-liu/LLaVA/blob/main/llava/model/multimodal_encoder/clip_encoder.py) on how we implement the CLIP visual encoder.
|
20 |
+
|
docs/Data.md
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Finetuning Data
|
2 |
+
We use GeoChat-Instruct to finetune our model. The instruction following dataset is present in GeoChat_Instruct.json and the images are present in the [huggingface repo](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct). The images are split into multiple files. Download the separate files in the same folder and run the following script to merge them.
|
3 |
+
|
4 |
+
```Shell
|
5 |
+
cat images_parta* > images.zip
|
6 |
+
```
|
7 |
+
|
8 |
+
Unzip the images in a folder and provide the folder path in training and evaluation scripts.
|
9 |
+
|
10 |
+
| Data file name | Size |
|
11 |
+
| --- | ---: |
|
12 |
+
| [GeoChat_Instruct](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/GeoChat_Instruct.json) | 263 MB |
|
13 |
+
|
14 |
+
## Pretraining Dataset
|
15 |
+
We use the same pretraining dataset as of LlaVA-v1.5.
|
16 |
+
The pretraining dataset used in this release is a subset of CC-3M dataset, filtered with a more balanced concept coverage distribution. Please see [here](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K) for a detailed description of the dataset structure and how to download the images.
|
17 |
+
|
18 |
+
If you already have CC-3M dataset on your disk, the image names follow this format: `GCC_train_000000000.jpg`. You may edit the `image` field correspondingly if necessary.
|
19 |
+
|
20 |
+
| Data | Chat File | Meta Data | Size |
|
21 |
+
| --- | --- | --- | ---: |
|
22 |
+
| CC-3M Concept-balanced 595K | [chat.json](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/chat.json) | [metadata.json](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/metadata.json) | 211 MB
|
23 |
+
| LAION/CC/SBU BLIP-Caption Concept-balanced 558K | [blip_laion_cc_sbu_558k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/blob/main/blip_laion_cc_sbu_558k.json) | [metadata.json](#) | 181 MB
|
24 |
+
|
docs/Evaluation.md
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Evaluation
|
2 |
+
|
3 |
+
We evaluate GeoChat on a variety of tasks, including scene classification, region captioning, visual grounding, grounding description and VQA.
|
4 |
+
Converted files in the input format for GeoChat are available at [GeoChat-Bench](https://huggingface.co/datasets/MBZUAI/GeoChat-Bench/tree/main)
|
5 |
+
|
6 |
+
|
7 |
+
Below we provide a general guideline for evaluating datasets.
|
8 |
+
|
9 |
+
1. LRBEN/HRBEN.
|
10 |
+
Images and ground truth for evaluation need to be downloaded from the following sources: [LRBEN](https://zenodo.org/records/6344334), [HRBEN](https://zenodo.org/records/6344367)
|
11 |
+
Give the path to the extracted image folder in the evaluation script. We add the following text after each question during our evaluation.
|
12 |
+
```
|
13 |
+
<question>
|
14 |
+
Answer the question using a single word or phrase.
|
15 |
+
```
|
16 |
+
```Shell
|
17 |
+
python geochat/eval/batch_geochat_vqa.py \
|
18 |
+
--model-path /path/to/model \
|
19 |
+
--question-file path/to/jsonl/file \
|
20 |
+
--answer-file path/to/output/jsonl/file \
|
21 |
+
--image_folder path/to/image/folder/
|
22 |
+
```
|
23 |
+
2. Scene Classification.
|
24 |
+
Download the images from the following sources, [UCmerced](http://weegee.vision.ucmerced.edu/datasets/landuse.html), [AID](https://drive.google.com/drive/folders/1-1D9DrYYWMGuuxx-qcvIIOV1oUkAVf-M). We add the following text after each question during our evaluation.
|
25 |
+
```
|
26 |
+
<question>
|
27 |
+
Classify the image from the following classes. Answer in one word or a short phrase.
|
28 |
+
```
|
29 |
+
```Shell
|
30 |
+
python geochat/eval/batch_geochat_scene.py \
|
31 |
+
--model-path /path/to/model \
|
32 |
+
--question-file path/to/jsonl/file \
|
33 |
+
--answer-file path/to/output/jsonl/file \
|
34 |
+
--image_folder path/to/image/folder/
|
35 |
+
```
|
36 |
+
|
37 |
+
3. Region-Captioning/Visual grounding.
|
38 |
+
|
39 |
+
The evaluation images are present in the image.zip folder in [GeoChat_Instruct](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/images.zip).
|
40 |
+
```Shell
|
41 |
+
python geochat/eval/batch_geochat_grounding.py \
|
42 |
+
--model-path /path/to/model \
|
43 |
+
--question-file path/to/jsonl/file \
|
44 |
+
--answer-file path/to/output/jsonl/file \
|
45 |
+
--image_folder path/to/image/folder/
|
46 |
+
```
|
47 |
+
|
48 |
+
```Shell
|
49 |
+
python geochat/eval/batch_geochat_referring.py \
|
50 |
+
--model-path /path/to/model \
|
51 |
+
--question-file path/to/jsonl/file \
|
52 |
+
--answer-file path/to/output/jsonl/file \
|
53 |
+
--image_folder path/to/image/folder/
|
54 |
+
```
|
docs/LoRA.md
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
## Demo (Web UI)
|
3 |
+
You need GeoChat-7B to run the demo locally. Download the model from [GeoChat-7B](https://huggingface.co/MBZUAI/geochat-7B). After loading the model, run this command by giving the model path to launch the gradio demo.
|
4 |
+
#### Launch the demo
|
5 |
+
```Shell
|
6 |
+
python geochat_demo.py --model-path /path/to/model
|
7 |
+
```
|
8 |
+
|
9 |
+
## Training
|
10 |
+
|
11 |
+
Please see sample training scripts for [LoRA](https://github.com/mbzuai-oryx/GeoChat/blob/main/scripts/finetune_lora.sh)
|
12 |
+
|
13 |
+
We provide sample DeepSpeed configs, [`zero3.json`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/zero3.json) is more like PyTorch FSDP, and [`zero3_offload.json`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/zero3_offload.json) can further save memory consumption by offloading parameters to CPU. `zero3.json` is usually faster than `zero3_offload.json` but requires more GPU memory, therefore, we recommend trying `zero3.json` first, and if you run out of GPU memory, try `zero3_offload.json`. You can also tweak the `per_device_train_batch_size` and `gradient_accumulation_steps` in the config to save memory, and just to make sure that `per_device_train_batch_size` and `gradient_accumulation_steps` remains the same.
|
14 |
+
|
15 |
+
If you are having issues with ZeRO-3 configs, and there are enough VRAM, you may try [`zero2.json`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/zero2.json). This consumes slightly more memory than ZeRO-3, and behaves more similar to PyTorch FSDP, while still supporting parameter-efficient tuning.
|
16 |
+
|
17 |
+
## Create Merged Checkpoints
|
18 |
+
|
19 |
+
```Shell
|
20 |
+
python scripts/merge_lora_weights.py \
|
21 |
+
--model-path /path/to/lora_model \
|
22 |
+
--model-base /path/to/base_model \
|
23 |
+
--save-model-path /path/to/merge_model
|
24 |
+
```
|
docs/MODEL_ZOO.md
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Zoo
|
2 |
+
|
3 |
+
| Base LLM | Vision Encoder | Pretrain Data | Pretraining schedule | Finetuning Data | Finetuning schedule | Download |
|
4 |
+
|----------|----------------|---------------|----------------------|-----------------|--------------------|------------------
|
5 |
+
| Vicuna-13B-v1.3 | CLIP-L-336px(extended to 504) | LCS-558K | 1e | Geochat_Instruct | proj-1e, lora-1e | [LoRA-Merged](https://huggingface.co/MBZUAI/geochat-7B) |
|
6 |
+
|
7 |
+
## Projector weights
|
8 |
+
We use the projector from LlaVA-1.5 for initialization. [Link](https://huggingface.co/liuhaotian/llava-v1.5-7b-lora)
|
9 |
+
|
10 |
+
**NOTE**: When you use our pretrained projector for visual instruction tuning, it is very important to **use the same base LLM and vision encoder** as the one we used for pretraining the projector. Otherwise, the performance will be very bad.
|
11 |
+
|
12 |
+
When using these projector weights to instruction tune your LMM, please make sure that these options are correctly set as follows,
|
13 |
+
|
14 |
+
```Shell
|
15 |
+
--mm_use_im_start_end False
|
16 |
+
--mm_use_im_patch_token False
|
17 |
+
```
|
18 |
+
|
docs/geochat_supp.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dc9b5d3df4af5c06e59fc2258332e00b779c55d441405cb5a7fd7997d29b63fe
|
3 |
+
size 4839915
|
geochat.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.2
|
2 |
+
Name: geochat
|
3 |
+
Version: 1.1.1
|
4 |
+
Summary: Grounded VLM for Remote Sensing
|
5 |
+
Project-URL: Homepage, https://github.com/mbzuai-oryx/GeoChat
|
6 |
+
Project-URL: Bug Tracker, https://github.com/mbzuai-oryx/GeoChat/issues
|
7 |
+
Classifier: Programming Language :: Python :: 3
|
8 |
+
Classifier: License :: OSI Approved :: Apache Software License
|
9 |
+
Requires-Python: >=3.8
|
10 |
+
Description-Content-Type: text/markdown
|
11 |
+
Requires-Dist: einops
|
12 |
+
Requires-Dist: fastapi
|
13 |
+
Requires-Dist: gradio==3.35.2
|
14 |
+
Requires-Dist: markdown2[all]
|
15 |
+
Requires-Dist: numpy
|
16 |
+
Requires-Dist: requests
|
17 |
+
Requires-Dist: sentencepiece
|
18 |
+
Requires-Dist: tokenizers>=0.12.1
|
19 |
+
Requires-Dist: torch==2.0.1
|
20 |
+
Requires-Dist: torchvision==0.15.2
|
21 |
+
Requires-Dist: uvicorn
|
22 |
+
Requires-Dist: wandb
|
23 |
+
Requires-Dist: shortuuid
|
24 |
+
Requires-Dist: httpx==0.24.0
|
25 |
+
Requires-Dist: peft==0.4.0
|
26 |
+
Requires-Dist: transformers==4.31.0
|
27 |
+
Requires-Dist: accelerate==0.21.0
|
28 |
+
Requires-Dist: bitsandbytes==0.41.0
|
29 |
+
Requires-Dist: scikit-learn==1.2.2
|
30 |
+
Requires-Dist: sentencepiece==0.1.99
|
31 |
+
Requires-Dist: einops==0.6.1
|
32 |
+
Requires-Dist: einops-exts==0.0.4
|
33 |
+
Requires-Dist: timm==0.6.13
|
34 |
+
Requires-Dist: gradio_client==0.2.9
|
35 |
+
|
36 |
+
# GeoChat <img src="images/logo_geochat.png" height="40">: Grounded Large Vision-Language Model for Remote Sensing [CVPR-2024]
|
37 |
+
<p align="center">
|
38 |
+
<img src="https://i.imgur.com/waxVImv.png" alt="Oryx Video-ChatGPT">
|
39 |
+
</p>
|
40 |
+
|
41 |
+
#### [Kartik Kuckreja](https://www.linkedin.com/in/kartik-kuckreja-930531221/)\*, [Muhammad Sohail Danish](https://www.linkedin.com/in/muhammad-sohail-danish/)\*, [Muzammal Naseer](https://muzammal-naseer.com/), [Abhijit Das](https://sites.google.com/site/dasabhijit2048/home), [Salman Khan](https://salman-h-khan.github.io/) and [Fahad Khan](https://sites.google.com/view/fahadkhans/home)
|
42 |
+
\* Equally contributing first authors
|
43 |
+
|
44 |
+
#### **Mohamed bin Zayed University of AI, Birla Institute of Technology & Science, Australian National University, Linkoping University**
|
45 |
+
|
46 |
+
[](https://mbzuai-oryx.github.io/GeoChat)
|
47 |
+
[](https://arxiv.org/abs/2311.15826)
|
48 |
+
[](https://youtu.be/KOKtkkKpNDk)
|
49 |
+
|
50 |
+
---
|
51 |
+
|
52 |
+
## 📢 Latest Updates
|
53 |
+
- Supplementary material for the accepted paper is available here: [Supplementary](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/geochat_supp.pdf).
|
54 |
+
- **Feb-28-24**: We open source the code, model, dataset, and evaluation scripts.
|
55 |
+
- **Feb-27-24**: GeoChat has been accepted to **CVPR-24** 🎉.
|
56 |
+
- **Nov-28-23**: GeoChat paper is released [arxiv link](https://arxiv.org/abs/2311.15826). 🔥🔥
|
57 |
+
---
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
## <img src="images/logo_geochat.png" height="40">Overview
|
62 |
+
|
63 |
+
GeoChat is the first grounded Large Vision Language Model, specifically tailored to Remote Sensing(RS) scenarios. Unlike general-domain models, GeoChat excels in handling high-resolution RS imagery, employing region-level reasoning for comprehensive scene interpretation. Leveraging a newly created RS multimodal dataset, GeoChat is fine-tuned using the LLaVA-1.5 architecture. This results in robust zero-shot performance across various RS tasks, including image and region captioning, visual question answering, scene classification, visually grounded conversations, and referring object detection.
|
64 |
+
|
65 |
+
---
|
66 |
+
## Contents
|
67 |
+
- [Install](#install)
|
68 |
+
- [Model Zoo](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/MODEL_ZOO.md)
|
69 |
+
- [Dataset](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/GeoChat_Instruct.json)
|
70 |
+
- [Train](#train)
|
71 |
+
- [Evaluation](#evaluation)
|
72 |
+
|
73 |
+
## Install
|
74 |
+
|
75 |
+
1. Clone this repository and navigate to GeoChat folder
|
76 |
+
```bash
|
77 |
+
git clone https://github.com/mbzuai-oryx/GeoChat.git
|
78 |
+
cd GeoChat
|
79 |
+
```
|
80 |
+
|
81 |
+
2. Install Package
|
82 |
+
```Shell
|
83 |
+
conda create -n geochat python=3.10 -y
|
84 |
+
conda activate geochat
|
85 |
+
pip install --upgrade pip # enable PEP 660 support
|
86 |
+
pip install -e .
|
87 |
+
```
|
88 |
+
|
89 |
+
3. Install additional packages for training cases
|
90 |
+
```
|
91 |
+
pip install ninja
|
92 |
+
pip install flash-attn --no-build-isolation
|
93 |
+
```
|
94 |
+
|
95 |
+
### Upgrade to latest code base
|
96 |
+
|
97 |
+
```Shell
|
98 |
+
git pull
|
99 |
+
pip uninstall transformers
|
100 |
+
pip install -e .
|
101 |
+
```
|
102 |
+
|
103 |
+
## GeoChat Weights and Demo
|
104 |
+
Please check out our [Model Zoo](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/MODEL_ZOO.md) for all public GeoChat checkpoints, and check [LoRA.md](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/LoRA.md) for instructions on how to run the demo and training.
|
105 |
+
|
106 |
+
## Train
|
107 |
+
|
108 |
+
GeoChat training consists of visual instruction tuning using GeoChat_Instruct Dataset: 318k Vicuna-generated multimodal instruction-following data, finetuned over the pretrained weights of LlaVA-v1.5.
|
109 |
+
|
110 |
+
We train GeoChat on 3 A100 GPUs with 40GB memory. To train on fewer GPUs, you can reduce the `per_device_train_batch_size` and increase the `gradient_accumulation_steps` accordingly. Always keep the global batch size the same: `per_device_train_batch_size` x `gradient_accumulation_steps` x `num_gpus`.
|
111 |
+
|
112 |
+
### Hyperparameters
|
113 |
+
We use a similar set of hyperparameters as Vicuna in finetuning. Both hyperparameters used in pretraining and finetuning are provided below.
|
114 |
+
|
115 |
+
| Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
|
116 |
+
| --- | ---: | ---: | ---: | ---: | ---: |
|
117 |
+
| GeoChat-7B | 144 | 2e-5 | 1 | 2048 | 0 |
|
118 |
+
|
119 |
+
### Pretrain (feature alignment)
|
120 |
+
|
121 |
+
We use the pretrained projector from LLaVAv1.5, which is trained on 558K subset of the LAION-CC-SBU dataset with BLIP captions. It takes around 3.5 hours for LLaVA-v1.5-7B.
|
122 |
+
|
123 |
+
- `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
|
124 |
+
- `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
|
125 |
+
|
126 |
+
### Visual Instruction Tuning
|
127 |
+
|
128 |
+
1. Prepare data
|
129 |
+
|
130 |
+
Please download the annotation of the final mixture of our instruction tuning data [GeoChat_Instruct.json](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/GeoChat_Instruct.json), and download the split image zips from the [hugging face](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct). Save the multiple image zips in a single folder and run the following command to merge them:
|
131 |
+
```Shell
|
132 |
+
cat images_parta* > images.zip
|
133 |
+
```
|
134 |
+
Unzip the images.zip file to a folder and give the folder's path in [finetune_lora.sh](https://github.com/mbzuai-oryx/GeoChat/blob/main/scripts/finetune_lora.sh).
|
135 |
+
|
136 |
+
2. Start training!
|
137 |
+
|
138 |
+
Visual instruction tuning takes more time due to the increased resolution of CLIP to 504X504. It takes around ~25 hours to finetune GeoChat-7B on 3x A100 (40G).
|
139 |
+
|
140 |
+
Training script with DeepSpeed ZeRO-3: [`finetune_lora.sh`](https://github.com/mbzuai-oryx/GeoChat/blob/main/scripts/finetune_lora.sh).
|
141 |
+
|
142 |
+
Options to note:
|
143 |
+
|
144 |
+
- `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
|
145 |
+
- `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
|
146 |
+
- `--image_aspect_ratio pad`: this pads the non-square images to square, instead of cropping them; it slightly reduces hallucination.
|
147 |
+
- `--group_by_modality_length True`: this should only be used when your instruction tuning dataset contains both language (e.g. ShareGPT) and multimodal (e.g. LLaVA-Instruct).
|
148 |
+
-
|
149 |
+
## Evaluation
|
150 |
+
|
151 |
+
We evaluate GeoChat on a diverse set of 7 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs.
|
152 |
+
See [Evaluation.md](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/Evaluation.md).
|
153 |
+
|
154 |
+
## 🏆 Contributions
|
155 |
+
|
156 |
+
- **RS multimodal instruction following dataset.** We present a novel data generation pipeline, to leverage existing object detection dataset to create short descriptions of the images, followed by using Vicuna-v1.5 to create conversations using the generated text alone. Further, we add visual question-answering and scene classification abilities
|
157 |
+
using their corresponding datasets. This results in a total of 318k instruction pairs for RS domain.
|
158 |
+
- **GeoChat.** Leveraging our dataset, we finetune LLaVA-1.5 to create the remote sensing-domain vision-language model - GeoChat. Our LoRA fine-tuning is efficient and avoids forgetting the necessary context embedded in fully-tuned LLaVA model, whose MLP projection is trained to align images into the word embedding space of the LLM (Vicuna-v1.5). This allows GeoChat to retain the conversation and instruction following abilities of LLaVA and extend its domain-knowledge to remote sensing tasks.
|
159 |
+
|
160 |
+
- **Evaluation Benchmark.** We also address the lack of evaluation benchmarks to assess the capability of existing VLMs on remote-sensing conversations. To this end, we setup evaluation protocols for conversation grounding in RS, as well as a setup a suite of tasks to allow comparisons with future efforts in this direction. We show various supervised as well as zero-shot evaluations for different remote sensing tasks, including image captioning, visual question answering and scene classification to demonstrate the generalisability of GeoChat conversational VLM.
|
161 |
+
|
162 |
+
---
|
163 |
+
## 👁️💬 GeoChat : Grounded Large Vision-Language Model for Remote Sensing
|
164 |
+
|
165 |
+
GeoChat can accomplish multiple tasks for remote-sensing (RS) image comprehension in a unified framework. Given suitable task tokens and user queries, the model can generate visually grounded responses (text with corresponding object locations - shown on top), visual question answering on images and regions (top left and bottom right, respectively) as well as scene classification (top right) and normal natural language conversations (bottom). This makes it the first RS VLM with grounding capability.
|
166 |
+
|
167 |
+
<p align="center">
|
168 |
+
<img src="images/overview2.png" alt="GeoChat Overview">
|
169 |
+
</p>
|
170 |
+
|
171 |
+
---
|
172 |
+
|
173 |
+
## 🛰️ GeoChat : Architecture
|
174 |
+
|
175 |
+
An overview of GeoChat - the first grounded large vision-language model for remote sensing. Given an image input together with a user query, a visual backbone is first used to encode patch-level tokens at a higher resolution via interpolating positional encodings. A multi-layer perceptron (MLP) is used to adapt vision-tokens to language space suitable for input to a Large Language Model (Vicuna 1.5). Besides visual inputs, region locations can also be input to the model together with task-specific prompts that specify the desired task required by the user. Given this context, the LLM can generate natural language responses interleaved with corresponding object locations. GeoChat can perform multiple tasks as shown on top e.g., scene classification, image/region captioning, VQA and grounded conversations.
|
176 |
+
|
177 |
+
<p align="center">
|
178 |
+
<img src="images/architecture.png" alt="GeoChat Architectural">
|
179 |
+
</p>
|
180 |
+
|
181 |
+
---
|
182 |
+
|
183 |
+
## 🔍 RS Multimodal Instruction Dataset
|
184 |
+
|
185 |
+
Types of annotations available in the GeoChat instruction-set. For a given RS image, we obtain object attribute and relationship information, referring expressions and region captions along with their corresponding region annotations (shown over the image). This structured information is used to create the rich instruction-set with a total of 318k image-instruction pairs.
|
186 |
+
|
187 |
+
<p align="center">
|
188 |
+
<img src="images/dataset.png" alt="Dataset Annotation Pipeline">
|
189 |
+
</p>
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
## 🤖 Qualitative results of GeoChat
|
194 |
+
|
195 |
+
Qualitative results of GeoChat. (<em>left-right</em>) Results are shown on grounding, referring object detection, and disaster/damage detection. The user can provide task-specific tokens (e.g., <strong>[grounding]</strong>) to shape model responses according to the desired behavior. The model can generate textual responses (<em>right</em>), only visual grounding (<em>center</em>) and both text and object groundings interleaved together (<em>left</em>). The model can also specify object types, object counts, object attributes and object relationships.
|
196 |
+
<p align="center">
|
197 |
+
<img src="images/examples.png" alt="Results_GCG">
|
198 |
+
</p>
|
199 |
+
|
200 |
+
---
|
201 |
+
|
202 |
+
## 🤖 Visual Question Answering
|
203 |
+
Qualitative examples for Visual Question Answering tasks. GeoChat is able to hold multi-turn conversations, based on various types of questions, including presence, count, complex comparisons and so on. It is able to detect objects and hold conversations against low resolution images as well.
|
204 |
+
<p align="center">
|
205 |
+
<img src="images/vqa.jpg" alt="Visual Question Answering">
|
206 |
+
</p>
|
207 |
+
|
208 |
+
---
|
209 |
+
|
210 |
+
## 🤖 Scene Classification
|
211 |
+
Qualitative examples for scene classification. We give the model all the classes from the dataset and ask to choose only one.
|
212 |
+
<p align="center">
|
213 |
+
<img src="images/scene.jpg" alt="Visual Question Answering">
|
214 |
+
</p>
|
215 |
+
|
216 |
+
---
|
217 |
+
|
218 |
+
## 🤖 Grounded Description
|
219 |
+
When asked to describe the image with the special token '[grounding]', GeoChat outputs both the description of the image as well as the bounding boxes for all the objects detected.
|
220 |
+
<p align="center">
|
221 |
+
<img src="images/grounded.jpg" alt="Grounded Description">
|
222 |
+
</p>
|
223 |
+
|
224 |
+
---
|
225 |
+
|
226 |
+
## 🤖 Referring Expression
|
227 |
+
When asked about an object as a referred expression, GeoChat is able to locate it and draw rotated bounding boxes around it correspondingly.
|
228 |
+
<p align="center">
|
229 |
+
<img src="images/ref1.jpg" alt="Referring Expression">
|
230 |
+
</p>
|
231 |
+
<p align="center">
|
232 |
+
<img src="images/ref_2.jpg" alt="Referring Expression">
|
233 |
+
</p>
|
234 |
+
|
235 |
+
---
|
236 |
+
|
237 |
+
## 🤖 Region Caption
|
238 |
+
Qualitative examples for region-based captioning. Given a bounding box, GeoChat is able to provide brief descriptions about the area or the object covered by the bounding box.
|
239 |
+
<p align="center">
|
240 |
+
<img src="images/iden.jpg" alt="Region Caption">
|
241 |
+
</p>
|
242 |
+
|
243 |
+
---
|
244 |
+
|
245 |
+
## 📜 Citation
|
246 |
+
```bibtex
|
247 |
+
@article{kuckreja2023geochat,
|
248 |
+
title={GeoChat: Grounded Large Vision-Language Model for Remote Sensing},
|
249 |
+
author={Kuckreja, Kartik and Danish, Muhammad S. and Naseer, Muzammal and Das, Abhijit and Khan, Salman and Khan, Fahad S.},
|
250 |
+
journal={The IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
251 |
+
year={2024}
|
252 |
+
}
|
253 |
+
```
|
254 |
+
## 🙏 Acknowledgement
|
255 |
+
We are thankful to LLaVA and Vicuna for releasing their models and code as open-source contributions.
|
256 |
+
|
257 |
+
---
|
258 |
+
[<img src="images/IVAL_logo.png" width="200" height="100">](https://www.ival-mbzuai.com)
|
259 |
+
[<img src="images/Oryx_logo.png" width="100" height="100">](https://github.com/mbzuai-oryx)
|
260 |
+
[<img src="images/MBZUAI_logo.png" width="360" height="85">](https://mbzuai.ac.ae)
|
geochat.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
README.md
|
2 |
+
pyproject.toml
|
3 |
+
geochat/__init__.py
|
4 |
+
geochat/constants.py
|
5 |
+
geochat/conversation.py
|
6 |
+
geochat/mm_utils.py
|
7 |
+
geochat/utils.py
|
8 |
+
geochat.egg-info/PKG-INFO
|
9 |
+
geochat.egg-info/SOURCES.txt
|
10 |
+
geochat.egg-info/dependency_links.txt
|
11 |
+
geochat.egg-info/requires.txt
|
12 |
+
geochat.egg-info/top_level.txt
|
13 |
+
geochat/eval/batch_geochat_grounding.py
|
14 |
+
geochat/eval/batch_geochat_referring.py
|
15 |
+
geochat/eval/batch_geochat_scene.py
|
16 |
+
geochat/eval/batch_geochat_vqa.py
|
17 |
+
geochat/model/__init__.py
|
18 |
+
geochat/model/apply_delta.py
|
19 |
+
geochat/model/builder.py
|
20 |
+
geochat/model/consolidate.py
|
21 |
+
geochat/model/geochat_arch.py
|
22 |
+
geochat/model/make_delta.py
|
23 |
+
geochat/model/utils.py
|
24 |
+
geochat/model/language_model/geochat_llama.py
|
25 |
+
geochat/model/language_model/geochat_mpt.py
|
26 |
+
geochat/model/language_model/mpt/adapt_tokenizer.py
|
27 |
+
geochat/model/language_model/mpt/attention.py
|
28 |
+
geochat/model/language_model/mpt/blocks.py
|
29 |
+
geochat/model/language_model/mpt/configuration_mpt.py
|
30 |
+
geochat/model/language_model/mpt/custom_embedding.py
|
31 |
+
geochat/model/language_model/mpt/flash_attn_triton.py
|
32 |
+
geochat/model/language_model/mpt/hf_prefixlm_converter.py
|
33 |
+
geochat/model/language_model/mpt/meta_init_context.py
|
34 |
+
geochat/model/language_model/mpt/modeling_mpt.py
|
35 |
+
geochat/model/language_model/mpt/norm.py
|
36 |
+
geochat/model/language_model/mpt/param_init_fns.py
|
37 |
+
geochat/model/multimodal_encoder/builder.py
|
38 |
+
geochat/model/multimodal_encoder/clip_encoder.py
|
39 |
+
geochat/model/multimodal_projector/builder.py
|
40 |
+
geochat/serve/__init__.py
|
41 |
+
geochat/serve/cli.py
|
42 |
+
geochat/serve/controller.py
|
43 |
+
geochat/serve/gradio_trial.py
|
44 |
+
geochat/serve/gradio_web_server.py
|
45 |
+
geochat/serve/model_worker.py
|
46 |
+
geochat/serve/register_worker.py
|
47 |
+
geochat/serve/test_message.py
|
48 |
+
geochat/train/geochat_trainer.py
|
49 |
+
geochat/train/llama_flash_attn_monkey_patch.py
|
50 |
+
geochat/train/train.py
|
51 |
+
geochat/train/train_mem.py
|
geochat.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
geochat.egg-info/requires.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einops
|
2 |
+
fastapi
|
3 |
+
gradio==3.35.2
|
4 |
+
markdown2[all]
|
5 |
+
numpy
|
6 |
+
requests
|
7 |
+
sentencepiece
|
8 |
+
tokenizers>=0.12.1
|
9 |
+
torch==2.0.1
|
10 |
+
torchvision==0.15.2
|
11 |
+
uvicorn
|
12 |
+
wandb
|
13 |
+
shortuuid
|
14 |
+
httpx==0.24.0
|
15 |
+
peft==0.4.0
|
16 |
+
transformers==4.31.0
|
17 |
+
accelerate==0.21.0
|
18 |
+
bitsandbytes==0.41.0
|
19 |
+
scikit-learn==1.2.2
|
20 |
+
sentencepiece==0.1.99
|
21 |
+
einops==0.6.1
|
22 |
+
einops-exts==0.0.4
|
23 |
+
timm==0.6.13
|
24 |
+
gradio_client==0.2.9
|
geochat.egg-info/top_level.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
demo_images
|
2 |
+
geochat
|
3 |
+
images
|
geochat/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import GeoChatLlamaForCausalLM
|
geochat/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (186 Bytes). View file
|
|
geochat/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (446 Bytes). View file
|
|
geochat/__pycache__/conversation.cpython-310.pyc
ADDED
Binary file (14.1 kB). View file
|
|
geochat/__pycache__/mm_utils.cpython-310.pyc
ADDED
Binary file (4.92 kB). View file
|
|
geochat/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (4.01 kB). View file
|
|
geochat/constants.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
IMAGE_TOKEN_INDEX = -200
|
9 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
10 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
11 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
12 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
geochat/conversation.py
ADDED
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
from PIL import Image
|
5 |
+
from threading import Thread
|
6 |
+
|
7 |
+
from geochat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
8 |
+
# from llava.conversation import conv_templates, SeparatorStyle
|
9 |
+
# from llava.model.builder import load_pretrained_model
|
10 |
+
from geochat.utils import disable_torch_init
|
11 |
+
from geochat.mm_utils import process_images_demo, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
12 |
+
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer,TextStreamer
|
13 |
+
import torch
|
14 |
+
import dataclasses
|
15 |
+
from enum import auto, Enum
|
16 |
+
from typing import List, Tuple, Any
|
17 |
+
|
18 |
+
|
19 |
+
class SeparatorStyle(Enum):
|
20 |
+
"""Different separator style."""
|
21 |
+
SINGLE = auto()
|
22 |
+
TWO = auto()
|
23 |
+
MPT = auto()
|
24 |
+
PLAIN = auto()
|
25 |
+
LLAMA_2 = auto()
|
26 |
+
|
27 |
+
|
28 |
+
@dataclasses.dataclass
|
29 |
+
class Conversation:
|
30 |
+
"""A class that keeps all conversation history."""
|
31 |
+
system: str
|
32 |
+
roles: List[str]
|
33 |
+
messages: List[List[str]]
|
34 |
+
offset: int
|
35 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
36 |
+
sep: str = "###"
|
37 |
+
sep2: str = None
|
38 |
+
version: str = "Unknown"
|
39 |
+
|
40 |
+
skip_next: bool = False
|
41 |
+
|
42 |
+
def get_prompt(self):
|
43 |
+
messages = self.messages
|
44 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
45 |
+
messages = self.messages.copy()
|
46 |
+
init_role, init_msg = messages[0].copy()
|
47 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
48 |
+
if 'mmtag' in self.version:
|
49 |
+
messages[0] = (init_role, init_msg)
|
50 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
51 |
+
messages.insert(1, (self.roles[1], "Received."))
|
52 |
+
else:
|
53 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
54 |
+
|
55 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
56 |
+
ret = self.system + self.sep
|
57 |
+
for role, message in messages:
|
58 |
+
if message:
|
59 |
+
if type(message) is tuple:
|
60 |
+
message, _, _ = message
|
61 |
+
ret += role + ": " + message + self.sep
|
62 |
+
else:
|
63 |
+
ret += role + ":"
|
64 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
65 |
+
seps = [self.sep, self.sep2]
|
66 |
+
ret = self.system + seps[0]
|
67 |
+
for i, (role, message) in enumerate(messages):
|
68 |
+
if message:
|
69 |
+
if type(message) is tuple:
|
70 |
+
message, _, _ = message
|
71 |
+
ret += role + ": " + message + seps[i % 2]
|
72 |
+
else:
|
73 |
+
ret += role + ":"
|
74 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
75 |
+
ret = self.system + self.sep
|
76 |
+
for role, message in messages:
|
77 |
+
if message:
|
78 |
+
if type(message) is tuple:
|
79 |
+
message, _, _ = message
|
80 |
+
ret += role + message + self.sep
|
81 |
+
else:
|
82 |
+
ret += role
|
83 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
84 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
85 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
86 |
+
ret = ""
|
87 |
+
|
88 |
+
for i, (role, message) in enumerate(messages):
|
89 |
+
if i == 0:
|
90 |
+
assert message, "first message should not be none"
|
91 |
+
assert role == self.roles[0], "first message should come from user"
|
92 |
+
if message:
|
93 |
+
if type(message) is tuple:
|
94 |
+
message, _, _ = message
|
95 |
+
if i == 0: message = wrap_sys(self.system) + message
|
96 |
+
if i % 2 == 0:
|
97 |
+
message = wrap_inst(message)
|
98 |
+
ret += self.sep + message
|
99 |
+
else:
|
100 |
+
ret += " " + message + " " + self.sep2
|
101 |
+
else:
|
102 |
+
ret += ""
|
103 |
+
ret = ret.lstrip(self.sep)
|
104 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
105 |
+
seps = [self.sep, self.sep2]
|
106 |
+
ret = self.system
|
107 |
+
for i, (role, message) in enumerate(messages):
|
108 |
+
if message:
|
109 |
+
if type(message) is tuple:
|
110 |
+
message, _, _ = message
|
111 |
+
ret += message + seps[i % 2]
|
112 |
+
else:
|
113 |
+
ret += ""
|
114 |
+
else:
|
115 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
116 |
+
|
117 |
+
return ret
|
118 |
+
|
119 |
+
def append_message(self, role, message):
|
120 |
+
self.messages.append([role, message])
|
121 |
+
|
122 |
+
def get_images(self, return_pil=False):
|
123 |
+
images = []
|
124 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
125 |
+
if i % 2 == 0:
|
126 |
+
if type(msg) is tuple:
|
127 |
+
import base64
|
128 |
+
from io import BytesIO
|
129 |
+
from PIL import Image
|
130 |
+
msg, image, image_process_mode = msg
|
131 |
+
if image_process_mode == "Pad":
|
132 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
133 |
+
width, height = pil_img.size
|
134 |
+
if width == height:
|
135 |
+
return pil_img
|
136 |
+
elif width > height:
|
137 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
138 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
139 |
+
return result
|
140 |
+
else:
|
141 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
142 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
143 |
+
return result
|
144 |
+
image = expand2square(image)
|
145 |
+
elif image_process_mode in ["Default", "Crop"]:
|
146 |
+
pass
|
147 |
+
elif image_process_mode == "Resize":
|
148 |
+
image = image.resize((336, 336))
|
149 |
+
else:
|
150 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
151 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
152 |
+
aspect_ratio = max_hw / min_hw
|
153 |
+
max_len, min_len = 800, 400
|
154 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
155 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
156 |
+
W, H = image.size
|
157 |
+
if longest_edge != max(image.size):
|
158 |
+
if H > W:
|
159 |
+
H, W = longest_edge, shortest_edge
|
160 |
+
else:
|
161 |
+
H, W = shortest_edge, longest_edge
|
162 |
+
image = image.resize((W, H))
|
163 |
+
if return_pil:
|
164 |
+
images.append(image)
|
165 |
+
else:
|
166 |
+
buffered = BytesIO()
|
167 |
+
image.save(buffered, format="PNG")
|
168 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
169 |
+
images.append(img_b64_str)
|
170 |
+
return images
|
171 |
+
|
172 |
+
def to_gradio_chatbot(self):
|
173 |
+
ret = []
|
174 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
175 |
+
if i % 2 == 0:
|
176 |
+
if type(msg) is tuple:
|
177 |
+
import base64
|
178 |
+
from io import BytesIO
|
179 |
+
msg, image, image_process_mode = msg
|
180 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
181 |
+
aspect_ratio = max_hw / min_hw
|
182 |
+
max_len, min_len = 800, 400
|
183 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
184 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
185 |
+
W, H = image.size
|
186 |
+
if H > W:
|
187 |
+
H, W = longest_edge, shortest_edge
|
188 |
+
else:
|
189 |
+
H, W = shortest_edge, longest_edge
|
190 |
+
image = image.resize((W, H))
|
191 |
+
buffered = BytesIO()
|
192 |
+
image.save(buffered, format="JPEG")
|
193 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
194 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
195 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
196 |
+
ret.append([msg, None])
|
197 |
+
else:
|
198 |
+
ret.append([msg, None])
|
199 |
+
else:
|
200 |
+
ret[-1][-1] = msg
|
201 |
+
return ret
|
202 |
+
|
203 |
+
def copy(self):
|
204 |
+
return Conversation(
|
205 |
+
system=self.system,
|
206 |
+
roles=self.roles,
|
207 |
+
messages=[[x, y] for x, y in self.messages],
|
208 |
+
offset=self.offset,
|
209 |
+
sep_style=self.sep_style,
|
210 |
+
sep=self.sep,
|
211 |
+
sep2=self.sep2,
|
212 |
+
version=self.version)
|
213 |
+
|
214 |
+
def dict(self):
|
215 |
+
if len(self.get_images()) > 0:
|
216 |
+
return {
|
217 |
+
"system": self.system,
|
218 |
+
"roles": self.roles,
|
219 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
220 |
+
"offset": self.offset,
|
221 |
+
"sep": self.sep,
|
222 |
+
"sep2": self.sep2,
|
223 |
+
}
|
224 |
+
return {
|
225 |
+
"system": self.system,
|
226 |
+
"roles": self.roles,
|
227 |
+
"messages": self.messages,
|
228 |
+
"offset": self.offset,
|
229 |
+
"sep": self.sep,
|
230 |
+
"sep2": self.sep2,
|
231 |
+
}
|
232 |
+
|
233 |
+
|
234 |
+
conv_vicuna_v0 = Conversation(
|
235 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
236 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
237 |
+
roles=("Human", "Assistant"),
|
238 |
+
messages=(
|
239 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
240 |
+
("Assistant",
|
241 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
242 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
243 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
244 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
245 |
+
"renewable and non-renewable energy sources:\n"
|
246 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
247 |
+
"energy sources are finite and will eventually run out.\n"
|
248 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
249 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
250 |
+
"and other negative effects.\n"
|
251 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
252 |
+
"have lower operational costs than non-renewable sources.\n"
|
253 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
254 |
+
"locations than non-renewable sources.\n"
|
255 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
256 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
257 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
258 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
259 |
+
),
|
260 |
+
offset=2,
|
261 |
+
sep_style=SeparatorStyle.SINGLE,
|
262 |
+
sep="###",
|
263 |
+
)
|
264 |
+
|
265 |
+
conv_vicuna_v1 = Conversation(
|
266 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
267 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
268 |
+
roles=("USER", "ASSISTANT"),
|
269 |
+
version="v1",
|
270 |
+
messages=(),
|
271 |
+
offset=0,
|
272 |
+
sep_style=SeparatorStyle.TWO,
|
273 |
+
sep=" ",
|
274 |
+
sep2="</s>",
|
275 |
+
)
|
276 |
+
|
277 |
+
conv_llama_2 = Conversation(
|
278 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
279 |
+
|
280 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
281 |
+
roles=("USER", "ASSISTANT"),
|
282 |
+
version="llama_v2",
|
283 |
+
messages=(),
|
284 |
+
offset=0,
|
285 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
286 |
+
sep="<s>",
|
287 |
+
sep2="</s>",
|
288 |
+
)
|
289 |
+
|
290 |
+
conv_llava_llama_2 = Conversation(
|
291 |
+
system="You are a helpful language and vision assistant. "
|
292 |
+
"You are able to understand the visual content that the user provides, "
|
293 |
+
"and assist the user with a variety of tasks using natural language.",
|
294 |
+
roles=("USER", "ASSISTANT"),
|
295 |
+
version="llama_v2",
|
296 |
+
messages=(),
|
297 |
+
offset=0,
|
298 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
299 |
+
sep="<s>",
|
300 |
+
sep2="</s>",
|
301 |
+
)
|
302 |
+
|
303 |
+
conv_mpt = Conversation(
|
304 |
+
system="""<|im_start|>system
|
305 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
306 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
307 |
+
version="mpt",
|
308 |
+
messages=(),
|
309 |
+
offset=0,
|
310 |
+
sep_style=SeparatorStyle.MPT,
|
311 |
+
sep="<|im_end|>",
|
312 |
+
)
|
313 |
+
|
314 |
+
conv_llava_plain = Conversation(
|
315 |
+
system="",
|
316 |
+
roles=("", ""),
|
317 |
+
messages=(
|
318 |
+
),
|
319 |
+
offset=0,
|
320 |
+
sep_style=SeparatorStyle.PLAIN,
|
321 |
+
sep="\n",
|
322 |
+
)
|
323 |
+
|
324 |
+
conv_llava_v0 = Conversation(
|
325 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
326 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
327 |
+
roles=("Human", "Assistant"),
|
328 |
+
messages=(
|
329 |
+
),
|
330 |
+
offset=0,
|
331 |
+
sep_style=SeparatorStyle.SINGLE,
|
332 |
+
sep="###",
|
333 |
+
)
|
334 |
+
|
335 |
+
conv_llava_v0_mmtag = Conversation(
|
336 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
337 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
338 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
339 |
+
roles=("Human", "Assistant"),
|
340 |
+
messages=(
|
341 |
+
),
|
342 |
+
offset=0,
|
343 |
+
sep_style=SeparatorStyle.SINGLE,
|
344 |
+
sep="###",
|
345 |
+
version="v0_mmtag",
|
346 |
+
)
|
347 |
+
|
348 |
+
conv_llava_v1 = Conversation(
|
349 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
350 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
351 |
+
roles=("USER", "ASSISTANT"),
|
352 |
+
version="v1",
|
353 |
+
messages=(),
|
354 |
+
offset=0,
|
355 |
+
sep_style=SeparatorStyle.TWO,
|
356 |
+
sep=" ",
|
357 |
+
sep2="</s>",
|
358 |
+
)
|
359 |
+
|
360 |
+
conv_llava_v1_mmtag = Conversation(
|
361 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
362 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
363 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
364 |
+
roles=("USER", "ASSISTANT"),
|
365 |
+
messages=(),
|
366 |
+
offset=0,
|
367 |
+
sep_style=SeparatorStyle.TWO,
|
368 |
+
sep=" ",
|
369 |
+
sep2="</s>",
|
370 |
+
version="v1_mmtag",
|
371 |
+
)
|
372 |
+
|
373 |
+
default_conversation = conv_vicuna_v0
|
374 |
+
conv_templates = {
|
375 |
+
"default": conv_vicuna_v0,
|
376 |
+
"v0": conv_vicuna_v0,
|
377 |
+
"v1": conv_vicuna_v1,
|
378 |
+
"vicuna_v1": conv_vicuna_v1,
|
379 |
+
"llama_2": conv_llama_2,
|
380 |
+
|
381 |
+
"plain": conv_llava_plain,
|
382 |
+
"v0_plain": conv_llava_plain,
|
383 |
+
"llava_v0": conv_llava_v0,
|
384 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
385 |
+
"llava_v1": conv_llava_v1,
|
386 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
387 |
+
"llava_llama_2": conv_llava_llama_2,
|
388 |
+
|
389 |
+
"mpt": conv_mpt,
|
390 |
+
}
|
391 |
+
|
392 |
+
class Chat:
|
393 |
+
def __init__(self, model, image_processor,tokenizer, device='cuda:0', stopping_criteria=None):
|
394 |
+
self.device = device
|
395 |
+
self.model = model
|
396 |
+
self.vis_processor = image_processor
|
397 |
+
self.tokenizer=tokenizer
|
398 |
+
|
399 |
+
# if stopping_criteria is not None:
|
400 |
+
# self.stopping_criteria = stopping_criteria
|
401 |
+
# else:
|
402 |
+
# stop_words_ids = [torch.tensor([2]).to(self.device)]
|
403 |
+
# self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
404 |
+
|
405 |
+
def ask(self, text, conv):
|
406 |
+
# import pdb;pdb.set_trace()
|
407 |
+
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
408 |
+
and conv.messages[-1][1][-9:] == '<image>\n': # last message is image.
|
409 |
+
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
410 |
+
else:
|
411 |
+
conv.append_message(conv.roles[0], text)
|
412 |
+
|
413 |
+
def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
414 |
+
repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
|
415 |
+
conv.append_message(conv.roles[1], None)
|
416 |
+
prompt = conv.get_prompt()
|
417 |
+
# prompt='A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\'s questions. USER: <image>\n hello ASSISTANT:'
|
418 |
+
text_input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device=self.device)
|
419 |
+
|
420 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
421 |
+
keywords = [stop_str]
|
422 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, text_input_ids)
|
423 |
+
current_max_len = text_input_ids.shape[1] + max_new_tokens
|
424 |
+
if current_max_len - max_length > 0:
|
425 |
+
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
426 |
+
'The model will not see the contexts outside the range.')
|
427 |
+
begin_idx = max(0, current_max_len - max_length)
|
428 |
+
embs = text_input_ids[:, begin_idx:]
|
429 |
+
|
430 |
+
generation_kwargs = dict(
|
431 |
+
input_ids=embs,
|
432 |
+
images=img_list[0],
|
433 |
+
max_new_tokens=max_new_tokens,
|
434 |
+
stopping_criteria=[stopping_criteria],
|
435 |
+
num_beams=num_beams,
|
436 |
+
do_sample=True,
|
437 |
+
min_length=min_length,
|
438 |
+
top_p=top_p,
|
439 |
+
use_cache=True,
|
440 |
+
repetition_penalty=repetition_penalty,
|
441 |
+
length_penalty=length_penalty,
|
442 |
+
temperature=float(temperature),
|
443 |
+
)
|
444 |
+
return generation_kwargs
|
445 |
+
|
446 |
+
# def answer(self, conv, img_list, **kargs):
|
447 |
+
# generation_dict = self.answer_prepare(conv, img_list, **kargs)
|
448 |
+
# output_token = self.model_generate(**generation_dict)[0]
|
449 |
+
# output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
450 |
+
|
451 |
+
# output_text = output_text.split('###')[0] # remove the stop sign '###'
|
452 |
+
# output_text = output_text.split('Assistant:')[-1].strip()
|
453 |
+
|
454 |
+
# conv.messages[-1][1] = output_text
|
455 |
+
# return output_text, output_token.cpu().numpy()
|
456 |
+
|
457 |
+
def stream_answer(self, conv, img_list, **kargs):
|
458 |
+
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
|
459 |
+
|
460 |
+
streamer = TextIteratorStreamer(self.tokenizer,skip_prompt=True, skip_special_tokens=True)
|
461 |
+
generation_kwargs['streamer'] = streamer
|
462 |
+
# import pdb;pdb.set_trace()
|
463 |
+
# output_ids=self.model.generate(*generation_kwargs)
|
464 |
+
output=self.model_generate(kwargs=generation_kwargs)
|
465 |
+
# thread = Thread(target=self.model_generate, kwargs=generation_kwargs)
|
466 |
+
# thread.start()
|
467 |
+
return streamer
|
468 |
+
|
469 |
+
def model_generate(self, *args, **kwargs):
|
470 |
+
# for 8 bit and 16 bit compatibility
|
471 |
+
with torch.inference_mode():
|
472 |
+
output = self.model.generate(kwargs['kwargs']['input_ids'],
|
473 |
+
images=kwargs['kwargs']['images'],
|
474 |
+
do_sample=False,
|
475 |
+
temperature=kwargs['kwargs']['temperature'],
|
476 |
+
max_new_tokens=kwargs['kwargs']['max_new_tokens'],
|
477 |
+
streamer=kwargs['kwargs']['streamer'],
|
478 |
+
use_cache=kwargs['kwargs']['use_cache'],
|
479 |
+
stopping_criteria=kwargs['kwargs']['stopping_criteria'])
|
480 |
+
# import pdb;pdb.set_trace()
|
481 |
+
# print(output)
|
482 |
+
outputs = self.tokenizer.decode(output[0,kwargs['kwargs']['input_ids'].shape[1]:]).strip()
|
483 |
+
# print(outputs)
|
484 |
+
return output
|
485 |
+
|
486 |
+
def encode_img(self, img_list):
|
487 |
+
|
488 |
+
image = img_list[0]
|
489 |
+
# image='/share/data/drive_3/kartik/LLaVA/output_images/output.jpg'
|
490 |
+
img_list.pop(0)
|
491 |
+
if isinstance(image, str): # is a image path
|
492 |
+
raw_image = Image.open(image).convert('RGB')
|
493 |
+
image = process_images_demo([raw_image], self.vis_processor)
|
494 |
+
# print("raw")
|
495 |
+
# image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
496 |
+
elif isinstance(image, Image.Image):
|
497 |
+
raw_image = image
|
498 |
+
image = process_images_demo([raw_image], self.vis_processor )
|
499 |
+
image=image.to(device=self.device,dtype=torch.float16)
|
500 |
+
# print("Image")
|
501 |
+
# image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
502 |
+
elif isinstance(image, torch.Tensor):
|
503 |
+
if len(image.shape) == 3:
|
504 |
+
image = image.unsqueeze(0)
|
505 |
+
image = image.to(self.device)
|
506 |
+
|
507 |
+
# image_emb, _ = self.model.encode_img(image)
|
508 |
+
img_list.append(image)
|
509 |
+
|
510 |
+
def upload_img(self, image, conv, img_list):
|
511 |
+
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN+'\n')
|
512 |
+
img_list.append(image)
|
513 |
+
msg = "Received."
|
514 |
+
|
515 |
+
return msg
|
516 |
+
|
517 |
+
|
518 |
+
|
519 |
+
# if __name__ == "__main__":
|
520 |
+
# print(default_conversation.get_prompt())
|
geochat/eval/batch_geochat_grounding.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
import shortuuid
|
7 |
+
|
8 |
+
from geochat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
9 |
+
from geochat.conversation import conv_templates, SeparatorStyle
|
10 |
+
from geochat.model.builder import load_pretrained_model
|
11 |
+
from geochat.utils import disable_torch_init
|
12 |
+
from geochat.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
13 |
+
|
14 |
+
from PIL import Image
|
15 |
+
import math
|
16 |
+
def split_list(lst, n):
|
17 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
18 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
19 |
+
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
20 |
+
|
21 |
+
|
22 |
+
def get_chunk(lst, n, k):
|
23 |
+
chunks = split_list(lst, n)
|
24 |
+
return chunks[k]
|
25 |
+
|
26 |
+
|
27 |
+
def eval_model(args):
|
28 |
+
# Model
|
29 |
+
disable_torch_init()
|
30 |
+
model_path = os.path.expanduser(args.model_path)
|
31 |
+
model_name = get_model_name_from_path(model_path)
|
32 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)
|
33 |
+
import pdb;pdb.set_trace()
|
34 |
+
# print(model)
|
35 |
+
questions=[]
|
36 |
+
questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
|
37 |
+
|
38 |
+
|
39 |
+
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
40 |
+
answers_file = os.path.expanduser(args.answers_file)
|
41 |
+
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
42 |
+
|
43 |
+
ans_file = open(answers_file, "w")
|
44 |
+
|
45 |
+
for i in tqdm(range(0,len(questions),args.batch_size)):
|
46 |
+
input_batch=[]
|
47 |
+
input_image_batch=[]
|
48 |
+
count=i
|
49 |
+
image_folder=[]
|
50 |
+
batch_end = min(i + args.batch_size, len(questions))
|
51 |
+
|
52 |
+
|
53 |
+
for j in range(i,batch_end):
|
54 |
+
image_file=questions[j]['image_id']+'.png'
|
55 |
+
|
56 |
+
if questions[j]['type']=='ref':
|
57 |
+
qs="[refer] Give me the location of <p> " + qs+" </p>"
|
58 |
+
else:
|
59 |
+
qs="[grounding]" + qs
|
60 |
+
|
61 |
+
if model.config.mm_use_im_start_end:
|
62 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
|
63 |
+
else:
|
64 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
65 |
+
|
66 |
+
conv = conv_templates[args.conv_mode].copy()
|
67 |
+
conv.append_message(conv.roles[0], qs)
|
68 |
+
conv.append_message(conv.roles[1], None)
|
69 |
+
prompt = conv.get_prompt()
|
70 |
+
|
71 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
72 |
+
input_batch.append(input_ids)
|
73 |
+
|
74 |
+
image = Image.open(os.path.join(args.image_folder, image_file))
|
75 |
+
|
76 |
+
image_folder.append(image)
|
77 |
+
|
78 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
79 |
+
keywords = [stop_str]
|
80 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
81 |
+
|
82 |
+
max_length = max(tensor.size(1) for tensor in input_batch)
|
83 |
+
|
84 |
+
final_input_list = [torch.cat((torch.zeros((1,max_length - tensor.size(1)), dtype=tensor.dtype,device=tensor.get_device()), tensor),dim=1) for tensor in input_batch]
|
85 |
+
final_input_tensors=torch.cat(final_input_list,dim=0)
|
86 |
+
image_tensor_batch = image_processor.preprocess(image_folder,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504}, return_tensors='pt')['pixel_values']
|
87 |
+
|
88 |
+
with torch.inference_mode():
|
89 |
+
output_ids = model.generate( final_input_tensors, images=image_tensor_batch.half().cuda(), do_sample=False , temperature=args.temperature, top_p=args.top_p, num_beams=1, max_new_tokens=256,length_penalty=2.0, use_cache=True)
|
90 |
+
|
91 |
+
input_token_len = final_input_tensors.shape[1]
|
92 |
+
n_diff_input_output = (final_input_tensors != output_ids[:, :input_token_len]).sum().item()
|
93 |
+
if n_diff_input_output > 0:
|
94 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
95 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
|
96 |
+
for k in range(0,len(final_input_list)):
|
97 |
+
output = outputs[k].strip()
|
98 |
+
if output.endswith(stop_str):
|
99 |
+
output = output[:-len(stop_str)]
|
100 |
+
output = output.strip()
|
101 |
+
|
102 |
+
ans_id = shortuuid.uuid()
|
103 |
+
|
104 |
+
ans_file.write(json.dumps({
|
105 |
+
|
106 |
+
"question_id": questions[count]["question_id"],
|
107 |
+
"image_id": questions[count]["image_id"],
|
108 |
+
"answer": output,
|
109 |
+
"ground_truth": questions[count]['ground_truth'],
|
110 |
+
"question":questions[count]['question'],
|
111 |
+
"type": questions[count]['type'],
|
112 |
+
"dataset": questions[count]['dataset'],
|
113 |
+
"obj_ids": questions[count]['obj_ids'],
|
114 |
+
"size_group": questions[count]['size_group'],
|
115 |
+
|
116 |
+
}) + "\n")
|
117 |
+
count=count+1
|
118 |
+
ans_file.flush()
|
119 |
+
ans_file.close()
|
120 |
+
|
121 |
+
|
122 |
+
if __name__ == "__main__":
|
123 |
+
parser = argparse.ArgumentParser()
|
124 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
125 |
+
parser.add_argument("--model-base", type=str, default=None)
|
126 |
+
parser.add_argument("--image-folder", type=str, default="")
|
127 |
+
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
|
128 |
+
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
|
129 |
+
parser.add_argument("--conv-mode", type=str, default="llava_v1")
|
130 |
+
parser.add_argument("--num-chunks", type=int, default=1)
|
131 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
132 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
133 |
+
parser.add_argument("--top_p", type=float, default=None)
|
134 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
135 |
+
parser.add_argument("--batch_size",type=int, default=1)
|
136 |
+
args = parser.parse_args()
|
137 |
+
|
138 |
+
eval_model(args)
|
geochat/eval/batch_geochat_referring.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
import shortuuid
|
7 |
+
|
8 |
+
from geochat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
9 |
+
from geochat.conversation import conv_templates, SeparatorStyle
|
10 |
+
from geochat.model.builder import load_pretrained_model
|
11 |
+
from geochat.utils import disable_torch_init
|
12 |
+
from geochat.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
13 |
+
|
14 |
+
from PIL import Image
|
15 |
+
import math
|
16 |
+
def split_list(lst, n):
|
17 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
18 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
19 |
+
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
20 |
+
|
21 |
+
|
22 |
+
def get_chunk(lst, n, k):
|
23 |
+
chunks = split_list(lst, n)
|
24 |
+
return chunks[k]
|
25 |
+
|
26 |
+
|
27 |
+
def eval_model(args):
|
28 |
+
# Model
|
29 |
+
disable_torch_init()
|
30 |
+
model_path = os.path.expanduser(args.model_path)
|
31 |
+
model_name = get_model_name_from_path(model_path)
|
32 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)
|
33 |
+
# print(model)
|
34 |
+
questions=[]
|
35 |
+
questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
|
36 |
+
|
37 |
+
|
38 |
+
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
39 |
+
answers_file = os.path.expanduser(args.answers_file)
|
40 |
+
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
41 |
+
|
42 |
+
ans_file = open(answers_file, "w")
|
43 |
+
|
44 |
+
for i in tqdm(range(0,len(questions),args.batch_size)):
|
45 |
+
input_batch=[]
|
46 |
+
input_image_batch=[]
|
47 |
+
count=i
|
48 |
+
image_folder=[]
|
49 |
+
batch_end = min(i + args.batch_size, len(questions))
|
50 |
+
|
51 |
+
|
52 |
+
for j in range(i,batch_end):
|
53 |
+
image_file=questions[j]['image_id']+'.png'
|
54 |
+
qs="[identify] What is the object present at " + questions[j]['question']
|
55 |
+
|
56 |
+
if model.config.mm_use_im_start_end:
|
57 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
|
58 |
+
else:
|
59 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
60 |
+
|
61 |
+
conv = conv_templates[args.conv_mode].copy()
|
62 |
+
conv.append_message(conv.roles[0], qs)
|
63 |
+
conv.append_message(conv.roles[1], None)
|
64 |
+
prompt = conv.get_prompt()
|
65 |
+
|
66 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
67 |
+
input_batch.append(input_ids)
|
68 |
+
|
69 |
+
image = Image.open(os.path.join(args.image_folder, image_file))
|
70 |
+
|
71 |
+
image_folder.append(image)
|
72 |
+
|
73 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
74 |
+
keywords = [stop_str]
|
75 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
76 |
+
|
77 |
+
max_length = max(tensor.size(1) for tensor in input_batch)
|
78 |
+
|
79 |
+
final_input_list = [torch.cat((torch.zeros((1,max_length - tensor.size(1)), dtype=tensor.dtype,device=tensor.get_device()), tensor),dim=1) for tensor in input_batch]
|
80 |
+
final_input_tensors=torch.cat(final_input_list,dim=0)
|
81 |
+
image_tensor_batch = image_processor.preprocess(image_folder,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504}, return_tensors='pt')['pixel_values']
|
82 |
+
|
83 |
+
with torch.inference_mode():
|
84 |
+
output_ids = model.generate( final_input_tensors, images=image_tensor_batch.half().cuda(), do_sample=False , temperature=args.temperature, top_p=args.top_p, num_beams=1, max_new_tokens=256,length_penalty=2.0, use_cache=True)
|
85 |
+
|
86 |
+
input_token_len = final_input_tensors.shape[1]
|
87 |
+
n_diff_input_output = (final_input_tensors != output_ids[:, :input_token_len]).sum().item()
|
88 |
+
if n_diff_input_output > 0:
|
89 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
90 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
|
91 |
+
for k in range(0,len(final_input_list)):
|
92 |
+
output = outputs[k].strip()
|
93 |
+
if output.endswith(stop_str):
|
94 |
+
output = output[:-len(stop_str)]
|
95 |
+
output = output.strip()
|
96 |
+
|
97 |
+
ans_id = shortuuid.uuid()
|
98 |
+
|
99 |
+
ans_file.write(json.dumps({
|
100 |
+
"question_id": questions[count]["question_id"],
|
101 |
+
"image_id": questions[count]["image_id"],
|
102 |
+
"answer": output,
|
103 |
+
"ground_truth": questions[count]['ground_truth'],
|
104 |
+
"question":questions[count]['question'],
|
105 |
+
"type": questions[count]['type'],
|
106 |
+
"dataset": questions[count]['dataset'],
|
107 |
+
"obj_ids": questions[count]['obj_ids'],
|
108 |
+
"size_group": questions[count]['size_group'],
|
109 |
+
|
110 |
+
}) + "\n")
|
111 |
+
count=count+1
|
112 |
+
ans_file.flush()
|
113 |
+
ans_file.close()
|
114 |
+
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
parser = argparse.ArgumentParser()
|
118 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
119 |
+
parser.add_argument("--model-base", type=str, default=None)
|
120 |
+
parser.add_argument("--image-folder", type=str, default="")
|
121 |
+
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
|
122 |
+
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
|
123 |
+
parser.add_argument("--conv-mode", type=str, default="llava_v1")
|
124 |
+
parser.add_argument("--num-chunks", type=int, default=1)
|
125 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
126 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
127 |
+
parser.add_argument("--top_p", type=float, default=None)
|
128 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
129 |
+
parser.add_argument("--batch_size",type=int, default=1)
|
130 |
+
args = parser.parse_args()
|
131 |
+
|
132 |
+
eval_model(args)
|
geochat/eval/batch_geochat_scene.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
import shortuuid
|
7 |
+
|
8 |
+
from geochat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
9 |
+
from geochat.conversation import conv_templates, SeparatorStyle
|
10 |
+
from geochat.model.builder import load_pretrained_model
|
11 |
+
from geochat.utils import disable_torch_init
|
12 |
+
from geochat.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
13 |
+
|
14 |
+
from PIL import Image
|
15 |
+
import math
|
16 |
+
|
17 |
+
def evaluation_metrics(data_path):
|
18 |
+
|
19 |
+
base = [json.loads(q) for q in open(data_path, "r")]
|
20 |
+
correct=0
|
21 |
+
incorrect=0
|
22 |
+
for answers in tqdm(base):
|
23 |
+
gt=answers['question_id'].split('/')[0].lower()
|
24 |
+
answer=answers['answer'].replace(' ','').lower().replace('.','')
|
25 |
+
if gt==answer:
|
26 |
+
correct=correct+1
|
27 |
+
else:
|
28 |
+
incorrect=incorrect+1
|
29 |
+
# else:
|
30 |
+
# continue
|
31 |
+
print('correct:',correct)
|
32 |
+
print('incorrect:',incorrect)
|
33 |
+
print('Total:',correct+incorrect)
|
34 |
+
print('Acc:',(correct/(correct+incorrect)))
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
def eval_model(args):
|
40 |
+
# Model
|
41 |
+
disable_torch_init()
|
42 |
+
model_path = os.path.expanduser(args.model_path)
|
43 |
+
model_name = get_model_name_from_path(model_path)
|
44 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)
|
45 |
+
# print(model)
|
46 |
+
questions=[]
|
47 |
+
questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
|
48 |
+
|
49 |
+
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
50 |
+
answers_file = os.path.expanduser(args.answers_file)
|
51 |
+
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
52 |
+
|
53 |
+
ans_file = open(answers_file, "w")
|
54 |
+
|
55 |
+
for i in tqdm(range(0,len(questions),args.batch_size)):
|
56 |
+
input_batch=[]
|
57 |
+
input_image_batch=[]
|
58 |
+
count=i
|
59 |
+
image_folder=[]
|
60 |
+
batch_end = min(i + args.batch_size, len(questions))
|
61 |
+
|
62 |
+
|
63 |
+
for j in range(i,batch_end):
|
64 |
+
image_file=questions[j]['image']
|
65 |
+
qs=questions[j]['text']
|
66 |
+
|
67 |
+
if model.config.mm_use_im_start_end:
|
68 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
|
69 |
+
else:
|
70 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
71 |
+
|
72 |
+
conv = conv_templates[args.conv_mode].copy()
|
73 |
+
conv.append_message(conv.roles[0], qs)
|
74 |
+
conv.append_message(conv.roles[1], None)
|
75 |
+
prompt = conv.get_prompt()
|
76 |
+
|
77 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
78 |
+
input_batch.append(input_ids)
|
79 |
+
|
80 |
+
image = Image.open(os.path.join(args.image_folder, image_file))
|
81 |
+
|
82 |
+
image_folder.append(image)
|
83 |
+
|
84 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
85 |
+
keywords = [stop_str]
|
86 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
87 |
+
|
88 |
+
max_length = max(tensor.size(1) for tensor in input_batch)
|
89 |
+
|
90 |
+
final_input_list = [torch.cat((torch.zeros((1,max_length - tensor.size(1)), dtype=tensor.dtype,device=tensor.get_device()), tensor),dim=1) for tensor in input_batch]
|
91 |
+
final_input_tensors=torch.cat(final_input_list,dim=0)
|
92 |
+
image_tensor_batch = image_processor.preprocess(image_folder,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504}, return_tensors='pt')['pixel_values']
|
93 |
+
|
94 |
+
with torch.inference_mode():
|
95 |
+
output_ids = model.generate( final_input_tensors, images=image_tensor_batch.half().cuda(), do_sample=False , temperature=args.temperature, top_p=args.top_p, num_beams=1, max_new_tokens=256,length_penalty=2.0, use_cache=True)
|
96 |
+
|
97 |
+
input_token_len = final_input_tensors.shape[1]
|
98 |
+
n_diff_input_output = (final_input_tensors != output_ids[:, :input_token_len]).sum().item()
|
99 |
+
if n_diff_input_output > 0:
|
100 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
101 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
|
102 |
+
for k in range(0,len(final_input_list)):
|
103 |
+
output = outputs[k].strip()
|
104 |
+
if output.endswith(stop_str):
|
105 |
+
output = output[:-len(stop_str)]
|
106 |
+
output = output.strip()
|
107 |
+
|
108 |
+
ans_id = shortuuid.uuid()
|
109 |
+
|
110 |
+
ans_file.write(json.dumps({
|
111 |
+
|
112 |
+
"question_id": questions[count]["question_id"],
|
113 |
+
"image_id": questions[count]["image"],
|
114 |
+
"answer": output,
|
115 |
+
"ground_truth": questions[count]['ground_truth']
|
116 |
+
}) + "\n")
|
117 |
+
count=count+1
|
118 |
+
ans_file.flush()
|
119 |
+
ans_file.close()
|
120 |
+
evaluation_metrics(answers_file)
|
121 |
+
|
122 |
+
|
123 |
+
if __name__ == "__main__":
|
124 |
+
parser = argparse.ArgumentParser()
|
125 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
126 |
+
parser.add_argument("--model-base", type=str, default=None)
|
127 |
+
parser.add_argument("--image-folder", type=str, default="")
|
128 |
+
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
|
129 |
+
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
|
130 |
+
parser.add_argument("--conv-mode", type=str, default="llava_v1")
|
131 |
+
parser.add_argument("--num-chunks", type=int, default=1)
|
132 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
133 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
134 |
+
parser.add_argument("--top_p", type=float, default=None)
|
135 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
136 |
+
parser.add_argument("--batch_size",type=int, default=1)
|
137 |
+
args = parser.parse_args()
|
138 |
+
|
139 |
+
eval_model(args)
|
geochat/eval/batch_geochat_vqa.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
import shortuuid
|
7 |
+
|
8 |
+
from geochat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
9 |
+
from geochat.conversation import conv_templates, SeparatorStyle
|
10 |
+
from geochat.model.builder import load_pretrained_model
|
11 |
+
from geochat.utils import disable_torch_init
|
12 |
+
from geochat.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
13 |
+
|
14 |
+
from PIL import Image
|
15 |
+
import math
|
16 |
+
def split_list(lst, n):
|
17 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
18 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
19 |
+
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
20 |
+
|
21 |
+
|
22 |
+
def get_chunk(lst, n, k):
|
23 |
+
chunks = split_list(lst, n)
|
24 |
+
return chunks[k]
|
25 |
+
|
26 |
+
|
27 |
+
def eval_model(args):
|
28 |
+
# Model
|
29 |
+
disable_torch_init()
|
30 |
+
model_path = os.path.expanduser(args.model_path)
|
31 |
+
model_name = get_model_name_from_path(model_path)
|
32 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)
|
33 |
+
|
34 |
+
questions=[]
|
35 |
+
questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
|
36 |
+
|
37 |
+
|
38 |
+
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
39 |
+
answers_file = os.path.expanduser(args.answers_file)
|
40 |
+
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
41 |
+
|
42 |
+
ans_file = open(answers_file, "w")
|
43 |
+
|
44 |
+
for i in tqdm(range(0,len(questions),args.batch_size)):
|
45 |
+
input_batch=[]
|
46 |
+
input_image_batch=[]
|
47 |
+
count=i
|
48 |
+
image_folder=[]
|
49 |
+
batch_end = min(i + args.batch_size, len(questions))
|
50 |
+
|
51 |
+
|
52 |
+
for j in range(i,batch_end):
|
53 |
+
image_file=questions[j]['image']
|
54 |
+
qs=questions[j]['text']
|
55 |
+
|
56 |
+
if model.config.mm_use_im_start_end:
|
57 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
|
58 |
+
else:
|
59 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
60 |
+
|
61 |
+
conv = conv_templates[args.conv_mode].copy()
|
62 |
+
conv.append_message(conv.roles[0], qs)
|
63 |
+
conv.append_message(conv.roles[1], None)
|
64 |
+
prompt = conv.get_prompt()
|
65 |
+
|
66 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
67 |
+
input_batch.append(input_ids)
|
68 |
+
|
69 |
+
image = Image.open(os.path.join(args.image_folder, image_file))
|
70 |
+
|
71 |
+
image_folder.append(image)
|
72 |
+
|
73 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
74 |
+
keywords = [stop_str]
|
75 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
76 |
+
|
77 |
+
max_length = max(tensor.size(1) for tensor in input_batch)
|
78 |
+
|
79 |
+
final_input_list = [torch.cat((torch.zeros((1,max_length - tensor.size(1)), dtype=tensor.dtype,device=tensor.get_device()), tensor),dim=1) for tensor in input_batch]
|
80 |
+
final_input_tensors=torch.cat(final_input_list,dim=0)
|
81 |
+
image_tensor_batch = image_processor.preprocess(image_folder,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504}, return_tensors='pt')['pixel_values']
|
82 |
+
|
83 |
+
with torch.inference_mode():
|
84 |
+
output_ids = model.generate( final_input_tensors, images=image_tensor_batch.half().cuda(), do_sample=False , temperature=args.temperature, top_p=args.top_p, num_beams=1, max_new_tokens=256,length_penalty=2.0, use_cache=True)
|
85 |
+
|
86 |
+
input_token_len = final_input_tensors.shape[1]
|
87 |
+
n_diff_input_output = (final_input_tensors != output_ids[:, :input_token_len]).sum().item()
|
88 |
+
if n_diff_input_output > 0:
|
89 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
90 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
|
91 |
+
for k in range(0,len(final_input_list)):
|
92 |
+
output = outputs[k].strip()
|
93 |
+
if output.endswith(stop_str):
|
94 |
+
output = output[:-len(stop_str)]
|
95 |
+
output = output.strip()
|
96 |
+
|
97 |
+
ans_id = shortuuid.uuid()
|
98 |
+
|
99 |
+
ans_file.write(json.dumps({
|
100 |
+
"question_id": questions[count]["question_id"],
|
101 |
+
"image_id": questions[count]["image"],
|
102 |
+
"answer": output,
|
103 |
+
}) + "\n")
|
104 |
+
count=count+1
|
105 |
+
ans_file.flush()
|
106 |
+
ans_file.close()
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
parser = argparse.ArgumentParser()
|
111 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
112 |
+
parser.add_argument("--model-base", type=str, default=None)
|
113 |
+
parser.add_argument("--image-folder", type=str, default="")
|
114 |
+
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
|
115 |
+
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
|
116 |
+
parser.add_argument("--conv-mode", type=str, default="llava_v1")
|
117 |
+
parser.add_argument("--num-chunks", type=int, default=1)
|
118 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
119 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
120 |
+
parser.add_argument("--top_p", type=float, default=None)
|
121 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
122 |
+
parser.add_argument("--batch_size",type=int, default=1)
|
123 |
+
args = parser.parse_args()
|
124 |
+
|
125 |
+
eval_model(args)
|
geochat/mm_utils.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from io import BytesIO
|
3 |
+
import base64
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import StoppingCriteria
|
7 |
+
from geochat.constants import IMAGE_TOKEN_INDEX
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
def load_image_from_base64(image):
|
11 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
12 |
+
|
13 |
+
|
14 |
+
def expand2square(pil_img, background_color):
|
15 |
+
width, height = pil_img.size
|
16 |
+
if width == height:
|
17 |
+
return pil_img
|
18 |
+
elif width > height:
|
19 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
20 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
21 |
+
return result
|
22 |
+
else:
|
23 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
24 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
25 |
+
return result
|
26 |
+
|
27 |
+
|
28 |
+
def process_images(images, image_processor, model_cfg):
|
29 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
30 |
+
new_images = []
|
31 |
+
if image_aspect_ratio == 'pad':
|
32 |
+
for image in images:
|
33 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
34 |
+
image = image_processor.preprocess(image,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504},return_tensors='pt')['pixel_values'][0]
|
35 |
+
# image = image_processor.preprocess(image,return_tensors='pt')['pixel_values'][0]
|
36 |
+
|
37 |
+
new_images.append(image)
|
38 |
+
else:
|
39 |
+
return image_processor(images, return_tensors='pt')['pixel_values']
|
40 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
41 |
+
new_images = torch.stack(new_images, dim=0)
|
42 |
+
return new_images
|
43 |
+
|
44 |
+
def process_images_demo(images, image_processor):
|
45 |
+
new_images = []
|
46 |
+
# image_aspect_ratio = 'pad'
|
47 |
+
for image in images:
|
48 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
49 |
+
image = image_processor.preprocess(image,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504},return_tensors='pt')['pixel_values'][0]
|
50 |
+
# image = image_processor.preprocess(image,return_tensors='pt')['pixel_values'][0]
|
51 |
+
|
52 |
+
new_images.append(image)
|
53 |
+
|
54 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
55 |
+
new_images = torch.stack(new_images, dim=0)
|
56 |
+
return new_images
|
57 |
+
|
58 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
59 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
60 |
+
|
61 |
+
def insert_separator(X, sep):
|
62 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
63 |
+
|
64 |
+
input_ids = []
|
65 |
+
offset = 0
|
66 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
67 |
+
offset = 1
|
68 |
+
input_ids.append(prompt_chunks[0][0])
|
69 |
+
|
70 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
71 |
+
input_ids.extend(x[offset:])
|
72 |
+
|
73 |
+
if return_tensors is not None:
|
74 |
+
if return_tensors == 'pt':
|
75 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
76 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
77 |
+
return input_ids
|
78 |
+
|
79 |
+
|
80 |
+
def get_model_name_from_path(model_path):
|
81 |
+
model_path = model_path.strip("/")
|
82 |
+
model_paths = model_path.split("/")
|
83 |
+
if model_paths[-1].startswith('checkpoint-'):
|
84 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
85 |
+
else:
|
86 |
+
return model_paths[-1]
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
92 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
93 |
+
self.keywords = keywords
|
94 |
+
self.keyword_ids = []
|
95 |
+
self.max_keyword_len = 0
|
96 |
+
for keyword in keywords:
|
97 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
98 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
99 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
100 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
101 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
102 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
103 |
+
self.tokenizer = tokenizer
|
104 |
+
self.start_len = input_ids.shape[1]
|
105 |
+
|
106 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
107 |
+
# assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
|
108 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
109 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
110 |
+
for keyword_id in self.keyword_ids:
|
111 |
+
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
|
112 |
+
return True
|
113 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
114 |
+
flag=False
|
115 |
+
for output in outputs:
|
116 |
+
|
117 |
+
for keyword in self.keywords:
|
118 |
+
if keyword in output:
|
119 |
+
flag=True
|
120 |
+
return flag
|
121 |
+
return flag
|
geochat/model/.ipynb_checkpoints/__init__-checkpoint.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .language_model.geochat_llama import GeoChatLlamaForCausalLM, GeoChatConfig
|
2 |
+
from .language_model.geochat_mpt import GeoChatMPTForCausalLM, GeoChatMPTConfig
|
geochat/model/.ipynb_checkpoints/builder-checkpoint.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import os
|
17 |
+
import warnings
|
18 |
+
import shutil
|
19 |
+
|
20 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
21 |
+
import torch
|
22 |
+
from geochat.model import *
|
23 |
+
from geochat.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
+
|
25 |
+
|
26 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
|
27 |
+
kwargs = {"device_map": device_map}
|
28 |
+
|
29 |
+
if load_8bit:
|
30 |
+
kwargs['load_in_8bit'] = True
|
31 |
+
elif load_4bit:
|
32 |
+
kwargs['load_in_4bit'] = True
|
33 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
34 |
+
load_in_4bit=True,
|
35 |
+
bnb_4bit_compute_dtype=torch.float16,
|
36 |
+
bnb_4bit_use_double_quant=True,
|
37 |
+
bnb_4bit_quant_type='nf4'
|
38 |
+
)
|
39 |
+
else:
|
40 |
+
kwargs['torch_dtype'] = torch.float16
|
41 |
+
|
42 |
+
if 'geochat' in model_name.lower():
|
43 |
+
# Load LLaVA model
|
44 |
+
if 'lora' in model_name.lower() and model_base is None:
|
45 |
+
warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
|
46 |
+
if 'lora' in model_name.lower() and model_base is not None:
|
47 |
+
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
48 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
49 |
+
print('Loading Geochat from base model...')
|
50 |
+
model = GeoChatLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
|
51 |
+
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
52 |
+
if model.lm_head.weight.shape[0] != token_num:
|
53 |
+
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
54 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
55 |
+
|
56 |
+
print('Loading additional GeoChat weights...')
|
57 |
+
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
58 |
+
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
59 |
+
else:
|
60 |
+
# this is probably from HF Hub
|
61 |
+
from huggingface_hub import hf_hub_download
|
62 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
63 |
+
cache_file = hf_hub_download(
|
64 |
+
repo_id=repo_id,
|
65 |
+
filename=filename,
|
66 |
+
subfolder=subfolder)
|
67 |
+
return torch.load(cache_file, map_location='cpu')
|
68 |
+
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
|
69 |
+
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
|
70 |
+
if any(k.startswith('model.model.') for k in non_lora_trainables):
|
71 |
+
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
|
72 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
73 |
+
|
74 |
+
from peft import PeftModel
|
75 |
+
print('Loading LoRA weights...')
|
76 |
+
model = PeftModel.from_pretrained(model, model_path)
|
77 |
+
print('Merging LoRA weights...')
|
78 |
+
model = model.merge_and_unload()
|
79 |
+
print('Model is loaded...')
|
80 |
+
elif model_base is not None:
|
81 |
+
# this may be mm projector only
|
82 |
+
print('Loading GeoChat from base model...')
|
83 |
+
if 'mpt' in model_name.lower():
|
84 |
+
if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
|
85 |
+
shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
|
86 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
|
87 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
88 |
+
model = GeoChatMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
89 |
+
else:
|
90 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
91 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
92 |
+
model = GeoChatLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
93 |
+
|
94 |
+
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
95 |
+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
96 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
97 |
+
else:
|
98 |
+
if 'mpt' in model_name.lower():
|
99 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
100 |
+
model = GeoChatMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
101 |
+
else:
|
102 |
+
print("Loading GeoChat......")
|
103 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
104 |
+
model = GeoChatLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
105 |
+
else:
|
106 |
+
# Load language model
|
107 |
+
if model_base is not None:
|
108 |
+
# PEFT model
|
109 |
+
from peft import PeftModel
|
110 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
111 |
+
model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
|
112 |
+
print(f"Loading LoRA weights from {model_path}")
|
113 |
+
model = PeftModel.from_pretrained(model, model_path)
|
114 |
+
print(f"Merging weights")
|
115 |
+
model = model.merge_and_unload()
|
116 |
+
print('Convert to FP16...')
|
117 |
+
model.to(torch.float16)
|
118 |
+
else:
|
119 |
+
use_fast = False
|
120 |
+
if 'mpt' in model_name.lower():
|
121 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
122 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
|
123 |
+
else:
|
124 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
125 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
126 |
+
|
127 |
+
image_processor = None
|
128 |
+
|
129 |
+
if 'geochat' in model_name.lower():
|
130 |
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
131 |
+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
132 |
+
if mm_use_im_patch_token:
|
133 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
134 |
+
if mm_use_im_start_end:
|
135 |
+
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
136 |
+
model.resize_token_embeddings(len(tokenizer))
|
137 |
+
|
138 |
+
vision_tower = model.get_vision_tower()
|
139 |
+
if not vision_tower.is_loaded:
|
140 |
+
vision_tower.load_model()
|
141 |
+
vision_tower.to(device=device, dtype=torch.float16)
|
142 |
+
image_processor = vision_tower.image_processor
|
143 |
+
|
144 |
+
if hasattr(model.config, "max_sequence_length"):
|
145 |
+
context_len = model.config.max_sequence_length
|
146 |
+
else:
|
147 |
+
context_len = 2048
|
148 |
+
|
149 |
+
return tokenizer, model, image_processor, context_len
|
geochat/model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .language_model.geochat_llama import GeoChatLlamaForCausalLM, GeoChatConfig
|
2 |
+
from .language_model.geochat_mpt import GeoChatMPTForCausalLM, GeoChatMPTConfig
|
geochat/model/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (338 Bytes). View file
|
|
geochat/model/__pycache__/builder.cpython-310.pyc
ADDED
Binary file (4.7 kB). View file
|
|
geochat/model/__pycache__/geochat_arch.cpython-310.pyc
ADDED
Binary file (8.26 kB). View file
|
|
geochat/model/apply_delta.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
10 |
+
from geochat import GeoChatLlamaForCausalLM
|
11 |
+
|
12 |
+
|
13 |
+
def apply_delta(base_model_path, target_model_path, delta_path):
|
14 |
+
print("Loading base model")
|
15 |
+
base = AutoModelForCausalLM.from_pretrained(
|
16 |
+
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
17 |
+
|
18 |
+
print("Loading delta")
|
19 |
+
delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
20 |
+
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
|
21 |
+
|
22 |
+
print("Applying delta")
|
23 |
+
for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
|
24 |
+
if name not in base.state_dict():
|
25 |
+
assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
|
26 |
+
continue
|
27 |
+
if param.data.shape == base.state_dict()[name].shape:
|
28 |
+
param.data += base.state_dict()[name]
|
29 |
+
else:
|
30 |
+
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
|
31 |
+
f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
|
32 |
+
bparam = base.state_dict()[name]
|
33 |
+
param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
|
34 |
+
|
35 |
+
print("Saving target model")
|
36 |
+
delta.save_pretrained(target_model_path)
|
37 |
+
delta_tokenizer.save_pretrained(target_model_path)
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
parser = argparse.ArgumentParser()
|
42 |
+
parser.add_argument("--base-model-path", type=str, required=True)
|
43 |
+
parser.add_argument("--target-model-path", type=str, required=True)
|
44 |
+
parser.add_argument("--delta-path", type=str, required=True)
|
45 |
+
|
46 |
+
args = parser.parse_args()
|
47 |
+
|
48 |
+
apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
|
geochat/model/builder.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import os
|
17 |
+
import warnings
|
18 |
+
import shutil
|
19 |
+
|
20 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
21 |
+
import torch
|
22 |
+
from geochat.model import *
|
23 |
+
from geochat.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
+
|
25 |
+
|
26 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
|
27 |
+
kwargs = {"device_map": device_map}
|
28 |
+
|
29 |
+
if load_8bit:
|
30 |
+
kwargs['load_in_8bit'] = True
|
31 |
+
elif load_4bit:
|
32 |
+
kwargs['load_in_4bit'] = True
|
33 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
34 |
+
load_in_4bit=True,
|
35 |
+
bnb_4bit_compute_dtype=torch.float16,
|
36 |
+
bnb_4bit_use_double_quant=True,
|
37 |
+
bnb_4bit_quant_type='nf4'
|
38 |
+
)
|
39 |
+
else:
|
40 |
+
kwargs['torch_dtype'] = torch.float16
|
41 |
+
|
42 |
+
if 'geochat' in model_name.lower():
|
43 |
+
# Load LLaVA model
|
44 |
+
if 'lora' in model_name.lower() and model_base is None:
|
45 |
+
warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
|
46 |
+
if 'lora' in model_name.lower() and model_base is not None:
|
47 |
+
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
48 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
49 |
+
print('Loading Geochat from base model...')
|
50 |
+
model = GeoChatLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
|
51 |
+
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
52 |
+
if model.lm_head.weight.shape[0] != token_num:
|
53 |
+
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
54 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
55 |
+
|
56 |
+
print('Loading additional GeoChat weights...')
|
57 |
+
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
58 |
+
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
59 |
+
else:
|
60 |
+
# this is probably from HF Hub
|
61 |
+
from huggingface_hub import hf_hub_download
|
62 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
63 |
+
cache_file = hf_hub_download(
|
64 |
+
repo_id=repo_id,
|
65 |
+
filename=filename,
|
66 |
+
subfolder=subfolder)
|
67 |
+
return torch.load(cache_file, map_location='cpu')
|
68 |
+
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
|
69 |
+
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
|
70 |
+
if any(k.startswith('model.model.') for k in non_lora_trainables):
|
71 |
+
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
|
72 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
73 |
+
|
74 |
+
from peft import PeftModel
|
75 |
+
print('Loading LoRA weights...')
|
76 |
+
model = PeftModel.from_pretrained(model, model_path)
|
77 |
+
print('Merging LoRA weights...')
|
78 |
+
model = model.merge_and_unload()
|
79 |
+
print('Model is loaded...')
|
80 |
+
elif model_base is not None:
|
81 |
+
# this may be mm projector only
|
82 |
+
print('Loading GeoChat from base model...')
|
83 |
+
if 'mpt' in model_name.lower():
|
84 |
+
if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
|
85 |
+
shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
|
86 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
|
87 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
88 |
+
model = GeoChatMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
89 |
+
else:
|
90 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
91 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
92 |
+
model = GeoChatLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
93 |
+
|
94 |
+
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
95 |
+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
96 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
97 |
+
else:
|
98 |
+
if 'mpt' in model_name.lower():
|
99 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
100 |
+
model = GeoChatMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
101 |
+
else:
|
102 |
+
print("Loading GeoChat......")
|
103 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
104 |
+
model = GeoChatLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
105 |
+
else:
|
106 |
+
# Load language model
|
107 |
+
if model_base is not None:
|
108 |
+
# PEFT model
|
109 |
+
from peft import PeftModel
|
110 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
111 |
+
model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
|
112 |
+
print(f"Loading LoRA weights from {model_path}")
|
113 |
+
model = PeftModel.from_pretrained(model, model_path)
|
114 |
+
print(f"Merging weights")
|
115 |
+
model = model.merge_and_unload()
|
116 |
+
print('Convert to FP16...')
|
117 |
+
model.to(torch.float16)
|
118 |
+
else:
|
119 |
+
use_fast = False
|
120 |
+
if 'mpt' in model_name.lower():
|
121 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
122 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
|
123 |
+
else:
|
124 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
125 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
126 |
+
|
127 |
+
image_processor = None
|
128 |
+
|
129 |
+
if 'geochat' in model_name.lower():
|
130 |
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
131 |
+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
132 |
+
if mm_use_im_patch_token:
|
133 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
134 |
+
if mm_use_im_start_end:
|
135 |
+
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
136 |
+
model.resize_token_embeddings(len(tokenizer))
|
137 |
+
|
138 |
+
vision_tower = model.get_vision_tower()
|
139 |
+
if not vision_tower.is_loaded:
|
140 |
+
vision_tower.load_model()
|
141 |
+
vision_tower.to(device=device, dtype=torch.float16)
|
142 |
+
image_processor = vision_tower.image_processor
|
143 |
+
|
144 |
+
if hasattr(model.config, "max_sequence_length"):
|
145 |
+
context_len = model.config.max_sequence_length
|
146 |
+
else:
|
147 |
+
context_len = 2048
|
148 |
+
|
149 |
+
return tokenizer, model, image_processor, context_len
|
geochat/model/consolidate.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
9 |
+
from geochat.model import *
|
10 |
+
from geochat.model.utils import auto_upgrade
|
11 |
+
|
12 |
+
|
13 |
+
def consolidate_ckpt(src_path, dst_path):
|
14 |
+
print("Loading model")
|
15 |
+
auto_upgrade(src_path)
|
16 |
+
src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
17 |
+
src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
|
18 |
+
src_model.save_pretrained(dst_path)
|
19 |
+
src_tokenizer.save_pretrained(dst_path)
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument("--src", type=str, required=True)
|
25 |
+
parser.add_argument("--dst", type=str, required=True)
|
26 |
+
|
27 |
+
args = parser.parse_args()
|
28 |
+
|
29 |
+
consolidate_ckpt(args.src, args.dst)
|
geochat/model/geochat_arch.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from abc import ABC, abstractmethod
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from .multimodal_encoder.builder import build_vision_tower
|
22 |
+
from .multimodal_projector.builder import build_vision_projector
|
23 |
+
|
24 |
+
from geochat.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
25 |
+
|
26 |
+
|
27 |
+
class GeoChatMetaModel:
|
28 |
+
|
29 |
+
def __init__(self, config):
|
30 |
+
super(GeoChatMetaModel, self).__init__(config)
|
31 |
+
|
32 |
+
if hasattr(config, "mm_vision_tower"):
|
33 |
+
self.vision_tower = build_vision_tower(config, delay_load=True)
|
34 |
+
self.mm_projector = build_vision_projector(config)
|
35 |
+
|
36 |
+
def get_vision_tower(self):
|
37 |
+
vision_tower = getattr(self, 'vision_tower', None)
|
38 |
+
if type(vision_tower) is list:
|
39 |
+
vision_tower = vision_tower[0]
|
40 |
+
return vision_tower
|
41 |
+
|
42 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
43 |
+
vision_tower = model_args.vision_tower
|
44 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
45 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
46 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
47 |
+
|
48 |
+
self.config.mm_vision_tower = vision_tower
|
49 |
+
|
50 |
+
if self.get_vision_tower() is None:
|
51 |
+
vision_tower = build_vision_tower(model_args)
|
52 |
+
|
53 |
+
if fsdp is not None and len(fsdp) > 0:
|
54 |
+
self.vision_tower = [vision_tower]
|
55 |
+
else:
|
56 |
+
self.vision_tower = vision_tower
|
57 |
+
else:
|
58 |
+
if fsdp is not None and len(fsdp) > 0:
|
59 |
+
vision_tower = self.vision_tower[0]
|
60 |
+
else:
|
61 |
+
vision_tower = self.vision_tower
|
62 |
+
vision_tower.load_model()
|
63 |
+
|
64 |
+
self.config.use_mm_proj = True
|
65 |
+
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
66 |
+
self.config.mm_hidden_size = vision_tower.hidden_size
|
67 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
68 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
69 |
+
|
70 |
+
if getattr(self, 'mm_projector', None) is None:
|
71 |
+
self.mm_projector = build_vision_projector(self.config)
|
72 |
+
# print(mm_projector)
|
73 |
+
|
74 |
+
|
75 |
+
if pretrain_mm_mlp_adapter is not None:
|
76 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
77 |
+
|
78 |
+
def get_w(weights, keyword):
|
79 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
80 |
+
|
81 |
+
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
class GeoChatMetaForCausalLM(ABC):
|
87 |
+
|
88 |
+
@abstractmethod
|
89 |
+
def get_model(self):
|
90 |
+
pass
|
91 |
+
|
92 |
+
def get_vision_tower(self):
|
93 |
+
return self.get_model().get_vision_tower()
|
94 |
+
|
95 |
+
def encode_images(self, images):
|
96 |
+
image_features = self.get_model().get_vision_tower()(images)
|
97 |
+
image_features = self.get_model().mm_projector(image_features)
|
98 |
+
return image_features
|
99 |
+
|
100 |
+
def prepare_inputs_labels_for_multimodal(
|
101 |
+
self, input_ids, attention_mask, past_key_values, labels, images
|
102 |
+
):
|
103 |
+
vision_tower = self.get_vision_tower()
|
104 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
105 |
+
if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
|
106 |
+
attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
|
107 |
+
return input_ids, attention_mask, past_key_values, None, labels
|
108 |
+
|
109 |
+
if type(images) is list or images.ndim == 5:
|
110 |
+
concat_images = torch.cat([image for image in images], dim=0)
|
111 |
+
image_features = self.encode_images(concat_images)
|
112 |
+
split_sizes = [image.shape[0] for image in images]
|
113 |
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
114 |
+
image_features = [x.flatten(0, 1) for x in image_features]
|
115 |
+
else:
|
116 |
+
image_features = self.encode_images(images)
|
117 |
+
|
118 |
+
new_input_embeds = []
|
119 |
+
new_labels = [] if labels is not None else None
|
120 |
+
cur_image_idx = 0
|
121 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
122 |
+
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
|
123 |
+
# multimodal LLM, but the current sample is not multimodal
|
124 |
+
# FIXME: this is a hacky fix, for deepspeed zero3 to work
|
125 |
+
half_len = cur_input_ids.shape[0] // 2
|
126 |
+
cur_image_features = image_features[cur_image_idx]
|
127 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
|
128 |
+
cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
|
129 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
|
130 |
+
new_input_embeds.append(cur_input_embeds)
|
131 |
+
if labels is not None:
|
132 |
+
new_labels.append(labels[batch_idx])
|
133 |
+
cur_image_idx += 1
|
134 |
+
continue
|
135 |
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
136 |
+
cur_new_input_embeds = []
|
137 |
+
if labels is not None:
|
138 |
+
cur_labels = labels[batch_idx]
|
139 |
+
cur_new_labels = []
|
140 |
+
assert cur_labels.shape == cur_input_ids.shape
|
141 |
+
while image_token_indices.numel() > 0:
|
142 |
+
cur_image_features = image_features[cur_image_idx]
|
143 |
+
image_token_start = image_token_indices[0]
|
144 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
145 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())
|
146 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))
|
147 |
+
cur_new_input_embeds.append(cur_image_features)
|
148 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))
|
149 |
+
if labels is not None:
|
150 |
+
cur_new_labels.append(cur_labels[:image_token_start])
|
151 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
152 |
+
cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])
|
153 |
+
cur_labels = cur_labels[image_token_start+2:]
|
154 |
+
else:
|
155 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
|
156 |
+
cur_new_input_embeds.append(cur_image_features)
|
157 |
+
if labels is not None:
|
158 |
+
cur_new_labels.append(cur_labels[:image_token_start])
|
159 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
160 |
+
cur_labels = cur_labels[image_token_start+1:]
|
161 |
+
cur_image_idx += 1
|
162 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
163 |
+
cur_input_ids = cur_input_ids[image_token_start+2:]
|
164 |
+
else:
|
165 |
+
cur_input_ids = cur_input_ids[image_token_start+1:]
|
166 |
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
167 |
+
if cur_input_ids.numel() > 0:
|
168 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
169 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach())
|
170 |
+
else:
|
171 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
|
172 |
+
if labels is not None:
|
173 |
+
cur_new_labels.append(cur_labels)
|
174 |
+
cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
|
175 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
|
176 |
+
new_input_embeds.append(cur_new_input_embeds)
|
177 |
+
if labels is not None:
|
178 |
+
cur_new_labels = torch.cat(cur_new_labels, dim=0)
|
179 |
+
new_labels.append(cur_new_labels)
|
180 |
+
|
181 |
+
if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
|
182 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
183 |
+
|
184 |
+
new_input_embeds_align = []
|
185 |
+
for cur_new_embed in new_input_embeds:
|
186 |
+
cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
|
187 |
+
new_input_embeds_align.append(cur_new_embed)
|
188 |
+
new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
|
189 |
+
|
190 |
+
if labels is not None:
|
191 |
+
new_labels_align = []
|
192 |
+
_new_labels = new_labels
|
193 |
+
for cur_new_label in new_labels:
|
194 |
+
cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
|
195 |
+
new_labels_align.append(cur_new_label)
|
196 |
+
new_labels = torch.stack(new_labels_align, dim=0)
|
197 |
+
|
198 |
+
if attention_mask is not None:
|
199 |
+
new_attention_mask = []
|
200 |
+
for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
|
201 |
+
new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
202 |
+
new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
|
203 |
+
cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
|
204 |
+
new_attention_mask.append(cur_new_attention_mask)
|
205 |
+
attention_mask = torch.stack(new_attention_mask, dim=0)
|
206 |
+
assert attention_mask.shape == new_labels.shape
|
207 |
+
else:
|
208 |
+
new_input_embeds = torch.stack(new_input_embeds, dim=0)
|
209 |
+
if labels is not None:
|
210 |
+
new_labels = torch.stack(new_labels, dim=0)
|
211 |
+
|
212 |
+
if attention_mask is not None:
|
213 |
+
new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
214 |
+
attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
|
215 |
+
assert attention_mask.shape == new_input_embeds.shape[:2]
|
216 |
+
|
217 |
+
return None, attention_mask, past_key_values, new_input_embeds, new_labels
|
218 |
+
|
219 |
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
220 |
+
if model_args.mm_use_im_patch_token:
|
221 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
222 |
+
self.resize_token_embeddings(len(tokenizer))
|
223 |
+
|
224 |
+
if model_args.mm_use_im_start_end:
|
225 |
+
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
226 |
+
self.resize_token_embeddings(len(tokenizer))
|
227 |
+
|
228 |
+
if num_new_tokens > 0:
|
229 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
230 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
231 |
+
|
232 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
233 |
+
dim=0, keepdim=True)
|
234 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
235 |
+
dim=0, keepdim=True)
|
236 |
+
|
237 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
238 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
239 |
+
|
240 |
+
if model_args.tune_mm_mlp_adapter:
|
241 |
+
for p in self.get_input_embeddings().parameters():
|
242 |
+
p.requires_grad = True
|
243 |
+
for p in self.get_output_embeddings().parameters():
|
244 |
+
p.requires_grad = False
|
245 |
+
|
246 |
+
if model_args.pretrain_mm_mlp_adapter:
|
247 |
+
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
|
248 |
+
print(mm_projector_weights)
|
249 |
+
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
250 |
+
assert num_new_tokens == 2
|
251 |
+
if input_embeddings.shape == embed_tokens_weight.shape:
|
252 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
253 |
+
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
254 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
255 |
+
else:
|
256 |
+
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
257 |
+
elif model_args.mm_use_im_patch_token:
|
258 |
+
if model_args.tune_mm_mlp_adapter:
|
259 |
+
for p in self.get_input_embeddings().parameters():
|
260 |
+
p.requires_grad = False
|
261 |
+
for p in self.get_output_embeddings().parameters():
|
262 |
+
p.requires_grad = False
|
geochat/model/language_model/.ipynb_checkpoints/geochat_llama-checkpoint.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from torch.nn import CrossEntropyLoss
|
21 |
+
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM, \
|
23 |
+
LlamaConfig, LlamaModel, LlamaForCausalLM
|
24 |
+
|
25 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
26 |
+
|
27 |
+
from ..geochat_arch import GeoChatMetaModel, GeoChatMetaForCausalLM
|
28 |
+
|
29 |
+
|
30 |
+
class GeoChatConfig(LlamaConfig):
|
31 |
+
model_type = "geochat"
|
32 |
+
|
33 |
+
|
34 |
+
class GeoChatLlamaModel(GeoChatMetaModel, LlamaModel):
|
35 |
+
config_class = GeoChatConfig
|
36 |
+
|
37 |
+
def __init__(self, config: LlamaConfig):
|
38 |
+
super(GeoChatLlamaModel, self).__init__(config)
|
39 |
+
|
40 |
+
|
41 |
+
class GeoChatLlamaForCausalLM(LlamaForCausalLM, GeoChatMetaForCausalLM):
|
42 |
+
config_class = GeoChatConfig
|
43 |
+
|
44 |
+
def __init__(self, config):
|
45 |
+
super(LlamaForCausalLM, self).__init__(config)
|
46 |
+
self.model = GeoChatLlamaModel(config)
|
47 |
+
|
48 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
49 |
+
|
50 |
+
# Initialize weights and apply final processing
|
51 |
+
self.post_init()
|
52 |
+
|
53 |
+
def get_model(self):
|
54 |
+
return self.model
|
55 |
+
|
56 |
+
def forward(
|
57 |
+
self,
|
58 |
+
input_ids: torch.LongTensor = None,
|
59 |
+
attention_mask: Optional[torch.Tensor] = None,
|
60 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
61 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
62 |
+
labels: Optional[torch.LongTensor] = None,
|
63 |
+
use_cache: Optional[bool] = None,
|
64 |
+
output_attentions: Optional[bool] = None,
|
65 |
+
output_hidden_states: Optional[bool] = None,
|
66 |
+
images: Optional[torch.FloatTensor] = None,
|
67 |
+
return_dict: Optional[bool] = None,
|
68 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
69 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
70 |
+
output_hidden_states = (
|
71 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
72 |
+
)
|
73 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
74 |
+
|
75 |
+
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
|
76 |
+
|
77 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
78 |
+
outputs = self.model(
|
79 |
+
input_ids=input_ids,
|
80 |
+
attention_mask=attention_mask,
|
81 |
+
past_key_values=past_key_values,
|
82 |
+
inputs_embeds=inputs_embeds,
|
83 |
+
use_cache=use_cache,
|
84 |
+
output_attentions=output_attentions,
|
85 |
+
output_hidden_states=output_hidden_states,
|
86 |
+
return_dict=return_dict
|
87 |
+
)
|
88 |
+
|
89 |
+
hidden_states = outputs[0]
|
90 |
+
logits = self.lm_head(hidden_states)
|
91 |
+
|
92 |
+
loss = None
|
93 |
+
if labels is not None:
|
94 |
+
# Shift so that tokens < n predict n
|
95 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
96 |
+
shift_labels = labels[..., 1:].contiguous()
|
97 |
+
# Flatten the tokens
|
98 |
+
loss_fct = CrossEntropyLoss()
|
99 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
100 |
+
shift_labels = shift_labels.view(-1)
|
101 |
+
# Enable model/pipeline parallelism
|
102 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
103 |
+
loss = loss_fct(shift_logits, shift_labels)
|
104 |
+
|
105 |
+
if not return_dict:
|
106 |
+
output = (logits,) + outputs[1:]
|
107 |
+
return (loss,) + output if loss is not None else output
|
108 |
+
|
109 |
+
return CausalLMOutputWithPast(
|
110 |
+
loss=loss,
|
111 |
+
logits=logits,
|
112 |
+
past_key_values=outputs.past_key_values,
|
113 |
+
hidden_states=outputs.hidden_states,
|
114 |
+
attentions=outputs.attentions,
|
115 |
+
)
|
116 |
+
|
117 |
+
def prepare_inputs_for_generation(
|
118 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
119 |
+
):
|
120 |
+
if past_key_values:
|
121 |
+
input_ids = input_ids[:, -1:]
|
122 |
+
|
123 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
124 |
+
if inputs_embeds is not None and past_key_values is None:
|
125 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
126 |
+
else:
|
127 |
+
model_inputs = {"input_ids": input_ids}
|
128 |
+
|
129 |
+
model_inputs.update(
|
130 |
+
{
|
131 |
+
"past_key_values": past_key_values,
|
132 |
+
"use_cache": kwargs.get("use_cache"),
|
133 |
+
"attention_mask": attention_mask,
|
134 |
+
"images": kwargs.get("images", None),
|
135 |
+
}
|
136 |
+
)
|
137 |
+
return model_inputs
|
138 |
+
|
139 |
+
AutoConfig.register("geochat", GeoChatConfig)
|
140 |
+
AutoModelForCausalLM.register(GeoChatConfig, GeoChatLlamaForCausalLM)
|
geochat/model/language_model/__pycache__/geochat_llama.cpython-310.pyc
ADDED
Binary file (3.56 kB). View file
|
|
geochat/model/language_model/__pycache__/geochat_mpt.cpython-310.pyc
ADDED
Binary file (4.71 kB). View file
|
|
geochat/model/language_model/geochat_llama.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from torch.nn import CrossEntropyLoss
|
21 |
+
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM, \
|
23 |
+
LlamaConfig, LlamaModel, LlamaForCausalLM
|
24 |
+
|
25 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
26 |
+
|
27 |
+
from ..geochat_arch import GeoChatMetaModel, GeoChatMetaForCausalLM
|
28 |
+
|
29 |
+
|
30 |
+
class GeoChatConfig(LlamaConfig):
|
31 |
+
model_type = "geochat"
|
32 |
+
|
33 |
+
|
34 |
+
class GeoChatLlamaModel(GeoChatMetaModel, LlamaModel):
|
35 |
+
config_class = GeoChatConfig
|
36 |
+
|
37 |
+
def __init__(self, config: LlamaConfig):
|
38 |
+
super(GeoChatLlamaModel, self).__init__(config)
|
39 |
+
|
40 |
+
|
41 |
+
class GeoChatLlamaForCausalLM(LlamaForCausalLM, GeoChatMetaForCausalLM):
|
42 |
+
config_class = GeoChatConfig
|
43 |
+
|
44 |
+
def __init__(self, config):
|
45 |
+
super(LlamaForCausalLM, self).__init__(config)
|
46 |
+
self.model = GeoChatLlamaModel(config)
|
47 |
+
|
48 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
49 |
+
|
50 |
+
# Initialize weights and apply final processing
|
51 |
+
self.post_init()
|
52 |
+
|
53 |
+
def get_model(self):
|
54 |
+
return self.model
|
55 |
+
|
56 |
+
def forward(
|
57 |
+
self,
|
58 |
+
input_ids: torch.LongTensor = None,
|
59 |
+
attention_mask: Optional[torch.Tensor] = None,
|
60 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
61 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
62 |
+
labels: Optional[torch.LongTensor] = None,
|
63 |
+
use_cache: Optional[bool] = None,
|
64 |
+
output_attentions: Optional[bool] = None,
|
65 |
+
output_hidden_states: Optional[bool] = None,
|
66 |
+
images: Optional[torch.FloatTensor] = None,
|
67 |
+
return_dict: Optional[bool] = None,
|
68 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
69 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
70 |
+
output_hidden_states = (
|
71 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
72 |
+
)
|
73 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
74 |
+
|
75 |
+
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
|
76 |
+
|
77 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
78 |
+
outputs = self.model(
|
79 |
+
input_ids=input_ids,
|
80 |
+
attention_mask=attention_mask,
|
81 |
+
past_key_values=past_key_values,
|
82 |
+
inputs_embeds=inputs_embeds,
|
83 |
+
use_cache=use_cache,
|
84 |
+
output_attentions=output_attentions,
|
85 |
+
output_hidden_states=output_hidden_states,
|
86 |
+
return_dict=return_dict
|
87 |
+
)
|
88 |
+
|
89 |
+
hidden_states = outputs[0]
|
90 |
+
logits = self.lm_head(hidden_states)
|
91 |
+
|
92 |
+
loss = None
|
93 |
+
if labels is not None:
|
94 |
+
# Shift so that tokens < n predict n
|
95 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
96 |
+
shift_labels = labels[..., 1:].contiguous()
|
97 |
+
# Flatten the tokens
|
98 |
+
loss_fct = CrossEntropyLoss()
|
99 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
100 |
+
shift_labels = shift_labels.view(-1)
|
101 |
+
# Enable model/pipeline parallelism
|
102 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
103 |
+
loss = loss_fct(shift_logits, shift_labels)
|
104 |
+
|
105 |
+
if not return_dict:
|
106 |
+
output = (logits,) + outputs[1:]
|
107 |
+
return (loss,) + output if loss is not None else output
|
108 |
+
|
109 |
+
return CausalLMOutputWithPast(
|
110 |
+
loss=loss,
|
111 |
+
logits=logits,
|
112 |
+
past_key_values=outputs.past_key_values,
|
113 |
+
hidden_states=outputs.hidden_states,
|
114 |
+
attentions=outputs.attentions,
|
115 |
+
)
|
116 |
+
|
117 |
+
def prepare_inputs_for_generation(
|
118 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
119 |
+
):
|
120 |
+
if past_key_values:
|
121 |
+
input_ids = input_ids[:, -1:]
|
122 |
+
|
123 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
124 |
+
if inputs_embeds is not None and past_key_values is None:
|
125 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
126 |
+
else:
|
127 |
+
model_inputs = {"input_ids": input_ids}
|
128 |
+
|
129 |
+
model_inputs.update(
|
130 |
+
{
|
131 |
+
"past_key_values": past_key_values,
|
132 |
+
"use_cache": kwargs.get("use_cache"),
|
133 |
+
"attention_mask": attention_mask,
|
134 |
+
"images": kwargs.get("images", None),
|
135 |
+
}
|
136 |
+
)
|
137 |
+
return model_inputs
|
138 |
+
|
139 |
+
AutoConfig.register("geochat", GeoChatConfig)
|
140 |
+
AutoModelForCausalLM.register(GeoChatConfig, GeoChatLlamaForCausalLM)
|
geochat/model/language_model/geochat_mpt.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple
|
17 |
+
import warnings
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import math
|
22 |
+
|
23 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
25 |
+
|
26 |
+
from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
|
27 |
+
from geochat.model.geochat_arch import GeoChatMetaModel, GeoChatMetaForCausalLM
|
28 |
+
|
29 |
+
|
30 |
+
class GeoChatMPTConfig(MPTConfig):
|
31 |
+
model_type = "geochat_mpt"
|
32 |
+
|
33 |
+
|
34 |
+
class GeoChatMPTModel(GeoChatMetaModel, MPTModel):
|
35 |
+
config_class = GeoChatMPTConfig
|
36 |
+
|
37 |
+
def __init__(self, config: MPTConfig):
|
38 |
+
config.hidden_size = config.d_model
|
39 |
+
super(GeoChatMPTModel, self).__init__(config)
|
40 |
+
|
41 |
+
def embed_tokens(self, x):
|
42 |
+
return self.wte(x)
|
43 |
+
|
44 |
+
|
45 |
+
class GeoChatMPTForCausalLM(MPTForCausalLM, GeoChatMetaForCausalLM):
|
46 |
+
config_class = GeoChatMPTConfig
|
47 |
+
supports_gradient_checkpointing = True
|
48 |
+
|
49 |
+
def __init__(self, config):
|
50 |
+
super(MPTForCausalLM, self).__init__(config)
|
51 |
+
|
52 |
+
if not config.tie_word_embeddings:
|
53 |
+
raise ValueError('MPTForCausalLM only supports tied word embeddings')
|
54 |
+
self.transformer = GeoChatMPTModel(config)
|
55 |
+
self.logit_scale = None
|
56 |
+
if config.logit_scale is not None:
|
57 |
+
logit_scale = config.logit_scale
|
58 |
+
if isinstance(logit_scale, str):
|
59 |
+
if logit_scale == 'inv_sqrt_d_model':
|
60 |
+
logit_scale = 1 / math.sqrt(config.d_model)
|
61 |
+
else:
|
62 |
+
raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
|
63 |
+
self.logit_scale = logit_scale
|
64 |
+
|
65 |
+
def get_model(self):
|
66 |
+
return self.transformer
|
67 |
+
|
68 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
69 |
+
if isinstance(module, GeoChatMPTModel):
|
70 |
+
module.gradient_checkpointing = value
|
71 |
+
|
72 |
+
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None):
|
73 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
74 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
75 |
+
|
76 |
+
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
|
77 |
+
outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
|
78 |
+
# FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338
|
79 |
+
logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
|
80 |
+
if self.logit_scale is not None:
|
81 |
+
if self.logit_scale == 0:
|
82 |
+
warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
|
83 |
+
logits *= self.logit_scale
|
84 |
+
loss = None
|
85 |
+
if labels is not None:
|
86 |
+
labels = torch.roll(labels, shifts=-1)
|
87 |
+
labels[:, -1] = -100
|
88 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
|
89 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
|
90 |
+
|
91 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
92 |
+
if inputs_embeds is not None:
|
93 |
+
raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
|
94 |
+
attention_mask = kwargs['attention_mask'].bool()
|
95 |
+
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
|
96 |
+
raise NotImplementedError('MPT does not support generation with right padding.')
|
97 |
+
if self.transformer.attn_uses_sequence_id and self.training:
|
98 |
+
sequence_id = torch.zeros_like(input_ids[:1])
|
99 |
+
else:
|
100 |
+
sequence_id = None
|
101 |
+
if past_key_values is not None:
|
102 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
103 |
+
if self.transformer.prefix_lm:
|
104 |
+
prefix_mask = torch.ones_like(attention_mask)
|
105 |
+
if kwargs.get('use_cache') == False:
|
106 |
+
raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
|
107 |
+
else:
|
108 |
+
prefix_mask = None
|
109 |
+
return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)}
|
110 |
+
|
111 |
+
|
112 |
+
AutoConfig.register("geochat_mpt", GeoChatMPTConfig)
|
113 |
+
AutoModelForCausalLM.register(GeoChatMPTConfig, GeoChatMPTForCausalLM)
|
geochat/model/language_model/mpt/__pycache__/adapt_tokenizer.cpython-310.pyc
ADDED
Binary file (2.24 kB). View file
|
|