Update modeling_llava_qwen2.py
Browse files- modeling_llava_qwen2.py +5 -1
modeling_llava_qwen2.py
CHANGED
|
@@ -12,6 +12,7 @@ from PIL import Image
|
|
| 12 |
import torch.utils.checkpoint
|
| 13 |
from torch import nn
|
| 14 |
import torch
|
|
|
|
| 15 |
from transformers.image_processing_utils import BatchFeature, get_size_dict
|
| 16 |
from transformers.image_transforms import (convert_to_rgb, normalize, rescale, resize, to_channel_dimension_format, )
|
| 17 |
from transformers.image_utils import (ChannelDimension, PILImageResampling, to_numpy_array, )
|
|
@@ -534,6 +535,7 @@ class SigLipVisionTower(nn.Module):
|
|
| 534 |
self.is_loaded = True
|
| 535 |
|
| 536 |
@torch.no_grad()
|
|
|
|
| 537 |
def forward(self, images):
|
| 538 |
if type(images) is list:
|
| 539 |
image_features = []
|
|
@@ -659,11 +661,13 @@ class LlavaMetaForCausalLM(ABC):
|
|
| 659 |
def get_vision_tower(self):
|
| 660 |
return self.get_model().get_vision_tower()
|
| 661 |
|
|
|
|
| 662 |
def encode_images(self, images):
|
| 663 |
image_features = self.get_model().get_vision_tower()(images)
|
| 664 |
image_features = self.get_model().mm_projector(image_features)
|
| 665 |
return image_features
|
| 666 |
-
|
|
|
|
| 667 |
def prepare_inputs_labels_for_multimodal(
|
| 668 |
self, input_ids, position_ids, attention_mask, past_key_values, labels, images
|
| 669 |
):
|
|
|
|
| 12 |
import torch.utils.checkpoint
|
| 13 |
from torch import nn
|
| 14 |
import torch
|
| 15 |
+
import spaces
|
| 16 |
from transformers.image_processing_utils import BatchFeature, get_size_dict
|
| 17 |
from transformers.image_transforms import (convert_to_rgb, normalize, rescale, resize, to_channel_dimension_format, )
|
| 18 |
from transformers.image_utils import (ChannelDimension, PILImageResampling, to_numpy_array, )
|
|
|
|
| 535 |
self.is_loaded = True
|
| 536 |
|
| 537 |
@torch.no_grad()
|
| 538 |
+
@spaces.GPU
|
| 539 |
def forward(self, images):
|
| 540 |
if type(images) is list:
|
| 541 |
image_features = []
|
|
|
|
| 661 |
def get_vision_tower(self):
|
| 662 |
return self.get_model().get_vision_tower()
|
| 663 |
|
| 664 |
+
@spaces.GPU
|
| 665 |
def encode_images(self, images):
|
| 666 |
image_features = self.get_model().get_vision_tower()(images)
|
| 667 |
image_features = self.get_model().mm_projector(image_features)
|
| 668 |
return image_features
|
| 669 |
+
|
| 670 |
+
@spaces.GPU
|
| 671 |
def prepare_inputs_labels_for_multimodal(
|
| 672 |
self, input_ids, position_ids, attention_mask, past_key_values, labels, images
|
| 673 |
):
|