Move inputs to model's device
Browse files- src/rad_dino/__init__.py +6 -2
src/rad_dino/__init__.py
CHANGED
|
@@ -6,7 +6,7 @@ from torch import Tensor
|
|
| 6 |
from torch import nn
|
| 7 |
from transformers import AutoImageProcessor
|
| 8 |
from transformers import AutoModel
|
| 9 |
-
from transformers.
|
| 10 |
|
| 11 |
|
| 12 |
__version__ = "0.1.0"
|
|
@@ -25,6 +25,10 @@ class RadDino(nn.Module):
|
|
| 25 |
self.model = AutoModel.from_pretrained(self._REPO).eval()
|
| 26 |
self.processor = AutoImageProcessor.from_pretrained(self._REPO, use_fast=False)
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def preprocess(self, image_or_images: TypeInputImages) -> BatchFeature:
|
| 29 |
return self.processor(image_or_images, return_tensors="pt")
|
| 30 |
|
|
@@ -53,7 +57,7 @@ class RadDino(nn.Module):
|
|
| 53 |
self,
|
| 54 |
image_or_images: TypeInputImages,
|
| 55 |
) -> tuple[TypeClsToken, TypePatchTokens]:
|
| 56 |
-
inputs = self.preprocess(image_or_images)
|
| 57 |
cls_token, patch_tokens_flat = self.encode(inputs)
|
| 58 |
patch_tokens = self.reshape_patch_tokens(patch_tokens_flat)
|
| 59 |
return cls_token, patch_tokens
|
|
|
|
| 6 |
from torch import nn
|
| 7 |
from transformers import AutoImageProcessor
|
| 8 |
from transformers import AutoModel
|
| 9 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 10 |
|
| 11 |
|
| 12 |
__version__ = "0.1.0"
|
|
|
|
| 25 |
self.model = AutoModel.from_pretrained(self._REPO).eval()
|
| 26 |
self.processor = AutoImageProcessor.from_pretrained(self._REPO, use_fast=False)
|
| 27 |
|
| 28 |
+
@property
|
| 29 |
+
def device(self) -> torch.device:
|
| 30 |
+
return next(self.model.parameters()).device
|
| 31 |
+
|
| 32 |
def preprocess(self, image_or_images: TypeInputImages) -> BatchFeature:
|
| 33 |
return self.processor(image_or_images, return_tensors="pt")
|
| 34 |
|
|
|
|
| 57 |
self,
|
| 58 |
image_or_images: TypeInputImages,
|
| 59 |
) -> tuple[TypeClsToken, TypePatchTokens]:
|
| 60 |
+
inputs = self.preprocess(image_or_images).to(self.device)
|
| 61 |
cls_token, patch_tokens_flat = self.encode(inputs)
|
| 62 |
patch_tokens = self.reshape_patch_tokens(patch_tokens_flat)
|
| 63 |
return cls_token, patch_tokens
|