Jiaming Han commited on
Commit
87f9ab9
·
1 Parent(s): c8a1461
.gitattributes CHANGED
@@ -25,7 +25,6 @@
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +32,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ examples/animal.png filter=lfs diff=lfs merge=lfs -text
36
+ examples/bell_ring.wav filter=lfs diff=lfs merge=lfs -text
37
+ examples/caixukun.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ examples/flower.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ examples/food_menu.png filter=lfs diff=lfs merge=lfs -text
40
+ examples/star_kun.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ model/lib/pointnet2/build/lib.linux-x86_64-cpython-39/pointnet2_cuda.cpython-39-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
42
+ model/lib/pointnet2/build/temp.linux-x86_64-cpython-39/src/ball_query.o filter=lfs diff=lfs merge=lfs -text
43
+ model/lib/pointnet2/build/temp.linux-x86_64-cpython-39/src/group_points.o filter=lfs diff=lfs merge=lfs -text
44
+ model/lib/pointnet2/build/temp.linux-x86_64-cpython-39/src/interpolate.o filter=lfs diff=lfs merge=lfs -text
45
+ model/lib/pointnet2/build/temp.linux-x86_64-cpython-39/src/pointnet2_api.o filter=lfs diff=lfs merge=lfs -text
46
+ model/lib/pointnet2/build/temp.linux-x86_64-cpython-39/src/sampling.o filter=lfs diff=lfs merge=lfs -text
47
+ model/lib/pointnet2/dist/pointnet2-0.0.0-py3.9-linux-x86_64.egg filter=lfs diff=lfs merge=lfs -text
48
+ model/lib/pointnet2_cuda.cpython-39-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ *.egg-info
4
+ dist
5
+
6
+ output
7
+ output_dir
8
+ *.pth
9
+ *.log
10
+ weights
README.md CHANGED
@@ -1,13 +1,15 @@
1
  ---
2
- title: Tar 7B
3
- emoji: 🐨
4
- colorFrom: indigo
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.34.2
8
  app_file: app.py
9
  pinned: false
 
 
10
  license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Tar
3
+ emoji: 🚀
4
+ colorFrom: red
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.34.0
8
  app_file: app.py
9
  pinned: false
10
+ python_version: 3.10.18
11
+ short_description: Unified MLLM with Text-Aligned Representations
12
  license: apache-2.0
13
  ---
14
 
