Upload modeling_sa2va_chat.py with huggingface_hub
Browse files- modeling_sa2va_chat.py +130 -106
modeling_sa2va_chat.py
CHANGED
|
@@ -594,116 +594,137 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 594 |
assert tokenizer
|
| 595 |
self.preparing_for_generation(tokenizer=tokenizer)
|
| 596 |
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
input_dict['vp_overall_mask'] = None
|
| 620 |
else:
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
|
| 632 |
-
images = dynamic_preprocess(image, self.min_dynamic_patch,
|
| 633 |
-
self.max_dynamic_patch,
|
| 634 |
-
self.image_size, self.use_thumbnail)
|
| 635 |
-
|
| 636 |
-
if mask_prompts is not None:
|
| 637 |
-
vp_overall_mask = torch.Tensor([False] * (len(images) - 1) + [True])
|
| 638 |
-
input_dict['vp_overall_mask'] = vp_overall_mask
|
| 639 |
-
else:
|
| 640 |
input_dict['vp_overall_mask'] = None
|
|
|
|
|
|
|
| 641 |
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
region_pixels = []
|
| 659 |
-
for mask_prompt in mask_prompts[0]:
|
| 660 |
-
region_pixels.append(mask_prompt.bool().to(torch.int64).sum())
|
| 661 |
-
|
| 662 |
-
vp_token_str = '\nThere are {} part regions in the picture: '.format(len(mask_prompts[0]))
|
| 663 |
-
for i in range(len(mask_prompts[0])):
|
| 664 |
-
vp_token_str = vp_token_str + \
|
| 665 |
-
f"region{i + 1}" + self.VP_START_TOKEN + \
|
| 666 |
-
self.IMG_CONTEXT_TOKEN * region_pixels[i] + \
|
| 667 |
-
self.VP_END_TOKEN
|
| 668 |
-
if i == len(mask_prompts[0]) - 1:
|
| 669 |
-
vp_token_str = vp_token_str + '.\n'
|
| 670 |
else:
|
| 671 |
-
|
| 672 |
-
else:
|
| 673 |
-
vp_token_str = ''
|
| 674 |
-
|
| 675 |
-
image_token_str = f'{self.IMG_START_TOKEN}' \
|
| 676 |
-
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
|
| 677 |
-
f'{self.IMG_END_TOKEN}'
|
| 678 |
-
image_token_str = image_token_str + '\n'
|
| 679 |
-
image_token_str = image_token_str * num_frames
|
| 680 |
-
image_token_str = image_token_str.strip()
|
| 681 |
-
|
| 682 |
-
ret_masks = []
|
| 683 |
-
|
| 684 |
-
if '<image>' in text or mask_prompts is not None:
|
| 685 |
-
assert past_text is None or len(past_text) == 0
|
| 686 |
-
text = text.replace('<image>', image_token_str + vp_token_str)
|
| 687 |
-
input_text = ''
|
| 688 |
-
input_text += self.template['INSTRUCTION'].format(
|
| 689 |
-
input=text, round=1, bot_name=self.bot_name)
|
| 690 |
-
input_text = past_text + input_text
|
| 691 |
-
ids = self.tokenizer.encode(input_text)
|
| 692 |
-
ret_past_text = self.tokenizer.decode(ids)
|
| 693 |
-
ids = torch.tensor(ids).cuda().unsqueeze(0)
|
| 694 |
|
| 695 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 707 |
|
| 708 |
generate_output = self.generate(
|
| 709 |
**mm_inputs,
|
|
@@ -716,8 +737,10 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 716 |
)
|
| 717 |
predict = self.tokenizer.decode(
|
| 718 |
generate_output.sequences[0], skip_special_tokens=False).strip()
|
| 719 |
-
|
| 720 |
-
|
|
|
|
|
|
|
| 721 |
# if have seg result, find the seg hidden states
|
| 722 |
hidden_states = generate_output.hidden_states
|
| 723 |
last_hidden_states = [item[-1][0] for item in hidden_states]
|
|
@@ -739,7 +762,8 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 739 |
masks = masks.sigmoid() > 0.5
|
| 740 |
masks = masks.cpu().numpy()
|
| 741 |
ret_masks.append(masks)
|
| 742 |
-
|
|
|
|
| 743 |
|
| 744 |
def get_seg_hidden_states(hidden_states, output_ids, seg_id):
|
| 745 |
seg_mask = output_ids == seg_id
|
|
|
|
| 594 |
assert tokenizer
|
| 595 |
self.preparing_for_generation(tokenizer=tokenizer)
|
| 596 |
|
| 597 |
+
if image is None and video is None and '<image>' not in past_text:
|
| 598 |
+
text = text.replace('<image>', "")
|
| 599 |
+
input_text = ''
|
| 600 |
+
input_text += self.template['INSTRUCTION'].format(
|
| 601 |
+
input=text, round=1, bot_name=self.bot_name)
|
| 602 |
+
input_text = past_text + input_text
|
| 603 |
+
ids = self.tokenizer.encode(input_text)
|
| 604 |
+
ids = torch.tensor(ids).cuda().unsqueeze(0)
|
| 605 |
+
|
| 606 |
+
attention_mask = torch.ones_like(ids, dtype=torch.bool)
|
| 607 |
+
|
| 608 |
+
mm_inputs = {
|
| 609 |
+
'pixel_values': None,
|
| 610 |
+
'input_ids': ids,
|
| 611 |
+
'attention_mask': attention_mask,
|
| 612 |
+
'position_ids': None,
|
| 613 |
+
'past_key_values': None,
|
| 614 |
+
'labels': None,
|
| 615 |
+
'prompt_masks': None,
|
| 616 |
+
'vp_overall_mask': None,
|
| 617 |
+
}
|
| 618 |
+
ret_masks = []
|
|
|
|
| 619 |
else:
|
| 620 |
+
input_dict = {}
|
| 621 |
+
if video is not None:
|
| 622 |
+
pixel_values = []
|
| 623 |
+
extra_pixel_values = []
|
| 624 |
+
ori_image_size = video[0].size
|
| 625 |
+
for frame_idx, frame_image in enumerate(video):
|
| 626 |
+
assert ori_image_size == frame_image.size
|
| 627 |
+
g_image = np.array(frame_image) # for grounding
|
| 628 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
| 629 |
+
g_image = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
|
| 630 |
+
extra_pixel_values.append(g_image)
|
| 631 |
+
if frame_idx < 5:
|
| 632 |
+
img = self.transformer(frame_image)
|
| 633 |
+
pixel_values.append(img)
|
| 634 |
+
|
| 635 |
+
pixel_values = torch.stack(pixel_values, dim=0).to(self.torch_dtype) # (n_f, 3, h, w)
|
| 636 |
+
g_pixel_values = torch.stack([
|
| 637 |
+
self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
|
| 638 |
+
]).to(self.torch_dtype)
|
| 639 |
+
num_image_tokens = self.patch_token
|
| 640 |
+
num_frames = len(pixel_values)
|
| 641 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 642 |
input_dict['vp_overall_mask'] = None
|
| 643 |
+
else:
|
| 644 |
+
ori_image_size = image.size
|
| 645 |
|
| 646 |
+
# prepare grounding images
|
| 647 |
+
g_image = np.array(image) # for grounding
|
| 648 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
| 649 |
+
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous().to(self.torch_dtype)
|
| 650 |
+
extra_pixel_values = [g_pixel_values]
|
| 651 |
+
g_pixel_values = torch.stack([
|
| 652 |
+
self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
|
| 653 |
+
]).to(self.torch_dtype)
|
| 654 |
+
|
| 655 |
+
images = dynamic_preprocess(image, self.min_dynamic_patch,
|
| 656 |
+
self.max_dynamic_patch,
|
| 657 |
+
self.image_size, self.use_thumbnail)
|
| 658 |
+
|
| 659 |
+
if mask_prompts is not None:
|
| 660 |
+
vp_overall_mask = torch.Tensor([False] * (len(images) - 1) + [True])
|
| 661 |
+
input_dict['vp_overall_mask'] = vp_overall_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 662 |
else:
|
| 663 |
+
input_dict['vp_overall_mask'] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
|
| 665 |
+
pixel_values = [self.transformer(image) for image in images]
|
| 666 |
+
pixel_values = torch.stack(pixel_values).to(self.torch_dtype)
|
| 667 |
+
num_image_tokens = pixel_values.shape[0] * self.patch_token
|
| 668 |
+
num_frames = 1
|
| 669 |
+
input_dict['g_pixel_values'] = g_pixel_values
|
| 670 |
+
input_dict['pixel_values'] = pixel_values
|
| 671 |
|
| 672 |
+
if mask_prompts is not None:
|
| 673 |
+
# reshape mask prompts to feature size
|
| 674 |
+
mask_prompts = [torch.Tensor(item).to(pixel_values.device) for item in mask_prompts]
|
| 675 |
+
mask_prompts = [F.interpolate(
|
| 676 |
+
item.unsqueeze(0),
|
| 677 |
+
size=(int(self.image_size // self.patch_size * self.downsample_ratio),
|
| 678 |
+
int(self.image_size // self.patch_size * self.downsample_ratio)),
|
| 679 |
+
mode='nearest').squeeze(0) for item in mask_prompts]
|
| 680 |
+
region_pixels = []
|
| 681 |
+
for mask_prompt in mask_prompts[0]:
|
| 682 |
+
region_pixels.append(mask_prompt.bool().to(torch.int64).sum())
|
| 683 |
+
|
| 684 |
+
vp_token_str = '\nThere are {} part regions in the picture: '.format(len(mask_prompts[0]))
|
| 685 |
+
for i in range(len(mask_prompts[0])):
|
| 686 |
+
vp_token_str = vp_token_str + \
|
| 687 |
+
f"region{i + 1}" + self.VP_START_TOKEN + \
|
| 688 |
+
self.IMG_CONTEXT_TOKEN * region_pixels[i] + \
|
| 689 |
+
self.VP_END_TOKEN
|
| 690 |
+
if i == len(mask_prompts[0]) - 1:
|
| 691 |
+
vp_token_str = vp_token_str + '.\n'
|
| 692 |
+
else:
|
| 693 |
+
vp_token_str = vp_token_str + ', '
|
| 694 |
+
else:
|
| 695 |
+
vp_token_str = ''
|
| 696 |
+
|
| 697 |
+
image_token_str = f'{self.IMG_START_TOKEN}' \
|
| 698 |
+
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
|
| 699 |
+
f'{self.IMG_END_TOKEN}'
|
| 700 |
+
image_token_str = image_token_str + '\n'
|
| 701 |
+
image_token_str = image_token_str * num_frames
|
| 702 |
+
image_token_str = image_token_str.strip()
|
| 703 |
+
|
| 704 |
+
ret_masks = []
|
| 705 |
+
|
| 706 |
+
if '<image>' in text or mask_prompts is not None:
|
| 707 |
+
assert past_text is None or len(past_text) == 0
|
| 708 |
+
text = text.replace('<image>', image_token_str + vp_token_str)
|
| 709 |
+
input_text = ''
|
| 710 |
+
input_text += self.template['INSTRUCTION'].format(
|
| 711 |
+
input=text, round=1, bot_name=self.bot_name)
|
| 712 |
+
input_text = past_text + input_text
|
| 713 |
+
ids = self.tokenizer.encode(input_text)
|
| 714 |
+
ids = torch.tensor(ids).cuda().unsqueeze(0)
|
| 715 |
+
|
| 716 |
+
attention_mask = torch.ones_like(ids, dtype=torch.bool)
|
| 717 |
+
|
| 718 |
+
mm_inputs = {
|
| 719 |
+
'pixel_values': input_dict['pixel_values'],
|
| 720 |
+
'input_ids': ids,
|
| 721 |
+
'attention_mask': attention_mask,
|
| 722 |
+
'position_ids': None,
|
| 723 |
+
'past_key_values': None,
|
| 724 |
+
'labels': None,
|
| 725 |
+
'prompt_masks': mask_prompts,
|
| 726 |
+
'vp_overall_mask': input_dict['vp_overall_mask'],
|
| 727 |
+
}
|
| 728 |
|
| 729 |
generate_output = self.generate(
|
| 730 |
**mm_inputs,
|
|
|
|
| 737 |
)
|
| 738 |
predict = self.tokenizer.decode(
|
| 739 |
generate_output.sequences[0], skip_special_tokens=False).strip()
|
| 740 |
+
|
| 741 |
+
if image is None and video is None and '<image>' not in past_text:
|
| 742 |
+
return {'prediction': predict, 'prediction_masks': ret_masks, }
|
| 743 |
+
|
| 744 |
# if have seg result, find the seg hidden states
|
| 745 |
hidden_states = generate_output.hidden_states
|
| 746 |
last_hidden_states = [item[-1][0] for item in hidden_states]
|
|
|
|
| 762 |
masks = masks.sigmoid() > 0.5
|
| 763 |
masks = masks.cpu().numpy()
|
| 764 |
ret_masks.append(masks)
|
| 765 |
+
|
| 766 |
+
return {'prediction': predict, 'prediction_masks': ret_masks,}
|
| 767 |
|
| 768 |
def get_seg_hidden_states(hidden_states, output_ids, seg_id):
|
| 769 |
seg_mask = output_ids == seg_id
|