Spaces:
Runtime error
Runtime error
import torch | |
from mmengine.model import BaseModule | |
from torch import nn | |
from mmseg.registry import MODELS | |
class DINOv2(nn.Module): | |
"""Use DINOv2 pre-trained models | |
""" | |
def __init__(self, version='large', freeze=False, load_from=None): | |
super().__init__() | |
if version == 'large': | |
self.dinov2 = torch.hub.load('torchhub/facebookresearch_dinov2_main', 'dinov2_vitl14', source='local', pretrained=False) | |
else: | |
raise NotImplementedError | |
if load_from is not None: | |
d = torch.load(load_from, map_location='cpu') | |
new_d = {} | |
for key, value in d.items(): | |
if 'pretrained' in key: | |
new_d[key.replace('pretrained.', '')] = value | |
self.dinov2.load_state_dict(new_d) | |
self.freeze = freeze | |
def forward(self, inputs): | |
B, _, h, w = inputs.shape | |
if self.freeze: | |
with torch.no_grad(): | |
features = self.dinov2.get_intermediate_layers(inputs, 4) | |
else: | |
features = self.dinov2.get_intermediate_layers(inputs, 4) | |
outs = [] | |
for feature in features: | |
C = feature.shape[-1] | |
feature = feature.permute(0, 2, 1).reshape(B, C, h // 14, w // 14).contiguous() | |
outs.append(feature) | |
return outs | |