charlescxk commited on
Commit
a38a851
·
1 Parent(s): 2c415b6
Files changed (40) hide show
  1. README.md +0 -13
  2. Upsample/__init__.py +1 -0
  3. Upsample/__pycache__/__init__.cpython-38.pyc +0 -0
  4. Upsample/__pycache__/arch_utils.cpython-38.pyc +0 -0
  5. Upsample/__pycache__/model.cpython-38.pyc +0 -0
  6. Upsample/__pycache__/rrdbnet_arch.cpython-38.pyc +0 -0
  7. Upsample/__pycache__/utils.cpython-38.pyc +0 -0
  8. Upsample/arch_utils.py +197 -0
  9. Upsample/model.py +93 -0
  10. Upsample/rrdbnet_arch.py +121 -0
  11. Upsample/utils.py +135 -0
  12. app.py +268 -0
  13. doge.png +0 -0
  14. equation.png +0 -0
  15. janus/__init__.py +31 -0
  16. janus/__pycache__/__init__.cpython-38.pyc +0 -0
  17. janus/models/__init__.py +28 -0
  18. janus/models/__pycache__/__init__.cpython-38.pyc +0 -0
  19. janus/models/__pycache__/clip_encoder.cpython-38.pyc +0 -0
  20. janus/models/__pycache__/image_processing_vlm.cpython-38.pyc +0 -0
  21. janus/models/__pycache__/modeling_vlm.cpython-38.pyc +0 -0
  22. janus/models/__pycache__/processing_vlm.cpython-38.pyc +0 -0
  23. janus/models/__pycache__/projector.cpython-38.pyc +0 -0
  24. janus/models/__pycache__/siglip_vit.cpython-38.pyc +0 -0
  25. janus/models/__pycache__/vq_model.cpython-38.pyc +0 -0
  26. janus/models/clip_encoder.py +122 -0
  27. janus/models/image_processing_vlm.py +208 -0
  28. janus/models/modeling_vlm.py +272 -0
  29. janus/models/processing_vlm.py +418 -0
  30. janus/models/projector.py +100 -0
  31. janus/models/siglip_vit.py +681 -0
  32. janus/models/vq_model.py +527 -0
  33. janus/utils/__init__.py +18 -0
  34. janus/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  35. janus/utils/__pycache__/conversation.cpython-38.pyc +0 -0
  36. janus/utils/__pycache__/io.cpython-38.pyc +0 -0
  37. janus/utils/conversation.py +365 -0
  38. janus/utils/io.py +89 -0
  39. requirements.txt +8 -0
  40. weights/RealESRGAN_x2.pth +3 -0
README.md DELETED
@@ -1,13 +0,0 @@
1
- ---
2
- title: Janus Pro 7B
3
- emoji: 🐢
4
- colorFrom: purple
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.13.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Upsample/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import RealESRGAN
Upsample/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (213 Bytes). View file
 
Upsample/__pycache__/arch_utils.cpython-38.pyc ADDED
Binary file (7.14 kB). View file
 
Upsample/__pycache__/model.cpython-38.pyc ADDED
Binary file (3.11 kB). View file
 
Upsample/__pycache__/rrdbnet_arch.cpython-38.pyc ADDED
Binary file (4.47 kB). View file
 
Upsample/__pycache__/utils.cpython-38.pyc ADDED
Binary file (4.05 kB). View file
 