15
+ # Tar: Unifying Visual Understanding and Generation via Text-Aligned Representations
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from torchvision.transforms.functional import to_tensor
4
+ from huggingface_hub import hf_hub_download, snapshot_download, login
5
+ import spaces
6
+
7
+ from tok.ar_dtok.ar_model import ARModel
8
+ from t2i_inference import T2IConfig, TextToImageInference
9
+
10
+ def generate_text(self, image: str, prompt: str) -> str:
11
+ image = image.convert('RGB')
12
+ image = to_tensor(image).unsqueeze(0).to(self.device)
13
+
14
+ image_code = self.visual_tokenizer.encoder(image.to(self.config.dtype))['bottleneck_rep']
15
+ image_text = "".join([f"<I{x}>" for x in image_code[0].cpu().tolist()])
16
+
17
+ messages = [
18
+ {"role": "system", "content": "You are a helpful assistant."},
19
+ {"role": "user", "content": f"{image_text}\n{prompt}"}
20
+ ]
21
+
22
+ input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
23
+ inputs = self.tokenizer(input_text, return_tensors="pt")
24
+
25
+ gen_ids = self.model.generate(
26
+ inputs.input_ids.to(self.device),
27
+ max_new_tokens=512,
28
+ do_sample=True)
29
+ return self.tokenizer.batch_decode(gen_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
30
+
31
+ login(token=os.getenv('HF_TOKEN'))
32
+ config = T2IConfig()
33
+ config.model = snapshot_download("csuhan/Tar-7B-v0.1")
34
+ config.ar_path = {
35
+ "1024px": hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_1024px.pth"),
36
+ "512px": hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_512px.pth"),
37
+ }
38
+ config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth")
39
+ config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
40
+ inference = TextToImageInference(config)
41
+
42
+ @spaces.GPU(duration=240)
43
+ def generate_image(prompt, resolution, top_p, top_k, cfg_scale):
44
+ image = inference.generate_image(prompt, resolution, top_p, top_k, cfg_scale)
45
+ return image
46
+
47
+ def clear_inputs_t2i():
48
+ return "", None
49
+
50
+ @spaces.GPU(duration=240)
51
+ def understand_image(image, prompt):
52
+ return generate_text(inference, image, prompt)
53
+
54
+ def clear_inputs_i2t():
55
+ return None, ""
56
+
57
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
58
+ gr.Markdown(
59
+ """
60
+ <div align="center">
61
+
62
+ ### Tar: Unifying Visual Understanding and Generation via Text-Aligned Representations
63
+
64
+ [🕸️ Project Page](http://tar.csuhan.com) • [📄 Paper](http://arxiv.org/abs/2506.18898) • [💻 Code](https://github.com/csuhan/Tar) • [📦 Model](https://huggingface.co/collections/csuhan/tar-68538273b5537d0bee712648)
65
+
66
+ </div>
67
+ """,
68
+ elem_id="title",
69
+ )
70
+ with gr.Tab("Image Generation"):
71
+ with gr.Row():
72
+ with gr.Column(scale=1):
73
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt")
74
+ with gr.Accordion("Advanced Settings", open=False):
75
+ resolution = gr.Radio(
76
+ ["512px", "1024px"], value="1024px", label="Resolution"
77
+ )
78
+ top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
79
+ top_k = gr.Slider(1, 2000, value=1200, step=10, label="Top-k")
80
+ cfg_scale = gr.Slider(1.0, 20.0, value=4.0, step=0.5, label="CFG Scale")
81
+ with gr.Row():
82
+ generate_btn = gr.Button("Generate")
83
+ clear_btn = gr.Button("Clear")
84
+ with gr.Column(scale=2):
85
+ output_image = gr.Image(label="Generated Image")
86
+
87
+ generate_btn.click(
88
+ generate_image,
89
+ inputs=[prompt, resolution, top_p, top_k, cfg_scale],
90
+ outputs=output_image
91
+ )
92
+ clear_btn.click(
93
+ clear_inputs_t2i,
94
+ outputs=[prompt, output_image]
95
+ )
96
+
97
+ with gr.Tab("Image Understanding"):
98
+ with gr.Row():
99
+ with gr.Column(scale=1):
100
+ image_input = gr.Image(label="Upload Image", type="pil")
101
+ question_input = gr.Textbox(label="Instruction", value="Describe the image shortly.")
102
+ with gr.Row():
103
+ qa_btn = gr.Button("Generate")
104
+ clear_btn_i2t = gr.Button("Clear")
105
+ with gr.Column(scale=1):
106
+ answer_output = gr.Textbox(label="Response", lines=4)
107
+
108
+ qa_btn.click(
109
+ understand_image,
110
+ inputs=[image_input, question_input],
111
+ outputs=answer_output
112
+ )
113
+
114
+ clear_btn_i2t.click(
115
+ clear_inputs_i2t,
116
+ outputs=[image_input, question_input, answer_output]
117
+ )
118
+
119
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+ accelerate==0.28.0
3
+ datasets==2.16.1
4
+ deepspeed==0.14.4
5
+ einops==0.8.1
6
+ gradio==5.34.0
7
+ huggingface_hub==0.29.1
8
+ numpy==1.26.1
9
+ Pillow==11.2.1
10
+ pyarrow==17.0.0
11
+ PyYAML==6.0.2
12
+ torch==2.1.2
13
+ torchvision==0.16.2
14
+ tqdm==4.66.5
15
+ transformers==4.50.0
16
+ wandb
17
+ easydict
t2i_inference.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+ from PIL import Image
7
+ from transformers import AutoTokenizer, Qwen2ForCausalLM
8
+
9
+ from tok.mm_autoencoder import MMAutoEncoder
10
+
11
+
12
+ @dataclass
13
+ class T2IConfig:
14
+ model_path: str = "csuhan/Tar-1.5B"
15
+ # visual tokenizer config
16
+ ar_path = None
17
+ encoder_path: str = 'ta_tok.pth'
18
+ decoder_path: str = 'vq_ds16_t2i.pt'
19
+
20
+ device: str = "cuda:0"
21
+ dtype: torch.dtype = torch.bfloat16
22
+ # generation parameters
23
+ scale: int = 0 # choose from [0, 1, 2]
24
+ seq_len: int = 729 # choose from [729, 169, 81]
25
+ temperature: float = 1.0
26
+ top_p: float = 0.95
27
+ top_k: int = 1200
28
+ cfg_scale: float = 4.0
29
+
30
+ class TextToImageInference:
31
+ def __init__(self, config: T2IConfig):
32
+ self.config = config
33
+ self.device = torch.device(config.device)
34
+ self._load_models()
35
+
36
+ def _load_models(self):
37
+ self.model = Qwen2ForCausalLM.from_pretrained(self.config.model_path, torch_dtype=self.config.dtype).to(self.device)
38
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path)
39
+
40
+ # Initialize visual tokenizer
41
+ config = dict(
42
+ ar_path_dict=self.config.ar_path,
43
+ encoder_path=self.config.encoder_path,
44
+ decoder_path=self.config.decoder_path,
45
+ encoder_args={'input_type': 'rec'},
46
+ decoder_args={},
47
+ )
48
+ self.visual_tokenizer = MMAutoEncoder(**config).eval().to(dtype=self.config.dtype, device=self.device)
49
+ for ar_model in self.visual_tokenizer.ar_model.values():
50
+ ar_model.cls_token_num = self.config.seq_len
51
+ self.visual_tokenizer.encoder.pool_scale = self.config.scale + 1
52
+
53
+ def generate_image(self, prompt, resolution, top_p, top_k, cfg_scale) -> Image.Image:
54
+ # Prepare prompt
55
+ messages = [
56
+ {"role": "system", "content": "You are a helpful assistant."},
57
+ {"role": "user", "content": prompt}
58
+ ]
59
+
60
+ input_text = self.tokenizer.apply_chat_template(
61
+ messages,
62
+ tokenize=False,
63
+ add_generation_prompt=True)
64
+ input_text += f"<im_start><S{self.config.scale}>"
65
+
66
+ # Generate tokens
67
+ inputs = self.tokenizer(input_text, return_tensors="pt")
68
+ gen_ids = self.model.generate(
69
+ inputs.input_ids.to(self.device),
70
+ max_new_tokens=self.config.seq_len,
71
+ do_sample=True,
72
+ temperature=self.config.temperature,
73
+ top_p=top_p,
74
+ top_k=top_k)
75
+
76
+ # Process generated tokens
77
+ gen_text = self.tokenizer.batch_decode(gen_ids)[0]
78
+ gen_code = [int(x) for x in re.findall(r'<I(\d+)>', gen_text)]
79
+ gen_code = gen_code[:self.config.seq_len] + [0] * max(0, self.config.seq_len - len(gen_code))
80
+ gen_code = torch.tensor(gen_code).unsqueeze(0).to(self.device)
81
+
82
+ gen_tensor = self.visual_tokenizer.decode_from_encoder_indices(
83
+ gen_code,
84
+ {'cfg_scale': cfg_scale, 'resolution': resolution},
85
+ )
86
+ gen_image = Image.fromarray(gen_tensor[0].numpy())
87
+ return gen_image
tok/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .ar_dtok import *
tok/ar_dtok/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .bottleneck import Bottleneck, SimVectorQuantizer
2
+ from .vqvae import VQVAE
tok/ar_dtok/ar_model.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from contextlib import contextmanager
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+
11
+ from .. import models
12
+ from .generate import generate as ar_generate
13
+
14
+
15
+ def find_multiple(n: int, k: int):
16
+ if n % k == 0:
17
+ return n
18
+ return n + k - (n % k)
19
+
20
+
21
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, scale_factor=10000):
22
+ """
23
+ embed_dim: output dimension for each position
24
+ pos: a list of positions to be encoded: size (M,)
25
+ out: (M, D)
26
+ scale_factor: the base for the scaling factor, default is 10000
27
+ """
28
+ assert embed_dim % 2 == 0
29
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
30
+ omega /= embed_dim / 2.
31
+ omega = 1. / scale_factor**omega # Parameterized scaling factor (D/2,)
32
+
33
+ pos = pos.reshape(-1) # (M,)
34
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
35
+
36
+ emb_sin = np.sin(out) # (M, D/2)
37
+ emb_cos = np.cos(out) # (M, D/2)
38
+
39
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
40
+ return emb
41
+
42
+
43
+ @dataclass
44
+ class ModelArgs:
45
+ dim: int = 4096
46
+ n_layer: int = 32
47
+ n_head: int = 32
48
+
49
+ n_kv_head: Optional[int] = None
50
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
51
+ ffn_dim_multiplier: Optional[float] = None
52
+ rope_base: float = 10000
53
+ norm_eps: float = 1e-5
54
+ initializer_range: float = 0.02
55
+
56
+ token_dropout_p: float = 0.1
57
+ attn_dropout_p: float = 0.0
58
+ resid_dropout_p: float = 0.1
59
+ ffn_dropout_p: float = 0.1
60
+ drop_path_rate: float = 0.0
61
+
62
+ num_classes: int = 1000
63
+ class_dropout_prob: float = 0.1
64
+ model_type: str = 'class_cond' # clip_cond, indice_cond
65
+ cond_dim: int = 1152
66
+ cond_vocab_size: int = 8192
67
+
68
+ vocab_size: int = 8192
69
+ cls_token_num: int = 1
70
+
71
+ max_batch_size: int = 32
72
+ max_seq_len: int = 2048
73
+
74
+ use_fixed_pe: bool = False
75
+
76
+ frame_prediction: bool = False
77
+
78
+
79
+ class RMSNorm(torch.nn.Module):
80
+ def __init__(self, dim: int, eps: float = 1e-5):
81
+ super().__init__()
82
+ self.eps = eps
83
+ self.weight = nn.Parameter(torch.ones(dim))
84
+
85
+ @torch.autocast(device_type='cuda', enabled=False)
86
+ def _norm(self, x):
87
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
88
+
89
+ def forward(self, x):
90
+ output = self._norm(x.float()).type_as(x)
91
+ return output * self.weight
92
+
93
+
94
+ class MLP(nn.Module):
95
+ def __init__(self, in_features, hidden_features, out_features):
96
+ super().__init__()
97
+ out_features = out_features or in_features
98
+ hidden_features = hidden_features or in_features
99
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
100
+ self.act = nn.GELU(approximate='tanh')
101
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
102
+
103
+ def forward(self, x):
104
+ x = self.fc1(x)
105
+ x = self.act(x)
106
+ x = self.fc2(x)
107
+ return x
108
+
109
+
110
+ #################################################################################
111
+ # Drop Path Implementation #
112
+ #################################################################################
113
+
114
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
115
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
116
+
117
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
118
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
119
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
120
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
121
+ 'survival rate' as the argument.
122
+
123
+ """
124
+ if drop_prob == 0. or not training:
125
+ return x
126
+ keep_prob = 1 - drop_prob
127
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
128
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
129
+ if keep_prob > 0.0 and scale_by_keep:
130
+ random_tensor.div_(keep_prob)
131
+ return x * random_tensor
132
+
133
+
134
+ class DropPath(torch.nn.Module):
135
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
136
+ """
137
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
138
+ super(DropPath, self).__init__()
139
+ self.drop_prob = drop_prob
140
+ self.scale_by_keep = scale_by_keep
141
+
142
+ def forward(self, x):
143
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
144
+
145
+ def extra_repr(self):
146
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
147
+
148
+
149
+ #################################################################################
150
+ # AR Model #
151
+ #################################################################################
152
+
153
+ class FeedForward(nn.Module):
154
+ def __init__(self, config: ModelArgs):
155
+ super().__init__()
156
+ hidden_dim = 4 * config.dim
157
+ hidden_dim = int(2 * hidden_dim / 3)
158
+ # custom dim factor multiplier
159
+ if config.ffn_dim_multiplier is not None:
160
+ hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
161
+ hidden_dim = find_multiple(hidden_dim, config.multiple_of)
162
+
163
+ self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
164
+ self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
165
+ self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
166
+ self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)
167
+
168
+ def forward(self, x):
169
+ return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
170
+
171
+
172
+ class KVCache(nn.Module):
173
+ def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
174
+ super().__init__()
175
+ cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
176
+ self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
177
+ self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
178
+
179
+ def update(self, input_pos, k_val, v_val):
180
+ # input_pos: [S], k_val: [B, H, S, D]
181
+ assert input_pos.shape[0] == k_val.shape[2], f"{input_pos.shape[0]} != {k_val.shape[2]}"
182
+ k_out = self.k_cache
183
+ v_out = self.v_cache
184
+ k_out[:, :, input_pos] = k_val.to(k_out.dtype)
185
+ v_out[:, :, input_pos] = v_val.to(v_out.dtype)
186
+
187
+ return k_out, v_out
188
+
189
+
190
+ class Attention(nn.Module):
191
+ def __init__(self, config: ModelArgs):
192
+ super().__init__()
193
+ assert config.dim % config.n_head == 0
194
+ self.dim = config.dim
195
+ self.head_dim = config.dim // config.n_head
196
+ self.n_head = config.n_head
197
+ self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
198
+ total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim
199
+
200
+ # key, query, value projections for all heads, but in a batch
201
+ self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
202
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
203
+ self.kv_cache = None
204
+
205
+ # regularization
206
+ self.attn_dropout_p = config.attn_dropout_p
207
+ self.resid_dropout = nn.Dropout(config.resid_dropout_p)
208
+
209
+ def forward(
210
+ self, x: torch.Tensor,
211
+ input_pos: Optional[torch.Tensor] = None,
212
+ mask: Optional[torch.Tensor] = None
213
+ ):
214
+ bsz, seqlen, _ = x.shape
215
+ kv_size = self.n_kv_head * self.head_dim
216
+ xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
217
+
218
+ xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
219
+ xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim)
220
+ xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim)
221
+
222
+ xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
223
+
224
+ if self.kv_cache is not None:
225
+ keys, values = self.kv_cache.update(input_pos, xk, xv)
226
+ else:
227
+ keys, values = xk, xv
228
+ keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
229
+ values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
230
+
231
+ output = F.scaled_dot_product_attention(
232
+ xq, keys, values,
233
+ attn_mask=mask,
234
+ is_causal=True if mask is None else False, # is_causal=False is for KV cache
235
+ dropout_p=self.attn_dropout_p if self.training else 0)
236
+
237
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
238
+
239
+ output = self.resid_dropout(self.wo(output))
240
+ return output
241
+
242
+
243
+ class TransformerBlock(nn.Module):
244
+ def __init__(self, config: ModelArgs, drop_path: float):
245
+ super().__init__()
246
+ self.attention = Attention(config)
247
+ self.feed_forward = FeedForward(config)
248
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
249
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
250
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
251
+
252
+ def forward(
253
+ self, x: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None):
254
+ h = x + self.drop_path(self.attention(self.attention_norm(x), start_pos, mask))
255
+ out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
256
+ return out
257
+
258
+
259
+ class LabelEmbedder(nn.Module):
260
+ """
261
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
262
+ """
263
+ def __init__(self, num_classes, hidden_size, dropout_prob):
264
+ super().__init__()
265
+ use_cfg_embedding = dropout_prob > 0
266
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
267
+ self.num_classes = num_classes
268
+ self.dropout_prob = dropout_prob
269
+
270
+ def token_drop(self, labels, force_drop_ids=None):
271
+ """
272
+ Drops labels to enable classifier-free guidance.
273
+ """
274
+ if force_drop_ids is None:
275
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
276
+ else:
277
+ drop_ids = force_drop_ids == 1
278
+ labels = torch.where(drop_ids, self.num_classes, labels)
279
+ return labels
280
+
281
+ def forward(self, labels, train, force_drop_ids=None):
282
+ use_dropout = self.dropout_prob > 0
283
+ if (train and use_dropout) or (force_drop_ids is not None):
284
+ labels = self.token_drop(labels, force_drop_ids)
285
+
286
+ # replace all negative labels with the last class (unconditional class)
287
+ labels = torch.where(labels < 0, self.num_classes, labels)
288
+ embeddings = self.embedding_table(labels)
289
+ return embeddings
290
+
291
+
292
+ class ARModel(nn.Module):
293
+ def __init__(self, config: ModelArgs):
294
+ super().__init__()
295
+ self.config = config
296
+ self.vocab_size = config.vocab_size
297
+ self.n_layer = config.n_layer
298
+ self.max_seq_length = config.max_seq_len
299
+ self.num_classes = config.num_classes
300
+ self.model_type = config.model_type
301
+ self.cls_token_num = config.cls_token_num
302
+ self.is_sampling = False
303
+ self.frame_prediction = config.frame_prediction
304
+
305
+ if self.model_type == 'class_cond':
306
+ self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob)
307
+ elif self.model_type == 'clip_cond':
308
+ self.clip_proj = nn.Linear(config.cond_dim, config.dim)
309
+ elif self.model_type == 'indice_cond':
310
+ self.clip_proj = LabelEmbedder(config.cond_vocab_size + 1, config.dim, 0.0)
311
+ else:
312
+ raise Exception("please check model type")
313
+
314
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
315
+ self.tok_dropout = nn.Dropout(config.token_dropout_p)
316
+
317
+ # transformer blocks
318
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.n_layer)]
319
+ self.layers = torch.nn.ModuleList()
320
+ for layer_id in range(config.n_layer):
321
+ self.layers.append(TransformerBlock(config, dpr[layer_id]))
322
+
323
+ # output layer
324
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
325
+ self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
326
+
327
+ if config.use_fixed_pe:
328
+ self.register_buffer('abs_pe', torch.zeros(1, config.max_seq_len + config.cls_token_num - 1, config.dim))
329
+ abs_pe = get_1d_sincos_pos_embed_from_grid(embed_dim=config.dim, pos=np.arange(config.max_seq_len + config.cls_token_num - 1))
330
+ self.abs_pe.copy_(torch.from_numpy(abs_pe).float().reshape_as(self.abs_pe))
331
+ print(f"Using fixed absolute PE")
332
+ else:
333
+ self.abs_pe = nn.Parameter(torch.randn(1, config.max_seq_len + config.cls_token_num - 1, config.dim) * 0.02)
334
+ print(f"Using learned absolute PE")
335
+
336
+ self.initialize_weights()
337
+
338
+ def initialize_weights(self):
339
+ # Initialize nn.Linear and nn.Embedding
340
+ self.apply(self._init_weights)
341
+
342
+ # Zero-out output layers:
343
+ if hasattr(self.output, 'weight') and isinstance(self.output.weight, nn.Parameter):
344
+ nn.init.constant_(self.output.weight, 0)
345
+
346
+ def _init_weights(self, module):
347
+ std = self.config.initializer_range
348
+ if isinstance(module, nn.Linear):
349
+ module.weight.data.normal_(mean=0.0, std=std)
350
+ if module.bias is not None:
351
+ module.bias.data.zero_()
352
+ elif isinstance(module, nn.Embedding):
353
+ module.weight.data.normal_(mean=0.0, std=std)
354
+
355
+
356
+ @property
357
+ def device(self):
358
+ return next(self.parameters()).device
359
+
360
+ @property
361
+ def dtype(self):
362
+ return next(self.parameters()).dtype
363
+
364
+
365
+ @contextmanager
366
+ def sampling(self):
367
+ self.is_sampling = True
368
+ try:
369
+ yield
370
+ finally:
371
+ self.is_sampling = False
372
+
373
+
374
+ def setup_caches(self, max_batch_size, max_seq_length, dtype):
375
+ assert max_seq_length == self.max_seq_length + self.cls_token_num, f'{max_seq_length} != {self.max_seq_length} + {self.cls_token_num=}'
376
+
377
+ head_dim = self.config.dim // self.config.n_head
378
+ max_seq_length = find_multiple(max_seq_length, 8)
379
+
380
+ for b in self.layers:
381
+ b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype)
382
+
383
+ causal_mask = torch.tril(torch.ones(max_seq_length, max_seq_length, dtype=torch.bool))
384
+ self.causal_mask = causal_mask.unsqueeze(0).repeat(max_batch_size, 1, 1)
385
+
386
+
387
+ def reset_caches(self):
388
+ for b in self.layers:
389
+ b.attention.kv_cache = None
390
+
391
+ def clip_embedding(self, x):
392
+ if self.model_type == 'clip_cond':
393
+ if self.training and self.config.class_dropout_prob > 0:
394
+ drop_ids = torch.rand(x.shape[0], device=x.device) < self.config.class_dropout_prob
395
+ x[drop_ids] = 0.
396
+ x = self.clip_proj(x.to(self.dtype)) # Linear
397
+ elif self.model_type == 'indice_cond':
398
+ if self.training and self.config.class_dropout_prob > 0:
399
+ drop_ids = torch.rand(x.shape[0], device=x.device) < self.config.class_dropout_prob
400
+ x[drop_ids] = self.config.cond_vocab_size
401
+ x = self.clip_proj(x, train=self.training) # Embedding
402
+ return x
403
+
404
+ def forward(
405
+ self,
406
+ idx: Optional[torch.Tensor], # (b, n)
407
+ cond_idx: Optional[torch.Tensor], # cond_idx_or_embed
408
+ input_pos: Optional[torch.Tensor] = None,
409
+ targets: Optional[torch.Tensor] = None,
410
+ mask: Optional[torch.Tensor] = None,
411
+ valid: Optional[torch.Tensor] = None,
412
+ ):
413
+ if idx is not None and cond_idx is not None: # training or naive inference
414
+ if self.model_type == 'class_cond':
415
+ cond_embeddings = self.cls_embedding(cond_idx, train=self.training).unsqueeze(1)[:,:self.cls_token_num]
416
+ elif self.model_type in ['clip_cond', 'indice_cond']:
417
+ cond_embeddings = self.clip_embedding(cond_idx)
418
+ token_embeddings = self.tok_embeddings(idx) # (b, n, d)
419
+ token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1) # (b, cls_token_num + n, d)
420
+ h = self.tok_dropout(token_embeddings)
421
+ else:
422
+ if cond_idx is not None: # prefill in inference
423
+ if self.model_type == 'class_cond':
424
+ token_embeddings = self.cls_embedding(cond_idx, train=self.training).unsqueeze(1)[:,:self.cls_token_num]
425
+ elif self.model_type in ['clip_cond', 'indice_cond']:
426
+ token_embeddings = self.clip_embedding(cond_idx)
427
+ else: # decode_n_tokens(kv cache) in inference
428
+ token_embeddings = self.tok_embeddings(idx)
429
+
430
+ bs = token_embeddings.shape[0]
431
+ mask = self.causal_mask[:bs, None, input_pos]
432
+ h = self.tok_dropout(token_embeddings)
433
+
434
+ if self.is_sampling:
435
+ h = h + self.abs_pe[:, input_pos]
436
+ else:
437
+ h = h + self.abs_pe[:, :h.shape[1]]
438
+
439
+ # transformer blocks
440
+ for layer in self.layers:
441
+ h = layer(h, input_pos, mask)
442
+
443
+ # output layers
444
+ h = self.norm(h)
445
+ logits = self.output(h)
446
+ # if self.training or self.is_sampling:
447
+ if cond_idx is not None:
448
+ # if self.training:
449
+ # logits = logits[:, self.cls_token_num - 1:].contiguous()
450
+ logits = logits[:, cond_idx.size(1) - 1:].contiguous()
451
+
452
+ # if we are given some desired targets also calculate the loss
453
+ loss = None
454
+ if valid is not None:
455
+ loss_all = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
456
+ valid_all = valid[:,None].repeat(1, targets.shape[1]).view(-1)
457
+ loss = (loss_all * valid_all).sum() / max(valid_all.sum(), 1)
458
+ elif targets is not None:
459
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
460
+ return logits, loss
461
+
462
+
463
+ @torch.inference_mode()
464
+ def sample(
465
+ self,
466
+ c,
467
+ cfg_scale=2.0,
468
+ cfg_interval=-1,
469
+ temperature=1.0,
470
+ top_k=0,
471
+ top_p=1.0,
472
+ seq_length=None,
473
+ ):
474
+ seq_length = self.max_seq_length if seq_length is None else seq_length
475
+ with self.sampling():
476
+ sampled_seqs = ar_generate(
477
+ self, c, seq_length,
478
+ cfg_scale=cfg_scale, cfg_interval=cfg_interval,
479
+ temperature=temperature, top_k=top_k,
480
+ top_p=top_p, sample_logits=True,
481
+ )
482
+ return sampled_seqs
483
+
484
+
485
+ @classmethod
486
+ def from_checkpoint(cls, ckpt, load_state_dict=True):
487
+ if isinstance(ckpt, str):
488
+ assert os.path.exists(ckpt), f"checkpoint {ckpt} does not exist"
489
+ ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage)
490
+ else:
491
+ assert isinstance(
492
+ ckpt, dict
493
+ ), f"checkpoint must be a dict or a path to a checkpoint"
494
+ model = models.make(ckpt["model"], load_sd=load_state_dict)
495
+ return model
496
+
497
+
498
+ #################################################################################
499
+ # LLAMA-ABS Configs #
500
+ #################################################################################
501
+
502
+ def LLAMA_ABS_XXXL(**kwargs):
503
+ return ARModel(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B
504
+
505
+ def LLAMA_ABS_XXL(**kwargs):
506
+ return ARModel(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B
507
+
508
+ def LLAMA_ABS_XL(**kwargs):
509
+ return ARModel(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M
510
+
511
+ def LLAMA_ABS_LP(**kwargs):
512
+ return ARModel(ModelArgs(n_layer=30, n_head=20, dim=1280, **kwargs)) # 632M
513
+
514
+ def LLAMA_ABS_L(**kwargs):
515
+ return ARModel(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M
516
+
517
+ def LLAMA_ABS_B(**kwargs):
518
+ return ARModel(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M
519
+
520
+ def LLAMA_ABS_S(**kwargs):
521
+ return ARModel(ModelArgs(n_layer=12, n_head=6, dim=384, **kwargs)) # 21.7M
522
+
523
+ ar_models = {
524
+ 'llama-abs-S': LLAMA_ABS_S,
525
+ 'llama-abs-B': LLAMA_ABS_B,
526
+ 'llama-abs-L': LLAMA_ABS_L,
527
+ 'llama-abs-LP': LLAMA_ABS_LP,
528
+ 'llama-abs-XL': LLAMA_ABS_XL,
529
+ 'llama-abs-XXL': LLAMA_ABS_XXL,
530
+ 'llama-abs-XXXL': LLAMA_ABS_XXXL,
531
+ }
532
+
533
+ models.models.update(ar_models)
tok/ar_dtok/bottleneck.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+ from .. import models
7
+ from ..models import register
8
+
9
+
10
+ @register("bottleneck")
11
+ class Bottleneck(nn.Module):
12
+ def __init__(
13
+ self,
14
+ bottleneck_dim: int,
15
+ input_dim: int,
16
+ output_dim: int,
17
+ token_nums: int,
18
+ regularizer=None,
19
+ **kwargs
20
+ ):
21
+ super().__init__()
22
+ self.token_nums = token_nums
23
+ self.input_dim = input_dim
24
+ self.output_dim = output_dim
25
+ if bottleneck_dim > 0:
26
+ self.bottleneck_dim = bottleneck_dim
27
+ else:
28
+ assert self.input_dim == self.output_dim, "input_dim and output_dim must be the same when bottleneck_dim is not specified"
29
+ self.bottleneck_dim = self.input_dim
30
+
31
+ self.project_dim = self.bottleneck_dim
32
+
33
+ if self.bottleneck_dim > 0:
34
+ self.in_linear = nn.Linear(self.input_dim, self.project_dim)
35
+ self.out_linear = nn.Linear(self.bottleneck_dim, self.output_dim)
36
+ else:
37
+ self.in_linear = self.out_linear = lambda x: x
38
+
39
+ regularizer['args']['dim'] = self.bottleneck_dim
40
+ regularizer['args']['token_nums'] = self.token_nums
41
+ self.regularizer = models.make(regularizer)
42
+
43
+ def project_in(self, x):
44
+ assert len(x.shape) == 3, "Input shape must be (batch, n_tokens, e_dim)"
45
+ z = self.in_linear(x)
46
+ return z
47
+
48
+ def project_out(self, z_cat):
49
+ z = self.out_linear(z_cat)
50
+ return z
51
+
52
+ def decode(self, bottleneck_rep):
53
+ regularized_z = self.regularizer.decode(bottleneck_rep)
54
+ return self.project_out(regularized_z)
55
+
56
+ def forward(self, x):
57
+ z = self.project_in(x)
58
+ projected_z = z
59
+ regularized_output = self.regularizer(z)
60
+ x_hat = self.project_out(regularized_output['regularized_z'])
61
+ bottleneck_rep = regularized_output.pop('bottleneck_rep')
62
+ return {
63
+ 'output': x_hat,
64
+ 'bottleneck_rep': bottleneck_rep,
65
+ 'projected_z': projected_z,
66
+ **regularized_output,
67
+ }
68
+
69
+
70
+ @register("simvq")
71
+ class SimVectorQuantizer(nn.Module):
72
+ def __init__(
73
+ self,
74
+ dim,
75
+ codebook_size,
76
+ l2_normalized=False,
77
+ same_index_shape=True,
78
+ stochastic=False,
79
+ stochastic_temperature=1.0,
80
+ **kwargs,
81
+ ):
82
+ super().__init__()
83
+ self.codebook_size = codebook_size
84
+ self.dim = dim
85
+ assert isinstance(l2_normalized, bool)
86
+ self.l2_normalized = l2_normalized
87
+ self.stochastic = stochastic
88
+ self.eval_deterministic = False
89
+ self.default_stochastic_temperature = stochastic_temperature
90
+
91
+ if self.stochastic:
92
+ if stochastic_temperature > 0: # fixed temperature
93
+ self.stochastic_temperature_inv = 1 / stochastic_temperature
94
+ else: # set stochastic_temperature < 0 to use learnable temperature
95
+ self.stochastic_temperature_inv = nn.Parameter(torch.tensor(10.0))
96
+
97
+ # for clear inference code, we remove the codebook init from LLM's embedding
98
+ self.embedding = nn.Embedding(self.codebook_size, self.dim)
99
+ self.embedding_proj = nn.Linear(self.dim, self.dim)
100
+
101
+ self.same_index_shape = same_index_shape
102
+
103
+ def set_eval_deterministic(self, deterministic=True):
104
+ self.eval_deterministic = deterministic
105
+
106
+ def set_stochastic_temperature(self, temperature):
107
+ self.stochastic_temperature_inv = 1 / temperature
108
+
109
+ @torch.autocast(device_type='cuda', enabled=False)
110
+ def get_emb(self):
111
+ emb = self.embedding_proj(self.embedding.weight)
112
+ if self.l2_normalized:
113
+ emb = F.normalize(emb, p=2, dim=-1)
114
+ # assert emb.dtype == torch.float32, f"Embedding weight dtype is {emb.dtype}, expected float32"
115
+ return emb
116
+
117
+ @torch.autocast(device_type='cuda', enabled=False)
118
+ def forward(self, z):
119
+ emb = self.get_emb()
120
+ z = z.to(emb)
121
+ # z = z.float()
122
+ assert len(z.shape) == 3, "Input shape must be (batch, n_tokens, e_dim)"
123
+ if self.l2_normalized:
124
+ z = F.normalize(z, p=2, dim=-1)
125
+
126
+ z_flattened = rearrange(z, 'b n d -> (b n) d')
127
+
128
+ if self.stochastic:
129
+ # sample the softmaxed cosine similarity
130
+ assert self.l2_normalized, "Stochastic sampling requires l2 normalization"
131
+ cos_sim = torch.einsum("bd,nd->bn", z_flattened, emb)
132
+ probs = F.softmax(cos_sim * self.stochastic_temperature_inv, dim=-1)
133
+ if self.eval_deterministic and not self.training:
134
+ q_indices = torch.argmax(probs, dim=-1)
135
+ else:
136
+ q_indices = torch.multinomial(probs, 1).squeeze(-1)
137
+ else:
138
+ d = (
139
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
140
+ + torch.sum(emb**2, dim=1)
141
+ - 2
142
+ * torch.einsum(
143
+ "bd,dn->bn", z_flattened, rearrange(emb, "n d -> d n")
144
+ )
145
+ )
146
+ q_indices = torch.argmin(d, dim=1)
147
+
148
+ quantized = F.embedding(q_indices, emb, self.embedding.padding_idx, self.embedding.max_norm,
149
+ self.embedding.norm_type, self.embedding.scale_grad_by_freq, self.embedding.sparse).view(z.shape) # (b, n, d)
150
+
151
+ # preserve gradients
152
+ quantized = z + (quantized - z).detach()
153
+
154
+ if self.same_index_shape:
155
+ q_indices = q_indices.reshape(quantized.shape[0], quantized.shape[1])
156
+
157
+ return_dict = {
158
+ 'unregularized_z': z, # but l2 normalized if l2_normalized=True
159
+ 'emb': emb, # but l2 normalized if l2_normalized=True
160
+ 'regularized_z': quantized,
161
+ 'bottleneck_rep': q_indices
162
+ }
163
+ return return_dict
164
+
165
+ def get_codebook_entry(self, indices, shape=None):
166
+ # shape specifying (batch, height, width, channel)
167
+ indices_shape = indices.shape
168
+ indices_flatten = rearrange(indices, '... -> (...)')
169
+
170
+ # get quantized latent vectors
171
+ emb = self.get_emb()
172
+ z_q = F.embedding(indices_flatten, emb)
173
+ # z_q = self.embedding(indices_flatten)
174
+ if self.l2_normalized:
175
+ z_q = F.normalize(z_q, p=2, dim=-1)
176
+
177
+ if shape is not None:
178
+ z_q = z_q.reshape(shape)
179
+ else:
180
+ z_q = z_q.reshape([*indices_shape, self.dim])
181
+ return z_q
182
+
183
+ def decode(self, indices):
184
+ return self.get_codebook_entry(indices)
tok/ar_dtok/generate.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # llamagen: https://github.com/FoundationVision/LlamaGen/blob/main/autoregressive/models/generate.py
3
+ # gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
4
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
5
+
6
+
7
+ import torch
8
+ import torch._dynamo.config
9
+ import torch._inductor.config
10
+ from torch.nn import functional as F
11
+
12
+
13
+ ### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html
14
+ def top_k_top_p_filtering(
15
+ logits,
16
+ top_k: int = 0,
17
+ top_p: float = 1.0,
18
+ filter_value: float = -float("Inf"),
19
+ min_tokens_to_keep: int = 1,
20
+ ):
21
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
22
+ Args:
23
+ logits: logits distribution shape (batch size, vocabulary size)
24
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
25
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
26
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
27
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
28
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
29
+ """
30
+ if top_k > 0:
31
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
32
+ # Remove all tokens with a probability less than the last token of the top-k
33
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
34
+ logits[indices_to_remove] = filter_value
35
+
36
+ if top_p < 1.0:
37
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
38
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
39
+
40
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
41
+ sorted_indices_to_remove = cumulative_probs > top_p
42
+ if min_tokens_to_keep > 1:
43
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
44
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
45
+ # Shift the indices to the right to keep also the first token above the threshold
46
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
47
+ sorted_indices_to_remove[..., 0] = 0
48
+
49
+ # scatter sorted tensors to original indexing
50
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
51
+ logits[indices_to_remove] = filter_value
52
+ return logits
53
+
54
+
55
+ def sample(logits, temperature: float=1.0, top_k: int=0, top_p: float=1.0, sample_logits=True):
56
+ logits = logits[:, -1, :] / max(temperature, 1e-5)
57
+ if top_k > 0 or top_p < 1.0:
58
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
59
+
60
+ # improve numerical stability of softmax
61
+ probs = F.softmax(logits.float(), dim=-1)
62
+ if sample_logits:
63
+ idx = torch.multinomial(probs, num_samples=1)
64
+ else:
65
+ _, idx = torch.topk(probs, k=1, dim=-1)
66
+ return idx, probs
67
+
68
+
69
+ def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: int = None, **kwargs):
70
+ logits = logits / max(temperature, 1e-5)
71
+ if top_k > 0 or top_p < 1.0:
72
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
73
+ probs = torch.nn.functional.softmax(logits, dim=-1)
74
+ return probs
75
+
76
+
77
+ def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, **sampling_kwargs):
78
+ if cfg_scale > 1.0:
79
+ logits, _ = model(None, cond_idx, input_pos)
80
+ logits_combined = logits
81
+ cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
82
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
83
+ else:
84
+ logits, _ = model(None, cond_idx, input_pos)
85
+
86
+ return sample(logits, **sampling_kwargs)[0]
87
+
88
+
89
+ def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, cfg_flag: bool, **sampling_kwargs):
90
+ assert input_pos.shape[-1] == 1
91
+ if cfg_scale > 1.0:
92
+ x_combined = torch.cat([x, x])
93
+ logits, _ = model(x_combined, cond_idx=None, input_pos=input_pos)
94
+ logits_combined = logits
95
+ cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
96
+ if cfg_flag:
97
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
98
+ else:
99
+ logits = cond_logits
100
+ else:
101
+ logits, _ = model(x, cond_idx=None, input_pos=input_pos)
102
+ return sample(logits, **sampling_kwargs)
103
+
104
+
105
+ def decode_n_tokens(
106
+ model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int,
107
+ cfg_scale: float, cfg_interval: int,
108
+ **sampling_kwargs):
109
+ new_tokens, new_probs = [], []
110
+ cfg_flag = True
111
+ for i in range(num_new_tokens):
112
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
113
+ if cfg_interval > -1 and i > cfg_interval:
114
+ cfg_flag = False
115
+ next_token, next_prob = decode_one_token(
116
+ model, cur_token, input_pos, cfg_scale, cfg_flag, **sampling_kwargs
117
+ )
118
+ input_pos += 1
119
+ new_tokens.append(next_token.clone())
120
+ new_probs.append(next_prob.clone())
121
+ cur_token = next_token.view(-1, 1)
122
+
123
+ return new_tokens, new_probs
124
+
125
+
126
+ @torch.no_grad()
127
+ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, **sampling_kwargs):
128
+ if model.frame_prediction:
129
+ assert cfg_scale == 1.0, "frame prediction requires cfg_scale=1.0 (no classifier-free guidance)"
130
+ cond_combined = cond
131
+ T = cond.shape[1]
132
+ elif model.model_type == 'class_cond':
133
+ if cfg_scale > 1.0:
134
+ cond_null = torch.ones_like(cond) * model.num_classes
135
+ cond_combined = torch.cat([cond, cond_null])
136
+ else:
137
+ cond_combined = cond
138
+ T = 1
139
+ elif model.model_type == 'clip_cond':
140
+ if cfg_scale > 1.0:
141
+ cond_null = torch.zeros_like(cond)
142
+ cond_combined = torch.cat([cond, cond_null])
143
+ else:
144
+ cond_combined = cond
145
+ T = model.cls_token_num
146
+ elif model.model_type == 'indice_cond':
147
+ if cfg_scale > 1.0:
148
+ cond_null = torch.ones_like(cond) * model.cond_vocab_size
149
+ cond_combined = torch.cat([cond, cond_null])
150
+ else:
151
+ cond_combined = cond
152
+ T = model.cls_token_num
153
+ else:
154
+ raise Exception("please check model type")
155
+
156
+ T_new = T + max_new_tokens
157
+ max_seq_length = T_new
158
+ max_batch_size = cond.shape[0]
159
+
160
+ device = cond.device
161
+ with torch.device(device):
162
+ max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size
163
+ model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=model.tok_embeddings.weight.dtype)
164
+
165
+ if emb_masks is not None:
166
+ assert emb_masks.shape[0] == max_batch_size
167
+ assert emb_masks.shape[-1] == T
168
+ if cfg_scale > 1.0:
169
+ model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1)
170
+ else:
171
+ model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1)
172
+
173
+ eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device)
174
+ model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix
175
+
176
+ # create an empty tensor of the expected final shape and fill in the current tokens
177
+ seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device)
178
+
179
+ input_pos = torch.arange(0, T, device=device)
180
+
181
+ next_token = prefill(model, cond_combined, input_pos, cfg_scale, **sampling_kwargs)
182
+ seq[:, T:T+1] = next_token
183
+
184
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
185
+ generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, **sampling_kwargs)
186
+ seq[:, T+1:] = torch.cat(generated_tokens, dim=1)
187
+
188
+ return seq[:, T:]
tok/ar_dtok/vqvae.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from ..models import register
9
+ from ..utils import ScalingLayer
10
+
11
+
12
+ @register('vqvae')
13
+ class VQVAE(nn.Module):
14
+ def __init__(
15
+ self,
16
+ model='VQ-16',
17
+ ckpt='',
18
+ codebook_size=16384,
19
+ codebook_embed_dim=8,
20
+ bottleneck_token_num=256,
21
+ input_size=256,
22
+ *args,
23
+ **kwargs,
24
+ ):
25
+ super().__init__()
26
+ self.codebook_size = codebook_size
27
+ self.codebook_embed_dim = codebook_embed_dim
28
+ self.bottleneck_token_num = bottleneck_token_num
29
+ self.input_size = input_size
30
+ self.model = VQ_models[model](
31
+ codebook_size=codebook_size,
32
+ codebook_embed_dim=codebook_embed_dim)
33
+ ckpt = torch.load(ckpt, map_location='cpu')
34
+ self.model.load_state_dict(ckpt['model'])
35
+ self.model.eval()
36
+
37
+ self.scale_layer = ScalingLayer(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
38
+
39
+ @classmethod
40
+ def from_checkpoint(cls, ckpt, **kwargs):
41
+ model = cls(ckpt=ckpt, **kwargs)
42
+ return model
43
+
44
+ def decode_from_bottleneck(self, z):
45
+ if z.ndim == 2:
46
+ b = z.size(0)
47
+ h = w = int(z.size(-1) ** 0.5)
48
+ z = self.model.decode_code(z, (b, self.codebook_embed_dim, h, w))
49
+ return self.scale_layer.inv(z)
50
+
51
+
52
+ # Adapt from https://github.com/FoundationVision/LlamaGen/blob/main/tokenizer/tokenizer_image/vq_model.py
53
+ @dataclass
54
+ class ModelArgs:
55
+ codebook_size: int = 16384
56
+ codebook_embed_dim: int = 8
57
+ codebook_l2_norm: bool = True
58
+ codebook_show_usage: bool = True
59
+ commit_loss_beta: float = 0.25
60
+ entropy_loss_ratio: float = 0.0
61
+
62
+ encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
63
+ decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
64
+ z_channels: int = 256
65
+ dropout_p: float = 0.0
66
+
67
+
68
+ class VQModel(nn.Module):
69
+ def __init__(self, config: ModelArgs):
70
+ super().__init__()
71
+ self.config = config
72
+ self.encoder = Encoder(ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
73
+ self.decoder = Decoder(ch_mult=config.decoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
74
+
75
+ self.quantize = VectorQuantizer(config.codebook_size, config.codebook_embed_dim,
76
+ config.commit_loss_beta, config.entropy_loss_ratio,
77
+ config.codebook_l2_norm, config.codebook_show_usage)
78
+ self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
79
+ self.post_quant_conv = nn.Conv2d(config.codebook_embed_dim, config.z_channels, 1)
80
+
81
+ def encode(self, x):
82
+ h = self.encoder(x)
83
+ h = self.quant_conv(h)
84
+ quant, emb_loss, info = self.quantize(h)
85
+ return quant, emb_loss, info
86
+
87
+ def decode(self, quant):
88
+ quant = self.post_quant_conv(quant)
89
+ dec = self.decoder(quant)
90
+ return dec
91
+
92
+ def decode_code(self, code_b, shape=None, channel_first=True):
93
+ quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
94
+ dec = self.decode(quant_b)
95
+ return dec
96
+
97
+ def forward(self, input):
98
+ quant, diff, _ = self.encode(input)
99
+ dec = self.decode(quant)
100
+ return dec, diff
101
+
102
+
103
+ class Encoder(nn.Module):
104
+ def __init__(self, in_channels=3, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2,
105
+ norm_type='group', dropout=0.0, resamp_with_conv=True, z_channels=256):
106
+ super().__init__()
107
+ self.num_resolutions = len(ch_mult)
108
+ self.num_res_blocks = num_res_blocks
109
+ self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
110
+
111
+ # downsampling
112
+ in_ch_mult = (1,) + tuple(ch_mult)
113
+ self.conv_blocks = nn.ModuleList()
114
+ for i_level in range(self.num_resolutions):
115
+ conv_block = nn.Module()
116
+ # res & attn
117
+ res_block = nn.ModuleList()
118
+ attn_block = nn.ModuleList()
119
+ block_in = ch*in_ch_mult[i_level]
120
+ block_out = ch*ch_mult[i_level]
121
+ for _ in range(self.num_res_blocks):
122
+ res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
123
+ block_in = block_out
124
+ if i_level == self.num_resolutions - 1:
125
+ attn_block.append(AttnBlock(block_in, norm_type))
126
+ conv_block.res = res_block
127
+ conv_block.attn = attn_block
128
+ # downsample
129
+ if i_level != self.num_resolutions-1:
130
+ conv_block.downsample = Downsample(block_in, resamp_with_conv)
131
+ self.conv_blocks.append(conv_block)
132
+
133
+ # middle
134
+ self.mid = nn.ModuleList()
135
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
136
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
137
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
138
+
139
+ # end
140
+ self.norm_out = Normalize(block_in, norm_type)
141
+ self.conv_out = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)
142
+
143
+ def forward(self, x):
144
+ h = self.conv_in(x)
145
+ # downsampling
146
+ for i_level, block in enumerate(self.conv_blocks):
147
+ for i_block in range(self.num_res_blocks):
148
+ h = block.res[i_block](h)
149
+ if len(block.attn) > 0:
150
+ h = block.attn[i_block](h)
151
+ if i_level != self.num_resolutions - 1:
152
+ h = block.downsample(h)
153
+
154
+ # middle
155
+ for mid_block in self.mid:
156
+ h = mid_block(h)
157
+
158
+ # end
159
+ h = self.norm_out(h)
160
+ h = nonlinearity(h)
161
+ h = self.conv_out(h)
162
+ return h
163
+
164
+
165
+ class Decoder(nn.Module):
166
+ def __init__(self, z_channels=256, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, norm_type="group",
167
+ dropout=0.0, resamp_with_conv=True, out_channels=3):
168
+ super().__init__()
169
+ self.num_resolutions = len(ch_mult)
170
+ self.num_res_blocks = num_res_blocks
171
+
172
+ block_in = ch*ch_mult[self.num_resolutions-1]
173
+ # z to block_in
174
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
175
+
176
+ # middle
177
+ self.mid = nn.ModuleList()
178
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
179
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
180
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
181
+
182
+ # upsampling
183
+ self.conv_blocks = nn.ModuleList()
184
+ for i_level in reversed(range(self.num_resolutions)):
185
+ conv_block = nn.Module()
186
+ # res & attn
187
+ res_block = nn.ModuleList()
188
+ attn_block = nn.ModuleList()
189
+ block_out = ch*ch_mult[i_level]
190
+ for _ in range(self.num_res_blocks + 1):
191
+ res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
192
+ block_in = block_out
193
+ if i_level == self.num_resolutions - 1:
194
+ attn_block.append(AttnBlock(block_in, norm_type))
195
+ conv_block.res = res_block
196
+ conv_block.attn = attn_block
197
+ # downsample
198
+ if i_level != 0:
199
+ conv_block.upsample = Upsample(block_in, resamp_with_conv)
200
+ self.conv_blocks.append(conv_block)
201
+
202
+ # end
203
+ self.norm_out = Normalize(block_in, norm_type)
204
+ self.conv_out = nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
205
+
206
+ @property
207
+ def last_layer(self):
208
+ return self.conv_out.weight
209
+
210
+ def forward(self, z):
211
+ # z to block_in
212
+ h = self.conv_in(z)
213
+
214
+ # middle
215
+ for mid_block in self.mid:
216
+ h = mid_block(h)
217
+
218
+ # upsampling
219
+ for i_level, block in enumerate(self.conv_blocks):
220
+ for i_block in range(self.num_res_blocks + 1):
221
+ h = block.res[i_block](h)
222
+ if len(block.attn) > 0:
223
+ h = block.attn[i_block](h)
224
+ if i_level != self.num_resolutions - 1:
225
+ h = block.upsample(h)
226
+
227
+ # end
228
+ h = self.norm_out(h)
229
+ h = nonlinearity(h)
230
+ h = self.conv_out(h)
231
+ return h
232
+
233
+
234
+ class VectorQuantizer(nn.Module):
235
+ def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
236
+ super().__init__()
237
+ self.n_e = n_e
238
+ self.e_dim = e_dim
239
+ self.beta = beta
240
+ self.entropy_loss_ratio = entropy_loss_ratio
241
+ self.l2_norm = l2_norm
242
+ self.show_usage = show_usage
243
+
244
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
245
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
246
+ if self.l2_norm:
247
+ self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1)
248
+ if self.show_usage:
249
+ self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
250
+
251
+ def forward(self, z):
252
+ # reshape z -> (batch, height, width, channel) and flatten
253
+ z = torch.einsum('b c h w -> b h w c', z).contiguous()
254
+ z_flattened = z.view(-1, self.e_dim)
255
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
256
+
257
+ if self.l2_norm:
258
+ z = F.normalize(z, p=2, dim=-1)
259
+ z_flattened = F.normalize(z_flattened, p=2, dim=-1)
260
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
261
+ else:
262
+ embedding = self.embedding.weight
263
+
264
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
265
+ torch.sum(embedding**2, dim=1) - 2 * \
266
+ torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding))
267
+
268
+ min_encoding_indices = torch.argmin(d, dim=1)
269
+ z_q = embedding[min_encoding_indices].view(z.shape)
270
+ perplexity = None
271
+ min_encodings = None
272
+ vq_loss = None
273
+ commit_loss = None
274
+ entropy_loss = None
275
+ codebook_usage = 0
276
+
277
+ if self.show_usage and self.training:
278
+ cur_len = min_encoding_indices.shape[0]
279
+ self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone()
280
+ self.codebook_used[-cur_len:] = min_encoding_indices
281
+ codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e
282
+
283
+ # compute loss for embedding
284
+ if self.training:
285
+ vq_loss = torch.mean((z_q - z.detach()) ** 2)
286
+ commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
287
+ entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
288
+
289
+ # preserve gradients
290
+ z_q = z + (z_q - z).detach()
291
+
292
+ # reshape back to match original input shape
293
+ z_q = torch.einsum('b h w c -> b c h w', z_q)
294
+
295
+ return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices)
296
+
297
+ def get_codebook_entry(self, indices, shape=None, channel_first=True):
298
+ # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
299
+ if self.l2_norm:
300
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
301
+ else:
302
+ embedding = self.embedding.weight
303
+ z_q = embedding[indices] # (b*h*w, c)
304
+
305
+ if shape is not None:
306
+ if channel_first:
307
+ z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
308
+ # reshape back to match original input shape
309
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
310
+ else:
311
+ z_q = z_q.view(shape)
312
+ return z_q
313
+
314
+
315
+ class ResnetBlock(nn.Module):
316
+ def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group'):
317
+ super().__init__()
318
+ self.in_channels = in_channels
319
+ out_channels = in_channels if out_channels is None else out_channels
320
+ self.out_channels = out_channels
321
+ self.use_conv_shortcut = conv_shortcut
322
+
323
+ self.norm1 = Normalize(in_channels, norm_type)
324
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
325
+ self.norm2 = Normalize(out_channels, norm_type)
326
+ self.dropout = nn.Dropout(dropout)
327
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
328
+
329
+ if self.in_channels != self.out_channels:
330
+ if self.use_conv_shortcut:
331
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
332
+ else:
333
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
334
+
335
+ def forward(self, x):
336
+ h = x
337
+ h = self.norm1(h)
338
+ h = nonlinearity(h)
339
+ h = self.conv1(h)
340
+ h = self.norm2(h)
341
+ h = nonlinearity(h)
342
+ h = self.dropout(h)
343
+ h = self.conv2(h)
344
+
345
+ if self.in_channels != self.out_channels:
346
+ if self.use_conv_shortcut:
347
+ x = self.conv_shortcut(x)
348
+ else:
349
+ x = self.nin_shortcut(x)
350
+ return x+h
351
+
352
+
353
+ class AttnBlock(nn.Module):
354
+ def __init__(self, in_channels, norm_type='group'):
355
+ super().__init__()
356
+ self.norm = Normalize(in_channels, norm_type)
357
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
358
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
359
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
360
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
361
+
362
+ def forward(self, x):
363
+ h_ = x
364
+ h_ = self.norm(h_)
365
+ q = self.q(h_)
366
+ k = self.k(h_)
367
+ v = self.v(h_)
368
+
369
+ # compute attention
370
+ b,c,h,w = q.shape
371
+ q = q.reshape(b,c,h*w)
372
+ q = q.permute(0,2,1) # b,hw,c
373
+ k = k.reshape(b,c,h*w) # b,c,hw
374
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
375
+ w_ = w_ * (int(c)**(-0.5))
376
+ w_ = F.softmax(w_, dim=2)
377
+
378
+ # attend to values
379
+ v = v.reshape(b,c,h*w)
380
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
381
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
382
+ h_ = h_.reshape(b,c,h,w)
383
+
384
+ h_ = self.proj_out(h_)
385
+
386
+ return x+h_
387
+
388
+ def nonlinearity(x):
389
+ # swish
390
+ return x*torch.sigmoid(x)
391
+
392
+ def Normalize(in_channels, norm_type='group'):
393
+ assert norm_type in ['group', 'batch']
394
+ if norm_type == 'group':
395
+ return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
396
+ elif norm_type == 'batch':
397
+ return nn.SyncBatchNorm(in_channels)
398
+
399
+
400
+ class Upsample(nn.Module):
401
+ def __init__(self, in_channels, with_conv):
402
+ super().__init__()
403
+ self.with_conv = with_conv
404
+ if self.with_conv:
405
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
406
+
407
+ def forward(self, x):
408
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
409
+ if self.with_conv:
410
+ x = self.conv(x)
411
+ return x
412
+
413
+
414
+ class Downsample(nn.Module):
415
+ def __init__(self, in_channels, with_conv):
416
+ super().__init__()
417
+ self.with_conv = with_conv
418
+ if self.with_conv:
419
+ # no asymmetric padding in torch conv, must do it ourselves
420
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
421
+
422
+ def forward(self, x):
423
+ if self.with_conv:
424
+ pad = (0,1,0,1)
425
+ x = F.pad(x, pad, mode="constant", value=0)
426
+ x = self.conv(x)
427
+ else:
428
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
429
+ return x
430
+
431
+
432
+ def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
433
+ flat_affinity = affinity.reshape(-1, affinity.shape[-1])
434
+ flat_affinity /= temperature
435
+ probs = F.softmax(flat_affinity, dim=-1)
436
+ log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
437
+ if loss_type == "softmax":
438
+ target_probs = probs
439
+ else:
440
+ raise ValueError("Entropy loss {} not supported".format(loss_type))
441
+ avg_probs = torch.mean(target_probs, dim=0)
442
+ avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
443
+ sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1))
444
+ loss = sample_entropy - avg_entropy
445
+ return loss
446
+
447
+
448
+ #################################################################################
449
+ # VQ Model Configs #
450
+ #################################################################################
451
+ def VQ_8(**kwargs):
452
+ return VQModel(ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs))
453
+
454
+ def VQ_16(**kwargs):
455
+ return VQModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs))
456
+
457
+ VQ_models = {'VQ-16': VQ_16, 'VQ-8': VQ_8}
tok/mm_autoencoder.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from tok.ar_dtok.ar_model import ARModel
5
+ from tok.ar_dtok.vqvae import VQVAE
6
+ from tok.ta_tok import TextAlignedTokenizer
7
+
8
+
9
+ class MMAutoEncoder(nn.Module):
10
+ def __init__(self,
11
+ ar_path_dict,
12
+ encoder_path, decoder_path,
13
+ encoder_args={}, decoder_args={}):
14
+ super().__init__()
15
+ self.ar_model = nn.ModuleDict({resolution: ARModel.from_checkpoint(ar_path) for resolution, ar_path in ar_path_dict.items()})
16
+
17
+ self.encoder = TextAlignedTokenizer.from_checkpoint(encoder_path, load_teacher=False, **encoder_args)
18
+ self.decoder = VQVAE.from_checkpoint(decoder_path, **decoder_args)
19
+
20
+ def ar_sample(self, x, args):
21
+ resolution = args.get("resolution", "1024px")
22
+ x = self.ar_model[resolution].sample(
23
+ x,
24
+ cfg_scale=args.get('cfg_scale', 1.0),
25
+ cfg_interval=args.get('cfg_interval', -1),
26
+ temperature=args.get('temperature', 1.0),
27
+ top_k=args.get('top_k', 0),
28
+ top_p=args.get('top_p', 1.0)
29
+ )
30
+ return x
31
+
32
+ def post_process(self, x):
33
+ x = x.cpu().float().clamp(0., 1.) * 255.
34
+ x = x.permute(0, 2, 3, 1) # [b, h, w, c]
35
+ x = x.to(torch.uint8)
36
+ return x
37
+
38
+ def encode(self, x):
39
+ return self.encoder(x.to(self.encoder.dtype))['encoded']
40
+
41
+ def get_encoder_indices(self, x):
42
+ # img -> encoder -> indices
43
+ return self.encoder(x.to(self.encoder.dtype))['bottleneck_rep']
44
+
45
+ @torch.inference_mode()
46
+ def decode_from_encoder_indices(self, indices, args={}):
47
+ # indices -> encoder feats -> ar -> decoder
48
+ encoder_x = self.encoder.decode_from_bottleneck(indices)
49
+ ar_indices = self.ar_sample(encoder_x, args)
50
+ decoder_x = self.decoder.decode_from_bottleneck(ar_indices)
51
+ x = self.post_process(decoder_x)
52
+ return x
53
+
54
+ def decode_from_vqvae_indices(self, indices):
55
+ decoder_x = self.decoder.decode_from_bottleneck(indices)
56
+ x = self.post_process(decoder_x)
57
+ return x
58
+
59
+ @torch.inference_mode()
60
+ def forward(self, x, args={}):
61
+ encoder_x = self.encoder(x.to(self.encoder.dtype))['encoded']
62
+ ar_indices = self.ar_sample(encoder_x, args)
63
+ decoder_x = self.decoder.decode_from_bottleneck(ar_indices)
64
+ x = self.post_process(decoder_x)
65
+ return x
tok/models.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import inspect
3
+
4
+ import torch
5
+
6
+ models = {}
7
+
8
+
9
+ def register(name):
10
+ def decorator(cls):
11
+ models[name] = cls
12
+ return cls
13
+ return decorator
14
+
15
+
16
+ def make(model_spec, args=None, load_sd=False) -> torch.nn.Module:
17
+ if args is not None:
18
+ model_args = copy.deepcopy(model_spec['args'])
19
+ model_args.update(args)
20
+ else:
21
+ model_args = model_spec['args']
22
+ model_params = inspect.signature(models[model_spec['name']]).parameters
23
+ if 'kwargs' not in model_params:
24
+ model_args = {k: v for k, v in model_args.items() if k in model_params}
25
+ model = models[model_spec['name']](**model_args)
26
+ if load_sd:
27
+ if ('abs_pe' in model_spec['sd']) and hasattr(model, 'abs_pe') and model_spec['sd']['abs_pe'].shape != model.abs_pe.shape:
28
+ del model_spec['sd']['abs_pe']
29
+ msg = model.load_state_dict(model_spec['sd'], strict=False)
30
+ print(msg)
31
+ return model
tok/ta_tok.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ from torchvision.transforms import Resize
6
+ from transformers import AutoConfig, AutoModel, Siglip2VisionConfig, Siglip2VisionModel
7
+
8
+ from . import models
9
+ from .utils import ScalingLayer
10
+
11
+
12
+ class TextAlignedTokenizer(nn.Module):
13
+ def __init__(
14
+ self,
15
+ bottleneck,
16
+ bottleneck_token_num=256,
17
+ input_size=384,
18
+ teacher='google/siglip2-so400m-patch14-384',
19
+ input_type='quant', # choose from ['quant', 'rec', 'indices']
20
+ pool_scale=1, # choose from [1, 2, 3]
21
+ decoder_depth=3,
22
+ select_layer_id=-2,
23
+ *args,
24
+ **kwargs
25
+ ):
26
+ super().__init__()
27
+ self.input_size = input_size
28
+ self.bottleneck_token_num = bottleneck_token_num
29
+ self.teacher = teacher
30
+ self.input_type = input_type
31
+ self.pool_scale = pool_scale
32
+ self.decoder_depth = decoder_depth
33
+ self.select_layer_id = select_layer_id
34
+
35
+ self.bottleneck_dim = bottleneck['args']['bottleneck_dim']
36
+
37
+ self.encoder_config = AutoConfig.from_pretrained(teacher)
38
+ self.encoder = AutoModel.from_config(self.encoder_config).vision_model
39
+
40
+ self.encoder_hidden_dim = self.encoder.config.hidden_size
41
+
42
+ self.decoder_config = Siglip2VisionConfig()
43
+ self.decoder_config.update({
44
+ 'patch_size': 1,
45
+ 'num_hidden_layers': self.decoder_depth,
46
+ 'num_channels': self.bottleneck_dim,
47
+ 'hidden_size': self.encoder_hidden_dim,
48
+ })
49
+ self.decoder = Siglip2VisionModel(self.decoder_config)
50
+
51
+ self.encode_task_layer = nn.Sequential(
52
+ nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim),
53
+ nn.Tanh())
54
+ self.decode_task_layer = nn.Sequential(
55
+ nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim),
56
+ nn.Tanh(),
57
+ nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim))
58
+
59
+ bottleneck_args = {
60
+ 'token_nums': self.bottleneck_token_num,
61
+ 'input_dim': self.encoder_hidden_dim,
62
+ 'output_dim': self.bottleneck_dim}
63
+ self.bottleneck = models.make(bottleneck, args=bottleneck_args)
64
+
65
+ self.scale_layer = ScalingLayer(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
66
+ self.image_resize = Resize((self.input_size, self.input_size))
67
+
68
+ def set_vq_eval_deterministic(self, deterministic=True):
69
+ self.bottleneck.regularizer.set_eval_deterministic(deterministic)
70
+
71
+ @property
72
+ def device(self):
73
+ return next(self.parameters()).device
74
+
75
+ @property
76
+ def dtype(self):
77
+ return next(self.parameters()).dtype
78
+
79
+ @classmethod
80
+ def from_checkpoint(cls, ckpt, load_teacher=True, **kwargs):
81
+ ckpt = torch.load(ckpt, map_location='cpu')
82
+ ckpt_kwargs = ckpt["model"]["args"]
83
+ model = cls(**kwargs, **ckpt_kwargs)
84
+ sd = ckpt["model"]["sd"]
85
+ if not load_teacher:
86
+ sd = {k: v for k, v in sd.items() if not k.startswith('teacher')}
87
+ model.load_state_dict(sd, strict=True)
88
+ return model
89
+
90
+ def encode(self, x, **kwargs):
91
+ if x.ndim == 5:
92
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
93
+ x = self.scale_layer(x)
94
+ if tuple(x.shape[-2:]) != (self.input_size, self.input_size):
95
+ x = self.image_resize(x)
96
+ vq_feats = self.encoder(x, output_hidden_states=True).hidden_states[self.select_layer_id]
97
+
98
+ pool_scale = self.pool_scale
99
+ pool_scale = kwargs.get("pool_scale", pool_scale)
100
+ if pool_scale != 1:
101
+ vq_feats = self.avg_pool(vq_feats, pool_scale)
102
+ vq_feats = self.encode_task_layer(vq_feats.to(x))
103
+
104
+ bottleneck_out = self.bottleneck(vq_feats)
105
+ z = bottleneck_out.pop('output')
106
+
107
+ return {'encoded': z, 'pool_scale': pool_scale, 'vq_feats': vq_feats, **bottleneck_out}
108
+
109
+ def avg_pool(self, z, pool_scale=1):
110
+ if z.ndim == 3:
111
+ b, n, c = z.shape
112
+ p = int(n ** 0.5)
113
+ z = rearrange(z, 'b (p1 p2) c -> b c p1 p2', p1=p, p2=p)
114
+ else:
115
+ b, c, p, _ = z.shape
116
+ p_s = int(p // pool_scale)
117
+ z = F.avg_pool2d(
118
+ z,
119
+ kernel_size=(pool_scale, pool_scale),
120
+ stride=(pool_scale, pool_scale)
121
+ ).contiguous()
122
+ z = rearrange(z, 'b c p1 p2 -> b (p1 p2) c')
123
+ return z
124
+
125
+ def decode(self, z):
126
+ if z.ndim == 4:
127
+ z = rearrange(z, 'b c p1 p2 -> b (p1 p2) c')
128
+ attention_mask = torch.ones(z.shape[:2], dtype=torch.int, device=z.device)
129
+ p = int(z.shape[1]**0.5)
130
+ spatial_shape = torch.tensor([[p, p]]*z.shape[0], device=self.device)
131
+ z = self.decoder(z, attention_mask, spatial_shape, output_hidden_states=True).last_hidden_state
132
+ z = self.decode_task_layer(z)
133
+ return z
134
+
135
+ def decode_from_bottleneck(self, bottleneck_rep):
136
+ z = self.bottleneck.decode(bottleneck_rep) # (b, n, c)
137
+ p = int(z.shape[1]**0.5)
138
+ z = rearrange(z, 'b (p1 p2) c -> b c p1 p2', p1=p, p2=p)
139
+ return self.decode(z)
140
+
141
+ def forward(self, data, **kwargs):
142
+ # data: video in shape (b, c, t, h, w)
143
+ encode_output = self.encode(data, **kwargs)
144
+ vq_feats = encode_output['encoded']
145
+ p = int(vq_feats.shape[1] ** 0.5)
146
+ vq_feats = rearrange(vq_feats, 'b (h w) c -> b c h w', h=p, w=p)
147
+ pred_feats = self.decode(vq_feats)
148
+
149
+ if self.input_type == 'quant':
150
+ z = encode_output["regularized_z"] # [b, n, c]
151
+ elif self.input_type == 'indices':
152
+ z = encode_output["bottleneck_rep"] # [b, n]
153
+ elif self.input_type == 'rec':
154
+ z = pred_feats # [b, n, c]
155
+ encode_output['encoded'] = z
156
+ return encode_output
tok/utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class ScalingLayer(nn.Module):
6
+ def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
7
+ super().__init__()
8
+ self.register_buffer('shift', torch.Tensor(mean)[None, :, None, None])
9
+ self.register_buffer('scale', torch.Tensor(std)[None, :, None, None])
10
+
11
+ def forward(self, inp):
12
+ return (inp - self.shift) / self.scale
13
+
14
+ def inv(self, inp):
15
+ return inp * self.scale + self.shift