roselee commited on
Commit
2b49482
·
verified ·
1 Parent(s): cdb8505

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +21 -0
  2. .ipynb_checkpoints/geochat_demo-checkpoint.py +707 -0
  3. .ipynb_checkpoints/pyproject-checkpoint.toml +39 -0
  4. README.md +227 -8
  5. demo_images/04133.png +3 -0
  6. demo_images/04444.png +3 -0
  7. demo_images/7292.JPG +3 -0
  8. demo_images/MicrosoftTeams-image.png +3 -0
  9. demo_images/church_183.png +3 -0
  10. demo_images/train_2956_0001.png +3 -0
  11. docs/Customize_Component.md +20 -0
  12. docs/Data.md +24 -0
  13. docs/Evaluation.md +54 -0
  14. docs/LoRA.md +24 -0
  15. docs/MODEL_ZOO.md +18 -0
  16. docs/geochat_supp.pdf +3 -0
  17. geochat.egg-info/PKG-INFO +260 -0
  18. geochat.egg-info/SOURCES.txt +51 -0
  19. geochat.egg-info/dependency_links.txt +1 -0
  20. geochat.egg-info/requires.txt +24 -0
  21. geochat.egg-info/top_level.txt +3 -0
  22. geochat/__init__.py +1 -0
  23. geochat/__pycache__/__init__.cpython-310.pyc +0 -0
  24. geochat/__pycache__/constants.cpython-310.pyc +0 -0
  25. geochat/__pycache__/conversation.cpython-310.pyc +0 -0
  26. geochat/__pycache__/mm_utils.cpython-310.pyc +0 -0
  27. geochat/__pycache__/utils.cpython-310.pyc +0 -0
  28. geochat/constants.py +12 -0
  29. geochat/conversation.py +520 -0
  30. geochat/eval/batch_geochat_grounding.py +138 -0
  31. geochat/eval/batch_geochat_referring.py +132 -0
  32. geochat/eval/batch_geochat_scene.py +139 -0
  33. geochat/eval/batch_geochat_vqa.py +125 -0
  34. geochat/mm_utils.py +121 -0
  35. geochat/model/.ipynb_checkpoints/__init__-checkpoint.py +2 -0
  36. geochat/model/.ipynb_checkpoints/builder-checkpoint.py +149 -0
  37. geochat/model/__init__.py +2 -0
  38. geochat/model/__pycache__/__init__.cpython-310.pyc +0 -0
  39. geochat/model/__pycache__/builder.cpython-310.pyc +0 -0
  40. geochat/model/__pycache__/geochat_arch.cpython-310.pyc +0 -0
  41. geochat/model/apply_delta.py +48 -0
  42. geochat/model/builder.py +149 -0
  43. geochat/model/consolidate.py +29 -0
  44. geochat/model/geochat_arch.py +262 -0
  45. geochat/model/language_model/.ipynb_checkpoints/geochat_llama-checkpoint.py +140 -0
  46. geochat/model/language_model/__pycache__/geochat_llama.cpython-310.pyc +0 -0
  47. geochat/model/language_model/__pycache__/geochat_mpt.cpython-310.pyc +0 -0
  48. geochat/model/language_model/geochat_llama.py +140 -0
  49. geochat/model/language_model/geochat_mpt.py +113 -0
  50. geochat/model/language_model/mpt/__pycache__/adapt_tokenizer.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,24 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo_images/04133.png filter=lfs diff=lfs merge=lfs -text