Upsample/arch_utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn as nn
4
+ from torch.nn import functional as F
5
+ from torch.nn import init as init
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+ @torch.no_grad()
9
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
10
+ """Initialize network weights.
11
+
12
+ Args:
13
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
14
+ scale (float): Scale initialized weights, especially for residual
15
+ blocks. Default: 1.
16
+ bias_fill (float): The value to fill bias. Default: 0
17
+ kwargs (dict): Other arguments for initialization function.
18
+ """
19
+ if not isinstance(module_list, list):
20
+ module_list = [module_list]
21
+ for module in module_list:
22
+ for m in module.modules():
23
+ if isinstance(m, nn.Conv2d):
24
+ init.kaiming_normal_(m.weight, **kwargs)
25
+ m.weight.data *= scale
26
+ if m.bias is not None:
27
+ m.bias.data.fill_(bias_fill)
28
+ elif isinstance(m, nn.Linear):
29
+ init.kaiming_normal_(m.weight, **kwargs)
30
+ m.weight.data *= scale
31
+ if m.bias is not None:
32
+ m.bias.data.fill_(bias_fill)
33
+ elif isinstance(m, _BatchNorm):
34
+ init.constant_(m.weight, 1)
35
+ if m.bias is not None:
36
+ m.bias.data.fill_(bias_fill)
37
+
38
+
39
+ def make_layer(basic_block, num_basic_block, **kwarg):
40
+ """Make layers by stacking the same blocks.
41
+
42
+ Args:
43
+ basic_block (nn.module): nn.module class for basic block.
44
+ num_basic_block (int): number of blocks.
45
+
46
+ Returns:
47
+ nn.Sequential: Stacked blocks in nn.Sequential.
48
+ """
49
+ layers = []
50
+ for _ in range(num_basic_block):
51
+ layers.append(basic_block(**kwarg))
52
+ return nn.Sequential(*layers)
53
+
54
+
55
+ class ResidualBlockNoBN(nn.Module):
56
+ """Residual block without BN.
57
+
58
+ It has a style of:
59
+ ---Conv-ReLU-Conv-+-
60
+ |________________|
61
+
62
+ Args:
63
+ num_feat (int): Channel number of intermediate features.
64
+ Default: 64.
65
+ res_scale (float): Residual scale. Default: 1.
66
+ pytorch_init (bool): If set to True, use pytorch default init,
67
+ otherwise, use default_init_weights. Default: False.
68
+ """
69
+
70
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
71
+ super(ResidualBlockNoBN, self).__init__()
72
+ self.res_scale = res_scale
73
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
74
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
75
+ self.relu = nn.ReLU(inplace=True)
76
+
77
+ if not pytorch_init:
78
+ default_init_weights([self.conv1, self.conv2], 0.1)
79
+
80
+ def forward(self, x):
81
+ identity = x
82
+ out = self.conv2(self.relu(self.conv1(x)))
83
+ return identity + out * self.res_scale
84
+
85
+
86
+ class Upsample(nn.Sequential):
87
+ """Upsample module.
88
+
89
+ Args:
90
+ scale (int): Scale factor. Supported scales: 2^n and 3.
91
+ num_feat (int): Channel number of intermediate features.
92
+ """
93
+
94
+ def __init__(self, scale, num_feat):
95
+ m = []
96
+ if (scale & (scale - 1)) == 0: # scale = 2^n
97
+ for _ in range(int(math.log(scale, 2))):
98
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
99
+ m.append(nn.PixelShuffle(2))
100
+ elif scale == 3:
101
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
102
+ m.append(nn.PixelShuffle(3))
103
+ else:
104
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
105
+ super(Upsample, self).__init__(*m)
106
+
107
+
108
+ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
109
+ """Warp an image or feature map with optical flow.
110
+
111
+ Args:
112
+ x (Tensor): Tensor with size (n, c, h, w).
113
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
114
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
115
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
116
+ Default: 'zeros'.
117
+ align_corners (bool): Before pytorch 1.3, the default value is
118
+ align_corners=True. After pytorch 1.3, the default value is
119
+ align_corners=False. Here, we use the True as default.
120
+
121
+ Returns:
122
+ Tensor: Warped image or feature map.
123
+ """
124
+ assert x.size()[-2:] == flow.size()[1:3]
125
+ _, _, h, w = x.size()
126
+ # create mesh grid
127
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
128
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
129
+ grid.requires_grad = False
130
+
131
+ vgrid = grid + flow
132
+ # scale grid to [-1,1]
133
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
134
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
135
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
136
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
137
+
138
+ # TODO, what if align_corners=False
139
+ return output
140
+
141
+
142
+ def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
143
+ """Resize a flow according to ratio or shape.
144
+
145
+ Args:
146
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
147
+ size_type (str): 'ratio' or 'shape'.
148
+ sizes (list[int | float]): the ratio for resizing or the final output
149
+ shape.
150
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
151
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
152
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
153
+ ratio > 1.0).
154
+ 2) The order of output_size should be [out_h, out_w].
155
+ interp_mode (str): The mode of interpolation for resizing.
156
+ Default: 'bilinear'.
157
+ align_corners (bool): Whether align corners. Default: False.
158
+
159
+ Returns:
160
+ Tensor: Resized flow.
161
+ """
162
+ _, _, flow_h, flow_w = flow.size()
163
+ if size_type == 'ratio':
164
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
165
+ elif size_type == 'shape':
166
+ output_h, output_w = sizes[0], sizes[1]
167
+ else:
168
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
169
+
170
+ input_flow = flow.clone()
171
+ ratio_h = output_h / flow_h
172
+ ratio_w = output_w / flow_w
173
+ input_flow[:, 0, :, :] *= ratio_w
174
+ input_flow[:, 1, :, :] *= ratio_h
175
+ resized_flow = F.interpolate(
176
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
177
+ return resized_flow
178
+
179
+
180
+ # TODO: may write a cpp file
181
+ def pixel_unshuffle(x, scale):
182
+ """ Pixel unshuffle.
183
+
184
+ Args:
185
+ x (Tensor): Input feature with shape (b, c, hh, hw).
186
+ scale (int): Downsample ratio.
187
+
188
+ Returns:
189
+ Tensor: the pixel unshuffled feature.
190
+ """
191
+ b, c, hh, hw = x.size()
192
+ out_channel = c * (scale**2)
193
+ assert hh % scale == 0 and hw % scale == 0
194
+ h = hh // scale
195
+ w = hw // scale
196
+ x_view = x.view(b, c, h, scale, w, scale)
197
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
Upsample/model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.nn import functional as F
4
+ from PIL import Image
5
+ import numpy as np
6
+ import cv2
7
+ from huggingface_hub import hf_hub_url, hf_hub_download
8
+
9
+ from .rrdbnet_arch import RRDBNet
10
+ from .utils import pad_reflect, split_image_into_overlapping_patches, stich_together, \
11
+ unpad_image
12
+
13
+ HF_MODELS = {
14
+ 2: dict(
15
+ repo_id='sberbank-ai/Real-ESRGAN',
16
+ filename='RealESRGAN_x2.pth',
17
+ ),
18
+ 4: dict(
19
+ repo_id='sberbank-ai/Real-ESRGAN',
20
+ filename='RealESRGAN_x4.pth',
21
+ ),
22
+ 8: dict(
23
+ repo_id='sberbank-ai/Real-ESRGAN',
24
+ filename='RealESRGAN_x8.pth',
25
+ ),
26
+ }
27
+
28
+
29
+ class RealESRGAN:
30
+ def __init__(self, device, scale=4):
31
+ self.device = device
32
+ self.scale = scale
33
+ self.model = RRDBNet(
34
+ num_in_ch=3, num_out_ch=3, num_feat=64,
35
+ num_block=23, num_grow_ch=32, scale=scale
36
+ )
37
+
38
+ def load_weights(self, model_path, download=True):
39
+ if not os.path.exists(model_path) and download:
40
+ assert self.scale in [2, 4, 8], 'You can download models only with scales: 2, 4, 8'
41
+ config = HF_MODELS[self.scale]
42
+ cache_dir = os.path.dirname(model_path)
43
+ local_filename = os.path.basename(model_path)
44
+ config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
45
+ htr = hf_hub_download(repo_id=config['repo_id'], cache_dir=cache_dir, local_dir=cache_dir,
46
+ filename=config['filename'])
47
+ print(htr)
48
+ # cached_download(config_file_url, cache_dir=cache_dir, force_filename=local_filename)
49
+ print('Weights downloaded to:', os.path.join(cache_dir, local_filename))
50
+
51
+ loadnet = torch.load(model_path)
52
+ if 'params' in loadnet:
53
+ self.model.load_state_dict(loadnet['params'], strict=True)
54
+ elif 'params_ema' in loadnet:
55
+ self.model.load_state_dict(loadnet['params_ema'], strict=True)
56
+ else:
57
+ self.model.load_state_dict(loadnet, strict=True)
58
+ self.model.eval()
59
+ self.model.to(self.device)
60
+
61
+ # @torch.cuda.amp.autocast()
62
+ def predict(self, lr_image, batch_size=4, patches_size=192,
63
+ padding=24, pad_size=15):
64
+ torch.autocast(device_type=self.device.type)
65
+ scale = self.scale
66
+ device = self.device
67
+ lr_image = np.array(lr_image)
68
+ lr_image = pad_reflect(lr_image, pad_size)
69
+
70
+ patches, p_shape = split_image_into_overlapping_patches(
71
+ lr_image, patch_size=patches_size, padding_size=padding
72
+ )
73
+ img = torch.FloatTensor(patches / 255).permute((0, 3, 1, 2)).to(device).detach()
74
+
75
+ with torch.no_grad():
76
+ res = self.model(img[0:batch_size])
77
+ for i in range(batch_size, img.shape[0], batch_size):
78
+ res = torch.cat((res, self.model(img[i:i + batch_size])), 0)
79
+
80
+ sr_image = res.permute((0, 2, 3, 1)).cpu().clamp_(0, 1)
81
+ np_sr_image = sr_image.numpy()
82
+
83
+ padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
84
+ scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
85
+ np_sr_image = stich_together(
86
+ np_sr_image, padded_image_shape=padded_size_scaled,
87
+ target_shape=scaled_image_shape, padding_size=padding * scale
88
+ )
89
+ sr_img = (np_sr_image * 255).astype(np.uint8)
90
+ sr_img = unpad_image(sr_img, pad_size * scale)
91
+ sr_img = Image.fromarray(sr_img)
92
+
93
+ return sr_img
Upsample/rrdbnet_arch.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from .arch_utils import default_init_weights, make_layer, pixel_unshuffle
6
+
7
+
8
+ class ResidualDenseBlock(nn.Module):
9
+ """Residual Dense Block.
10
+
11
+ Used in RRDB block in ESRGAN.
12
+
13
+ Args:
14
+ num_feat (int): Channel number of intermediate features.
15
+ num_grow_ch (int): Channels for each growth.
16
+ """
17
+
18
+ def __init__(self, num_feat=64, num_grow_ch=32):
19
+ super(ResidualDenseBlock, self).__init__()
20
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
21
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
22
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
23
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
25
+
26
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
27
+
28
+ # initialization
29
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
30
+
31
+ def forward(self, x):
32
+ x1 = self.lrelu(self.conv1(x))
33
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
34
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
35
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
36
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
37
+ # Emperically, we use 0.2 to scale the residual for better performance
38
+ return x5 * 0.2 + x
39
+
40
+
41
+ class RRDB(nn.Module):
42
+ """Residual in Residual Dense Block.
43
+
44
+ Used in RRDB-Net in ESRGAN.
45
+
46
+ Args:
47
+ num_feat (int): Channel number of intermediate features.
48
+ num_grow_ch (int): Channels for each growth.
49
+ """
50
+
51
+ def __init__(self, num_feat, num_grow_ch=32):
52
+ super(RRDB, self).__init__()
53
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
54
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
55
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
56
+
57
+ def forward(self, x):
58
+ out = self.rdb1(x)
59
+ out = self.rdb2(out)
60
+ out = self.rdb3(out)
61
+ # Emperically, we use 0.2 to scale the residual for better performance
62
+ return out * 0.2 + x
63
+
64
+
65
+ class RRDBNet(nn.Module):
66
+ """Networks consisting of Residual in Residual Dense Block, which is used
67
+ in ESRGAN.
68
+
69
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
70
+
71
+ We extend ESRGAN for scale x2 and scale x1.
72
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
73
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
74
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
75
+
76
+ Args:
77
+ num_in_ch (int): Channel number of inputs.
78
+ num_out_ch (int): Channel number of outputs.
79
+ num_feat (int): Channel number of intermediate features.
80
+ Default: 64
81
+ num_block (int): Block number in the trunk network. Defaults: 23
82
+ num_grow_ch (int): Channels for each growth. Default: 32.
83
+ """
84
+
85
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
86
+ super(RRDBNet, self).__init__()
87
+ self.scale = scale
88
+ if scale == 2:
89
+ num_in_ch = num_in_ch * 4
90
+ elif scale == 1:
91
+ num_in_ch = num_in_ch * 16
92
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
93
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
94
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
95
+ # upsample
96
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
98
+ if scale == 8:
99
+ self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
100
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
101
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
102
+
103
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
104
+
105
+ def forward(self, x):
106
+ if self.scale == 2:
107
+ feat = pixel_unshuffle(x, scale=2)
108
+ elif self.scale == 1:
109
+ feat = pixel_unshuffle(x, scale=4)
110
+ else:
111
+ feat = x
112
+ feat = self.conv_first(feat)
113
+ body_feat = self.conv_body(self.body(feat))
114
+ feat = feat + body_feat
115
+ # upsample
116
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
117
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
118
+ if self.scale == 8:
119
+ feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest')))
120
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
121
+ return out
Upsample/utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ import os
5
+ import io
6
+
7
+
8
+ def pad_reflect(image, pad_size):
9
+ imsize = image.shape
10
+ height, width = imsize[:2]
11
+ new_img = np.zeros([height + pad_size * 2, width + pad_size * 2, imsize[2]]).astype(np.uint8)
12
+ new_img[pad_size:-pad_size, pad_size:-pad_size, :] = image
13
+
14
+ new_img[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) # top
15
+ new_img[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) # bottom
16
+ new_img[:, 0:pad_size, :] = np.flip(new_img[:, pad_size:pad_size * 2, :], axis=1) # left
17
+ new_img[:, -pad_size:, :] = np.flip(new_img[:, -pad_size * 2:-pad_size, :], axis=1) # right
18
+
19
+ return new_img
20
+
21
+
22
+ def unpad_image(image, pad_size):
23
+ return image[pad_size:-pad_size, pad_size:-pad_size, :]
24
+
25
+
26
+ def process_array(image_array, expand=True):
27
+ """ Process a 3-dimensional array into a scaled, 4 dimensional batch of size 1. """
28
+
29
+ image_batch = image_array / 255.0
30
+ if expand:
31
+ image_batch = np.expand_dims(image_batch, axis=0)
32
+ return image_batch
33
+
34
+
35
+ def process_output(output_tensor):
36
+ """ Transforms the 4-dimensional output tensor into a suitable image format. """
37
+
38
+ sr_img = output_tensor.clip(0, 1) * 255
39
+ sr_img = np.uint8(sr_img)
40
+ return sr_img
41
+
42
+
43
+ def pad_patch(image_patch, padding_size, channel_last=True):
44
+ """ Pads image_patch with with padding_size edge values. """
45
+
46
+ if channel_last:
47
+ return np.pad(
48
+ image_patch,
49
+ ((padding_size, padding_size), (padding_size, padding_size), (0, 0)),
50
+ 'edge',
51
+ )
52
+ else:
53
+ return np.pad(
54
+ image_patch,
55
+ ((0, 0), (padding_size, padding_size), (padding_size, padding_size)),
56
+ 'edge',
57
+ )
58
+
59
+
60
+ def unpad_patches(image_patches, padding_size):
61
+ return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :]
62
+
63
+
64
+ def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2):
65
+ """ Splits the image into partially overlapping patches.
66
+ The patches overlap by padding_size pixels.
67
+ Pads the image twice:
68
+ - first to have a size multiple of the patch size,
69
+ - then to have equal padding at the borders.
70
+ Args:
71
+ image_array: numpy array of the input image.
72
+ patch_size: size of the patches from the original image (without padding).
73
+ padding_size: size of the overlapping area.
74
+ """
75
+
76
+ xmax, ymax, _ = image_array.shape
77
+ x_remainder = xmax % patch_size
78
+ y_remainder = ymax % patch_size
79
+
80
+ # modulo here is to avoid extending of patch_size instead of 0
81
+ x_extend = (patch_size - x_remainder) % patch_size
82
+ y_extend = (patch_size - y_remainder) % patch_size
83
+
84
+ # make sure the image is divisible into regular patches
85
+ extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')
86
+
87
+ # add padding around the image to simplify computations
88
+ padded_image = pad_patch(extended_image, padding_size, channel_last=True)
89
+
90
+ xmax, ymax, _ = padded_image.shape
91
+ patches = []
92
+
93
+ x_lefts = range(padding_size, xmax - padding_size, patch_size)
94
+ y_tops = range(padding_size, ymax - padding_size, patch_size)
95
+
96
+ for x in x_lefts:
97
+ for y in y_tops:
98
+ x_left = x - padding_size
99
+ y_top = y - padding_size
100
+ x_right = x + patch_size + padding_size
101
+ y_bottom = y + patch_size + padding_size
102
+ patch = padded_image[x_left:x_right, y_top:y_bottom, :]
103
+ patches.append(patch)
104
+
105
+ return np.array(patches), padded_image.shape
106
+
107
+
108
+ def stich_together(patches, padded_image_shape, target_shape, padding_size=4):
109
+ """ Reconstruct the image from overlapping patches.
110
+ After scaling, shapes and padding should be scaled too.
111
+ Args:
112
+ patches: patches obtained with split_image_into_overlapping_patches
113
+ padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches
114
+ target_shape: shape of the final image
115
+ padding_size: size of the overlapping area.
116
+ """
117
+
118
+ xmax, ymax, _ = padded_image_shape
119
+ patches = unpad_patches(patches, padding_size)
120
+ patch_size = patches.shape[1]
121
+ n_patches_per_row = ymax // patch_size
122
+
123
+ complete_image = np.zeros((xmax, ymax, 3))
124
+
125
+ row = -1
126
+ col = 0
127
+ for i in range(len(patches)):
128
+ if i % n_patches_per_row == 0:
129
+ row += 1
130
+ col = 0
131
+ complete_image[
132
+ row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size, :
133
+ ] = patches[i]
134
+ col += 1
135
+ return complete_image[0: target_shape[0], 0: target_shape[1], :]
app.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoConfig, AutoModelForCausalLM
4
+ from janus.models import MultiModalityCausalLM, VLChatProcessor
5
+ from janus.utils.io import load_pil_images
6
+ from PIL import Image
7
+
8
+ import numpy as np
9
+ import os
10
+ import time
11
+ from Upsample import RealESRGAN
12
+ import spaces # Import spaces for ZeroGPU compatibility
13
+
14
+
15
+ # Load model and processor
16
+ model_path = "deepseek-ai/Janus-Pro-7B"
17
+ config = AutoConfig.from_pretrained(model_path)
18
+ language_config = config.language_config
19
+ language_config._attn_implementation = 'eager'
20
+ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
21
+ language_config=language_config,
22
+ trust_remote_code=True)
23
+ if torch.cuda.is_available():
24
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
25
+ else:
26
+ vl_gpt = vl_gpt.to(torch.float16)
27
+
28
+ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
29
+ tokenizer = vl_chat_processor.tokenizer
30
+ cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
+
32
+ # SR model
33
+ sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2)
34
+ sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
35
+
36
+ @torch.inference_mode()
37
+ @spaces.GPU(duration=120)
38
+ # Multimodal Understanding function
39
+ def multimodal_understanding(image, question, seed, top_p, temperature):
40
+ # Clear CUDA cache before generating
41
+ torch.cuda.empty_cache()
42
+
43
+ # set seed
44
+ torch.manual_seed(seed)
45
+ np.random.seed(seed)
46
+ torch.cuda.manual_seed(seed)
47
+
48
+ conversation = [
49
+ {
50
+ "role": "<|User|>",
51
+ "content": f"<image_placeholder>\n{question}",
52
+ "images": [image],
53
+ },
54
+ {"role": "<|Assistant|>", "content": ""},
55
+ ]
56
+
57
+ pil_images = [Image.fromarray(image)]
58
+ prepare_inputs = vl_chat_processor(
59
+ conversations=conversation, images=pil_images, force_batchify=True
60
+ ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
61
+
62
+
63
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
64
+
65
+ outputs = vl_gpt.language_model.generate(
66
+ inputs_embeds=inputs_embeds,
67
+ attention_mask=prepare_inputs.attention_mask,
68
+ pad_token_id=tokenizer.eos_token_id,
69
+ bos_token_id=tokenizer.bos_token_id,
70
+ eos_token_id=tokenizer.eos_token_id,
71
+ max_new_tokens=512,
72
+ do_sample=False if temperature == 0 else True,
73
+ use_cache=True,
74
+ temperature=temperature,
75
+ top_p=top_p,
76
+ )
77
+
78
+ answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
79
+ return answer
80
+
81
+
82
+ def generate(input_ids,
83
+ width,
84
+ height,
85
+ temperature: float = 1,
86
+ parallel_size: int = 5,
87
+ cfg_weight: float = 5,
88
+ image_token_num_per_image: int = 576,
89
+ patch_size: int = 16):
90
+ # Clear CUDA cache before generating
91
+ torch.cuda.empty_cache()
92
+
93
+ tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
94
+ for i in range(parallel_size * 2):
95
+ tokens[i, :] = input_ids
96
+ if i % 2 != 0:
97
+ tokens[i, 1:-1] = vl_chat_processor.pad_id
98
+ inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
99
+ generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
100
+
101
+ pkv = None
102
+ for i in range(image_token_num_per_image):
103
+ with torch.no_grad():
104
+ outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
105
+ use_cache=True,
106
+ past_key_values=pkv)
107
+ pkv = outputs.past_key_values
108
+ hidden_states = outputs.last_hidden_state
109
+ logits = vl_gpt.gen_head(hidden_states[:, -1, :])
110
+ logit_cond = logits[0::2, :]
111
+ logit_uncond = logits[1::2, :]
112
+ logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
113
+ probs = torch.softmax(logits / temperature, dim=-1)
114
+ next_token = torch.multinomial(probs, num_samples=1)
115
+ generated_tokens[:, i] = next_token.squeeze(dim=-1)
116
+ next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
117
+
118
+ img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
119
+ inputs_embeds = img_embeds.unsqueeze(dim=1)
120
+
121
+
122
+
123
+ patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
124
+ shape=[parallel_size, 8, width // patch_size, height // patch_size])
125
+
126
+ return generated_tokens.to(dtype=torch.int), patches
127
+
128
+ def unpack(dec, width, height, parallel_size=5):
129
+ dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
130
+ dec = np.clip((dec + 1) / 2 * 255, 0, 255)
131
+
132
+ visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
133
+ visual_img[:, :, :] = dec
134
+
135
+ return visual_img
136
+
137
+
138
+
139
+ @torch.inference_mode()
140
+ @spaces.GPU(duration=120) # Specify a duration to avoid timeout
141
+ def generate_image(prompt,
142
+ seed=None,
143
+ guidance=5,
144
+ t2i_temperature=1.0):
145
+ # Clear CUDA cache and avoid tracking gradients
146
+ torch.cuda.empty_cache()
147
+ # Set the seed for reproducible results
148
+ if seed is not None:
149
+ torch.manual_seed(seed)
150
+ torch.cuda.manual_seed(seed)
151
+ np.random.seed(seed)
152
+ width = 384
153
+ height = 384
154
+ parallel_size = 5
155
+
156
+ with torch.no_grad():
157
+ messages = [{'role': '<|User|>', 'content': prompt},
158
+ {'role': '<|Assistant|>', 'content': ''}]
159
+ text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
160
+ sft_format=vl_chat_processor.sft_format,
161
+ system_prompt='')
162
+ text = text + vl_chat_processor.image_start_tag
163
+
164
+ input_ids = torch.LongTensor(tokenizer.encode(text))
165
+ output, patches = generate(input_ids,
166
+ width // 16 * 16,
167
+ height // 16 * 16,
168
+ cfg_weight=guidance,
169
+ parallel_size=parallel_size,
170
+ temperature=t2i_temperature)
171
+ images = unpack(patches,
172
+ width // 16 * 16,
173
+ height // 16 * 16,
174
+ parallel_size=parallel_size)
175
+
176
+ # return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)]
177
+ stime = time.time()
178
+ ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
179
+ print(f'upsample time: {time.time() - stime}')
180
+ return ret_images
181
+
182
+
183
+ @spaces.GPU(duration=60)
184
+ def image_upsample(img: Image.Image) -> Image.Image:
185
+ if img is None:
186
+ raise Exception("Image not uploaded")
187
+
188
+ width, height = img.size
189
+
190
+ if width >= 5000 or height >= 5000:
191
+ raise Exception("The image is too large.")
192
+
193
+ global sr_model
194
+ result = sr_model.predict(img.convert('RGB'))
195
+ return result
196
+
197
+
198
+ # Gradio interface
199
+ with gr.Blocks() as demo:
200
+ gr.Markdown(value="# Multimodal Understanding")
201
+ with gr.Row():
202
+ image_input = gr.Image()
203
+ with gr.Column():
204
+ question_input = gr.Textbox(label="Question")
205
+ und_seed_input = gr.Number(label="Seed", precision=0, value=42)
206
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
207
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
208
+
209
+ understanding_button = gr.Button("Chat")
210
+ understanding_output = gr.Textbox(label="Response")
211
+
212
+ examples_inpainting = gr.Examples(
213
+ label="Multimodal Understanding examples",
214
+ examples=[
215
+ [
216
+ "explain this meme",
217
+ "doge.png",
218
+ ],
219
+ [
220
+ "Convert the formula into latex code.",
221
+ "equation.png",
222
+ ],
223
+ ],
224
+ inputs=[question_input, image_input],
225
+ )
226
+
227
+
228
+ gr.Markdown(value="# Text-to-Image Generation")
229
+
230
+
231
+
232
+ with gr.Row():
233
+ cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
234
+ t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="temperature")
235
+
236
+ prompt_input = gr.Textbox(label="Prompt. (Prompt in more detail can help produce better images!)")
237
+ seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
238
+
239
+ generation_button = gr.Button("Generate Images")
240
+
241
+ image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
242
+
243
+ examples_t2i = gr.Examples(
244
+ label="Text to image generation examples.",
245
+ examples=[
246
+ "Master shifu racoon wearing drip attire as a street gangster.",
247
+ "The face of a beautiful girl",
248
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
249
+ "A glass of red wine on a reflective surface.",
250
+ "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
251
+ "The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.",
252
+ ],
253
+ inputs=prompt_input,
254
+ )
255
+
256
+ understanding_button.click(
257
+ multimodal_understanding,
258
+ inputs=[image_input, question_input, und_seed_input, top_p, temperature],
259
+ outputs=understanding_output
260
+ )
261
+
262
+ generation_button.click(
263
+ fn=generate_image,
264
+ inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
265
+ outputs=image_output
266
+ )
267
+
268
+ demo.launch(share=True)
doge.png ADDED
equation.png ADDED
janus/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+
21
+ # check if python version is above 3.10
22
+ import sys
23
+
24
+ if sys.version_info >= (3, 10):
25
+ print("Python version is above 3.10, patching the collections module.")
26
+ # Monkey patch collections
27
+ import collections
28
+ import collections.abc
29
+
30
+ for type_name in collections.abc.__all__:
31
+ setattr(collections, type_name, getattr(collections.abc, type_name))
janus/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (433 Bytes). View file
 
janus/models/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from .image_processing_vlm import VLMImageProcessor
21
+ from .modeling_vlm import MultiModalityCausalLM
22
+ from .processing_vlm import VLChatProcessor
23
+
24
+ __all__ = [
25
+ "VLMImageProcessor",
26
+ "VLChatProcessor",
27
+ "MultiModalityCausalLM",
28
+ ]
janus/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (391 Bytes). View file
 
janus/models/__pycache__/clip_encoder.cpython-38.pyc ADDED
Binary file (2.74 kB). View file
 
janus/models/__pycache__/image_processing_vlm.cpython-38.pyc ADDED
Binary file (4.98 kB). View file
 
janus/models/__pycache__/modeling_vlm.cpython-38.pyc ADDED
Binary file (7.1 kB). View file
 
janus/models/__pycache__/processing_vlm.cpython-38.pyc ADDED
Binary file (11.1 kB). View file
 
janus/models/__pycache__/projector.cpython-38.pyc ADDED
Binary file (2.23 kB). View file
 
janus/models/__pycache__/siglip_vit.cpython-38.pyc ADDED
Binary file (18.4 kB). View file
 
janus/models/__pycache__/vq_model.cpython-38.pyc ADDED
Binary file (12.5 kB). View file
 
janus/models/clip_encoder.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from typing import Dict, List, Literal, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torchvision.transforms
25
+ from einops import rearrange
26
+
27
+ from janus.models.siglip_vit import create_siglip_vit
28
+
29
+
30
+ class CLIPVisionTower(nn.Module):
31
+ def __init__(
32
+ self,
33
+ model_name: str = "siglip_large_patch16_384",
34
+ image_size: Union[Tuple[int, int], int] = 336,
35
+ select_feature: str = "patch",
36
+ select_layer: int = -2,
37
+ select_layers: list = None,
38
+ ckpt_path: str = "",
39
+ pixel_mean: Optional[List[float]] = None,
40
+ pixel_std: Optional[List[float]] = None,
41
+ **kwargs,
42
+ ):
43
+ super().__init__()
44
+
45
+ self.model_name = model_name
46
+ self.select_feature = select_feature
47
+ self.select_layer = select_layer
48
+ self.select_layers = select_layers
49
+
50
+ vision_tower_params = {
51
+ "model_name": model_name,
52
+ "image_size": image_size,
53
+ "ckpt_path": ckpt_path,
54
+ "select_layer": select_layer,
55
+ }
56
+ vision_tower_params.update(kwargs)
57
+ self.vision_tower, self.forward_kwargs = self.build_vision_tower(
58
+ vision_tower_params
59
+ )
60
+
61
+ if pixel_mean is not None and pixel_std is not None:
62
+ image_norm = torchvision.transforms.Normalize(
63
+ mean=pixel_mean, std=pixel_std
64
+ )
65
+ else:
66
+ image_norm = None
67
+
68
+ self.image_norm = image_norm
69
+
70
+ def build_vision_tower(self, vision_tower_params):
71
+ if self.model_name.startswith("siglip"):
72
+ self.select_feature = "same"
73
+ vision_tower = create_siglip_vit(**vision_tower_params)
74
+ forward_kwargs = dict()
75
+
76
+ elif self.model_name.startswith("sam"):
77
+ vision_tower = create_sam_vit(**vision_tower_params)
78
+ forward_kwargs = dict()
79
+
80
+ else: # huggingface
81
+ from transformers import CLIPVisionModel
82
+
83
+ vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
84
+ forward_kwargs = dict(output_hidden_states=True)
85
+
86
+ return vision_tower, forward_kwargs
87
+
88
+ def feature_select(self, image_forward_outs):
89
+ if isinstance(image_forward_outs, torch.Tensor):
90
+ # the output has been the self.select_layer"s features
91
+ image_features = image_forward_outs
92
+ else:
93
+ image_features = image_forward_outs.hidden_states[self.select_layer]
94
+
95
+ if self.select_feature == "patch":
96
+ # if the output has cls_token
97
+ image_features = image_features[:, 1:]
98
+ elif self.select_feature == "cls_patch":
99
+ image_features = image_features
100
+ elif self.select_feature == "same":
101
+ image_features = image_features
102
+
103
+ else:
104
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
105
+ return image_features
106
+
107
+ def forward(self, images):
108
+ """
109
+
110
+ Args:
111
+ images (torch.Tensor): [b, 3, H, W]
112
+
113
+ Returns:
114
+ image_features (torch.Tensor): [b, n_patch, d]
115
+ """
116
+
117
+ if self.image_norm is not None:
118
+ images = self.image_norm(images)
119
+
120
+ image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
121
+ image_features = self.feature_select(image_forward_outs)
122
+ return image_features
janus/models/image_processing_vlm.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from typing import List, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torchvision
25
+ import torchvision.transforms.functional
26
+ from PIL import Image
27
+ from transformers import AutoImageProcessor, PretrainedConfig
28
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
29
+ from transformers.image_utils import to_numpy_array
30
+ from transformers.utils import logging
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
35
+ IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
36
+ IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
37
+ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
38
+ IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
39
+
40
+
41
+ def expand2square(pil_img, background_color):
42
+ width, height = pil_img.size
43
+ if width == height:
44
+ return pil_img
45
+ elif width > height:
46
+ result = Image.new(pil_img.mode, (width, width), background_color)
47
+ result.paste(pil_img, (0, (width - height) // 2))
48
+ return result
49
+ else:
50
+ result = Image.new(pil_img.mode, (height, height), background_color)
51
+ result.paste(pil_img, ((height - width) // 2, 0))
52
+ return result
53
+
54
+
55
+ class VLMImageProcessorConfig(PretrainedConfig):
56
+ model_type = "deepseek_vlm"
57
+ image_size: int
58
+ min_size: int
59
+ image_mean: Union[Tuple[float, float, float], List[float]]
60
+ image_std: Union[Tuple[float, float, float], List[float]]
61
+ rescale_factor: float
62
+ do_normalize: bool
63
+
64
+ def __init__(
65
+ self,
66
+ image_size: int,
67
+ min_size: int = 14,
68
+ image_mean: Union[Tuple[float, float, float], List[float]] = (
69
+ 0.48145466,
70
+ 0.4578275,
71
+ 0.40821073,
72
+ ),
73
+ image_std: Union[Tuple[float, float, float], List[float]] = (
74
+ 0.26862954,
75
+ 0.26130258,
76
+ 0.27577711,
77
+ ),
78
+ rescale_factor: float = 1.0 / 255.0,
79
+ do_normalize: bool = True,
80
+ **kwargs,
81
+ ):
82
+ self.image_size = image_size
83
+ self.min_size = min_size
84
+ self.image_mean = image_mean
85
+ self.image_std = image_std
86
+ self.rescale_factor = rescale_factor
87
+ self.do_normalize = do_normalize
88
+
89
+ super().__init__(**kwargs)
90
+
91
+
92
+ class VLMImageProcessor(BaseImageProcessor):
93
+ model_input_names = ["pixel_values"]
94
+
95
+ def __init__(
96
+ self,
97
+ image_size: int,
98
+ min_size: int = 14,
99
+ image_mean: Union[Tuple[float, float, float], List[float]] = (
100
+ 0.48145466,
101
+ 0.4578275,
102
+ 0.40821073,
103
+ ),
104
+ image_std: Union[Tuple[float, float, float], List[float]] = (
105
+ 0.26862954,
106
+ 0.26130258,
107
+ 0.27577711,
108
+ ),
109
+ rescale_factor: float = 1.0 / 255.0,
110
+ do_normalize: bool = True,
111
+ **kwargs,
112
+ ):
113
+ super().__init__(**kwargs)
114
+
115
+ self.image_size = image_size
116
+ self.rescale_factor = rescale_factor
117
+ self.image_mean = image_mean
118
+ self.image_std = image_std
119
+ self.min_size = min_size
120
+ self.do_normalize = do_normalize
121
+
122
+ if image_mean is None:
123
+ self.background_color = (127, 127, 127)
124
+ else:
125
+ self.background_color = tuple([int(x * 255) for x in image_mean])
126
+
127
+ def resize(self, pil_img: Image) -> np.ndarray:
128
+ """
129
+
130
+ Args:
131
+ pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
132
+
133
+ Returns:
134
+ x (np.ndarray): [3, self.image_size, self.image_size]
135
+ """
136
+
137
+ width, height = pil_img.size
138
+ max_size = max(width, height)
139
+
140
+ size = [
141
+ max(int(height / max_size * self.image_size), self.min_size),
142
+ max(int(width / max_size * self.image_size), self.min_size),
143
+ ]
144
+
145
+ if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
146
+ print(f"orig size = {pil_img.size}, new size = {size}")
147
+ raise ValueError("Invalid size!")
148
+
149
+ pil_img = torchvision.transforms.functional.resize(
150
+ pil_img,
151
+ size,
152
+ interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC,
153
+ antialias=True,
154
+ )
155
+
156
+ pil_img = expand2square(pil_img, self.background_color)
157
+ x = to_numpy_array(pil_img)
158
+
159
+ # [H, W, 3] -> [3, H, W]
160
+ x = np.transpose(x, (2, 0, 1))
161
+
162
+ return x
163
+
164
+ def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
165
+ # resize and pad to [self.image_size, self.image_size]
166
+ # then convert from [H, W, 3] to [3, H, W]
167
+ images: List[np.ndarray] = [self.resize(image) for image in images]
168
+
169
+ # resacle from [0, 255] -> [0, 1]
170
+ images = [
171
+ self.rescale(
172
+ image=image,
173
+ scale=self.rescale_factor,
174
+ input_data_format="channels_first",
175
+ )
176
+ for image in images
177
+ ]
178
+
179
+ # normalize
180
+ if self.do_normalize:
181
+ images = [
182
+ self.normalize(
183
+ image=image,
184
+ mean=self.image_mean,
185
+ std=self.image_std,
186
+ input_data_format="channels_first",
187
+ )
188
+ for image in images
189
+ ]
190
+
191
+ data = {"pixel_values": images}
192
+ return BatchFeature(data=data, tensor_type=return_tensors)
193
+
194
+ @property
195
+ def default_shape(self):
196
+ return [3, self.image_size, self.image_size]
197
+
198
+
199
+ AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor)
200
+
201
+
202
+ if __name__ == "__main__":
203
+ image_processor = VLMImageProcessor(
204
+ image_size=1024,
205
+ image_mean=IMAGENET_INCEPTION_MEAN,
206
+ image_std=IMAGENET_INCEPTION_STD,
207
+ do_normalize=True,
208
+ )
janus/models/modeling_vlm.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ import torch
21
+ from attrdict import AttrDict
22
+ from einops import rearrange
23
+ from transformers import (
24
+ AutoConfig,
25
+ AutoModelForCausalLM,
26
+ LlamaConfig,
27
+ LlamaForCausalLM,
28
+ PreTrainedModel,
29
+ )
30
+ from transformers.configuration_utils import PretrainedConfig
31
+
32
+ from janus.models.clip_encoder import CLIPVisionTower
33
+ from janus.models.projector import MlpProjector
34
+
35
+
36
+ class vision_head(torch.nn.Module):
37
+ def __init__(self, params):
38
+ super().__init__()
39
+ self.output_mlp_projector = torch.nn.Linear(
40
+ params.n_embed, params.image_token_embed
41
+ )
42
+ self.vision_activation = torch.nn.GELU()
43
+ self.vision_head = torch.nn.Linear(
44
+ params.image_token_embed, params.image_token_size
45
+ )
46
+
47
+ def forward(self, x):
48
+ x = self.output_mlp_projector(x)
49
+ x = self.vision_activation(x)
50
+ x = self.vision_head(x)
51
+ return x
52
+
53
+
54
+ def model_name_to_cls(cls_name):
55
+ if "MlpProjector" in cls_name:
56
+ cls = MlpProjector
57
+
58
+ elif "CLIPVisionTower" in cls_name:
59
+ cls = CLIPVisionTower
60
+
61
+ elif "VQ" in cls_name:
62
+ from janus.models.vq_model import VQ_models
63
+
64
+ cls = VQ_models[cls_name]
65
+ elif "vision_head" in cls_name:
66
+ cls = vision_head
67
+ else:
68
+ raise ValueError(f"class_name {cls_name} is invalid.")
69
+
70
+ return cls
71
+
72
+
73
+ class VisionConfig(PretrainedConfig):
74
+ model_type = "vision"
75
+ cls: str = ""
76
+ params: AttrDict = {}
77
+
78
+ def __init__(self, **kwargs):
79
+ super().__init__(**kwargs)
80
+
81
+ self.cls = kwargs.get("cls", "")
82
+ if not isinstance(self.cls, str):
83
+ self.cls = self.cls.__name__
84
+
85
+ self.params = AttrDict(kwargs.get("params", {}))
86
+
87
+
88
+ class AlignerConfig(PretrainedConfig):
89
+ model_type = "aligner"
90
+ cls: str = ""
91
+ params: AttrDict = {}
92
+
93
+ def __init__(self, **kwargs):
94
+ super().__init__(**kwargs)
95
+
96
+ self.cls = kwargs.get("cls", "")
97
+ if not isinstance(self.cls, str):
98
+ self.cls = self.cls.__name__
99
+
100
+ self.params = AttrDict(kwargs.get("params", {}))
101
+
102
+
103
+ class GenVisionConfig(PretrainedConfig):
104
+ model_type = "gen_vision"
105
+ cls: str = ""
106
+ params: AttrDict = {}
107
+
108
+ def __init__(self, **kwargs):
109
+ super().__init__(**kwargs)
110
+
111
+ self.cls = kwargs.get("cls", "")
112
+ if not isinstance(self.cls, str):
113
+ self.cls = self.cls.__name__
114
+
115
+ self.params = AttrDict(kwargs.get("params", {}))
116
+
117
+
118
+ class GenAlignerConfig(PretrainedConfig):
119
+ model_type = "gen_aligner"
120
+ cls: str = ""
121
+ params: AttrDict = {}
122
+
123
+ def __init__(self, **kwargs):
124
+ super().__init__(**kwargs)
125
+
126
+ self.cls = kwargs.get("cls", "")
127
+ if not isinstance(self.cls, str):
128
+ self.cls = self.cls.__name__
129
+
130
+ self.params = AttrDict(kwargs.get("params", {}))
131
+
132
+
133
+ class GenHeadConfig(PretrainedConfig):
134
+ model_type = "gen_head"
135
+ cls: str = ""
136
+ params: AttrDict = {}
137
+
138
+ def __init__(self, **kwargs):
139
+ super().__init__(**kwargs)
140
+
141
+ self.cls = kwargs.get("cls", "")
142
+ if not isinstance(self.cls, str):
143
+ self.cls = self.cls.__name__
144
+
145
+ self.params = AttrDict(kwargs.get("params", {}))
146
+
147
+
148
+ class MultiModalityConfig(PretrainedConfig):
149
+ model_type = "multi_modality"
150
+ vision_config: VisionConfig
151
+ aligner_config: AlignerConfig
152
+
153
+ gen_vision_config: GenVisionConfig
154
+ gen_aligner_config: GenAlignerConfig
155
+ gen_head_config: GenHeadConfig
156
+
157
+ language_config: LlamaConfig
158
+
159
+ def __init__(self, **kwargs):
160
+ super().__init__(**kwargs)
161
+ vision_config = kwargs.get("vision_config", {})
162
+ self.vision_config = VisionConfig(**vision_config)
163
+
164
+ aligner_config = kwargs.get("aligner_config", {})
165
+ self.aligner_config = AlignerConfig(**aligner_config)
166
+
167
+ gen_vision_config = kwargs.get("gen_vision_config", {})
168
+ self.gen_vision_config = GenVisionConfig(**gen_vision_config)
169
+
170
+ gen_aligner_config = kwargs.get("gen_aligner_config", {})
171
+ self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
172
+
173
+ gen_head_config = kwargs.get("gen_head_config", {})
174
+ self.gen_head_config = GenHeadConfig(**gen_head_config)
175
+
176
+ language_config = kwargs.get("language_config", {})
177
+ if isinstance(language_config, LlamaConfig):
178
+ self.language_config = language_config
179
+ else:
180
+ self.language_config = LlamaConfig(**language_config)
181
+
182
+
183
+ class MultiModalityPreTrainedModel(PreTrainedModel):
184
+ config_class = MultiModalityConfig
185
+ base_model_prefix = "multi_modality"
186
+ _no_split_modules = []
187
+ _skip_keys_device_placement = "past_key_values"
188
+
189
+
190
+ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
191
+ def __init__(self, config: MultiModalityConfig):
192
+ super().__init__(config)
193
+
194
+ vision_config = config.vision_config
195
+ vision_cls = model_name_to_cls(vision_config.cls)
196
+ self.vision_model = vision_cls(**vision_config.params)
197
+
198
+ aligner_config = config.aligner_config
199
+ aligner_cls = model_name_to_cls(aligner_config.cls)
200
+ self.aligner = aligner_cls(aligner_config.params)
201
+
202
+ gen_vision_config = config.gen_vision_config
203
+ gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
204
+ self.gen_vision_model = gen_vision_cls()
205
+
206
+ gen_aligner_config = config.gen_aligner_config
207
+ gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
208
+ self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
209
+
210
+ gen_head_config = config.gen_head_config
211
+ gen_head_cls = model_name_to_cls(gen_head_config.cls)
212
+ self.gen_head = gen_head_cls(gen_head_config.params)
213
+
214
+ self.gen_embed = torch.nn.Embedding(
215
+ gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed
216
+ )
217
+
218
+ language_config = config.language_config
219
+ self.language_model = LlamaForCausalLM(language_config)
220
+
221
+ def prepare_inputs_embeds(
222
+ self,
223
+ input_ids: torch.LongTensor,
224
+ pixel_values: torch.FloatTensor,
225
+ images_seq_mask: torch.LongTensor,
226
+ images_emb_mask: torch.LongTensor,
227
+ **kwargs,
228
+ ):
229
+ """
230
+
231
+ Args:
232
+ input_ids (torch.LongTensor): [b, T]
233
+ pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
234
+ images_seq_mask (torch.BoolTensor): [b, T]
235
+ images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
236
+
237
+ assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
238
+
239
+ Returns:
240
+ input_embeds (torch.Tensor): [b, T, D]
241
+ """
242
+
243
+ bs, n = pixel_values.shape[0:2]
244
+ images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
245
+ # [b x n, T2, D]
246
+ images_embeds = self.aligner(self.vision_model(images))
247
+
248
+ # [b x n, T2, D] -> [b, n x T2, D]
249
+ images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
250
+ # [b, n, T2] -> [b, n x T2]
251
+ images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
252
+
253
+ # [b, T, D]
254
+ input_ids[input_ids < 0] = 0 # ignore the image embeddings
255
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
256
+
257
+ # replace with the image embeddings
258
+ inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
259
+
260
+ return inputs_embeds
261
+
262
+ def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
263
+ return self.gen_aligner(self.gen_embed(image_ids))
264
+
265
+
266
+ AutoConfig.register("vision", VisionConfig)
267
+ AutoConfig.register("aligner", AlignerConfig)
268
+ AutoConfig.register("gen_vision", GenVisionConfig)
269
+ AutoConfig.register("gen_aligner", GenAlignerConfig)
270
+ AutoConfig.register("gen_head", GenHeadConfig)
271
+ AutoConfig.register("multi_modality", MultiModalityConfig)
272
+ AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)
janus/models/processing_vlm.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Dict, List
22
+
23
+ import torch
24
+ from PIL.Image import Image
25
+ from transformers import LlamaTokenizerFast
26
+ from transformers.processing_utils import ProcessorMixin
27
+
28
+ from janus.models.image_processing_vlm import VLMImageProcessor
29
+ from janus.utils.conversation import get_conv_template
30
+
31
+
32
+ class DictOutput(object):
33
+ def keys(self):
34
+ return self.__dict__.keys()
35
+
36
+ def __getitem__(self, item):
37
+ return self.__dict__[item]
38
+
39
+ def __setitem__(self, key, value):
40
+ self.__dict__[key] = value
41
+
42
+
43
+ @dataclass
44
+ class VLChatProcessorOutput(DictOutput):
45
+ sft_format: str
46
+ input_ids: torch.Tensor
47
+ pixel_values: torch.Tensor
48
+ num_image_tokens: torch.IntTensor
49
+
50
+ def __len__(self):
51
+ return len(self.input_ids)
52
+
53
+
54
+ @dataclass
55
+ class BatchedVLChatProcessorOutput(DictOutput):
56
+ sft_format: List[str]
57
+ input_ids: torch.Tensor
58
+ pixel_values: torch.Tensor
59
+ attention_mask: torch.Tensor
60
+ images_seq_mask: torch.BoolTensor
61
+ images_emb_mask: torch.BoolTensor
62
+
63
+ def to(self, device, dtype=torch.bfloat16):
64
+ self.input_ids = self.input_ids.to(device)
65
+ self.attention_mask = self.attention_mask.to(device)
66
+ self.images_seq_mask = self.images_seq_mask.to(device)
67
+ self.images_emb_mask = self.images_emb_mask.to(device)
68
+ self.pixel_values = self.pixel_values.to(device=device, dtype=dtype)
69
+ return self
70
+
71
+
72
+ class VLChatProcessor(ProcessorMixin):
73
+ image_processor_class = "AutoImageProcessor"
74
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
75
+
76
+ attributes = ["image_processor", "tokenizer"]
77
+
78
+ system_prompt = (
79
+ "You are a helpful language and vision assistant. "
80
+ "You are able to understand the visual content that the user provides, "
81
+ "and assist the user with a variety of tasks using natural language."
82
+ )
83
+
84
+ def __init__(
85
+ self,
86
+ image_processor: VLMImageProcessor,
87
+ tokenizer: LlamaTokenizerFast,
88
+ image_tag: str = "<image_placeholder>",
89
+ image_start_tag: str = "<begin_of_image>",
90
+ image_end_tag: str = "<end_of_image>",
91
+ pad_tag: str = "<|▁pad▁|>",
92
+ num_image_tokens: int = 576,
93
+ add_special_token: bool = False,
94
+ sft_format: str = "deepseek",
95
+ mask_prompt: bool = True,
96
+ ignore_id: int = -100,
97
+ **kwargs,
98
+ ):
99
+ self.image_processor = image_processor
100
+ self.tokenizer = tokenizer
101
+
102
+ image_id = self.tokenizer.vocab.get(image_tag)
103
+ if image_id is None:
104
+ special_tokens = [image_tag]
105
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
106
+ self.tokenizer.add_special_tokens(special_tokens_dict)
107
+ print(f"Add image tag = {image_tag} to the tokenizer")
108
+
109
+ self.image_tag = image_tag
110
+ self.image_start_tag = image_start_tag
111
+ self.image_end_tag = image_end_tag
112
+ self.pad_tag = pad_tag
113
+
114
+ self.num_image_tokens = num_image_tokens
115
+ self.add_special_token = add_special_token
116
+ self.sft_format = sft_format
117
+ self.mask_prompt = mask_prompt
118
+ self.ignore_id = ignore_id
119
+
120
+ super().__init__(
121
+ image_processor,
122
+ tokenizer,
123
+ image_tag,
124
+ num_image_tokens,
125
+ add_special_token,
126
+ sft_format,
127
+ mask_prompt,
128
+ ignore_id,
129
+ **kwargs,
130
+ )
131
+
132
+ def new_chat_template(self):
133
+ conv = get_conv_template(self.sft_format)
134
+ conv.set_system_message(self.system_prompt)
135
+ return conv
136
+
137
+ def apply_sft_template_for_multi_turn_prompts(
138
+ self,
139
+ conversations: List[Dict[str, str]],
140
+ sft_format: str = "deepseek",
141
+ system_prompt: str = "",
142
+ ):
143
+ """
144
+ Applies the SFT template to conversation.
145
+
146
+ An example of conversation:
147
+ conversation = [
148
+ {
149
+ "role": "User",
150
+ "content": "<image_placeholder> is Figure 1.\n<image_placeholder> is Figure 2.\nWhich image is brighter?",
151
+ "images": [
152
+ "./multi-images/attribute_comparison_1.png",
153
+ "./multi-images/attribute_comparison_2.png"
154
+ ]
155
+ },
156
+ {
157
+ "role": "Assistant",
158
+ "content": ""
159
+ }
160
+ ]
161
+
162
+ Args:
163
+ conversations (List[Dict]): A conversation with a List of Dict[str, str] text.
164
+ sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
165
+ system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
166
+
167
+ Returns:
168
+ sft_prompt (str): The formatted text.
169
+ """
170
+
171
+ conv = get_conv_template(sft_format)
172
+ conv.set_system_message(system_prompt)
173
+ for message in conversations:
174
+ conv.append_message(message["role"], message["content"].strip())
175
+ sft_prompt = conv.get_prompt().strip()
176
+
177
+ return sft_prompt
178
+
179
+ @property
180
+ def image_token(self):
181
+ return self.image_tag
182
+
183
+ @property
184
+ def image_id(self):
185
+ image_id = self.tokenizer.vocab.get(self.image_tag)
186
+ return image_id
187
+
188
+ @property
189
+ def image_start_id(self):
190
+ image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
191
+ return image_start_id
192
+
193
+ @property
194
+ def image_end_id(self):
195
+ image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
196
+ return image_end_id
197
+
198
+ @property
199
+ def image_start_token(self):
200
+ return self.image_start_tag
201
+
202
+ @property
203
+ def image_end_token(self):
204
+ return self.image_end_tag
205
+
206
+ @property
207
+ def pad_id(self):
208
+ pad_id = self.tokenizer.vocab.get(self.pad_tag)
209
+ # pad_id = self.tokenizer.pad_token_id
210
+ # if pad_id is None:
211
+ # pad_id = self.tokenizer.eos_token_id
212
+
213
+ return pad_id
214
+
215
+ def add_image_token(
216
+ self,
217
+ image_indices: List[int],
218
+ input_ids: torch.LongTensor,
219
+ ):
220
+ """
221
+
222
+ Args:
223
+ image_indices (List[int]): [index_0, index_1, ..., index_j]
224
+ input_ids (torch.LongTensor): [N]
225
+
226
+ Returns:
227
+ input_ids (torch.LongTensor): [N + image tokens]
228
+ num_image_tokens (torch.IntTensor): [n_images]
229
+ """
230
+
231
+ input_slices = []
232
+
233
+ start = 0
234
+ for index in image_indices:
235
+ if self.add_special_token:
236
+ end = index + 1
237
+ else:
238
+ end = index
239
+
240
+ # original text tokens
241
+ input_slices.append(input_ids[start:end])
242
+
243
+ # add boi, image tokens, eoi and set the mask as False
244
+ input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))
245
+ input_slices.append(
246
+ self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
247
+ )
248
+ input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))
249
+ start = index + 1
250
+
251
+ # the left part
252
+ input_slices.append(input_ids[start:])
253
+
254
+ # concat all slices
255
+ input_ids = torch.cat(input_slices, dim=0)
256
+ num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
257
+
258
+ return input_ids, num_image_tokens
259
+
260
+ def process_one(
261
+ self,
262
+ prompt: str = None,
263
+ conversations: List[Dict[str, str]] = None,
264
+ images: List[Image] = None,
265
+ **kwargs,
266
+ ):
267
+ """
268
+
269
+ Args:
270
+ prompt (str): the formatted prompt;
271
+ conversations (List[Dict]): conversations with a list of messages;
272
+ images (List[ImageType]): the list of images;
273
+ **kwargs:
274
+
275
+ Returns:
276
+ outputs (BaseProcessorOutput): the output of the processor,
277
+ - input_ids (torch.LongTensor): [N + image tokens]
278
+ - target_ids (torch.LongTensor): [N + image tokens]
279
+ - images (torch.FloatTensor): [n_images, 3, H, W]
280
+ - image_id (int): the id of the image token
281
+ - num_image_tokens (List[int]): the number of image tokens
282
+ """
283
+
284
+ assert (
285
+ prompt is None or conversations is None
286
+ ), "prompt and conversations cannot be used at the same time."
287
+
288
+ if prompt is None:
289
+ # apply sft format
290
+ sft_format = self.apply_sft_template_for_multi_turn_prompts(
291
+ conversations=conversations,
292
+ sft_format=self.sft_format,
293
+ system_prompt=self.system_prompt,
294
+ )
295
+ else:
296
+ sft_format = prompt
297
+
298
+ # tokenize
299
+ input_ids = self.tokenizer.encode(sft_format)
300
+ input_ids = torch.LongTensor(input_ids)
301
+
302
+ # add image tokens to the input_ids
303
+ image_token_mask: torch.BoolTensor = input_ids == self.image_id
304
+ image_indices = image_token_mask.nonzero()
305
+ input_ids, num_image_tokens = self.add_image_token(
306
+ image_indices=image_indices,
307
+ input_ids=input_ids,
308
+ )
309
+
310
+ # load images
311
+ images_outputs = self.image_processor(images, return_tensors="pt")
312
+
313
+ prepare = VLChatProcessorOutput(
314
+ sft_format=sft_format,
315
+ input_ids=input_ids,
316
+ pixel_values=images_outputs.pixel_values,
317
+ num_image_tokens=num_image_tokens,
318
+ )
319
+
320
+ return prepare
321
+
322
+ def __call__(
323
+ self,
324
+ *,
325
+ prompt: str = None,
326
+ conversations: List[Dict[str, str]] = None,
327
+ images: List[Image] = None,
328
+ force_batchify: bool = True,
329
+ **kwargs,
330
+ ):
331
+ """
332
+
333
+ Args:
334
+ prompt (str): the formatted prompt;
335
+ conversations (List[Dict]): conversations with a list of messages;
336
+ images (List[ImageType]): the list of images;
337
+ force_batchify (bool): force batchify the inputs;
338
+ **kwargs:
339
+
340
+ Returns:
341
+ outputs (BaseProcessorOutput): the output of the processor,
342
+ - input_ids (torch.LongTensor): [N + image tokens]
343
+ - images (torch.FloatTensor): [n_images, 3, H, W]
344
+ - image_id (int): the id of the image token
345
+ - num_image_tokens (List[int]): the number of image tokens
346
+ """
347
+
348
+ prepare = self.process_one(
349
+ prompt=prompt, conversations=conversations, images=images
350
+ )
351
+
352
+ if force_batchify:
353
+ prepare = self.batchify([prepare])
354
+
355
+ return prepare
356
+
357
+ def batchify(
358
+ self, prepare_list: List[VLChatProcessorOutput]
359
+ ) -> BatchedVLChatProcessorOutput:
360
+ """
361
+ Preprocesses the inputs for multimodal inference.
362
+
363
+ Args:
364
+ prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
365
+
366
+ Returns:
367
+ BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
368
+ """
369
+
370
+ batch_size = len(prepare_list)
371
+ sft_format = []
372
+ n_images = []
373
+ seq_lens = []
374
+ for prepare in prepare_list:
375
+ n_images.append(len(prepare.num_image_tokens))
376
+ seq_lens.append(len(prepare))
377
+
378
+ input_token_max_len = max(seq_lens)
379
+ max_n_images = max(1, max(n_images))
380
+
381
+ batched_input_ids = torch.full(
382
+ (batch_size, input_token_max_len), self.pad_id
383
+ ).long() # FIXME
384
+ batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
385
+ batched_pixel_values = torch.zeros(
386
+ (batch_size, max_n_images, *self.image_processor.default_shape)
387
+ ).float()
388
+ batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
389
+ batched_images_emb_mask = torch.zeros(
390
+ (batch_size, max_n_images, self.num_image_tokens)
391
+ ).bool()
392
+
393
+ for i, prepare in enumerate(prepare_list):
394
+ input_ids = prepare.input_ids
395
+ seq_len = len(prepare)
396
+ n_image = len(prepare.num_image_tokens)
397
+ # left-padding
398
+ batched_attention_mask[i, -seq_len:] = 1
399
+ batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
400
+ batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
401
+
402
+ if n_image > 0:
403
+ batched_pixel_values[i, :n_image] = prepare.pixel_values
404
+ for j, n_image_tokens in enumerate(prepare.num_image_tokens):
405
+ batched_images_emb_mask[i, j, :n_image_tokens] = True
406
+
407
+ sft_format.append(prepare.sft_format)
408
+
409
+ batched_prepares = BatchedVLChatProcessorOutput(
410
+ input_ids=batched_input_ids,
411
+ attention_mask=batched_attention_mask,
412
+ pixel_values=batched_pixel_values,
413
+ images_seq_mask=batched_images_seq_mask,
414
+ images_emb_mask=batched_images_emb_mask,
415
+ sft_format=sft_format,
416
+ )
417
+
418
+ return batched_prepares
janus/models/projector.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from typing import Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ from attrdict import AttrDict
25
+
26
+
27
+ class MlpProjector(nn.Module):
28
+ def __init__(self, cfg):
29
+ super().__init__()
30
+
31
+ self.cfg = cfg
32
+
33
+ if cfg.projector_type == "identity":
34
+ modules = nn.Identity()
35
+
36
+ elif cfg.projector_type == "linear":
37
+ modules = nn.Linear(cfg.input_dim, cfg.n_embed)
38
+
39
+ elif cfg.projector_type == "mlp_gelu":
40
+ mlp_depth = cfg.get("depth", 1)
41
+ modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
42
+ for _ in range(1, mlp_depth):
43
+ modules.append(nn.GELU())
44
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
45
+ modules = nn.Sequential(*modules)
46
+
47
+ elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
48
+ mlp_depth = cfg.get("depth", 1)
49
+ self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
50
+ self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
51
+
52
+ modules = []
53
+ for _ in range(1, mlp_depth):
54
+ modules.append(nn.GELU())
55
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
56
+ modules = nn.Sequential(*modules)
57
+
58
+ else:
59
+ raise ValueError(f"Unknown projector type: {cfg.projector_type}")
60
+
61
+ self.layers = modules
62
+
63
+ def forward(
64
+ self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
65
+ ):
66
+ """
67
+
68
+ Args:
69
+ x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor,
70
+ then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x);
71
+ otherwise it is the feature from the single vision encoder.
72
+
73
+ Returns:
74
+ x (torch.Tensor): [b, s, c]
75
+ """
76
+
77
+ if isinstance(x_or_tuple, tuple):
78
+ # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
79
+ high_x, low_x = x_or_tuple
80
+ high_x = self.high_up_proj(high_x)
81
+ low_x = self.low_up_proj(low_x)
82
+ x = torch.concat([high_x, low_x], dim=-1)
83
+ else:
84
+ x = x_or_tuple
85
+
86
+ return self.layers(x)
87
+
88
+
89
+ if __name__ == "__main__":
90
+ cfg = AttrDict(
91
+ input_dim=1024,
92
+ n_embed=2048,
93
+ depth=2,
94
+ projector_type="low_high_hybrid_split_mlp_gelu",
95
+ )
96
+ inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024))
97
+
98
+ m = MlpProjector(cfg)
99
+ out = m(inputs)
100
+ print(out.shape)
janus/models/siglip_vit.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
21
+ import math
22
+ import warnings
23
+ from dataclasses import dataclass
24
+ from functools import partial
25
+ from typing import (
26
+ Callable,
27
+ Dict,
28
+ Final,
29
+ List,
30
+ Literal,
31
+ Optional,
32
+ Sequence,
33
+ Set,
34
+ Tuple,
35
+ Type,
36
+ Union,
37
+ )
38
+
39
+ import torch
40
+ import torch.nn as nn
41
+ import torch.nn.functional as F
42
+ from timm.layers import (
43
+ AttentionPoolLatent,
44
+ DropPath,
45
+ LayerType,
46
+ Mlp,
47
+ PatchDropout,
48
+ PatchEmbed,
49
+ resample_abs_pos_embed,
50
+ )
51
+ from timm.models._manipulate import checkpoint_seq, named_apply
52
+
53
+
54
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
55
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
56
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
57
+ def norm_cdf(x):
58
+ # Computes standard normal cumulative distribution function
59
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
60
+
61
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
62
+ warnings.warn(
63
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
64
+ "The distribution of values may be incorrect.",
65
+ stacklevel=2,
66
+ )
67
+
68
+ with torch.no_grad():
69
+ # Values are generated by using a truncated uniform distribution and
70
+ # then using the inverse CDF for the normal distribution.
71
+ # Get upper and lower cdf values
72
+ l = norm_cdf((a - mean) / std) # noqa: E741
73
+ u = norm_cdf((b - mean) / std)
74
+
75
+ # Uniformly fill tensor with values from [l, u], then translate to
76
+ # [2l-1, 2u-1].
77
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
78
+
79
+ # Use inverse cdf transform for normal distribution to get truncated
80
+ # standard normal
81
+ tensor.erfinv_()
82
+
83
+ # Transform to proper mean, std
84
+ tensor.mul_(std * math.sqrt(2.0))
85
+ tensor.add_(mean)
86
+
87
+ # Clamp to ensure it's in the proper range
88
+ tensor.clamp_(min=a, max=b)
89
+ return tensor
90
+
91
+
92
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
93
+ # type: (torch.Tensor, float, float, float, float) -> torch.Tensor
94
+ r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
95
+ convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype.
96
+ Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
97
+ from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
98
+ with values outside :math:`[a, b]` redrawn until they are within
99
+ the bounds. The method used for generating the random values works
100
+ best when :math:`a \leq \text{mean} \leq b`.
101
+ Args:
102
+ tensor: an n-dimensional `torch.Tensor`
103
+ mean: the mean of the normal distribution
104
+ std: the standard deviation of the normal distribution
105
+ a: the minimum cutoff value
106
+ b: the maximum cutoff value
107
+ Examples:
108
+ >>> w = torch.empty(3, 5)
109
+ >>> nn.init.trunc_normal_(w)
110
+ """
111
+
112
+ with torch.no_grad():
113
+ dtype = tensor.dtype
114
+ tensor_fp32 = tensor.float()
115
+ tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
116
+ tensor_dtype = tensor_fp32.to(dtype=dtype)
117
+ tensor.copy_(tensor_dtype)
118
+
119
+
120
+ def init_weights(self):
121
+ if self.pos_embed is not None:
122
+ trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
123
+ trunc_normal_(self.latent, std=self.latent_dim**-0.5)
124
+
125
+
126
+ def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
127
+ """ViT weight initialization, original timm impl (for reproducibility)"""
128
+ if isinstance(module, nn.Linear):
129
+ trunc_normal_(module.weight, std=0.02)
130
+ if module.bias is not None:
131
+ nn.init.zeros_(module.bias)
132
+ elif hasattr(module, "init_weights"):
133
+ module.init_weights()
134
+
135
+
136
+ class Attention(nn.Module):
137
+ fused_attn: Final[bool]
138
+
139
+ def __init__(
140
+ self,
141
+ dim: int,
142
+ num_heads: int = 8,
143
+ qkv_bias: bool = False,
144
+ qk_norm: bool = False,
145
+ attn_drop: float = 0.0,
146
+ proj_drop: float = 0.0,
147
+ norm_layer: nn.Module = nn.LayerNorm,
148
+ ) -> None:
149
+ super().__init__()
150
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
151
+ self.num_heads = num_heads
152
+ self.head_dim = dim // num_heads
153
+ self.scale = self.head_dim**-0.5
154
+ # self.fused_attn = use_fused_attn()
155
+ self.fused_attn = True
156
+
157
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
158
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
159
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
160
+ self.attn_drop = nn.Dropout(attn_drop)
161
+ self.proj = nn.Linear(dim, dim)
162
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ B, N, C = x.shape
166
+ qkv = (
167
+ self.qkv(x)
168
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
169
+ .permute(2, 0, 3, 1, 4)
170
+ )
171
+ q, k, v = qkv.unbind(0)
172
+ q, k = self.q_norm(q), self.k_norm(k)
173
+
174
+ if self.fused_attn:
175
+ x = F.scaled_dot_product_attention(
176
+ q,
177
+ k,
178
+ v,
179
+ dropout_p=self.attn_drop.p if self.training else 0.0,
180
+ )
181
+ else:
182
+ q = q * self.scale
183
+ attn = q @ k.transpose(-2, -1)
184
+ attn = attn.softmax(dim=-1)
185
+ attn = self.attn_drop(attn)
186
+ x = attn @ v
187
+
188
+ x = x.transpose(1, 2).reshape(B, N, C)
189
+ x = self.proj(x)
190
+ x = self.proj_drop(x)
191
+ return x
192
+
193
+
194
+ class LayerScale(nn.Module):
195
+ def __init__(
196
+ self,
197
+ dim: int,
198
+ init_values: float = 1e-5,
199
+ inplace: bool = False,
200
+ ) -> None:
201
+ super().__init__()
202
+ self.inplace = inplace
203
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
204
+
205
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
206
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
207
+
208
+
209
+ class Block(nn.Module):
210
+ def __init__(
211
+ self,
212
+ dim: int,
213
+ num_heads: int,
214
+ mlp_ratio: float = 4.0,
215
+ qkv_bias: bool = False,
216
+ qk_norm: bool = False,
217
+ proj_drop: float = 0.0,
218
+ attn_drop: float = 0.0,
219
+ init_values: Optional[float] = None,
220
+ drop_path: float = 0.0,
221
+ act_layer: nn.Module = nn.GELU,
222
+ norm_layer: nn.Module = nn.LayerNorm,
223
+ mlp_layer: nn.Module = Mlp,
224
+ ) -> None:
225
+ super().__init__()
226
+ self.norm1 = norm_layer(dim)
227
+ self.attn = Attention(
228
+ dim,
229
+ num_heads=num_heads,
230
+ qkv_bias=qkv_bias,
231
+ qk_norm=qk_norm,
232
+ attn_drop=attn_drop,
233
+ proj_drop=proj_drop,
234
+ norm_layer=norm_layer,
235
+ )
236
+ self.ls1 = (
237
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
238
+ )
239
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
240
+
241
+ self.norm2 = norm_layer(dim)
242
+ self.mlp = mlp_layer(
243
+ in_features=dim,
244
+ hidden_features=int(dim * mlp_ratio),
245
+ act_layer=act_layer,
246
+ drop=proj_drop,
247
+ )
248
+ self.ls2 = (
249
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
250
+ )
251
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
252
+
253
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
254
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
255
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
256
+ return x
257
+
258
+
259
+ class VisionTransformer(nn.Module):
260
+ """Vision Transformer
261
+
262
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
263
+ - https://arxiv.org/abs/2010.11929
264
+ """
265
+
266
+ dynamic_img_size: Final[bool]
267
+
268
+ def __init__(
269
+ self,
270
+ img_size: Union[int, Tuple[int, int]] = 224,
271
+ patch_size: Union[int, Tuple[int, int]] = 16,
272
+ in_chans: int = 3,
273
+ num_classes: int = 1000,
274
+ global_pool: Literal["", "avg", "token", "map"] = "token",
275
+ embed_dim: int = 768,
276
+ depth: int = 12,
277
+ num_heads: int = 12,
278
+ mlp_ratio: float = 4.0,
279
+ qkv_bias: bool = True,
280
+ qk_norm: bool = False,
281
+ init_values: Optional[float] = None,
282
+ class_token: bool = True,
283
+ no_embed_class: bool = False,
284
+ reg_tokens: int = 0,
285
+ pre_norm: bool = False,
286
+ fc_norm: Optional[bool] = None,
287
+ dynamic_img_size: bool = False,
288
+ dynamic_img_pad: bool = False,
289
+ drop_rate: float = 0.0,
290
+ pos_drop_rate: float = 0.0,
291
+ patch_drop_rate: float = 0.0,
292
+ proj_drop_rate: float = 0.0,
293
+ attn_drop_rate: float = 0.0,
294
+ drop_path_rate: float = 0.0,
295
+ weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
296
+ embed_layer: Callable = PatchEmbed,
297
+ norm_layer: Optional[LayerType] = None,
298
+ act_layer: Optional[LayerType] = None,
299
+ block_fn: Type[nn.Module] = Block,
300
+ mlp_layer: Type[nn.Module] = Mlp,
301
+ ignore_head: bool = False,
302
+ ) -> None:
303
+ """
304
+ Args:
305
+ img_size: Input image size.
306
+ patch_size: Patch size.
307
+ in_chans: Number of image input channels.
308
+ num_classes: Mumber of classes for classification head.
309
+ global_pool: Type of global pooling for final sequence (default: 'token').
310
+ embed_dim: Transformer embedding dimension.
311
+ depth: Depth of transformer.
312
+ num_heads: Number of attention heads.
313
+ mlp_ratio: Ratio of mlp hidden dim to embedding dim.
314
+ qkv_bias: Enable bias for qkv projections if True.
315
+ init_values: Layer-scale init values (layer-scale enabled if not None).
316
+ class_token: Use class token.
317
+ no_embed_class: Don't include position embeddings for class (or reg) tokens.
318
+ reg_tokens: Number of register tokens.
319
+ fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
320
+ drop_rate: Head dropout rate.
321
+ pos_drop_rate: Position embedding dropout rate.
322
+ attn_drop_rate: Attention dropout rate.
323
+ drop_path_rate: Stochastic depth rate.
324
+ weight_init: Weight initialization scheme.
325
+ embed_layer: Patch embedding layer.
326
+ norm_layer: Normalization layer.
327
+ act_layer: MLP activation layer.
328
+ block_fn: Transformer block layer.
329
+ """
330
+ super().__init__()
331
+ assert global_pool in ("", "avg", "token", "map")
332
+ assert class_token or global_pool != "token"
333
+ use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
334
+ # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
335
+ # act_layer = get_act_layer(act_layer) or nn.GELU
336
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
337
+ act_layer = nn.GELU
338
+
339
+ self.num_classes = num_classes
340
+ self.global_pool = global_pool
341
+ self.num_features = self.embed_dim = (
342
+ embed_dim # num_features for consistency with other models
343
+ )
344
+ self.num_prefix_tokens = 1 if class_token else 0
345
+ self.num_prefix_tokens += reg_tokens
346
+ self.num_reg_tokens = reg_tokens
347
+ self.has_class_token = class_token
348
+ self.no_embed_class = (
349
+ no_embed_class # don't embed prefix positions (includes reg)
350
+ )
351
+ self.dynamic_img_size = dynamic_img_size
352
+ self.grad_checkpointing = False
353
+ self.ignore_head = ignore_head
354
+
355
+ embed_args = {}
356
+ if dynamic_img_size:
357
+ # flatten deferred until after pos embed
358
+ embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
359
+ self.patch_embed = embed_layer(
360
+ img_size=img_size,
361
+ patch_size=patch_size,
362
+ in_chans=in_chans,
363
+ embed_dim=embed_dim,
364
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
365
+ dynamic_img_pad=dynamic_img_pad,
366
+ **embed_args,
367
+ )
368
+ num_patches = self.patch_embed.num_patches
369
+
370
+ self.cls_token = (
371
+ nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
372
+ )
373
+ self.reg_token = (
374
+ nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
375
+ )
376
+ embed_len = (
377
+ num_patches if no_embed_class else num_patches + self.num_prefix_tokens
378
+ )
379
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
380
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
381
+ if patch_drop_rate > 0:
382
+ self.patch_drop = PatchDropout(
383
+ patch_drop_rate,
384
+ num_prefix_tokens=self.num_prefix_tokens,
385
+ )
386
+ else:
387
+ self.patch_drop = nn.Identity()
388
+ self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
389
+
390
+ dpr = [
391
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
392
+ ] # stochastic depth decay rule
393
+ self.blocks = nn.Sequential(
394
+ *[
395
+ block_fn(
396
+ dim=embed_dim,
397
+ num_heads=num_heads,
398
+ mlp_ratio=mlp_ratio,
399
+ qkv_bias=qkv_bias,
400
+ qk_norm=qk_norm,
401
+ init_values=init_values,
402
+ proj_drop=proj_drop_rate,
403
+ attn_drop=attn_drop_rate,
404
+ drop_path=dpr[i],
405
+ norm_layer=norm_layer,
406
+ act_layer=act_layer,
407
+ mlp_layer=mlp_layer,
408
+ )
409
+ for i in range(depth)
410
+ ]
411
+ )
412
+ self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
413
+
414
+ # Classifier Head
415
+ if global_pool == "map":
416
+ AttentionPoolLatent.init_weights = init_weights
417
+ self.attn_pool = AttentionPoolLatent(
418
+ self.embed_dim,
419
+ num_heads=num_heads,
420
+ mlp_ratio=mlp_ratio,
421
+ norm_layer=norm_layer,
422
+ )
423
+ else:
424
+ self.attn_pool = None
425
+ self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
426
+ self.head_drop = nn.Dropout(drop_rate)
427
+ self.head = (
428
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
429
+ )
430
+
431
+ if weight_init != "skip":
432
+ self.init_weights(weight_init)
433
+
434
+ def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
435
+ assert mode in ("jax", "jax_nlhb", "moco", "")
436
+ # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
437
+ trunc_normal_(self.pos_embed, std=0.02)
438
+ if self.cls_token is not None:
439
+ nn.init.normal_(self.cls_token, std=1e-6)
440
+ named_apply(init_weights_vit_timm, self)
441
+
442
+ @torch.jit.ignore
443
+ def no_weight_decay(self) -> Set:
444
+ return {"pos_embed", "cls_token", "dist_token"}
445
+
446
+ @torch.jit.ignore
447
+ def group_matcher(self, coarse: bool = False) -> Dict:
448
+ return dict(
449
+ stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
450
+ blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
451
+ )
452
+
453
+ @torch.jit.ignore
454
+ def set_grad_checkpointing(self, enable: bool = True) -> None:
455
+ self.grad_checkpointing = enable
456
+
457
+ @torch.jit.ignore
458
+ def get_classifier(self) -> nn.Module:
459
+ return self.head
460
+
461
+ def reset_classifier(self, num_classes: int, global_pool=None) -> None:
462
+ self.num_classes = num_classes
463
+ if global_pool is not None:
464
+ assert global_pool in ("", "avg", "token", "map")
465
+ if global_pool == "map" and self.attn_pool is None:
466
+ assert (
467
+ False
468
+ ), "Cannot currently add attention pooling in reset_classifier()."
469
+ elif global_pool != "map " and self.attn_pool is not None:
470
+ self.attn_pool = None # remove attention pooling
471
+ self.global_pool = global_pool
472
+ self.head = (
473
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
474
+ )
475
+
476
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
477
+ if self.dynamic_img_size:
478
+ B, H, W, C = x.shape
479
+ pos_embed = resample_abs_pos_embed(
480
+ self.pos_embed,
481
+ (H, W),
482
+ num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
483
+ )
484
+ x = x.view(B, -1, C)
485
+ else:
486
+ pos_embed = self.pos_embed
487
+
488
+ to_cat = []
489
+ if self.cls_token is not None:
490
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
491
+ if self.reg_token is not None:
492
+ to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
493
+
494
+ if self.no_embed_class:
495
+ # deit-3, updated JAX (big vision)
496
+ # position embedding does not overlap with class token, add then concat
497
+ x = x + pos_embed
498
+ if to_cat:
499
+ x = torch.cat(to_cat + [x], dim=1)
500
+ else:
501
+ # original timm, JAX, and deit vit impl
502
+ # pos_embed has entry for class token, concat then add
503
+ if to_cat:
504
+ x = torch.cat(to_cat + [x], dim=1)
505
+ x = x + pos_embed
506
+
507
+ return self.pos_drop(x)
508
+
509
+ def _intermediate_layers(
510
+ self,
511
+ x: torch.Tensor,
512
+ n: Union[int, Sequence] = 1,
513
+ ) -> List[torch.Tensor]:
514
+ outputs, num_blocks = [], len(self.blocks)
515
+ take_indices = set(
516
+ range(num_blocks - n, num_blocks) if isinstance(n, int) else n
517
+ )
518
+
519
+ # forward pass
520
+ x = self.patch_embed(x)
521
+ x = self._pos_embed(x)
522
+ x = self.patch_drop(x)
523
+ x = self.norm_pre(x)
524
+ for i, blk in enumerate(self.blocks):
525
+ x = blk(x)
526
+ if i in take_indices:
527
+ outputs.append(x)
528
+
529
+ return outputs
530
+
531
+ def get_intermediate_layers(
532
+ self,
533
+ x: torch.Tensor,
534
+ n: Union[int, Sequence] = 1,
535
+ reshape: bool = False,
536
+ return_prefix_tokens: bool = False,
537
+ norm: bool = False,
538
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
539
+ """Intermediate layer accessor (NOTE: This is a WIP experiment).
540
+ Inspired by DINO / DINOv2 interface
541
+ """
542
+ # take last n blocks if n is an int, if in is a sequence, select by matching indices
543
+ outputs = self._intermediate_layers(x, n)
544
+ if norm:
545
+ outputs = [self.norm(out) for out in outputs]
546
+ prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
547
+ outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
548
+
549
+ if reshape:
550
+ grid_size = self.patch_embed.grid_size
551
+ outputs = [
552
+ out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
553
+ .permute(0, 3, 1, 2)
554
+ .contiguous()
555
+ for out in outputs
556
+ ]
557
+
558
+ if return_prefix_tokens:
559
+ return tuple(zip(outputs, prefix_tokens))
560
+ return tuple(outputs)
561
+
562
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
563
+ x = self.patch_embed(x)
564
+ x = self._pos_embed(x)
565
+ x = self.patch_drop(x)
566
+ x = self.norm_pre(x)
567
+ if self.grad_checkpointing and not torch.jit.is_scripting():
568
+ x = checkpoint_seq(self.blocks, x)
569
+ else:
570
+ x = self.blocks(x)
571
+ x = self.norm(x)
572
+ return x
573
+
574
+ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
575
+ if self.attn_pool is not None:
576
+ x = self.attn_pool(x)
577
+ elif self.global_pool == "avg":
578
+ x = x[:, self.num_prefix_tokens :].mean(dim=1)
579
+ elif self.global_pool:
580
+ x = x[:, 0] # class token
581
+ x = self.fc_norm(x)
582
+ x = self.head_drop(x)
583
+ return x if pre_logits else self.head(x)
584
+
585
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
586
+ x = self.forward_features(x)
587
+ if not self.ignore_head:
588
+ x = self.forward_head(x)
589
+ return x
590
+
591
+
592
+ @dataclass
593
+ class SigLIPVisionCfg:
594
+ width: int = 1152
595
+ layers: Union[Tuple[int, int, int, int], int] = 27
596
+ heads: int = 16
597
+ patch_size: int = 14
598
+ image_size: Union[Tuple[int, int], int] = 336
599
+ global_pool: str = "map"
600
+ mlp_ratio: float = 3.7362
601
+ class_token: bool = False
602
+ num_classes: int = 0
603
+ use_checkpoint: bool = False
604
+
605
+
606
+ SigLIP_MODEL_CONFIG = {
607
+ "siglip_so400m_patch14_384": {
608
+ "image_size": 336,
609
+ "patch_size": 14,
610
+ "width": 1152,
611
+ "layers": 27,
612
+ "heads": 16,
613
+ "mlp_ratio": 3.7362,
614
+ "global_pool": "map",
615
+ "use_checkpoint": False,
616
+ },
617
+ "siglip_so400m_patch14_224": {
618
+ "image_size": 224,
619
+ "patch_size": 14,
620
+ "width": 1152,
621
+ "layers": 27,
622
+ "heads": 16,
623
+ "mlp_ratio": 3.7362,
624
+ "global_pool": "map",
625
+ "use_checkpoint": False,
626
+ },
627
+ "siglip_large_patch16_384": {
628
+ "image_size": 384,
629
+ "patch_size": 16,
630
+ "width": 1024,
631
+ "layers": 24,
632
+ "heads": 16,
633
+ "mlp_ratio": 4,
634
+ "global_pool": "map",
635
+ "use_checkpoint": False,
636
+ },
637
+ }
638
+
639
+
640
+ def create_siglip_vit(
641
+ model_name: str = "siglip_so400m_patch14_384",
642
+ image_size: int = 384,
643
+ select_layer: int = -1,
644
+ ckpt_path: str = "",
645
+ **kwargs,
646
+ ):
647
+ assert (
648
+ model_name in SigLIP_MODEL_CONFIG.keys()
649
+ ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
650
+
651
+ vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
652
+
653
+ if select_layer <= 0:
654
+ layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
655
+ else:
656
+ layers = min(vision_cfg.layers, select_layer)
657
+
658
+ model = VisionTransformer(
659
+ img_size=image_size,
660
+ patch_size=vision_cfg.patch_size,
661
+ embed_dim=vision_cfg.width,
662
+ depth=layers,
663
+ num_heads=vision_cfg.heads,
664
+ mlp_ratio=vision_cfg.mlp_ratio,
665
+ class_token=vision_cfg.class_token,
666
+ global_pool=vision_cfg.global_pool,
667
+ ignore_head=kwargs.get("ignore_head", True),
668
+ weight_init=kwargs.get("weight_init", "skip"),
669
+ num_classes=0,
670
+ )
671
+
672
+ if ckpt_path:
673
+ state_dict = torch.load(ckpt_path, map_location="cpu")
674
+
675
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
676
+ print(
677
+ f"SigLIP-ViT restores from {ckpt_path},\n"
678
+ f"\tincompatible_keys:', {incompatible_keys}."
679
+ )
680
+
681
+ return model
janus/models/vq_model.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+
21
+ from dataclasses import dataclass, field
22
+ from typing import List
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+
28
+ from functools import partial
29
+
30
+
31
+ @dataclass
32
+ class ModelArgs:
33
+ codebook_size: int = 16384
34
+ codebook_embed_dim: int = 8
35
+ codebook_l2_norm: bool = True
36
+ codebook_show_usage: bool = True
37
+ commit_loss_beta: float = 0.25
38
+ entropy_loss_ratio: float = 0.0
39
+
40
+ encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
41
+ decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
42
+ z_channels: int = 256
43
+ dropout_p: float = 0.0
44
+
45
+
46
+ class Encoder(nn.Module):
47
+ def __init__(
48
+ self,
49
+ in_channels=3,
50
+ ch=128,
51
+ ch_mult=(1, 1, 2, 2, 4),
52
+ num_res_blocks=2,
53
+ norm_type="group",
54
+ dropout=0.0,
55
+ resamp_with_conv=True,
56
+ z_channels=256,
57
+ ):
58
+ super().__init__()
59
+ self.num_resolutions = len(ch_mult)
60
+ self.num_res_blocks = num_res_blocks
61
+ self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
62
+
63
+ # downsampling
64
+ in_ch_mult = (1,) + tuple(ch_mult)
65
+ self.conv_blocks = nn.ModuleList()
66
+ for i_level in range(self.num_resolutions):
67
+ conv_block = nn.Module()
68
+ # res & attn
69
+ res_block = nn.ModuleList()
70
+ attn_block = nn.ModuleList()
71
+ block_in = ch * in_ch_mult[i_level]
72
+ block_out = ch * ch_mult[i_level]
73
+ for _ in range(self.num_res_blocks):
74
+ res_block.append(
75
+ ResnetBlock(
76
+ block_in, block_out, dropout=dropout, norm_type=norm_type
77
+ )
78
+ )
79
+ block_in = block_out
80
+ if i_level == self.num_resolutions - 1:
81
+ attn_block.append(AttnBlock(block_in, norm_type))
82
+ conv_block.res = res_block
83
+ conv_block.attn = attn_block
84
+ # downsample
85
+ if i_level != self.num_resolutions - 1:
86
+ conv_block.downsample = Downsample(block_in, resamp_with_conv)
87
+ self.conv_blocks.append(conv_block)
88
+
89
+ # middle
90
+ self.mid = nn.ModuleList()
91
+ self.mid.append(
92
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
93
+ )
94
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
95
+ self.mid.append(
96
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
97
+ )
98
+
99
+ # end
100
+ self.norm_out = Normalize(block_in, norm_type)
101
+ self.conv_out = nn.Conv2d(
102
+ block_in, z_channels, kernel_size=3, stride=1, padding=1
103
+ )
104
+
105
+ def forward(self, x):
106
+ h = self.conv_in(x)
107
+ # downsampling
108
+ for i_level, block in enumerate(self.conv_blocks):
109
+ for i_block in range(self.num_res_blocks):
110
+ h = block.res[i_block](h)
111
+ if len(block.attn) > 0:
112
+ h = block.attn[i_block](h)
113
+ if i_level != self.num_resolutions - 1:
114
+ h = block.downsample(h)
115
+
116
+ # middle
117
+ for mid_block in self.mid:
118
+ h = mid_block(h)
119
+
120
+ # end
121
+ h = self.norm_out(h)
122
+ h = nonlinearity(h)
123
+ h = self.conv_out(h)
124
+ return h
125
+
126
+
127
+ class Decoder(nn.Module):
128
+ def __init__(
129
+ self,
130
+ z_channels=256,
131
+ ch=128,
132
+ ch_mult=(1, 1, 2, 2, 4),
133
+ num_res_blocks=2,
134
+ norm_type="group",
135
+ dropout=0.0,
136
+ resamp_with_conv=True,
137
+ out_channels=3,
138
+ ):
139
+ super().__init__()
140
+ self.num_resolutions = len(ch_mult)
141
+ self.num_res_blocks = num_res_blocks
142
+
143
+ block_in = ch * ch_mult[self.num_resolutions - 1]
144
+ # z to block_in
145
+ self.conv_in = nn.Conv2d(
146
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
147
+ )
148
+
149
+ # middle
150
+ self.mid = nn.ModuleList()
151
+ self.mid.append(
152
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
153
+ )
154
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
155
+ self.mid.append(
156
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
157
+ )
158
+
159
+ # upsampling
160
+ self.conv_blocks = nn.ModuleList()
161
+ for i_level in reversed(range(self.num_resolutions)):
162
+ conv_block = nn.Module()
163
+ # res & attn
164
+ res_block = nn.ModuleList()
165
+ attn_block = nn.ModuleList()
166
+ block_out = ch * ch_mult[i_level]
167
+ for _ in range(self.num_res_blocks + 1):
168
+ res_block.append(
169
+ ResnetBlock(
170
+ block_in, block_out, dropout=dropout, norm_type=norm_type
171
+ )
172
+ )
173
+ block_in = block_out
174
+ if i_level == self.num_resolutions - 1:
175
+ attn_block.append(AttnBlock(block_in, norm_type))
176
+ conv_block.res = res_block
177
+ conv_block.attn = attn_block
178
+ # downsample
179
+ if i_level != 0:
180
+ conv_block.upsample = Upsample(block_in, resamp_with_conv)
181
+ self.conv_blocks.append(conv_block)
182
+
183
+ # end
184
+ self.norm_out = Normalize(block_in, norm_type)
185
+ self.conv_out = nn.Conv2d(
186
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
187
+ )
188
+
189
+ @property
190
+ def last_layer(self):
191
+ return self.conv_out.weight
192
+
193
+ def forward(self, z):
194
+ # z to block_in
195
+ h = self.conv_in(z)
196
+
197
+ # middle
198
+ for mid_block in self.mid:
199
+ h = mid_block(h)
200
+
201
+ # upsampling
202
+ for i_level, block in enumerate(self.conv_blocks):
203
+ for i_block in range(self.num_res_blocks + 1):
204
+ h = block.res[i_block](h)
205
+ if len(block.attn) > 0:
206
+ h = block.attn[i_block](h)
207
+ if i_level != self.num_resolutions - 1:
208
+ h = block.upsample(h)
209
+
210
+ # end
211
+ h = self.norm_out(h)
212
+ h = nonlinearity(h)
213
+ h = self.conv_out(h)
214
+ return h
215
+
216
+
217
+ class VectorQuantizer(nn.Module):
218
+ def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
219
+ super().__init__()
220
+ self.n_e = n_e
221
+ self.e_dim = e_dim
222
+ self.beta = beta
223
+ self.entropy_loss_ratio = entropy_loss_ratio
224
+ self.l2_norm = l2_norm
225
+ self.show_usage = show_usage
226
+
227
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
228
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
229
+ if self.l2_norm:
230
+ self.embedding.weight.data = F.normalize(
231
+ self.embedding.weight.data, p=2, dim=-1
232
+ )
233
+ if self.show_usage:
234
+ self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
235
+
236
+ def forward(self, z):
237
+ # reshape z -> (batch, height, width, channel) and flatten
238
+ z = torch.einsum("b c h w -> b h w c", z).contiguous()
239
+ z_flattened = z.view(-1, self.e_dim)
240
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
241
+
242
+ if self.l2_norm:
243
+ z = F.normalize(z, p=2, dim=-1)
244
+ z_flattened = F.normalize(z_flattened, p=2, dim=-1)
245
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
246
+ else:
247
+ embedding = self.embedding.weight
248
+
249
+ d = (
250
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
251
+ + torch.sum(embedding**2, dim=1)
252
+ - 2
253
+ * torch.einsum(
254
+ "bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding)
255
+ )
256
+ )
257
+
258
+ min_encoding_indices = torch.argmin(d, dim=1)
259
+ z_q = embedding[min_encoding_indices].view(z.shape)
260
+ perplexity = None
261
+ min_encodings = None
262
+ vq_loss = None
263
+ commit_loss = None
264
+ entropy_loss = None
265
+
266
+ # compute loss for embedding
267
+ if self.training:
268
+ vq_loss = torch.mean((z_q - z.detach()) ** 2)
269
+ commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
270
+ entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
271
+
272
+ # preserve gradients
273
+ z_q = z + (z_q - z).detach()
274
+
275
+ # reshape back to match original input shape
276
+ z_q = torch.einsum("b h w c -> b c h w", z_q)
277
+
278
+ return (
279
+ z_q,
280
+ (vq_loss, commit_loss, entropy_loss),
281
+ (perplexity, min_encodings, min_encoding_indices),
282
+ )
283
+
284
+ def get_codebook_entry(self, indices, shape=None, channel_first=True):
285
+ # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
286
+ if self.l2_norm:
287
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
288
+ else:
289
+ embedding = self.embedding.weight
290
+ z_q = embedding[indices] # (b*h*w, c)
291
+
292
+ if shape is not None:
293
+ if channel_first:
294
+ z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
295
+ # reshape back to match original input shape
296
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
297
+ else:
298
+ z_q = z_q.view(shape)
299
+ return z_q
300
+
301
+
302
+ class ResnetBlock(nn.Module):
303
+ def __init__(
304
+ self,
305
+ in_channels,
306
+ out_channels=None,
307
+ conv_shortcut=False,
308
+ dropout=0.0,
309
+ norm_type="group",
310
+ ):
311
+ super().__init__()
312
+ self.in_channels = in_channels
313
+ out_channels = in_channels if out_channels is None else out_channels
314
+ self.out_channels = out_channels
315
+ self.use_conv_shortcut = conv_shortcut
316
+
317
+ self.norm1 = Normalize(in_channels, norm_type)
318
+ self.conv1 = nn.Conv2d(
319
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
320
+ )
321
+ self.norm2 = Normalize(out_channels, norm_type)
322
+ self.dropout = nn.Dropout(dropout)
323
+ self.conv2 = nn.Conv2d(
324
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
325
+ )
326
+
327
+ if self.in_channels != self.out_channels:
328
+ if self.use_conv_shortcut:
329
+ self.conv_shortcut = nn.Conv2d(
330
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
331
+ )
332
+ else:
333
+ self.nin_shortcut = nn.Conv2d(
334
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
335
+ )
336
+
337
+ def forward(self, x):
338
+ h = x
339
+ h = self.norm1(h)
340
+ h = nonlinearity(h)
341
+ h = self.conv1(h)
342
+ h = self.norm2(h)
343
+ h = nonlinearity(h)
344
+ h = self.dropout(h)
345
+ h = self.conv2(h)
346
+
347
+ if self.in_channels != self.out_channels:
348
+ if self.use_conv_shortcut:
349
+ x = self.conv_shortcut(x)
350
+ else:
351
+ x = self.nin_shortcut(x)
352
+ return x + h
353
+
354
+
355
+ class AttnBlock(nn.Module):
356
+ def __init__(self, in_channels, norm_type="group"):
357
+ super().__init__()
358
+ self.norm = Normalize(in_channels, norm_type)
359
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
360
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
361
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
362
+ self.proj_out = nn.Conv2d(
363
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
364
+ )
365
+
366
+ def forward(self, x):
367
+ h_ = x
368
+ h_ = self.norm(h_)
369
+ q = self.q(h_)
370
+ k = self.k(h_)
371
+ v = self.v(h_)
372
+
373
+ # compute attention
374
+ b, c, h, w = q.shape
375
+ q = q.reshape(b, c, h * w)
376
+ q = q.permute(0, 2, 1) # b,hw,c
377
+ k = k.reshape(b, c, h * w) # b,c,hw
378
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
379
+ w_ = w_ * (int(c) ** (-0.5))
380
+ w_ = F.softmax(w_, dim=2)
381
+
382
+ # attend to values
383
+ v = v.reshape(b, c, h * w)
384
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
385
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
386
+ h_ = h_.reshape(b, c, h, w)
387
+
388
+ h_ = self.proj_out(h_)
389
+
390
+ return x + h_
391
+
392
+
393
+ def nonlinearity(x):
394
+ # swish
395
+ return x * torch.sigmoid(x)
396
+
397
+
398
+ def Normalize(in_channels, norm_type="group"):
399
+ assert norm_type in ["group", "batch"]
400
+ if norm_type == "group":
401
+ return nn.GroupNorm(
402
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
403
+ )
404
+ elif norm_type == "batch":
405
+ return nn.SyncBatchNorm(in_channels)
406
+
407
+
408
+ class Upsample(nn.Module):
409
+ def __init__(self, in_channels, with_conv):
410
+ super().__init__()
411
+ self.with_conv = with_conv
412
+ if self.with_conv:
413
+ self.conv = nn.Conv2d(
414
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
415
+ )
416
+
417
+ def forward(self, x):
418
+ if x.dtype != torch.float32:
419
+ x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to(
420
+ torch.bfloat16
421
+ )
422
+ else:
423
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
424
+
425
+ if self.with_conv:
426
+ x = self.conv(x)
427
+ return x
428
+
429
+
430
+ class Downsample(nn.Module):
431
+ def __init__(self, in_channels, with_conv):
432
+ super().__init__()
433
+ self.with_conv = with_conv
434
+ if self.with_conv:
435
+ # no asymmetric padding in torch conv, must do it ourselves
436
+ self.conv = nn.Conv2d(
437
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
438
+ )
439
+
440
+ def forward(self, x):
441
+ if self.with_conv:
442
+ pad = (0, 1, 0, 1)
443
+ x = F.pad(x, pad, mode="constant", value=0)
444
+ x = self.conv(x)
445
+ else:
446
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
447
+ return x
448
+
449
+
450
+ def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
451
+ flat_affinity = affinity.reshape(-1, affinity.shape[-1])
452
+ flat_affinity /= temperature
453
+ probs = F.softmax(flat_affinity, dim=-1)
454
+ log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
455
+ if loss_type == "softmax":
456
+ target_probs = probs
457
+ else:
458
+ raise ValueError("Entropy loss {} not supported".format(loss_type))
459
+ avg_probs = torch.mean(target_probs, dim=0)
460
+ avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
461
+ sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1))
462
+ loss = sample_entropy - avg_entropy
463
+ return loss
464
+
465
+
466
+ class VQModel(nn.Module):
467
+ def __init__(self, config: ModelArgs):
468
+ super().__init__()
469
+ self.config = config
470
+ self.encoder = Encoder(
471
+ ch_mult=config.encoder_ch_mult,
472
+ z_channels=config.z_channels,
473
+ dropout=config.dropout_p,
474
+ )
475
+ self.decoder = Decoder(
476
+ ch_mult=config.decoder_ch_mult,
477
+ z_channels=config.z_channels,
478
+ dropout=config.dropout_p,
479
+ )
480
+
481
+ self.quantize = VectorQuantizer(
482
+ config.codebook_size,
483
+ config.codebook_embed_dim,
484
+ config.commit_loss_beta,
485
+ config.entropy_loss_ratio,
486
+ config.codebook_l2_norm,
487
+ config.codebook_show_usage,
488
+ )
489
+ self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
490
+ self.post_quant_conv = nn.Conv2d(
491
+ config.codebook_embed_dim, config.z_channels, 1
492
+ )
493
+
494
+ def encode(self, x):
495
+ h = self.encoder(x)
496
+ h = self.quant_conv(h)
497
+ quant, emb_loss, info = self.quantize(h)
498
+ return quant, emb_loss, info
499
+
500
+ def decode(self, quant):
501
+ quant = self.post_quant_conv(quant)
502
+ dec = self.decoder(quant)
503
+ return dec
504
+
505
+ def decode_code(self, code_b, shape=None, channel_first=True):
506
+ quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
507
+ dec = self.decode(quant_b)
508
+ return dec
509
+
510
+ def forward(self, input):
511
+ quant, diff, _ = self.encode(input)
512
+ dec = self.decode(quant)
513
+ return dec, diff
514
+
515
+
516
+ #################################################################################
517
+ # VQ Model Configs #
518
+ #################################################################################
519
+ def VQ_16(**kwargs):
520
+ return VQModel(
521
+ ModelArgs(
522
+ encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs
523
+ )
524
+ )
525
+
526
+
527
+ VQ_models = {"VQ-16": VQ_16}
janus/utils/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
janus/utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (174 Bytes). View file
 
janus/utils/__pycache__/conversation.cpython-38.pyc ADDED
Binary file (7.5 kB). View file
 
janus/utils/__pycache__/io.cpython-38.pyc ADDED
Binary file (2.06 kB). View file
 
janus/utils/conversation.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ """
21
+ From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
22
+ """
23
+
24
+ import dataclasses
25
+ from enum import IntEnum, auto
26
+ from typing import Dict, List
27
+
28
+
29
+ class SeparatorStyle(IntEnum):
30
+ """Separator styles."""
31
+
32
+ ADD_COLON_SINGLE = auto()
33
+ ADD_COLON_TWO = auto()
34
+ ADD_COLON_SPACE_SINGLE = auto()
35
+ NO_COLON_SINGLE = auto()
36
+ NO_COLON_TWO = auto()
37
+ ADD_NEW_LINE_SINGLE = auto()
38
+ LLAMA2 = auto()
39
+ CHATGLM = auto()
40
+ CHATML = auto()
41
+ CHATINTERN = auto()
42
+ DOLLY = auto()
43
+ RWKV = auto()
44
+ PHOENIX = auto()
45
+ ROBIN = auto()
46
+ DeepSeek = auto()
47
+ PLAIN = auto()
48
+ ALIGNMENT = auto()
49
+
50
+
51
+ @dataclasses.dataclass
52
+ class Conversation:
53
+ """A class that manages prompt templates and keeps all conversation history."""
54
+
55
+ # The name of this template
56
+ name: str
57
+ # The template of the system prompt
58
+ system_template: str = "{system_message}"
59
+ # The system message
60
+ system_message: str = ""
61
+ # The names of two roles
62
+ roles: List[str] = (("USER", "ASSISTANT"),)
63
+ # All messages. Each item is (role, message).
64
+ messages: List[List[str]] = ()
65
+ # The number of few shot examples
66
+ offset: int = 0
67
+ # The separator style and configurations
68
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
69
+ sep: str = "\n"
70
+ sep2: str = None
71
+ # Stop criteria (the default one is EOS token)
72
+ stop_str: str = None
73
+ # Stops generation if meeting any token in this list
74
+ stop_token_ids: List[int] = None
75
+
76
+ def get_prompt(self) -> str:
77
+ """Get the prompt for generation."""
78
+ system_prompt = self.system_template.format(system_message=self.system_message)
79
+
80
+ if self.sep_style == SeparatorStyle.DeepSeek:
81
+ seps = [self.sep, self.sep2]
82
+ if system_prompt == "" or system_prompt is None:
83
+ ret = ""
84
+ else:
85
+ ret = system_prompt + seps[0]
86
+ for i, (role, message) in enumerate(self.messages):
87
+ if message:
88
+ ret += role + ": " + message + seps[i % 2]
89
+ else:
90
+ ret += role + ":"
91
+ return ret
92
+ elif self.sep_style == SeparatorStyle.LLAMA2:
93
+ seps = [self.sep, self.sep2]
94
+ if self.system_message:
95
+ ret = system_prompt
96
+ else:
97
+ ret = "[INST] "
98
+ for i, (role, message) in enumerate(self.messages):
99
+ tag = self.roles[i % 2]
100
+ if message:
101
+ if type(message) is tuple: # multimodal message
102
+ message, _ = message
103
+ if i == 0:
104
+ ret += message + " "
105
+ else:
106
+ ret += tag + " " + message + seps[i % 2]
107
+ else:
108
+ ret += tag
109
+ return ret
110
+ elif self.sep_style == SeparatorStyle.PLAIN:
111
+ seps = [self.sep, self.sep2]
112
+ ret = ""
113
+ for i, (role, message) in enumerate(self.messages):
114
+ if message:
115
+ if type(message) is tuple:
116
+ message, _, _ = message
117
+ if i % 2 == 0:
118
+ ret += message + seps[i % 2]
119
+ else:
120
+ ret += message + seps[i % 2]
121
+ else:
122
+ ret += ""
123
+ return ret
124
+ elif self.sep_style == SeparatorStyle.ALIGNMENT:
125
+ seps = [self.sep, self.sep2]
126
+ ret = ""
127
+ for i, (role, message) in enumerate(self.messages):
128
+ if message:
129
+ if type(message) is tuple:
130
+ message, _, _ = message
131
+ if i % 2 == 0:
132
+ ret += "<image>\n" + seps[i % 2]
133
+ else:
134
+ ret += message + seps[i % 2]
135
+ else:
136
+ ret += ""
137
+ return ret
138
+ else:
139
+ raise ValueError(f"Invalid style: {self.sep_style}")
140
+
141
+ def get_prompt_for_current_round(self, content=None):
142
+ """Get current round formatted question prompt during sft training"""
143
+ if self.sep_style == SeparatorStyle.PLAIN:
144
+ formatted_question = "<image>\n"
145
+ elif self.sep_style == SeparatorStyle.DeepSeek:
146
+ formatted_question = (
147
+ f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:"
148
+ )
149
+ else:
150
+ raise ValueError(f"Unsupported sep_style: {self.sep_style}")
151
+ return formatted_question
152
+
153
+ def set_system_message(self, system_message: str):
154
+ """Set the system message."""
155
+ self.system_message = system_message
156
+
157
+ def append_message(self, role: str, message: str):
158
+ """Append a new message."""
159
+ self.messages.append([role, message])
160
+
161
+ def reset_message(self):
162
+ """Reset a new message."""
163
+ self.messages = []
164
+
165
+ def update_last_message(self, message: str):
166
+ """Update the last output.
167
+
168
+ The last message is typically set to be None when constructing the prompt,
169
+ so we need to update it in-place after getting the response from a model.
170
+ """
171
+ self.messages[-1][1] = message
172
+
173
+ def to_gradio_chatbot(self):
174
+ """Convert the conversation to gradio chatbot format."""
175
+ ret = []
176
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
177
+ if i % 2 == 0:
178
+ ret.append([msg, None])
179
+ else:
180
+ ret[-1][-1] = msg
181
+ return ret
182
+
183
+ def to_openai_api_messages(self):
184
+ """Convert the conversation to OpenAI chat completion format."""
185
+ system_prompt = self.system_template.format(system_message=self.system_message)
186
+ ret = [{"role": "system", "content": system_prompt}]
187
+
188
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
189
+ if i % 2 == 0:
190
+ ret.append({"role": "user", "content": msg})
191
+ else:
192
+ if msg is not None:
193
+ ret.append({"role": "assistant", "content": msg})
194
+ return ret
195
+
196
+ def copy(self):
197
+ return Conversation(
198
+ name=self.name,
199
+ system_template=self.system_template,
200
+ system_message=self.system_message,
201
+ roles=self.roles,
202
+ messages=[[x, y] for x, y in self.messages],
203
+ offset=self.offset,
204
+ sep_style=self.sep_style,
205
+ sep=self.sep,
206
+ sep2=self.sep2,
207
+ stop_str=self.stop_str,
208
+ stop_token_ids=self.stop_token_ids,
209
+ )
210
+
211
+ def dict(self):
212
+ return {
213
+ "template_name": self.name,
214
+ "system_message": self.system_message,
215
+ "roles": self.roles,
216
+ "messages": self.messages,
217
+ "offset": self.offset,
218
+ }
219
+
220
+
221
+ # A global registry for all conversation templates
222
+ conv_templates: Dict[str, Conversation] = {}
223
+
224
+
225
+ def register_conv_template(template: Conversation, override: bool = False):
226
+ """Register a new conversation template."""
227
+ if not override:
228
+ assert (
229
+ template.name not in conv_templates
230
+ ), f"{template.name} has been registered."
231
+
232
+ conv_templates[template.name] = template
233
+
234
+
235
+ def get_conv_template(name: str) -> Conversation:
236
+ """Get a conversation template."""
237
+ return conv_templates[name].copy()
238
+
239
+
240
+ # llava_llama2 template
241
+ register_conv_template(
242
+ Conversation(
243
+ name="llava_llama2",
244
+ system_message="You are a helpful language and vision assistant. "
245
+ "You are able to understand the visual content that the user provides, "
246
+ "and assist the user with a variety of tasks using natural language.",
247
+ system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
248
+ roles=("[INST]", "[/INST]"),
249
+ messages=(),
250
+ offset=0,
251
+ sep_style=SeparatorStyle.LLAMA2,
252
+ sep=" ",
253
+ sep2=" </s><s>",
254
+ stop_token_ids=[2],
255
+ )
256
+ )
257
+
258
+ # llama2 template
259
+ # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
260
+ register_conv_template(
261
+ Conversation(
262
+ name="llama-2",
263
+ system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
264
+ roles=("[INST]", "[/INST]"),
265
+ messages=(),
266
+ offset=0,
267
+ sep_style=SeparatorStyle.LLAMA2,
268
+ sep=" ",
269
+ sep2=" </s><s>",
270
+ stop_token_ids=[2],
271
+ )
272
+ )
273
+
274
+
275
+ # deepseek template
276
+ register_conv_template(
277
+ Conversation(
278
+ name="deepseek_old",
279
+ system_template="{system_message}",
280
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
281
+ # "thinking step by step to be sure you get the right answer.",
282
+ system_message="",
283
+ roles=("User", "Assistant"),
284
+ messages=(),
285
+ offset=0,
286
+ sep_style=SeparatorStyle.DeepSeek,
287
+ sep="\n\n",
288
+ sep2="<|end▁of▁sentence|>",
289
+ stop_token_ids=[100001],
290
+ stop_str=["User:", "<|end▁of▁sentence|>"],
291
+ )
292
+ )
293
+ register_conv_template(
294
+ Conversation(
295
+ name="deepseek",
296
+ system_template="{system_message}",
297
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
298
+ # "thinking step by step to be sure you get the right answer.",
299
+ system_message="",
300
+ roles=("<|User|>", "<|Assistant|>"),
301
+ messages=(),
302
+ offset=0,
303
+ sep_style=SeparatorStyle.DeepSeek,
304
+ sep="\n\n",
305
+ sep2="<|end▁of▁sentence|>",
306
+ stop_token_ids=[100001],
307
+ stop_str=["<|User|>", "<|end▁of▁sentence|>"]
308
+ )
309
+ )
310
+
311
+ register_conv_template(
312
+ Conversation(
313
+ name="plain",
314
+ system_template="",
315
+ system_message="",
316
+ roles=("", ""),
317
+ messages=(),
318
+ offset=0,
319
+ sep_style=SeparatorStyle.PLAIN,
320
+ sep="",
321
+ sep2="",
322
+ stop_token_ids=[2],
323
+ stop_str=["</s>"],
324
+ )
325
+ )
326
+
327
+
328
+ register_conv_template(
329
+ Conversation(
330
+ name="alignment",
331
+ system_template="",
332
+ system_message="",
333
+ roles=("", ""),
334
+ messages=(),
335
+ offset=0,
336
+ sep_style=SeparatorStyle.ALIGNMENT,
337
+ sep="",
338
+ sep2="",
339
+ stop_token_ids=[2],
340
+ stop_str=["</s>"],
341
+ )
342
+ )
343
+
344
+
345
+ if __name__ == "__main__":
346
+ # print("Llama-2 template:")
347
+ # conv = get_conv_template("llama-2")
348
+ # conv.set_system_message("You are a helpful, respectful and honest assistant.")
349
+ # conv.append_message(conv.roles[0], "Hello!")
350
+ # conv.append_message(conv.roles[1], "Hi!")
351
+ # conv.append_message(conv.roles[0], "How are you?")
352
+ # conv.append_message(conv.roles[1], None)
353
+ # print(conv.get_prompt())
354
+
355
+ # print("\n")
356
+
357
+ print("deepseek template:")
358
+ conv = get_conv_template("deepseek")
359
+ conv.append_message(conv.roles[0], "Hello!")
360
+ conv.append_message(conv.roles[1], "Hi! This is Tony.")
361
+ conv.append_message(conv.roles[0], "Who are you?")
362
+ conv.append_message(conv.roles[1], "I am a helpful assistant.")
363
+ conv.append_message(conv.roles[0], "How are you?")
364
+ conv.append_message(conv.roles[1], None)
365
+ print(conv.get_prompt())
janus/utils/io.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ import json
21
+ from typing import Dict, List
22
+
23
+ import PIL.Image
24
+ import torch
25
+ import base64
26
+ import io
27
+ from transformers import AutoModelForCausalLM
28
+
29
+ from janus.models import MultiModalityCausalLM, VLChatProcessor
30
+
31
+
32
+ def load_pretrained_model(model_path: str):
33
+ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
34
+ tokenizer = vl_chat_processor.tokenizer
35
+
36
+ vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
37
+ model_path, trust_remote_code=True
38
+ )
39
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
40
+
41
+ return tokenizer, vl_chat_processor, vl_gpt
42
+
43
+
44
+ def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]:
45
+ """
46
+
47
+ Support file path or base64 images.
48
+
49
+ Args:
50
+ conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
51
+ [
52
+ {
53
+ "role": "User",
54
+ "content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.",
55
+ "images": ["./examples/table_datasets.png"]
56
+ },
57
+ {"role": "Assistant", "content": ""},
58
+ ]
59
+
60
+ Returns:
61
+ pil_images (List[PIL.Image.Image]): the list of PIL images.
62
+
63
+ """
64
+
65
+ pil_images = []
66
+
67
+ for message in conversations:
68
+ if "images" not in message:
69
+ continue
70
+
71
+ for image_data in message["images"]:
72
+ if image_data.startswith("data:image"):
73
+ # Image data is in base64 format
74
+ _, image_data = image_data.split(",", 1)
75
+ image_bytes = base64.b64decode(image_data)
76
+ pil_img = PIL.Image.open(io.BytesIO(image_bytes))
77
+ else:
78
+ # Image data is a file path
79
+ pil_img = PIL.Image.open(image_data)
80
+ pil_img = pil_img.convert("RGB")
81
+ pil_images.append(pil_img)
82
+
83
+ return pil_images
84
+
85
+
86
+ def load_json(filepath):
87
+ with open(filepath, "r") as f:
88
+ data = json.load(f)
89
+ return data
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ diffusers
3
+ gradio
4
+ numpy
5
+ torch
6
+ safetensors
7
+ transformers
8
+ git+https://github.com/deepseek-ai/Janus
weights/RealESRGAN_x2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c830d067d54fc767b9543a8432f36d91bc2de313584e8bbfe4ac26a47339e899
3
+ size 67061725