37
+ demo_images/04444.png filter=lfs diff=lfs merge=lfs -text
38
+ demo_images/7292.JPG filter=lfs diff=lfs merge=lfs -text
39
+ demo_images/MicrosoftTeams-image.png filter=lfs diff=lfs merge=lfs -text
40
+ demo_images/church_183.png filter=lfs diff=lfs merge=lfs -text
41
+ demo_images/train_2956_0001.png filter=lfs diff=lfs merge=lfs -text
42
+ docs/geochat_supp.pdf filter=lfs diff=lfs merge=lfs -text
43
+ geochat/serve/examples/11760.jpg filter=lfs diff=lfs merge=lfs -text
44
+ geochat/serve/examples/11765.jpg filter=lfs diff=lfs merge=lfs -text
45
+ images/architecture.png filter=lfs diff=lfs merge=lfs -text
46
+ images/dataset.png filter=lfs diff=lfs merge=lfs -text
47
+ images/examples.png filter=lfs diff=lfs merge=lfs -text
48
+ images/grounded.jpg filter=lfs diff=lfs merge=lfs -text
49
+ images/iden.jpg filter=lfs diff=lfs merge=lfs -text
50
+ images/logo_geochat.png filter=lfs diff=lfs merge=lfs -text
51
+ images/overview2.png filter=lfs diff=lfs merge=lfs -text
52
+ images/ref1.jpg filter=lfs diff=lfs merge=lfs -text
53
+ images/ref_2.jpg filter=lfs diff=lfs merge=lfs -text
54
+ images/scene.jpg filter=lfs diff=lfs merge=lfs -text
55
+ images/teaser.png filter=lfs diff=lfs merge=lfs -text
56
+ images/vqa.jpg filter=lfs diff=lfs merge=lfs -text
.ipynb_checkpoints/geochat_demo-checkpoint.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ from collections import defaultdict
5
+
6
+ import cv2
7
+ import re
8
+ import math
9
+ import numpy as np
10
+ from PIL import Image
11
+ import torch
12
+ import html
13
+ import gradio as gr
14
+
15
+ import torchvision.transforms as T
16
+ import torch.backends.cudnn as cudnn
17
+
18
+ from geochat.conversation import conv_templates, Chat
19
+ from geochat.model.builder import load_pretrained_model
20
+ from geochat.mm_utils import get_model_name_from_path
21
+
22
+
23
+ def parse_args():
24
+ parser = argparse.ArgumentParser(description="Demo")
25
+ # parser = argparse.ArgumentParser()
26
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
27
+ parser.add_argument("--model-base", type=str, default=None)
28
+ parser.add_argument("--gpu-id", type=str,default=0)
29
+ parser.add_argument("--device", type=str, default="cuda")
30
+ parser.add_argument("--conv-mode", type=str, default=None)
31
+ parser.add_argument("--max-new-tokens", type=int, default=300)
32
+ parser.add_argument("--load-8bit", action="store_true")
33
+ parser.add_argument("--load-4bit", action="store_true")
34
+ parser.add_argument("--debug", action="store_true")
35
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
36
+ # args = parser.parse_args()
37
+ args = parser.parse_args()
38
+ return args
39
+
40
+
41
+ random.seed(42)
42
+ np.random.seed(42)
43
+ torch.manual_seed(42)
44
+
45
+ cudnn.benchmark = False
46
+ cudnn.deterministic = True
47
+
48
+ print('Initializing Chat')
49
+ args = parse_args()
50
+ # cfg = Config(args)
51
+
52
+ model_name = get_model_name_from_path(args.model_path)
53
+ print(model_name)
54
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
55
+
56
+ device = 'cuda:{}'.format(args.gpu_id)
57
+
58
+ # model_config = cfg.model_cfg
59
+ # model_config.device_8bit = args.gpu_id
60
+ # model_cls = registry.get_model_class(model_config.arch)
61
+ # model = model_cls.from_config(model_config).to(device)
62
+ bounding_box_size = 100
63
+
64
+ # vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
65
+ # vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
66
+
67
+ model = model.eval()
68
+
69
+ CONV_VISION = conv_templates['llava_v1'].copy()
70
+
71
+ def bbox_and_angle_to_polygon(x1, y1, x2, y2, a):
72
+ # Calculate center coordinates
73
+ x_ctr = (x1 + x2) / 2
74
+ y_ctr = (y1 + y2) / 2
75
+
76
+ # Calculate width and height
77
+ w = abs(x2 - x1)
78
+ h = abs(y2 - y1)
79
+
80
+ # Calculate the angle in radians
81
+ angle_rad = math.radians(a)
82
+
83
+ # Calculate coordinates of the four corners of the rotated bounding box
84
+ cos_a = math.cos(angle_rad)
85
+ sin_a = math.sin(angle_rad)
86
+
87
+ x1_rot = cos_a * (-w / 2) - sin_a * (-h / 2) + x_ctr
88
+ y1_rot = sin_a * (-w / 2) + cos_a * (-h / 2) + y_ctr
89
+
90
+ x2_rot = cos_a * (w / 2) - sin_a * (-h / 2) + x_ctr
91
+ y2_rot = sin_a * (w / 2) + cos_a * (-h / 2) + y_ctr
92
+
93
+ x3_rot = cos_a * (w / 2) - sin_a * (h / 2) + x_ctr
94
+ y3_rot = sin_a * (w / 2) + cos_a * (h / 2) + y_ctr
95
+
96
+ x4_rot = cos_a * (-w / 2) - sin_a * (h / 2) + x_ctr
97
+ y4_rot = sin_a * (-w / 2) + cos_a * (h / 2) + y_ctr
98
+
99
+ # Return the polygon coordinates
100
+ polygon_coords = np.array((x1_rot, y1_rot, x2_rot, y2_rot, x3_rot, y3_rot, x4_rot, y4_rot))
101
+
102
+ return polygon_coords
103
+
104
+ def rotate_bbox(top_right, bottom_left, angle_degrees):
105
+ # Convert angle to radians
106
+ angle_radians = np.radians(angle_degrees)
107
+
108
+ # Calculate the center of the rectangle
109
+ center = ((top_right[0] + bottom_left[0]) / 2, (top_right[1] + bottom_left[1]) / 2)
110
+
111
+ # Calculate the width and height of the rectangle
112
+ width = top_right[0] - bottom_left[0]
113
+ height = top_right[1] - bottom_left[1]
114
+
115
+ # Create a rotation matrix
116
+ rotation_matrix = cv2.getRotationMatrix2D(center, angle_degrees, 1)
117
+
118
+ # Create an array of the rectangle corners
119
+ rectangle_points = np.array([[bottom_left[0], bottom_left[1]],
120
+ [top_right[0], bottom_left[1]],
121
+ [top_right[0], top_right[1]],
122
+ [bottom_left[0], top_right[1]]], dtype=np.float32)
123
+
124
+ # Rotate the rectangle points
125
+ rotated_rectangle = cv2.transform(np.array([rectangle_points]), rotation_matrix)[0]
126
+
127
+ return rotated_rectangle
128
+ def extract_substrings(string):
129
+ # first check if there is no-finished bracket
130
+ index = string.rfind('}')
131
+ if index != -1:
132
+ string = string[:index + 1]
133
+
134
+ pattern = r'<p>(.*?)\}(?!<)'
135
+ matches = re.findall(pattern, string)
136
+ substrings = [match for match in matches]
137
+
138
+ return substrings
139
+
140
+
141
+ def is_overlapping(rect1, rect2):
142
+ x1, y1, x2, y2 = rect1
143
+ x3, y3, x4, y4 = rect2
144
+ return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
145
+
146
+
147
+ def computeIoU(bbox1, bbox2):
148
+ x1, y1, x2, y2 = bbox1
149
+ x3, y3, x4, y4 = bbox2
150
+ intersection_x1 = max(x1, x3)
151
+ intersection_y1 = max(y1, y3)
152
+ intersection_x2 = min(x2, x4)
153
+ intersection_y2 = min(y2, y4)
154
+ intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
155
+ bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
156
+ bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
157
+ union_area = bbox1_area + bbox2_area - intersection_area
158
+ iou = intersection_area / union_area
159
+ return iou
160
+
161
+
162
+ def save_tmp_img(visual_img):
163
+ file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
164
+ file_path = "/tmp/gradio" + file_name
165
+ visual_img.save(file_path)
166
+ return file_path
167
+
168
+
169
+ def mask2bbox(mask):
170
+ if mask is None:
171
+ return ''
172
+ mask = mask.resize([100, 100], resample=Image.NEAREST)
173
+ mask = np.array(mask)[:, :, 0]
174
+
175
+ rows = np.any(mask, axis=1)
176
+ cols = np.any(mask, axis=0)
177
+
178
+ if rows.sum():
179
+ # Get the top, bottom, left, and right boundaries
180
+ rmin, rmax = np.where(rows)[0][[0, -1]]
181
+ cmin, cmax = np.where(cols)[0][[0, -1]]
182
+ bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
183
+ else:
184
+ bbox = ''
185
+
186
+ return bbox
187
+
188
+
189
+ def escape_markdown(text):
190
+ # List of Markdown special characters that need to be escaped
191
+ md_chars = ['<', '>']
192
+
193
+ # Escape each special character
194
+ for char in md_chars:
195
+ text = text.replace(char, '\\' + char)
196
+
197
+ return text
198
+
199
+
200
+ def reverse_escape(text):
201
+ md_chars = ['\\<', '\\>']
202
+
203
+ for char in md_chars:
204
+ text = text.replace(char, char[1:])
205
+
206
+ return text
207
+
208
+
209
+ colors = [
210
+ (255, 0, 0),
211
+ (0, 255, 0),
212
+ (0, 0, 255),
213
+ (210, 210, 0),
214
+ (255, 0, 255),
215
+ (0, 255, 255),
216
+ (114, 128, 250),
217
+ (0, 165, 255),
218
+ (0, 128, 0),
219
+ (144, 238, 144),
220
+ (238, 238, 175),
221
+ (255, 191, 0),
222
+ (0, 128, 0),
223
+ (226, 43, 138),
224
+ (255, 0, 255),
225
+ (0, 215, 255),
226
+ ]
227
+
228
+ color_map = {
229
+ f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
230
+ color_id, color in enumerate(colors)
231
+ }
232
+
233
+ used_colors = colors
234
+
235
+
236
+ def visualize_all_bbox_together(image, generation):
237
+ if image is None:
238
+ return None, ''
239
+
240
+ generation = html.unescape(generation)
241
+
242
+ image_width, image_height = image.size
243
+ image = image.resize([500, int(500 / image_width * image_height)])
244
+ image_width, image_height = image.size
245
+
246
+ string_list = extract_substrings(generation)
247
+ if string_list: # it is grounding or detection
248
+ mode = 'all'
249
+ entities = defaultdict(list)
250
+ i = 0
251
+ j = 0
252
+ for string in string_list:
253
+ try:
254
+ obj, string = string.split('</p>')
255
+ except ValueError:
256
+ print('wrong string: ', string)
257
+ continue
258
+ if "}{" in string:
259
+ string=string.replace("}{","}<delim>{")
260
+ bbox_list = string.split('<delim>')
261
+ flag = False
262
+ for bbox_string in bbox_list:
263
+ integers = re.findall(r'-?\d+', bbox_string)
264
+ if len(integers)==4:
265
+ angle=0
266
+ else:
267
+ angle=integers[4]
268
+ integers=integers[:-1]
269
+
270
+ if len(integers) == 4:
271
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
272
+ left = x0 / bounding_box_size * image_width
273
+ bottom = y0 / bounding_box_size * image_height
274
+ right = x1 / bounding_box_size * image_width
275
+ top = y1 / bounding_box_size * image_height
276
+
277
+ entities[obj].append([left, bottom, right, top,angle])
278
+
279
+ j += 1
280
+ flag = True
281
+ if flag:
282
+ i += 1
283
+ else:
284
+ integers = re.findall(r'-?\d+', generation)
285
+ # if len(integers)==4:
286
+ angle=0
287
+ # else:
288
+ # angle=integers[4]
289
+ integers=integers[:-1]
290
+ if len(integers) == 4: # it is refer
291
+ mode = 'single'
292
+
293
+ entities = list()
294
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
295
+ left = x0 / bounding_box_size * image_width
296
+ bottom = y0 / bounding_box_size * image_height
297
+ right = x1 / bounding_box_size * image_width
298
+ top = y1 / bounding_box_size * image_height
299
+ entities.append([left, bottom, right, top,angle])
300
+ else:
301
+ # don't detect any valid bbox to visualize
302
+ return None, ''
303
+
304
+ if len(entities) == 0:
305
+ return None, ''
306
+
307
+ if isinstance(image, Image.Image):
308
+ image_h = image.height
309
+ image_w = image.width
310
+ image = np.array(image)
311
+
312
+ elif isinstance(image, str):
313
+ if os.path.exists(image):
314
+ pil_img = Image.open(image).convert("RGB")
315
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
316
+ image_h = pil_img.height
317
+ image_w = pil_img.width
318
+ else:
319
+ raise ValueError(f"invaild image path, {image}")
320
+ elif isinstance(image, torch.Tensor):
321
+
322
+ image_tensor = image.cpu()
323
+ reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
324
+ reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
325
+ image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
326
+ pil_img = T.ToPILImage()(image_tensor)
327
+ image_h = pil_img.height
328
+ image_w = pil_img.width
329
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
330
+ else:
331
+ raise ValueError(f"invalid image format, {type(image)} for {image}")
332
+
333
+ indices = list(range(len(entities)))
334
+
335
+ new_image = image.copy()
336
+
337
+ previous_bboxes = []
338
+ # size of text
339
+ text_size = 0.4
340
+ # thickness of text
341
+ text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
342
+ box_line = 2
343
+ (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
344
+ base_height = int(text_height * 0.675)
345
+ text_offset_original = text_height - base_height
346
+ text_spaces = 2
347
+
348
+ # num_bboxes = sum(len(x[-1]) for x in entities)
349
+ used_colors = colors # random.sample(colors, k=num_bboxes)
350
+
351
+ color_id = -1
352
+ for entity_idx, entity_name in enumerate(entities):
353
+ if mode == 'single' or mode == 'identify':
354
+ bboxes = entity_name
355
+ bboxes = [bboxes]
356
+ else:
357
+ bboxes = entities[entity_name]
358
+ color_id += 1
359
+ for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm,angle) in enumerate(bboxes):
360
+ skip_flag = False
361
+ orig_x1, orig_y1, orig_x2, orig_y2,angle = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm), int(angle)
362
+
363
+ color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
364
+ top_right=(orig_x1,orig_y1)
365
+ bottom_left=(orig_x2,orig_y2)
366
+ angle=angle
367
+ rotated_bbox = rotate_bbox(top_right, bottom_left, angle)
368
+ new_image=cv2.polylines(new_image, [rotated_bbox.astype(np.int32)], isClosed=True,thickness=2, color=color)
369
+
370
+ # new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
371
+
372
+ if mode == 'all':
373
+ l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
374
+
375
+ x1 = orig_x1 - l_o
376
+ y1 = orig_y1 - l_o
377
+
378
+ if y1 < text_height + text_offset_original + 2 * text_spaces:
379
+ y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
380
+ x1 = orig_x1 + r_o
381
+
382
+ # add text background
383
+ (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
384
+ text_line)
385
+ text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
386
+ text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
387
+
388
+ for prev_bbox in previous_bboxes:
389
+ if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
390
+ prev_bbox['phrase'] == entity_name:
391
+ skip_flag = True
392
+ break
393
+ while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
394
+ text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
395
+ text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
396
+ y1 += (text_height + text_offset_original + 2 * text_spaces)
397
+
398
+ if text_bg_y2 >= image_h:
399
+ text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
400
+ text_bg_y2 = image_h
401
+ y1 = image_h
402
+ break
403
+ if not skip_flag:
404
+ alpha = 0.5
405
+ for i in range(text_bg_y1, text_bg_y2):
406
+ for j in range(text_bg_x1, text_bg_x2):
407
+ if i < image_h and j < image_w:
408
+ if j < text_bg_x1 + 1.35 * c_width:
409
+ # original color
410
+ bg_color = color
411
+ else:
412
+ # white
413
+ bg_color = [255, 255, 255]
414
+ new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
415
+ np.uint8)
416
+
417
+ cv2.putText(
418
+ new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
419
+ cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
420
+ )
421
+
422
+ previous_bboxes.append(
423
+ {'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
424
+
425
+ if mode == 'all':
426
+ def color_iterator(colors):
427
+ while True:
428
+ for color in colors:
429
+ yield color
430
+
431
+ color_gen = color_iterator(colors)
432
+
433
+ # Add colors to phrases and remove <p></p>
434
+ def colored_phrases(match):
435
+ phrase = match.group(1)
436
+ color = next(color_gen)
437
+ return f'<span style="color:rgb{color}">{phrase}</span>'
438
+
439
+ generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
440
+ generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
441
+ else:
442
+ generation_colored = ''
443
+
444
+ pil_image = Image.fromarray(new_image)
445
+ return pil_image, generation_colored
446
+
447
+
448
+ def gradio_reset(chat_state, img_list):
449
+ if chat_state is not None:
450
+ chat_state.messages = []
451
+ if img_list is not None:
452
+ img_list = []
453
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
454
+ interactive=True), chat_state, img_list
455
+
456
+
457
+ def image_upload_trigger(upload_flag, replace_flag, img_list):
458
+ # set the upload flag to true when receive a new image.
459
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
460
+ upload_flag = 1
461
+ if img_list:
462
+ replace_flag = 1
463
+ return upload_flag, replace_flag
464
+
465
+
466
+ def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
467
+ # set the upload flag to true when receive a new image.
468
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
469
+ upload_flag = 1
470
+ if img_list or replace_flag == 1:
471
+ replace_flag = 1
472
+
473
+ return upload_flag, replace_flag
474
+
475
+
476
+ def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
477
+ if len(user_message) == 0:
478
+ text_box_show = 'Input should not be empty!'
479
+ else:
480
+ text_box_show = ''
481
+
482
+ if isinstance(gr_img, dict):
483
+ gr_img, mask = gr_img['image'], gr_img['mask']
484
+ else:
485
+ mask = None
486
+
487
+ if '[identify]' in user_message:
488
+ # check if user provide bbox in the text input
489
+ integers = re.findall(r'-?\d+', user_message)
490
+ if len(integers) != 4: # no bbox in text
491
+ bbox = mask2bbox(mask)
492
+ user_message = user_message + bbox
493
+
494
+ if chat_state is None:
495
+ chat_state = CONV_VISION.copy()
496
+
497
+ if upload_flag:
498
+ if replace_flag:
499
+ chat_state = CONV_VISION.copy() # new image, reset everything
500
+ replace_flag = 0
501
+ chatbot = []
502
+ img_list = []
503
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
504
+ upload_flag = 0
505
+
506
+ chat.ask(user_message, chat_state)
507
+
508
+ chatbot = chatbot + [[user_message, None]]
509
+
510
+ if '[identify]' in user_message:
511
+ visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
512
+ if visual_img is not None:
513
+ file_path = save_tmp_img(visual_img)
514
+ chatbot = chatbot + [[(file_path,), None]]
515
+
516
+ return text_box_show, chatbot, chat_state, img_list, upload_flag, replace_flag
517
+
518
+
519
+ # def gradio_answer(chatbot, chat_state, img_list, temperature):
520
+ # llm_message = chat.answer(conv=chat_state,
521
+ # img_list=img_list,
522
+ # temperature=temperature,
523
+ # max_new_tokens=500,
524
+ # max_length=2000)[0]
525
+ # chatbot[-1][1] = llm_message
526
+ # return chatbot, chat_state
527
+
528
+
529
+ def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
530
+ if len(img_list) > 0:
531
+ if not isinstance(img_list[0], torch.Tensor):
532
+ chat.encode_img(img_list)
533
+ streamer = chat.stream_answer(conv=chat_state,
534
+ img_list=img_list,
535
+ temperature=temperature,
536
+ max_new_tokens=500,
537
+ max_length=2000)
538
+ # chatbot[-1][1] = output
539
+ # chat_state.messages[-1][1] = '</s>'
540
+
541
+ output = ''
542
+ for new_output in streamer:
543
+ # print(new_output)
544
+ output=output+new_output
545
+ print(output)
546
+ # if "{" in output:
547
+ # chatbot[-1][1]="Grounding and referring expression is still under work."
548
+ # else:
549
+ output = escape_markdown(output)
550
+ # output += escapped
551
+ chatbot[-1][1] = output
552
+ yield chatbot, chat_state
553
+ chat_state.messages[-1][1] = '</s>'
554
+ return chatbot, chat_state
555
+
556
+
557
+ def gradio_visualize(chatbot, gr_img):
558
+ if isinstance(gr_img, dict):
559
+ gr_img, mask = gr_img['image'], gr_img['mask']
560
+
561
+ unescaped = reverse_escape(chatbot[-1][1])
562
+ visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
563
+ if visual_img is not None:
564
+ if len(generation_color):
565
+ chatbot[-1][1] = generation_color
566
+ file_path = save_tmp_img(visual_img)
567
+ chatbot = chatbot + [[None, (file_path,)]]
568
+
569
+ return chatbot
570
+
571
+
572
+ def gradio_taskselect(idx):
573
+ prompt_list = [
574
+ '',
575
+ 'Classify the image in the following classes: ',
576
+ '[identify] what is this ',
577
+ ]
578
+ instruct_list = [
579
+ '**Hint:** Type in whatever you want',
580
+ '**Hint:** Type in the classes you want the model to classify in',
581
+ '**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw',
582
+ ]
583
+ return prompt_list[idx], instruct_list[idx]
584
+
585
+
586
+
587
+
588
+ chat = Chat(model, image_processor,tokenizer, device=device)
589
+
590
+
591
+ title = """<h1 align="center">GeoChat Demo</h1>"""
592
+ description = 'Welcome to Our GeoChat Chatbot Demo!'
593
+ article = """<div style="display: flex;"><p style="display: inline-block;"><a href='https://mbzuai-oryx.github.io/GeoChat'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p style="display: inline-block;"><a href='https://arxiv.org/abs/2311.15826'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p><p style="display: inline-block;"><a href='https://github.com/mbzuai-oryx/GeoChat/tree/main'><img src='https://img.shields.io/badge/GitHub-Repo-blue'></a></p><p style="display: inline-block;"><a href='https://youtu.be/KOKtkkKpNDk?feature=shared'><img src='https://img.shields.io/badge/YouTube-Video-red'></a></p></div>"""
594
+ # article = """<p><a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p>"""
595
+
596
+ introduction = '''
597
+ 1. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time).
598
+ 2. No Tag: Input whatever you want and CLICK **Send** without any tagging
599
+
600
+ You can also simply chat in free form!
601
+ '''
602
+
603
+
604
+ text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False,
605
+ scale=12)
606
+ with gr.Blocks() as demo:
607
+ gr.Markdown(title)
608
+ # gr.Markdown(description)
609
+ gr.Markdown(article)
610
+
611
+ with gr.Row():
612
+ with gr.Column(scale=0.5):
613
+ image = gr.Image(type="pil", tool='sketch', brush_radius=20)
614
+
615
+ temperature = gr.Slider(
616
+ minimum=0.1,
617
+ maximum=1.5,
618
+ value=0.6,
619
+ step=0.1,
620
+ interactive=True,
621
+ label="Temperature",
622
+ )
623
+
624
+ clear = gr.Button("Restart")
625
+
626
+ gr.Markdown(introduction)
627
+
628
+ with gr.Column():
629
+ chat_state = gr.State(value=None)
630
+ img_list = gr.State(value=[])
631
+ chatbot = gr.Chatbot(label='GeoChat')
632
+
633
+ dataset = gr.Dataset(
634
+ components=[gr.Textbox(visible=False)],
635
+ samples=[['No Tag'], ['Scene Classification'],['Identify']],
636
+ type="index",
637
+ label='Task Shortcuts',
638
+ )
639
+ task_inst = gr.Markdown('**Hint:** Upload your image and chat')
640
+ with gr.Row():
641
+ text_input.render()
642
+ send = gr.Button("Send", variant='primary', size='sm', scale=1)
643
+
644
+ upload_flag = gr.State(value=0)
645
+ replace_flag = gr.State(value=0)
646
+ image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
647
+
648
+ with gr.Row():
649
+ with gr.Column():
650
+ gr.Examples(examples=[
651
+ ["demo_images/train_2956_0001.png", "Where are the airplanes located and what is their type?", upload_flag, replace_flag,
652
+ img_list],
653
+ ["demo_images/7292.JPG", "How many buildings are flooded?", upload_flag,
654
+ replace_flag, img_list],
655
+ ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
656
+ outputs=[upload_flag, replace_flag])
657
+ with gr.Column():
658
+ gr.Examples(examples=[
659
+ ["demo_images/church_183.png", "Classify the image in the following classes: Church, Beach, Dense Residential, Storage Tanks.",
660
+ upload_flag, replace_flag, img_list],
661
+ ["demo_images/04444.png", "[identify] what is this {<8><26><22><37>}", upload_flag,
662
+ replace_flag, img_list],
663
+ ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
664
+ outputs=[upload_flag, replace_flag])
665
+
666
+ dataset.click(
667
+ gradio_taskselect,
668
+ inputs=[dataset],
669
+ outputs=[text_input, task_inst],
670
+ show_progress="hidden",
671
+ postprocess=False,
672
+ queue=False,
673
+ )
674
+
675
+ text_input.submit(
676
+ gradio_ask,
677
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
678
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
679
+ ).success(
680
+ gradio_stream_answer,
681
+ [chatbot, chat_state, img_list, temperature],
682
+ [chatbot, chat_state]
683
+ ).success(
684
+ gradio_visualize,
685
+ [chatbot, image],
686
+ [chatbot],
687
+ queue=False,
688
+ )
689
+
690
+ send.click(
691
+ gradio_ask,
692
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
693
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
694
+ ).success(
695
+ gradio_stream_answer,
696
+ [chatbot, chat_state, img_list, temperature],
697
+ [chatbot, chat_state]
698
+ ).success(
699
+ gradio_visualize,
700
+ [chatbot, image],
701
+ [chatbot],
702
+ queue=False,
703
+ )
704
+
705
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
706
+
707
+ demo.launch(share=True, enable_queue=True,server_name='0.0.0.0')
.ipynb_checkpoints/pyproject-checkpoint.toml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "geochat"
7
+ version = "1.1.1"
8
+ description = "Grounded VLM for Remote Sensing"
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: Apache Software License",
14
+ ]
15
+ dependencies = [
16
+ "einops", "fastapi", "gradio==3.35.2", "markdown2[all]", "numpy",
17
+ "requests", "sentencepiece", "tokenizers>=0.12.1",
18
+ "torch==2.0.1", "torchvision==0.15.2", "uvicorn", "wandb",
19
+ "shortuuid", "httpx==0.24.0",
20
+ #"deepspeed==0.9.5",
21
+ "peft==0.4.0",
22
+ "transformers==4.31.0",
23
+ "accelerate==0.21.0",
24
+ "bitsandbytes==0.41.0",
25
+ "scikit-learn==1.2.2",
26
+ "sentencepiece==0.1.99",
27
+ "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
28
+ "gradio_client==0.2.9"
29
+ ]
30
+
31
+ [project.urls]
32
+ "Homepage" = "https://github.com/mbzuai-oryx/GeoChat"
33
+ "Bug Tracker" = "https://github.com/mbzuai-oryx/GeoChat/issues"
34
+
35
+ [tool.setuptools.packages.find]
36
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
37
+
38
+ [tool.wheel]
39
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
README.md CHANGED
@@ -1,12 +1,231 @@
1
  ---
2
- title: Csu
3
- emoji: 🏢
4
- colorFrom: gray
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.15.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: csu
3
+ app_file: geochat_demo.py
 
 
4
  sdk: gradio
5
+ sdk_version: 3.35.2
 
 
6
  ---
7
+ # GeoChat <img src="images/logo_geochat.png" height="40">: Grounded Large Vision-Language Model for Remote Sensing [CVPR-2024]
8
+ <p align="center">
9
+ <img src="https://i.imgur.com/waxVImv.png" alt="Oryx Video-ChatGPT">
10
+ </p>
11
 
12
+ #### [Kartik Kuckreja](https://www.linkedin.com/in/kartik-kuckreja-930531221/)\*, [Muhammad Sohail Danish](https://www.linkedin.com/in/muhammad-sohail-danish/)\*, [Muzammal Naseer](https://muzammal-naseer.com/), [Abhijit Das](https://sites.google.com/site/dasabhijit2048/home), [Salman Khan](https://salman-h-khan.github.io/) and [Fahad Khan](https://sites.google.com/view/fahadkhans/home)
13
+ \* Equally contributing first authors
14
+
15
+ #### **Mohamed bin Zayed University of AI, Birla Institute of Technology & Science, Australian National University, Linkoping University**
16
+
17
+ [![Website](https://img.shields.io/badge/Project-Website-87CEEB)](https://mbzuai-oryx.github.io/GeoChat)
18
+ [![paper](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2311.15826)
19
+ [![video](https://img.shields.io/badge/Video-Presentation-F9D371)](https://youtu.be/KOKtkkKpNDk)
20
+
21
+ ---
22
+
23
+ ## 📢 Latest Updates
24
+ - Supplementary material for the accepted paper is available here: [Supplementary](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/geochat_supp.pdf).
25
+ - **Feb-28-24**: We open source the code, model, dataset, and evaluation scripts.
26
+ - **Feb-27-24**: GeoChat has been accepted to **CVPR-24** 🎉.
27
+ - **Nov-28-23**: GeoChat paper is released [arxiv link](https://arxiv.org/abs/2311.15826). 🔥🔥
28
+ ---
29
+
30
+
31
+
32
+ ## <img src="images/logo_geochat.png" height="40">Overview
33
+
34
+ GeoChat is the first grounded Large Vision Language Model, specifically tailored to Remote Sensing(RS) scenarios. Unlike general-domain models, GeoChat excels in handling high-resolution RS imagery, employing region-level reasoning for comprehensive scene interpretation. Leveraging a newly created RS multimodal dataset, GeoChat is fine-tuned using the LLaVA-1.5 architecture. This results in robust zero-shot performance across various RS tasks, including image and region captioning, visual question answering, scene classification, visually grounded conversations, and referring object detection.
35
+
36
+ ---
37
+ ## Contents
38
+ - [Install](#install)
39
+ - [Model Zoo](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/MODEL_ZOO.md)
40
+ - [Dataset](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/GeoChat_Instruct.json)
41
+ - [Train](#train)
42
+ - [Evaluation](#evaluation)
43
+
44
+ ## Install
45
+
46
+ 1. Clone this repository and navigate to GeoChat folder
47
+ ```bash
48
+ git clone https://github.com/mbzuai-oryx/GeoChat.git
49
+ cd GeoChat
50
+ ```
51
+
52
+ 2. Install Package
53
+ ```Shell
54
+ conda create -n geochat python=3.10 -y
55
+ conda activate geochat
56
+ pip install --upgrade pip # enable PEP 660 support
57
+ pip install -e .
58
+ ```
59
+
60
+ 3. Install additional packages for training cases
61
+ ```
62
+ pip install ninja
63
+ pip install flash-attn --no-build-isolation
64
+ ```
65
+
66
+ ### Upgrade to latest code base
67
+
68
+ ```Shell
69
+ git pull
70
+ pip uninstall transformers
71
+ pip install -e .
72
+ ```
73
+
74
+ ## GeoChat Weights and Demo
75
+ Please check out our [Model Zoo](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/MODEL_ZOO.md) for all public GeoChat checkpoints, and check [LoRA.md](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/LoRA.md) for instructions on how to run the demo and training.
76
+
77
+ ## Train
78
+
79
+ GeoChat training consists of visual instruction tuning using GeoChat_Instruct Dataset: 318k Vicuna-generated multimodal instruction-following data, finetuned over the pretrained weights of LlaVA-v1.5.
80
+
81
+ We train GeoChat on 3 A100 GPUs with 40GB memory. To train on fewer GPUs, you can reduce the `per_device_train_batch_size` and increase the `gradient_accumulation_steps` accordingly. Always keep the global batch size the same: `per_device_train_batch_size` x `gradient_accumulation_steps` x `num_gpus`.
82
+
83
+ ### Hyperparameters
84
+ We use a similar set of hyperparameters as Vicuna in finetuning. Both hyperparameters used in pretraining and finetuning are provided below.
85
+
86
+ | Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
87
+ | --- | ---: | ---: | ---: | ---: | ---: |
88
+ | GeoChat-7B | 144 | 2e-5 | 1 | 2048 | 0 |
89
+
90
+ ### Pretrain (feature alignment)
91
+
92
+ We use the pretrained projector from LLaVAv1.5, which is trained on 558K subset of the LAION-CC-SBU dataset with BLIP captions. It takes around 3.5 hours for LLaVA-v1.5-7B.
93
+
94
+ - `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
95
+ - `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
96
+
97
+ ### Visual Instruction Tuning
98
+
99
+ 1. Prepare data
100
+
101
+ Please download the annotation of the final mixture of our instruction tuning data [GeoChat_Instruct.json](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/GeoChat_Instruct.json), and download the split image zips from the [hugging face](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct). Save the multiple image zips in a single folder and run the following command to merge them:
102
+ ```Shell
103
+ cat images_parta* > images.zip
104
+ ```
105
+ Unzip the images.zip file to a folder and give the folder's path in [finetune_lora.sh](https://github.com/mbzuai-oryx/GeoChat/blob/main/scripts/finetune_lora.sh).
106
+
107
+ 2. Start training!
108
+
109
+ Visual instruction tuning takes more time due to the increased resolution of CLIP to 504X504. It takes around ~25 hours to finetune GeoChat-7B on 3x A100 (40G).
110
+
111
+ Training script with DeepSpeed ZeRO-3: [`finetune_lora.sh`](https://github.com/mbzuai-oryx/GeoChat/blob/main/scripts/finetune_lora.sh).
112
+
113
+ Options to note:
114
+
115
+ - `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
116
+ - `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
117
+ - `--image_aspect_ratio pad`: this pads the non-square images to square, instead of cropping them; it slightly reduces hallucination.
118
+ - `--group_by_modality_length True`: this should only be used when your instruction tuning dataset contains both language (e.g. ShareGPT) and multimodal (e.g. LLaVA-Instruct).
119
+ -
120
+ ## Evaluation
121
+
122
+ We evaluate GeoChat on a diverse set of 7 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs.
123
+ See [Evaluation.md](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/Evaluation.md).
124
+
125
+ ## 🏆 Contributions
126
+
127
+ - **RS multimodal instruction following dataset.** We present a novel data generation pipeline, to leverage existing object detection dataset to create short descriptions of the images, followed by using Vicuna-v1.5 to create conversations using the generated text alone. Further, we add visual question-answering and scene classification abilities
128
+ using their corresponding datasets. This results in a total of 318k instruction pairs for RS domain.
129
+ - **GeoChat.** Leveraging our dataset, we finetune LLaVA-1.5 to create the remote sensing-domain vision-language model - GeoChat. Our LoRA fine-tuning is efficient and avoids forgetting the necessary context embedded in fully-tuned LLaVA model, whose MLP projection is trained to align images into the word embedding space of the LLM (Vicuna-v1.5). This allows GeoChat to retain the conversation and instruction following abilities of LLaVA and extend its domain-knowledge to remote sensing tasks.
130
+
131
+ - **Evaluation Benchmark.** We also address the lack of evaluation benchmarks to assess the capability of existing VLMs on remote-sensing conversations. To this end, we setup evaluation protocols for conversation grounding in RS, as well as a setup a suite of tasks to allow comparisons with future efforts in this direction. We show various supervised as well as zero-shot evaluations for different remote sensing tasks, including image captioning, visual question answering and scene classification to demonstrate the generalisability of GeoChat conversational VLM.
132
+
133
+ ---
134
+ ## 👁️💬 GeoChat : Grounded Large Vision-Language Model for Remote Sensing
135
+
136
+ GeoChat can accomplish multiple tasks for remote-sensing (RS) image comprehension in a unified framework. Given suitable task tokens and user queries, the model can generate visually grounded responses (text with corresponding object locations - shown on top), visual question answering on images and regions (top left and bottom right, respectively) as well as scene classification (top right) and normal natural language conversations (bottom). This makes it the first RS VLM with grounding capability.
137
+
138
+ <p align="center">
139
+ <img src="images/overview2.png" alt="GeoChat Overview">
140
+ </p>
141
+
142
+ ---
143
+
144
+ ## 🛰️ GeoChat : Architecture
145
+
146
+ An overview of GeoChat - the first grounded large vision-language model for remote sensing. Given an image input together with a user query, a visual backbone is first used to encode patch-level tokens at a higher resolution via interpolating positional encodings. A multi-layer perceptron (MLP) is used to adapt vision-tokens to language space suitable for input to a Large Language Model (Vicuna 1.5). Besides visual inputs, region locations can also be input to the model together with task-specific prompts that specify the desired task required by the user. Given this context, the LLM can generate natural language responses interleaved with corresponding object locations. GeoChat can perform multiple tasks as shown on top e.g., scene classification, image/region captioning, VQA and grounded conversations.
147
+
148
+ <p align="center">
149
+ <img src="images/architecture.png" alt="GeoChat Architectural">
150
+ </p>
151
+
152
+ ---
153
+
154
+ ## 🔍 RS Multimodal Instruction Dataset
155
+
156
+ Types of annotations available in the GeoChat instruction-set. For a given RS image, we obtain object attribute and relationship information, referring expressions and region captions along with their corresponding region annotations (shown over the image). This structured information is used to create the rich instruction-set with a total of 318k image-instruction pairs.
157
+
158
+ <p align="center">
159
+ <img src="images/dataset.png" alt="Dataset Annotation Pipeline">
160
+ </p>
161
+
162
+
163
+
164
+ ## 🤖 Qualitative results of GeoChat
165
+
166
+ Qualitative results of GeoChat. (<em>left-right</em>) Results are shown on grounding, referring object detection, and disaster/damage detection. The user can provide task-specific tokens (e.g., <strong>[grounding]</strong>) to shape model responses according to the desired behavior. The model can generate textual responses (<em>right</em>), only visual grounding (<em>center</em>) and both text and object groundings interleaved together (<em>left</em>). The model can also specify object types, object counts, object attributes and object relationships.
167
+ <p align="center">
168
+ <img src="images/examples.png" alt="Results_GCG">
169
+ </p>
170
+
171
+ ---
172
+
173
+ ## 🤖 Visual Question Answering
174
+ Qualitative examples for Visual Question Answering tasks. GeoChat is able to hold multi-turn conversations, based on various types of questions, including presence, count, complex comparisons and so on. It is able to detect objects and hold conversations against low resolution images as well.
175
+ <p align="center">
176
+ <img src="images/vqa.jpg" alt="Visual Question Answering">
177
+ </p>
178
+
179
+ ---
180
+
181
+ ## 🤖 Scene Classification
182
+ Qualitative examples for scene classification. We give the model all the classes from the dataset and ask to choose only one.
183
+ <p align="center">
184
+ <img src="images/scene.jpg" alt="Visual Question Answering">
185
+ </p>
186
+
187
+ ---
188
+
189
+ ## 🤖 Grounded Description
190
+ When asked to describe the image with the special token '[grounding]', GeoChat outputs both the description of the image as well as the bounding boxes for all the objects detected.
191
+ <p align="center">
192
+ <img src="images/grounded.jpg" alt="Grounded Description">
193
+ </p>
194
+
195
+ ---
196
+
197
+ ## 🤖 Referring Expression
198
+ When asked about an object as a referred expression, GeoChat is able to locate it and draw rotated bounding boxes around it correspondingly.
199
+ <p align="center">
200
+ <img src="images/ref1.jpg" alt="Referring Expression">
201
+ </p>
202
+ <p align="center">
203
+ <img src="images/ref_2.jpg" alt="Referring Expression">
204
+ </p>
205
+
206
+ ---
207
+
208
+ ## 🤖 Region Caption
209
+ Qualitative examples for region-based captioning. Given a bounding box, GeoChat is able to provide brief descriptions about the area or the object covered by the bounding box.
210
+ <p align="center">
211
+ <img src="images/iden.jpg" alt="Region Caption">
212
+ </p>
213
+
214
+ ---
215
+
216
+ ## 📜 Citation
217
+ ```bibtex
218
+ @article{kuckreja2023geochat,
219
+ title={GeoChat: Grounded Large Vision-Language Model for Remote Sensing},
220
+ author={Kuckreja, Kartik and Danish, Muhammad S. and Naseer, Muzammal and Das, Abhijit and Khan, Salman and Khan, Fahad S.},
221
+ journal={The IEEE/CVF Conference on Computer Vision and Pattern Recognition},
222
+ year={2024}
223
+ }
224
+ ```
225
+ ## 🙏 Acknowledgement
226
+ We are thankful to LLaVA and Vicuna for releasing their models and code as open-source contributions.
227
+
228
+ ---
229
+ [<img src="images/IVAL_logo.png" width="200" height="100">](https://www.ival-mbzuai.com)
230
+ [<img src="images/Oryx_logo.png" width="100" height="100">](https://github.com/mbzuai-oryx)
231
+ [<img src="images/MBZUAI_logo.png" width="360" height="85">](https://mbzuai.ac.ae)
demo_images/04133.png ADDED

Git LFS Details

  • SHA256: d554202729a40d67eb39fc38759e196ca628cf6f7f3c2679b075fcb9a9f52e80
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB
demo_images/04444.png ADDED

Git LFS Details

  • SHA256: 9c6ec5f948638f44dd80814bfe20205a5c57ec9d01ebf3fcaf50e5a37f2067f5
  • Pointer size: 131 Bytes
  • Size of remote file: 969 kB
demo_images/7292.JPG ADDED

Git LFS Details

  • SHA256: 5a16bbdb6f4743afac0dc3ea914003c5609ff735966d88a3f6cfccddf837baaf
  • Pointer size: 132 Bytes
  • Size of remote file: 3.41 MB
demo_images/MicrosoftTeams-image.png ADDED

Git LFS Details

  • SHA256: 1b20fb8c3e814b8bb1895079ff1c02d0234dc21607ddada6c7cc0ba89f45e479
  • Pointer size: 131 Bytes
  • Size of remote file: 273 kB
demo_images/church_183.png ADDED

Git LFS Details

  • SHA256: 225af61e9e76edbdfe995b6f8d9d1a07255e05d3b558653db687c800d293bdde
  • Pointer size: 131 Bytes
  • Size of remote file: 686 kB
demo_images/train_2956_0001.png ADDED

Git LFS Details

  • SHA256: 2bcd2e7cd60fb52bd786f7cc7705ea6ddb68238a17fc2da6bd8495d307f11ae9
  • Pointer size: 131 Bytes
  • Size of remote file: 680 kB
docs/Customize_Component.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Customize Components in GeoChat
2
+
3
+ This is an initial guide on how to replace the LLMs, visual encoders, etc. with your choice of components.
4
+
5
+ ## LLM
6
+
7
+ It is quite simple to swap out LLaMA to any other LLMs. You can refer to our implementation of [`GeoChat_llama.py`](https://github.com/mbzuai-oryx/GeoChat/blob/main/geochat/model/language_model/geochat_llama.py) for an example of how to replace the LLM.
8
+
9
+ Although it may seem that it still needs ~100 lines of code, most of them are copied from the original `llama.py` from HF. The only part that is different is to insert some lines for processing the multimodal inputs.
10
+
11
+ In `forward` function, you can see that we call `self.prepare_inputs_labels_for_multimodal` to process the multimodal inputs. This function is defined in `GeoChatMetaForCausalLM` and you just need to insert it into the `forward` function of your LLM.
12
+
13
+ In `prepare_inputs_for_generation` function, you can see that we add `images` to the `model_inputs`. This is because we need to pass the images to the LLM during generation.
14
+
15
+ These are basically all the changes you need to make to replace the LLM.
16
+
17
+ ## Visual Encoder
18
+
19
+ You can check out [`clip_encoder.py`](https://github.com/haotian-liu/LLaVA/blob/main/llava/model/multimodal_encoder/clip_encoder.py) on how we implement the CLIP visual encoder.
20
+
docs/Data.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Finetuning Data
2
+ We use GeoChat-Instruct to finetune our model. The instruction following dataset is present in GeoChat_Instruct.json and the images are present in the [huggingface repo](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct). The images are split into multiple files. Download the separate files in the same folder and run the following script to merge them.
3
+
4
+ ```Shell
5
+ cat images_parta* > images.zip
6
+ ```
7
+
8
+ Unzip the images in a folder and provide the folder path in training and evaluation scripts.
9
+
10
+ | Data file name | Size |
11
+ | --- | ---: |
12
+ | [GeoChat_Instruct](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/GeoChat_Instruct.json) | 263 MB |
13
+
14
+ ## Pretraining Dataset
15
+ We use the same pretraining dataset as of LlaVA-v1.5.
16
+ The pretraining dataset used in this release is a subset of CC-3M dataset, filtered with a more balanced concept coverage distribution. Please see [here](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K) for a detailed description of the dataset structure and how to download the images.
17
+
18
+ If you already have CC-3M dataset on your disk, the image names follow this format: `GCC_train_000000000.jpg`. You may edit the `image` field correspondingly if necessary.
19
+
20
+ | Data | Chat File | Meta Data | Size |
21
+ | --- | --- | --- | ---: |
22
+ | CC-3M Concept-balanced 595K | [chat.json](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/chat.json) | [metadata.json](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/metadata.json) | 211 MB
23
+ | LAION/CC/SBU BLIP-Caption Concept-balanced 558K | [blip_laion_cc_sbu_558k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/blob/main/blip_laion_cc_sbu_558k.json) | [metadata.json](#) | 181 MB
24
+
docs/Evaluation.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation
2
+
3
+ We evaluate GeoChat on a variety of tasks, including scene classification, region captioning, visual grounding, grounding description and VQA.
4
+ Converted files in the input format for GeoChat are available at [GeoChat-Bench](https://huggingface.co/datasets/MBZUAI/GeoChat-Bench/tree/main)
5
+
6
+
7
+ Below we provide a general guideline for evaluating datasets.
8
+
9
+ 1. LRBEN/HRBEN.
10
+ Images and ground truth for evaluation need to be downloaded from the following sources: [LRBEN](https://zenodo.org/records/6344334), [HRBEN](https://zenodo.org/records/6344367)
11
+ Give the path to the extracted image folder in the evaluation script. We add the following text after each question during our evaluation.
12
+ ```
13
+ <question>
14
+ Answer the question using a single word or phrase.
15
+ ```
16
+ ```Shell
17
+ python geochat/eval/batch_geochat_vqa.py \
18
+ --model-path /path/to/model \
19
+ --question-file path/to/jsonl/file \
20
+ --answer-file path/to/output/jsonl/file \
21
+ --image_folder path/to/image/folder/
22
+ ```
23
+ 2. Scene Classification.
24
+ Download the images from the following sources, [UCmerced](http://weegee.vision.ucmerced.edu/datasets/landuse.html), [AID](https://drive.google.com/drive/folders/1-1D9DrYYWMGuuxx-qcvIIOV1oUkAVf-M). We add the following text after each question during our evaluation.
25
+ ```
26
+ <question>
27
+ Classify the image from the following classes. Answer in one word or a short phrase.
28
+ ```
29
+ ```Shell
30
+ python geochat/eval/batch_geochat_scene.py \
31
+ --model-path /path/to/model \
32
+ --question-file path/to/jsonl/file \
33
+ --answer-file path/to/output/jsonl/file \
34
+ --image_folder path/to/image/folder/
35
+ ```
36
+
37
+ 3. Region-Captioning/Visual grounding.
38
+
39
+ The evaluation images are present in the image.zip folder in [GeoChat_Instruct](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/images.zip).
40
+ ```Shell
41
+ python geochat/eval/batch_geochat_grounding.py \
42
+ --model-path /path/to/model \
43
+ --question-file path/to/jsonl/file \
44
+ --answer-file path/to/output/jsonl/file \
45
+ --image_folder path/to/image/folder/
46
+ ```
47
+
48
+ ```Shell
49
+ python geochat/eval/batch_geochat_referring.py \
50
+ --model-path /path/to/model \
51
+ --question-file path/to/jsonl/file \
52
+ --answer-file path/to/output/jsonl/file \
53
+ --image_folder path/to/image/folder/
54
+ ```
docs/LoRA.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Demo (Web UI)
3
+ You need GeoChat-7B to run the demo locally. Download the model from [GeoChat-7B](https://huggingface.co/MBZUAI/geochat-7B). After loading the model, run this command by giving the model path to launch the gradio demo.
4
+ #### Launch the demo
5
+ ```Shell
6
+ python geochat_demo.py --model-path /path/to/model
7
+ ```
8
+
9
+ ## Training
10
+
11
+ Please see sample training scripts for [LoRA](https://github.com/mbzuai-oryx/GeoChat/blob/main/scripts/finetune_lora.sh)
12
+
13
+ We provide sample DeepSpeed configs, [`zero3.json`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/zero3.json) is more like PyTorch FSDP, and [`zero3_offload.json`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/zero3_offload.json) can further save memory consumption by offloading parameters to CPU. `zero3.json` is usually faster than `zero3_offload.json` but requires more GPU memory, therefore, we recommend trying `zero3.json` first, and if you run out of GPU memory, try `zero3_offload.json`. You can also tweak the `per_device_train_batch_size` and `gradient_accumulation_steps` in the config to save memory, and just to make sure that `per_device_train_batch_size` and `gradient_accumulation_steps` remains the same.
14
+
15
+ If you are having issues with ZeRO-3 configs, and there are enough VRAM, you may try [`zero2.json`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/zero2.json). This consumes slightly more memory than ZeRO-3, and behaves more similar to PyTorch FSDP, while still supporting parameter-efficient tuning.
16
+
17
+ ## Create Merged Checkpoints
18
+
19
+ ```Shell
20
+ python scripts/merge_lora_weights.py \
21
+ --model-path /path/to/lora_model \
22
+ --model-base /path/to/base_model \
23
+ --save-model-path /path/to/merge_model
24
+ ```
docs/MODEL_ZOO.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Zoo
2
+
3
+ | Base LLM | Vision Encoder | Pretrain Data | Pretraining schedule | Finetuning Data | Finetuning schedule | Download |
4
+ |----------|----------------|---------------|----------------------|-----------------|--------------------|------------------
5
+ | Vicuna-13B-v1.3 | CLIP-L-336px(extended to 504) | LCS-558K | 1e | Geochat_Instruct | proj-1e, lora-1e | [LoRA-Merged](https://huggingface.co/MBZUAI/geochat-7B) |
6
+
7
+ ## Projector weights
8
+ We use the projector from LlaVA-1.5 for initialization. [Link](https://huggingface.co/liuhaotian/llava-v1.5-7b-lora)
9
+
10
+ **NOTE**: When you use our pretrained projector for visual instruction tuning, it is very important to **use the same base LLM and vision encoder** as the one we used for pretraining the projector. Otherwise, the performance will be very bad.
11
+
12
+ When using these projector weights to instruction tune your LMM, please make sure that these options are correctly set as follows,
13
+
14
+ ```Shell
15
+ --mm_use_im_start_end False
16
+ --mm_use_im_patch_token False
17
+ ```
18
+
docs/geochat_supp.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc9b5d3df4af5c06e59fc2258332e00b779c55d441405cb5a7fd7997d29b63fe
3
+ size 4839915
geochat.egg-info/PKG-INFO ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.2
2
+ Name: geochat
3
+ Version: 1.1.1
4
+ Summary: Grounded VLM for Remote Sensing
5
+ Project-URL: Homepage, https://github.com/mbzuai-oryx/GeoChat
6
+ Project-URL: Bug Tracker, https://github.com/mbzuai-oryx/GeoChat/issues
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: Apache Software License
9
+ Requires-Python: >=3.8
10
+ Description-Content-Type: text/markdown
11
+ Requires-Dist: einops
12
+ Requires-Dist: fastapi
13
+ Requires-Dist: gradio==3.35.2
14
+ Requires-Dist: markdown2[all]
15
+ Requires-Dist: numpy
16
+ Requires-Dist: requests
17
+ Requires-Dist: sentencepiece
18
+ Requires-Dist: tokenizers>=0.12.1
19
+ Requires-Dist: torch==2.0.1
20
+ Requires-Dist: torchvision==0.15.2
21
+ Requires-Dist: uvicorn
22
+ Requires-Dist: wandb
23
+ Requires-Dist: shortuuid
24
+ Requires-Dist: httpx==0.24.0
25
+ Requires-Dist: peft==0.4.0
26
+ Requires-Dist: transformers==4.31.0
27
+ Requires-Dist: accelerate==0.21.0
28
+ Requires-Dist: bitsandbytes==0.41.0
29
+ Requires-Dist: scikit-learn==1.2.2
30
+ Requires-Dist: sentencepiece==0.1.99
31
+ Requires-Dist: einops==0.6.1
32
+ Requires-Dist: einops-exts==0.0.4
33
+ Requires-Dist: timm==0.6.13
34
+ Requires-Dist: gradio_client==0.2.9
35
+
36
+ # GeoChat <img src="images/logo_geochat.png" height="40">: Grounded Large Vision-Language Model for Remote Sensing [CVPR-2024]
37
+ <p align="center">
38
+ <img src="https://i.imgur.com/waxVImv.png" alt="Oryx Video-ChatGPT">
39
+ </p>
40
+
41
+ #### [Kartik Kuckreja](https://www.linkedin.com/in/kartik-kuckreja-930531221/)\*, [Muhammad Sohail Danish](https://www.linkedin.com/in/muhammad-sohail-danish/)\*, [Muzammal Naseer](https://muzammal-naseer.com/), [Abhijit Das](https://sites.google.com/site/dasabhijit2048/home), [Salman Khan](https://salman-h-khan.github.io/) and [Fahad Khan](https://sites.google.com/view/fahadkhans/home)
42
+ \* Equally contributing first authors
43
+
44
+ #### **Mohamed bin Zayed University of AI, Birla Institute of Technology & Science, Australian National University, Linkoping University**
45
+
46
+ [![Website](https://img.shields.io/badge/Project-Website-87CEEB)](https://mbzuai-oryx.github.io/GeoChat)
47
+ [![paper](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2311.15826)
48
+ [![video](https://img.shields.io/badge/Video-Presentation-F9D371)](https://youtu.be/KOKtkkKpNDk)
49
+
50
+ ---
51
+
52
+ ## 📢 Latest Updates
53
+ - Supplementary material for the accepted paper is available here: [Supplementary](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/geochat_supp.pdf).
54
+ - **Feb-28-24**: We open source the code, model, dataset, and evaluation scripts.
55
+ - **Feb-27-24**: GeoChat has been accepted to **CVPR-24** 🎉.
56
+ - **Nov-28-23**: GeoChat paper is released [arxiv link](https://arxiv.org/abs/2311.15826). 🔥🔥
57
+ ---
58
+
59
+
60
+
61
+ ## <img src="images/logo_geochat.png" height="40">Overview
62
+
63
+ GeoChat is the first grounded Large Vision Language Model, specifically tailored to Remote Sensing(RS) scenarios. Unlike general-domain models, GeoChat excels in handling high-resolution RS imagery, employing region-level reasoning for comprehensive scene interpretation. Leveraging a newly created RS multimodal dataset, GeoChat is fine-tuned using the LLaVA-1.5 architecture. This results in robust zero-shot performance across various RS tasks, including image and region captioning, visual question answering, scene classification, visually grounded conversations, and referring object detection.
64
+
65
+ ---
66
+ ## Contents
67
+ - [Install](#install)
68
+ - [Model Zoo](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/MODEL_ZOO.md)
69
+ - [Dataset](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/GeoChat_Instruct.json)
70
+ - [Train](#train)
71
+ - [Evaluation](#evaluation)
72
+
73
+ ## Install
74
+
75
+ 1. Clone this repository and navigate to GeoChat folder
76
+ ```bash
77
+ git clone https://github.com/mbzuai-oryx/GeoChat.git
78
+ cd GeoChat
79
+ ```
80
+
81
+ 2. Install Package
82
+ ```Shell
83
+ conda create -n geochat python=3.10 -y
84
+ conda activate geochat
85
+ pip install --upgrade pip # enable PEP 660 support
86
+ pip install -e .
87
+ ```
88
+
89
+ 3. Install additional packages for training cases
90
+ ```
91
+ pip install ninja
92
+ pip install flash-attn --no-build-isolation
93
+ ```
94
+
95
+ ### Upgrade to latest code base
96
+
97
+ ```Shell
98
+ git pull
99
+ pip uninstall transformers
100
+ pip install -e .
101
+ ```
102
+
103
+ ## GeoChat Weights and Demo
104
+ Please check out our [Model Zoo](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/MODEL_ZOO.md) for all public GeoChat checkpoints, and check [LoRA.md](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/LoRA.md) for instructions on how to run the demo and training.
105
+
106
+ ## Train
107
+
108
+ GeoChat training consists of visual instruction tuning using GeoChat_Instruct Dataset: 318k Vicuna-generated multimodal instruction-following data, finetuned over the pretrained weights of LlaVA-v1.5.
109
+
110
+ We train GeoChat on 3 A100 GPUs with 40GB memory. To train on fewer GPUs, you can reduce the `per_device_train_batch_size` and increase the `gradient_accumulation_steps` accordingly. Always keep the global batch size the same: `per_device_train_batch_size` x `gradient_accumulation_steps` x `num_gpus`.
111
+
112
+ ### Hyperparameters
113
+ We use a similar set of hyperparameters as Vicuna in finetuning. Both hyperparameters used in pretraining and finetuning are provided below.
114
+
115
+ | Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
116
+ | --- | ---: | ---: | ---: | ---: | ---: |
117
+ | GeoChat-7B | 144 | 2e-5 | 1 | 2048 | 0 |
118
+
119
+ ### Pretrain (feature alignment)
120
+
121
+ We use the pretrained projector from LLaVAv1.5, which is trained on 558K subset of the LAION-CC-SBU dataset with BLIP captions. It takes around 3.5 hours for LLaVA-v1.5-7B.
122
+
123
+ - `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
124
+ - `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
125
+
126
+ ### Visual Instruction Tuning
127
+
128
+ 1. Prepare data
129
+
130
+ Please download the annotation of the final mixture of our instruction tuning data [GeoChat_Instruct.json](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/GeoChat_Instruct.json), and download the split image zips from the [hugging face](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct). Save the multiple image zips in a single folder and run the following command to merge them:
131
+ ```Shell
132
+ cat images_parta* > images.zip
133
+ ```
134
+ Unzip the images.zip file to a folder and give the folder's path in [finetune_lora.sh](https://github.com/mbzuai-oryx/GeoChat/blob/main/scripts/finetune_lora.sh).
135
+
136
+ 2. Start training!
137
+
138
+ Visual instruction tuning takes more time due to the increased resolution of CLIP to 504X504. It takes around ~25 hours to finetune GeoChat-7B on 3x A100 (40G).
139
+
140
+ Training script with DeepSpeed ZeRO-3: [`finetune_lora.sh`](https://github.com/mbzuai-oryx/GeoChat/blob/main/scripts/finetune_lora.sh).
141
+
142
+ Options to note:
143
+
144
+ - `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
145
+ - `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
146
+ - `--image_aspect_ratio pad`: this pads the non-square images to square, instead of cropping them; it slightly reduces hallucination.
147
+ - `--group_by_modality_length True`: this should only be used when your instruction tuning dataset contains both language (e.g. ShareGPT) and multimodal (e.g. LLaVA-Instruct).
148
+ -
149
+ ## Evaluation
150
+
151
+ We evaluate GeoChat on a diverse set of 7 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs.
152
+ See [Evaluation.md](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/Evaluation.md).
153
+
154
+ ## 🏆 Contributions
155
+
156
+ - **RS multimodal instruction following dataset.** We present a novel data generation pipeline, to leverage existing object detection dataset to create short descriptions of the images, followed by using Vicuna-v1.5 to create conversations using the generated text alone. Further, we add visual question-answering and scene classification abilities
157
+ using their corresponding datasets. This results in a total of 318k instruction pairs for RS domain.
158
+ - **GeoChat.** Leveraging our dataset, we finetune LLaVA-1.5 to create the remote sensing-domain vision-language model - GeoChat. Our LoRA fine-tuning is efficient and avoids forgetting the necessary context embedded in fully-tuned LLaVA model, whose MLP projection is trained to align images into the word embedding space of the LLM (Vicuna-v1.5). This allows GeoChat to retain the conversation and instruction following abilities of LLaVA and extend its domain-knowledge to remote sensing tasks.
159
+
160
+ - **Evaluation Benchmark.** We also address the lack of evaluation benchmarks to assess the capability of existing VLMs on remote-sensing conversations. To this end, we setup evaluation protocols for conversation grounding in RS, as well as a setup a suite of tasks to allow comparisons with future efforts in this direction. We show various supervised as well as zero-shot evaluations for different remote sensing tasks, including image captioning, visual question answering and scene classification to demonstrate the generalisability of GeoChat conversational VLM.
161
+
162
+ ---
163
+ ## 👁️💬 GeoChat : Grounded Large Vision-Language Model for Remote Sensing
164
+
165
+ GeoChat can accomplish multiple tasks for remote-sensing (RS) image comprehension in a unified framework. Given suitable task tokens and user queries, the model can generate visually grounded responses (text with corresponding object locations - shown on top), visual question answering on images and regions (top left and bottom right, respectively) as well as scene classification (top right) and normal natural language conversations (bottom). This makes it the first RS VLM with grounding capability.
166
+
167
+ <p align="center">
168
+ <img src="images/overview2.png" alt="GeoChat Overview">
169
+ </p>
170
+
171
+ ---
172
+
173
+ ## 🛰️ GeoChat : Architecture
174
+
175
+ An overview of GeoChat - the first grounded large vision-language model for remote sensing. Given an image input together with a user query, a visual backbone is first used to encode patch-level tokens at a higher resolution via interpolating positional encodings. A multi-layer perceptron (MLP) is used to adapt vision-tokens to language space suitable for input to a Large Language Model (Vicuna 1.5). Besides visual inputs, region locations can also be input to the model together with task-specific prompts that specify the desired task required by the user. Given this context, the LLM can generate natural language responses interleaved with corresponding object locations. GeoChat can perform multiple tasks as shown on top e.g., scene classification, image/region captioning, VQA and grounded conversations.
176
+
177
+ <p align="center">
178
+ <img src="images/architecture.png" alt="GeoChat Architectural">
179
+ </p>
180
+
181
+ ---
182
+
183
+ ## 🔍 RS Multimodal Instruction Dataset
184
+
185
+ Types of annotations available in the GeoChat instruction-set. For a given RS image, we obtain object attribute and relationship information, referring expressions and region captions along with their corresponding region annotations (shown over the image). This structured information is used to create the rich instruction-set with a total of 318k image-instruction pairs.
186
+
187
+ <p align="center">
188
+ <img src="images/dataset.png" alt="Dataset Annotation Pipeline">
189
+ </p>
190
+
191
+
192
+
193
+ ## 🤖 Qualitative results of GeoChat
194
+
195
+ Qualitative results of GeoChat. (<em>left-right</em>) Results are shown on grounding, referring object detection, and disaster/damage detection. The user can provide task-specific tokens (e.g., <strong>[grounding]</strong>) to shape model responses according to the desired behavior. The model can generate textual responses (<em>right</em>), only visual grounding (<em>center</em>) and both text and object groundings interleaved together (<em>left</em>). The model can also specify object types, object counts, object attributes and object relationships.
196
+ <p align="center">
197
+ <img src="images/examples.png" alt="Results_GCG">
198
+ </p>
199
+
200
+ ---
201
+
202
+ ## 🤖 Visual Question Answering
203
+ Qualitative examples for Visual Question Answering tasks. GeoChat is able to hold multi-turn conversations, based on various types of questions, including presence, count, complex comparisons and so on. It is able to detect objects and hold conversations against low resolution images as well.
204
+ <p align="center">
205
+ <img src="images/vqa.jpg" alt="Visual Question Answering">
206
+ </p>
207
+
208
+ ---
209
+
210
+ ## 🤖 Scene Classification
211
+ Qualitative examples for scene classification. We give the model all the classes from the dataset and ask to choose only one.
212
+ <p align="center">
213
+ <img src="images/scene.jpg" alt="Visual Question Answering">
214
+ </p>
215
+
216
+ ---
217
+
218
+ ## 🤖 Grounded Description
219
+ When asked to describe the image with the special token '[grounding]', GeoChat outputs both the description of the image as well as the bounding boxes for all the objects detected.
220
+ <p align="center">
221
+ <img src="images/grounded.jpg" alt="Grounded Description">
222
+ </p>
223
+
224
+ ---
225
+
226
+ ## 🤖 Referring Expression
227
+ When asked about an object as a referred expression, GeoChat is able to locate it and draw rotated bounding boxes around it correspondingly.
228
+ <p align="center">
229
+ <img src="images/ref1.jpg" alt="Referring Expression">
230
+ </p>
231
+ <p align="center">
232
+ <img src="images/ref_2.jpg" alt="Referring Expression">
233
+ </p>
234
+
235
+ ---
236
+
237
+ ## 🤖 Region Caption
238
+ Qualitative examples for region-based captioning. Given a bounding box, GeoChat is able to provide brief descriptions about the area or the object covered by the bounding box.
239
+ <p align="center">
240
+ <img src="images/iden.jpg" alt="Region Caption">
241
+ </p>
242
+
243
+ ---
244
+
245
+ ## 📜 Citation
246
+ ```bibtex
247
+ @article{kuckreja2023geochat,
248
+ title={GeoChat: Grounded Large Vision-Language Model for Remote Sensing},
249
+ author={Kuckreja, Kartik and Danish, Muhammad S. and Naseer, Muzammal and Das, Abhijit and Khan, Salman and Khan, Fahad S.},
250
+ journal={The IEEE/CVF Conference on Computer Vision and Pattern Recognition},
251
+ year={2024}
252
+ }
253
+ ```
254
+ ## 🙏 Acknowledgement
255
+ We are thankful to LLaVA and Vicuna for releasing their models and code as open-source contributions.
256
+
257
+ ---
258
+ [<img src="images/IVAL_logo.png" width="200" height="100">](https://www.ival-mbzuai.com)
259
+ [<img src="images/Oryx_logo.png" width="100" height="100">](https://github.com/mbzuai-oryx)
260
+ [<img src="images/MBZUAI_logo.png" width="360" height="85">](https://mbzuai.ac.ae)
geochat.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ geochat/__init__.py
4
+ geochat/constants.py
5
+ geochat/conversation.py
6
+ geochat/mm_utils.py
7
+ geochat/utils.py
8
+ geochat.egg-info/PKG-INFO
9
+ geochat.egg-info/SOURCES.txt
10
+ geochat.egg-info/dependency_links.txt
11
+ geochat.egg-info/requires.txt
12
+ geochat.egg-info/top_level.txt
13
+ geochat/eval/batch_geochat_grounding.py
14
+ geochat/eval/batch_geochat_referring.py
15
+ geochat/eval/batch_geochat_scene.py
16
+ geochat/eval/batch_geochat_vqa.py
17
+ geochat/model/__init__.py
18
+ geochat/model/apply_delta.py
19
+ geochat/model/builder.py
20
+ geochat/model/consolidate.py
21
+ geochat/model/geochat_arch.py
22
+ geochat/model/make_delta.py
23
+ geochat/model/utils.py
24
+ geochat/model/language_model/geochat_llama.py
25
+ geochat/model/language_model/geochat_mpt.py
26
+ geochat/model/language_model/mpt/adapt_tokenizer.py
27
+ geochat/model/language_model/mpt/attention.py
28
+ geochat/model/language_model/mpt/blocks.py
29
+ geochat/model/language_model/mpt/configuration_mpt.py
30
+ geochat/model/language_model/mpt/custom_embedding.py
31
+ geochat/model/language_model/mpt/flash_attn_triton.py
32
+ geochat/model/language_model/mpt/hf_prefixlm_converter.py
33
+ geochat/model/language_model/mpt/meta_init_context.py
34
+ geochat/model/language_model/mpt/modeling_mpt.py
35
+ geochat/model/language_model/mpt/norm.py
36
+ geochat/model/language_model/mpt/param_init_fns.py
37
+ geochat/model/multimodal_encoder/builder.py
38
+ geochat/model/multimodal_encoder/clip_encoder.py
39
+ geochat/model/multimodal_projector/builder.py
40
+ geochat/serve/__init__.py
41
+ geochat/serve/cli.py
42
+ geochat/serve/controller.py
43
+ geochat/serve/gradio_trial.py
44
+ geochat/serve/gradio_web_server.py
45
+ geochat/serve/model_worker.py
46
+ geochat/serve/register_worker.py
47
+ geochat/serve/test_message.py
48
+ geochat/train/geochat_trainer.py
49
+ geochat/train/llama_flash_attn_monkey_patch.py
50
+ geochat/train/train.py
51
+ geochat/train/train_mem.py
geochat.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
geochat.egg-info/requires.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops
2
+ fastapi
3
+ gradio==3.35.2
4
+ markdown2[all]
5
+ numpy
6
+ requests
7
+ sentencepiece
8
+ tokenizers>=0.12.1
9
+ torch==2.0.1
10
+ torchvision==0.15.2
11
+ uvicorn
12
+ wandb
13
+ shortuuid
14
+ httpx==0.24.0
15
+ peft==0.4.0
16
+ transformers==4.31.0
17
+ accelerate==0.21.0
18
+ bitsandbytes==0.41.0
19
+ scikit-learn==1.2.2
20
+ sentencepiece==0.1.99
21
+ einops==0.6.1
22
+ einops-exts==0.0.4
23
+ timm==0.6.13
24
+ gradio_client==0.2.9
geochat.egg-info/top_level.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ demo_images
2
+ geochat
3
+ images
geochat/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import GeoChatLlamaForCausalLM
geochat/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (186 Bytes). View file
 
geochat/__pycache__/constants.cpython-310.pyc ADDED
Binary file (446 Bytes). View file
 
geochat/__pycache__/conversation.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
geochat/__pycache__/mm_utils.cpython-310.pyc ADDED
Binary file (4.92 kB). View file
 
geochat/__pycache__/utils.cpython-310.pyc ADDED
Binary file (4.01 kB). View file
 
geochat/constants.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
geochat/conversation.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ from PIL import Image
5
+ from threading import Thread
6
+
7
+ from geochat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
8
+ # from llava.conversation import conv_templates, SeparatorStyle
9
+ # from llava.model.builder import load_pretrained_model
10
+ from geochat.utils import disable_torch_init
11
+ from geochat.mm_utils import process_images_demo, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
12
+ from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer,TextStreamer
13
+ import torch
14
+ import dataclasses
15
+ from enum import auto, Enum
16
+ from typing import List, Tuple, Any
17
+
18
+
19
+ class SeparatorStyle(Enum):
20
+ """Different separator style."""
21
+ SINGLE = auto()
22
+ TWO = auto()
23
+ MPT = auto()
24
+ PLAIN = auto()
25
+ LLAMA_2 = auto()
26
+
27
+
28
+ @dataclasses.dataclass
29
+ class Conversation:
30
+ """A class that keeps all conversation history."""
31
+ system: str
32
+ roles: List[str]
33
+ messages: List[List[str]]
34
+ offset: int
35
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
36
+ sep: str = "###"
37
+ sep2: str = None
38
+ version: str = "Unknown"
39
+
40
+ skip_next: bool = False
41
+
42
+ def get_prompt(self):
43
+ messages = self.messages
44
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
45
+ messages = self.messages.copy()
46
+ init_role, init_msg = messages[0].copy()
47
+ init_msg = init_msg[0].replace("<image>", "").strip()
48
+ if 'mmtag' in self.version:
49
+ messages[0] = (init_role, init_msg)
50
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
51
+ messages.insert(1, (self.roles[1], "Received."))
52
+ else:
53
+ messages[0] = (init_role, "<image>\n" + init_msg)
54
+
55
+ if self.sep_style == SeparatorStyle.SINGLE:
56
+ ret = self.system + self.sep
57
+ for role, message in messages:
58
+ if message:
59
+ if type(message) is tuple:
60
+ message, _, _ = message
61
+ ret += role + ": " + message + self.sep
62
+ else:
63
+ ret += role + ":"
64
+ elif self.sep_style == SeparatorStyle.TWO:
65
+ seps = [self.sep, self.sep2]
66
+ ret = self.system + seps[0]
67
+ for i, (role, message) in enumerate(messages):
68
+ if message:
69
+ if type(message) is tuple:
70
+ message, _, _ = message
71
+ ret += role + ": " + message + seps[i % 2]
72
+ else:
73
+ ret += role + ":"
74
+ elif self.sep_style == SeparatorStyle.MPT:
75
+ ret = self.system + self.sep
76
+ for role, message in messages:
77
+ if message:
78
+ if type(message) is tuple:
79
+ message, _, _ = message
80
+ ret += role + message + self.sep
81
+ else:
82
+ ret += role
83
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
84
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
85
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
86
+ ret = ""
87
+
88
+ for i, (role, message) in enumerate(messages):
89
+ if i == 0:
90
+ assert message, "first message should not be none"
91
+ assert role == self.roles[0], "first message should come from user"
92
+ if message:
93
+ if type(message) is tuple:
94
+ message, _, _ = message
95
+ if i == 0: message = wrap_sys(self.system) + message
96
+ if i % 2 == 0:
97
+ message = wrap_inst(message)
98
+ ret += self.sep + message
99
+ else:
100
+ ret += " " + message + " " + self.sep2
101
+ else:
102
+ ret += ""
103
+ ret = ret.lstrip(self.sep)
104
+ elif self.sep_style == SeparatorStyle.PLAIN:
105
+ seps = [self.sep, self.sep2]
106
+ ret = self.system
107
+ for i, (role, message) in enumerate(messages):
108
+ if message:
109
+ if type(message) is tuple:
110
+ message, _, _ = message
111
+ ret += message + seps[i % 2]
112
+ else:
113
+ ret += ""
114
+ else:
115
+ raise ValueError(f"Invalid style: {self.sep_style}")
116
+
117
+ return ret
118
+
119
+ def append_message(self, role, message):
120
+ self.messages.append([role, message])
121
+
122
+ def get_images(self, return_pil=False):
123
+ images = []
124
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
125
+ if i % 2 == 0:
126
+ if type(msg) is tuple:
127
+ import base64
128
+ from io import BytesIO
129
+ from PIL import Image
130
+ msg, image, image_process_mode = msg
131
+ if image_process_mode == "Pad":
132
+ def expand2square(pil_img, background_color=(122, 116, 104)):
133
+ width, height = pil_img.size
134
+ if width == height:
135
+ return pil_img
136
+ elif width > height:
137
+ result = Image.new(pil_img.mode, (width, width), background_color)
138
+ result.paste(pil_img, (0, (width - height) // 2))
139
+ return result
140
+ else:
141
+ result = Image.new(pil_img.mode, (height, height), background_color)
142
+ result.paste(pil_img, ((height - width) // 2, 0))
143
+ return result
144
+ image = expand2square(image)
145
+ elif image_process_mode in ["Default", "Crop"]:
146
+ pass
147
+ elif image_process_mode == "Resize":
148
+ image = image.resize((336, 336))
149
+ else:
150
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
151
+ max_hw, min_hw = max(image.size), min(image.size)
152
+ aspect_ratio = max_hw / min_hw
153
+ max_len, min_len = 800, 400
154
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
155
+ longest_edge = int(shortest_edge * aspect_ratio)
156
+ W, H = image.size
157
+ if longest_edge != max(image.size):
158
+ if H > W:
159
+ H, W = longest_edge, shortest_edge
160
+ else:
161
+ H, W = shortest_edge, longest_edge
162
+ image = image.resize((W, H))
163
+ if return_pil:
164
+ images.append(image)
165
+ else:
166
+ buffered = BytesIO()
167
+ image.save(buffered, format="PNG")
168
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
169
+ images.append(img_b64_str)
170
+ return images
171
+
172
+ def to_gradio_chatbot(self):
173
+ ret = []
174
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
175
+ if i % 2 == 0:
176
+ if type(msg) is tuple:
177
+ import base64
178
+ from io import BytesIO
179
+ msg, image, image_process_mode = msg
180
+ max_hw, min_hw = max(image.size), min(image.size)
181
+ aspect_ratio = max_hw / min_hw
182
+ max_len, min_len = 800, 400
183
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
184
+ longest_edge = int(shortest_edge * aspect_ratio)
185
+ W, H = image.size
186
+ if H > W:
187
+ H, W = longest_edge, shortest_edge
188
+ else:
189
+ H, W = shortest_edge, longest_edge
190
+ image = image.resize((W, H))
191
+ buffered = BytesIO()
192
+ image.save(buffered, format="JPEG")
193
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
194
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
195
+ msg = img_str + msg.replace('<image>', '').strip()
196
+ ret.append([msg, None])
197
+ else:
198
+ ret.append([msg, None])
199
+ else:
200
+ ret[-1][-1] = msg
201
+ return ret
202
+
203
+ def copy(self):
204
+ return Conversation(
205
+ system=self.system,
206
+ roles=self.roles,
207
+ messages=[[x, y] for x, y in self.messages],
208
+ offset=self.offset,
209
+ sep_style=self.sep_style,
210
+ sep=self.sep,
211
+ sep2=self.sep2,
212
+ version=self.version)
213
+
214
+ def dict(self):
215
+ if len(self.get_images()) > 0:
216
+ return {
217
+ "system": self.system,
218
+ "roles": self.roles,
219
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
220
+ "offset": self.offset,
221
+ "sep": self.sep,
222
+ "sep2": self.sep2,
223
+ }
224
+ return {
225
+ "system": self.system,
226
+ "roles": self.roles,
227
+ "messages": self.messages,
228
+ "offset": self.offset,
229
+ "sep": self.sep,
230
+ "sep2": self.sep2,
231
+ }
232
+
233
+
234
+ conv_vicuna_v0 = Conversation(
235
+ system="A chat between a curious human and an artificial intelligence assistant. "
236
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
237
+ roles=("Human", "Assistant"),
238
+ messages=(
239
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
240
+ ("Assistant",
241
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
242
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
243
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
244
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
245
+ "renewable and non-renewable energy sources:\n"
246
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
247
+ "energy sources are finite and will eventually run out.\n"
248
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
249
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
250
+ "and other negative effects.\n"
251
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
252
+ "have lower operational costs than non-renewable sources.\n"
253
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
254
+ "locations than non-renewable sources.\n"
255
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
256
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
257
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
258
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
259
+ ),
260
+ offset=2,
261
+ sep_style=SeparatorStyle.SINGLE,
262
+ sep="###",
263
+ )
264
+
265
+ conv_vicuna_v1 = Conversation(
266
+ system="A chat between a curious user and an artificial intelligence assistant. "
267
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
268
+ roles=("USER", "ASSISTANT"),
269
+ version="v1",
270
+ messages=(),
271
+ offset=0,
272
+ sep_style=SeparatorStyle.TWO,
273
+ sep=" ",
274
+ sep2="</s>",
275
+ )
276
+
277
+ conv_llama_2 = Conversation(
278
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
279
+
280
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
281
+ roles=("USER", "ASSISTANT"),
282
+ version="llama_v2",
283
+ messages=(),
284
+ offset=0,
285
+ sep_style=SeparatorStyle.LLAMA_2,
286
+ sep="<s>",
287
+ sep2="</s>",
288
+ )
289
+
290
+ conv_llava_llama_2 = Conversation(
291
+ system="You are a helpful language and vision assistant. "
292
+ "You are able to understand the visual content that the user provides, "
293
+ "and assist the user with a variety of tasks using natural language.",
294
+ roles=("USER", "ASSISTANT"),
295
+ version="llama_v2",
296
+ messages=(),
297
+ offset=0,
298
+ sep_style=SeparatorStyle.LLAMA_2,
299
+ sep="<s>",
300
+ sep2="</s>",
301
+ )
302
+
303
+ conv_mpt = Conversation(
304
+ system="""<|im_start|>system
305
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
306
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
307
+ version="mpt",
308
+ messages=(),
309
+ offset=0,
310
+ sep_style=SeparatorStyle.MPT,
311
+ sep="<|im_end|>",
312
+ )
313
+
314
+ conv_llava_plain = Conversation(
315
+ system="",
316
+ roles=("", ""),
317
+ messages=(
318
+ ),
319
+ offset=0,
320
+ sep_style=SeparatorStyle.PLAIN,
321
+ sep="\n",
322
+ )
323
+
324
+ conv_llava_v0 = Conversation(
325
+ system="A chat between a curious human and an artificial intelligence assistant. "
326
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
327
+ roles=("Human", "Assistant"),
328
+ messages=(
329
+ ),
330
+ offset=0,
331
+ sep_style=SeparatorStyle.SINGLE,
332
+ sep="###",
333
+ )
334
+
335
+ conv_llava_v0_mmtag = Conversation(
336
+ system="A chat between a curious user and an artificial intelligence assistant. "
337
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
338
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
339
+ roles=("Human", "Assistant"),
340
+ messages=(
341
+ ),
342
+ offset=0,
343
+ sep_style=SeparatorStyle.SINGLE,
344
+ sep="###",
345
+ version="v0_mmtag",
346
+ )
347
+
348
+ conv_llava_v1 = Conversation(
349
+ system="A chat between a curious human and an artificial intelligence assistant. "
350
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
351
+ roles=("USER", "ASSISTANT"),
352
+ version="v1",
353
+ messages=(),
354
+ offset=0,
355
+ sep_style=SeparatorStyle.TWO,
356
+ sep=" ",
357
+ sep2="</s>",
358
+ )
359
+
360
+ conv_llava_v1_mmtag = Conversation(
361
+ system="A chat between a curious user and an artificial intelligence assistant. "
362
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
363
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
364
+ roles=("USER", "ASSISTANT"),
365
+ messages=(),
366
+ offset=0,
367
+ sep_style=SeparatorStyle.TWO,
368
+ sep=" ",
369
+ sep2="</s>",
370
+ version="v1_mmtag",
371
+ )
372
+
373
+ default_conversation = conv_vicuna_v0
374
+ conv_templates = {
375
+ "default": conv_vicuna_v0,
376
+ "v0": conv_vicuna_v0,
377
+ "v1": conv_vicuna_v1,
378
+ "vicuna_v1": conv_vicuna_v1,
379
+ "llama_2": conv_llama_2,
380
+
381
+ "plain": conv_llava_plain,
382
+ "v0_plain": conv_llava_plain,
383
+ "llava_v0": conv_llava_v0,
384
+ "v0_mmtag": conv_llava_v0_mmtag,
385
+ "llava_v1": conv_llava_v1,
386
+ "v1_mmtag": conv_llava_v1_mmtag,
387
+ "llava_llama_2": conv_llava_llama_2,
388
+
389
+ "mpt": conv_mpt,
390
+ }
391
+
392
+ class Chat:
393
+ def __init__(self, model, image_processor,tokenizer, device='cuda:0', stopping_criteria=None):
394
+ self.device = device
395
+ self.model = model
396
+ self.vis_processor = image_processor
397
+ self.tokenizer=tokenizer
398
+
399
+ # if stopping_criteria is not None:
400
+ # self.stopping_criteria = stopping_criteria
401
+ # else:
402
+ # stop_words_ids = [torch.tensor([2]).to(self.device)]
403
+ # self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
404
+
405
+ def ask(self, text, conv):
406
+ # import pdb;pdb.set_trace()
407
+ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
408
+ and conv.messages[-1][1][-9:] == '<image>\n': # last message is image.
409
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
410
+ else:
411
+ conv.append_message(conv.roles[0], text)
412
+
413
+ def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
414
+ repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
415
+ conv.append_message(conv.roles[1], None)
416
+ prompt = conv.get_prompt()
417
+ # prompt='A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\'s questions. USER: <image>\n hello ASSISTANT:'
418
+ text_input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device=self.device)
419
+
420
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
421
+ keywords = [stop_str]
422
+ stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, text_input_ids)
423
+ current_max_len = text_input_ids.shape[1] + max_new_tokens
424
+ if current_max_len - max_length > 0:
425
+ print('Warning: The number of tokens in current conversation exceeds the max length. '
426
+ 'The model will not see the contexts outside the range.')
427
+ begin_idx = max(0, current_max_len - max_length)
428
+ embs = text_input_ids[:, begin_idx:]
429
+
430
+ generation_kwargs = dict(
431
+ input_ids=embs,
432
+ images=img_list[0],
433
+ max_new_tokens=max_new_tokens,
434
+ stopping_criteria=[stopping_criteria],
435
+ num_beams=num_beams,
436
+ do_sample=True,
437
+ min_length=min_length,
438
+ top_p=top_p,
439
+ use_cache=True,
440
+ repetition_penalty=repetition_penalty,
441
+ length_penalty=length_penalty,
442
+ temperature=float(temperature),
443
+ )
444
+ return generation_kwargs
445
+
446
+ # def answer(self, conv, img_list, **kargs):
447
+ # generation_dict = self.answer_prepare(conv, img_list, **kargs)
448
+ # output_token = self.model_generate(**generation_dict)[0]
449
+ # output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
450
+
451
+ # output_text = output_text.split('###')[0] # remove the stop sign '###'
452
+ # output_text = output_text.split('Assistant:')[-1].strip()
453
+
454
+ # conv.messages[-1][1] = output_text
455
+ # return output_text, output_token.cpu().numpy()
456
+
457
+ def stream_answer(self, conv, img_list, **kargs):
458
+ generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
459
+
460
+ streamer = TextIteratorStreamer(self.tokenizer,skip_prompt=True, skip_special_tokens=True)
461
+ generation_kwargs['streamer'] = streamer
462
+ # import pdb;pdb.set_trace()
463
+ # output_ids=self.model.generate(*generation_kwargs)
464
+ output=self.model_generate(kwargs=generation_kwargs)
465
+ # thread = Thread(target=self.model_generate, kwargs=generation_kwargs)
466
+ # thread.start()
467
+ return streamer
468
+
469
+ def model_generate(self, *args, **kwargs):
470
+ # for 8 bit and 16 bit compatibility
471
+ with torch.inference_mode():
472
+ output = self.model.generate(kwargs['kwargs']['input_ids'],
473
+ images=kwargs['kwargs']['images'],
474
+ do_sample=False,
475
+ temperature=kwargs['kwargs']['temperature'],
476
+ max_new_tokens=kwargs['kwargs']['max_new_tokens'],
477
+ streamer=kwargs['kwargs']['streamer'],
478
+ use_cache=kwargs['kwargs']['use_cache'],
479
+ stopping_criteria=kwargs['kwargs']['stopping_criteria'])
480
+ # import pdb;pdb.set_trace()
481
+ # print(output)
482
+ outputs = self.tokenizer.decode(output[0,kwargs['kwargs']['input_ids'].shape[1]:]).strip()
483
+ # print(outputs)
484
+ return output
485
+
486
+ def encode_img(self, img_list):
487
+
488
+ image = img_list[0]
489
+ # image='/share/data/drive_3/kartik/LLaVA/output_images/output.jpg'
490
+ img_list.pop(0)
491
+ if isinstance(image, str): # is a image path
492
+ raw_image = Image.open(image).convert('RGB')
493
+ image = process_images_demo([raw_image], self.vis_processor)
494
+ # print("raw")
495
+ # image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
496
+ elif isinstance(image, Image.Image):
497
+ raw_image = image
498
+ image = process_images_demo([raw_image], self.vis_processor )
499
+ image=image.to(device=self.device,dtype=torch.float16)
500
+ # print("Image")
501
+ # image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
502
+ elif isinstance(image, torch.Tensor):
503
+ if len(image.shape) == 3:
504
+ image = image.unsqueeze(0)
505
+ image = image.to(self.device)
506
+
507
+ # image_emb, _ = self.model.encode_img(image)
508
+ img_list.append(image)
509
+
510
+ def upload_img(self, image, conv, img_list):
511
+ conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN+'\n')
512
+ img_list.append(image)
513
+ msg = "Received."
514
+
515
+ return msg
516
+
517
+
518
+
519
+ # if __name__ == "__main__":
520
+ # print(default_conversation.get_prompt())
geochat/eval/batch_geochat_grounding.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from geochat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from geochat.conversation import conv_templates, SeparatorStyle
10
+ from geochat.model.builder import load_pretrained_model
11
+ from geochat.utils import disable_torch_init
12
+ from geochat.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13
+
14
+ from PIL import Image
15
+ import math
16
+ def split_list(lst, n):
17
+ """Split a list into n (roughly) equal-sized chunks"""
18
+ chunk_size = math.ceil(len(lst) / n) # integer division
19
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
20
+
21
+
22
+ def get_chunk(lst, n, k):
23
+ chunks = split_list(lst, n)
24
+ return chunks[k]
25
+
26
+
27
+ def eval_model(args):
28
+ # Model
29
+ disable_torch_init()
30
+ model_path = os.path.expanduser(args.model_path)
31
+ model_name = get_model_name_from_path(model_path)
32
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)
33
+ import pdb;pdb.set_trace()
34
+ # print(model)
35
+ questions=[]
36
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
37
+
38
+
39
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
40
+ answers_file = os.path.expanduser(args.answers_file)
41
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
42
+
43
+ ans_file = open(answers_file, "w")
44
+
45
+ for i in tqdm(range(0,len(questions),args.batch_size)):
46
+ input_batch=[]
47
+ input_image_batch=[]
48
+ count=i
49
+ image_folder=[]
50
+ batch_end = min(i + args.batch_size, len(questions))
51
+
52
+
53
+ for j in range(i,batch_end):
54
+ image_file=questions[j]['image_id']+'.png'
55
+
56
+ if questions[j]['type']=='ref':
57
+ qs="[refer] Give me the location of <p> " + qs+" </p>"
58
+ else:
59
+ qs="[grounding]" + qs
60
+
61
+ if model.config.mm_use_im_start_end:
62
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
63
+ else:
64
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
65
+
66
+ conv = conv_templates[args.conv_mode].copy()
67
+ conv.append_message(conv.roles[0], qs)
68
+ conv.append_message(conv.roles[1], None)
69
+ prompt = conv.get_prompt()
70
+
71
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
72
+ input_batch.append(input_ids)
73
+
74
+ image = Image.open(os.path.join(args.image_folder, image_file))
75
+
76
+ image_folder.append(image)
77
+
78
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
79
+ keywords = [stop_str]
80
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
81
+
82
+ max_length = max(tensor.size(1) for tensor in input_batch)
83
+
84
+ final_input_list = [torch.cat((torch.zeros((1,max_length - tensor.size(1)), dtype=tensor.dtype,device=tensor.get_device()), tensor),dim=1) for tensor in input_batch]
85
+ final_input_tensors=torch.cat(final_input_list,dim=0)
86
+ image_tensor_batch = image_processor.preprocess(image_folder,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504}, return_tensors='pt')['pixel_values']
87
+
88
+ with torch.inference_mode():
89
+ output_ids = model.generate( final_input_tensors, images=image_tensor_batch.half().cuda(), do_sample=False , temperature=args.temperature, top_p=args.top_p, num_beams=1, max_new_tokens=256,length_penalty=2.0, use_cache=True)
90
+
91
+ input_token_len = final_input_tensors.shape[1]
92
+ n_diff_input_output = (final_input_tensors != output_ids[:, :input_token_len]).sum().item()
93
+ if n_diff_input_output > 0:
94
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
95
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
96
+ for k in range(0,len(final_input_list)):
97
+ output = outputs[k].strip()
98
+ if output.endswith(stop_str):
99
+ output = output[:-len(stop_str)]
100
+ output = output.strip()
101
+
102
+ ans_id = shortuuid.uuid()
103
+
104
+ ans_file.write(json.dumps({
105
+
106
+ "question_id": questions[count]["question_id"],
107
+ "image_id": questions[count]["image_id"],
108
+ "answer": output,
109
+ "ground_truth": questions[count]['ground_truth'],
110
+ "question":questions[count]['question'],
111
+ "type": questions[count]['type'],
112
+ "dataset": questions[count]['dataset'],
113
+ "obj_ids": questions[count]['obj_ids'],
114
+ "size_group": questions[count]['size_group'],
115
+
116
+ }) + "\n")
117
+ count=count+1
118
+ ans_file.flush()
119
+ ans_file.close()
120
+
121
+
122
+ if __name__ == "__main__":
123
+ parser = argparse.ArgumentParser()
124
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
125
+ parser.add_argument("--model-base", type=str, default=None)
126
+ parser.add_argument("--image-folder", type=str, default="")
127
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
128
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
129
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
130
+ parser.add_argument("--num-chunks", type=int, default=1)
131
+ parser.add_argument("--chunk-idx", type=int, default=0)
132
+ parser.add_argument("--temperature", type=float, default=0.2)
133
+ parser.add_argument("--top_p", type=float, default=None)
134
+ parser.add_argument("--num_beams", type=int, default=1)
135
+ parser.add_argument("--batch_size",type=int, default=1)
136
+ args = parser.parse_args()
137
+
138
+ eval_model(args)
geochat/eval/batch_geochat_referring.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from geochat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from geochat.conversation import conv_templates, SeparatorStyle
10
+ from geochat.model.builder import load_pretrained_model
11
+ from geochat.utils import disable_torch_init
12
+ from geochat.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13
+
14
+ from PIL import Image
15
+ import math
16
+ def split_list(lst, n):
17
+ """Split a list into n (roughly) equal-sized chunks"""
18
+ chunk_size = math.ceil(len(lst) / n) # integer division
19
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
20
+
21
+
22
+ def get_chunk(lst, n, k):
23
+ chunks = split_list(lst, n)
24
+ return chunks[k]
25
+
26
+
27
+ def eval_model(args):
28
+ # Model
29
+ disable_torch_init()
30
+ model_path = os.path.expanduser(args.model_path)
31
+ model_name = get_model_name_from_path(model_path)
32
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)
33
+ # print(model)
34
+ questions=[]
35
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
36
+
37
+
38
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
39
+ answers_file = os.path.expanduser(args.answers_file)
40
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
41
+
42
+ ans_file = open(answers_file, "w")
43
+
44
+ for i in tqdm(range(0,len(questions),args.batch_size)):
45
+ input_batch=[]
46
+ input_image_batch=[]
47
+ count=i
48
+ image_folder=[]
49
+ batch_end = min(i + args.batch_size, len(questions))
50
+
51
+
52
+ for j in range(i,batch_end):
53
+ image_file=questions[j]['image_id']+'.png'
54
+ qs="[identify] What is the object present at " + questions[j]['question']
55
+
56
+ if model.config.mm_use_im_start_end:
57
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
58
+ else:
59
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
60
+
61
+ conv = conv_templates[args.conv_mode].copy()
62
+ conv.append_message(conv.roles[0], qs)
63
+ conv.append_message(conv.roles[1], None)
64
+ prompt = conv.get_prompt()
65
+
66
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
67
+ input_batch.append(input_ids)
68
+
69
+ image = Image.open(os.path.join(args.image_folder, image_file))
70
+
71
+ image_folder.append(image)
72
+
73
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
74
+ keywords = [stop_str]
75
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
76
+
77
+ max_length = max(tensor.size(1) for tensor in input_batch)
78
+
79
+ final_input_list = [torch.cat((torch.zeros((1,max_length - tensor.size(1)), dtype=tensor.dtype,device=tensor.get_device()), tensor),dim=1) for tensor in input_batch]
80
+ final_input_tensors=torch.cat(final_input_list,dim=0)
81
+ image_tensor_batch = image_processor.preprocess(image_folder,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504}, return_tensors='pt')['pixel_values']
82
+
83
+ with torch.inference_mode():
84
+ output_ids = model.generate( final_input_tensors, images=image_tensor_batch.half().cuda(), do_sample=False , temperature=args.temperature, top_p=args.top_p, num_beams=1, max_new_tokens=256,length_penalty=2.0, use_cache=True)
85
+
86
+ input_token_len = final_input_tensors.shape[1]
87
+ n_diff_input_output = (final_input_tensors != output_ids[:, :input_token_len]).sum().item()
88
+ if n_diff_input_output > 0:
89
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
90
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
91
+ for k in range(0,len(final_input_list)):
92
+ output = outputs[k].strip()
93
+ if output.endswith(stop_str):
94
+ output = output[:-len(stop_str)]
95
+ output = output.strip()
96
+
97
+ ans_id = shortuuid.uuid()
98
+
99
+ ans_file.write(json.dumps({
100
+ "question_id": questions[count]["question_id"],
101
+ "image_id": questions[count]["image_id"],
102
+ "answer": output,
103
+ "ground_truth": questions[count]['ground_truth'],
104
+ "question":questions[count]['question'],
105
+ "type": questions[count]['type'],
106
+ "dataset": questions[count]['dataset'],
107
+ "obj_ids": questions[count]['obj_ids'],
108
+ "size_group": questions[count]['size_group'],
109
+
110
+ }) + "\n")
111
+ count=count+1
112
+ ans_file.flush()
113
+ ans_file.close()
114
+
115
+
116
+ if __name__ == "__main__":
117
+ parser = argparse.ArgumentParser()
118
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
119
+ parser.add_argument("--model-base", type=str, default=None)
120
+ parser.add_argument("--image-folder", type=str, default="")
121
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
122
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
123
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
124
+ parser.add_argument("--num-chunks", type=int, default=1)
125
+ parser.add_argument("--chunk-idx", type=int, default=0)
126
+ parser.add_argument("--temperature", type=float, default=0.2)
127
+ parser.add_argument("--top_p", type=float, default=None)
128
+ parser.add_argument("--num_beams", type=int, default=1)
129
+ parser.add_argument("--batch_size",type=int, default=1)
130
+ args = parser.parse_args()
131
+
132
+ eval_model(args)
geochat/eval/batch_geochat_scene.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from geochat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from geochat.conversation import conv_templates, SeparatorStyle
10
+ from geochat.model.builder import load_pretrained_model
11
+ from geochat.utils import disable_torch_init
12
+ from geochat.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13
+
14
+ from PIL import Image
15
+ import math
16
+
17
+ def evaluation_metrics(data_path):
18
+
19
+ base = [json.loads(q) for q in open(data_path, "r")]
20
+ correct=0
21
+ incorrect=0
22
+ for answers in tqdm(base):
23
+ gt=answers['question_id'].split('/')[0].lower()
24
+ answer=answers['answer'].replace(' ','').lower().replace('.','')
25
+ if gt==answer:
26
+ correct=correct+1
27
+ else:
28
+ incorrect=incorrect+1
29
+ # else:
30
+ # continue
31
+ print('correct:',correct)
32
+ print('incorrect:',incorrect)
33
+ print('Total:',correct+incorrect)
34
+ print('Acc:',(correct/(correct+incorrect)))
35
+
36
+
37
+
38
+
39
+ def eval_model(args):
40
+ # Model
41
+ disable_torch_init()
42
+ model_path = os.path.expanduser(args.model_path)
43
+ model_name = get_model_name_from_path(model_path)
44
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)
45
+ # print(model)
46
+ questions=[]
47
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
48
+
49
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
50
+ answers_file = os.path.expanduser(args.answers_file)
51
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
52
+
53
+ ans_file = open(answers_file, "w")
54
+
55
+ for i in tqdm(range(0,len(questions),args.batch_size)):
56
+ input_batch=[]
57
+ input_image_batch=[]
58
+ count=i
59
+ image_folder=[]
60
+ batch_end = min(i + args.batch_size, len(questions))
61
+
62
+
63
+ for j in range(i,batch_end):
64
+ image_file=questions[j]['image']
65
+ qs=questions[j]['text']
66
+
67
+ if model.config.mm_use_im_start_end:
68
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
69
+ else:
70
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
71
+
72
+ conv = conv_templates[args.conv_mode].copy()
73
+ conv.append_message(conv.roles[0], qs)
74
+ conv.append_message(conv.roles[1], None)
75
+ prompt = conv.get_prompt()
76
+
77
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
78
+ input_batch.append(input_ids)
79
+
80
+ image = Image.open(os.path.join(args.image_folder, image_file))
81
+
82
+ image_folder.append(image)
83
+
84
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
85
+ keywords = [stop_str]
86
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
87
+
88
+ max_length = max(tensor.size(1) for tensor in input_batch)
89
+
90
+ final_input_list = [torch.cat((torch.zeros((1,max_length - tensor.size(1)), dtype=tensor.dtype,device=tensor.get_device()), tensor),dim=1) for tensor in input_batch]
91
+ final_input_tensors=torch.cat(final_input_list,dim=0)
92
+ image_tensor_batch = image_processor.preprocess(image_folder,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504}, return_tensors='pt')['pixel_values']
93
+
94
+ with torch.inference_mode():
95
+ output_ids = model.generate( final_input_tensors, images=image_tensor_batch.half().cuda(), do_sample=False , temperature=args.temperature, top_p=args.top_p, num_beams=1, max_new_tokens=256,length_penalty=2.0, use_cache=True)
96
+
97
+ input_token_len = final_input_tensors.shape[1]
98
+ n_diff_input_output = (final_input_tensors != output_ids[:, :input_token_len]).sum().item()
99
+ if n_diff_input_output > 0:
100
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
101
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
102
+ for k in range(0,len(final_input_list)):
103
+ output = outputs[k].strip()
104
+ if output.endswith(stop_str):
105
+ output = output[:-len(stop_str)]
106
+ output = output.strip()
107
+
108
+ ans_id = shortuuid.uuid()
109
+
110
+ ans_file.write(json.dumps({
111
+
112
+ "question_id": questions[count]["question_id"],
113
+ "image_id": questions[count]["image"],
114
+ "answer": output,
115
+ "ground_truth": questions[count]['ground_truth']
116
+ }) + "\n")
117
+ count=count+1
118
+ ans_file.flush()
119
+ ans_file.close()
120
+ evaluation_metrics(answers_file)
121
+
122
+
123
+ if __name__ == "__main__":
124
+ parser = argparse.ArgumentParser()
125
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
126
+ parser.add_argument("--model-base", type=str, default=None)
127
+ parser.add_argument("--image-folder", type=str, default="")
128
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
129
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
130
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
131
+ parser.add_argument("--num-chunks", type=int, default=1)
132
+ parser.add_argument("--chunk-idx", type=int, default=0)
133
+ parser.add_argument("--temperature", type=float, default=0.2)
134
+ parser.add_argument("--top_p", type=float, default=None)
135
+ parser.add_argument("--num_beams", type=int, default=1)
136
+ parser.add_argument("--batch_size",type=int, default=1)
137
+ args = parser.parse_args()
138
+
139
+ eval_model(args)
geochat/eval/batch_geochat_vqa.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from geochat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from geochat.conversation import conv_templates, SeparatorStyle
10
+ from geochat.model.builder import load_pretrained_model
11
+ from geochat.utils import disable_torch_init
12
+ from geochat.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13
+
14
+ from PIL import Image
15
+ import math
16
+ def split_list(lst, n):
17
+ """Split a list into n (roughly) equal-sized chunks"""
18
+ chunk_size = math.ceil(len(lst) / n) # integer division
19
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
20
+
21
+
22
+ def get_chunk(lst, n, k):
23
+ chunks = split_list(lst, n)
24
+ return chunks[k]
25
+
26
+
27
+ def eval_model(args):
28
+ # Model
29
+ disable_torch_init()
30
+ model_path = os.path.expanduser(args.model_path)
31
+ model_name = get_model_name_from_path(model_path)
32
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)
33
+
34
+ questions=[]
35
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
36
+
37
+
38
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
39
+ answers_file = os.path.expanduser(args.answers_file)
40
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
41
+
42
+ ans_file = open(answers_file, "w")
43
+
44
+ for i in tqdm(range(0,len(questions),args.batch_size)):
45
+ input_batch=[]
46
+ input_image_batch=[]
47
+ count=i
48
+ image_folder=[]
49
+ batch_end = min(i + args.batch_size, len(questions))
50
+
51
+
52
+ for j in range(i,batch_end):
53
+ image_file=questions[j]['image']
54
+ qs=questions[j]['text']
55
+
56
+ if model.config.mm_use_im_start_end:
57
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
58
+ else:
59
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
60
+
61
+ conv = conv_templates[args.conv_mode].copy()
62
+ conv.append_message(conv.roles[0], qs)
63
+ conv.append_message(conv.roles[1], None)
64
+ prompt = conv.get_prompt()
65
+
66
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
67
+ input_batch.append(input_ids)
68
+
69
+ image = Image.open(os.path.join(args.image_folder, image_file))
70
+
71
+ image_folder.append(image)
72
+
73
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
74
+ keywords = [stop_str]
75
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
76
+
77
+ max_length = max(tensor.size(1) for tensor in input_batch)
78
+
79
+ final_input_list = [torch.cat((torch.zeros((1,max_length - tensor.size(1)), dtype=tensor.dtype,device=tensor.get_device()), tensor),dim=1) for tensor in input_batch]
80
+ final_input_tensors=torch.cat(final_input_list,dim=0)
81
+ image_tensor_batch = image_processor.preprocess(image_folder,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504}, return_tensors='pt')['pixel_values']
82
+
83
+ with torch.inference_mode():
84
+ output_ids = model.generate( final_input_tensors, images=image_tensor_batch.half().cuda(), do_sample=False , temperature=args.temperature, top_p=args.top_p, num_beams=1, max_new_tokens=256,length_penalty=2.0, use_cache=True)
85
+
86
+ input_token_len = final_input_tensors.shape[1]
87
+ n_diff_input_output = (final_input_tensors != output_ids[:, :input_token_len]).sum().item()
88
+ if n_diff_input_output > 0:
89
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
90
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
91
+ for k in range(0,len(final_input_list)):
92
+ output = outputs[k].strip()
93
+ if output.endswith(stop_str):
94
+ output = output[:-len(stop_str)]
95
+ output = output.strip()
96
+
97
+ ans_id = shortuuid.uuid()
98
+
99
+ ans_file.write(json.dumps({
100
+ "question_id": questions[count]["question_id"],
101
+ "image_id": questions[count]["image"],
102
+ "answer": output,
103
+ }) + "\n")
104
+ count=count+1
105
+ ans_file.flush()
106
+ ans_file.close()
107
+
108
+
109
+ if __name__ == "__main__":
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
112
+ parser.add_argument("--model-base", type=str, default=None)
113
+ parser.add_argument("--image-folder", type=str, default="")
114
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
115
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
116
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
117
+ parser.add_argument("--num-chunks", type=int, default=1)
118
+ parser.add_argument("--chunk-idx", type=int, default=0)
119
+ parser.add_argument("--temperature", type=float, default=0.2)
120
+ parser.add_argument("--top_p", type=float, default=None)
121
+ parser.add_argument("--num_beams", type=int, default=1)
122
+ parser.add_argument("--batch_size",type=int, default=1)
123
+ args = parser.parse_args()
124
+
125
+ eval_model(args)
geochat/mm_utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+
5
+ import torch
6
+ from transformers import StoppingCriteria
7
+ from geochat.constants import IMAGE_TOKEN_INDEX
8
+ import numpy as np
9
+
10
+ def load_image_from_base64(image):
11
+ return Image.open(BytesIO(base64.b64decode(image)))
12
+
13
+
14
+ def expand2square(pil_img, background_color):
15
+ width, height = pil_img.size
16
+ if width == height:
17
+ return pil_img
18
+ elif width > height:
19
+ result = Image.new(pil_img.mode, (width, width), background_color)
20
+ result.paste(pil_img, (0, (width - height) // 2))
21
+ return result
22
+ else:
23
+ result = Image.new(pil_img.mode, (height, height), background_color)
24
+ result.paste(pil_img, ((height - width) // 2, 0))
25
+ return result
26
+
27
+
28
+ def process_images(images, image_processor, model_cfg):
29
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
30
+ new_images = []
31
+ if image_aspect_ratio == 'pad':
32
+ for image in images:
33
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
34
+ image = image_processor.preprocess(image,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504},return_tensors='pt')['pixel_values'][0]
35
+ # image = image_processor.preprocess(image,return_tensors='pt')['pixel_values'][0]
36
+
37
+ new_images.append(image)
38
+ else:
39
+ return image_processor(images, return_tensors='pt')['pixel_values']
40
+ if all(x.shape == new_images[0].shape for x in new_images):
41
+ new_images = torch.stack(new_images, dim=0)
42
+ return new_images
43
+
44
+ def process_images_demo(images, image_processor):
45
+ new_images = []
46
+ # image_aspect_ratio = 'pad'
47
+ for image in images:
48
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
49
+ image = image_processor.preprocess(image,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504},return_tensors='pt')['pixel_values'][0]
50
+ # image = image_processor.preprocess(image,return_tensors='pt')['pixel_values'][0]
51
+
52
+ new_images.append(image)
53
+
54
+ if all(x.shape == new_images[0].shape for x in new_images):
55
+ new_images = torch.stack(new_images, dim=0)
56
+ return new_images
57
+
58
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
59
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
60
+
61
+ def insert_separator(X, sep):
62
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
63
+
64
+ input_ids = []
65
+ offset = 0
66
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
67
+ offset = 1
68
+ input_ids.append(prompt_chunks[0][0])
69
+
70
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
71
+ input_ids.extend(x[offset:])
72
+
73
+ if return_tensors is not None:
74
+ if return_tensors == 'pt':
75
+ return torch.tensor(input_ids, dtype=torch.long)
76
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
77
+ return input_ids
78
+
79
+
80
+ def get_model_name_from_path(model_path):
81
+ model_path = model_path.strip("/")
82
+ model_paths = model_path.split("/")
83
+ if model_paths[-1].startswith('checkpoint-'):
84
+ return model_paths[-2] + "_" + model_paths[-1]
85
+ else:
86
+ return model_paths[-1]
87
+
88
+
89
+
90
+
91
+ class KeywordsStoppingCriteria(StoppingCriteria):
92
+ def __init__(self, keywords, tokenizer, input_ids):
93
+ self.keywords = keywords
94
+ self.keyword_ids = []
95
+ self.max_keyword_len = 0
96
+ for keyword in keywords:
97
+ cur_keyword_ids = tokenizer(keyword).input_ids
98
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
99
+ cur_keyword_ids = cur_keyword_ids[1:]
100
+ if len(cur_keyword_ids) > self.max_keyword_len:
101
+ self.max_keyword_len = len(cur_keyword_ids)
102
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
103
+ self.tokenizer = tokenizer
104
+ self.start_len = input_ids.shape[1]
105
+
106
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
107
+ # assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
108
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
109
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
110
+ for keyword_id in self.keyword_ids:
111
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
112
+ return True
113
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
114
+ flag=False
115
+ for output in outputs:
116
+
117
+ for keyword in self.keywords:
118
+ if keyword in output:
119
+ flag=True
120
+ return flag
121
+ return flag
geochat/model/.ipynb_checkpoints/__init__-checkpoint.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .language_model.geochat_llama import GeoChatLlamaForCausalLM, GeoChatConfig
2
+ from .language_model.geochat_mpt import GeoChatMPTForCausalLM, GeoChatMPTConfig
geochat/model/.ipynb_checkpoints/builder-checkpoint.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import warnings
18
+ import shutil
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
+ import torch
22
+ from geochat.model import *
23
+ from geochat.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+
25
+
26
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
27
+ kwargs = {"device_map": device_map}
28
+
29
+ if load_8bit:
30
+ kwargs['load_in_8bit'] = True
31
+ elif load_4bit:
32
+ kwargs['load_in_4bit'] = True
33
+ kwargs['quantization_config'] = BitsAndBytesConfig(
34
+ load_in_4bit=True,
35
+ bnb_4bit_compute_dtype=torch.float16,
36
+ bnb_4bit_use_double_quant=True,
37
+ bnb_4bit_quant_type='nf4'
38
+ )
39
+ else:
40
+ kwargs['torch_dtype'] = torch.float16
41
+
42
+ if 'geochat' in model_name.lower():
43
+ # Load LLaVA model
44
+ if 'lora' in model_name.lower() and model_base is None:
45
+ warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
46
+ if 'lora' in model_name.lower() and model_base is not None:
47
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
48
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
49
+ print('Loading Geochat from base model...')
50
+ model = GeoChatLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
51
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
52
+ if model.lm_head.weight.shape[0] != token_num:
53
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
54
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
55
+
56
+ print('Loading additional GeoChat weights...')
57
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
58
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
59
+ else:
60
+ # this is probably from HF Hub
61
+ from huggingface_hub import hf_hub_download
62
+ def load_from_hf(repo_id, filename, subfolder=None):
63
+ cache_file = hf_hub_download(
64
+ repo_id=repo_id,
65
+ filename=filename,
66
+ subfolder=subfolder)
67
+ return torch.load(cache_file, map_location='cpu')
68
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
69
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
70
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
71
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
72
+ model.load_state_dict(non_lora_trainables, strict=False)
73
+
74
+ from peft import PeftModel
75
+ print('Loading LoRA weights...')
76
+ model = PeftModel.from_pretrained(model, model_path)
77
+ print('Merging LoRA weights...')
78
+ model = model.merge_and_unload()
79
+ print('Model is loaded...')
80
+ elif model_base is not None:
81
+ # this may be mm projector only
82
+ print('Loading GeoChat from base model...')
83
+ if 'mpt' in model_name.lower():
84
+ if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
85
+ shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
86
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
87
+ cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
88
+ model = GeoChatMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
89
+ else:
90
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
91
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
92
+ model = GeoChatLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
93
+
94
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
95
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
96
+ model.load_state_dict(mm_projector_weights, strict=False)
97
+ else:
98
+ if 'mpt' in model_name.lower():
99
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
100
+ model = GeoChatMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
101
+ else:
102
+ print("Loading GeoChat......")
103
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
104
+ model = GeoChatLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
105
+ else:
106
+ # Load language model
107
+ if model_base is not None:
108
+ # PEFT model
109
+ from peft import PeftModel
110
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
111
+ model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
112
+ print(f"Loading LoRA weights from {model_path}")
113
+ model = PeftModel.from_pretrained(model, model_path)
114
+ print(f"Merging weights")
115
+ model = model.merge_and_unload()
116
+ print('Convert to FP16...')
117
+ model.to(torch.float16)
118
+ else:
119
+ use_fast = False
120
+ if 'mpt' in model_name.lower():
121
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
122
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
123
+ else:
124
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
125
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
126
+
127
+ image_processor = None
128
+
129
+ if 'geochat' in model_name.lower():
130
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
131
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
132
+ if mm_use_im_patch_token:
133
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
134
+ if mm_use_im_start_end:
135
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
136
+ model.resize_token_embeddings(len(tokenizer))
137
+
138
+ vision_tower = model.get_vision_tower()
139
+ if not vision_tower.is_loaded:
140
+ vision_tower.load_model()
141
+ vision_tower.to(device=device, dtype=torch.float16)
142
+ image_processor = vision_tower.image_processor
143
+
144
+ if hasattr(model.config, "max_sequence_length"):
145
+ context_len = model.config.max_sequence_length
146
+ else:
147
+ context_len = 2048
148
+
149
+ return tokenizer, model, image_processor, context_len
geochat/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .language_model.geochat_llama import GeoChatLlamaForCausalLM, GeoChatConfig
2
+ from .language_model.geochat_mpt import GeoChatMPTForCausalLM, GeoChatMPTConfig
geochat/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (338 Bytes). View file
 
geochat/model/__pycache__/builder.cpython-310.pyc ADDED
Binary file (4.7 kB). View file
 
geochat/model/__pycache__/geochat_arch.cpython-310.pyc ADDED
Binary file (8.26 kB). View file
 
geochat/model/apply_delta.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from geochat import GeoChatLlamaForCausalLM
11
+
12
+
13
+ def apply_delta(base_model_path, target_model_path, delta_path):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading delta")
19
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21
+
22
+ print("Applying delta")
23
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data += base.state_dict()[name]
29
+ else:
30
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31
+ f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32
+ bparam = base.state_dict()[name]
33
+ param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34
+
35
+ print("Saving target model")
36
+ delta.save_pretrained(target_model_path)
37
+ delta_tokenizer.save_pretrained(target_model_path)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--base-model-path", type=str, required=True)
43
+ parser.add_argument("--target-model-path", type=str, required=True)
44
+ parser.add_argument("--delta-path", type=str, required=True)
45
+
46
+ args = parser.parse_args()
47
+
48
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
geochat/model/builder.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import warnings
18
+ import shutil
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
+ import torch
22
+ from geochat.model import *
23
+ from geochat.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+
25
+
26
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
27
+ kwargs = {"device_map": device_map}
28
+
29
+ if load_8bit:
30
+ kwargs['load_in_8bit'] = True
31
+ elif load_4bit:
32
+ kwargs['load_in_4bit'] = True
33
+ kwargs['quantization_config'] = BitsAndBytesConfig(
34
+ load_in_4bit=True,
35
+ bnb_4bit_compute_dtype=torch.float16,
36
+ bnb_4bit_use_double_quant=True,
37
+ bnb_4bit_quant_type='nf4'
38
+ )
39
+ else:
40
+ kwargs['torch_dtype'] = torch.float16
41
+
42
+ if 'geochat' in model_name.lower():
43
+ # Load LLaVA model
44
+ if 'lora' in model_name.lower() and model_base is None:
45
+ warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
46
+ if 'lora' in model_name.lower() and model_base is not None:
47
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
48
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
49
+ print('Loading Geochat from base model...')
50
+ model = GeoChatLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
51
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
52
+ if model.lm_head.weight.shape[0] != token_num:
53
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
54
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
55
+
56
+ print('Loading additional GeoChat weights...')
57
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
58
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
59
+ else:
60
+ # this is probably from HF Hub
61
+ from huggingface_hub import hf_hub_download
62
+ def load_from_hf(repo_id, filename, subfolder=None):
63
+ cache_file = hf_hub_download(
64
+ repo_id=repo_id,
65
+ filename=filename,
66
+ subfolder=subfolder)
67
+ return torch.load(cache_file, map_location='cpu')
68
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
69
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
70
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
71
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
72
+ model.load_state_dict(non_lora_trainables, strict=False)
73
+
74
+ from peft import PeftModel
75
+ print('Loading LoRA weights...')
76
+ model = PeftModel.from_pretrained(model, model_path)
77
+ print('Merging LoRA weights...')
78
+ model = model.merge_and_unload()
79
+ print('Model is loaded...')
80
+ elif model_base is not None:
81
+ # this may be mm projector only
82
+ print('Loading GeoChat from base model...')
83
+ if 'mpt' in model_name.lower():
84
+ if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
85
+ shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
86
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
87
+ cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
88
+ model = GeoChatMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
89
+ else:
90
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
91
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
92
+ model = GeoChatLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
93
+
94
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
95
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
96
+ model.load_state_dict(mm_projector_weights, strict=False)
97
+ else:
98
+ if 'mpt' in model_name.lower():
99
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
100
+ model = GeoChatMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
101
+ else:
102
+ print("Loading GeoChat......")
103
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
104
+ model = GeoChatLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
105
+ else:
106
+ # Load language model
107
+ if model_base is not None:
108
+ # PEFT model
109
+ from peft import PeftModel
110
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
111
+ model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
112
+ print(f"Loading LoRA weights from {model_path}")
113
+ model = PeftModel.from_pretrained(model, model_path)
114
+ print(f"Merging weights")
115
+ model = model.merge_and_unload()
116
+ print('Convert to FP16...')
117
+ model.to(torch.float16)
118
+ else:
119
+ use_fast = False
120
+ if 'mpt' in model_name.lower():
121
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
122
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
123
+ else:
124
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
125
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
126
+
127
+ image_processor = None
128
+
129
+ if 'geochat' in model_name.lower():
130
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
131
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
132
+ if mm_use_im_patch_token:
133
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
134
+ if mm_use_im_start_end:
135
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
136
+ model.resize_token_embeddings(len(tokenizer))
137
+
138
+ vision_tower = model.get_vision_tower()
139
+ if not vision_tower.is_loaded:
140
+ vision_tower.load_model()
141
+ vision_tower.to(device=device, dtype=torch.float16)
142
+ image_processor = vision_tower.image_processor
143
+
144
+ if hasattr(model.config, "max_sequence_length"):
145
+ context_len = model.config.max_sequence_length
146
+ else:
147
+ context_len = 2048
148
+
149
+ return tokenizer, model, image_processor, context_len
geochat/model/consolidate.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from geochat.model import *
10
+ from geochat.model.utils import auto_upgrade
11
+
12
+
13
+ def consolidate_ckpt(src_path, dst_path):
14
+ print("Loading model")
15
+ auto_upgrade(src_path)
16
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18
+ src_model.save_pretrained(dst_path)
19
+ src_tokenizer.save_pretrained(dst_path)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--src", type=str, required=True)
25
+ parser.add_argument("--dst", type=str, required=True)
26
+
27
+ args = parser.parse_args()
28
+
29
+ consolidate_ckpt(args.src, args.dst)
geochat/model/geochat_arch.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from .multimodal_encoder.builder import build_vision_tower
22
+ from .multimodal_projector.builder import build_vision_projector
23
+
24
+ from geochat.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
+
26
+
27
+ class GeoChatMetaModel:
28
+
29
+ def __init__(self, config):
30
+ super(GeoChatMetaModel, self).__init__(config)
31
+
32
+ if hasattr(config, "mm_vision_tower"):
33
+ self.vision_tower = build_vision_tower(config, delay_load=True)
34
+ self.mm_projector = build_vision_projector(config)
35
+
36
+ def get_vision_tower(self):
37
+ vision_tower = getattr(self, 'vision_tower', None)
38
+ if type(vision_tower) is list:
39
+ vision_tower = vision_tower[0]
40
+ return vision_tower
41
+
42
+ def initialize_vision_modules(self, model_args, fsdp=None):
43
+ vision_tower = model_args.vision_tower
44
+ mm_vision_select_layer = model_args.mm_vision_select_layer
45
+ mm_vision_select_feature = model_args.mm_vision_select_feature
46
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
47
+
48
+ self.config.mm_vision_tower = vision_tower
49
+
50
+ if self.get_vision_tower() is None:
51
+ vision_tower = build_vision_tower(model_args)
52
+
53
+ if fsdp is not None and len(fsdp) > 0:
54
+ self.vision_tower = [vision_tower]
55
+ else:
56
+ self.vision_tower = vision_tower
57
+ else:
58
+ if fsdp is not None and len(fsdp) > 0:
59
+ vision_tower = self.vision_tower[0]
60
+ else:
61
+ vision_tower = self.vision_tower
62
+ vision_tower.load_model()
63
+
64
+ self.config.use_mm_proj = True
65
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
66
+ self.config.mm_hidden_size = vision_tower.hidden_size
67
+ self.config.mm_vision_select_layer = mm_vision_select_layer
68
+ self.config.mm_vision_select_feature = mm_vision_select_feature
69
+
70
+ if getattr(self, 'mm_projector', None) is None:
71
+ self.mm_projector = build_vision_projector(self.config)
72
+ # print(mm_projector)
73
+
74
+
75
+ if pretrain_mm_mlp_adapter is not None:
76
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
77
+
78
+ def get_w(weights, keyword):
79
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
80
+
81
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
82
+
83
+
84
+
85
+
86
+ class GeoChatMetaForCausalLM(ABC):
87
+
88
+ @abstractmethod
89
+ def get_model(self):
90
+ pass
91
+
92
+ def get_vision_tower(self):
93
+ return self.get_model().get_vision_tower()
94
+
95
+ def encode_images(self, images):
96
+ image_features = self.get_model().get_vision_tower()(images)
97
+ image_features = self.get_model().mm_projector(image_features)
98
+ return image_features
99
+
100
+ def prepare_inputs_labels_for_multimodal(
101
+ self, input_ids, attention_mask, past_key_values, labels, images
102
+ ):
103
+ vision_tower = self.get_vision_tower()
104
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
105
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
106
+ attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
107
+ return input_ids, attention_mask, past_key_values, None, labels
108
+
109
+ if type(images) is list or images.ndim == 5:
110
+ concat_images = torch.cat([image for image in images], dim=0)
111
+ image_features = self.encode_images(concat_images)
112
+ split_sizes = [image.shape[0] for image in images]
113
+ image_features = torch.split(image_features, split_sizes, dim=0)
114
+ image_features = [x.flatten(0, 1) for x in image_features]
115
+ else:
116
+ image_features = self.encode_images(images)
117
+
118
+ new_input_embeds = []
119
+ new_labels = [] if labels is not None else None
120
+ cur_image_idx = 0
121
+ for batch_idx, cur_input_ids in enumerate(input_ids):
122
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
123
+ # multimodal LLM, but the current sample is not multimodal
124
+ # FIXME: this is a hacky fix, for deepspeed zero3 to work
125
+ half_len = cur_input_ids.shape[0] // 2
126
+ cur_image_features = image_features[cur_image_idx]
127
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
128
+ cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
129
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
130
+ new_input_embeds.append(cur_input_embeds)
131
+ if labels is not None:
132
+ new_labels.append(labels[batch_idx])
133
+ cur_image_idx += 1
134
+ continue
135
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
136
+ cur_new_input_embeds = []
137
+ if labels is not None:
138
+ cur_labels = labels[batch_idx]
139
+ cur_new_labels = []
140
+ assert cur_labels.shape == cur_input_ids.shape
141
+ while image_token_indices.numel() > 0:
142
+ cur_image_features = image_features[cur_image_idx]
143
+ image_token_start = image_token_indices[0]
144
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
145
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())
146
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))
147
+ cur_new_input_embeds.append(cur_image_features)
148
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))
149
+ if labels is not None:
150
+ cur_new_labels.append(cur_labels[:image_token_start])
151
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
152
+ cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])
153
+ cur_labels = cur_labels[image_token_start+2:]
154
+ else:
155
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
156
+ cur_new_input_embeds.append(cur_image_features)
157
+ if labels is not None:
158
+ cur_new_labels.append(cur_labels[:image_token_start])
159
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
160
+ cur_labels = cur_labels[image_token_start+1:]
161
+ cur_image_idx += 1
162
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
163
+ cur_input_ids = cur_input_ids[image_token_start+2:]
164
+ else:
165
+ cur_input_ids = cur_input_ids[image_token_start+1:]
166
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
167
+ if cur_input_ids.numel() > 0:
168
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
169
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach())
170
+ else:
171
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
172
+ if labels is not None:
173
+ cur_new_labels.append(cur_labels)
174
+ cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
175
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
176
+ new_input_embeds.append(cur_new_input_embeds)
177
+ if labels is not None:
178
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
179
+ new_labels.append(cur_new_labels)
180
+
181
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
182
+ max_len = max(x.shape[0] for x in new_input_embeds)
183
+
184
+ new_input_embeds_align = []
185
+ for cur_new_embed in new_input_embeds:
186
+ cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
187
+ new_input_embeds_align.append(cur_new_embed)
188
+ new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
189
+
190
+ if labels is not None:
191
+ new_labels_align = []
192
+ _new_labels = new_labels
193
+ for cur_new_label in new_labels:
194
+ cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
195
+ new_labels_align.append(cur_new_label)
196
+ new_labels = torch.stack(new_labels_align, dim=0)
197
+
198
+ if attention_mask is not None:
199
+ new_attention_mask = []
200
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
201
+ new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
202
+ new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
203
+ cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
204
+ new_attention_mask.append(cur_new_attention_mask)
205
+ attention_mask = torch.stack(new_attention_mask, dim=0)
206
+ assert attention_mask.shape == new_labels.shape
207
+ else:
208
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
209
+ if labels is not None:
210
+ new_labels = torch.stack(new_labels, dim=0)
211
+
212
+ if attention_mask is not None:
213
+ new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
214
+ attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
215
+ assert attention_mask.shape == new_input_embeds.shape[:2]
216
+
217
+ return None, attention_mask, past_key_values, new_input_embeds, new_labels
218
+
219
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
220
+ if model_args.mm_use_im_patch_token:
221
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
222
+ self.resize_token_embeddings(len(tokenizer))
223
+
224
+ if model_args.mm_use_im_start_end:
225
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
226
+ self.resize_token_embeddings(len(tokenizer))
227
+
228
+ if num_new_tokens > 0:
229
+ input_embeddings = self.get_input_embeddings().weight.data
230
+ output_embeddings = self.get_output_embeddings().weight.data
231
+
232
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
233
+ dim=0, keepdim=True)
234
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
235
+ dim=0, keepdim=True)
236
+
237
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
238
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
239
+
240
+ if model_args.tune_mm_mlp_adapter:
241
+ for p in self.get_input_embeddings().parameters():
242
+ p.requires_grad = True
243
+ for p in self.get_output_embeddings().parameters():
244
+ p.requires_grad = False
245
+
246
+ if model_args.pretrain_mm_mlp_adapter:
247
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
248
+ print(mm_projector_weights)
249
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
250
+ assert num_new_tokens == 2
251
+ if input_embeddings.shape == embed_tokens_weight.shape:
252
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
253
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
254
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
255
+ else:
256
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
257
+ elif model_args.mm_use_im_patch_token:
258
+ if model_args.tune_mm_mlp_adapter:
259
+ for p in self.get_input_embeddings().parameters():
260
+ p.requires_grad = False
261
+ for p in self.get_output_embeddings().parameters():
262
+ p.requires_grad = False
geochat/model/language_model/.ipynb_checkpoints/geochat_llama-checkpoint.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, \
23
+ LlamaConfig, LlamaModel, LlamaForCausalLM
24
+
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+
27
+ from ..geochat_arch import GeoChatMetaModel, GeoChatMetaForCausalLM
28
+
29
+
30
+ class GeoChatConfig(LlamaConfig):
31
+ model_type = "geochat"
32
+
33
+
34
+ class GeoChatLlamaModel(GeoChatMetaModel, LlamaModel):
35
+ config_class = GeoChatConfig
36
+
37
+ def __init__(self, config: LlamaConfig):
38
+ super(GeoChatLlamaModel, self).__init__(config)
39
+
40
+
41
+ class GeoChatLlamaForCausalLM(LlamaForCausalLM, GeoChatMetaForCausalLM):
42
+ config_class = GeoChatConfig
43
+
44
+ def __init__(self, config):
45
+ super(LlamaForCausalLM, self).__init__(config)
46
+ self.model = GeoChatLlamaModel(config)
47
+
48
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.model
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.LongTensor = None,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
61
+ inputs_embeds: Optional[torch.FloatTensor] = None,
62
+ labels: Optional[torch.LongTensor] = None,
63
+ use_cache: Optional[bool] = None,
64
+ output_attentions: Optional[bool] = None,
65
+ output_hidden_states: Optional[bool] = None,
66
+ images: Optional[torch.FloatTensor] = None,
67
+ return_dict: Optional[bool] = None,
68
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
69
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
70
+ output_hidden_states = (
71
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
72
+ )
73
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
74
+
75
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
76
+
77
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
78
+ outputs = self.model(
79
+ input_ids=input_ids,
80
+ attention_mask=attention_mask,
81
+ past_key_values=past_key_values,
82
+ inputs_embeds=inputs_embeds,
83
+ use_cache=use_cache,
84
+ output_attentions=output_attentions,
85
+ output_hidden_states=output_hidden_states,
86
+ return_dict=return_dict
87
+ )
88
+
89
+ hidden_states = outputs[0]
90
+ logits = self.lm_head(hidden_states)
91
+
92
+ loss = None
93
+ if labels is not None:
94
+ # Shift so that tokens < n predict n
95
+ shift_logits = logits[..., :-1, :].contiguous()
96
+ shift_labels = labels[..., 1:].contiguous()
97
+ # Flatten the tokens
98
+ loss_fct = CrossEntropyLoss()
99
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
100
+ shift_labels = shift_labels.view(-1)
101
+ # Enable model/pipeline parallelism
102
+ shift_labels = shift_labels.to(shift_logits.device)
103
+ loss = loss_fct(shift_logits, shift_labels)
104
+
105
+ if not return_dict:
106
+ output = (logits,) + outputs[1:]
107
+ return (loss,) + output if loss is not None else output
108
+
109
+ return CausalLMOutputWithPast(
110
+ loss=loss,
111
+ logits=logits,
112
+ past_key_values=outputs.past_key_values,
113
+ hidden_states=outputs.hidden_states,
114
+ attentions=outputs.attentions,
115
+ )
116
+
117
+ def prepare_inputs_for_generation(
118
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
119
+ ):
120
+ if past_key_values:
121
+ input_ids = input_ids[:, -1:]
122
+
123
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
124
+ if inputs_embeds is not None and past_key_values is None:
125
+ model_inputs = {"inputs_embeds": inputs_embeds}
126
+ else:
127
+ model_inputs = {"input_ids": input_ids}
128
+
129
+ model_inputs.update(
130
+ {
131
+ "past_key_values": past_key_values,
132
+ "use_cache": kwargs.get("use_cache"),
133
+ "attention_mask": attention_mask,
134
+ "images": kwargs.get("images", None),
135
+ }
136
+ )
137
+ return model_inputs
138
+
139
+ AutoConfig.register("geochat", GeoChatConfig)
140
+ AutoModelForCausalLM.register(GeoChatConfig, GeoChatLlamaForCausalLM)
geochat/model/language_model/__pycache__/geochat_llama.cpython-310.pyc ADDED
Binary file (3.56 kB). View file
 
geochat/model/language_model/__pycache__/geochat_mpt.cpython-310.pyc ADDED
Binary file (4.71 kB). View file
 
geochat/model/language_model/geochat_llama.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, \
23
+ LlamaConfig, LlamaModel, LlamaForCausalLM
24
+
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+
27
+ from ..geochat_arch import GeoChatMetaModel, GeoChatMetaForCausalLM
28
+
29
+
30
+ class GeoChatConfig(LlamaConfig):
31
+ model_type = "geochat"
32
+
33
+
34
+ class GeoChatLlamaModel(GeoChatMetaModel, LlamaModel):
35
+ config_class = GeoChatConfig
36
+
37
+ def __init__(self, config: LlamaConfig):
38
+ super(GeoChatLlamaModel, self).__init__(config)
39
+
40
+
41
+ class GeoChatLlamaForCausalLM(LlamaForCausalLM, GeoChatMetaForCausalLM):
42
+ config_class = GeoChatConfig
43
+
44
+ def __init__(self, config):
45
+ super(LlamaForCausalLM, self).__init__(config)
46
+ self.model = GeoChatLlamaModel(config)
47
+
48
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.model
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.LongTensor = None,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
61
+ inputs_embeds: Optional[torch.FloatTensor] = None,
62
+ labels: Optional[torch.LongTensor] = None,
63
+ use_cache: Optional[bool] = None,
64
+ output_attentions: Optional[bool] = None,
65
+ output_hidden_states: Optional[bool] = None,
66
+ images: Optional[torch.FloatTensor] = None,
67
+ return_dict: Optional[bool] = None,
68
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
69
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
70
+ output_hidden_states = (
71
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
72
+ )
73
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
74
+
75
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
76
+
77
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
78
+ outputs = self.model(
79
+ input_ids=input_ids,
80
+ attention_mask=attention_mask,
81
+ past_key_values=past_key_values,
82
+ inputs_embeds=inputs_embeds,
83
+ use_cache=use_cache,
84
+ output_attentions=output_attentions,
85
+ output_hidden_states=output_hidden_states,
86
+ return_dict=return_dict
87
+ )
88
+
89
+ hidden_states = outputs[0]
90
+ logits = self.lm_head(hidden_states)
91
+
92
+ loss = None
93
+ if labels is not None:
94
+ # Shift so that tokens < n predict n
95
+ shift_logits = logits[..., :-1, :].contiguous()
96
+ shift_labels = labels[..., 1:].contiguous()
97
+ # Flatten the tokens
98
+ loss_fct = CrossEntropyLoss()
99
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
100
+ shift_labels = shift_labels.view(-1)
101
+ # Enable model/pipeline parallelism
102
+ shift_labels = shift_labels.to(shift_logits.device)
103
+ loss = loss_fct(shift_logits, shift_labels)
104
+
105
+ if not return_dict:
106
+ output = (logits,) + outputs[1:]
107
+ return (loss,) + output if loss is not None else output
108
+
109
+ return CausalLMOutputWithPast(
110
+ loss=loss,
111
+ logits=logits,
112
+ past_key_values=outputs.past_key_values,
113
+ hidden_states=outputs.hidden_states,
114
+ attentions=outputs.attentions,
115
+ )
116
+
117
+ def prepare_inputs_for_generation(
118
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
119
+ ):
120
+ if past_key_values:
121
+ input_ids = input_ids[:, -1:]
122
+
123
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
124
+ if inputs_embeds is not None and past_key_values is None:
125
+ model_inputs = {"inputs_embeds": inputs_embeds}
126
+ else:
127
+ model_inputs = {"input_ids": input_ids}
128
+
129
+ model_inputs.update(
130
+ {
131
+ "past_key_values": past_key_values,
132
+ "use_cache": kwargs.get("use_cache"),
133
+ "attention_mask": attention_mask,
134
+ "images": kwargs.get("images", None),
135
+ }
136
+ )
137
+ return model_inputs
138
+
139
+ AutoConfig.register("geochat", GeoChatConfig)
140
+ AutoModelForCausalLM.register(GeoChatConfig, GeoChatLlamaForCausalLM)
geochat/model/language_model/geochat_mpt.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple
17
+ import warnings
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ import math
22
+
23
+ from transformers import AutoConfig, AutoModelForCausalLM
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+
26
+ from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
27
+ from geochat.model.geochat_arch import GeoChatMetaModel, GeoChatMetaForCausalLM
28
+
29
+
30
+ class GeoChatMPTConfig(MPTConfig):
31
+ model_type = "geochat_mpt"
32
+
33
+
34
+ class GeoChatMPTModel(GeoChatMetaModel, MPTModel):
35
+ config_class = GeoChatMPTConfig
36
+
37
+ def __init__(self, config: MPTConfig):
38
+ config.hidden_size = config.d_model
39
+ super(GeoChatMPTModel, self).__init__(config)
40
+
41
+ def embed_tokens(self, x):
42
+ return self.wte(x)
43
+
44
+
45
+ class GeoChatMPTForCausalLM(MPTForCausalLM, GeoChatMetaForCausalLM):
46
+ config_class = GeoChatMPTConfig
47
+ supports_gradient_checkpointing = True
48
+
49
+ def __init__(self, config):
50
+ super(MPTForCausalLM, self).__init__(config)
51
+
52
+ if not config.tie_word_embeddings:
53
+ raise ValueError('MPTForCausalLM only supports tied word embeddings')
54
+ self.transformer = GeoChatMPTModel(config)
55
+ self.logit_scale = None
56
+ if config.logit_scale is not None:
57
+ logit_scale = config.logit_scale
58
+ if isinstance(logit_scale, str):
59
+ if logit_scale == 'inv_sqrt_d_model':
60
+ logit_scale = 1 / math.sqrt(config.d_model)
61
+ else:
62
+ raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
63
+ self.logit_scale = logit_scale
64
+
65
+ def get_model(self):
66
+ return self.transformer
67
+
68
+ def _set_gradient_checkpointing(self, module, value=False):
69
+ if isinstance(module, GeoChatMPTModel):
70
+ module.gradient_checkpointing = value
71
+
72
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None):
73
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
74
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
75
+
76
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
77
+ outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
78
+ # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338
79
+ logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
80
+ if self.logit_scale is not None:
81
+ if self.logit_scale == 0:
82
+ warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
83
+ logits *= self.logit_scale
84
+ loss = None
85
+ if labels is not None:
86
+ labels = torch.roll(labels, shifts=-1)
87
+ labels[:, -1] = -100
88
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
89
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
90
+
91
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
92
+ if inputs_embeds is not None:
93
+ raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
94
+ attention_mask = kwargs['attention_mask'].bool()
95
+ if attention_mask[:, -1].sum() != attention_mask.shape[0]:
96
+ raise NotImplementedError('MPT does not support generation with right padding.')
97
+ if self.transformer.attn_uses_sequence_id and self.training:
98
+ sequence_id = torch.zeros_like(input_ids[:1])
99
+ else:
100
+ sequence_id = None
101
+ if past_key_values is not None:
102
+ input_ids = input_ids[:, -1].unsqueeze(-1)
103
+ if self.transformer.prefix_lm:
104
+ prefix_mask = torch.ones_like(attention_mask)
105
+ if kwargs.get('use_cache') == False:
106
+ raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
107
+ else:
108
+ prefix_mask = None
109
+ return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)}
110
+
111
+
112
+ AutoConfig.register("geochat_mpt", GeoChatMPTConfig)
113
+ AutoModelForCausalLM.register(GeoChatMPTConfig, GeoChatMPTForCausalLM)
geochat/model/language_model/mpt/__pycache__/adapt_tokenizer.cpython-310.pyc ADDED
Binary file (2.24 kB). View file