cai-qi commited on
Commit
aa4fdd4
·
verified ·
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/test_3.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/test.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.swp
2
+ **/__pycache__/**
3
+ **/.ipynb_checkpoints/**
4
+ .idea/*
5
+ llava/
6
+ _vis_cached/
7
+ _vqgan/
8
+ _vae/
9
+ _vae*/
10
+ ckpt/
11
+ log/
12
+ tb*/
13
+ img*/
14
+ local_output*
15
+ _auto_*
16
+ sd-vae-ft-mse/
17
+ stable-diffusion-v1-4/
18
+ *.pth
19
+ *.pth.tar
20
+ *.ckpt
21
+ *.log
22
+ *.txt
23
+ *.ipynb
24
+ toscli
25
+ *.hydra
26
+ wandb
27
+ *.jsonl
28
+ *.jpg
29
+ *.png
30
+ *.json
31
+ *.csv
32
+ *.tar.gz
33
+ *.bin
34
+ data/
35
+ tmp
36
+ output
37
+ *.tsv
38
+ *.mp4
39
+ output/*
40
+ results/
41
+ *.JPEG
42
+ debug/
43
+ weights
44
+ checkpoints
45
+ ref.py
46
+ wandb
47
+ .DS_Store
48
+ clean*
49
+ *_local.sh
50
+ *_local.py
51
+ HiDream-ai/
52
+ *.pyc
53
+ __pycache__
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 FoundationVision
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: VAREdit-8B-512
3
+ emoji: 🚀
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.43.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ models:
12
+ - HiDream-ai/VAREdit
13
+ ---
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+ # VAREdit
17
+
18
+ ![VAREdit Demo](assets/demo.jpg)
19
+
20
+ [VAREdit](https://github.com/HiDream-ai/VAREdit) is an advanced image editing model built on the [Infinity](https://huggingface.co/FoundationVision/infinity) models, designed for high-quality instruction-based image editing.
21
+
22
+ ## 🌟 Key Features
23
+
24
+ - **Strong Instruction Follow**: Follows instructions more accurately due to the autoregressive nature of the model.
25
+ - **Efficient Inference**: Optimized for fast generation with less than 1 seconds for 8B model.
26
+ - **Flexible Resolution**: Supports 512×512 and 1024×1024 image resolutions
27
+ ![VAREdit Demo](assets/framework.jpg)
28
+
29
+ ## 📊 Model Variants
30
+
31
+ | Model Variant | Resolutions | HuggingFace Model | Time (H800) | VRAM (GB) |
32
+ |------------------|--------------|----------------------------------------------------------------------------------|----------|-----------|
33
+ | VAREdit-8B-512 | 512×512 | [VAREdit-8B-512](https://huggingface.co/HiDream-ai/VAREdit) | ~0.7s | 50.41 |
34
+ | VAREdit-8B-1024 | 1024×1024 | [VAREdit-8B-1024](https://huggingface.co/HiDream-ai/VAREdit) | ~1.99s | 50.41 |
35
+
36
+ ## 🚀 Quick Start
37
+
38
+ ### Prerequisites
39
+
40
+ Before starting, ensure you have:
41
+ - Python 3.8+
42
+ - CUDA-compatible GPU with sufficient VRAM (8GB+ for 2B model, 24GB+ for 8B model)
43
+ - Required dependencies installed
44
+
45
+ ### Installation
46
+
47
+ 1. **Clone the repository**
48
+ ```bash
49
+ git clone https://github.com/HiDream-ai/VAREdit.git
50
+ cd VAREdit
51
+ ```
52
+
53
+ 2. **Install dependencies**
54
+ ```bash
55
+ pip install -r requirements.txt
56
+ ```
57
+
58
+ 3. **Download model checkpoints**
59
+
60
+ Download the VAREdit model checkpoints:
61
+ ```bash
62
+ # Download from HuggingFace
63
+ git lfs install
64
+ git clone https://huggingface.co/HiDream-ai/VAREdit
65
+ ```
66
+
67
+ ### Basic Usage
68
+
69
+ ```python
70
+ from infer import load_model, generate_image
71
+
72
+ model_components = load_model(
73
+ pretrain_root="HiDream-ai/VAREdit",
74
+ model_path="HiDream-ai/VAREdit/8B-1024.pth",
75
+ model_size="8B",
76
+ image_size=1024
77
+ )
78
+
79
+ # Generate edited image
80
+ edited_image = generate_image(
81
+ model_components,
82
+ src_img_path="assets/test.jpg",
83
+ instruction="Add glasses to this girl and change hair color to red",
84
+ cfg=3.0, # Classifier-free guidance scale
85
+ tau=0.1, # Temperature parameter
86
+ seed=42 # Optional random seed
87
+ )
88
+ ```
89
+
90
+ ## 📝 Detailed Configuration
91
+
92
+ ### Model Sampling Parameters
93
+
94
+ | Parameter | Description | Default |
95
+ |-----------|-------------|---------|
96
+ | `cfg` | Classifier-free guidance scale | 3.0 |
97
+ | `tau` | Temperature for sampling | 1.0 |
98
+ | `seed` | Random seed for reproducibility | -1 (random) |
99
+
100
+ ## 📂 Project Structure
101
+
102
+ ```
103
+ VAREdit/
104
+ ├── infer.py # Main inference script
105
+ ├── infinity/ # Core model implementations
106
+ │ ├── models/ # Model architectures
107
+ │ ├── dataset/ # Data processing utilities
108
+ │ └── utils/ # Helper functions
109
+ ├── tools/ # Additional tools and scripts
110
+ │ └── run_infinity.py # Model execution utilities
111
+ ├── assets/ # Demo images and resources
112
+ └── README.md # This file
113
+ ```
114
+
115
+ ## 📊 Performance Benchmarks
116
+ | **Method** | **Size** | **EMU-Edit Bal.** | **PIE-Bench Bal.** | **Time (A800)** |
117
+ |:---|:---:|:---:|:---:|:---:|
118
+ | InstructPix2Pix | 1.1B | 2.923 | 4.034 | 3.5s |
119
+ | UltraEdit | 7.7B | 4.541 | 5.580 | 2.6s |
120
+ | OmniGen | 3.8B | 4.674 | 3.492 | 16.5s |
121
+ | AnySD | 2.9B | 3.129 | 3.326 | 3.4s |
122
+ | EditAR | 0.8B | 3.305 | 4.707 | 45.5s |
123
+ | ACE++ | 16.9B | 2.076 | 2.574 | 5.7s |
124
+ | ICEdit | 17.0B | 4.785 | 4.933 | 8.4s |
125
+ | **VAREdit** (256px) | 2.2B | 5.565 | 6.684 | 0.5s |
126
+ | **VAREdit** (512px) | 2.2B | 5.662 | 6.996 | 0.7s |
127
+ | **VAREdit** (512px) | 8.4B | 7.7923 | 8.1055 | 1.2s |
128
+ | **VAREdit** (1024px) | 8.4B | 7.3797 | 7.6880 | 3.9s |
129
+
130
+ **Note**: The released 8B models are trained longer and on more data, so the performances are better than that in the paper.
131
+
132
+ ## 📄 License
133
+
134
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
135
+
136
+ ## 📚 Citation
137
+
138
+ If you use VAREdit in your research, please cite:
139
+
140
+ ```bibtex
141
+ @article{varedit2025,
142
+ title={Visual Autoregressive Modeling for Instruction-Guided Image Editing},
143
+ author={Mao, Qingyang and Cai, Qi and Li, Yehao and Pan, Yingwei and Cheng, Mingyue and Yao, Ting and Liu, Qi and Mei, Tao},
144
+ journal={arXiv preprint},
145
+ year={2025}
146
+ }
147
+ ```
148
+
149
+ ## 🙏 Acknowledgments
150
+
151
+ - Built on the [Infinity](https://huggingface.co/FoundationVision/infinity) models
152
+
153
+ **Note**: This project is under active development. Features and code may change.
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio app for VAREdit image editing model.
3
+ Provides web interface for editing images with text instructions.
4
+ """
5
+ import spaces
6
+ import gradio as gr
7
+ import os
8
+ import tempfile
9
+ from PIL import Image
10
+ import logging
11
+ from infer import load_model, generate_image
12
+ import os
13
+ from huggingface_hub import snapshot_download
14
+ import torch
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ @spaces.GPU
21
+ def edit_image(
22
+ input_image: Image.Image,
23
+ instruction: str,
24
+ cfg: float = 4.0,
25
+ tau: float = 0.5,
26
+ seed: int = -1
27
+ ) -> Image.Image:
28
+ """Edit image based on text instruction."""
29
+ if input_image is None:
30
+ raise gr.Error("Please upload an image")
31
+
32
+ if not instruction.strip():
33
+ raise gr.Error("Please provide an editing instruction")
34
+
35
+ try:
36
+ # Load model if needed
37
+ # Save input image to temporary file
38
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file:
39
+ input_image.save(tmp_file.name, 'JPEG')
40
+ temp_path = tmp_file.name
41
+
42
+ try:
43
+ # Generate edited image
44
+ result_image = generate_image(
45
+ model_components,
46
+ temp_path,
47
+ instruction,
48
+ cfg=cfg,
49
+ tau=tau,
50
+ seed=seed if seed != -1 else None
51
+ )
52
+
53
+ return result_image
54
+
55
+ finally:
56
+ # Clean up temporary file
57
+ if os.path.exists(temp_path):
58
+ os.unlink(temp_path)
59
+
60
+ except Exception as e:
61
+ logger.error(f"Image editing failed: {e}")
62
+ raise gr.Error(f"Failed to edit image: {str(e)}")
63
+
64
+ # Create Gradio interface
65
+ def create_interface():
66
+ with gr.Blocks(title="VAREdit Image Editor") as demo:
67
+ gr.Markdown("# VAREdit Image Editor")
68
+ gr.Markdown("Edit images using natural language instructions with the VAREdit model.")
69
+
70
+ with gr.Row():
71
+ with gr.Column():
72
+ input_image = gr.Image(
73
+ type="pil",
74
+ label="Input Image",
75
+ )
76
+
77
+ instruction = gr.Textbox(
78
+ label="Editing Instruction",
79
+ placeholder="e.g., 'Remove glasses from this person', 'Change the sky to sunset', 'Add a hat'",
80
+ lines=2
81
+ )
82
+
83
+ with gr.Accordion("Advanced Settings", open=False):
84
+ cfg = gr.Slider(
85
+ minimum=1.0,
86
+ maximum=10.0,
87
+ value=3.0,
88
+ step=0.5,
89
+ label="CFG Scale (Guidance Strength)"
90
+ )
91
+
92
+ tau = gr.Slider(
93
+ minimum=0.1,
94
+ maximum=1.0,
95
+ value=0.1,
96
+ step=0.01,
97
+ label="Temperature (Tau)"
98
+ )
99
+
100
+ seed = gr.Number(
101
+ value=-1,
102
+ label="Seed (-1 for random)",
103
+ precision=0
104
+ )
105
+
106
+ edit_btn = gr.Button("Edit Image", variant="primary", size="lg")
107
+
108
+ with gr.Column():
109
+ output_image = gr.Image(
110
+ label="Edited Image",
111
+ )
112
+
113
+ # Example images and instructions
114
+ gr.Markdown("## Examples")
115
+ gr.Examples(
116
+ examples=[
117
+ ["assets/test_3.jpg", "change shirt to a black-and-white striped Breton top, add a red beret, set the background to an artist's loft with a window view of the Eiffel Tower"],
118
+ ["assets/test.jpg", "Add glasses to this girl and change hair color to red"],
119
+ ["assets/test_1.jpg", "replace all the bullets with shimmering, multi-colored butterflies."],
120
+ ["assets/test_4.jpg", "Set the scene against a dark, blurred-out server room, make all text and arrows glow with a vibrant cyan light"],
121
+ ],
122
+ inputs=[input_image, instruction],
123
+ outputs=output_image,
124
+ fn=lambda img, inst: edit_image(img, inst),
125
+ cache_examples=False
126
+ )
127
+
128
+ # Set up event handler
129
+ edit_btn.click(
130
+ fn=edit_image,
131
+ inputs=[input_image, instruction, cfg, tau, seed],
132
+ outputs=output_image
133
+ )
134
+
135
+ return demo
136
+ model_path = "HiDream-ai/VAREdit"
137
+
138
+ snapshot_download(repo_id=model_path, max_workers=16,repo_type="model",
139
+ local_dir=model_path)
140
+ model_components = load_model("HiDream-ai/VAREdit", "HiDream-ai/VAREdit/8B-512.pth", "8B", 512)
141
+
142
+ if __name__ == "__main__":
143
+ demo = create_interface()
144
+ demo.queue(max_size=50, default_concurrency_limit=16).launch(show_api=False)
assets/keep ADDED
File without changes
assets/test.jpg ADDED

Git LFS Details

  • SHA256: 313485a23fe9574c8968717398520d2b0c061aee460b317b93c5cb9100395cdd
  • Pointer size: 131 Bytes
  • Size of remote file: 118 kB
assets/test_1.jpg ADDED
assets/test_3.jpg ADDED

Git LFS Details

  • SHA256: cf117f67ef5a8056eabc4adbc3174a043a0811925f39ba2d1cd7c081e8f6b1fc
  • Pointer size: 131 Bytes
  • Size of remote file: 332 kB
assets/test_4.jpg ADDED
infer.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image inference module for VAREdit model.
3
+ Supports 2B and 8B model variants for image editing with text instructions.
4
+ """
5
+ import argparse
6
+ import logging
7
+ from typing import Tuple, Any, Optional
8
+ from torchvision.transforms.functional import to_tensor
9
+ import numpy as np
10
+ from PIL import Image
11
+ import PIL.Image as PImage
12
+ from tools.run_infinity import (
13
+ load_tokenizer, load_visual_tokenizer, load_transformer,
14
+ gen_one_img, h_div_w_templates, dynamic_resolution_h_w
15
+ )
16
+ import time
17
+ import torch
18
+
19
+ def transform(pil_img, target_image_size):
20
+ # currently only support square image.
21
+ width, height = pil_img.size
22
+ max_dim = max(width, height)
23
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
24
+ padded_image.paste(pil_img, (0, 0))
25
+ def crop_op(image):
26
+ image = image.resize((max_dim, max_dim), resample=PImage.LANCZOS)
27
+ crop_image = image.crop((0, 0, width, height))
28
+ return crop_image
29
+ padded_image = padded_image.resize((target_image_size, target_image_size), resample=PImage.LANCZOS)
30
+ im = to_tensor(np.array(padded_image))
31
+ return im.add(im).add_(-1), crop_op
32
+
33
+ # Configure logging
34
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # Model configurations
38
+ MODEL_CONFIGS = {
39
+ '2B': {
40
+ 'vae_filename': 'infinity_vae_d32reg.pth',
41
+ 'vae_type': 32,
42
+ 'model_type': 'infinity_2b',
43
+ 'apply_spatial_patchify': 0,
44
+ },
45
+ '8B': {
46
+ 'vae_filename': 'infinity_vae_d56_f8_14_patchify.pth',
47
+ 'vae_type': 14,
48
+ 'model_type': 'infinity_8b',
49
+ 'apply_spatial_patchify': 1,
50
+ }
51
+ }
52
+
53
+ # Common model arguments
54
+ COMMON_ARGS = {
55
+ 'cfg_insertion_layer': 0,
56
+ 'add_lvl_embeding_only_first_block': 1,
57
+ 'use_bit_label': 1,
58
+ 'rope2d_each_sa_layer': 1,
59
+ 'rope2d_normalized_by_hw': 2,
60
+ 'use_scale_schedule_embedding': 0,
61
+ 'sampling_per_bits': 1,
62
+ 'text_channels': 2048,
63
+ 'h_div_w_template': 1.000,
64
+ 'use_flex_attn': 0,
65
+ 'cache_dir': '/dev/shm',
66
+ 'checkpoint_type': 'torch',
67
+ 'bf16': 1,
68
+ 'enable_model_cache': 0,
69
+ }
70
+
71
+
72
+ def load_model(pretrain_root: str, model_path: str, model_size: str, image_size: int) -> Tuple[Any, ...]:
73
+ """
74
+ Load the model and its components.
75
+
76
+ Args:
77
+ pretrain_root: Root directory for pretrained models
78
+ model_path: Path to the specific model checkpoint
79
+
80
+ Returns:
81
+ Tuple of (args, model, vae, tokenizer, text_encoder)
82
+
83
+ Raises:
84
+ ValueError: If unsupported model size is specified
85
+ """
86
+ if model_size not in MODEL_CONFIGS:
87
+ raise ValueError(f"Unsupported model size: {model_size}. Choose '2B' or '8B'.")
88
+
89
+ config = MODEL_CONFIGS[model_size]
90
+
91
+ # Build arguments
92
+ args_dict = {
93
+ **COMMON_ARGS,
94
+ **config,
95
+ 'model_path': model_path,
96
+ 'vae_path': f"{pretrain_root}/{config['vae_filename']}",
97
+ 'text_encoder_ckpt': f"{pretrain_root}/flan-t5-xl"
98
+ }
99
+ args = argparse.Namespace(**args_dict)
100
+ if image_size == 512:
101
+ args.pn = "0.25M"
102
+ elif image_size == 1024:
103
+ args.pn = "1M"
104
+ else:
105
+ raise ValueError(f"Unsupported image size: {image_size}. Choose 512 or 1024.")
106
+ logger.info(f"Loading {model_size} model from {model_path}")
107
+
108
+ # Load components
109
+ text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt)
110
+ vae = load_visual_tokenizer(args)
111
+ model = load_transformer(vae, args)
112
+
113
+ logger.info("Model loaded successfully")
114
+ return args, model, vae, text_tokenizer, text_encoder, image_size
115
+
116
+
117
+ def generate_image(
118
+ model_components: Tuple[Any, ...],
119
+ src_img_path: str,
120
+ instruction: str,
121
+ cfg: float = 4.0,
122
+ tau: float = 0.5,
123
+ seed: Optional[int] = -1,
124
+ ) -> None:
125
+ """
126
+ Generate edited image based on source image and text instruction.
127
+
128
+ Args:
129
+ model_components: Tuple of (args, model, vae, tokenizer, text_encoder)
130
+ src_img_path: Path to source image
131
+ instruction: Text instruction for editing
132
+ cfg: Classifier-free guidance scale
133
+ tau: Temperature parameter
134
+ """
135
+ args, model, vae, tokenizer, text_encoder, image_size = model_components
136
+
137
+ # Set default image size
138
+ assert image_size in [512, 1024], f"Invalid image size: {image_size}, expected 512 or 1024"
139
+ if image_size == 512:
140
+ pn = "0.25M"
141
+ elif image_size == 1024:
142
+ pn = "1M"
143
+
144
+ # Load and preprocess source image
145
+ try:
146
+ with Image.open(src_img_path) as src_img:
147
+ src_img = src_img.convert('RGB')
148
+ src_img_tensor, crop_op = transform(src_img, image_size)
149
+ except Exception as e:
150
+ logger.error(f"Failed to load source image: {e}")
151
+ raise
152
+
153
+ # Set up generation parameters
154
+ aspect_ratio = 1.0 # h:w ratio
155
+ h_div_w_template = h_div_w_templates[np.argmin(np.abs(h_div_w_templates - aspect_ratio))]
156
+ scale_schedule = [(1, h, w) for (_, h, w) in dynamic_resolution_h_w[h_div_w_template][pn]['scales']]
157
+
158
+ logger.info(f"Generating image with instruction: '{instruction}'")
159
+
160
+ # Generate image
161
+ if seed == -1:
162
+ seed = np.random.randint(0, 1000000)
163
+ torch.cuda.empty_cache()
164
+ start_time = time.time()
165
+ generated_image = gen_one_img(
166
+ model, vae, tokenizer, text_encoder,
167
+ instruction, src_img_tensor,
168
+ g_seed=seed,
169
+ gt_leak=0,
170
+ gt_ls_Bl=None,
171
+ cfg_list=cfg,
172
+ tau_list=tau,
173
+ scale_schedule=scale_schedule,
174
+ cfg_insertion_layer=[args.cfg_insertion_layer],
175
+ vae_type=args.vae_type,
176
+ sampling_per_bits=args.sampling_per_bits,
177
+ enable_positive_prompt=0,
178
+ apply_spatial_patchify=args.apply_spatial_patchify,
179
+ )
180
+ end_time = time.time()
181
+ logger.info(f"Time taken: {end_time - start_time:.2f} seconds")
182
+ max_memory = torch.cuda.max_memory_allocated() / 1024 ** 3
183
+ logger.info(f"Max memory: {max_memory:.2f} GB")
184
+ generated_image_np = generated_image.cpu().numpy()
185
+ if generated_image_np.shape[2] == 3:
186
+ generated_image_np = generated_image_np[..., ::-1]
187
+ result_image = Image.fromarray(generated_image_np.astype(np.uint8))
188
+ result_image = crop_op(result_image)
189
+ return result_image
190
+
191
+ def main():
192
+ """Main execution function with example usage."""
193
+ try:
194
+ # Load model
195
+ model_components = load_model(
196
+ "HiDream-ai/VAREdit",
197
+ "HiDream-ai/VAREdit/8B-1024.pth",
198
+ "8B",
199
+ 1024
200
+ )
201
+
202
+ # Generate image
203
+ generate_image(
204
+ model_components,
205
+ "assets/test.jpg",
206
+ "Add glasses to this girl and change hair color to red",
207
+ cfg=3.0,
208
+ tau=1.0,
209
+ seed=42
210
+ )
211
+
212
+ except Exception as e:
213
+ logger.error(f"Inference failed: {e}")
214
+ raise
215
+
216
+
217
+ if __name__ == "__main__":
218
+ main()
infinity/dataset/build.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import os.path as osp
4
+ import random
5
+ import subprocess
6
+ from functools import partial
7
+ from typing import Optional
8
+ import time
9
+
10
+ import pytz
11
+ from infinity.dataset.webdataset import WDSEditDataset
12
+
13
+ try:
14
+ from grp import getgrgid
15
+ from pwd import getpwuid
16
+ except:
17
+ pass
18
+ import PIL.Image as PImage
19
+ from PIL import ImageFile
20
+ import numpy as np
21
+ from torchvision.transforms import transforms
22
+ from torchvision.transforms.functional import resize, to_tensor
23
+ import torch.distributed as tdist
24
+
25
+ from torchvision.transforms import InterpolationMode
26
+ bicubic = InterpolationMode.BICUBIC
27
+ lanczos = InterpolationMode.LANCZOS
28
+ PImage.MAX_IMAGE_PIXELS = (1024 * 1024 * 1024 // 4 // 3) * 5
29
+ ImageFile.LOAD_TRUNCATED_IMAGES = False
30
+
31
+
32
+ def time_str(fmt='[%m-%d %H:%M:%S]'):
33
+ return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt)
34
+
35
+
36
+ def normalize_01_into_pm1(x): # normalize x from [0, 1] to [-1, 1] by (x*2) - 1
37
+ return x.add(x).add_(-1)
38
+
39
+
40
+ def denormalize_pm1_into_01(x): # denormalize x from [-1, 1] to [0, 1]
41
+ return x.add(1).mul_(0.5)
42
+
43
+
44
+ def center_crop_arr(pil_image, image_size):
45
+ """
46
+ Center cropping implementation from ADM.
47
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
48
+ """
49
+ while min(*pil_image.size) >= 2 * image_size:
50
+ pil_image = pil_image.resize(
51
+ tuple(x // 2 for x in pil_image.size), resample=PImage.BOX
52
+ )
53
+
54
+ scale = image_size / min(*pil_image.size)
55
+ pil_image = pil_image.resize(
56
+ tuple(round(x * scale) for x in pil_image.size), resample=PImage.LANCZOS
57
+ )
58
+
59
+ arr = np.array(pil_image)
60
+ crop_y = (arr.shape[0] - image_size) // 2
61
+ crop_x = (arr.shape[1] - image_size) // 2
62
+ return PImage.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
63
+
64
+
65
+ class RandomResize:
66
+ def __init__(self, mid_reso, final_reso, interpolation):
67
+ ub = max(round((mid_reso + (mid_reso-final_reso) / 8) / 4) * 4, mid_reso)
68
+ self.reso_lb, self.reso_ub = final_reso, ub
69
+ self.interpolation = interpolation
70
+
71
+ def __call__(self, img):
72
+ return resize(img, size=random.randint(self.reso_lb, self.reso_ub), interpolation=self.interpolation)
73
+
74
+ def __repr__(self):
75
+ return f'RandomResize(reso=({self.reso_lb}, {self.reso_ub}), interpolation={self.interpolation})'
76
+
77
+
78
+ def load_save(reso=512):
79
+ import os
80
+ from PIL import Image as PImage
81
+ from torchvision.transforms import transforms, InterpolationMode
82
+ aug = transforms.Compose([
83
+ transforms.Resize(512, interpolation=InterpolationMode.LANCZOS),
84
+ transforms.CenterCrop((512, 512))
85
+ ])
86
+ src_folder = r'C:\Users\16333\Pictures\imgs_to_visual_v2'
87
+ ls = [os.path.join(src_folder, x) for x in ('1.jpg', '2.jpg', '3.png', '4.png', '5.png')]
88
+ print(ls)
89
+ imgs = []
90
+ for i, fname in enumerate(ls):
91
+ assert os.path.exists(fname)
92
+ with PImage.open(fname) as img:
93
+ img = img.convert('RGB')
94
+ img = aug(img)
95
+ imgs.append(img)
96
+ dst_d, dst_f = os.path.split(fname)
97
+ dst = os.path.join(dst_d, f'crop{dst_f.replace(".jpg", ".png")}')
98
+ img.save(dst)
99
+
100
+ W, H = imgs[0].size
101
+ WW = W * len(imgs)
102
+ new_im = PImage.new('RGB', (WW, H))
103
+ x_offset = 0
104
+ for img in imgs:
105
+ new_im.paste(img, (x_offset, 0))
106
+ x_offset += W
107
+ dst = os.path.join(src_folder, f'junfeng.png')
108
+ new_im.save(dst)
109
+
110
+
111
+ def print_aug(transform, label):
112
+ print(f'Transform {label} = ')
113
+ if hasattr(transform, 'transforms'):
114
+ for t in transform.transforms:
115
+ print(t)
116
+ else:
117
+ print(transform)
118
+ print('---------------------------\n')
119
+
120
+
121
+ def build_t2i_dataset(
122
+ args,
123
+ data_path: str,
124
+ data_load_reso: int,
125
+ max_caption_len: int,
126
+ short_prob=0.2,
127
+ load_vae_instead_of_image=False
128
+ ):
129
+ if args.use_streaming_dataset:
130
+ # return T2IIterableDataset(
131
+ # data_path,
132
+ # max_caption_len=max_caption_len,
133
+ # short_prob=short_prob,
134
+ # load_vae_instead_of_image=load_vae_instead_of_image,
135
+ # buffersize=args.iterable_data_buffersize,
136
+ # pn=args.pn,
137
+ # online_t5=args.online_t5,
138
+ # batch_size=args.batch_size,
139
+ # num_replicas=tdist.get_world_size(), # 1,
140
+ # rank=tdist.get_rank(), # 0
141
+ # dataloader_workers=args.workers,
142
+ # dynamic_resolution_across_gpus=args.dynamic_resolution_across_gpus,
143
+ # enable_dynamic_length_prompt=args.enable_dynamic_length_prompt,
144
+ # seed=args.seed if args.seed is not None else int(time.time()),
145
+ # )
146
+ return WDSEditDataset(
147
+ data_path,
148
+ buffersize=args.iterable_data_buffersize,
149
+ pn=args.pn,
150
+ batch_size=args.batch_size,
151
+ num_replicas=tdist.get_world_size(), # 1,
152
+ rank=tdist.get_rank(), # 0
153
+ # dataloader_workers=args.workers,
154
+ # dynamic_resolution_across_gpus=args.dynamic_resolution_across_gpus,
155
+ # enable_dynamic_length_prompt=args.enable_dynamic_length_prompt,
156
+ # seed=args.seed if args.seed is not None else int(time.time()),
157
+ )
158
+ else:
159
+ raise ValueError(f'args.use_streaming_dataset={args.use_streaming_dataset} unsupported')
160
+
161
+
162
+ def pil_load(path: str, proposal_size):
163
+ with open(path, 'rb') as f:
164
+ img: PImage.Image = PImage.open(f)
165
+ w: int = img.width
166
+ h: int = img.height
167
+ sh: int = min(h, w)
168
+ if sh > proposal_size:
169
+ ratio: float = proposal_size / sh
170
+ w = round(ratio * w)
171
+ h = round(ratio * h)
172
+ img.draft('RGB', (w, h))
173
+ img = img.convert('RGB')
174
+ return img
175
+
176
+
177
+ def rewrite(im: PImage, file: str, info: str):
178
+ kw = dict(quality=100)
179
+ if file.lower().endswith('.tif') or file.lower().endswith('.tiff'):
180
+ kw['compression'] = 'none'
181
+ elif file.lower().endswith('.webp'):
182
+ kw['lossless'] = True
183
+
184
+ st = os.stat(file)
185
+ uname = getpwuid(st.st_uid).pw_name
186
+ gname = getgrgid(st.st_gid).gr_name
187
+ mode = oct(st.st_mode)[-3:]
188
+
189
+ local_file = osp.basename(file)
190
+ im.save(local_file, **kw)
191
+ print(f'************* <REWRITE: {info}> ************* @ {file}')
192
+ subprocess.call(f'sudo mv {local_file} {file}; sudo chown {uname}:{gname} {file}; sudo chmod {mode} {file}', shell=True)
infinity/dataset/webdataset.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, get_h_div_w_template2indices, h_div_w_templates
2
+
3
+ import webdataset as wds
4
+ from torch.utils.data import DataLoader
5
+ from torchvision.transforms.functional import to_tensor
6
+ import numpy as np
7
+ import PIL.Image as PImage
8
+ import io
9
+
10
+
11
+ def pad_image_to_square(img):
12
+ width, height = img.size
13
+ max_side = max(width, height)
14
+ new_img = PImage.new("RGB", (max_side, max_side), (0, 0, 0))
15
+ paste_position = ((max_side - width) // 2, (max_side - height) // 2)
16
+ new_img.paste(img, paste_position)
17
+ return new_img
18
+
19
+
20
+ def transform(pil_img, tgt_h, tgt_w):
21
+ width, height = pil_img.size
22
+ if width / height <= tgt_w / tgt_h:
23
+ resized_width = tgt_w
24
+ resized_height = int(tgt_w / (width / height))
25
+ else:
26
+ resized_height = tgt_h
27
+ resized_width = int((width / height) * tgt_h)
28
+ pil_img = pil_img.resize((resized_width, resized_height), resample=PImage.LANCZOS)
29
+ # crop the center out
30
+ arr = np.array(pil_img)
31
+ crop_y = (arr.shape[0] - tgt_h) // 2
32
+ crop_x = (arr.shape[1] - tgt_w) // 2
33
+ im = to_tensor(arr[crop_y: crop_y + tgt_h, crop_x: crop_x + tgt_w])
34
+ # print(f'im size {im.shape}')
35
+ return im.add(im).add_(-1)
36
+
37
+
38
+ def preprocess(sample):
39
+ src, tgt, prompt = sample
40
+ h, w = dynamic_resolution_h_w[h_div_w_template][PN]['pixel']
41
+ src_img = PImage.open(io.BytesIO(src)).convert('RGB')
42
+ tgt_img = PImage.open(io.BytesIO(tgt)).convert('RGB').resize((src_img.size))
43
+ src_img = transform(src_img, h, w)
44
+ tgt_img = transform(tgt_img, h, w)
45
+ instruction = prompt.decode('utf-8')
46
+ return src_img, tgt_img, instruction
47
+
48
+
49
+ def WDSEditDataset(
50
+ data_path,
51
+ buffersize,
52
+ pn,
53
+ batch_size,
54
+ ):
55
+ urls = []
56
+ overall_length = 0
57
+
58
+ with open(f"{data_path}/SEEDEdit.txt", "r") as file:
59
+ info_file = file.readlines()
60
+ urls_base = "SEED_EDIT_DATA_SHARD_BASE"
61
+ data_file = []
62
+ for item in info_file:
63
+ file_name, length, shard_num = item.strip('\n').split('\t')
64
+ length, shard_num = int(length), int(shard_num)
65
+ for shard in range(shard_num):
66
+ data_file.append(f"wds_{file_name}_{shard:=04d}.tar")
67
+ overall_length += length
68
+ urls += [urls_base.replace("<FILE>", file) for file in data_file]
69
+
70
+ with open(f"{data_path}/ImgEdit.txt", "r") as file:
71
+ info_file = file.readlines()
72
+ urls_base = "IMG_EDIT_DATA_SHARD_BASE"
73
+ data_file = []
74
+ for item in info_file:
75
+ file_name, length, shard_num = item.strip('\n').split('\t')
76
+ length, shard_num = int(length), int(shard_num)
77
+ for shard in range(shard_num):
78
+ data_file.append(f"wds_{file_name}_{shard:=04d}.tar")
79
+ overall_length += length
80
+ urls += [urls_base.replace("<FILE>", file) for file in data_file]
81
+
82
+ global PN
83
+ PN = pn
84
+ global h_div_w_template
85
+ h_div_w_template = h_div_w_templates[np.argmin(np.abs(1.0 - h_div_w_templates))]
86
+ dataset = wds.WebDataset(
87
+ urls,
88
+ nodesplitter=wds.shardlists.split_by_node,
89
+ shardshuffle=True,
90
+ resampled=True,
91
+ cache_size=buffersize,
92
+ handler=wds.handlers.warn_and_continue,
93
+ ).with_length(overall_length).shuffle(100).to_tuple("src.jpg", "tgt.jpg", "txt").map(preprocess).batched(batch_size, partial=False).with_epoch(100000)
94
+ return dataset
infinity/models/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from timm.loss import SoftTargetCrossEntropy
3
+
4
+ from timm.models.layers import DropPath
5
+
6
+ from .infinity import Infinity, sample_with_top_k_top_p_also_inplace_modifying_logits_
7
+
8
+ def _ex_repr(self):
9
+ return ', '.join(
10
+ f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v))
11
+ for k, v in vars(self).items()
12
+ if not k.startswith('_') and k != 'training'
13
+ and not isinstance(v, (torch.nn.Module, torch.Tensor))
14
+ )
15
+ for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy): # no longer __repr__ DropPath with drop_prob
16
+ if hasattr(clz, 'extra_repr'):
17
+ clz.extra_repr = _ex_repr
18
+ else:
19
+ clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
20
+
21
+ DropPath.__repr__ = lambda self: f'{type(self).__name__}(...)'
22
+
23
+ alias_dict = {}
24
+ for d in range(6, 40+2, 2):
25
+ alias_dict[f'd{d}'] = f'infinity_d{d}'
26
+ alias_dict_inv = {v: k for k, v in alias_dict.items()}
infinity/models/basic.py ADDED
@@ -0,0 +1,646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Definitions of blocks of VAR transformer model.
3
+ """
4
+
5
+ import math
6
+ import os
7
+ from functools import partial
8
+ from typing import Optional, Tuple, Union, List
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+ from timm.models.layers import DropPath, drop_path
15
+ from torch.utils.checkpoint import checkpoint
16
+
17
+ # Import flash_attn's attention
18
+ from flash_attn import flash_attn_func # q, k, or v: BLHc, ret: BLHc
19
+ from flash_attn import flash_attn_varlen_kvpacked_func # qkv: N3Hc, ret: NHc
20
+
21
+ from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc
22
+
23
+ # Import flash_attn's fused ops
24
+ try:
25
+ from flash_attn.ops.layer_norm import dropout_add_layer_norm
26
+ from flash_attn.ops.rms_norm import dropout_add_rms_norm
27
+ from flash_attn.ops.rms_norm import rms_norm as rms_norm_impl
28
+ from flash_attn.ops.fused_dense import fused_mlp_func
29
+ flash_fused_op_installed = True
30
+ except ImportError:
31
+ dropout_add_layer_norm = dropout_add_rms_norm = fused_mlp_func = None
32
+ flash_fused_op_installed = False
33
+
34
+ def rms_norm_impl(x, weight, epsilon):
35
+ return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(epsilon))) * weight
36
+
37
+
38
+ def precompute_rope2d_freqs_grid(dim, dynamic_resolution_h_w, rope2d_normalized_by_hw, pad_to_multiplier=1, max_height=2048 // 16, max_width=2048 // 16, base=10000.0, device=None, scaling_factor=1.0):
39
+ # split the dimension into half, one for x and one for y
40
+ half_dim = dim // 2
41
+ inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2, dtype=torch.int64).float().to(device) / half_dim)) # namely theta, 1 / (10000^(i/half_dim)), i=0,2,..., half_dim-2
42
+ t_height = torch.arange(max_height * 2, device=device, dtype=torch.int64).type_as(inv_freq)
43
+ t_width = torch.arange(max_width * 2, device=device, dtype=torch.int64).type_as(inv_freq)
44
+ t_height = t_height / scaling_factor
45
+ freqs_height = torch.outer(t_height, inv_freq) # (max_height, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2), namely y*theta
46
+ t_width = t_width / scaling_factor
47
+ freqs_width = torch.outer(t_width, inv_freq) # (max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2), namely x*theta
48
+ freqs_grid_map = torch.concat([
49
+ freqs_height[:, None, :].expand(-1, max_width * 2, -1), # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2)
50
+ freqs_width[None, :, :].expand(max_height * 2, -1, -1), # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2)
51
+ ], dim=-1) # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d))
52
+ freqs_grid_map = torch.stack([torch.cos(freqs_grid_map), torch.sin(freqs_grid_map)], dim=0)
53
+ # (2, max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d))
54
+
55
+ rope2d_freqs_grid = {}
56
+ for h_div_w in dynamic_resolution_h_w:
57
+ scale_schedule = dynamic_resolution_h_w[h_div_w]['1M']['scales']
58
+ _, ph, pw = scale_schedule[-1]
59
+ max_edge_length = freqs_grid_map.shape[1]
60
+ if ph >= pw:
61
+ uph, upw = max_edge_length, int(max_edge_length / ph * pw)
62
+ else:
63
+ uph, upw = int(max_edge_length / pw * ph), max_edge_length
64
+
65
+ rope_cache_list = []
66
+ _, uph, upw = scale_schedule[-1]
67
+ src_indices = torch.stack([
68
+ (torch.arange(uph)).reshape(uph, 1).expand(uph, upw) + uph,
69
+ (torch.arange(upw)).reshape(1, upw).expand(uph, upw) + upw,
70
+ ], dim=-1).round().int()
71
+ src_indices = src_indices.reshape(-1, 2)
72
+ src_rope_cache = freqs_grid_map[:, src_indices[:,0], src_indices[:,1], :] # (2, ph*pw, half_head_dim)
73
+ src_rope_cache = src_rope_cache.reshape(2, uph, upw, -1)
74
+ rope_cache_list.append(src_rope_cache.reshape(2, uph * upw, -1))
75
+
76
+ for i, (_, ph, pw) in enumerate(scale_schedule):
77
+ ph_mul_pw = ph * pw
78
+ if rope2d_normalized_by_hw == 1: # downsample
79
+ rope_cache = F.interpolate(freqs_grid_map[:, :uph, :upw, :].permute([0,3,1,2]), size=(ph, pw), mode='bilinear', align_corners=True)
80
+ rope_cache = rope_cache.permute([0,2,3,1]) # (2, ph, pw, half_head_dim)
81
+ elif rope2d_normalized_by_hw == 2: # star stylee
82
+ _, uph, upw = scale_schedule[-1]
83
+ tgt_indices = torch.stack([
84
+ (torch.arange(ph) * (uph / ph)).reshape(ph, 1).expand(ph, pw),
85
+ (torch.arange(pw) * (upw / pw)).reshape(1, pw).expand(ph, pw),
86
+ ], dim=-1).round().int()
87
+ tgt_indices = tgt_indices.reshape(-1, 2) # (ph*pw, 2)
88
+ tgt_rope_cache = freqs_grid_map[:, tgt_indices[:,0], tgt_indices[:,1], :] # (2, ph*pw, half_head_dim)
89
+ tgt_rope_cache = tgt_rope_cache.reshape(2, ph, pw, -1)
90
+ elif rope2d_normalized_by_hw == 0:
91
+ rope_cache = freqs_grid_map[:, :ph, :pw, :] # (2, ph, pw, half_head_dim)
92
+ else:
93
+ raise ValueError(f'Unknown rope2d_normalized_by_hw: {rope2d_normalized_by_hw}')
94
+ rope_cache_list.append(tgt_rope_cache.reshape(2, ph_mul_pw, -1))
95
+ cat_rope_cache = torch.cat(rope_cache_list, 1) # (2, seq_len, half_head_dim)
96
+
97
+ if cat_rope_cache.shape[1] % pad_to_multiplier:
98
+ pad = torch.zeros(2, pad_to_multiplier - cat_rope_cache.shape[1] % pad_to_multiplier, half_dim)
99
+ cat_rope_cache = torch.cat([cat_rope_cache, pad], dim=1)
100
+ cat_rope_cache = cat_rope_cache[:,None,None,None] # (2, 1, 1, 1, seq_len, half_dim)
101
+ for pn in dynamic_resolution_h_w[h_div_w]:
102
+ scale_schedule = dynamic_resolution_h_w[h_div_w][pn]['scales']
103
+ tmp_scale_schedule = [(1, h, w) for _, h, w in scale_schedule]
104
+ rope2d_freqs_grid[str(tuple(tmp_scale_schedule))] = cat_rope_cache
105
+ return rope2d_freqs_grid
106
+
107
+
108
+ def apply_rotary_emb(q, k, scale_schedule, rope2d_freqs_grid, pad_to_multiplier, rope2d_normalized_by_hw, scale_ind, src=True):
109
+ qk = torch.stack((q, k), dim=0) #(2, batch_size, heads, seq_len, head_dim)
110
+ device_type = qk.device.type
111
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
112
+ with torch.autocast(device_type=device_type, enabled=False):
113
+ seq_len = qk.shape[3]
114
+ assert len(scale_schedule[0]) == 3
115
+ start = 0
116
+
117
+ if not src:
118
+ start += np.array(scale_schedule[-1]).prod()
119
+ for i in range(scale_ind):
120
+ start += np.array(scale_schedule[i]).prod()
121
+
122
+ start = int(start)
123
+ rope2d_freqs_grid[str(tuple(scale_schedule))] = rope2d_freqs_grid[str(tuple(scale_schedule))].to(qk.device)
124
+ assert start+seq_len <= rope2d_freqs_grid[str(tuple(scale_schedule))].shape[4]
125
+ rope_cache = rope2d_freqs_grid[str(tuple(scale_schedule))][:, :, :, :, start:start+seq_len] # rope_cache shape: [2, 1, 1, 1, seq_len, half_head_dim]
126
+ qk = qk.reshape(*qk.shape[:-1], -1, 2) #(2, batch_size, heads, seq_len, half_head_dim, 2)
127
+ qk = torch.stack([
128
+ rope_cache[0] * qk[...,0] - rope_cache[1] * qk[...,1],
129
+ rope_cache[1] * qk[...,0] + rope_cache[0] * qk[...,1],
130
+ ], dim=-1) # (2, batch_size, heads, seq_len, half_head_dim, 2), here stack + reshape should not be concate
131
+ qk = qk.reshape(*qk.shape[:-2], -1) #(2, batch_size, heads, seq_len, head_dim)
132
+ q, k = qk.unbind(dim=0) # (batch_size, heads, seq_len, head_dim)
133
+ return q, k
134
+
135
+
136
+ class FastRMSNorm(nn.Module):
137
+ def __init__(self, C, eps=1e-6, elementwise_affine=True):
138
+ super().__init__()
139
+ self.C = C
140
+ self.eps = eps
141
+ self.elementwise_affine = elementwise_affine
142
+ if self.elementwise_affine:
143
+ self.weight = nn.Parameter(torch.ones(C))
144
+ else:
145
+ self.register_buffer('weight', torch.ones(C))
146
+
147
+ def forward(self, x):
148
+ src_type = x.dtype
149
+ return rms_norm_impl(x.float(), self.weight, epsilon=self.eps).to(src_type)
150
+
151
+ def extra_repr(self) -> str:
152
+ return f'C={self.C}, eps={self.eps:g}, elementwise_affine={self.elementwise_affine}'
153
+
154
+
155
+ def get_dropout_layer(p):
156
+ return nn.Dropout(p, inplace=True) if p > 0 else nn.Identity()
157
+
158
+
159
+ class FFN(nn.Module):
160
+ def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_mlp=False):
161
+ super().__init__()
162
+ self.fused_mlp_func = fused_mlp_func if fused_mlp else None
163
+ out_features = out_features or in_features
164
+ hidden_features = hidden_features or in_features
165
+ self.fc1 = nn.Linear(in_features, hidden_features)
166
+ self.act = nn.GELU(approximate='tanh')
167
+ self.fc2 = nn.Linear(hidden_features, out_features)
168
+ self.drop = get_dropout_layer(drop)
169
+ self.heuristic = -1
170
+
171
+ def forward(self, x):
172
+ if self.fused_mlp_func is not None:
173
+ return self.drop(self.fused_mlp_func(
174
+ x=x,
175
+ weight1=self.fc1.weight,
176
+ weight2=self.fc2.weight,
177
+ bias1=self.fc1.bias,
178
+ bias2=self.fc2.bias,
179
+ activation='gelu_approx',
180
+ save_pre_act=self.training,
181
+ return_residual=False,
182
+ checkpoint_lvl=0,
183
+ heuristic=self.heuristic,
184
+ process_group=None,
185
+ ))
186
+ else:
187
+ return self.drop(self.fc2( self.act(self.fc1(x)) ))
188
+
189
+ def extra_repr(self) -> str:
190
+ return f'fused_mlp={self.fused_mlp_func is not None}'
191
+
192
+
193
+ class FFNSwiGLU(nn.Module):
194
+ def __init__(self, in_features, hidden_features, out_features=None, drop=0., fused_mlp=False):
195
+ super().__init__()
196
+ self.fused_mlp_func = None
197
+ hidden_features = round(2 * hidden_features / 3 / 256) * 256
198
+
199
+ out_features = out_features or in_features
200
+ self.fcg = nn.Linear(in_features, hidden_features, bias=False)
201
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
202
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
203
+ self.drop = get_dropout_layer(drop)
204
+
205
+ def forward(self, x):
206
+ return self.drop(self.fc2( F.silu(self.fcg(x), inplace=True).mul_(self.fc1(x)) ))
207
+
208
+ def extra_repr(self) -> str:
209
+ return f'fused_mlp={self.fused_mlp_func is not None}'
210
+
211
+
212
+ class SelfAttention(nn.Module):
213
+ def __init__(
214
+ self, embed_dim=768, num_heads=12,
215
+ proj_drop=0., tau=1, cos_attn=False, customized_flash_attn=True, use_flex_attn=False,
216
+ batch_size=2, pad_to_multiplier=1, rope2d_normalized_by_hw=0,
217
+ ):
218
+ """
219
+ :param embed_dim: model's width
220
+ :param num_heads: num heads of multi-head attention
221
+ :param proj_drop: always 0 for testing
222
+ :param tau: always 1
223
+ :param cos_attn: always True: during attention, q and k will be L2-normalized and scaled by a head-wise learnable parameter self.scale_mul_1H11
224
+ :param customized_flash_attn:
225
+ """
226
+ super().__init__()
227
+ assert embed_dim % num_heads == 0
228
+ self.using_flash = customized_flash_attn
229
+
230
+ self.num_heads, self.head_dim = num_heads, embed_dim // num_heads
231
+ self.tau, self.cos_attn = tau, cos_attn
232
+ if self.cos_attn:
233
+ self.scale = 1
234
+ size = (1, 1, self.num_heads, 1) if self.using_flash else (1, self.num_heads, 1, 1)
235
+ # size: 11H1 or 1H11
236
+ self.scale_mul_1H11 = nn.Parameter(torch.full(size=size, fill_value=4.0).log(), requires_grad=True)
237
+ self.max_scale_mul = torch.log(torch.tensor(100)).item()
238
+ else:
239
+ self.scale = 1 / math.sqrt(self.head_dim) / self.tau
240
+
241
+ self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
242
+ self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim))
243
+ self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
244
+
245
+ self.proj = nn.Linear(embed_dim, embed_dim)
246
+ self.proj_drop = get_dropout_layer(proj_drop)
247
+
248
+ self.caching = False # kv caching: only used during inference
249
+ self.cached_k = None # kv caching: only used during inference
250
+ self.cached_v = None # kv caching: only used during inference
251
+
252
+ self.batch_size = batch_size
253
+ self.use_flex_attn = use_flex_attn
254
+ self.pad_to_multiplier = pad_to_multiplier
255
+
256
+ self.rope2d_normalized_by_hw = rope2d_normalized_by_hw
257
+
258
+
259
+ def kv_caching(self, enable: bool): # kv caching: only used during inference
260
+ self.caching = enable
261
+ self.cached_k = None
262
+ self.cached_v = None
263
+ self.cached_init_k = None
264
+ self.cached_init_v = None
265
+
266
+ # NOTE: attn_bias_or_two_vector is None during inference
267
+ def forward(self, x, attn_bias_or_two_vector: Union[torch.Tensor, Tuple[torch.IntTensor, torch.IntTensor]], attn_fn=None, scale_schedule=None, rope2d_freqs_grid=None, start_layer=False, scale_ind=0, src=True):
268
+ """
269
+ :param (fp32) x: shaped (B or batch_size, L or seq_length, C or hidden_dim); if seq-parallel is used, the `L` dim would be shared
270
+ :param (fp32) attn_bias_or_two_vector:
271
+ if not using_flash:
272
+ a block-wise, lower-triangle matrix, like:
273
+ [[[[0, -, -, -, -, -, -, -, -, -, -, -, -, -],
274
+ [0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
275
+ [0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
276
+ [0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
277
+ [0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
278
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
279
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
280
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
281
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
282
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
283
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
284
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
285
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
286
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]]
287
+ where 0 means visible and - means invisible (-inf)
288
+ else:
289
+ a tuple of two 1-dim int vector (VAR_visible_kvlen, VAR_invisible_qlen)
290
+ :return: shaped (B or batch_size, L or seq_length, C or hidden_dim); if seq-parallel is used, the `L` dim would be shared
291
+ """
292
+ # x: fp32
293
+ B, L, C = x.shape
294
+
295
+ # qkv: amp, bf16
296
+ qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim) # BL3Hc
297
+ if self.using_flash: q, k, v = qkv.unbind(dim=2); L_dim = 1 # q or k or v: all are shaped in (B:batch_size, L:seq_len, H:heads, c:head_dim)
298
+ else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); L_dim = 2 # q or k or v: all are shaped in (B:batch_size, H:heads, L:seq_len, c:head_dim)
299
+
300
+ if self.cos_attn: # always True
301
+ scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp() # 11H1 (flash), or 1H11 (not flash)
302
+ q = F.normalize(q, dim=-1, eps=1e-12).mul(scale_mul).contiguous() # fp32
303
+ k = F.normalize(k, dim=-1, eps=1e-12).contiguous() # fp32
304
+ v = v.contiguous() # bf16
305
+ else: # be contiguous, to make kernel happy
306
+ q = q.contiguous() # bf16
307
+ k = k.contiguous() # bf16
308
+ v = v.contiguous() # bf16
309
+ if rope2d_freqs_grid is not None:
310
+ q, k = apply_rotary_emb(q, k, scale_schedule, rope2d_freqs_grid, self.pad_to_multiplier, self.rope2d_normalized_by_hw, scale_ind, src=src) #, freqs_cis=freqs_cis)
311
+
312
+ def down_func(tensor, src_scale, tgt_scale):
313
+ """
314
+ Downsample the tensor from src_scale to tgt_scale with area interpolation.
315
+ :param tensor: [B, H, L, c]
316
+ :param src_scale: (1, h, w)
317
+ :param tgt_scale: (1, h, w)
318
+ :return: [B, H, L', c]
319
+ """
320
+ B, H, L, C = tensor.shape
321
+ src_h, src_w = src_scale[-2], src_scale[-1]
322
+ tgt_h, tgt_w = tgt_scale[-2], tgt_scale[-1]
323
+ if src_h == tgt_h and src_w == tgt_w:
324
+ return tensor
325
+ else:
326
+ # area interpolation
327
+ tensor = tensor.permute(0, 1, 3, 2)
328
+ tensor = tensor.reshape(B, H, C, src_h, src_w).reshape(B, -1, src_h, src_w)
329
+ tensor = F.interpolate(tensor, size=(tgt_h, tgt_w), mode='area')
330
+ tensor = tensor.reshape(B, H, C, tgt_h, tgt_w).reshape(B, H, C, -1)
331
+ tensor = tensor.permute(0, 1, 3, 2)
332
+ return tensor
333
+
334
+ if self.caching: # kv caching: only used during inference
335
+ if start_layer:
336
+ if self.cached_init_k is None: self.cached_init_k = k; self.cached_init_v = v
337
+ else:
338
+ k_src = down_func(self.cached_init_k, scale_schedule[-1], scale_schedule[scale_ind])
339
+ v_src = down_func(self.cached_init_v, scale_schedule[-1], scale_schedule[scale_ind])
340
+ if self.cached_k is None: self.cached_k = k; self.cached_v = v
341
+ else: k = self.cached_k = torch.cat((self.cached_k, k), dim=L_dim); v = self.cached_v = torch.cat((self.cached_v, v), dim=L_dim)
342
+ k = torch.cat([k_src, k], dim=L_dim); v = torch.cat([v_src, v], dim=L_dim)
343
+ else:
344
+ if self.cached_k is None: self.cached_k = k; self.cached_v = v
345
+ else: k = self.cached_k = torch.cat((self.cached_k, k), dim=L_dim); v = self.cached_v = torch.cat((self.cached_v, v), dim=L_dim)
346
+
347
+ if not self.caching and start_layer: # train & first layer
348
+ length_list = [int(np.array(scale_schedule[i]).prod()) for i in range(len(scale_schedule))]
349
+ length_list = [length_list[-1]] + length_list
350
+ length_list += [L - sum(length_list)]
351
+ q_list = torch.split(q, length_list, dim=L_dim) # [B, H, L, c]
352
+ k_list = torch.split(k, length_list, dim=L_dim) # [B, H, L, c]
353
+ v_list = torch.split(v, length_list, dim=L_dim) # [B, H, L, c]
354
+ outputs = []
355
+ for i in range(len(length_list)-1):
356
+ k_src = down_func(k_list[0], scale_schedule[-1], scale_schedule[i-1])
357
+ k_cal = torch.cat([k_src, *k_list[1:i+1]], dim=L_dim) if i > 0 else k_src
358
+ v_src = down_func(v_list[0], scale_schedule[-1], scale_schedule[i-1])
359
+ v_cal = torch.cat([v_src, *v_list[1:i+1]], dim=L_dim) if i > 0 else v_src
360
+ output = slow_attn(query=q_list[i], key=k_cal, value=v_cal, scale=self.scale, dropout_p=0).transpose(1, 2).reshape(B, -1, C)
361
+ outputs.append(output)
362
+ pad_zeros = torch.zeros_like(v_list[-1], dtype=v.dtype, device=v.device).permute(0, 2, 1, 3)
363
+ pad_zeros = pad_zeros.reshape(*pad_zeros.shape[:-2], -1) # [B, H, L', c]
364
+ outputs.append(pad_zeros)
365
+ oup = torch.cat(outputs, dim=1) # [B, L, C]
366
+ elif self.using_flash: # Default false
367
+ if attn_bias_or_two_vector is not None: # training
368
+ kw = dict(VAR_visible_kvlen=attn_bias_or_two_vector[0], VAR_invisible_qlen=attn_bias_or_two_vector[1])
369
+ else: # inference (autoregressive sampling)
370
+ kw = dict()
371
+ oup = flash_attn_func(q.to(v.dtype), k.to(v.dtype), v, dropout_p=0, softmax_scale=self.scale, **kw).view(B, L, C)
372
+ else:
373
+ # if self.cos_attn: q, k are in fp32; v is in bf16
374
+ # else: q, k, v are in bf16
375
+ if self.use_flex_attn and attn_fn is not None: # train & high layers
376
+ oup = attn_fn(q, k, v, scale=self.scale).transpose(1, 2).reshape(B, L, C)
377
+ else: # inference
378
+ oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias_or_two_vector, dropout_p=0).transpose(1, 2).reshape(B, L, C)
379
+ # oup: bf16
380
+
381
+ return self.proj_drop(self.proj(oup))
382
+
383
+ def extra_repr(self) -> str:
384
+ tail = ''
385
+ return f'using_flash={self.using_flash}, tau={self.tau}, cos_attn={self.cos_attn}{tail}'
386
+
387
+
388
+ class CrossAttention(nn.Module):
389
+ def __init__(
390
+ self, for_attn_pool=False, embed_dim=768, kv_dim=4096, num_heads=12,
391
+ proj_drop=0., cos_attn=False,
392
+ ):
393
+ """
394
+ :param for_attn_pool: only used in VAR.text_proj_for_sos
395
+ :param embed_dim: Q's dim
396
+ :param kv_dim: K's and V's dim
397
+ :param num_heads: num heads of multi-head attention
398
+ :param proj_drop: proj drop out
399
+ :param cos_attn: during attention, q and k will be L2-normalized and scaled by a head-wise learnable parameter self.scale_mul_1H11
400
+ """
401
+ cos_attn = False # TODO: never use cos attn in cross attention with T5 kv
402
+ super().__init__()
403
+ self.for_attn_pool = for_attn_pool
404
+ self.embed_dim = embed_dim
405
+ self.kv_dim = kv_dim
406
+ assert embed_dim % num_heads == 0
407
+ self.num_heads, self.head_dim = num_heads, embed_dim // num_heads # =64
408
+ self.cos_attn = cos_attn
409
+ if self.cos_attn:
410
+ self.scale = 1
411
+ self.scale_mul_1H1 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True)
412
+ self.max_scale_mul = torch.log(torch.tensor(100)).item()
413
+ else:
414
+ self.scale = 1 / math.sqrt(self.head_dim)
415
+
416
+ if for_attn_pool:
417
+ q = torch.empty(1, self.num_heads, self.head_dim)
418
+ nn.init.trunc_normal_(q, mean=0, std=math.sqrt(1 / embed_dim / 3))
419
+ self.mat_q = nn.Parameter(q)
420
+ else:
421
+ self.mat_q = nn.Linear(embed_dim, embed_dim, bias=True)
422
+ self.mat_kv = nn.Linear(kv_dim, embed_dim*2, bias=False)
423
+ self.v_bias = nn.Parameter(torch.zeros(embed_dim))
424
+ self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
425
+
426
+ self.proj = nn.Linear(embed_dim, embed_dim)
427
+ self.proj_drop = get_dropout_layer(proj_drop)
428
+
429
+ def forward(self, q, ca_kv):
430
+ """
431
+ :param q: shaped as (batch, seq_len, Q_dim)
432
+ :param ca_kv: contains several vectors, each of which is shaped as (len_i, KV_dim). We have [len_1xKV_dim, len_2xKV_dim, len_3xKV_dim, ...] and lens == [len_1, len_2, len_3, ...]
433
+ - kv_compact: shaped as (sum(lens), KV_dim)
434
+ - cu_seqlens_k: cumulated sum of lens
435
+ - max_seqlen_k: int, max(lens)
436
+ NOTE: seq_len (num of Qs) can reach 10k; but len_i (num of KVs) must <= 256
437
+
438
+ :return: shaped as (batch, seq_len, Q_dim)
439
+ """
440
+ kv_compact, cu_seqlens_k, max_seqlen_k = ca_kv
441
+ N = kv_compact.shape[0]
442
+
443
+ kv_compact = F.linear(kv_compact, weight=self.mat_kv.weight, bias=torch.cat((self.zero_k_bias, self.v_bias))).view(N, 2, self.num_heads, self.head_dim) # NC => N2Hc
444
+ # attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
445
+
446
+ if not self.for_attn_pool:
447
+ B, Lq = q.shape[:2]
448
+ q_compact = self.mat_q(q).view(-1, self.num_heads, self.head_dim)
449
+ else:
450
+ B = cu_seqlens_k.shape[0] - 1
451
+ Lq = 1
452
+ q_compact = self.mat_q.repeat(B, 1, 1).to(dtype=kv_compact.dtype)
453
+
454
+ if self.cos_attn: # always False
455
+ scale_mul = self.scale_mul_1H1.clamp_max(self.max_scale_mul).exp()
456
+ k, v = kv_compact.unbind(dim=1)
457
+ q_compact = F.normalize(q_compact, dim=-1).mul(scale_mul)
458
+ k = F.normalize(k, dim=-1)
459
+ kv_compact = torch.stack((k, v), dim=1)
460
+
461
+ q_compact = q_compact.contiguous()
462
+ kv_compact = kv_compact.contiguous()
463
+
464
+ cu_seqlens_q = torch.arange(0, Lq * (B+1), Lq, dtype=torch.int32, device=q_compact.device)
465
+ if q_compact.dtype == torch.float32: # todo: fp16 or bf16?
466
+ oup = flash_attn_varlen_kvpacked_func(q=q_compact.to(dtype=torch.bfloat16), kv=kv_compact.to(dtype=torch.bfloat16), cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=Lq, max_seqlen_k=max_seqlen_k, dropout_p=0, softmax_scale=self.scale).reshape(B, Lq, -1)
467
+ oup = oup.float()
468
+ else:
469
+ oup = flash_attn_varlen_kvpacked_func(q=q_compact, kv=kv_compact, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=Lq, max_seqlen_k=max_seqlen_k, dropout_p=0, softmax_scale=self.scale).reshape(B, Lq, -1)
470
+
471
+ return self.proj_drop(self.proj(oup))
472
+
473
+ def extra_repr(self) -> str:
474
+ return f'Cq={self.embed_dim}, Ckv={self.kv_dim}, cos_attn={self.cos_attn}'
475
+
476
+
477
+ class SelfAttnBlock(nn.Module):
478
+ def __init__(
479
+ self, embed_dim, kv_dim, cross_attn_layer_scale, cond_dim, act: bool, shared_aln: bool, norm_layer: partial,
480
+ num_heads, mlp_ratio=4., drop=0., drop_path=0., tau=1, cos_attn=False,
481
+ swiglu=False, customized_flash_attn=False, fused_mlp=False, fused_norm_func=None, checkpointing_sa_only=False,
482
+ ):
483
+ super(SelfAttnBlock, self).__init__()
484
+ self.C, self.D = embed_dim, cond_dim
485
+ self.drop_path_rate = drop_path
486
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
487
+ self.attn = SelfAttention(
488
+ embed_dim=embed_dim, num_heads=num_heads, proj_drop=drop, tau=tau, cos_attn=cos_attn, customized_flash_attn=customized_flash_attn, attn_fn = attn_fn
489
+ )
490
+ self.using_swiglu = swiglu
491
+ self.ffn = (FFNSwiGLU if swiglu else FFN)(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio / 256) * 256, drop=drop, fused_mlp=fused_mlp)
492
+
493
+ self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
494
+ self.fused_norm_func = fused_norm_func
495
+ self.norm_eps = norm_layer.keywords.get('eps', 1e-6)
496
+
497
+ self.shared_aln = shared_aln
498
+ if self.shared_aln:
499
+ self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
500
+ else:
501
+ lin = nn.Linear(cond_dim, 6*embed_dim)
502
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin)
503
+
504
+ # NOTE: attn_bias_or_two_vector is None during inference
505
+ def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector): # todo: minGPT and vqgan also uses pre-norm, just like this, while MaskGiT uses post-norm
506
+ with torch.cuda.amp.autocast(enabled=False):
507
+ if self.shared_aln: # always True; (1, 1, 6, C) + (B, 1, 6, C)
508
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
509
+ else:
510
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
511
+
512
+ if self.fused_ada_norm is None:
513
+ x = x + self.drop_path(self.attn( self.ln_wo_grad(x.float()).mul(scale1.add(1)).add_(shift1), attn_bias_or_two_vector=attn_bias_or_two_vector ).mul_(gamma1))
514
+ x = x + self.drop_path(self.ffn( self.ln_wo_grad(x.float()).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
515
+ else:
516
+ x = x + self.drop_path(self.attn(self.fused_ada_norm(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1), attn_bias_or_two_vector=attn_bias_or_two_vector).mul_(gamma1))
517
+ x = x + self.drop_path(self.ffn(self.fused_ada_norm(C=self.C, eps=self.norm_eps, x=x, scale=scale2, shift=shift2)).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
518
+ return x
519
+
520
+ def extra_repr(self) -> str:
521
+ return f'shared_aln={self.shared_aln}, fused_norm={self.fused_norm_func is not None}'
522
+
523
+
524
+ class CrossAttnBlock(nn.Module):
525
+ def __init__(
526
+ self,
527
+ embed_dim, kv_dim, cross_attn_layer_scale, cond_dim, act: bool, shared_aln: bool, norm_layer: partial,
528
+ num_heads, mlp_ratio=4., drop=0., drop_path=0., tau=1, cos_attn=False,
529
+ swiglu=False, customized_flash_attn=False, fused_mlp=False, fused_norm_func=None, checkpointing_sa_only=False,
530
+ use_flex_attn=False, batch_size=2, pad_to_multiplier=1, apply_rope2d=False, rope2d_normalized_by_hw=False,
531
+ ):
532
+ super(CrossAttnBlock, self).__init__()
533
+ self.C, self.D = embed_dim, cond_dim
534
+ self.drop_path_rate = drop_path
535
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
536
+ self.sa = SelfAttention(
537
+ embed_dim=embed_dim, num_heads=num_heads, proj_drop=drop, tau=tau, cos_attn=cos_attn, customized_flash_attn=customized_flash_attn,
538
+ use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw,
539
+ )
540
+ self.ca = CrossAttention(embed_dim=embed_dim, kv_dim=kv_dim, num_heads=num_heads, proj_drop=drop, cos_attn=cos_attn)
541
+ self.using_swiglu = swiglu
542
+ self.ffn = (FFNSwiGLU if swiglu else FFN)(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio / 256) * 256, drop=drop, fused_mlp=fused_mlp)
543
+
544
+ self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
545
+ self.fused_norm_func = fused_norm_func
546
+ self.norm_eps = norm_layer.keywords.get('eps', 1e-6)
547
+ self.ca_norm = norm_layer(embed_dim, elementwise_affine=True)
548
+
549
+ self.shared_aln = shared_aln
550
+ if self.shared_aln: # always True
551
+ self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
552
+ else:
553
+ lin = nn.Linear(cond_dim, 6*embed_dim)
554
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin)
555
+
556
+ if cross_attn_layer_scale >= 0:
557
+ self.ca_gamma = nn.Parameter(cross_attn_layer_scale * torch.ones(embed_dim), requires_grad=True)
558
+ else:
559
+ self.ca_gamma = 1
560
+
561
+ self.checkpointing_sa_only = checkpointing_sa_only
562
+
563
+
564
+ # NOTE: attn_bias_or_two_vector is None during inference
565
+ def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn=None, scale_schedule=None, rope2d_freqs_grid=None, start_layer=False, scale_ind=0, src=True): # todo: minGPT and vqgan also uses pre-norm, just like this, while MaskGiT uses post-norm
566
+ with torch.cuda.amp.autocast(enabled=False): # disable half precision
567
+ if self.shared_aln: # always True; (1, 1, 6, C) + (B, 1, 6, C)
568
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
569
+ else:
570
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
571
+
572
+ if self.fused_norm_func is None:
573
+ x_sa = self.ln_wo_grad(x.float()).mul(scale1.add(1)).add_(shift1)
574
+ if self.checkpointing_sa_only and self.training:
575
+ x_sa = checkpoint(self.sa, x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False, start_layer=start_layer, src=src)
576
+ else:
577
+ x_sa = self.sa(x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, start_layer=start_layer, src=src)
578
+ x = x + self.drop_path(x_sa.mul_(gamma1))
579
+ x = x + self.ca(self.ca_norm(x), ca_kv).float().mul_(self.ca_gamma)
580
+ x = x + self.drop_path(self.ffn( self.ln_wo_grad(x.float()).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
581
+ else:
582
+ x_sa = self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1)
583
+ if self.checkpointing_sa_only and self.training:
584
+ x_sa = checkpoint(self.sa, x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False, start_layer=start_layer, src=src)
585
+ else:
586
+ x_sa = self.sa(x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, start_layer=start_layer, scale_ind=scale_ind, src=src)
587
+ x = x + self.drop_path(x_sa.mul_(gamma1))
588
+ x = x + self.ca(self.ca_norm(x), ca_kv).float().mul_(self.ca_gamma)
589
+ x = x + self.drop_path(self.ffn(self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale2, shift=shift2)).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
590
+ return x
591
+
592
+ def extra_repr(self) -> str:
593
+ return f'shared_aln={self.shared_aln}, fused_norm={self.fused_norm_func is not None}, ca_gamma={"<learnable>" if isinstance(self.ca_gamma, nn.Parameter) else self.ca_gamma}'
594
+
595
+
596
+ class AdaLNBeforeHead(nn.Module):
597
+ def __init__(self, C, D, act: bool, norm_layer: partial, fused_norm_func=None): # C: embed_dim, D: cond_dim
598
+ super().__init__()
599
+ self.C, self.D = C, D
600
+ self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
601
+ self.fused_norm_func = fused_norm_func
602
+ self.norm_eps = norm_layer.keywords.get('eps', 1e-6)
603
+ lin = nn.Linear(D, 2*C)
604
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin)
605
+
606
+ def forward(self, x_BLC: torch.Tensor, cond_BD: Optional[torch.Tensor]):
607
+ scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
608
+ if self.fused_norm_func is None:
609
+ return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)
610
+ else:
611
+ return self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x_BLC, scale=scale, shift=shift)
612
+
613
+
614
+ def main():
615
+ dev = 'cpu' # 'cuda' if torch.cuda.is_available() else 'cpu'
616
+ rng = torch.Generator(device=dev)
617
+ # for Li in ([1, 3, 5], [1, 3]):
618
+ rng.manual_seed(0)
619
+ B, H, cq, ckv = 4, 8, 64, 96
620
+ Cq = H*cq
621
+ Ckv = H*ckv
622
+
623
+ Li = [5, 4, 7, 6]
624
+ Lq = 10
625
+ L = max(Li)
626
+ attn_bias = torch.zeros(B, 1, Lq, L, device=dev)
627
+ for i, x in enumerate(Li):
628
+ attn_bias[i, 0, :, x:] = -torch.inf
629
+
630
+ q = torch.randn(B, Lq, H, cq, generator=rng, device=dev)
631
+ k = torch.randn(B, L, H, ckv, generator=rng, device=dev)
632
+ v = torch.randn(B, L, H, ckv, generator=rng, device=dev)
633
+ tq, tk, tv = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # BHLc
634
+
635
+ seqlen_k = torch.tensor(Li, dtype=torch.int32, device=dev)
636
+ cu_seqlens_k = F.pad(torch.cumsum(seqlen_k, dim=0, dtype=torch.torch.int32), (1, 0))
637
+ kv = torch.stack([k, v], dim=2)
638
+ kv_compact = torch.cat([kv[i, :Li[i]] for i in range(B)], dim=0)
639
+
640
+ ca = CrossAttention(for_attn_pool=False, embed_dim=Cq, kv_dim=Ckv, num_heads=H)
641
+ CrossAttention.forward
642
+ ca(q, (kv_compact, cu_seqlens_k, max(Li))).mean().backward()
643
+
644
+
645
+ if __name__ == '__main__':
646
+ main()
infinity/models/bitwise_self_correction.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+
8
+
9
+ def labels2image(all_indices, label_type='int_label', scale_schedule=None):
10
+ summed_codes, recons_imgs = self.vae.decode_from_indices(all_indices, scale_schedule, label_type)
11
+ recons_img = recons_imgs[0]
12
+ recons_img = (recons_img + 1) / 2
13
+ recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1]
14
+ return recons_img
15
+
16
+ def features2image(raw_features):
17
+ recons_imgs = self.vae.decode(raw_features.squeeze(-3))
18
+ recons_img = recons_imgs[0]
19
+ recons_img = (recons_img + 1) / 2
20
+ recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1]
21
+ return recons_img
22
+
23
+ class BitwiseSelfCorrection(object):
24
+ def __init__(self, vae, args):
25
+ self.noise_apply_layers = args.noise_apply_layers
26
+ self.noise_apply_requant = args.noise_apply_requant
27
+ self.noise_apply_strength = args.noise_apply_strength
28
+ self.apply_spatial_patchify = args.apply_spatial_patchify
29
+ self.vae = vae
30
+ self.debug_bsc = args.debug_bsc
31
+
32
+ def flip_requant(self, vae_scale_schedule, inp_B3HW, raw_features, device, src=False):
33
+ with torch.amp.autocast('cuda', enabled = False):
34
+ B = raw_features.shape[0]
35
+ if raw_features.dim() == 4:
36
+ codes_out = raw_features.unsqueeze(2)
37
+ else:
38
+ codes_out = raw_features
39
+ cum_var_input = 0
40
+ gt_all_bit_indices = []
41
+ pred_all_bit_indices = []
42
+ if src:
43
+ residual = F.interpolate(codes_out, size=vae_scale_schedule[-1], mode=self.vae.quantizer.z_interplote_down).contiguous()
44
+ if self.apply_spatial_patchify:
45
+ # (B,d,1,H,W) -> (B,d,H,W) -> (B,4d,H/2,W/2)
46
+ residual = torch.nn.functional.pixel_unshuffle(residual.squeeze(-3), 2)
47
+ x_BLC_wo_prefix = residual.reshape(*residual.shape[:2], -1).permute(0,2,1)
48
+ gt_ms_idx_Bl = None
49
+ else:
50
+ x_BLC_wo_prefix = []
51
+ for si, (pt, ph, pw) in enumerate(vae_scale_schedule):
52
+ residual = codes_out - cum_var_input
53
+ if si != len(vae_scale_schedule)-1:
54
+ residual = F.interpolate(residual, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_down).contiguous()
55
+ quantized, _, bit_indices, loss = self.vae.quantizer.lfq(residual) # quantized shape: [B, d_vae, 1, h, w], bit_indices shape: [B,1,h,w,d_vae]
56
+ gt_all_bit_indices.append(bit_indices)
57
+ if not src and si < self.noise_apply_layers:
58
+ noise_apply_strength = np.random.randint(0, 100 * self.noise_apply_strength+1) * 0.01
59
+ mask = torch.rand(*bit_indices.shape).to(device) < noise_apply_strength
60
+ pred_bit_indices = bit_indices.clone()
61
+ pred_bit_indices[mask] = 1 - pred_bit_indices[mask]
62
+ pred_all_bit_indices.append(pred_bit_indices)
63
+ if self.noise_apply_requant:
64
+ quantized = self.vae.quantizer.lfq.indices_to_codes(pred_bit_indices, label_type = 'bit_label')
65
+ else:
66
+ pred_all_bit_indices.append(bit_indices)
67
+ cum_var_input = cum_var_input + F.interpolate(quantized, size=vae_scale_schedule[-1], mode=self.vae.quantizer.z_interplote_up).contiguous()
68
+ if si < len(vae_scale_schedule)-1:
69
+ this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si+1], mode=self.vae.quantizer.z_interplote_up).contiguous()
70
+ if self.apply_spatial_patchify:
71
+ # (B,d,1,H,W) -> (B,d,H,W) -> (B,4d,H/2,W/2)
72
+ this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2)
73
+ x_BLC_wo_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) # (B,H/2*W/2,4C) or (B,H*W,C)
74
+ x_BLC_wo_prefix = torch.cat(x_BLC_wo_prefix, 1)
75
+ if self.apply_spatial_patchify:
76
+ gt_ms_idx_Bl = []
77
+ for item in gt_all_bit_indices:
78
+ # item shape: (B,1,H,W,d)
79
+ item = item.squeeze(1).permute(0,3,1,2) # (B,d,H,W)
80
+ # (B,d,H,W) -> (B,4d,H/2,W/2)
81
+ item = torch.nn.functional.pixel_unshuffle(item, 2)
82
+ # (B,4d,H/2,W/2) -> (B,H/2,W/2,4d) -> (B,H/2*w/2,4d)
83
+ item = item.permute(0,2,3,1).reshape(B, -1, 4*self.vae.codebook_dim)
84
+ gt_ms_idx_Bl.append(item)
85
+ else:
86
+ gt_ms_idx_Bl = [item.reshape(B, -1, self.vae.codebook_dim) for item in gt_all_bit_indices]
87
+
88
+
89
+ # if self.debug_bsc:
90
+ # self.visualize(vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices)
91
+
92
+ return x_BLC_wo_prefix, gt_ms_idx_Bl
93
+
94
+ def visualize(self, vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices):
95
+ gt_img = (inp_B3HW.squeeze(-3) + 1) / 2 * 255
96
+ gt_img = gt_img[0].permute(1,2,0).cpu().numpy().astype(np.uint8)[:,:,::-1]
97
+ recons_img_2 = labels2image(gt_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule)
98
+ recons_img_3 = labels2image(pred_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule)
99
+ cat_image = np.concatenate([gt_img, recons_img_2, recons_img_3], axis=1)
100
+ save_path = osp.abspath('non_teacher_force.jpg')
101
+ cv2.imwrite(save_path, cat_image)
102
+ print(f'Save to {save_path}')
103
+ import pdb; pdb.set_trace()
104
+ print(cat_image.shape)
105
+
infinity/models/bsq_vae/conv.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class Conv(nn.Module):
8
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, cnn_type="2d", causal_offset=0, temporal_down=False):
9
+ super().__init__()
10
+ self.cnn_type = cnn_type
11
+ self.slice_seq_len = 17
12
+
13
+ if cnn_type == "2d":
14
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
15
+ if cnn_type == "3d":
16
+ if temporal_down == False:
17
+ stride = (1, stride, stride)
18
+ else:
19
+ stride = (stride, stride, stride)
20
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0)
21
+ if isinstance(kernel_size, int):
22
+ kernel_size = (kernel_size, kernel_size, kernel_size)
23
+ self.padding = (
24
+ kernel_size[0] - 1 + causal_offset, # Temporal causal padding
25
+ padding, # Height padding
26
+ padding # Width padding
27
+ )
28
+ self.causal_offset = causal_offset
29
+ self.stride = stride
30
+ self.kernel_size = kernel_size
31
+
32
+ def forward(self, x):
33
+ if self.cnn_type == "2d":
34
+ if x.ndim == 5:
35
+ B, C, T, H, W = x.shape
36
+ x = rearrange(x, "B C T H W -> (B T) C H W")
37
+ x = self.conv(x)
38
+ x = rearrange(x, "(B T) C H W -> B C T H W", T=T)
39
+ return x
40
+ else:
41
+ return self.conv(x)
42
+ if self.cnn_type == "3d":
43
+ assert self.stride[0] == 1 or self.stride[0] == 2, f"only temporal stride = 1 or 2 are supported"
44
+ xs = []
45
+ for i in range(0, x.shape[2], self.slice_seq_len+self.stride[0]-1):
46
+ st = i
47
+ en = min(i+self.slice_seq_len, x.shape[2])
48
+ _x = x[:,:,st:en,:,:]
49
+ if i == 0:
50
+ _x = F.pad(_x, (self.padding[2], self.padding[2], # Width
51
+ self.padding[1], self.padding[1], # Height
52
+ self.padding[0], 0)) # Temporal
53
+ else:
54
+ padding_0 = self.kernel_size[0] - 1
55
+ _x = F.pad(_x, (self.padding[2], self.padding[2], # Width
56
+ self.padding[1], self.padding[1], # Height
57
+ padding_0, 0)) # Temporal
58
+ _x[:,:,:padding_0,
59
+ self.padding[1]:_x.shape[-2]-self.padding[1],
60
+ self.padding[2]:_x.shape[-1]-self.padding[2]] += x[:,:,i-padding_0:i,:,:]
61
+ _x = self.conv(_x)
62
+ xs.append(_x)
63
+ try:
64
+ x = torch.cat(xs, dim=2)
65
+ except:
66
+ device = x.device
67
+ del x
68
+ xs = [_x.cpu().pin_memory() for _x in xs]
69
+ torch.cuda.empty_cache()
70
+ x = torch.cat([_x.cpu() for _x in xs], dim=2).to(device=device)
71
+ return x
infinity/models/bsq_vae/dynamic_resolution.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import tqdm
4
+
5
+ vae_stride = 16
6
+ ratio2hws = {
7
+ 1.000: [(1,1),(2,2),(4,4),(6,6),(8,8),(12,12),(16,16),(20,20),(24,24),(32,32),(40,40),(48,48),(64,64)],
8
+ 1.250: [(1,1),(2,2),(3,3),(5,4),(10,8),(15,12),(20,16),(25,20),(30,24),(35,28),(45,36),(55,44),(70,56)],
9
+ 1.333: [(1,1),(2,2),(4,3),(8,6),(12,9),(16,12),(20,15),(24,18),(28,21),(36,27),(48,36),(60,45),(72,54)],
10
+ 1.500: [(1,1),(2,2),(3,2),(6,4),(9,6),(15,10),(21,14),(27,18),(33,22),(39,26),(48,32),(63,42),(78,52)],
11
+ 1.750: [(1,1),(2,2),(3,3),(7,4),(11,6),(14,8),(21,12),(28,16),(35,20),(42,24),(56,32),(70,40),(84,48)],
12
+ 2.000: [(1,1),(2,2),(4,2),(6,3),(10,5),(16,8),(22,11),(30,15),(38,19),(46,23),(60,30),(74,37),(90,45)],
13
+ 2.500: [(1,1),(2,2),(5,2),(10,4),(15,6),(20,8),(25,10),(30,12),(40,16),(50,20),(65,26),(80,32),(100,40)],
14
+ 3.000: [(1,1),(2,2),(6,2),(9,3),(15,5),(21,7),(27,9),(36,12),(45,15),(54,18),(72,24),(90,30),(111,37)],
15
+ }
16
+ full_ratio2hws = {}
17
+ for ratio, hws in ratio2hws.items():
18
+ full_ratio2hws[ratio] = hws
19
+ full_ratio2hws[int(1/ratio*1000)/1000] = [(item[1], item[0]) for item in hws]
20
+
21
+ dynamic_resolution_h_w = {}
22
+ predefined_HW_Scales_dynamic = {}
23
+ for ratio in full_ratio2hws:
24
+ dynamic_resolution_h_w[ratio] ={}
25
+ for ind, leng in enumerate([7, 10, 13]):
26
+ h, w = full_ratio2hws[ratio][leng-1][0], full_ratio2hws[ratio][leng-1][1] # feature map size
27
+ pixel = (h * vae_stride, w * vae_stride) # The original image (H, W)
28
+ dynamic_resolution_h_w[ratio][pixel[1]] = {
29
+ 'pixel': pixel,
30
+ 'scales': full_ratio2hws[ratio][:leng]
31
+ } # W as key
32
+ predefined_HW_Scales_dynamic[(h, w)] = full_ratio2hws[ratio][:leng]
infinity/models/bsq_vae/flux_vqgan.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import imageio
4
+ import torch
5
+ import numpy as np
6
+ from einops import rearrange
7
+ from torch import Tensor, nn
8
+ import torch.nn.functional as F
9
+ import torchvision
10
+ from torchvision import transforms
11
+ from safetensors.torch import load_file
12
+ import torch.utils.checkpoint as checkpoint
13
+
14
+ from .conv import Conv
15
+ from .multiscale_bsq import MultiScaleBSQ
16
+
17
+ ptdtype = {None: torch.float32, 'fp32': torch.float32, 'bf16': torch.bfloat16}
18
+
19
+ class Normalize(nn.Module):
20
+ def __init__(self, in_channels, norm_type, norm_axis="spatial"):
21
+ super().__init__()
22
+ self.norm_axis = norm_axis
23
+ assert norm_type in ['group', 'batch', "no"]
24
+ if norm_type == 'group':
25
+ if in_channels % 32 == 0:
26
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
27
+ elif in_channels % 24 == 0:
28
+ self.norm = nn.GroupNorm(num_groups=24, num_channels=in_channels, eps=1e-6, affine=True)
29
+ else:
30
+ raise NotImplementedError
31
+ elif norm_type == 'batch':
32
+ self.norm = nn.SyncBatchNorm(in_channels, track_running_stats=False) # Runtime Error: grad inplace if set track_running_stats to True
33
+ elif norm_type == 'no':
34
+ self.norm = nn.Identity()
35
+
36
+ def forward(self, x):
37
+ if self.norm_axis == "spatial":
38
+ if x.ndim == 4:
39
+ x = self.norm(x)
40
+ else:
41
+ B, C, T, H, W = x.shape
42
+ x = rearrange(x, "B C T H W -> (B T) C H W")
43
+ x = self.norm(x)
44
+ x = rearrange(x, "(B T) C H W -> B C T H W", T=T)
45
+ elif self.norm_axis == "spatial-temporal":
46
+ x = self.norm(x)
47
+ else:
48
+ raise NotImplementedError
49
+ return x
50
+
51
+ def swish(x: Tensor) -> Tensor:
52
+ try:
53
+ return x * torch.sigmoid(x)
54
+ except:
55
+ device = x.device
56
+ x = x.cpu().pin_memory()
57
+ return (x*torch.sigmoid(x)).to(device=device)
58
+
59
+
60
+ class AttnBlock(nn.Module):
61
+ def __init__(self, in_channels, norm_type='group', cnn_param=None):
62
+ super().__init__()
63
+ self.in_channels = in_channels
64
+
65
+ self.norm = Normalize(in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
66
+
67
+ self.q = Conv(in_channels, in_channels, kernel_size=1)
68
+ self.k = Conv(in_channels, in_channels, kernel_size=1)
69
+ self.v = Conv(in_channels, in_channels, kernel_size=1)
70
+ self.proj_out = Conv(in_channels, in_channels, kernel_size=1)
71
+
72
+ def attention(self, h_: Tensor) -> Tensor:
73
+ B, _, T, _, _ = h_.shape
74
+ h_ = self.norm(h_)
75
+ h_ = rearrange(h_, "B C T H W -> (B T) C H W") # spatial attention only
76
+ q = self.q(h_)
77
+ k = self.k(h_)
78
+ v = self.v(h_)
79
+
80
+ b, c, h, w = q.shape
81
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
82
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
83
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
84
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
85
+
86
+ return rearrange(h_, "(b t) 1 (h w) c -> b c t h w", h=h, w=w, c=c, b=B, t=T)
87
+
88
+ def forward(self, x: Tensor) -> Tensor:
89
+ return x + self.proj_out(self.attention(x))
90
+
91
+
92
+ class ResnetBlock(nn.Module):
93
+ def __init__(self, in_channels: int, out_channels: int, norm_type='group', cnn_param=None):
94
+ super().__init__()
95
+ self.in_channels = in_channels
96
+ out_channels = in_channels if out_channels is None else out_channels
97
+ self.out_channels = out_channels
98
+
99
+ self.norm1 = Normalize(in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
100
+ if cnn_param["res_conv_2d"] in ["half", "full"]:
101
+ self.conv1 = Conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d")
102
+ else:
103
+ self.conv1 = Conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
104
+ self.norm2 = Normalize(out_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
105
+ if cnn_param["res_conv_2d"] in ["full"]:
106
+ self.conv2 = Conv(out_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d")
107
+ else:
108
+ self.conv2 = Conv(out_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
109
+ if self.in_channels != self.out_channels:
110
+ self.nin_shortcut = Conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
111
+
112
+ def forward(self, x):
113
+ h = x
114
+ h = self.norm1(h)
115
+ h = swish(h)
116
+ h = self.conv1(h)
117
+
118
+ h = self.norm2(h)
119
+ h = swish(h)
120
+ h = self.conv2(h)
121
+
122
+ if self.in_channels != self.out_channels:
123
+ x = self.nin_shortcut(x)
124
+
125
+ return x + h
126
+
127
+
128
+ class Downsample(nn.Module):
129
+ def __init__(self, in_channels, cnn_type="2d", spatial_down=False, temporal_down=False):
130
+ super().__init__()
131
+ assert spatial_down == True
132
+ if cnn_type == "2d":
133
+ self.pad = (0,1,0,1)
134
+ if cnn_type == "3d":
135
+ self.pad = (0,1,0,1,0,0) # add padding to the right for h-axis and w-axis. No padding for t-axis
136
+ # no asymmetric padding in torch conv, must do it ourselves
137
+ self.conv = Conv(in_channels, in_channels, kernel_size=3, stride=2, padding=0, cnn_type=cnn_type, temporal_down=temporal_down)
138
+
139
+ def forward(self, x: Tensor):
140
+ x = nn.functional.pad(x, self.pad, mode="constant", value=0)
141
+ x = self.conv(x)
142
+ return x
143
+
144
+
145
+ class Upsample(nn.Module):
146
+ def __init__(self, in_channels, cnn_type="2d", spatial_up=False, temporal_up=False, use_pxsl=False):
147
+ super().__init__()
148
+ if cnn_type == "2d":
149
+ self.scale_factor = 2
150
+ self.causal_offset = 0
151
+ else:
152
+ assert spatial_up == True
153
+ if temporal_up:
154
+ self.scale_factor = (2,2,2)
155
+ self.causal_offset = -1
156
+ else:
157
+ self.scale_factor = (1,2,2)
158
+ self.causal_offset = 0
159
+ self.use_pxsl = use_pxsl
160
+ if self.use_pxsl:
161
+ self.conv = Conv(in_channels, in_channels*4, kernel_size=3, stride=1, padding=1, cnn_type=cnn_type, causal_offset=self.causal_offset)
162
+ self.pxsl = nn.PixelShuffle(2)
163
+ else:
164
+ self.conv = Conv(in_channels, in_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_type, causal_offset=self.causal_offset)
165
+
166
+ def forward(self, x: Tensor):
167
+ if self.use_pxsl:
168
+ x = self.conv(x)
169
+ x = self.pxsl(x)
170
+ else:
171
+ try:
172
+ x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
173
+ except:
174
+ # shard across channel
175
+ _xs = []
176
+ for i in range(x.shape[1]):
177
+ _x = F.interpolate(x[:,i:i+1,...], scale_factor=self.scale_factor, mode="nearest")
178
+ _xs.append(_x)
179
+ x = torch.cat(_xs, dim=1)
180
+ x = self.conv(x)
181
+ return x
182
+
183
+
184
+ class Encoder(nn.Module):
185
+ def __init__(
186
+ self,
187
+ ch: int,
188
+ ch_mult: list[int],
189
+ num_res_blocks: int,
190
+ z_channels: int,
191
+ in_channels = 3,
192
+ patch_size=8, temporal_patch_size=4,
193
+ norm_type='group', cnn_param=None,
194
+ use_checkpoint=False,
195
+ use_vae=True,
196
+ ):
197
+ super().__init__()
198
+ self.max_down = np.log2(patch_size)
199
+ self.temporal_max_down = np.log2(temporal_patch_size)
200
+ self.temporal_down_offset = self.max_down - self.temporal_max_down
201
+ self.ch = ch
202
+ self.num_resolutions = len(ch_mult)
203
+ self.num_res_blocks = num_res_blocks
204
+ self.in_channels = in_channels
205
+ self.cnn_param = cnn_param
206
+ self.use_checkpoint = use_checkpoint
207
+ # downsampling
208
+ # self.conv_in = Conv(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
209
+ # cnn_param["cnn_type"] = "2d" for images, cnn_param["cnn_type"] = "3d" for videos
210
+ if cnn_param["conv_in_out_2d"] == "yes": # "yes" for video
211
+ self.conv_in = Conv(in_channels, ch, kernel_size=3, stride=1, padding=1, cnn_type="2d")
212
+ else:
213
+ self.conv_in = Conv(in_channels, ch, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
214
+
215
+ in_ch_mult = (1,) + tuple(ch_mult)
216
+ self.in_ch_mult = in_ch_mult
217
+ self.down = nn.ModuleList()
218
+ block_in = self.ch
219
+ for i_level in range(self.num_resolutions):
220
+ block = nn.ModuleList()
221
+ attn = nn.ModuleList()
222
+ block_in = ch * in_ch_mult[i_level]
223
+ block_out = ch * ch_mult[i_level]
224
+ for _ in range(self.num_res_blocks):
225
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, norm_type=norm_type, cnn_param=cnn_param))
226
+ block_in = block_out
227
+ down = nn.Module()
228
+ down.block = block
229
+ down.attn = attn
230
+ # downsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE
231
+ spatial_down = True if i_level < self.max_down else False
232
+ temporal_down = True if i_level < self.max_down and i_level >= self.temporal_down_offset else False
233
+ if spatial_down or temporal_down:
234
+ down.downsample = Downsample(block_in, cnn_type=cnn_param["cnn_type"], spatial_down=spatial_down, temporal_down=temporal_down)
235
+ self.down.append(down)
236
+
237
+ # middle
238
+ self.mid = nn.Module()
239
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param)
240
+ if cnn_param["cnn_attention"] == "yes":
241
+ self.mid.attn_1 = AttnBlock(block_in, norm_type, cnn_param=cnn_param)
242
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param)
243
+
244
+ # end
245
+ self.norm_out = Normalize(block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
246
+ if cnn_param["conv_inner_2d"] == "yes":
247
+ self.conv_out = Conv(block_in, (int(use_vae) + 1) * z_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d")
248
+ else:
249
+ self.conv_out = Conv(block_in, (int(use_vae) + 1) * z_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
250
+
251
+ def forward(self, x, return_hidden=False):
252
+ if not self.use_checkpoint:
253
+ return self._forward(x, return_hidden=return_hidden)
254
+ else:
255
+ return checkpoint.checkpoint(self._forward, x, return_hidden, use_reentrant=False)
256
+
257
+ def _forward(self, x: Tensor, return_hidden=False) -> Tensor:
258
+ # downsampling
259
+ h0 = self.conv_in(x)
260
+ hs = [h0]
261
+ for i_level in range(self.num_resolutions):
262
+ for i_block in range(self.num_res_blocks):
263
+ h = self.down[i_level].block[i_block](hs[-1])
264
+ if len(self.down[i_level].attn) > 0:
265
+ h = self.down[i_level].attn[i_block](h)
266
+ hs.append(h)
267
+ if hasattr(self.down[i_level], "downsample"):
268
+ hs.append(self.down[i_level].downsample(hs[-1]))
269
+
270
+ # middle
271
+ h = hs[-1]
272
+ hs_mid = [h]
273
+ h = self.mid.block_1(h)
274
+ if self.cnn_param["cnn_attention"] == "yes":
275
+ h = self.mid.attn_1(h)
276
+ h = self.mid.block_2(h)
277
+ hs_mid.append(h)
278
+ # end
279
+ h = self.norm_out(h)
280
+ h = swish(h)
281
+ h = self.conv_out(h)
282
+ if return_hidden:
283
+ return h, hs, hs_mid
284
+ else:
285
+ return h
286
+
287
+
288
+ class Decoder(nn.Module):
289
+ def __init__(
290
+ self,
291
+ ch: int,
292
+ ch_mult: list[int],
293
+ num_res_blocks: int,
294
+ z_channels: int,
295
+ out_ch = 3,
296
+ patch_size=8, temporal_patch_size=4,
297
+ norm_type="group", cnn_param=None,
298
+ use_checkpoint=False,
299
+ use_freq_dec=False, # use frequency features for decoder
300
+ use_pxsf=False
301
+ ):
302
+ super().__init__()
303
+ self.max_up = np.log2(patch_size)
304
+ self.temporal_max_up = np.log2(temporal_patch_size)
305
+ self.temporal_up_offset = self.max_up - self.temporal_max_up
306
+ self.ch = ch
307
+ self.num_resolutions = len(ch_mult)
308
+ self.num_res_blocks = num_res_blocks
309
+ self.ffactor = 2 ** (self.num_resolutions - 1)
310
+ self.cnn_param = cnn_param
311
+ self.use_checkpoint = use_checkpoint
312
+ self.use_freq_dec = use_freq_dec
313
+ self.use_pxsf = use_pxsf
314
+
315
+ # compute in_ch_mult, block_in and curr_res at lowest res
316
+ block_in = ch * ch_mult[self.num_resolutions - 1]
317
+
318
+ # z to block_in
319
+ if cnn_param["conv_inner_2d"] == "yes":
320
+ self.conv_in = Conv(z_channels, block_in, kernel_size=3, stride=1, padding=1, cnn_type="2d")
321
+ else:
322
+ self.conv_in = Conv(z_channels, block_in, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
323
+
324
+ # middle
325
+ self.mid = nn.Module()
326
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param)
327
+ if cnn_param["cnn_attention"] == "yes":
328
+ self.mid.attn_1 = AttnBlock(block_in, norm_type=norm_type, cnn_param=cnn_param)
329
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param)
330
+
331
+ # upsampling
332
+ self.up = nn.ModuleList()
333
+ for i_level in reversed(range(self.num_resolutions)):
334
+ block = nn.ModuleList()
335
+ attn = nn.ModuleList()
336
+ block_out = ch * ch_mult[i_level]
337
+ for _ in range(self.num_res_blocks + 1):
338
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, norm_type=norm_type, cnn_param=cnn_param))
339
+ block_in = block_out
340
+ up = nn.Module()
341
+ up.block = block
342
+ up.attn = attn
343
+ # upsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE, offset 1 compared with encoder
344
+ # https://github.com/black-forest-labs/flux/blob/b4f689aaccd40de93429865793e84a734f4a6254/src/flux/modules/autoencoder.py#L228
345
+ spatial_up = True if 1 <= i_level <= self.max_up else False
346
+ temporal_up = True if 1 <= i_level <= self.max_up and i_level >= self.temporal_up_offset+1 else False
347
+ if spatial_up or temporal_up:
348
+ up.upsample = Upsample(block_in, cnn_type=cnn_param["cnn_type"], spatial_up=spatial_up, temporal_up=temporal_up, use_pxsl=self.use_pxsf)
349
+ self.up.insert(0, up) # prepend to get consistent order
350
+
351
+ # end
352
+ self.norm_out = Normalize(block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
353
+ if cnn_param["conv_in_out_2d"] == "yes":
354
+ self.conv_out = Conv(block_in, out_ch, kernel_size=3, stride=1, padding=1, cnn_type="2d")
355
+ else:
356
+ self.conv_out = Conv(block_in, out_ch, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
357
+
358
+ def forward(self, z):
359
+ if not self.use_checkpoint:
360
+ return self._forward(z)
361
+ else:
362
+ return checkpoint.checkpoint(self._forward, z, use_reentrant=False)
363
+
364
+ def _forward(self, z: Tensor) -> Tensor:
365
+ # z to block_in
366
+ h = self.conv_in(z)
367
+
368
+ # middle
369
+ h = self.mid.block_1(h)
370
+ if self.cnn_param["cnn_attention"] == "yes":
371
+ h = self.mid.attn_1(h)
372
+ h = self.mid.block_2(h)
373
+
374
+ # upsampling
375
+ for i_level in reversed(range(self.num_resolutions)):
376
+ for i_block in range(self.num_res_blocks + 1):
377
+ h = self.up[i_level].block[i_block](h)
378
+ if len(self.up[i_level].attn) > 0:
379
+ h = self.up[i_level].attn[i_block](h)
380
+ if hasattr(self.up[i_level], "upsample"):
381
+ h = self.up[i_level].upsample(h)
382
+
383
+ # end
384
+ h = self.norm_out(h)
385
+ h = swish(h)
386
+ h = self.conv_out(h)
387
+ return h
388
+
389
+
390
+ class AutoEncoder(nn.Module):
391
+ def __init__(self, args):
392
+ super().__init__()
393
+ self.args = args
394
+ cnn_param = dict(
395
+ cnn_type=args.cnn_type,
396
+ conv_in_out_2d=args.conv_in_out_2d,
397
+ res_conv_2d=args.res_conv_2d,
398
+ cnn_attention=args.cnn_attention,
399
+ cnn_norm_axis=args.cnn_norm_axis,
400
+ conv_inner_2d=args.conv_inner_2d,
401
+ )
402
+ self.encoder = Encoder(
403
+ ch=args.base_ch,
404
+ ch_mult=args.encoder_ch_mult,
405
+ num_res_blocks=args.num_res_blocks,
406
+ z_channels=args.codebook_dim,
407
+ patch_size=args.patch_size,
408
+ temporal_patch_size=args.temporal_patch_size,
409
+ cnn_param=cnn_param,
410
+ use_checkpoint=args.use_checkpoint,
411
+ use_vae=args.use_vae,
412
+ )
413
+ self.decoder = Decoder(
414
+ ch=args.base_ch,
415
+ ch_mult=args.decoder_ch_mult,
416
+ num_res_blocks=args.num_res_blocks,
417
+ z_channels=args.codebook_dim,
418
+ patch_size=args.patch_size,
419
+ temporal_patch_size=args.temporal_patch_size,
420
+ cnn_param=cnn_param,
421
+ use_checkpoint=args.use_checkpoint,
422
+ use_freq_dec=args.use_freq_dec,
423
+ use_pxsf=args.use_pxsf # pixelshuffle for upsampling
424
+ )
425
+ self.z_drop = nn.Dropout(args.z_drop)
426
+ self.scale_factor = 0.3611
427
+ self.shift_factor = 0.1159
428
+ self.codebook_dim = self.embed_dim = args.codebook_dim
429
+
430
+ self.gan_feat_weight = args.gan_feat_weight
431
+ self.video_perceptual_weight = args.video_perceptual_weight
432
+ self.recon_loss_type = args.recon_loss_type
433
+ self.l1_weight = args.l1_weight
434
+ self.use_vae = args.use_vae
435
+ self.kl_weight = args.kl_weight
436
+ self.lfq_weight = args.lfq_weight
437
+ self.image_gan_weight = args.image_gan_weight # image GAN loss weight
438
+ self.video_gan_weight = args.video_gan_weight # video GAN loss weight
439
+ self.perceptual_weight = args.perceptual_weight
440
+ self.flux_weight = args.flux_weight
441
+ self.cycle_weight = args.cycle_weight
442
+ self.cycle_feat_weight = args.cycle_feat_weight
443
+ self.cycle_gan_weight = args.cycle_gan_weight
444
+
445
+ self.flux_image_encoder = None
446
+
447
+ if not args.use_vae:
448
+ if args.quantizer_type == 'MultiScaleBSQ':
449
+ self.quantizer = MultiScaleBSQ(
450
+ dim = args.codebook_dim, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
451
+ codebook_size = args.codebook_size, # codebook size, must be a power of 2
452
+ entropy_loss_weight = args.entropy_loss_weight, # how much weight to place on entropy loss
453
+ diversity_gamma = args.diversity_gamma, # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894
454
+ preserve_norm=args.preserve_norm, # preserve norm of the input for BSQ
455
+ ln_before_quant=args.ln_before_quant, # use layer norm before quantization
456
+ ln_init_by_sqrt=args.ln_init_by_sqrt, # layer norm init value 1/sqrt(d)
457
+ commitment_loss_weight=args.commitment_loss_weight, # loss weight of commitment loss
458
+ new_quant=args.new_quant,
459
+ use_decay_factor=args.use_decay_factor,
460
+ mask_out=args.mask_out,
461
+ use_stochastic_depth=args.use_stochastic_depth,
462
+ drop_rate=args.drop_rate,
463
+ schedule_mode=args.schedule_mode,
464
+ keep_first_quant=args.keep_first_quant,
465
+ keep_last_quant=args.keep_last_quant,
466
+ remove_residual_detach=args.remove_residual_detach,
467
+ use_out_phi=args.use_out_phi,
468
+ use_out_phi_res=args.use_out_phi_res,
469
+ random_flip = args.random_flip,
470
+ flip_prob = args.flip_prob,
471
+ flip_mode = args.flip_mode,
472
+ max_flip_lvl = args.max_flip_lvl,
473
+ random_flip_1lvl = args.random_flip_1lvl,
474
+ flip_lvl_idx = args.flip_lvl_idx,
475
+ drop_when_test = args.drop_when_test,
476
+ drop_lvl_idx = args.drop_lvl_idx,
477
+ drop_lvl_num = args.drop_lvl_num,
478
+ )
479
+ self.quantize = self.quantizer
480
+ self.vocab_size = args.codebook_size
481
+ else:
482
+ raise NotImplementedError(f"{args.quantizer_type} not supported")
483
+
484
+
485
+ def forward(self, x):
486
+ is_image = x.ndim == 4
487
+ if not is_image:
488
+ B, C, T, H, W = x.shape
489
+ else:
490
+ B, C, H, W = x.shape
491
+ T = 1
492
+ enc_dtype = ptdtype[self.args.encoder_dtype]
493
+
494
+ with torch.amp.autocast("cuda", dtype=enc_dtype):
495
+ h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W
496
+ hs = [_h.detach() for _h in hs]
497
+ hs_mid = [_h.detach() for _h in hs_mid]
498
+ h = h.to(dtype=torch.float32)
499
+ # print(z.shape)
500
+ # Multiscale LFQ
501
+ z, all_indices, _, _, all_loss, _ = self.quantizer(h)
502
+ x_recon = self.decoder(z)
503
+ vq_output = {
504
+ "commitment_loss": torch.mean(all_loss) * self.lfq_weight, # here commitment loss is sum of commitment loss and entropy penalty
505
+ "encodings": all_indices,
506
+ }
507
+ return x_recon, vq_output
508
+
509
+ def encode_for_raw_features(self, x, scale_schedule, return_residual_norm_per_scale=False):
510
+ is_image = x.ndim == 4
511
+ if not is_image:
512
+ B, C, T, H, W = x.shape
513
+ else:
514
+ B, C, H, W = x.shape
515
+ T = 1
516
+
517
+ enc_dtype = ptdtype[self.args.encoder_dtype]
518
+ with torch.amp.autocast("cuda", dtype=enc_dtype):
519
+ h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W
520
+
521
+ hs = [_h.detach() for _h in hs]
522
+ hs_mid = [_h.detach() for _h in hs_mid]
523
+ h = h.to(dtype=torch.float32)
524
+ return h, hs, hs_mid
525
+
526
+ def encode(self, x, scale_schedule, return_residual_norm_per_scale=False):
527
+ h, hs, hs_mid = self.encode_for_raw_features(x, scale_schedule, return_residual_norm_per_scale)
528
+ # Multiscale LFQ
529
+ z, all_indices, all_bit_indices, residual_norm_per_scale, all_loss, var_input = self.quantizer(h, scale_schedule=scale_schedule, return_residual_norm_per_scale=return_residual_norm_per_scale)
530
+ return h, z, all_indices, all_bit_indices, residual_norm_per_scale, var_input
531
+
532
+ def decode(self, z):
533
+ x_recon = self.decoder(z)
534
+ x_recon = torch.clamp(x_recon, min=-1, max=1)
535
+ return x_recon
536
+
537
+ def decode_from_indices(self, all_indices, scale_schedule, label_type):
538
+ summed_codes = 0
539
+ for idx_Bl in all_indices:
540
+ codes = self.quantizer.lfq.indices_to_codes(idx_Bl, label_type)
541
+ summed_codes += F.interpolate(codes, size=scale_schedule[-1], mode=self.quantizer.z_interplote_up)
542
+ assert summed_codes.shape[-3] == 1
543
+ x_recon = self.decoder(summed_codes.squeeze(-3))
544
+ x_recon = torch.clamp(x_recon, min=-1, max=1)
545
+ return summed_codes, x_recon
546
+
547
+ @staticmethod
548
+ def add_model_specific_args(parent_parser):
549
+ parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
550
+ parser.add_argument("--flux_weight", type=float, default=0)
551
+ parser.add_argument("--cycle_weight", type=float, default=0)
552
+ parser.add_argument("--cycle_feat_weight", type=float, default=0)
553
+ parser.add_argument("--cycle_gan_weight", type=float, default=0)
554
+ parser.add_argument("--cycle_loop", type=int, default=0)
555
+ parser.add_argument("--z_drop", type=float, default=0.)
556
+ return parser
557
+
infinity/models/bsq_vae/multiscale_bsq.py ADDED
@@ -0,0 +1,718 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Binary Spherical Quantization
3
+ Proposed in https://arxiv.org/abs/2406.07548
4
+
5
+ In the simplest setup, each dimension is quantized into {-1, 1}.
6
+ An entropy penalty is used to encourage utilization.
7
+ """
8
+
9
+ import random
10
+ from math import log2, ceil
11
+ from functools import partial, cache
12
+ from collections import namedtuple
13
+ from contextlib import nullcontext
14
+
15
+ import torch.distributed as dist
16
+ from torch.distributed import nn as dist_nn
17
+
18
+ import torch
19
+ from torch import nn, einsum
20
+ import torch.nn.functional as F
21
+ from torch.nn import Module
22
+ from torch.amp import autocast
23
+ import numpy as np
24
+
25
+ from einops import rearrange, reduce, pack, unpack
26
+
27
+ # from einx import get_at
28
+
29
+ from .dynamic_resolution import predefined_HW_Scales_dynamic
30
+
31
+ # constants
32
+
33
+ Return = namedtuple('Return', ['quantized', 'indices', 'bit_indices', 'entropy_aux_loss'])
34
+
35
+ LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
36
+
37
+ # distributed helpers
38
+
39
+ @cache
40
+ def is_distributed():
41
+ return dist.is_initialized() and dist.get_world_size() > 1
42
+
43
+ def maybe_distributed_mean(t):
44
+ if not is_distributed():
45
+ return t
46
+
47
+ dist_nn.all_reduce(t)
48
+ t = t / dist.get_world_size()
49
+ return t
50
+
51
+ # helper functions
52
+
53
+ def exists(v):
54
+ return v is not None
55
+
56
+ def identity(t):
57
+ return t
58
+
59
+ def default(*args):
60
+ for arg in args:
61
+ if exists(arg):
62
+ return arg() if callable(arg) else arg
63
+ return None
64
+
65
+ def round_up_multiple(num, mult):
66
+ return ceil(num / mult) * mult
67
+
68
+ def pack_one(t, pattern):
69
+ return pack([t], pattern)
70
+
71
+ def unpack_one(t, ps, pattern):
72
+ return unpack(t, ps, pattern)[0]
73
+
74
+ def l2norm(t):
75
+ return F.normalize(t, dim = -1)
76
+
77
+ # entropy
78
+
79
+ def log(t, eps = 1e-5):
80
+ return t.clamp(min = eps).log()
81
+
82
+ def entropy(prob):
83
+ return (-prob * log(prob)).sum(dim=-1)
84
+
85
+ # cosine sim linear
86
+
87
+ class CosineSimLinear(Module):
88
+ def __init__(
89
+ self,
90
+ dim_in,
91
+ dim_out,
92
+ scale = 1.
93
+ ):
94
+ super().__init__()
95
+ self.scale = scale
96
+ self.weight = nn.Parameter(torch.randn(dim_in, dim_out))
97
+
98
+ def forward(self, x):
99
+ x = F.normalize(x, dim = -1)
100
+ w = F.normalize(self.weight, dim = 0)
101
+ return (x @ w) * self.scale
102
+
103
+
104
+ def get_latent2scale_schedule(T: int, H: int, W: int, mode="original"):
105
+ assert mode in ["original", "dynamic", "dense", "same1", "same2", "same3"]
106
+ predefined_HW_Scales = {
107
+ # 256 * 256
108
+ (32, 32): [(1, 1), (2, 2), (3, 3), (4, 4), (6, 6), (9, 9), (13, 13), (18, 18), (24, 24), (32, 32)],
109
+ (16, 16): [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (8, 8), (10, 10), (13, 13), (16, 16)],
110
+ # 1024x1024
111
+ (64, 64): [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (7, 7), (9, 9), (12, 12), (16, 16), (21, 21), (27, 27), (36, 36), (48, 48), (64, 64)],
112
+
113
+ (36, 64): [(1, 1), (2, 2), (3, 3), (4, 4), (6, 6), (9, 12), (13, 16), (18, 24), (24, 32), (32, 48), (36, 64)],
114
+ }
115
+ if mode == "dynamic":
116
+ predefined_HW_Scales.update(predefined_HW_Scales_dynamic)
117
+ elif mode == "dense":
118
+ predefined_HW_Scales[(16, 16)] = [(x, x) for x in range(1, 16+1)]
119
+ predefined_HW_Scales[(32, 32)] = predefined_HW_Scales[(16, 16)] + [(20, 20), (24, 24), (28, 28), (32, 32)]
120
+ predefined_HW_Scales[(64, 64)] = predefined_HW_Scales[(32, 32)] + [(40, 40), (48, 48), (56, 56), (64, 64)]
121
+ elif mode.startswith("same"):
122
+ num_quant = int(mode[len("same"):])
123
+ predefined_HW_Scales[(16, 16)] = [(16, 16) for _ in range(num_quant)]
124
+ predefined_HW_Scales[(32, 32)] = [(32, 32) for _ in range(num_quant)]
125
+ predefined_HW_Scales[(64, 64)] = [(64, 64) for _ in range(num_quant)]
126
+
127
+ predefined_T_Scales = [1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15, 17, 17, 17, 17, 17]
128
+ patch_THW_shape_per_scale = predefined_HW_Scales[(H, W)]
129
+ if len(predefined_T_Scales) < len(patch_THW_shape_per_scale):
130
+ # print("warning: the length of predefined_T_Scales is less than the length of patch_THW_shape_per_scale!")
131
+ predefined_T_Scales += [predefined_T_Scales[-1]] * (len(patch_THW_shape_per_scale) - len(predefined_T_Scales))
132
+ patch_THW_shape_per_scale = [(min(T, t), h, w ) for (h, w), t in zip(patch_THW_shape_per_scale, predefined_T_Scales[:len(patch_THW_shape_per_scale)])]
133
+ return patch_THW_shape_per_scale
134
+
135
+ class LayerNorm(nn.Module):
136
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
137
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
138
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
139
+ with shape (batch_size, channels, height, width).
140
+ normalized_shape: int
141
+ """
142
+ def __init__(self, normalized_shape, norm_weight=False, eps=1e-6, data_format="channels_first"):
143
+ super().__init__()
144
+ if norm_weight:
145
+ self.weight = nn.Parameter(torch.ones(normalized_shape)/(normalized_shape**0.5))
146
+ else:
147
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
148
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
149
+ self.eps = eps
150
+ self.data_format = data_format
151
+ if self.data_format not in ["channels_last", "channels_first"]:
152
+ raise NotImplementedError
153
+ self.normalized_shape = (normalized_shape, )
154
+
155
+ def forward(self, x):
156
+ if self.data_format == "channels_last":
157
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
158
+ elif self.data_format == "channels_first":
159
+ u = x.mean(1, keepdim=True)
160
+ s = (x - u).pow(2).mean(1, keepdim=True)
161
+ x = (x - u) / torch.sqrt(s + self.eps)
162
+ if x.ndim == 4: # (b, c, h, w)
163
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
164
+ elif x.ndim == 5: # (b, c, t, h, w)
165
+ x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
166
+ else:
167
+ raise ValueError("the number of dimensions of the input should be 4 or 5")
168
+ return x
169
+
170
+ class MultiScaleBSQ(Module):
171
+ """ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
172
+
173
+ def __init__(
174
+ self,
175
+ *,
176
+ dim,
177
+ codebook_size,
178
+ soft_clamp_input_value = None,
179
+ aux_loss = False, # intermediate auxiliary loss
180
+ ln_before_quant=False, # add a LN before multi-scale RQ
181
+ ln_init_by_sqrt=False, # weight init by 1/sqrt(d)
182
+ use_decay_factor=False,
183
+ use_stochastic_depth=False,
184
+ drop_rate=0.,
185
+ schedule_mode="original", # ["original", "dynamic", "dense"]
186
+ keep_first_quant=False,
187
+ keep_last_quant=False,
188
+ remove_residual_detach=False,
189
+ random_flip = False,
190
+ flip_prob = 0.5,
191
+ flip_mode = "stochastic", # "stochastic", "deterministic"
192
+ max_flip_lvl = 1,
193
+ random_flip_1lvl = False, # random flip one level each time
194
+ flip_lvl_idx = None,
195
+ drop_when_test=False,
196
+ drop_lvl_idx=None,
197
+ drop_lvl_num=0,
198
+ **kwargs
199
+ ):
200
+ super().__init__()
201
+ codebook_dim = int(log2(codebook_size))
202
+
203
+ requires_projection = codebook_dim != dim
204
+ self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
205
+ self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
206
+ self.has_projections = requires_projection
207
+ self.layernorm = LayerNorm(codebook_dim, norm_weight=ln_init_by_sqrt) if ln_before_quant else nn.Identity()
208
+ self.use_stochastic_depth = use_stochastic_depth
209
+ self.drop_rate = drop_rate
210
+ self.remove_residual_detach = remove_residual_detach
211
+ self.random_flip = random_flip
212
+ self.flip_prob = flip_prob
213
+ self.flip_mode = flip_mode
214
+ self.max_flip_lvl = max_flip_lvl
215
+ self.random_flip_1lvl = random_flip_1lvl
216
+ self.flip_lvl_idx = flip_lvl_idx
217
+ assert (random_flip and random_flip_1lvl) == False
218
+ self.drop_when_test = drop_when_test
219
+ self.drop_lvl_idx = drop_lvl_idx
220
+ self.drop_lvl_num = drop_lvl_num
221
+ if self.drop_when_test:
222
+ assert drop_lvl_idx is not None
223
+ assert drop_lvl_num > 0
224
+
225
+ self.lfq = BSQ(
226
+ dim = codebook_dim,
227
+ codebook_scale = 1/np.sqrt(codebook_dim),
228
+ soft_clamp_input_value = soft_clamp_input_value,
229
+ # experimental_softplus_entropy_loss=True,
230
+ # entropy_loss_offset=2,
231
+ **kwargs
232
+ )
233
+
234
+ self.z_interplote_up = 'trilinear'
235
+ self.z_interplote_down = 'area'
236
+
237
+ self.use_decay_factor = use_decay_factor
238
+ self.schedule_mode = schedule_mode
239
+ self.keep_first_quant = keep_first_quant
240
+ self.keep_last_quant = keep_last_quant
241
+ if self.use_stochastic_depth and self.drop_rate > 0:
242
+ assert self.keep_first_quant or self.keep_last_quant
243
+
244
+ @property
245
+ def codebooks(self):
246
+ return self.lfq.codebook
247
+
248
+ def get_codes_from_indices(self, indices_list):
249
+ all_codes = []
250
+ for indices in indices_list:
251
+ codes = self.lfq.indices_to_codes(indices)
252
+ all_codes.append(codes)
253
+ _, _, T, H, W = all_codes[-1].size()
254
+ summed_codes = 0
255
+ for code in all_codes:
256
+ summed_codes += F.interpolate(code, size=(T, H, W), mode=self.z_interplote_up)
257
+ return summed_codes
258
+
259
+ def get_output_from_indices(self, indices):
260
+ codes = self.get_codes_from_indices(indices)
261
+ codes_summed = reduce(codes, 'q ... -> ...', 'sum')
262
+ return self.project_out(codes_summed)
263
+
264
+ def flip_quant(self, x):
265
+ assert self.flip_mode == 'stochastic'
266
+ flip_mask = torch.rand_like(x) < self.flip_prob
267
+ x = x.clone()
268
+ x[flip_mask] = -x[flip_mask]
269
+ return x
270
+
271
+ def forward(
272
+ self,
273
+ x,
274
+ scale_schedule=None,
275
+ mask = None,
276
+ return_all_codes = False,
277
+ return_residual_norm_per_scale = False
278
+ ):
279
+ if x.ndim == 4:
280
+ x = x.unsqueeze(2)
281
+ B, C, T, H, W = x.size()
282
+
283
+ if scale_schedule is None:
284
+ if self.schedule_mode.startswith("same"):
285
+ scale_num = int(self.schedule_mode[len("same"):])
286
+ assert T == 1
287
+ scale_schedule = [(1, H, W)] * scale_num
288
+ else:
289
+ scale_schedule = get_latent2scale_schedule(T, H, W, mode=self.schedule_mode)
290
+ scale_num = len(scale_schedule)
291
+
292
+ # x = self.project_in(x)
293
+ x = x.permute(0, 2, 3, 4, 1).contiguous() # (b, c, t, h, w) => (b, t, h, w, c)
294
+ x = self.project_in(x)
295
+ x = x.permute(0, 4, 1, 2, 3).contiguous() # (b, t, h, w, c) => (b, c, t, h, w)
296
+ x = self.layernorm(x)
297
+
298
+ quantized_out = 0.
299
+ residual = x
300
+
301
+ all_losses = []
302
+ all_indices = []
303
+ all_bit_indices = []
304
+ var_inputs = []
305
+ residual_norm_per_scale = []
306
+
307
+ # go through the layers
308
+ out_fact = init_out_fact = 1.0
309
+ # residual_list = []
310
+ # interpolate_residual_list = []
311
+ # quantized_list = []
312
+ if self.drop_when_test:
313
+ drop_lvl_start = self.drop_lvl_idx
314
+ drop_lvl_end = self.drop_lvl_idx + self.drop_lvl_num
315
+ scale_num = len(scale_schedule)
316
+ with autocast('cuda', enabled = False):
317
+ for si, (pt, ph, pw) in enumerate(scale_schedule):
318
+ out_fact = max(0.1, out_fact) if self.use_decay_factor else init_out_fact
319
+ if (pt, ph, pw) != (T, H, W):
320
+ interpolate_residual = F.interpolate(residual, size=(pt, ph, pw), mode=self.z_interplote_down)
321
+ else:
322
+ interpolate_residual = residual
323
+ if return_residual_norm_per_scale:
324
+ residual_norm_per_scale.append((torch.abs(interpolate_residual) < 0.05 * self.lfq.codebook_scale).sum() / interpolate_residual.numel())
325
+ # residual_list.append(torch.norm(residual.detach(), dim=1).mean())
326
+ # interpolate_residual_list.append(torch.norm(interpolate_residual.detach(), dim=1).mean())
327
+ if self.training and self.use_stochastic_depth and random.random() < self.drop_rate:
328
+ if (si == 0 and self.keep_first_quant) or (si == scale_num - 1 and self.keep_last_quant):
329
+ quantized, indices, _, loss = self.lfq(interpolate_residual)
330
+ quantized = quantized * out_fact
331
+ all_indices.append(indices)
332
+ all_losses.append(loss)
333
+ else:
334
+ quantized = torch.zeros_like(interpolate_residual)
335
+ elif self.drop_when_test and drop_lvl_start <= si < drop_lvl_end:
336
+ continue
337
+ else:
338
+ # residual_norm = torch.norm(interpolate_residual.detach(), dim=1) # (b, t, h, w)
339
+ # print(si, residual_norm.min(), residual_norm.max(), residual_norm.mean())
340
+ quantized, indices, bit_indices, loss = self.lfq(interpolate_residual)
341
+ if self.random_flip and si < self.max_flip_lvl:
342
+ quantized = self.flip_quant(quantized)
343
+ if self.random_flip_1lvl and si == self.flip_lvl_idx:
344
+ quantized = self.flip_quant(quantized)
345
+ quantized = quantized * out_fact
346
+ all_indices.append(indices)
347
+ # quantized_list.append(torch.norm(quantized.detach(), dim=1).mean())
348
+ if (pt, ph, pw) != (T, H, W):
349
+ quantized = F.interpolate(quantized, size=(T, H, W), mode=self.z_interplote_up).contiguous()
350
+
351
+ if self.remove_residual_detach:
352
+ residual = residual - quantized
353
+ else:
354
+ residual = residual - quantized.detach()
355
+ quantized_out = quantized_out + quantized
356
+
357
+ all_bit_indices.append(bit_indices)
358
+ all_losses.append(loss)
359
+ if si != scale_num - 1:
360
+ var_inputs.append(F.interpolate(quantized_out, size=scale_schedule[si+1], mode=self.z_interplote_down).contiguous())
361
+
362
+ if self.use_decay_factor:
363
+ out_fact -= 0.1
364
+ # print("residual_list:", residual_list)
365
+ # print("interpolate_residual_list:", interpolate_residual_list)
366
+ # print("quantized_list:", quantized_list)
367
+ # import ipdb; ipdb.set_trace()
368
+ # project out, if needed
369
+ quantized_out = quantized_out.permute(0, 2, 3, 4, 1).contiguous() # (b, c, t, h, w) => (b, t, h, w, c)
370
+ quantized_out = self.project_out(quantized_out)
371
+ quantized_out = quantized_out.permute(0, 4, 1, 2, 3).contiguous() # (b, t, h, w, c) => (b, c, t, h, w)
372
+
373
+ # image
374
+ if quantized_out.size(2) == 1:
375
+ quantized_out = quantized_out.squeeze(2)
376
+
377
+ # stack all losses and indices
378
+
379
+ all_losses = torch.stack(all_losses, dim = -1)
380
+
381
+ ret = (quantized_out, all_indices, all_bit_indices, residual_norm_per_scale, all_losses, var_inputs)
382
+
383
+ if not return_all_codes:
384
+ return ret
385
+
386
+ # whether to return all codes from all codebooks across layers
387
+ all_codes = self.get_codes_from_indices(all_indices)
388
+
389
+ # will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
390
+
391
+ return (*ret, all_codes)
392
+
393
+
394
+ class BSQ(Module):
395
+ def __init__(
396
+ self,
397
+ *,
398
+ dim = None,
399
+ codebook_size = None,
400
+ entropy_loss_weight = 0.1,
401
+ commitment_loss_weight = 0.25,
402
+ diversity_gamma = 1.,
403
+ straight_through_activation = nn.Identity(),
404
+ num_codebooks = 1,
405
+ keep_num_codebooks_dim = None,
406
+ codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
407
+ frac_per_sample_entropy = 1., # make less than 1. to only use a random fraction of the probs for per sample entropy
408
+ has_projections = None,
409
+ projection_has_bias = True,
410
+ soft_clamp_input_value = None,
411
+ cosine_sim_project_in = False,
412
+ cosine_sim_project_in_scale = None,
413
+ channel_first = None,
414
+ experimental_softplus_entropy_loss = False,
415
+ entropy_loss_offset = 5., # how much to shift the loss before softplus
416
+ spherical = True, # from https://arxiv.org/abs/2406.07548
417
+ force_quantization_f32 = True, # will force the quantization step to be full precision
418
+ inv_temperature = 100.0,
419
+ gamma0=1.0, gamma=1.0, zeta=1.0,
420
+ preserve_norm = False, # whether to preserve the original norm info
421
+ new_quant = False, # new quant function,
422
+ mask_out = False, # mask the output as 0 in some conditions
423
+ use_out_phi = False, # use output phi network
424
+ use_out_phi_res = False, # residual out phi
425
+ ):
426
+ super().__init__()
427
+
428
+ # some assert validations
429
+
430
+ assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
431
+ assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
432
+
433
+ codebook_size = default(codebook_size, lambda: 2 ** dim)
434
+ self.codebook_size = codebook_size
435
+
436
+ codebook_dim = int(log2(codebook_size))
437
+ codebook_dims = codebook_dim * num_codebooks
438
+ dim = default(dim, codebook_dims)
439
+ self.codebook_dims = codebook_dims
440
+
441
+ has_projections = default(has_projections, dim != codebook_dims)
442
+
443
+ if cosine_sim_project_in:
444
+ cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale)
445
+ project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in)
446
+ else:
447
+ project_in_klass = partial(nn.Linear, bias = projection_has_bias)
448
+
449
+ self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity() # nn.Identity()
450
+ self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity() # nn.Identity()
451
+ self.has_projections = has_projections
452
+
453
+ self.out_phi = nn.Linear(codebook_dims, codebook_dims) if use_out_phi else nn.Identity()
454
+ self.use_out_phi_res = use_out_phi_res
455
+ if self.use_out_phi_res:
456
+ self.out_phi_scale = nn.Parameter(torch.zeros(codebook_dims), requires_grad=True) # init as zero
457
+
458
+ self.dim = dim
459
+ self.codebook_dim = codebook_dim
460
+ self.num_codebooks = num_codebooks
461
+
462
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
463
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
464
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
465
+
466
+ # channel first
467
+
468
+ self.channel_first = channel_first
469
+
470
+ # straight through activation
471
+
472
+ self.activation = straight_through_activation
473
+
474
+ # For BSQ (binary spherical quantization)
475
+ if not spherical:
476
+ raise ValueError("For BSQ, spherical must be True.")
477
+ self.persample_entropy_compute = 'analytical'
478
+ self.inv_temperature = inv_temperature
479
+ self.gamma0 = gamma0 # loss weight for entropy penalty
480
+ self.gamma = gamma # loss weight for entropy penalty
481
+ self.zeta = zeta # loss weight for entire entropy penalty
482
+ self.preserve_norm = preserve_norm
483
+ self.new_quant = new_quant
484
+ self.mask_out = mask_out
485
+
486
+ # entropy aux loss related weights
487
+
488
+ assert 0 < frac_per_sample_entropy <= 1.
489
+ self.frac_per_sample_entropy = frac_per_sample_entropy
490
+
491
+ self.diversity_gamma = diversity_gamma
492
+ self.entropy_loss_weight = entropy_loss_weight
493
+
494
+ # codebook scale
495
+
496
+ self.codebook_scale = codebook_scale
497
+
498
+ # commitment loss
499
+
500
+ self.commitment_loss_weight = commitment_loss_weight
501
+
502
+ # whether to soft clamp the input value from -value to value
503
+
504
+ self.soft_clamp_input_value = soft_clamp_input_value
505
+ assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale
506
+
507
+ # whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions)
508
+
509
+ self.entropy_loss_offset = entropy_loss_offset
510
+ self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss
511
+
512
+ # for no auxiliary loss, during inference
513
+
514
+ self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
515
+ self.register_buffer('zero', torch.tensor(0.), persistent = False)
516
+
517
+ # whether to force quantization step to be f32
518
+
519
+ self.force_quantization_f32 = force_quantization_f32
520
+
521
+ # codes
522
+
523
+ # all_codes = torch.arange(codebook_size)
524
+ # bits = ((all_codes[..., None].int() & self.mask) != 0).float()
525
+ # codebook = self.bits_to_codes(bits)
526
+
527
+ # self.register_buffer('codebook', codebook.float(), persistent = False)
528
+
529
+ def bits_to_codes(self, bits):
530
+ return bits * self.codebook_scale * 2 - self.codebook_scale
531
+
532
+ # @property
533
+ # def dtype(self):
534
+ # return self.codebook.dtype
535
+
536
+ def indices_to_codes(
537
+ self,
538
+ indices,
539
+ label_type = 'int_label',
540
+ project_out = True
541
+ ):
542
+ assert label_type in ['int_label', 'bit_label']
543
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
544
+ should_transpose = default(self.channel_first, is_img_or_video)
545
+
546
+ if not self.keep_num_codebooks_dim:
547
+ if label_type == 'int_label':
548
+ indices = rearrange(indices, '... -> ... 1')
549
+ else:
550
+ indices = indices.unsqueeze(-2)
551
+
552
+ # indices to codes, which are bits of either -1 or 1
553
+
554
+ if label_type == 'int_label':
555
+ assert indices[..., None].int().min() > 0
556
+ bits = ((indices[..., None].int() & self.mask) != 0).float() # .to(self.dtype)
557
+ else:
558
+ bits = indices
559
+
560
+ codes = self.bits_to_codes(bits)
561
+
562
+ codes = l2norm(codes) # must normalize when using BSQ
563
+
564
+ codes = rearrange(codes, '... c d -> ... (c d)')
565
+
566
+ # whether to project codes out to original dimensions
567
+ # if the input feature dimensions were not log2(codebook size)
568
+
569
+ if project_out:
570
+ codes = self.project_out(codes)
571
+
572
+ # rearrange codes back to original shape
573
+
574
+ if should_transpose:
575
+ codes = rearrange(codes, 'b ... d -> b d ...')
576
+
577
+ return codes
578
+
579
+ def quantize(self, z):
580
+ assert z.shape[-1] == self.codebook_dims, f"Expected {self.codebook_dims} dimensions, got {z.shape[-1]}"
581
+
582
+ zhat = torch.where(z > 0,
583
+ torch.tensor(1, dtype=z.dtype, device=z.device),
584
+ torch.tensor(-1, dtype=z.dtype, device=z.device))
585
+ return z + (zhat - z).detach()
586
+
587
+ def quantize_new(self, z):
588
+ assert z.shape[-1] == self.codebook_dims, f"Expected {self.codebook_dims} dimensions, got {z.shape[-1]}"
589
+
590
+ zhat = torch.where(z > 0,
591
+ torch.tensor(1, dtype=z.dtype, device=z.device),
592
+ torch.tensor(-1, dtype=z.dtype, device=z.device))
593
+
594
+ q_scale = 1. / (self.codebook_dims ** 0.5)
595
+ zhat = q_scale * zhat # on unit sphere
596
+
597
+ return z + (zhat - z).detach()
598
+
599
+ def soft_entropy_loss(self, z):
600
+ if self.persample_entropy_compute == 'analytical':
601
+ # if self.l2_norm:
602
+ p = torch.sigmoid(-4 * z / (self.codebook_dims ** 0.5) * self.inv_temperature)
603
+ # else:
604
+ # p = torch.sigmoid(-4 * z * self.inv_temperature)
605
+ prob = torch.stack([p, 1-p], dim=-1) # (b, h, w, 18, 2)
606
+ per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() # (b,h,w,18)->(b,h,w)->scalar
607
+ else:
608
+ per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
609
+
610
+ # macro average of the probability of each subgroup
611
+ avg_prob = reduce(prob, '... g d ->g d', 'mean') # (18, 2)
612
+ codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False)
613
+
614
+ # the approximation of the entropy is the sum of the entropy of each subgroup
615
+ return per_sample_entropy, codebook_entropy.sum(), avg_prob
616
+
617
+ def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True):
618
+ if normalize: # False
619
+ probs = (count + eps) / (count + eps).sum(dim=dim, keepdim =True)
620
+ else: # True
621
+ probs = count
622
+ H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim)
623
+ return H
624
+
625
+ def forward(
626
+ self,
627
+ x,
628
+ return_loss_breakdown = False,
629
+ mask = None,
630
+ entropy_weight=0.1
631
+ ):
632
+ """
633
+ einstein notation
634
+ b - batch
635
+ n - sequence (or flattened spatial dimensions)
636
+ d - feature dimension, which is also log2(codebook size)
637
+ c - number of codebook dim
638
+ """
639
+
640
+ is_img_or_video = x.ndim >= 4
641
+ should_transpose = default(self.channel_first, is_img_or_video)
642
+
643
+ # standardize image or video into (batch, seq, dimension)
644
+
645
+ if should_transpose:
646
+ x = rearrange(x, 'b d ... -> b ... d')
647
+ x, ps = pack_one(x, 'b * d') # x.shape [b, hwt, c]
648
+
649
+ assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
650
+
651
+ x = self.project_in(x)
652
+
653
+ # split out number of codebooks
654
+
655
+ x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
656
+
657
+ x = l2norm(x)
658
+
659
+ # whether to force quantization step to be full precision or not
660
+
661
+ force_f32 = self.force_quantization_f32
662
+
663
+ quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
664
+
665
+ indices = None
666
+ with quantization_context():
667
+
668
+ if force_f32:
669
+ orig_dtype = x.dtype
670
+ x = x.float()
671
+
672
+ # use straight-through gradients (optionally with custom activation fn) if training
673
+ if self.new_quant:
674
+ quantized = self.quantize_new(x)
675
+
676
+ # calculate indices
677
+ bit_indices = (quantized > 0).int()
678
+ entropy_penalty = persample_entropy = cb_entropy = self.zero
679
+ commit_loss = self.zero
680
+
681
+ # input back to original dtype if needed
682
+
683
+ if force_f32:
684
+ x = x.type(orig_dtype)
685
+
686
+ # merge back codebook dim
687
+ x = quantized # rename quantized to x for output
688
+ x = rearrange(x, 'b n c d -> b n (c d)')
689
+
690
+ # project out to feature dimension if needed
691
+
692
+ x = self.project_out(x)
693
+
694
+ # reconstitute image or video dimensions
695
+
696
+ if should_transpose:
697
+ x = unpack_one(x, ps, 'b * d')
698
+ x = rearrange(x, 'b ... d -> b d ...')
699
+
700
+ bit_indices = unpack_one(bit_indices, ps, 'b * c d')
701
+
702
+ # whether to remove single codebook dim
703
+
704
+ if not self.keep_num_codebooks_dim:
705
+ bit_indices = rearrange(bit_indices, '... 1 d -> ... d')
706
+
707
+ # complete aux loss
708
+
709
+ aux_loss = commit_loss * self.commitment_loss_weight + (self.zeta * entropy_penalty / self.inv_temperature)*entropy_weight
710
+ # returns
711
+
712
+ ret = Return(x, indices, bit_indices, aux_loss)
713
+
714
+ if not return_loss_breakdown:
715
+ return ret
716
+
717
+ return ret, LossBreakdown(persample_entropy, cb_entropy, commit_loss)
718
+
infinity/models/bsq_vae/vae.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from infinity.models.bsq_vae.flux_vqgan import AutoEncoder
5
+
6
+ def load_cnn(model, state_dict, prefix, expand=False, use_linear=False):
7
+ delete_keys = []
8
+ loaded_keys = []
9
+ for key in state_dict:
10
+ if key.startswith(prefix):
11
+ _key = key[len(prefix):]
12
+ if _key in model.state_dict():
13
+ # load nn.Conv2d or nn.Linear to nn.Linear
14
+ if use_linear and (".q.weight" in key or ".k.weight" in key or ".v.weight" in key or ".proj_out.weight" in key):
15
+ load_weights = state_dict[key].squeeze()
16
+ elif _key.endswith(".conv.weight") and expand:
17
+ if model.state_dict()[_key].shape == state_dict[key].shape:
18
+ # 2D cnn to 2D cnn
19
+ load_weights = state_dict[key]
20
+ else:
21
+ # 2D cnn to 3D cnn
22
+ _expand_dim = model.state_dict()[_key].shape[2]
23
+ load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1)
24
+ else:
25
+ load_weights = state_dict[key]
26
+ model.state_dict()[_key].copy_(load_weights)
27
+ delete_keys.append(key)
28
+ loaded_keys.append(prefix+_key)
29
+ # load nn.Conv2d to Conv class
30
+ conv_list = ["conv"] if use_linear else ["conv", ".q.", ".k.", ".v.", ".proj_out.", ".nin_shortcut."]
31
+ if any(k in _key for k in conv_list):
32
+ if _key.endswith(".weight"):
33
+ conv_key = _key.replace(".weight", ".conv.weight")
34
+ if conv_key and conv_key in model.state_dict():
35
+ if model.state_dict()[conv_key].shape == state_dict[key].shape:
36
+ # 2D cnn to 2D cnn
37
+ load_weights = state_dict[key]
38
+ else:
39
+ # 2D cnn to 3D cnn
40
+ _expand_dim = model.state_dict()[conv_key].shape[2]
41
+ load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1)
42
+ model.state_dict()[conv_key].copy_(load_weights)
43
+ delete_keys.append(key)
44
+ loaded_keys.append(prefix+conv_key)
45
+ if _key.endswith(".bias"):
46
+ conv_key = _key.replace(".bias", ".conv.bias")
47
+ if conv_key and conv_key in model.state_dict():
48
+ model.state_dict()[conv_key].copy_(state_dict[key])
49
+ delete_keys.append(key)
50
+ loaded_keys.append(prefix+conv_key)
51
+ # load nn.GroupNorm to Normalize class
52
+ if "norm" in _key:
53
+ if _key.endswith(".weight"):
54
+ norm_key = _key.replace(".weight", ".norm.weight")
55
+ if norm_key and norm_key in model.state_dict():
56
+ model.state_dict()[norm_key].copy_(state_dict[key])
57
+ delete_keys.append(key)
58
+ loaded_keys.append(prefix+norm_key)
59
+ if _key.endswith(".bias"):
60
+ norm_key = _key.replace(".bias", ".norm.bias")
61
+ if norm_key and norm_key in model.state_dict():
62
+ model.state_dict()[norm_key].copy_(state_dict[key])
63
+ delete_keys.append(key)
64
+ loaded_keys.append(prefix+norm_key)
65
+
66
+ for key in delete_keys:
67
+ del state_dict[key]
68
+
69
+ return model, state_dict, loaded_keys
70
+
71
+
72
+ def vae_model(vqgan_ckpt, schedule_mode, codebook_dim, codebook_size, test_mode=True, patch_size=16, encoder_ch_mult=[1, 2, 4, 4, 4], decoder_ch_mult=[1, 2, 4, 4, 4],):
73
+ args=argparse.Namespace(
74
+ vqgan_ckpt=vqgan_ckpt,
75
+ sd_ckpt=None,
76
+ inference_type='image',
77
+ save='./imagenet_val_bsq',
78
+ save_prediction=True,
79
+ image_recon4video=False,
80
+ junke_old=False,
81
+ device='cuda',
82
+ max_steps=1000000.0,
83
+ log_every=1,
84
+ visu_every=1000,
85
+ ckpt_every=1000,
86
+ default_root_dir='',
87
+ compile='no',
88
+ ema='no',
89
+ lr=0.0001,
90
+ beta1=0.9,
91
+ beta2=0.95,
92
+ warmup_steps=0,
93
+ optim_type='Adam',
94
+ disc_optim_type=None,
95
+ lr_min=0.0,
96
+ warmup_lr_init=0.0,
97
+ max_grad_norm=1.0,
98
+ max_grad_norm_disc=1.0,
99
+ disable_sch=False,
100
+ patch_size=patch_size,
101
+ temporal_patch_size=4,
102
+ embedding_dim=256,
103
+ codebook_dim=codebook_dim,
104
+ num_quantizers=8,
105
+ quantizer_type='MultiScaleBSQ',
106
+ use_vae=False,
107
+ use_freq_enc=False,
108
+ use_freq_dec=False,
109
+ preserve_norm=False,
110
+ ln_before_quant=False,
111
+ ln_init_by_sqrt=False,
112
+ use_pxsf=False,
113
+ new_quant=True,
114
+ use_decay_factor=False,
115
+ mask_out=False,
116
+ use_stochastic_depth=False,
117
+ drop_rate=0.0,
118
+ schedule_mode=schedule_mode,
119
+ lr_drop=None,
120
+ lr_drop_rate=0.1,
121
+ keep_first_quant=False,
122
+ keep_last_quant=False,
123
+ remove_residual_detach=False,
124
+ use_out_phi=False,
125
+ use_out_phi_res=False,
126
+ use_lecam_reg=False,
127
+ lecam_weight=0.05,
128
+ perceptual_model='vgg16',
129
+ base_ch_disc=64,
130
+ random_flip=False,
131
+ flip_prob=0.5,
132
+ flip_mode='stochastic',
133
+ max_flip_lvl=1,
134
+ not_load_optimizer=False,
135
+ use_lecam_reg_zero=False,
136
+ freeze_encoder=False,
137
+ rm_downsample=False,
138
+ random_flip_1lvl=False,
139
+ flip_lvl_idx=0,
140
+ drop_when_test=False,
141
+ drop_lvl_idx=0,
142
+ drop_lvl_num=1,
143
+ disc_version='v1',
144
+ magvit_disc=False,
145
+ sigmoid_in_disc=False,
146
+ activation_in_disc='leaky_relu',
147
+ apply_blur=False,
148
+ apply_noise=False,
149
+ dis_warmup_steps=0,
150
+ dis_lr_multiplier=1.0,
151
+ dis_minlr_multiplier=False,
152
+ disc_channels=64,
153
+ disc_layers=3,
154
+ discriminator_iter_start=0,
155
+ disc_pretrain_iter=0,
156
+ disc_optim_steps=1,
157
+ disc_warmup=0,
158
+ disc_pool='no',
159
+ disc_pool_size=1000,
160
+ advanced_disc=False,
161
+ recon_loss_type='l1',
162
+ video_perceptual_weight=0.0,
163
+ image_gan_weight=1.0,
164
+ video_gan_weight=1.0,
165
+ image_disc_weight=0.0,
166
+ video_disc_weight=0.0,
167
+ l1_weight=4.0,
168
+ gan_feat_weight=0.0,
169
+ perceptual_weight=0.0,
170
+ kl_weight=0.0,
171
+ lfq_weight=0.0,
172
+ entropy_loss_weight=0.1,
173
+ commitment_loss_weight=0.25,
174
+ diversity_gamma=1,
175
+ norm_type='group',
176
+ disc_loss_type='hinge',
177
+ use_checkpoint=False,
178
+ precision='fp32',
179
+ encoder_dtype='fp32',
180
+ upcast_attention='',
181
+ upcast_tf32=False,
182
+ tokenizer='flux',
183
+ pretrained=None,
184
+ pretrained_mode='full',
185
+ inflation_pe=False,
186
+ init_vgen='no',
187
+ no_init_idis=False,
188
+ init_idis='keep',
189
+ init_vdis='no',
190
+ enable_nan_detector=False,
191
+ turn_on_profiler=False,
192
+ profiler_scheduler_wait_steps=10,
193
+ debug=True,
194
+ video_logger=False,
195
+ bytenas='',
196
+ username='',
197
+ seed=1234,
198
+ vq_to_vae=False,
199
+ load_not_strict=False,
200
+ zero=0,
201
+ bucket_cap_mb=40,
202
+ manual_gc_interval=1000,
203
+ data_path=[''],
204
+ data_type=[''],
205
+ dataset_list=['imagenet'],
206
+ fps=-1,
207
+ dataaug='resizecrop',
208
+ multi_resolution=False,
209
+ random_bucket_ratio=0.0,
210
+ sequence_length=16,
211
+ resolution=[256, 256],
212
+ batch_size=[1],
213
+ num_workers=0,
214
+ image_channels=3,
215
+ codebook_size=codebook_size,
216
+ codebook_l2_norm=True,
217
+ codebook_show_usage=True,
218
+ commit_loss_beta=0.25,
219
+ entropy_loss_ratio=0.0,
220
+ base_ch=128,
221
+ num_res_blocks=2,
222
+ encoder_ch_mult=encoder_ch_mult,
223
+ decoder_ch_mult=decoder_ch_mult,
224
+ dropout_p=0.0,
225
+ cnn_type='2d',
226
+ cnn_version='v1',
227
+ conv_in_out_2d='no',
228
+ conv_inner_2d='no',
229
+ res_conv_2d='no',
230
+ cnn_attention='no',
231
+ cnn_norm_axis='spatial',
232
+ flux_weight=0,
233
+ cycle_weight=0,
234
+ cycle_feat_weight=0,
235
+ cycle_gan_weight=0,
236
+ cycle_loop=0,
237
+ z_drop=0.0)
238
+
239
+ vae = AutoEncoder(args)
240
+ use_vae = vae.use_vae
241
+ if not use_vae:
242
+ num_codes = args.codebook_size
243
+ if isinstance(vqgan_ckpt, str):
244
+ state_dict = torch.load(args.vqgan_ckpt, map_location=torch.device("cpu"), weights_only=True)
245
+ else:
246
+ state_dict = args.vqgan_ckpt
247
+ if state_dict:
248
+ if args.ema == "yes":
249
+ vae, new_state_dict, loaded_keys = load_cnn(vae, state_dict["ema"], prefix="", expand=False)
250
+ else:
251
+ vae, new_state_dict, loaded_keys = load_cnn(vae, state_dict["vae"], prefix="", expand=False)
252
+ if test_mode:
253
+ vae.eval()
254
+ [p.requires_grad_(False) for p in vae.parameters()]
255
+ return vae
infinity/models/ema.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ from collections import OrderedDict
4
+
5
+
6
+ def get_ema_model(model):
7
+ ema_model = copy.deepcopy(model)
8
+ ema_model.eval()
9
+ for param in ema_model.parameters():
10
+ param.requires_grad = False
11
+ return ema_model
12
+
13
+ @torch.no_grad()
14
+ def update_ema(ema_model, model, decay=0.9999):
15
+ """
16
+ Step the EMA model towards the current model.
17
+ """
18
+ ema_params = OrderedDict(ema_model.named_parameters())
19
+ model_params = OrderedDict(model.named_parameters())
20
+
21
+ for name, param in model_params.items():
22
+ # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
23
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
infinity/models/flex_attn.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wrap torch's flex attention and handle mess info or potentially refactor
3
+ """
4
+ from functools import partial
5
+ import torch
6
+ import numpy as np
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ try:
10
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
11
+ flex_attention_available = True
12
+ except ImportError:
13
+ print(f"[Warning] flex attention need pytorch 2.5.0+ but your version is {torch.__version__}")
14
+ flex_attention_available = False
15
+
16
+ def _causal_mask(b, h, q_idx, kv_idx):
17
+ return q_idx >= kv_idx
18
+
19
+ def _length_to_offsets(lengths, device):
20
+ """Converts a list of lengths to a list of offsets.
21
+
22
+ Args:
23
+ lengths: A list of lengths.
24
+
25
+ """
26
+ offsets = [0]
27
+ offsets.extend(lengths)
28
+ offsets = torch.tensor(offsets, device=device, dtype=torch.int32)
29
+ offsets = torch.cumsum(offsets, dim=-1)
30
+ return offsets
31
+
32
+ def _generate_var_mask_mod(offsets):
33
+ """Generates mask mods that apply to inputs to flex attention in the sequence stacked
34
+ format.
35
+
36
+ Args:
37
+ offsets: This tensor should be of shape(num_documents + 1)
38
+ this should contain the cumulative counts of document tokens.
39
+ e.g. if you have 3 documents of length 2, 4, 3 then
40
+ offsets = [0, 2, 6, 9]
41
+
42
+ Note:
43
+ What is the sequence stacked format? When assembling batches of inputs, we
44
+ take multiple sequences and stack them together to form 1 large sequence. We then
45
+ use masking to ensure that the attention scores are only applied to tokens within
46
+ the same document.
47
+ """
48
+
49
+ def _offsets_to_doc_ids_tensor(offsets):
50
+ device = offsets.device
51
+ counts = offsets[1:] - offsets[:-1]
52
+ return torch.repeat_interleave(
53
+ torch.arange(len(counts), device=device, dtype=torch.int32), counts
54
+ )
55
+
56
+ document_id = _offsets_to_doc_ids_tensor(offsets)
57
+
58
+ def var_mask_mod(b, h, q_idx, kv_idx):
59
+ same_doc = document_id[q_idx] == document_id[kv_idx]
60
+ causal_mask = _causal_mask(b, h, q_idx, kv_idx)
61
+ return same_doc | causal_mask
62
+
63
+ return var_mask_mod
64
+
65
+ def _generate_var_infer_mask_with_kv_cache(lengths):
66
+ kv_len = sum(lengths)
67
+ def var_mask_mod(b, h, q_idx, kv_idx):
68
+ return kv_idx < kv_len
69
+
70
+ return var_mask_mod
71
+
72
+ def _generate_var_edit_block_mask_mod(offsets):
73
+
74
+ def _offsets_to_doc_ids_tensor(offsets):
75
+ device = offsets.device
76
+ counts = offsets[1:] - offsets[:-1]
77
+ return torch.repeat_interleave(
78
+ torch.arange(len(counts), device=device, dtype=torch.int32), counts
79
+ )
80
+
81
+ document_id = _offsets_to_doc_ids_tensor(offsets)
82
+ text_id = (document_id[-1] + 1) // 2
83
+
84
+ def var_edit_block_mask_mod(b, h, q_idx, kv_idx):
85
+ causal_doc = document_id[q_idx] >= document_id[kv_idx]
86
+ with_edit = (document_id[q_idx] % text_id) >= (document_id[kv_idx] % text_id)
87
+ return causal_doc & with_edit
88
+
89
+ return var_edit_block_mask_mod
90
+
91
+ class FlexAttn(nn.Module):
92
+ def __init__(
93
+ self, block_scales:list, mask_type:str, B, H, L:int, auto_padding=False
94
+ ):
95
+ """
96
+ :param block_scales: accept VAR's block sizes like [(1,1), (2,2), (3,3)]
97
+ :param mask_type: var/causal
98
+ :param B: batch size
99
+ :param H: heads num
100
+ :param L: sequence length
101
+ """
102
+ super().__init__()
103
+ if not flex_attention_available:
104
+ raise NotImplementedError((f"[Error] flex attention need pytorch 2.5.0+ but your version is {torch.__version__}"))
105
+
106
+ self.support_mask_type = ["var", "causal", "var_infer_mask_with_kv_cache", "var_edit_block"]
107
+ self.auto_padding = auto_padding
108
+
109
+ self.flex_attention = torch.compile(flex_attention)
110
+
111
+ self.block_scales = block_scales
112
+ self.lengths = [ x * y * z for x,y,z in block_scales]
113
+
114
+ self.offsets = _length_to_offsets(self.lengths, device='cuda')
115
+
116
+ # if L paded to align 128, block need to cover padding area
117
+ if self.offsets[-1] < L:
118
+ self.offsets = torch.cat((self.offsets, torch.tensor([L], device='cuda')), dim=0)
119
+
120
+ if mask_type == "var":
121
+ self.mask_mod = _generate_var_mask_mod(self.offsets)
122
+ self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True)
123
+ elif mask_type == "causal":
124
+ self.mask_mod = _causal_mask
125
+ self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True)
126
+ elif mask_type == 'var_infer_mask_with_kv_cache':
127
+ self.mask_mod = _generate_var_infer_mask_with_kv_cache(self.lengths)
128
+ self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True)
129
+ elif mask_type == 'var_edit_block':
130
+ self.mask_mod = _generate_var_edit_block_mask_mod(self.offsets)
131
+ self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True)
132
+ else:
133
+ raise NotImplementedError(f"{mask_type} not supportted in FlexAttn, support type:{self.support_mask_type}")
134
+
135
+
136
+ def forward(self, q, k, v, scale = None):
137
+ if self.auto_padding:
138
+ q_pad_len = (128 - q.shape[-2] % 128) % 128
139
+ kv_pad_len = (128 - k.shape[-2] % 128) % 128
140
+ q_pad = F.pad(q, (0, 0, 0, q_pad_len))
141
+ k_pad = F.pad(k, (0, 0, 0, kv_pad_len))
142
+ v_pad = F.pad(v, (0, 0, 0, kv_pad_len))
143
+ oup = self.flex_attention(q_pad.to(v_pad.dtype), k_pad.to(v.dtype), v_pad, block_mask = self.block_mask, scale = scale)
144
+ if q_pad_len > 0:
145
+ oup = oup[:,:,:-q_pad_len]
146
+ else:
147
+ oup = self.flex_attention(q.to(v.dtype), k.to(v.dtype), v, block_mask = self.block_mask, scale = scale)
148
+ return oup
149
+
150
+ def extra_repr(self) -> str:
151
+ tail = ''
152
+ return f'block size:{self.block_scales} {tail}'
infinity/models/fused_op.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from copy import deepcopy
3
+ from typing import Union
4
+
5
+ import torch
6
+ from torch import nn as nn
7
+ from torch.nn import functional as F
8
+
9
+
10
+ @torch.compile(fullgraph=True)
11
+ def fused_rms_norm(x: torch.Tensor, weight: nn.Parameter, eps: float):
12
+ x = x.float()
13
+ return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(eps))) * weight
14
+
15
+
16
+ @torch.compile(fullgraph=True)
17
+ def fused_ada_layer_norm(C: int, eps: float, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor):
18
+ x = x.float()
19
+ x = F.layer_norm(input=x, normalized_shape=(C,), weight=None, bias=None, eps=eps)
20
+ return x.mul(scale.add(1)).add_(shift)
21
+
22
+
23
+ @torch.compile(fullgraph=True)
24
+ def fused_ada_rms_norm(C: int, eps: float, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor):
25
+ x = x.float()
26
+ x = (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(eps)))
27
+ return x.mul(scale.add(1)).add_(shift)
infinity/models/infinity.py ADDED
@@ -0,0 +1,847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Definition of Infinity transformer model.
3
+ """
4
+
5
+ import math
6
+ import random
7
+ import time
8
+ from contextlib import nullcontext
9
+ from functools import partial
10
+ from typing import List, Optional, Tuple, Union, Dict, Any
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from timm.models import register_model
16
+ from torch.utils.checkpoint import checkpoint
17
+ from PIL import Image
18
+ import numpy as np
19
+
20
+ import infinity.utils.dist as dist
21
+ from infinity.utils.dist import for_visualize
22
+ from infinity.models.basic import flash_attn_func, flash_fused_op_installed, AdaLNBeforeHead, CrossAttnBlock, SelfAttnBlock, CrossAttention, FastRMSNorm, precompute_rope2d_freqs_grid
23
+ from infinity.utils import misc
24
+ from infinity.models.flex_attn import FlexAttn
25
+ from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
26
+
27
+ try:
28
+ from infinity.models.fused_op import fused_ada_layer_norm, fused_ada_rms_norm
29
+ except:
30
+ fused_ada_layer_norm, fused_ada_rms_norm = None, None
31
+
32
+
33
+ class MultiInpIdentity(nn.Module):
34
+ def forward(self, x, *args, **kwargs):
35
+ return x
36
+
37
+
38
+ class TextAttentivePool(nn.Module):
39
+ def __init__(self, Ct5: int, D: int):
40
+ super().__init__()
41
+ self.Ct5, self.D = Ct5, D
42
+ if D > 4096:
43
+ self.head_dim = 64
44
+ else:
45
+ self.head_dim = 128
46
+
47
+ self.num_heads = Ct5 // self.head_dim
48
+ self.ca = CrossAttention(for_attn_pool=True, embed_dim=self.D, kv_dim=Ct5, num_heads=self.num_heads)
49
+ def forward(self, ca_kv):
50
+ return self.ca(None, ca_kv).squeeze(1)
51
+
52
+ class SharedAdaLin(nn.Linear):
53
+ def forward(self, cond_BD):
54
+ C = self.weight.shape[0] // 6
55
+ return super().forward(cond_BD).reshape(-1, 1, 6, C) # B16C
56
+
57
+
58
+ class MultipleLayers(nn.Module):
59
+ def __init__(self, ls, num_blocks_in_a_chunk, index):
60
+ super().__init__()
61
+ self.module = nn.ModuleList()
62
+ for i in range(index, index+num_blocks_in_a_chunk):
63
+ self.module.append(ls[i])
64
+
65
+ def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn=None, scale_schedule=None, checkpointing_full_block=False, rope2d_freqs_grid=None, start_layer=False, src=True):
66
+ h = x
67
+ for m in self.module:
68
+ if checkpointing_full_block:
69
+ h = torch.utils.checkpoint.checkpoint(m, h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, start_layer, src, use_reentrant=False)
70
+ else:
71
+ h = m(h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, start_layer, src)
72
+ start_layer = False
73
+ return h
74
+
75
+ class Infinity(nn.Module):
76
+ def __init__(
77
+ self, vae_local,
78
+ text_channels=0, text_maxlen=0, # text-cond generation
79
+ selecting_idx=None, # class-cond generation
80
+ embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., # model's architecture
81
+ drop_rate=0., drop_path_rate=0., # drop out and drop path
82
+ norm_eps=1e-6, rms_norm=False, # norm layer
83
+ shared_aln=False, head_aln=True, # adaptive norm
84
+ cond_drop_rate=0.1, # for classifier-free guidance
85
+ rand_uncond=False,
86
+ cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False,
87
+ raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),
88
+ head_depth=1,
89
+ top_p=0.0, top_k=0.0,
90
+ customized_flash_attn=False, fused_mlp=False, fused_norm=False,
91
+ block_chunks=1,
92
+ checkpointing=None,
93
+ pad_to_multiplier=0,
94
+ use_flex_attn=False,
95
+ batch_size=2,
96
+ add_lvl_embeding_only_first_block=1,
97
+ use_bit_label=1,
98
+ rope2d_each_sa_layer=0,
99
+ rope2d_normalized_by_hw=0,
100
+ pn=None,
101
+ train_h_div_w_list=None,
102
+ video_frames=1,
103
+ always_training_scales=20,
104
+ apply_spatial_patchify = 0,
105
+ inference_mode=False,
106
+ ):
107
+ # set hyperparameters
108
+ self.C = embed_dim
109
+ # self.clip_dim = 1536
110
+ self.inference_mode = inference_mode
111
+ self.apply_spatial_patchify = apply_spatial_patchify
112
+ if self.apply_spatial_patchify:
113
+ self.d_vae = vae_local.embed_dim * 4
114
+ else:
115
+ self.d_vae = vae_local.embed_dim
116
+ self.use_bit_label = use_bit_label
117
+ self.codebook_dim = self.d_vae
118
+ self.V = (self.codebook_dim * 2) if self.use_bit_label else vae_local.vocab_size
119
+ self.bit_mask = vae_local.quantizer.lfq.mask if self.use_bit_label else None
120
+ self.Ct5 = text_channels
121
+ self.depth = depth
122
+ self.num_heads = num_heads
123
+ self.batch_size = batch_size
124
+ self.mlp_ratio = mlp_ratio
125
+ self.cond_drop_rate = cond_drop_rate
126
+ self.norm_eps = norm_eps
127
+ self.prog_si = -1
128
+ self.pn = pn
129
+ self.train_h_div_w_list = train_h_div_w_list if train_h_div_w_list else h_div_w_templates
130
+ self.video_frames = video_frames
131
+ self.always_training_scales = always_training_scales
132
+
133
+ assert add_lvl_embeding_only_first_block in [0,1]
134
+ self.add_lvl_embeding_only_first_block = add_lvl_embeding_only_first_block
135
+ assert rope2d_each_sa_layer in [0,1]
136
+ self.rope2d_each_sa_layer = rope2d_each_sa_layer
137
+ self.rope2d_normalized_by_hw = rope2d_normalized_by_hw
138
+ print(f'self.codebook_dim: {self.codebook_dim}, self.add_lvl_embeding_only_first_block: {self.add_lvl_embeding_only_first_block}, \
139
+ self.use_bit_label: {self.use_bit_label}, self.rope2d_each_sa_layer: {rope2d_each_sa_layer}, self.rope2d_normalized_by_hw: {self.rope2d_normalized_by_hw}')
140
+ head_up_method = ''
141
+ word_patch_size = 1 if head_up_method in {'', 'no'} else 2
142
+ if word_patch_size > 1:
143
+ assert all(raw_pn % word_patch_size == 0 for raw_pn in raw_scale_schedule), f'raw_scale_schedule={raw_scale_schedule}, not compatible with word_patch_size={word_patch_size}'
144
+
145
+ self.checkpointing = checkpointing
146
+ self.pad_to_multiplier = max(1, pad_to_multiplier)
147
+
148
+ customized_kernel_installed = any('Infinity' in arg_name for arg_name in flash_attn_func.__code__.co_varnames)
149
+ self.customized_flash_attn = customized_flash_attn and customized_kernel_installed
150
+ if customized_flash_attn and not customized_kernel_installed:
151
+ import inspect, warnings
152
+ file_path = inspect.getsourcefile(flash_attn_func)
153
+ line_number = inspect.getsourcelines(flash_attn_func)[1]
154
+ info = (
155
+ f'>>>>>> Customized FlashAttention2 is not installed or compiled, but specified in args by --flash=1. Set customized_flash_attn = False. <<<<<<\n'
156
+ f'>>>>>> `flash_attn_func` is in [line {line_number}] [file {file_path}] <<<<<<\n'
157
+ f'>>>>>> {flash_attn_func.__code__.co_varnames=} <<<<<<\n'
158
+ )
159
+ warnings.warn(info, ImportWarning)
160
+ print(info, flush=True)
161
+
162
+ self.raw_scale_schedule = raw_scale_schedule # 'raw' means before any patchifying
163
+ self.first_l = 1
164
+ # solve top-p top-k sampling hyperparameters
165
+ self.top_p, self.top_k = max(min(top_p, 1), 0), (round(top_k * self.V) if 0 < top_k < 1 else round(top_k))
166
+ if self.top_p < 1e-5: self.top_p = 0
167
+ if self.top_k >= self.V or self.top_k <= 0: self.top_k = 0
168
+
169
+ t = torch.zeros(dist.get_world_size(), device=dist.get_device())
170
+ t[dist.get_rank()] = float(flash_fused_op_installed)
171
+ dist.barrier()
172
+ dist.allreduce(t)
173
+ assert round(t.sum().item()) in {0, dist.get_world_size()}, f'flash_fused_op_installed: {t}'
174
+
175
+ super().__init__()
176
+ self.rng = torch.Generator(device=dist.get_device())
177
+ self.maybe_record_function = nullcontext
178
+ self.text_maxlen = text_maxlen
179
+ self.t2i = text_channels != 0
180
+
181
+ # [inp & position embedding]
182
+ init_std = math.sqrt(1 / self.C / 3)
183
+ self.norm0_cond = nn.Identity()
184
+ if self.t2i:
185
+ self.selecting_idx = None
186
+ self.num_classes = 0
187
+ self.D = self.C
188
+
189
+ cfg_uncond = torch.empty(self.text_maxlen, self.Ct5)
190
+ rng = torch.Generator(device='cpu')
191
+ rng.manual_seed(0)
192
+ torch.nn.init.trunc_normal_(cfg_uncond, std=1.2, generator=rng)
193
+ cfg_uncond /= self.Ct5 ** 0.5
194
+ if rand_uncond:
195
+ self.register_buffer('cfg_uncond', cfg_uncond)
196
+ else:
197
+ self.cfg_uncond = nn.Parameter(cfg_uncond)
198
+
199
+ self.text_norm = FastRMSNorm(self.Ct5, elementwise_affine=True, eps=norm_eps)
200
+ self.text_proj_for_sos = TextAttentivePool(self.Ct5, self.D)
201
+ self.text_proj_for_ca = nn.Sequential(
202
+ nn.Linear(self.Ct5, self.D),
203
+ nn.GELU(approximate='tanh'),
204
+ nn.Linear(self.D, self.D),
205
+ )
206
+ # self.clip_proj_for_sos = nn.Linear(self.clip_dim, self.D // 4)
207
+ else: # class-label cond
208
+ if selecting_idx is None:
209
+ num_classes = 1000
210
+ print(f'======= WARNING: selecting_idx not specified, set to 1/{num_classes} @ {dist.get_device()} =======')
211
+ selecting_idx = torch.full((1, num_classes), fill_value=1/num_classes, dtype=torch.float32, device=dist.get_device())
212
+ self.selecting_idx = selecting_idx
213
+ self.num_classes = selecting_idx.shape[-1]
214
+ self.D = self.C
215
+ self.class_emb = nn.Embedding(self.num_classes + 1, self.C)
216
+ nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std)
217
+
218
+ self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
219
+ nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
220
+ if self.rope2d_each_sa_layer:
221
+ rope2d_freqs_grid = precompute_rope2d_freqs_grid(dim=self.C//self.num_heads, dynamic_resolution_h_w=dynamic_resolution_h_w, pad_to_multiplier=self.pad_to_multiplier, rope2d_normalized_by_hw=self.rope2d_normalized_by_hw)
222
+ self.rope2d_freqs_grid = rope2d_freqs_grid
223
+ else:
224
+ raise ValueError(f'self.rope2d_each_sa_layer={self.rope2d_each_sa_layer} not implemented')
225
+ self.lvl_embed = nn.Embedding(15, self.C)
226
+ nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
227
+
228
+ # [input layers] input norm && input embedding
229
+ norm_layer = partial(FastRMSNorm if rms_norm else nn.LayerNorm, eps=norm_eps)
230
+ self.norm0_ve = norm_layer(self.d_vae) if nm0 else nn.Identity()
231
+ self.word_embed = nn.Linear(self.d_vae, self.C)
232
+
233
+ # [shared adaptive layernorm mapping network]
234
+ self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity()
235
+
236
+ # fused norm
237
+ if fused_norm:
238
+ fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm
239
+ if fused_norm_func is not None: # pre-compile
240
+ B = 2
241
+ x = torch.randn(B, 1, self.C).requires_grad_(True)
242
+ scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True)
243
+ shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True)
244
+ # fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale, shift=shift).mean().backward()
245
+ del B, x, scale, shift
246
+ else:
247
+ fused_norm_func = None
248
+
249
+ # [backbone and head]
250
+ self.use_flex_attn = use_flex_attn
251
+ self.attn_fn_compile_dict = {}
252
+ self.batch_size = batch_size
253
+ if self.use_flex_attn:
254
+ self.attn_fn_compile_dict = self.compile_flex_attn()
255
+
256
+ self.drop_path_rate = drop_path_rate
257
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # dpr means drop path rate (linearly increasing)
258
+ self.unregistered_blocks = []
259
+ for block_idx in range(depth):
260
+ block = (CrossAttnBlock if self.t2i else SelfAttnBlock)(
261
+ embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer,
262
+ num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn,
263
+ swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func,
264
+ checkpointing_sa_only=self.checkpointing == 'self-attn',
265
+ use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw,
266
+ )
267
+ self.unregistered_blocks.append(block)
268
+
269
+ # [head]
270
+ V = self.V
271
+ if head_aln:
272
+ self.head_nm = AdaLNBeforeHead(self.C, self.D, act=True, norm_layer=norm_layer, fused_norm_func=fused_norm_func)
273
+ self.head = nn.Linear(self.C, V) if head_depth == 1 else nn.Sequential(nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V))
274
+ else:
275
+ self.head_nm = MultiInpIdentity()
276
+ self.head = nn.Sequential(norm_layer(self.C), nn.Linear(self.C, V)) if head_depth == 1 else nn.Sequential(norm_layer(self.C), nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V))
277
+
278
+ self.num_block_chunks = block_chunks or 1
279
+ self.num_blocks_in_a_chunk = depth // block_chunks
280
+ print(f"{self.num_blocks_in_a_chunk=}, {depth=}, {block_chunks=}")
281
+ assert self.num_blocks_in_a_chunk * block_chunks == depth
282
+ if self.num_block_chunks == 1:
283
+ self.blocks = nn.ModuleList(self.unregistered_blocks)
284
+ else:
285
+ self.block_chunks = nn.ModuleList()
286
+ for i in range(self.num_block_chunks):
287
+ self.block_chunks.append(MultipleLayers(self.unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk))
288
+ print(
289
+ f'\n[constructor] ==== customized_flash_attn={self.customized_flash_attn} (using_flash={sum((b.sa.using_flash if self.t2i else b.attn.using_flash) for b in self.unregistered_blocks)}/{self.depth}), fused_mlp={fused_mlp} (fused_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.unregistered_blocks)}/{self.depth}) ==== \n'
290
+ f' [Infinity config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}, swiglu={swiglu} num_blocks_in_a_chunk={self.num_blocks_in_a_chunk}\n'
291
+ f' [drop ratios] drop_rate={drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})',
292
+ end='\n\n', flush=True
293
+ )
294
+
295
+
296
+ def compile_flex_attn(self):
297
+ attn_fn_compile_dict = {}
298
+ for h_div_w in self.train_h_div_w_list:
299
+ h_div_w_template = h_div_w_templates[np.argmin(np.abs(float(h_div_w) - h_div_w_templates))]
300
+ full_scale_schedule = dynamic_resolution_h_w[h_div_w_template][self.pn]['scales']
301
+ if self.inference_mode:
302
+ apply_flex_attn_scales = list(range(1, 1+len(full_scale_schedule)))
303
+ mask_type = "infinity_infer_mask_with_kv_cache"
304
+ auto_padding = True
305
+ else:
306
+ mask_type = 'var'
307
+ # mask_type = 'var_edit_block'
308
+ auto_padding = False
309
+ apply_flex_attn_scales = [min(self.always_training_scales, len(full_scale_schedule))]
310
+ for scales_num in apply_flex_attn_scales:
311
+ print(f'====== apply flex attn hdivw: {h_div_w} scales: {scales_num} ======')
312
+ scale_schedule = full_scale_schedule[:scales_num]
313
+ scale_schedule = [ (min(t, self.video_frames//4+1), h, w) for (t,h, w) in scale_schedule]
314
+ patchs_nums_tuple = tuple(scale_schedule)
315
+ edit_scale_schedule = [scale_schedule[-1]] + scale_schedule
316
+ edit_patchs_nums_tuple = tuple(edit_scale_schedule)
317
+
318
+ SEQ_L = sum( pt * ph * pw for pt, ph, pw in edit_patchs_nums_tuple)
319
+ aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L
320
+ attn_fn = FlexAttn(block_scales = edit_patchs_nums_tuple,
321
+ mask_type = mask_type,
322
+ B = self.batch_size,
323
+ H = self.num_heads,
324
+ L = aligned_L,
325
+ auto_padding=auto_padding)
326
+ attn_fn_compile_dict[patchs_nums_tuple] = attn_fn
327
+
328
+ if self.video_frames > 1: # append image attn_fn when self.video_frames > 1 (namely videos)
329
+ scale_schedule = [ (1, h, w) for (t,h, w) in scale_schedule]
330
+ patchs_nums_tuple = tuple(scale_schedule)
331
+ edit_scale_schedule = [scale_schedule[-1]] + scale_schedule
332
+ edit_patchs_nums_tuple = tuple(edit_scale_schedule)
333
+ SEQ_L = sum( pt * ph * pw for pt, ph, pw in edit_patchs_nums_tuple)
334
+ aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L
335
+ attn_fn = FlexAttn(block_scales = edit_patchs_nums_tuple,
336
+ mask_type = mask_type,
337
+ B = self.batch_size,
338
+ H = self.num_heads,
339
+ L = aligned_L)
340
+ attn_fn_compile_dict[patchs_nums_tuple] = attn_fn
341
+ return attn_fn_compile_dict
342
+
343
+ def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]):
344
+ """
345
+ :param h: hidden_state, shaped (B or batch_size, L or seq_len, C or hidden_dim)
346
+ :param cond_BD: shaped (B or batch_size, D or cond_dim)
347
+ :param tau: temperature
348
+ :return: logits, shaped (B or batch_size, V or vocabulary_size)
349
+ """
350
+ with torch.amp.autocast('cuda', enabled=False):
351
+ return self.head(self.head_nm(h.float(), cond_BD.float()))
352
+
353
+ def add_lvl_embeding(self, feature, scale_ind, scale_schedule, need_to_pad=0):
354
+ bs, seq_len, c = feature.shape
355
+ patch_t, patch_h, patch_w = scale_schedule[scale_ind]
356
+ t_mul_h_mul_w = patch_t * patch_h * patch_w
357
+ assert t_mul_h_mul_w + need_to_pad == seq_len
358
+ feature[:, :t_mul_h_mul_w] += self.lvl_embed(scale_ind*torch.ones((bs, t_mul_h_mul_w),dtype=torch.int).to(feature.device))
359
+ return feature
360
+
361
+ def add_lvl_embeding_for_x_BLC(self, x_BLC, scale_schedule, need_to_pad=0):
362
+ ptr = 0
363
+ x_BLC_list = []
364
+
365
+ scale_seq_len = np.array(scale_schedule[-1]).prod()
366
+ x_BLC_this_scale = x_BLC[:,ptr:ptr+scale_seq_len]
367
+ ptr += scale_seq_len
368
+ x_BLC_this_scale = self.add_lvl_embeding(x_BLC_this_scale, len(scale_schedule)-1, scale_schedule)
369
+ x_BLC_list.append(x_BLC_this_scale)
370
+
371
+ for scale_ind, patch_t_h_w in enumerate(scale_schedule):
372
+ scale_seq_len = np.array(patch_t_h_w).prod()
373
+ x_BLC_this_scale = x_BLC[:,ptr:ptr+scale_seq_len] # shape: [bs, patch_h*patch_w, c]
374
+ ptr += scale_seq_len
375
+ x_BLC_this_scale = self.add_lvl_embeding(x_BLC_this_scale, scale_ind, scale_schedule)
376
+ x_BLC_list.append(x_BLC_this_scale)
377
+
378
+ assert x_BLC.shape[1] == (ptr + need_to_pad), f'{x_BLC.shape[1]} != {ptr} + {need_to_pad}'
379
+ x_BLC_list.append(x_BLC[:,ptr:])
380
+ x_BLC = torch.cat(x_BLC_list, dim=1)
381
+ return x_BLC
382
+
383
+ def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], source_x_BLC_wo_prefix: torch.Tensor, target_x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]],
384
+ cfg_infer=False,
385
+ **kwargs,
386
+ ) -> Union[torch.Tensor, List[torch.Tensor]]: # returns logits_BLV
387
+ """
388
+ label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k)
389
+ :return: logits BLV, V is vocab_size
390
+ """
391
+ # if cfg_infer:
392
+ # return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, clip_features=clip_features, scale_schedule=scale_schedule, **kwargs)
393
+
394
+ # [1. get input sequence x_BLC]
395
+ with torch.amp.autocast('cuda', enabled=False):
396
+ kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT
397
+ # drop cond
398
+ total = 0
399
+ for le in lens:
400
+ if random.random() < self.cond_drop_rate:
401
+ kv_compact[total:total+le] = self.cfg_uncond[:le]
402
+ total += le
403
+ must_on_graph = self.cfg_uncond[0, 0] * 0
404
+ kv_compact = self.text_norm(kv_compact).contiguous()
405
+ sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32
406
+ # sos_clip = self.clip_proj_for_sos(clip_features).float().contiguous()
407
+ # sos = cond_BD = torch.cat((sos_text, sos_clip), dim=-1)
408
+ kv_compact = self.text_proj_for_ca(kv_compact).contiguous()
409
+ kv_compact[0, 0] += must_on_graph #ADD
410
+ ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k
411
+
412
+ cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() # gss: gamma, scale, shift; cond_BD_or_gss should be float32
413
+
414
+ B = source_x_BLC_wo_prefix.shape[0]
415
+ sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1)
416
+
417
+ # concat the input: sr1, t, sr2, tf1, sr3, tf2, ...
418
+ src = self.word_embed(self.norm0_ve(source_x_BLC_wo_prefix))
419
+ tgt = self.word_embed(self.norm0_ve(target_x_BLC_wo_prefix))
420
+ x_BLC = torch.cat((src, sos, tgt), dim=1)
421
+
422
+ # [1.1. pad the seqlen dim]
423
+ l_end = x_BLC.shape[1]
424
+ need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end # 0
425
+
426
+ if self.customized_flash_attn:
427
+ Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end]
428
+ Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end]
429
+ attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen)
430
+ # todo: solve need_to_pad here
431
+ elif self.use_flex_attn:
432
+ if need_to_pad:
433
+ x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad))
434
+ assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0'
435
+ attn_bias_or_two_vector = None
436
+ else:
437
+ d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1)
438
+ dT = d.transpose(1, 2) # dT: 11L
439
+ attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end)
440
+ attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() # attn_bias: 11LL
441
+ if need_to_pad:
442
+ attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf)
443
+ attn_bias[0, 0, l_end:, 0] = 0
444
+ x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad))
445
+ attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device)
446
+
447
+ if self.use_flex_attn:
448
+ attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)]
449
+ else:
450
+ attn_fn = None
451
+
452
+ # [2. block loop]
453
+ SelfAttnBlock.forward, CrossAttnBlock.forward
454
+ checkpointing_full_block = self.checkpointing == 'full-block' and self.training
455
+ if self.num_block_chunks == 1:
456
+ for i, b in enumerate(self.blocks):
457
+ if self.add_lvl_embeding_only_first_block and i == 0:
458
+ x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
459
+ if not self.add_lvl_embeding_only_first_block:
460
+ x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
461
+ start_layer = True if i == 0 else False
462
+ if checkpointing_full_block:
463
+ x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, new_attn_fn, scale_schedule, self.rope2d_freqs_grid, start_layer, use_reentrant=False)
464
+ else:
465
+ x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, start_layer=start_layer)
466
+ else:
467
+ for i, chunk in enumerate(self.block_chunks): # this path
468
+ if self.add_lvl_embeding_only_first_block and i == 0:
469
+ x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
470
+ if not self.add_lvl_embeding_only_first_block:
471
+ x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
472
+ start_layer = True if i == 0 else False
473
+ x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid, start_layer=start_layer)
474
+
475
+ # [3. unpad the seqlen dim, and then get logits]
476
+ output = []
477
+ length = 0
478
+ for i, (_, h, w) in enumerate(scale_schedule):
479
+ length += h*w
480
+ start = np.array(scale_schedule[-1]).prod()
481
+ output = x_BLC[:, start:start+length]
482
+
483
+ return self.get_logits(output, cond_BD) # return logits BLV, V is vocab_size
484
+
485
+ @torch.no_grad()
486
+ def autoregressive_infer_cfg(
487
+ self,
488
+ vae=None,
489
+ scale_schedule=None, src_img_prefix=None,
490
+ label_B_or_BLT=None, clip_features=None,
491
+ B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None,
492
+ g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0,
493
+ returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False,
494
+ cfg_exp_k: float=0.0, cfg_insertion_layer=[-5],
495
+ vae_type=0, softmax_merge_topk=-1, ret_img=False,
496
+ trunk_scale=1000,
497
+ gt_leak=0, gt_ls_Bl=None,
498
+ inference_mode=False,
499
+ save_img_path=None,
500
+ sampling_per_bits=1,
501
+ ): # returns List[idx_Bl]
502
+ if g_seed is None: rng = None
503
+ else: self.rng.manual_seed(g_seed); rng = self.rng
504
+ assert len(cfg_list) >= len(scale_schedule)
505
+ assert len(tau_list) >= len(scale_schedule)
506
+
507
+ # scale_schedule is used by infinity, vae_scale_schedule is used by vae if there exists a spatial patchify,
508
+ # we need to convert scale_schedule to vae_scale_schedule by multiply 2 to h and w
509
+ if self.apply_spatial_patchify:
510
+ vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule]
511
+ else:
512
+ vae_scale_schedule = scale_schedule
513
+
514
+ src_BLC_emb = self.word_embed(self.norm0_ve(src_img_prefix))
515
+
516
+ kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT
517
+ if any(np.array(cfg_list) != 1):
518
+ bs = 2*B
519
+ if not negative_label_B_or_BLT:
520
+ kv_compact_un = kv_compact.clone()
521
+ total = 0
522
+ for le in lens:
523
+ kv_compact_un[total:total+le] = (self.cfg_uncond)[:le]
524
+ total += le
525
+ kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0)
526
+ cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0)
527
+ else:
528
+ kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT
529
+ kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0)
530
+ cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0)
531
+ max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un)
532
+ else:
533
+ bs = B
534
+
535
+ kv_compact = self.text_norm(kv_compact)
536
+ sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) # sos shape: [2, 4096]
537
+ kv_compact = self.text_proj_for_ca(kv_compact) # kv_compact shape: [304, 4096]
538
+ ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k
539
+ last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1)
540
+
541
+ with torch.amp.autocast('cuda', enabled=False):
542
+ cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous()
543
+ accu_BChw, cur_L, ret = None, 0, [] # current length, list of reconstructed images
544
+ idx_Bl_list, idx_Bld_list = [], []
545
+
546
+ if inference_mode:
547
+ for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True)
548
+ else:
549
+ assert self.num_block_chunks > 1
550
+ for block_chunk_ in self.block_chunks:
551
+ for module in block_chunk_.module.module:
552
+ (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True)
553
+
554
+ abs_cfg_insertion_layers = []
555
+ add_cfg_on_logits, add_cfg_on_probs = False, False
556
+ leng = len(self.unregistered_blocks)
557
+ for item in cfg_insertion_layer:
558
+ if item == 0: # add cfg on logits
559
+ add_cfg_on_logits = True
560
+ elif item == 1: # add cfg on probs
561
+ add_cfg_on_probs = True # todo in the future, we may want to add cfg on logits and probs
562
+ elif item < 0: # determine to add cfg at item-th layer's output
563
+ assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}'
564
+ abs_cfg_insertion_layers.append(leng+item)
565
+ else:
566
+ raise ValueError(f'cfg_insertion_layer: {item} is not valid')
567
+
568
+ start = 0
569
+ num_stages_minus_1 = len(scale_schedule)-1
570
+ length = np.array(scale_schedule[-1]).prod()
571
+ src_last_stage = src_BLC_emb
572
+ start += length
573
+ attn_fn = None
574
+ if self.use_flex_attn:
575
+ attn_fn = self.attn_fn_compile_dict.get(tuple(scale_schedule), None)
576
+ for block_idx, b in enumerate(self.block_chunks):
577
+ if self.add_lvl_embeding_only_first_block and block_idx == 0:
578
+ src_last_stage = self.add_lvl_embeding(src_last_stage, num_stages_minus_1, scale_schedule)
579
+ if not self.add_lvl_embeding_only_first_block:
580
+ src_last_stage = self.add_lvl_embeding(src_last_stage, num_stages_minus_1, scale_schedule)
581
+ start_layer = True if block_idx == 0 else False
582
+ for m in b.module:
583
+ src_last_stage = m(x=src_last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, start_layer=start_layer, scale_ind=num_stages_minus_1)
584
+ start_layer = False
585
+
586
+ summed_codes = 0
587
+ for si, pn in enumerate(scale_schedule): # si: i-th segment
588
+ cfg = cfg_list[si]
589
+ if si >= trunk_scale:
590
+ break
591
+ cur_L += np.array(pn).prod()
592
+
593
+ need_to_pad = 0
594
+ attn_fn = None
595
+ if self.use_flex_attn:
596
+ # need_to_pad = (self.pad_to_multiplier - cur_L % self.pad_to_multiplier) % self.pad_to_multiplier
597
+ # if need_to_pad:
598
+ # last_stage = F.pad(last_stage, (0, 0, 0, need_to_pad))
599
+ attn_fn = self.attn_fn_compile_dict.get(tuple(scale_schedule[:(si+1)]), None)
600
+
601
+ # assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item'
602
+ layer_idx = 0
603
+ for block_idx, b in enumerate(self.block_chunks):
604
+ # last_stage shape: [4, 1, 2048], cond_BD_or_gss.shape: [4, 1, 6, 2048], ca_kv[0].shape: [64, 2048], ca_kv[1].shape [5], ca_kv[2]: int
605
+ if self.add_lvl_embeding_only_first_block and block_idx == 0:
606
+ last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad)
607
+ if not self.add_lvl_embeding_only_first_block:
608
+ last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad)
609
+ start_layer = True if block_idx == 0 else False
610
+ for m in b.module:
611
+ last_stage = m(x=last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, start_layer=start_layer, scale_ind=si, src=False)
612
+ start_layer = False
613
+ if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers):
614
+ # print(f'add cfg={cfg} on {layer_idx}-th layer output')
615
+ last_stage = cfg * last_stage[:B] + (1-cfg) * last_stage[B:]
616
+ last_stage = torch.cat((last_stage, last_stage), 0)
617
+ layer_idx += 1
618
+
619
+ if (cfg != 1) and add_cfg_on_logits:
620
+ # print(f'add cfg on add_cfg_on_logits')
621
+ logits_BlV = self.get_logits(last_stage, cond_BD).mul(1/tau_list[si])
622
+ logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:]
623
+ else:
624
+ logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si])
625
+
626
+ if self.use_bit_label:
627
+ tmp_bs, tmp_seq_len = logits_BlV.shape[:2]
628
+ logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2)
629
+ idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0]
630
+ idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1)
631
+ else:
632
+ idx_Bl = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0]
633
+ if vae_type != 0:
634
+ assert returns_vemb
635
+ if si < gt_leak:
636
+ idx_Bld = gt_ls_Bl[si]
637
+ else:
638
+ assert pn[0] == 1
639
+ idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d]
640
+ if self.apply_spatial_patchify: # unpatchify operation
641
+ idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w]
642
+ idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w]
643
+ idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d]
644
+ idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d]
645
+
646
+ idx_Bld_list.append(idx_Bld)
647
+ codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w]
648
+ if si != num_stages_minus_1:
649
+ summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up)
650
+ last_stage = F.interpolate(summed_codes, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_up) # [B, d, 1, h, w] or [B, d, 1, 2h, 2w]
651
+ last_stage = last_stage.squeeze(-3) # [B, d, h, w] or [B, d, 2h, 2w]
652
+ if self.apply_spatial_patchify: # patchify operation
653
+ last_stage = torch.nn.functional.pixel_unshuffle(last_stage, 2) # [B, 4d, h, w]
654
+ last_stage = last_stage.reshape(*last_stage.shape[:2], -1) # [B, d, h*w] or [B, 4d, h*w]
655
+ last_stage = torch.permute(last_stage, [0,2,1]) # [B, h*w, d] or [B, h*w, 4d]
656
+ else:
657
+ summed_codes += codes
658
+ else:
659
+ if si < gt_leak:
660
+ idx_Bl = gt_ls_Bl[si]
661
+ h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl).float() # BlC
662
+
663
+ # h_BChw = h_BChw.float().transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1])
664
+ h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1], scale_schedule[si][2])
665
+ ret.append(h_BChw if returns_vemb != 0 else idx_Bl)
666
+ idx_Bl_list.append(idx_Bl)
667
+ if si != num_stages_minus_1:
668
+ accu_BChw, last_stage = self.quant_only_used_in_inference[0].one_step_fuse(si, num_stages_minus_1+1, accu_BChw, h_BChw, scale_schedule)
669
+
670
+ if si != num_stages_minus_1:
671
+ last_stage = self.word_embed(self.norm0_ve(last_stage))
672
+ last_stage = last_stage.repeat(bs//B, 1, 1)
673
+
674
+ if inference_mode:
675
+ for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False)
676
+ else:
677
+ assert self.num_block_chunks > 1
678
+ for block_chunk_ in self.block_chunks:
679
+ for module in block_chunk_.module.module:
680
+ (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False)
681
+
682
+ if not ret_img:
683
+ return ret, idx_Bld_list, []
684
+
685
+ if vae_type != 0:
686
+ img = vae.decode(summed_codes.squeeze(-3))
687
+ else:
688
+ img = vae.viz_from_ms_h_BChw(ret, scale_schedule=scale_schedule, same_shape=True, last_one=True)
689
+
690
+ img = (img + 1) / 2
691
+ img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,))
692
+ return ret, idx_Bld_list, img
693
+
694
+ @for_visualize
695
+ def vis_key_params(self, ep):
696
+ return
697
+
698
+ def load_state_dict(self, state_dict: Dict[str, Any], strict=False, assign=False):
699
+ for k in state_dict:
700
+ if 'cfg_uncond' in k:
701
+ old, new = state_dict[k], self.cfg_uncond.data
702
+ min_tlen = min(old.shape[0], new.shape[0])
703
+ if min_tlen == old.shape[0]:
704
+ state_dict[k] = torch.cat((old.to(device=new.device, dtype=new.dtype), new[min_tlen:]))
705
+ else:
706
+ state_dict[k] = old[:min_tlen]
707
+
708
+ for buf_name in ('lvl_1L', 'attn_bias_for_masking', 'Infinity_visible_kvlen', 'Infinity_invisible_qlen'):
709
+ state_dict.pop(buf_name, None)
710
+ if hasattr(self, buf_name):
711
+ state_dict[buf_name] = getattr(self, buf_name)
712
+
713
+ return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
714
+
715
+ def special_init(
716
+ self,
717
+ aln_init: float,
718
+ aln_gamma_init: float,
719
+ scale_head: float,
720
+ scale_proj: int,
721
+ ):
722
+ # init head's norm
723
+ if isinstance(self.head_nm, AdaLNBeforeHead):
724
+ self.head_nm.ada_lin[-1].weight.data.mul_(aln_init) # there's no gamma for head
725
+ if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None:
726
+ self.head_nm.ada_lin[-1].bias.data.zero_()
727
+
728
+ # init head's proj
729
+ if scale_head >= 0:
730
+ if isinstance(self.head, nn.Linear):
731
+ self.head.weight.data.mul_(scale_head)
732
+ self.head.bias.data.zero_()
733
+ elif isinstance(self.head, nn.Sequential):
734
+ self.head[-1].weight.data.mul_(scale_head)
735
+ self.head[-1].bias.data.zero_()
736
+
737
+ depth = len(self.unregistered_blocks)
738
+ for block_idx, sab in enumerate(self.unregistered_blocks):
739
+ sab: Union[SelfAttnBlock, CrossAttnBlock]
740
+ # init proj
741
+ scale = 1 / math.sqrt(2*depth if scale_proj == 1 else 2*(1 + block_idx))
742
+ if scale_proj == 1:
743
+ if self.t2i:
744
+ sab.sa.proj.weight.data.mul_(scale)
745
+ sab.ca.proj.weight.data.mul_(scale)
746
+ else:
747
+ sab.attn.proj.weight.data.mul_(scale)
748
+ sab.ffn.fc2.weight.data.mul_(scale)
749
+ # if sab.using_swiglu:
750
+ # nn.init.ones_(sab.ffn.fcg.bias)
751
+ # nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5)
752
+
753
+ # init ada_lin
754
+ if hasattr(sab, 'ada_lin'):
755
+ lin = sab.ada_lin[-1]
756
+ lin.weight.data[:2*self.C].mul_(aln_gamma_init) # init gamma
757
+ lin.weight.data[2*self.C:].mul_(aln_init) # init scale and shift
758
+ if hasattr(lin, 'bias') and lin.bias is not None:
759
+ lin.bias.data.zero_()
760
+ elif hasattr(sab, 'ada_gss'):
761
+ sab.ada_gss.data[:, :, :2, :].mul_(aln_gamma_init) # init gamma
762
+ sab.ada_gss.data[:, :, 2:, :].mul_(aln_init) # init scale and shift
763
+
764
+ def extra_repr(self):
765
+ return f'drop_path_rate={self.drop_path_rate}'
766
+
767
+ def get_layer_id_and_scale_exp(self, para_name: str):
768
+ raise NotImplementedError
769
+
770
+
771
+ def sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
772
+ B, l, V = logits_BlV.shape
773
+ if top_k > 0:
774
+ top_k = min(top_k, V)
775
+ idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
776
+ logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
777
+ if top_p > 0:
778
+ sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
779
+ sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
780
+ sorted_idx_to_remove[..., -1:] = False
781
+ logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf)
782
+ # sample (have to squeeze cuz multinomial can only be used on 2D tensor)
783
+ replacement = num_samples >= 0
784
+ num_samples = abs(num_samples)
785
+ return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
786
+
787
+ def sampling_with_top_k_top_p_also_inplace_modifying_probs_(probs_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
788
+ B, l, V = probs_BlV.shape
789
+ if top_k > 0:
790
+ top_k = min(top_k, V)
791
+ idx_to_remove = probs_BlV < probs_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
792
+ probs_BlV.masked_fill_(idx_to_remove, 0)
793
+ if top_p > 0:
794
+ sorted_probs, sorted_idx = probs_BlV.sort(dim=-1, descending=False)
795
+ sorted_idx_to_remove = sorted_probs.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
796
+ sorted_idx_to_remove[..., -1:] = False
797
+ probs_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), 0)
798
+ # sample (have to squeeze cuz multinomial can only be used on 2D tensor)
799
+ probs_BlV = probs_BlV / probs_BlV.sum(-1, keepdims=True)
800
+ replacement = num_samples >= 0
801
+ num_samples = abs(num_samples)
802
+ return torch.multinomial(probs_BlV.view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
803
+
804
+
805
+ def get_params_num(d, w, mlp):
806
+ m = round(mlp * w / 256) * 256
807
+ s = d * (w**2 * 8 + w*m * 2) # sa+ca, mlp
808
+ s += w**2 * 6 # saln
809
+ s += 4096 * w # pred
810
+ s += 32 * w # we
811
+
812
+ Ct5 = 4096
813
+ s += Ct5*w * 4 # T5 attn pool
814
+ s += Ct5*w + w*w # T5 mlp
815
+ return f'{s/1e9:.2f}B'
816
+
817
+
818
+ TIMM_KEYS = {'img_size', 'pretrained', 'pretrained_cfg', 'pretrained_cfg_overlay', 'global_pool'}
819
+
820
+ @register_model
821
+ def infinity_2b(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, **kwargs): return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
822
+
823
+ @register_model
824
+ def infinity_8b(depth=40, embed_dim=3584, num_heads=28, drop_path_rate=0.1, **kwargs): return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
825
+
826
+ @register_model
827
+ def infinity_20b(depth=58, embed_dim=4608, num_heads=4608//128, drop_path_rate=0.25, **kwargs): return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
828
+
829
+ # model configuration for scaling Infinity transformer
830
+ @register_model
831
+ def infinity_layer12(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, **kwargs):
832
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
833
+ @register_model
834
+ def infinity_layer16(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, **kwargs):
835
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
836
+ @register_model
837
+ def infinity_layer24(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, **kwargs):
838
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
839
+ @register_model
840
+ def infinity_layer32(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, **kwargs):
841
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
842
+ @register_model
843
+ def infinity_layer40(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, **kwargs):
844
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
845
+ @register_model
846
+ def infinity_layer48(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, **kwargs):
847
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
infinity/models/init_param.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def init_weights(model: nn.Module, conv_std_or_gain: float = 0.02, other_std: float = 0.02):
5
+ """
6
+ :param model: the model to be inited
7
+ :param conv_std_or_gain: how to init every conv layer `m`
8
+ > 0: nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain)
9
+ < 0: nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain)
10
+ :param other_std: how to init every linear layer or embedding layer
11
+ use nn.init.trunc_normal_(m.weight.data, std=other_std)
12
+ """
13
+ skip = abs(conv_std_or_gain) > 10
14
+ if skip: return
15
+ print(f'[init_weights] {type(model).__name__} with {"std" if conv_std_or_gain > 0 else "gain"}={abs(conv_std_or_gain):g}')
16
+ for m in model.modules():
17
+ if isinstance(m, nn.Linear):
18
+ nn.init.trunc_normal_(m.weight.data, std=other_std)
19
+ if m.bias is not None:
20
+ nn.init.constant_(m.bias.data, 0.)
21
+ elif isinstance(m, nn.Embedding):
22
+ nn.init.trunc_normal_(m.weight.data, std=other_std)
23
+ if m.padding_idx is not None:
24
+ m.weight.data[m.padding_idx].zero_()
25
+ elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
26
+ nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain) if conv_std_or_gain > 0 else nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain) # todo: StyleSwin: (..., gain=.02)
27
+ if hasattr(m, 'bias') and m.bias is not None:
28
+ nn.init.constant_(m.bias.data, 0.)
29
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
30
+ if m.bias is not None:
31
+ nn.init.constant_(m.bias.data, 0.)
32
+ if m.weight is not None:
33
+ nn.init.constant_(m.weight.data, 1.)
infinity/models/t5.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import os
4
+ import traceback
5
+ import numpy as np
6
+ from huggingface_hub import hf_hub_download
7
+ from transformers import AutoTokenizer, T5EncoderModel
8
+
9
+ import ftfy
10
+ import html
11
+ from bs4 import BeautifulSoup
12
+ import urllib.parse as ul
13
+
14
+
15
+ class T5Embedder:
16
+
17
+ available_models = ['t5-v1_1-xxl']
18
+ bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa
19
+
20
+ def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
21
+ t5_model_kwargs=None, torch_dtype=torch.bfloat16, use_offload_folder=None, model_max_length=512, padding="max_length", clean_caption_func_name="clean_caption"):
22
+ self.device = torch.device(device)
23
+ self.torch_dtype = torch_dtype
24
+ if t5_model_kwargs is None:
25
+ t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
26
+ if use_offload_folder is not None:
27
+ t5_model_kwargs['offload_folder'] = use_offload_folder
28
+ t5_model_kwargs['device_map'] = {
29
+ 'shared': self.device,
30
+ 'encoder.embed_tokens': self.device,
31
+ 'encoder.block.0': self.device,
32
+ 'encoder.block.1': self.device,
33
+ 'encoder.block.2': self.device,
34
+ 'encoder.block.3': self.device,
35
+ 'encoder.block.4': self.device,
36
+ 'encoder.block.5': self.device,
37
+ 'encoder.block.6': self.device,
38
+ 'encoder.block.7': self.device,
39
+ 'encoder.block.8': self.device,
40
+ 'encoder.block.9': self.device,
41
+ 'encoder.block.10': self.device,
42
+ 'encoder.block.11': self.device,
43
+ 'encoder.block.12': 'disk',
44
+ 'encoder.block.13': 'disk',
45
+ 'encoder.block.14': 'disk',
46
+ 'encoder.block.15': 'disk',
47
+ 'encoder.block.16': 'disk',
48
+ 'encoder.block.17': 'disk',
49
+ 'encoder.block.18': 'disk',
50
+ 'encoder.block.19': 'disk',
51
+ 'encoder.block.20': 'disk',
52
+ 'encoder.block.21': 'disk',
53
+ 'encoder.block.22': 'disk',
54
+ 'encoder.block.23': 'disk',
55
+ 'encoder.final_layer_norm': 'disk',
56
+ 'encoder.dropout': 'disk',
57
+ }
58
+ else:
59
+ t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device}
60
+
61
+ self.use_text_preprocessing = use_text_preprocessing
62
+ self.hf_token = hf_token
63
+ self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
64
+ self.dir_or_name = dir_or_name
65
+ tokenizer_path, path = dir_or_name, dir_or_name
66
+ if local_cache:
67
+ cache_dir = os.path.join(self.cache_dir, dir_or_name)
68
+ tokenizer_path, path = cache_dir, cache_dir
69
+ elif dir_or_name in self.available_models:
70
+ cache_dir = os.path.join(self.cache_dir, dir_or_name)
71
+ for filename in [
72
+ 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
73
+ 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'
74
+ ]:
75
+ hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
76
+ force_filename=filename, token=self.hf_token)
77
+ tokenizer_path, path = cache_dir, cache_dir
78
+ else:
79
+ cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')
80
+ for filename in [
81
+ 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
82
+ ]:
83
+ hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,
84
+ force_filename=filename, token=self.hf_token)
85
+ tokenizer_path = cache_dir
86
+
87
+ print(f"Loading T5 from {tokenizer_path}")
88
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
89
+ self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
90
+ self.model_max_length = model_max_length
91
+ self.padding = padding
92
+ self.clean_caption_func = self.__getattribute__(clean_caption_func_name)
93
+
94
+ @torch.no_grad()
95
+ def get_text_embeddings(self, texts):
96
+ import time
97
+ start_time = time.time()
98
+
99
+ texts = [self.text_preprocessing(text) for text in texts]
100
+ # print("text_preprocessing: ", time.time() - start_time)
101
+
102
+ text_tokens_and_mask = self.tokenizer(
103
+ texts,
104
+ max_length=self.model_max_length,
105
+ padding=self.padding,
106
+ truncation=True,
107
+ return_attention_mask=True,
108
+ add_special_tokens=True,
109
+ return_tensors='pt'
110
+ )
111
+
112
+ # print("tokenizer: ", time.time() - start_time)
113
+
114
+ text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids'].to(self.device)
115
+ text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask'].to(self.device)
116
+
117
+ with torch.no_grad():
118
+ text_encoder_embs = self.model(
119
+ input_ids=text_tokens_and_mask['input_ids'],
120
+ attention_mask=text_tokens_and_mask['attention_mask'],
121
+ )['last_hidden_state'].detach()
122
+
123
+ # print("model: ", time.time() - start_time)
124
+ return text_encoder_embs, text_tokens_and_mask['attention_mask'], text_tokens_and_mask['input_ids'], texts
125
+
126
+ def text_preprocessing(self, text):
127
+ if self.use_text_preprocessing:
128
+ try:
129
+ # The exact text cleaning as was in the training stage:
130
+ text = self.clean_caption_func(text)
131
+ text = self.clean_caption_func(text)
132
+ return text
133
+ except Exception as e:
134
+ print(f"Error in text preprocessing: {e} with text: {text}")
135
+ print(traceback.format_exc())
136
+ return text
137
+ else:
138
+ return text.lower().strip()
139
+
140
+ @staticmethod
141
+ def basic_clean(text):
142
+ text = ftfy.fix_text(text)
143
+ text = html.unescape(html.unescape(text))
144
+ return text.strip()
145
+
146
+ def clean_caption(self, caption):
147
+ caption = str(caption)
148
+ caption = ul.unquote_plus(caption)
149
+ caption = caption.strip().lower()
150
+ caption = re.sub('<person>', 'person', caption)
151
+ # urls:
152
+ caption = re.sub(
153
+ r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
154
+ '', caption) # regex for urls
155
+ caption = re.sub(
156
+ r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
157
+ '', caption) # regex for urls
158
+ # html:
159
+ try:
160
+ caption = BeautifulSoup(caption, features='html.parser').text
161
+ except Exception as e:
162
+ print(f"Error parsing caption:{caption} with html.parser: {e}")
163
+
164
+ # @<nickname>
165
+ caption = re.sub(r'@[\w\d]+\b', '', caption)
166
+
167
+ # 31C0—31EF CJK Strokes
168
+ # 31F0—31FF Katakana Phonetic Extensions
169
+ # 3200—32FF Enclosed CJK Letters and Months
170
+ # 3300—33FF CJK Compatibility
171
+ # 3400—4DBF CJK Unified Ideographs Extension A
172
+ # 4DC0—4DFF Yijing Hexagram Symbols
173
+ # 4E00—9FFF CJK Unified Ideographs
174
+ caption = re.sub(r'[\u31c0-\u31ef]+', '', caption)
175
+ caption = re.sub(r'[\u31f0-\u31ff]+', '', caption)
176
+ caption = re.sub(r'[\u3200-\u32ff]+', '', caption)
177
+ caption = re.sub(r'[\u3300-\u33ff]+', '', caption)
178
+ caption = re.sub(r'[\u3400-\u4dbf]+', '', caption)
179
+ caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption)
180
+ caption = re.sub(r'[\u4e00-\u9fff]+', '', caption)
181
+ #######################################################
182
+
183
+ # все виды тире / all types of dash --> "-"
184
+ caption = re.sub(
185
+ r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa
186
+ '-', caption)
187
+
188
+ # кавычки к одному стандарту
189
+ caption = re.sub(r'[`´«»“”¨]', '"', caption)
190
+ caption = re.sub(r'[‘’]', "'", caption)
191
+
192
+ # &quot;
193
+ caption = re.sub(r'&quot;?', '', caption)
194
+ # &amp
195
+ caption = re.sub(r'&amp', '', caption)
196
+
197
+ # ip adresses:
198
+ caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption)
199
+
200
+ # article ids:
201
+ caption = re.sub(r'\d:\d\d\s+$', '', caption)
202
+
203
+ # \n
204
+ caption = re.sub(r'\\n', ' ', caption)
205
+
206
+ # "#123"
207
+ caption = re.sub(r'#\d{1,3}\b', '', caption)
208
+ # "#12345.."
209
+ caption = re.sub(r'#\d{5,}\b', '', caption)
210
+ # "123456.."
211
+ caption = re.sub(r'\b\d{6,}\b', '', caption)
212
+ # filenames:
213
+ caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption)
214
+
215
+ #
216
+ caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT"""
217
+ caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT"""
218
+
219
+ caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
220
+ caption = re.sub(r'\s+\.\s+', r' ', caption) # " . "
221
+
222
+ # this-is-my-cute-cat / this_is_my_cute_cat
223
+ regex2 = re.compile(r'(?:\-|\_)')
224
+ if len(re.findall(regex2, caption)) > 3:
225
+ caption = re.sub(regex2, ' ', caption)
226
+
227
+ caption = self.basic_clean(caption)
228
+
229
+ caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640
230
+ caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc
231
+ caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231
232
+
233
+ caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption)
234
+ caption = re.sub(r'(free\s)?download(\sfree)?', '', caption)
235
+ caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption)
236
+ caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption)
237
+ caption = re.sub(r'\bpage\s+\d+\b', '', caption)
238
+
239
+ caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a...
240
+
241
+ caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)
242
+
243
+ caption = re.sub(r'\b\s+\:\s+', r': ', caption)
244
+ caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
245
+ caption = re.sub(r'\s+', ' ', caption)
246
+
247
+ caption.strip()
248
+
249
+ caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
250
+ caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
251
+ caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
252
+ caption = re.sub(r'^\.\S+$', '', caption)
253
+
254
+ return caption.strip()
255
+
256
+
257
+ def clean_caption_simplify(self, caption):
258
+ # 将 caption 转换为字符串
259
+ caption = str(caption)
260
+
261
+ # 解码 URL 编码的字符串
262
+ caption = ul.unquote_plus(caption)
263
+
264
+ # 去除首尾空格并转换为小写
265
+ caption = caption.strip().lower()
266
+
267
+ # 将 '<person>' 替换为 'person'
268
+ caption = re.sub('<person>', 'person', caption)
269
+
270
+ # 移除 URL
271
+ caption = re.sub(
272
+ r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))',
273
+ '', caption) # 匹配以 http:// 或 https:// 开头的 URL
274
+ caption = re.sub(
275
+ r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))',
276
+ '', caption) # 匹配以 www. 开头的 URL
277
+
278
+ # 解析 HTML 并删除 HTML 标签
279
+ caption = BeautifulSoup(caption, features='html.parser').text
280
+
281
+ # 移除 @nickname 标签
282
+ caption = re.sub(r'@[\w\d]+\b', '', caption)
283
+
284
+ # 移除特定 Unicode 范围的字符:CJK 相关字符
285
+ caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) # CJK 笔划
286
+ caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) # 片假名语音扩展
287
+ caption = re.sub(r'[\u3200-\u32ff]+', '', caption) # 圆括号中的 CJK 字母和月份
288
+ caption = re.sub(r'[\u3300-\u33ff]+', '', caption) # CJK 兼容性
289
+ caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) # CJK 统一表意符号扩展 A
290
+ caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) # 易经卦象符号
291
+ caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) # CJK 统一表意符号
292
+
293
+ # 所有类型的破折号替换为 "-"
294
+ caption = re.sub(
295
+ r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+',
296
+ '-', caption) # 匹配各种 Unicode 破折号
297
+
298
+ # 统一不同类型的引号
299
+ caption = re.sub(r'[`´«»“”¨]', '"', caption) # 将各种引号替换为标准引号
300
+ caption = re.sub(r'[‘’]', "'", caption) # 将左单引号和右单引号替换为标准单引号
301
+
302
+ # 移除 &quot; 和 &amp
303
+ caption = re.sub(r'&quot;?', '', caption) # 移除 HTML 实体 &quot;
304
+ caption = re.sub(r'&amp', '', caption) # 移除 HTML 实体 &amp
305
+
306
+ # 移除 IP 地址
307
+ caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) # 匹配 IPv4 地址
308
+
309
+ # 移除文章 ID 格式
310
+ caption = re.sub(r'\d:\d\d\s+$', '', caption) # 匹配类似 '1:23 ' 的格式
311
+
312
+ # 移除 \n 转义字符
313
+ caption = re.sub(r'\\n', ' ', caption)
314
+
315
+ # 移除特定格式的标签
316
+ # caption = re.sub(r'#\d{1,3}\b', '', caption) # #123 移除 # 加 1 到 3 位数字的标签
317
+ # caption = re.sub(r'#\d{5,}\b', '', caption) # #12345.. 移除 # 加 5 位或以上数字的标签
318
+ # caption = re.sub(r'\b\d{6,}\b', '', caption) # 123456.. 移除 6 位或以上的纯数字
319
+
320
+ # 移除文件名
321
+ caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) # 匹配图片和视频文件,匹配完整的文件名,包括文件名本身和扩展名。
322
+
323
+ # 简化多重引号和点
324
+ caption = re.sub(r'[\"\']{2,}', r'"', caption) # 连续的双引号替换为一个双引号
325
+ caption = re.sub(r'[\.]{2,}', r' ', caption) # 连续的点替换为空格
326
+
327
+ # 使用通用标点正则表达式清理无效标点
328
+ caption = re.sub(self.bad_punct_regex, r' ', caption) # 自定义的无效标点正则表达式
329
+ caption = re.sub(r'\s+\.\s+', r' ', caption) # 移除空格和点
330
+
331
+ # 过滤带有太多破折号或下划线的文本
332
+ regex2 = re.compile(r'(?:\-|\_)')
333
+ if len(re.findall(regex2, caption)) > 3:
334
+ caption = re.sub(regex2, ' ', caption)
335
+
336
+ # 基本清理
337
+ caption = self.basic_clean(caption)
338
+
339
+ # 移除特定格式的短字符串
340
+ # caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # 匹配三个字母以下加三个数字以上的字符串
341
+ # caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # 匹配字母数字混合的字符串
342
+ # caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 匹配数字字母混合的字符串
343
+
344
+ # 移除特定的广告或指令性短语
345
+ # caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) # 匹配 'worldwide free shipping', 'free shipping'
346
+ # caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) # 匹配 'free download', 'download free'
347
+ # caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) # 匹配 'click for ...' 或 'click on ...'
348
+ # caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) # 匹配文件扩展名,匹配独立的扩展名或扩展名后可能跟随的特定词汇的场景
349
+ # caption = re.sub(r'\bpage\s+\d+\b', '', caption) # 匹配 'page 123'
350
+
351
+ # 移除复杂模式的字符串
352
+ # caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # 123A456B789
353
+
354
+ # 移除特定的矩形标识符
355
+ caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)
356
+
357
+ # 修复多余的空白和标点
358
+ caption = re.sub(r'\b\s+\:\s+', r': ', caption)
359
+ caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
360
+ caption = re.sub(r'\s+', ' ', caption)
361
+
362
+ # 去除首尾的多余字符
363
+ caption.strip()
364
+ caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
365
+ caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
366
+ caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
367
+ caption = re.sub(r'^\.\S+$', '', caption)
368
+
369
+ return caption.strip()
infinity/utils/amp_opt.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import signal
4
+ import sys
5
+ import time
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
10
+ # from memory_profiler import profile
11
+
12
+ import infinity.utils.dist as dist
13
+ from infinity.utils import misc
14
+
15
+ class NullCtx:
16
+ def __enter__(self):
17
+ pass
18
+
19
+ def __exit__(self, exc_type, exc_val, exc_tb):
20
+ pass
21
+
22
+
23
+ def handle_timeout(signum, frame):
24
+ raise TimeoutError('took too long')
25
+
26
+
27
+ def per_param_clip_grad_norm_(parameters, thresh: float, stable=False, fp=None) -> (float, float):
28
+ skipped, max_grad = [], 0
29
+ for pi, p in enumerate(parameters):
30
+ if p.grad is not None:
31
+ g = p.grad.data.norm(2).item() + 1e-7
32
+ max_grad = max(max_grad, g)
33
+ clip_coef = thresh / g
34
+ if clip_coef < 1:
35
+ if stable and clip_coef < 0.2:
36
+ skipped.append(clip_coef)
37
+ p.grad.data.mul_(0) # todo NOTE: inf.mul_(0)==nan will shrink the scale ratio, but inf.zero_()==0 won't
38
+ else:
39
+ p.grad.data.mul_(clip_coef)
40
+
41
+ # if fp is not None: fp.write(f'[per_param_clip_grad_norm_:47] finished.\n'); fp.flush()
42
+ return 0 if len(skipped) == 0 else math.log10(max(min(skipped), 1e-7)), max_grad
43
+
44
+
45
+ class AmpOptimizer:
46
+ def __init__(
47
+ self,
48
+ model_name_3letters: str, mixed_precision: int,
49
+ optimizer: torch.optim.Optimizer, model_maybe_fsdp: Union[torch.nn.Module, FSDP],
50
+ r_accu: float, grad_clip: float, zero: int,
51
+ ):
52
+ self.enable_amp = mixed_precision > 0
53
+ self.zero = zero
54
+ if self.enable_amp:
55
+ self.using_fp16_rather_bf16 = mixed_precision != 2
56
+ self.max_sc = float(mixed_precision if mixed_precision > 128 else 32768)
57
+
58
+ # todo: on both V100 and A100, torch.get_autocast_gpu_dtype() returns fp16, not bf16.
59
+ self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=self.zero == 0) # todo: cache_enabled=False
60
+ if self.using_fp16_rather_bf16:
61
+ self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000)
62
+ else:
63
+ self.scaler = None
64
+ else:
65
+ self.using_fp16_rather_bf16 = True
66
+ self.amp_ctx = NullCtx()
67
+ self.scaler = None
68
+
69
+ t = torch.zeros(dist.get_world_size())
70
+ t[dist.get_rank()] = float(self.enable_amp)
71
+ dist.allreduce(t)
72
+ assert round(t.sum().item()) in {0, dist.get_world_size()}, f'enable_amp: {t}'
73
+
74
+ t = torch.zeros(dist.get_world_size())
75
+ t[dist.get_rank()] = float(self.using_fp16_rather_bf16)
76
+ dist.allreduce(t)
77
+ assert round(t.sum().item()) in {0, dist.get_world_size()}, f'using_fp16_rather_bf16: {t}'
78
+
79
+ self.model_name_3letters = model_name_3letters
80
+ self.optimizer, self.model_maybe_fsdp = optimizer, model_maybe_fsdp
81
+ self.r_accu = r_accu
82
+
83
+ self.paras = self.names = ... # todo: solve EMA-related codes
84
+
85
+ self.grad_clip, self.grad_clip_we = grad_clip, 0 # todo: disable wclip
86
+ if self.grad_clip > 100:
87
+ self.grad_clip %= 100
88
+ self.per_param = True
89
+ else:
90
+ self.per_param = False
91
+ self.per_param = False # todo: disable wclip
92
+
93
+ self.early_clipping = grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm')
94
+ self.late_clipping = grad_clip > 0 and hasattr(optimizer, 'global_grad_norm') # deepspeed's optimizer
95
+
96
+ self.fp = None
97
+ self.last_orig_norm: torch.Tensor = torch.tensor(0.1)
98
+
99
+ @torch.no_grad()
100
+ def log_param(self, ep: int):
101
+ if self.zero == 0:
102
+ for name, values in get_param_for_log(self.model_name_3letters, self.model_maybe_fsdp.named_parameters()).items():
103
+ values: List[float]
104
+ if len(values) == 1: # e.g., cls token will only have one value
105
+ values.append(values[0])
106
+ else:
107
+ ...
108
+ # todo: log params
109
+
110
+ # @profile(precision=4, stream=open('amp_sc.log', 'w+'))
111
+ def backward_clip_step(
112
+ self, ep: int, it: int, g_it: int, stepping: bool, logging_params: bool, loss: torch.Tensor, clip_decay_ratio=1, stable=False,
113
+ ) -> Tuple[torch.Tensor, Optional[float]]:
114
+ # backward
115
+ loss = loss.mul(self.r_accu) # r_accu == 1.0 / n_gradient_accumulation
116
+ orig_norm = scaler_sc = None
117
+ # if self.fp is not None:
118
+ # if g_it % 20 == 0: self.fp.seek(0); self.fp.truncate(0)
119
+ if self.scaler is not None:
120
+ self.scaler.scale(loss).backward(retain_graph=False, create_graph=False) # retain_graph=retain_graph, create_graph=create_graph
121
+ else:
122
+ loss.backward(retain_graph=False, create_graph=False)
123
+ # if self.fp is not None: self.fp.write(f'[backward_clip_step:131] [it{it}, g_it{g_it}] after backward\n'); self.fp.flush()
124
+
125
+ # clip gradients then step optimizer
126
+ if stepping:
127
+ if self.scaler is not None: self.scaler.unscale_(self.optimizer) # now the gradient can be correctly got
128
+ # if self.fp is not None: self.fp.write(f'[backward_clip_step:137] [it{it}, g_it{g_it}] after scaler.unscale_\n'); self.fp.flush()
129
+
130
+ skipped, orig_norm = 0, self.last_orig_norm
131
+ # try:
132
+ if self.fp is not None:
133
+ if g_it % 10 == 0: self.fp.seek(0); self.fp.truncate(0)
134
+ self.fp.write(f'<ep{ep} it{it} {g_it}>\n'); self.fp.flush()
135
+ if self.early_clipping:
136
+ c = self.grad_clip * clip_decay_ratio
137
+ if self.zero:
138
+ orig_norm: Optional[torch.Tensor] = self.model_maybe_fsdp.clip_grad_norm_(c)
139
+ else:
140
+ orig_norm: Optional[torch.Tensor] = torch.nn.utils.clip_grad_norm_(self.model_maybe_fsdp.parameters(), c)
141
+
142
+ # if self.fp is not None: self.fp.write(f'[backward_clip_step:175] [it{it}, g_it{g_it}] before opt step\n'); self.fp.flush()
143
+ if self.scaler is not None:
144
+ self.scaler: torch.cuda.amp.GradScaler
145
+ if self.zero:
146
+ # synchronize found_inf_per_device before calling step, so that even if only some ranks found inf on their sharded params, all other ranks will know
147
+ # otherwise, when saving FSDP optimizer state, it will cause AssertionError saying "Different ranks have different values for step."
148
+ for optimizer_state in self.scaler._per_optimizer_states.values():
149
+ for t in optimizer_state['found_inf_per_device'].values():
150
+ dist.allreduce(t) # ideally, each rank only has one single t; so no need to use async allreduce
151
+
152
+ self.scaler.step(self.optimizer)
153
+ scaler_sc: Optional[float] = self.scaler.get_scale()
154
+ if scaler_sc > self.max_sc: # fp16 will overflow when >65536, so multiply 32768 could be dangerous
155
+ # print(f'[fp16 scaling] too large loss scale {scaler_sc}! (clip to {self.max_sc:g})')
156
+ self.scaler.update(new_scale=self.max_sc)
157
+ else:
158
+ self.scaler.update()
159
+ try:
160
+ scaler_sc = float(math.log2(scaler_sc))
161
+ except Exception as e:
162
+ print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True)
163
+ time.sleep(1)
164
+ print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True)
165
+ raise e
166
+ else:
167
+ self.optimizer.step()
168
+
169
+ if self.late_clipping:
170
+ orig_norm: Optional[torch.Tensor] = self.optimizer.global_grad_norm
171
+ self.last_orig_norm = orig_norm
172
+ # no zero_grad calling here, gonna log those gradients!
173
+ return orig_norm, scaler_sc
174
+
175
+ def state_dict(self):
176
+ return {
177
+ 'optimizer': self.optimizer.state_dict()
178
+ } if self.scaler is None else {
179
+ 'scaler': self.scaler.state_dict(),
180
+ 'optimizer': self.optimizer.state_dict()
181
+ }
182
+
183
+ def load_state_dict(self, state, strict=True):
184
+ if self.scaler is not None:
185
+ try: self.scaler.load_state_dict(state['scaler'])
186
+ except Exception as e: print(f'[fp16 load_state_dict err] {e}')
187
+ self.optimizer.load_state_dict(state['optimizer'])
infinity/utils/arg_util.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ import random
5
+ import subprocess
6
+ import sys
7
+ import time
8
+ from collections import OrderedDict, deque
9
+ from typing import Optional, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ from tap import Tap
14
+
15
+ import infinity.utils.dist as dist
16
+
17
+
18
+ class Args(Tap):
19
+ local_out_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # directory for save checkpoints
20
+ data_path: str = '' # dataset
21
+ bed: str = '' # bed directory for copy checkpoints apart from local_out_path
22
+ vae_ckpt: str = '' # VAE ckpt
23
+ exp_name: str = '' # experiment name
24
+ ds: str = 'oi' # only used in GPT training::load_viz_data & FID benchmark
25
+ model: str = '' # for VAE training, 'b' or any other for GPT training
26
+ short_cap_prob: float = 0.2 # prob for training with short captions
27
+ project_name: str = 'Infinity' # name of wandb project
28
+ tf32: bool = True # whether to use TensorFloat32
29
+ auto_resume: bool = True # whether to automatically resume from the last checkpoint found in args.bed
30
+ rush_resume: str = '' # pretrained infinity checkpoint
31
+ nowd: int = 1 # whether to disable weight decay on sparse params (like class token)
32
+ enable_hybrid_shard: bool = False # whether to use hybrid FSDP
33
+ inner_shard_degree: int = 1 # inner degree for FSDP
34
+ zero: int = 0 # ds zero
35
+ buck: str = 'chunk' # =0 for using module-wise
36
+ fsdp_orig: bool = True
37
+ enable_checkpointing: str = None # checkpointing strategy: full-block, self-attn
38
+ pad_to_multiplier: int = 1 # >1 for padding the seq len to a multiplier of this
39
+ log_every_iter: bool = False
40
+ checkpoint_type: str = 'torch' # checkpoint_type: torch, onmistore
41
+ seed: int = None # 3407
42
+ rand: bool = True # actual seed = seed + (dist.get_rank()*512 if rand else 0)
43
+ device: str = 'cpu'
44
+ task_id: str = '2493513'
45
+ trial_id: str = '7260554'
46
+ robust_run_id: str = '00'
47
+ ckpt_trials = []
48
+ real_trial_id: str = '7260552'
49
+ chunk_nodes: int = None
50
+ is_master_node: bool = None
51
+ # dir
52
+ log_txt_path: str = ''
53
+ t5_path: str = '' # if not specified: automatically find from all bytenas
54
+ online_t5: bool = True # whether to use online t5 or load local features
55
+ # GPT
56
+ sdpa_mem: bool = True # whether to use with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True)
57
+ tfast: int = 0 # compile GPT
58
+ model_alias: str = 'b' # [automatically set; don't specify this]
59
+ rms: bool = False
60
+ aln: float = 1e-3 # multiplier of ada_lin.w's initialization
61
+ alng: float = -1 # multiplier of ada_lin.w[gamma channels]'s initialization, -1: the same as aln
62
+ saln: bool = False # whether to use a shared adaln layer
63
+ haln: bool = True # whether to use a specific adaln layer in head layer
64
+ nm0: bool = False # norm before word proj linear
65
+ tau: float = 1 # tau of self attention in GPT
66
+ cos: bool = True # cosine attn as in swin v2
67
+ swi: bool = False # whether to use FFNSwiGLU, instead of vanilla FFN
68
+ dp: float = -1
69
+ drop: float = 0.0 # GPT's dropout (VAE's is --vd)
70
+ hd: int = 0
71
+ ca_gamma: float = -1 # >=0 for using layer-scale for cross attention
72
+ diva: int = 1 # rescale_attn_fc_weights
73
+ hd0: float = 0.02 # head.w *= hd0
74
+ dec: int = 1 # dec depth
75
+ cum: int = 3 # cumulating fea map as GPT TF input, 0: not cum; 1: cum @ next hw, 2: cum @ final hw
76
+ rwe: bool = False # random word emb
77
+ tp: float = 0.0 # top-p
78
+ tk: float = 0.0 # top-k
79
+ tini: float = 0.02 # init parameters
80
+ cfg: float = 0.1 # >0: classifier-free guidance, drop cond with prob cfg
81
+ rand_uncond = False # whether to use random, unlearnable uncond embeding
82
+ ema: float = 0.9999 # VAE's ema ratio, not VAR's. 0.9977844 == 0.5 ** (32 / (10 * 1000)) from gans, 0.9999 from SD
83
+ tema: float = 0 # 0.9999 in DiffiT, DiT
84
+ fp16: int = 0 # 1: fp16, 2: bf16, >2: fp16's max scaling multiplier todo: 记得让quantize相关的feature都强制fp32!另外residueal最好也是fp32(根据flash-attention)nn.Conv2d有一个参数是use_float16?
85
+ fuse: bool = False # whether to use fused mlp
86
+ fused_norm: bool = False # whether to use fused norm
87
+ flash: bool = False # whether to use customized flash-attn kernel
88
+ xen: bool = False # whether to use xentropy
89
+ use_flex_attn: bool = False # whether to use flex_attn to speedup training
90
+ stable: bool = False
91
+ gblr: float = 1e-4
92
+ dblr: float = None # =gblr if is None
93
+ tblr: float = 6e-4
94
+ glr: float = None
95
+ dlr: float = None
96
+ tlr: float = None # vqgan: 4e-5
97
+ gwd: float = 0.005
98
+ dwd: float = 0.0005
99
+ twd: float = 0.005 # vqgan: 0.01
100
+ gwde: float = 0
101
+ dwde: float = 0
102
+ twde: float = 0
103
+ ls: float = 0.0 # label smooth
104
+ lz: float = 0.0 # z loss from PaLM = 1e-4 todo
105
+ eq: int = 0 # equalized loss
106
+ ep: int = 100
107
+ wp: float = 0
108
+ wp0: float = 0.005
109
+ wpe: float = 0.3 # 0.001, final cosine lr = wpe * peak lr
110
+ sche: str = '' # cos, exp, lin
111
+ log_freq: int = 50 # log frequency in the stdout
112
+ gclip: float = 6. # <=0 for not grad clip VAE
113
+ dclip: float = 6. # <=0 for not grad clip discriminator
114
+ tclip: float = 2. # <=0 for not grad clip GPT; >100 for per-param clip (%= 100 automatically)
115
+ cdec: bool = False # decay the grad clip thresholds of GPT and GPT's word embed
116
+ opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5(比Adam学习率低四倍)和wd=0.8(比Adam高八倍);比如在小的 batch_size 时,Lion 的表现不如 AdamW
117
+ ada: str = '' # adam's beta0 and beta1 for VAE or GPT, '0_0.99' from style-swin and magvit, '0.5_0.9' from VQGAN
118
+ dada: str = '' # adam's beta0 and beta1 for discriminator
119
+ oeps: float = 0 # adam's eps, pixart uses 1e-10
120
+ afuse: bool = True # fused adam
121
+ # data
122
+ pn: str = '' # pixel nums, choose from 0.06M, 0.25M, 1M
123
+ scale_schedule: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_')))
124
+ patch_size: int = None # [automatically set; don't specify this] = 2 ** (len(args.scale_schedule) - 1)
125
+ resos: tuple = None # [automatically set; don't specify this]
126
+ data_load_reso: int = None # [automatically set; don't specify this]
127
+ workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
128
+ lbs: int = 0 # local batch size; if lbs != 0, bs will be ignored, and will be reset as round(args.lbs / args.ac) * dist.get_world_size()
129
+ bs: int = 0 # global batch size; if lbs != 0, bs will be ignored
130
+ batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size())
131
+ glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size()
132
+ ac: int = 1 # gradient accumulation
133
+ r_accu: float = 1.0 # [automatically set; don't specify this] = 1 / args.ac
134
+ norm_eps: float = 1e-6 # norm eps for infinity
135
+ tlen: int = 512 # truncate text embedding to this length
136
+ Ct5: int = 2048 # feature dimension of text encoder
137
+ use_bit_label: int = 1 # pred bitwise labels or index-wise labels
138
+ bitloss_type: str = 'mean' # mean or sum
139
+ dynamic_resolution_across_gpus: int = 1 # allow dynamic resolution across gpus
140
+ enable_dynamic_length_prompt: int = 0 # enable dynamic length prompt during training
141
+ use_streaming_dataset: int = 0 # use streaming dataset
142
+ iterable_data_buffersize: int = 90000 # streaming dataset buffer size
143
+ save_model_iters_freq: int = 1000 # save model iter freq
144
+ noise_apply_layers: int = -1 # Bitwise Self-Correction: apply noise to layers, -1 means not apply noise
145
+ noise_apply_strength: float = -1 # Bitwise Self-Correction: apply noise strength, -1 means not apply noise
146
+ noise_apply_requant: int = 1 # Bitwise Self-Correction: requant after apply noise
147
+ rope2d_each_sa_layer: int = 0 # apply rope2d to each self-attention layer
148
+ rope2d_normalized_by_hw: int = 1 # apply normalized rope2d
149
+ use_fsdp_model_ema: int = 0 # use fsdp model ema
150
+ add_lvl_embeding_only_first_block: int = 1 # apply lvl pe embedding only first block or each block
151
+ reweight_loss_by_scale: int = 0 # reweight loss by scale
152
+ always_training_scales: int = 100 # trunc training scales
153
+ vae_type: int = 1 # here 16/32/64 is bsq vae of different quant bits
154
+ fake_vae_input: bool = False # fake vae input for debug
155
+ model_init_device: str = 'cuda' # model_init_device
156
+ prefetch_factor: int = 2 # prefetch_factor for dataset
157
+ apply_spatial_patchify: int = 0 # apply apply_spatial_patchify or not
158
+ debug_bsc: int = 0 # save figs and set breakpoint for debug bsc and check input
159
+ task_type: str = 't2i' # take type to t2i or t2v
160
+
161
+
162
+ ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ###############################
163
+ ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ###############################
164
+ ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ###############################
165
+
166
+
167
+ # would be automatically set in runtime
168
+ branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
169
+ commit_id: str = '' # subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
170
+ commit_msg: str = ''# (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this]
171
+ cmd: str = ' '.join(a.replace('--exp_name=', '').replace('--exp_name ', '') for a in sys.argv[7:]) # [automatically set; don't specify this]
172
+ tag: str = 'UK' # [automatically set; don't specify this]
173
+ acc_all: float = None # [automatically set; don't specify this]
174
+ acc_real: float = None # [automatically set; don't specify this]
175
+ acc_fake: float = None # [automatically set; don't specify this]
176
+ last_Lnll: float = None # [automatically set; don't specify this]
177
+ last_L1: float = None # [automatically set; don't specify this]
178
+ last_Ld: float = None # [automatically set; don't specify this]
179
+ last_wei_g: float = None # [automatically set; don't specify this]
180
+ grad_boom: str = None # [automatically set; don't specify this]
181
+ diff: float = None # [automatically set; don't specify this]
182
+ diffs: str = '' # [automatically set; don't specify this]
183
+ diffs_ema: str = None # [automatically set; don't specify this]
184
+ ca_performance: str = '' # [automatically set; don't specify this]
185
+ cur_phase: str = '' # [automatically set; don't specify this]
186
+ cur_it: str = '' # [automatically set; don't specify this]
187
+ cur_ep: str = '' # [automatically set; don't specify this]
188
+ remain_time: str = '' # [automatically set; don't specify this]
189
+ finish_time: str = '' # [automatically set; don't specify this]
190
+ iter_speed: float = None # [automatically set; don't specify this]
191
+ img_per_day: float = None # [automatically set; don't specify this]
192
+ max_nvidia_smi: float = 0 # [automatically set; don't specify this]
193
+ max_memory_allocated: float = None # [automatically set; don't specify this]
194
+ max_memory_reserved: float = None # [automatically set; don't specify this]
195
+ num_alloc_retries: int = None # [automatically set; don't specify this]
196
+ MFU: float = None # [automatically set; don't specify this]
197
+ HFU: float = None # [automatically set; don't specify this]
198
+ # ==================================================================================================================
199
+ # ======================== ignore these parts below since they are only for debug use ==============================
200
+ # ==================================================================================================================
201
+ dbg_modified: bool = False
202
+ dbg_ks: bool = False
203
+ dbg_ks_last = None
204
+ dbg_ks_fp = None
205
+ def dbg_ks_this_line(self, g_it: int):
206
+ if self.dbg_ks:
207
+ if self.dbg_ks_last is None:
208
+ self.dbg_ks_last = deque(maxlen=6)
209
+
210
+ from utils.misc import time_str
211
+ self.dbg_ks_fp.seek(0)
212
+ f_back = sys._getframe().f_back
213
+ file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
214
+ info = f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})'
215
+ if g_it is not None:
216
+ info += f' [g_it: {g_it}]'
217
+
218
+ self.dbg_ks_last.append(info)
219
+ self.dbg_ks_fp.write('\n'.join(self.dbg_ks_last) + '\n')
220
+ self.dbg_ks_fp.flush()
221
+
222
+ dbg: bool = 'KEVIN_LOCAL' in os.environ # only used when debug about unused param in DDP
223
+ ks: bool = False
224
+ nodata: bool = False # if True, will set nova=True as well
225
+ nodata_tlen: int = 320
226
+ nova: bool = False # no val, no FID
227
+ prof: int = 0 # profile
228
+ prof_freq: int = 50 # profile
229
+ tos_profiler_file_prefix: str = 'vgpt_default/'
230
+ profall: int = 0
231
+ @property
232
+ def is_vae_visualization_only(self) -> bool:
233
+ return self.v_seed > 0
234
+ v_seed: int = 0 # v_seed != 0 means the visualization-only mode
235
+ @property
236
+ def is_gpt_visualization_only(self) -> bool:
237
+ return self.g_seed > 0
238
+ g_seed: int = 0 # g_seed != 0 means the visualization-only mode
239
+ # ==================================================================================================================
240
+ # ======================== ignore these parts above since they are only for debug use ==============================
241
+ # ==================================================================================================================
242
+
243
+ @property
244
+ def gpt_training(self):
245
+ return len(self.model) > 0
246
+
247
+ def set_initial_seed(self, benchmark: bool):
248
+ torch.backends.cudnn.enabled = True
249
+ torch.backends.cudnn.benchmark = benchmark
250
+ if self.seed is None:
251
+ torch.backends.cudnn.deterministic = False
252
+ else:
253
+ seed = self.seed + (dist.get_rank()*512 if self.rand else 0)
254
+ torch.backends.cudnn.deterministic = True
255
+ os.environ['PYTHONHASHSEED'] = str(seed)
256
+ random.seed(seed)
257
+ np.random.seed(seed)
258
+ torch.manual_seed(seed)
259
+ if torch.cuda.is_available():
260
+ torch.cuda.manual_seed(seed)
261
+ torch.cuda.manual_seed_all(seed)
262
+
263
+ def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation
264
+ if self.seed is None:
265
+ return None
266
+ g = torch.Generator()
267
+ g.manual_seed(self.seed + dist.get_rank()*512)
268
+ return g
269
+
270
+ def compile_model(self, m, fast):
271
+ if fast == 0:
272
+ return m
273
+ return torch.compile(m, mode={
274
+ 1: 'reduce-overhead',
275
+ 2: 'max-autotune',
276
+ 3: 'default',
277
+ }[fast]) if hasattr(torch, 'compile') else m
278
+
279
+ def dump_log(self):
280
+ if not dist.is_local_master():
281
+ return
282
+ nd = {'is_master': dist.is_visualizer()}
283
+ r_trial, trial = str(self.real_trial_id), str(self.trial_id)
284
+ for k, v in {
285
+ 'name': self.exp_name, 'tag': self.tag, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch,
286
+ 'Lnll': self.last_Lnll, 'L1': self.last_L1,
287
+ 'Ld': self.last_Ld,
288
+ 'acc': self.acc_all, 'acc_r': self.acc_real, 'acc_f': self.acc_fake,
289
+ 'weiG': self.last_wei_g if (self.last_wei_g is None or math.isfinite(self.last_wei_g)) else -23333,
290
+ 'grad': self.grad_boom,
291
+
292
+ 'cur': self.cur_phase, 'cur_ep': self.cur_ep, 'cur_it': self.cur_it,
293
+ 'rema': self.remain_time, 'fini': self.finish_time, 'last_upd': time.strftime("%Y-%m-%d %H:%M", time.localtime()),
294
+ 'bsep': f'{self.glb_batch_size}/{self.ep}',
295
+ 'G_lrwd': f'{self.glr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.gwd:g}',
296
+ 'D_lrwd': f'{self.dlr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.dwd:g}',
297
+ 'T_lrwd': f'{self.tlr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.twd:g}',
298
+ 'diff': self.diff, 'diffs': self.diffs, 'diffs_ema': self.diffs_ema if self.diffs_ema else None,
299
+ 'opt': self.opt,
300
+ 'is_master_node': self.is_master_node,
301
+ }.items():
302
+ if hasattr(v, 'item'):v = v.item()
303
+ if v is None or (isinstance(v, str) and len(v) == 0): continue
304
+ nd[k] = v
305
+ if r_trial == trial:
306
+ nd.pop('trial', None)
307
+
308
+ with open(self.log_txt_path, 'w') as fp:
309
+ json.dump(nd, fp, indent=2)
310
+
311
+ def touch_log(self): # listener will kill me if log_txt_path is not updated for 120s
312
+ os.utime(self.log_txt_path) # about 2e-6 sec
313
+
314
+ def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
315
+ d = (OrderedDict if key_ordered else dict)()
316
+ # self.as_dict() would contain methods, but we only need variables
317
+ for k in self.class_variables.keys():
318
+ if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
319
+ d[k] = getattr(self, k)
320
+ return d
321
+
322
+ def load_state_dict(self, d: Union[OrderedDict, dict, str]):
323
+ if isinstance(d, str): # for compatibility with old version
324
+ d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l]))
325
+ for k in d.keys():
326
+ if k in {'is_large_model', 'gpt_training'}:
327
+ continue
328
+ try:
329
+ setattr(self, k, d[k])
330
+ except Exception as e:
331
+ print(f'k={k}, v={d[k]}')
332
+ raise e
333
+
334
+ @staticmethod
335
+ def set_tf32(tf32: bool):
336
+ if torch.cuda.is_available():
337
+ torch.backends.cudnn.allow_tf32 = bool(tf32)
338
+ torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
339
+ if hasattr(torch, 'set_float32_matmul_precision'):
340
+ torch.set_float32_matmul_precision('high' if tf32 else 'highest')
341
+ print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
342
+ print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
343
+ print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
344
+
345
+ def __str__(self):
346
+ s = []
347
+ for k in self.class_variables.keys():
348
+ if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
349
+ s.append(f' {k:20s}: {getattr(self, k)}')
350
+ s = '\n'.join(s)
351
+ return f'{{\n{s}\n}}\n'
352
+
353
+
354
+ def init_dist_and_get_args():
355
+ for i in range(len(sys.argv)):
356
+ if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
357
+ del sys.argv[i]
358
+ break
359
+ args = Args(explicit_bool=True).parse_args(known_only=True)
360
+ args.chunk_nodes = int(os.environ.get('CK', '') or '0')
361
+
362
+ if len(args.extra_args) > 0 and args.is_master_node == 0:
363
+ print(f'======================================================================================')
364
+ print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
365
+ print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
366
+ print(f'======================================================================================\n\n')
367
+
368
+ args.set_tf32(args.tf32)
369
+ if args.dbg:
370
+ torch.autograd.set_detect_anomaly(True)
371
+
372
+ try: os.makedirs(args.bed, exist_ok=True)
373
+ except: pass
374
+ try: os.makedirs(args.local_out_path, exist_ok=True)
375
+ except: pass
376
+
377
+ day3 = 60*24*3
378
+ dist.init_distributed_mode(local_out_path=args.local_out_path, fork=False, timeout_minutes=day3 if int(os.environ.get('LONG_DBG', '0') or '0') > 0 else 30)
379
+
380
+ args.tlen = max(args.tlen, args.nodata_tlen)
381
+ if args.zero and args.tema != 0:
382
+ args.tema = 0
383
+ print(f'======================================================================================')
384
+ print(f'======================== WARNING: args.tema:=0, due to zero={args.zero} ========================')
385
+ print(f'======================================================================================\n\n')
386
+
387
+ if args.nodata:
388
+ args.nova = True
389
+
390
+ if not args.tos_profiler_file_prefix.endswith('/'): args.tos_profiler_file_prefix += '/'
391
+
392
+ if args.alng < 0:
393
+ args.alng = args.aln
394
+
395
+ args.device = dist.get_device()
396
+ args.r_accu = 1 / args.ac # gradient accumulation
397
+ args.data_load_reso = None
398
+ args.rand |= args.seed is None
399
+ args.sche = args.sche or ('lin0' if args.gpt_training else 'cos')
400
+ if args.wp == 0:
401
+ args.wp = args.ep * 1/100
402
+
403
+ di = {
404
+ 'b': 'bilinear', 'c': 'bicubic', 'n': 'nearest', 'a': 'area', 'aa': 'area+area',
405
+ 'at': 'auto', 'auto': 'auto',
406
+ 'v': 'vae',
407
+ 'x': 'pix', 'xg': 'pix_glu', 'gx': 'pix_glu', 'g': 'pix_glu'
408
+ }
409
+
410
+ args.ada = args.ada or ('0.9_0.96' if args.gpt_training else '0.5_0.9')
411
+ args.dada = args.dada or args.ada
412
+ args.opt = args.opt.lower().strip()
413
+
414
+ if args.lbs:
415
+ bs_per_gpu = args.lbs / args.ac
416
+ else:
417
+ bs_per_gpu = args.bs / args.ac / dist.get_world_size()
418
+ bs_per_gpu = round(bs_per_gpu)
419
+ args.batch_size = bs_per_gpu
420
+ args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size()
421
+ args.workers = min(args.workers, bs_per_gpu)
422
+ args.dblr = args.dblr or args.gblr
423
+ args.glr = args.ac * args.gblr * args.glb_batch_size / 256
424
+ args.dlr = args.ac * args.dblr * args.glb_batch_size / 256
425
+ args.tlr = args.ac * args.tblr * args.glb_batch_size / 256
426
+ args.gwde = args.gwde or args.gwd
427
+ args.dwde = args.dwde or args.dwd
428
+ args.twde = args.twde or args.twd
429
+
430
+ if args.dbg_modified:
431
+ torch.autograd.set_detect_anomaly(True)
432
+ args.dbg_ks &= dist.is_local_master()
433
+ if args.dbg_ks:
434
+ args.dbg_ks_fp = open(os.path.join(args.local_out_path, 'dbg_ks.txt'), 'w')
435
+
436
+ # gpt args
437
+ if args.gpt_training:
438
+ assert args.vae_ckpt, 'VAE ckpt must be specified when training GPT'
439
+ from infinity.models import alias_dict, alias_dict_inv
440
+ if args.model in alias_dict:
441
+ args.model = alias_dict[args.model]
442
+ args.model_alias = alias_dict_inv[args.model]
443
+ else:
444
+ args.model_alias = args.model
445
+ args.model = f'infinity_{args.model}'
446
+
447
+ args.task_id = '123'
448
+ args.trial_id = '123'
449
+ args.robust_run_id = '0'
450
+ args.log_txt_path = os.path.join(args.local_out_path, 'log.txt')
451
+
452
+ ls = '[]'
453
+ if 'AUTO_RESUME' in os.environ:
454
+ ls.append(int(os.environ['AUTO_RESUME']))
455
+ ls = sorted(ls, reverse=True)
456
+ ls = [str(i) for i in ls]
457
+ args.ckpt_trials = ls
458
+ args.real_trial_id = args.trial_id if len(ls) == 0 else str(ls[-1])
459
+
460
+ args.enable_checkpointing = None if args.enable_checkpointing in [False, 0, "0"] else args.enable_checkpointing
461
+ args.enable_checkpointing = "full-block" if args.enable_checkpointing in [True, 1, "1"] else args.enable_checkpointing
462
+ assert args.enable_checkpointing in [None, "full-block", "full-attn", "self-attn"], \
463
+ f"only support no-checkpointing or full-block/full-attn checkpointing, but got {args.enable_checkpointing}."
464
+
465
+ if len(args.exp_name) == 0:
466
+ args.exp_name = os.path.basename(args.bed) or 'test_exp'
467
+
468
+ if '-' in args.exp_name:
469
+ args.tag, args.exp_name = args.exp_name.split('-', maxsplit=1)
470
+ else:
471
+ args.tag = 'UK'
472
+
473
+ if dist.is_master():
474
+ os.system(f'rm -rf {os.path.join(args.bed, "ready-node*")} {os.path.join(args.local_out_path, "ready-node*")}')
475
+
476
+ if args.sdpa_mem:
477
+ from torch.backends.cuda import enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp
478
+ enable_flash_sdp(True)
479
+ enable_mem_efficient_sdp(True)
480
+ enable_math_sdp(False)
481
+
482
+ return args
infinity/utils/csv_util.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import csv
4
+
5
+ import numpy as np
6
+
7
+
8
+ def write_dicts2csv_file(input_dict_list, csv_filename):
9
+ os.makedirs(osp.dirname(csv_filename), exist_ok=True)
10
+ with open(csv_filename, mode='w', newline='', encoding='utf-8') as file:
11
+ fieldnames = input_dict_list[0].keys()
12
+ writer = csv.DictWriter(file, fieldnames=fieldnames)
13
+ writer.writeheader()
14
+ writer.writerows(input_dict_list)
15
+ print(f'"{csv_filename}" has been written.')
16
+
17
+ def load_csv_as_dicts(csv_filename):
18
+ with open(csv_filename, mode='r', newline='', encoding='utf-8') as csvfile:
19
+ reader = csv.DictReader(csvfile)
20
+ return list(reader)
infinity/utils/dist.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import functools
3
+ import os
4
+ import sys
5
+ from typing import List
6
+ from typing import Union
7
+
8
+ import pytz
9
+ import torch
10
+ import torch.distributed as tdist
11
+ import torch.multiprocessing as mp
12
+
13
+
14
+ __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu'
15
+ __rank_str_zfill = '0'
16
+ __initialized = False
17
+
18
+
19
+ def initialized():
20
+ return __initialized
21
+
22
+
23
+ def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout_minutes=30):
24
+ global __device
25
+ if not torch.cuda.is_available():
26
+ print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
27
+ return
28
+ elif 'RANK' not in os.environ:
29
+ torch.cuda.set_device(gpu_id_if_not_distibuted)
30
+ __device = torch.empty(1).cuda().device
31
+ print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
32
+ return
33
+ # then 'RANK' must exist
34
+ global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
35
+ local_rank = global_rank % num_gpus
36
+ torch.cuda.set_device(local_rank)
37
+
38
+ # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
39
+ """
40
+ if mp.get_start_method(allow_none=True) is None:
41
+ method = 'fork' if fork else 'spawn'
42
+ print(f'[dist initialize] mp method={method}')
43
+ mp.set_start_method(method)
44
+ """
45
+ tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60))
46
+
47
+ global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill
48
+ __local_rank = local_rank
49
+ __rank, __world_size = tdist.get_rank(), tdist.get_world_size()
50
+ __rank_str_zfill = str(__rank).zfill(len(str(__world_size)))
51
+ __device = torch.device(local_rank)
52
+ __initialized = True
53
+
54
+ assert tdist.is_initialized(), 'torch.distributed is not initialized!'
55
+ print(f'[lrk={get_local_rank()}, rk={get_rank()}]')
56
+
57
+
58
+ def get_rank():
59
+ return __rank
60
+
61
+
62
+ def get_rank_given_group(group: tdist.ProcessGroup):
63
+ return tdist.get_rank(group=group)
64
+
65
+
66
+ def get_rank_str_zfill():
67
+ return __rank_str_zfill
68
+
69
+
70
+ def get_local_rank():
71
+ return __local_rank
72
+
73
+
74
+ def get_world_size():
75
+ return __world_size
76
+
77
+
78
+ def get_device():
79
+ return __device
80
+
81
+
82
+ def set_gpu_id(gpu_id: int):
83
+ if gpu_id is None: return
84
+ global __device
85
+ if isinstance(gpu_id, (str, int)):
86
+ torch.cuda.set_device(int(gpu_id))
87
+ __device = torch.empty(1).cuda().device
88
+ else:
89
+ raise NotImplementedError
90
+
91
+
92
+ def is_master():
93
+ return __rank == 0
94
+
95
+
96
+ def is_local_master():
97
+ return __local_rank == 0
98
+
99
+
100
+ def is_visualizer():
101
+ return __rank == 0
102
+ # return __rank == max(__world_size - 8, 0)
103
+
104
+
105
+ def parallelize(net, syncbn=False):
106
+ if syncbn:
107
+ net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
108
+ net = net.cuda()
109
+ net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
110
+ return net
111
+
112
+
113
+ def new_group(ranks: List[int]):
114
+ if __initialized:
115
+ return tdist.new_group(ranks=ranks)
116
+ return None
117
+
118
+
119
+ def new_local_machine_group():
120
+ if __initialized:
121
+ cur_subgroup, subgroups = tdist.new_subgroups()
122
+ return cur_subgroup
123
+ return None
124
+
125
+
126
+ def barrier():
127
+ if __initialized:
128
+ tdist.barrier()
129
+
130
+
131
+ def allreduce(t: torch.Tensor, async_op=False):
132
+ if __initialized:
133
+ if not t.is_cuda:
134
+ cu = t.detach().cuda()
135
+ ret = tdist.all_reduce(cu, async_op=async_op)
136
+ t.copy_(cu.cpu())
137
+ else:
138
+ ret = tdist.all_reduce(t, async_op=async_op)
139
+ return ret
140
+ return None
141
+
142
+
143
+ def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
144
+ if __initialized:
145
+ if not t.is_cuda:
146
+ t = t.cuda()
147
+ ls = [torch.empty_like(t) for _ in range(__world_size)]
148
+ tdist.all_gather(ls, t)
149
+ else:
150
+ ls = [t]
151
+ if cat:
152
+ ls = torch.cat(ls, dim=0)
153
+ return ls
154
+
155
+
156
+ def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
157
+ if __initialized:
158
+ if not t.is_cuda:
159
+ t = t.cuda()
160
+
161
+ t_size = torch.tensor(t.size(), device=t.device)
162
+ ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
163
+ tdist.all_gather(ls_size, t_size)
164
+
165
+ max_B = max(size[0].item() for size in ls_size)
166
+ pad = max_B - t_size[0].item()
167
+ if pad:
168
+ pad_size = (pad, *t.size()[1:])
169
+ t = torch.cat((t, t.new_empty(pad_size)), dim=0)
170
+
171
+ ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
172
+ tdist.all_gather(ls_padded, t)
173
+ ls = []
174
+ for t, size in zip(ls_padded, ls_size):
175
+ ls.append(t[:size[0].item()])
176
+ else:
177
+ ls = [t]
178
+ if cat:
179
+ ls = torch.cat(ls, dim=0)
180
+ return ls
181
+
182
+
183
+ def broadcast(t: torch.Tensor, src_rank) -> None:
184
+ if __initialized:
185
+ if not t.is_cuda:
186
+ cu = t.detach().cuda()
187
+ tdist.broadcast(cu, src=src_rank)
188
+ t.copy_(cu.cpu())
189
+ else:
190
+ tdist.broadcast(t, src=src_rank)
191
+
192
+
193
+ def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
194
+ if not initialized():
195
+ return torch.tensor([val]) if fmt is None else [fmt % val]
196
+
197
+ ts = torch.zeros(__world_size)
198
+ ts[__rank] = val
199
+ allreduce(ts)
200
+ if fmt is None:
201
+ return ts
202
+ return [fmt % v for v in ts.cpu().numpy().tolist()]
203
+
204
+
205
+ def master_only(func):
206
+ @functools.wraps(func)
207
+ def wrapper(*args, **kwargs):
208
+ force = kwargs.pop('force', False)
209
+ if force or is_master():
210
+ ret = func(*args, **kwargs)
211
+ else:
212
+ ret = None
213
+ barrier()
214
+ return ret
215
+ return wrapper
216
+
217
+
218
+ def local_master_only(func):
219
+ @functools.wraps(func)
220
+ def wrapper(*args, **kwargs):
221
+ force = kwargs.pop('force', False)
222
+ if force or is_local_master():
223
+ ret = func(*args, **kwargs)
224
+ else:
225
+ ret = None
226
+ barrier()
227
+ return ret
228
+ return wrapper
229
+
230
+
231
+ def for_visualize(func):
232
+ @functools.wraps(func)
233
+ def wrapper(*args, **kwargs):
234
+ if is_visualizer():
235
+ # with torch.no_grad():
236
+ ret = func(*args, **kwargs)
237
+ else:
238
+ ret = None
239
+ return ret
240
+ return wrapper
241
+
242
+
243
+ def finalize():
244
+ if __initialized:
245
+ tdist.destroy_process_group()
246
+
247
+
248
+ def init_distributed_mode(local_out_path, fork=False, only_sync_master=False, timeout_minutes=30):
249
+ try:
250
+ __initialize(fork=fork, timeout_minutes=timeout_minutes)
251
+ barrier()
252
+ except RuntimeError as e:
253
+ print(f'{"!"*80} dist init error (NCCL Error?), stopping training! {"!"*80}', flush=True)
254
+ raise e
255
+
256
+ if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True)
257
+ _change_builtin_print(is_local_master())
258
+ # if (is_master() if only_sync_master else is_local_master()) and local_out_path is not None and len(local_out_path):
259
+ # sys.stdout, sys.stderr = BackupStreamToFile(local_out_path, for_stdout=True), BackupStreamToFile(local_out_path, for_stdout=False)
260
+
261
+
262
+ def _change_builtin_print(is_master):
263
+ import builtins as __builtin__
264
+
265
+ builtin_print = __builtin__.print
266
+ if type(builtin_print) != type(open):
267
+ return
268
+
269
+ def prt(*args, **kwargs):
270
+ force = kwargs.pop('force', False)
271
+ clean = kwargs.pop('clean', False)
272
+ deeper = kwargs.pop('deeper', False)
273
+ if is_master or force:
274
+ if not clean:
275
+ f_back = sys._getframe().f_back
276
+ if deeper and f_back.f_back is not None:
277
+ f_back = f_back.f_back
278
+ file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
279
+ time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
280
+ builtin_print(f'{time_str} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)
281
+ else:
282
+ builtin_print(*args, **kwargs)
283
+
284
+ __builtin__.print = prt
285
+
286
+
287
+ class BackupStreamToFile(object):
288
+ def __init__(self, local_output_dir, for_stdout=True):
289
+ self.for_stdout = for_stdout
290
+ self.terminal_stream = sys.stdout if for_stdout else sys.stderr
291
+ fname = os.path.join(local_output_dir, 'b1_stdout.txt' if for_stdout else 'b2_stderr.txt')
292
+ existing = os.path.exists(fname)
293
+ self.file_stream = open(fname, 'a')
294
+ if existing:
295
+ time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
296
+ self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str} ' + '='*55 + '\n')
297
+ self.file_stream.flush()
298
+ os.system(f'ln -s {fname} /opt/tiger/run_trial/ >/dev/null 2>&1')
299
+ self.enabled = True
300
+
301
+ def write(self, message):
302
+ self.terminal_stream.write(message)
303
+ self.file_stream.write(message)
304
+
305
+ def flush(self):
306
+ self.terminal_stream.flush()
307
+ self.file_stream.flush()
308
+
309
+ def isatty(self):
310
+ return True
311
+
312
+ def close(self):
313
+ if not self.enabled:
314
+ return
315
+ self.enabled = False
316
+ self.file_stream.flush()
317
+ self.file_stream.close()
318
+ if self.for_stdout:
319
+ sys.stdout = self.terminal_stream
320
+ sys.stdout.flush()
321
+ else:
322
+ sys.stderr = self.terminal_stream
323
+ sys.stderr.flush()
324
+
325
+ def __del__(self):
326
+ self.close()
infinity/utils/dynamic_resolution.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import tqdm
4
+
5
+ vae_stride = 16
6
+ ratio2hws = {
7
+ 1.000: [(1,1),(2,2),(4,4),(6,6),(8,8),(12,12),(16,16),(20,20),(24,24),(32,32),(40,40),(48,48),(64,64)],
8
+ 1.250: [(1,1),(2,2),(3,3),(5,4),(10,8),(15,12),(20,16),(25,20),(30,24),(35,28),(45,36),(55,44),(70,56)],
9
+ 1.333: [(1,1),(2,2),(4,3),(8,6),(12,9),(16,12),(20,15),(24,18),(28,21),(36,27),(48,36),(60,45),(72,54)],
10
+ 1.500: [(1,1),(2,2),(3,2),(6,4),(9,6),(15,10),(21,14),(27,18),(33,22),(39,26),(48,32),(63,42),(78,52)],
11
+ 1.750: [(1,1),(2,2),(3,3),(7,4),(11,6),(14,8),(21,12),(28,16),(35,20),(42,24),(56,32),(70,40),(84,48)],
12
+ 2.000: [(1,1),(2,2),(4,2),(6,3),(10,5),(16,8),(22,11),(30,15),(38,19),(46,23),(60,30),(74,37),(90,45)],
13
+ 2.500: [(1,1),(2,2),(5,2),(10,4),(15,6),(20,8),(25,10),(30,12),(40,16),(50,20),(65,26),(80,32),(100,40)],
14
+ 3.000: [(1,1),(2,2),(6,2),(9,3),(15,5),(21,7),(27,9),(36,12),(45,15),(54,18),(72,24),(90,30),(111,37)],
15
+ }
16
+ predefined_t = [1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15, 17, 21]
17
+
18
+ full_ratio2hws = {}
19
+ for ratio, hws in ratio2hws.items():
20
+ full_ratio2hws[ratio] = hws
21
+ if ratio != 1.000:
22
+ full_ratio2hws[int(1/ratio*1000)/1000] = [(item[1], item[0]) for item in hws]
23
+
24
+ dynamic_resolution_h_w = {}
25
+ for ratio in full_ratio2hws:
26
+ dynamic_resolution_h_w[ratio] ={}
27
+ for ind, leng in enumerate([7, 10, 12, 13]):
28
+ h_div_w = full_ratio2hws[ratio][leng-1][0] / full_ratio2hws[ratio][leng-1][1]
29
+ assert np.abs(h_div_w-ratio) < 0.01, f'{full_ratio2hws[ratio][leng-1]}: {h_div_w} != {ratio}'
30
+ pixel = (full_ratio2hws[ratio][leng-1][0] * vae_stride, full_ratio2hws[ratio][leng-1][1] * vae_stride)
31
+ if ind == 0:
32
+ total_pixels = '0.06M'
33
+ elif ind == 1:
34
+ total_pixels = '0.25M'
35
+ elif ind == 2:
36
+ total_pixels = '0.60M'
37
+ else:
38
+ total_pixels = '1M'
39
+
40
+ scales = full_ratio2hws[ratio][:leng]
41
+ scales = [ (t, h, w) for t, (h, w) in zip(predefined_t, scales) ]
42
+ dynamic_resolution_h_w[ratio][total_pixels] = {
43
+ 'pixel': pixel,
44
+ 'scales': scales
45
+ }
46
+
47
+ h_div_w_templates = []
48
+ for h_div_w in dynamic_resolution_h_w.keys():
49
+ h_div_w_templates.append(h_div_w)
50
+ h_div_w_templates = np.array(h_div_w_templates)
51
+
52
+ def get_h_div_w_template2indices(h_div_w_list, h_div_w_templates):
53
+ indices = list(range(len(h_div_w_list)))
54
+ h_div_w_template2indices = {}
55
+ pbar = tqdm.tqdm(total=len(indices), desc='get_h_div_w_template2indices...')
56
+ for h_div_w, index in zip(h_div_w_list, indices):
57
+ pbar.update(1)
58
+ nearest_h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w-h_div_w_templates))]
59
+ if nearest_h_div_w_template_ not in h_div_w_template2indices:
60
+ h_div_w_template2indices[nearest_h_div_w_template_] = []
61
+ h_div_w_template2indices[nearest_h_div_w_template_].append(index)
62
+ for h_div_w_template_, sub_indices in h_div_w_template2indices.items():
63
+ h_div_w_template2indices[h_div_w_template_] = np.array(sub_indices)
64
+ return h_div_w_template2indices
65
+
66
+ if __name__ == '__main__':
67
+ for h_div_w_template in dynamic_resolution_h_w:
68
+ for total_pixels in dynamic_resolution_h_w[h_div_w_template]:
69
+ scales = np.array(dynamic_resolution_h_w[h_div_w_template][total_pixels]['scales'])
70
+ seq_len = np.sum(scales[:,0]*scales[:,1])
71
+ if total_pixels == '1M':
72
+ string = f'{h_div_w_template}, {total_pixels}, {dynamic_resolution_h_w[h_div_w_template][total_pixels]}, seq_len: {seq_len}'.replace(', ', ',')
73
+ print(string)
infinity/utils/large_file_util.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import time
4
+ import itertools
5
+ import shutil
6
+ import glob
7
+ import argparse
8
+
9
+ import tqdm
10
+ import numpy as np
11
+ import threading
12
+
13
+ def save_lines(lines, filename):
14
+ os.makedirs(osp.dirname(filename), exist_ok=True)
15
+ with open(filename, 'w') as f:
16
+ f.writelines(lines)
17
+ del lines
18
+
19
+ def get_part_jsonls(filepath, total_line_number, parts=512):
20
+ dirname, filename, ext = osp.dirname(filepath), osp.splitext(osp.basename(filepath))[0], osp.splitext(osp.basename(filepath))[1]
21
+ if parts == 1:
22
+ return False, {1: filepath}
23
+ save_dir = osp.join(dirname, f'{parts:04d}_parts')
24
+ chunk_id2save_files = {}
25
+ missing = False
26
+ chunk_size = int(total_line_number/parts)
27
+ for chunk_id in range(1, parts+1):
28
+ if chunk_id == parts:
29
+ num_of_lines = total_line_number - chunk_size * (parts-1)
30
+ else:
31
+ num_of_lines = chunk_size
32
+ chunk_id2save_files[chunk_id] = osp.join(save_dir, f'{filename}_{chunk_id:04d}_{parts:04d}_{num_of_lines:09d}{ext}')
33
+ if not osp.exists(chunk_id2save_files[chunk_id]):
34
+ missing = True
35
+ return missing, chunk_id2save_files
36
+
37
+ def split_large_txt_files(filepath, chunk_id2save_files):
38
+ thread_list = []
39
+ chunk_id = 1
40
+ with open(filepath, 'r') as f:
41
+ chunk = []
42
+ pbar = tqdm.tqdm(total=len(chunk_id2save_files))
43
+ for line in f:
44
+ chunk.append(line)
45
+ cur_chunk_size = int(osp.splitext(osp.basename(chunk_id2save_files[chunk_id]))[0].split('_')[-1])
46
+ if len(chunk) >= cur_chunk_size:
47
+ pbar.update(1)
48
+ thread_list.append(threading.Thread(target=save_lines, args=(chunk, chunk_id2save_files[chunk_id])))
49
+ thread_list[-1].start()
50
+ chunk = []
51
+ chunk_id += 1
52
+ if len(chunk):
53
+ import ipdb; ipdb.set_trace()
54
+ assert not len(chunk)
55
+ for thread in thread_list:
56
+ thread.join()
57
+
58
+ if __name__ == '__main__':
59
+ parser = argparse.ArgumentParser()
60
+ parser.add_argument('--jsonl_folder', type=str, default='')
61
+ parser.add_argument('--parts', type=int, default=600)
62
+ args = parser.parse_args()
63
+ for jsonl_filepath in sorted(glob.glob(osp.join(args.jsonl_folder, '*.jsonl'))):
64
+ print(jsonl_filepath)
65
+ t1 = time.time()
66
+ line_num = int(jsonl_filepath.split('_')[-1].split('.')[0])
67
+ missing, chunk_id2save_files = get_part_jsonls(jsonl_filepath, line_num, parts=args.parts)
68
+ split_large_txt_files(jsonl_filepath, chunk_id2save_files)
69
+ t2 = time.time()
70
+ print(f'split takes {t2-t1}s')
infinity/utils/load.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ import gc
3
+ import os
4
+ import os.path as osp
5
+ import random
6
+ import sys
7
+ from copy import deepcopy
8
+ from typing import Tuple, Union
9
+
10
+ import colorama
11
+ import torch
12
+ import yaml
13
+
14
+ import infinity.utils.dist as dist
15
+
16
+ from infinity.models import Infinity
17
+ from infinity.models.ema import get_ema_model
18
+ from infinity.utils import arg_util, misc
19
+ from infinity.utils.misc import os_system
20
+
21
+
22
+ def build_vae_gpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'):
23
+ if args.vae_type in [8,14,16,18,20,24,32,64,128]:
24
+ from infinity.models.bsq_vae.vae import vae_model
25
+ schedule_mode = "dynamic"
26
+ codebook_dim = args.vae_type # 18
27
+ codebook_size = 2**codebook_dim
28
+ if args.apply_spatial_patchify:
29
+ patch_size = 8
30
+ encoder_ch_mult=[1, 2, 4, 4]
31
+ decoder_ch_mult=[1, 2, 4, 4]
32
+ else:
33
+ patch_size = 16
34
+ encoder_ch_mult=[1, 2, 4, 4, 4]
35
+ decoder_ch_mult=[1, 2, 4, 4, 4]
36
+ vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
37
+ encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device)
38
+ if args.fake_vae_input:
39
+ vae_local.encoder = None
40
+ vae_local.decoder = None
41
+ torch.cuda.empty_cache()
42
+ else:
43
+ raise ValueError(f"vae_type {args.vae_type} not supported")
44
+ if force_flash: args.flash = True
45
+ gpt_kw = dict(
46
+ pretrained=False, global_pool='',
47
+ text_channels=args.Ct5, text_maxlen=args.tlen,
48
+ norm_eps=args.norm_eps, rms_norm=args.rms,
49
+ shared_aln=args.saln, head_aln=args.haln,
50
+ cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop,
51
+ cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi,
52
+ raw_scale_schedule=args.scale_schedule,
53
+ head_depth=args.dec,
54
+ top_p=args.tp, top_k=args.tk,
55
+ customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm,
56
+ checkpointing=args.enable_checkpointing,
57
+ pad_to_multiplier=args.pad_to_multiplier,
58
+ use_flex_attn=args.use_flex_attn,
59
+ batch_size=args.batch_size,
60
+ add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
61
+ use_bit_label=args.use_bit_label,
62
+ rope2d_each_sa_layer=args.rope2d_each_sa_layer,
63
+ rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
64
+ pn=args.pn,
65
+ train_h_div_w_list=args.train_h_div_w_list,
66
+ always_training_scales=args.always_training_scales,
67
+ apply_spatial_patchify=args.apply_spatial_patchify,
68
+ )
69
+ if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp
70
+ if args.hd > 0: gpt_kw['num_heads'] = args.hd
71
+
72
+ print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n')
73
+ gpt_kw['vae_local'] = vae_local
74
+
75
+ model_str = args.model.replace('vgpt', 'infinity') # legacy
76
+ print(f"{model_str=}")
77
+ if model_str.rsplit('c', maxsplit=1)[-1].isdecimal():
78
+ model_str, block_chunks = model_str.rsplit('c', maxsplit=1)
79
+ block_chunks = int(block_chunks)
80
+ else:
81
+ block_chunks = 1
82
+ gpt_kw['block_chunks'] = block_chunks
83
+
84
+ from infinity.models import Infinity
85
+ from timm.models import create_model
86
+ gpt_wo_ddp: Infinity = create_model(model_str, **gpt_kw)
87
+ if args.use_fsdp_model_ema:
88
+ gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp)
89
+ else:
90
+ gpt_wo_ddp_ema = None
91
+ gpt_wo_ddp = gpt_wo_ddp.to(device)
92
+
93
+ assert all(not p.requires_grad for p in vae_local.parameters())
94
+ assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters())
95
+
96
+ return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema
97
+
98
+
99
+ if __name__ == '__main__':
100
+ ld(sys.argv[1])
infinity/utils/lr_control.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from pprint import pformat
3
+ from typing import Tuple, List, Dict, Union
4
+
5
+ import torch.nn
6
+ import infinity.utils.dist as dist
7
+
8
+
9
+ def lr_wd_annealing(sche_type: str, optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it, wp0=0.005, wpe=0.001):
10
+ """Decay the learning rate with half-cycle cosine after warmup"""
11
+ wp_it = round(wp_it)
12
+
13
+ if cur_it < wp_it:
14
+ cur_lr = wp0 + (1-wp0) * cur_it / wp_it
15
+ else:
16
+ pasd = (cur_it - wp_it) / (max_it-1 - wp_it) # [0, 1]
17
+ rest = 1 - pasd # [1, 0]
18
+ if sche_type == 'cos':
19
+ cur_lr = wpe + (1-wpe) * (0.5 + 0.5 * math.cos(math.pi * pasd))
20
+ elif sche_type == 'lin':
21
+ T = 0.15; max_rest = 1-T
22
+ if pasd < T: cur_lr = 1
23
+ else: cur_lr = wpe + (1-wpe) * rest / max_rest # 1 to wpe
24
+ elif sche_type == 'lin0':
25
+ T = 0.05; max_rest = 1-T
26
+ if pasd < T: cur_lr = 1
27
+ else: cur_lr = wpe + (1-wpe) * rest / max_rest
28
+ elif sche_type == 'lin00':
29
+ cur_lr = wpe + (1-wpe) * rest
30
+ elif sche_type.startswith('lin'):
31
+ T = float(sche_type[3:]); max_rest = 1-T
32
+ wpe_mid = wpe + (1-wpe) * max_rest
33
+ wpe_mid = (1 + wpe_mid) / 2
34
+ if pasd < T: cur_lr = 1 + (wpe_mid-1) * pasd / T
35
+ else: cur_lr = wpe + (wpe_mid-wpe) * rest / max_rest
36
+ elif sche_type == 'exp':
37
+ T = 0.15; max_rest = 1-T
38
+ if pasd < T: cur_lr = 1
39
+ else:
40
+ expo = (pasd-T) / max_rest * math.log(wpe)
41
+ cur_lr = math.exp(expo)
42
+ else:
43
+ raise NotImplementedError(f'unknown sche_type {sche_type}')
44
+
45
+ cur_lr *= peak_lr
46
+ pasd = cur_it / (max_it-1)
47
+ cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * pasd))
48
+
49
+ inf = 1e6
50
+ min_lr, max_lr = inf, -1
51
+ min_wd, max_wd = inf, -1
52
+ for param_group in optimizer.param_groups:
53
+ param_group['lr'] = cur_lr * param_group.get('lr_sc', 1) # 'lr_sc' could be assigned
54
+ max_lr = max(max_lr, param_group['lr'])
55
+ min_lr = min(min_lr, param_group['lr'])
56
+
57
+ param_group['weight_decay'] = cur_wd * param_group.get('wd_sc', 1)
58
+ max_wd = max(max_wd, param_group['weight_decay'])
59
+ if param_group['weight_decay'] > 0:
60
+ min_wd = min(min_wd, param_group['weight_decay'])
61
+
62
+ if min_lr == inf: min_lr = -1
63
+ if min_wd == inf: min_wd = -1
64
+ return min_lr, max_lr, min_wd, max_wd
65
+
66
+
67
+ def filter_params(model, ndim_dict, nowd_keys=(), lr_scale=0.0) -> Tuple[
68
+ List[str], List[torch.nn.Parameter], List[Dict[str, Union[torch.nn.Parameter, float]]]
69
+ ]:
70
+ with_lr_scale = hasattr(model, 'get_layer_id_and_scale_exp') and 0 < lr_scale <= 1
71
+ print(f'[get_param_groups][lr decay] with_lr_scale={with_lr_scale}, lr_scale={lr_scale}')
72
+ para_groups, para_groups_dbg = {}, {}
73
+ names, paras = [], []
74
+ names_no_grad = []
75
+ count, numel = 0, 0
76
+ for name, para in model.named_parameters():
77
+ name = name.replace('_fsdp_wrapped_module.', '')
78
+ if not para.requires_grad:
79
+ names_no_grad.append(name)
80
+ continue # frozen weights
81
+ count += 1
82
+ numel += para.numel()
83
+ names.append(name)
84
+ paras.append(para)
85
+
86
+ if ndim_dict.get(name, 2) == 1 or name.endswith('bias') or any(k in name for k in nowd_keys):
87
+ cur_wd_sc, group_name = 0., 'ND'
88
+ # elif any(k in name for k in small_wd_keys):
89
+ # cur_wd_sc, group_name = small_wd, 'small_decay'
90
+ else:
91
+ cur_wd_sc, group_name = 1., 'D'
92
+
93
+ if with_lr_scale:
94
+ layer_id, scale_exp = model.get_layer_id_and_scale_exp(name)
95
+ group_name = f'layer{layer_id}_' + group_name
96
+ cur_lr_sc = lr_scale ** scale_exp
97
+ dbg = f'[layer {layer_id}][sc = {lr_scale} ** {scale_exp}]'
98
+ else:
99
+ cur_lr_sc = 1.
100
+ dbg = f'[no scale]'
101
+
102
+ if group_name not in para_groups:
103
+ para_groups[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
104
+ para_groups_dbg[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': dbg}
105
+ para_groups[group_name]['params'].append(para)
106
+ para_groups_dbg[group_name]['params'].append(name)
107
+
108
+ for g in para_groups_dbg.values():
109
+ g['params'] = pformat(', '.join(g['params']), width=200)
110
+
111
+ print(f'[get_param_groups] param_groups = \n{pformat(para_groups_dbg, indent=2, width=240)}\n')
112
+
113
+ for rk in range(dist.get_world_size()):
114
+ dist.barrier()
115
+ if dist.get_rank() == rk:
116
+ print(f'[get_param_groups][rank{dist.get_rank()}] {type(model).__name__=} {count=}, {numel=}', flush=True, force=True)
117
+ print('')
118
+
119
+ assert len(names_no_grad) == 0, f'[get_param_groups] names_no_grad = \n{pformat(names_no_grad, indent=2, width=240)}\n'
120
+ del ndim_dict
121
+ return names, paras, list(para_groups.values())
122
+
123
+
124
+ def plot():
125
+ import matplotlib.pyplot as plt
126
+ import torch.nn as nn
127
+ from torch.optim import SGD
128
+ # for sche in ('lin', 'lin0', 'lin00', 'lin0.5', 'lin0.75'):
129
+ for sche in ('lin0', ):
130
+ op = SGD(nn.Linear(3, 4).parameters(), lr=1e-3)
131
+ it, lr = [], []
132
+ iters = 500
133
+ wp_it, max_it = 1 * iters, 10 * iters
134
+ for cur_it in range(max_it):
135
+ it.append(cur_it)
136
+ lr.append(lr_wd_annealing(sche, op, 0.1, 1e-5, 1e-5, cur_it, wp_it, max_it, wpe=0.3)[0])
137
+
138
+ plt.figure()
139
+ plt.title(sche)
140
+ plt.plot(it, lr, 'b', label=sche)
141
+ plt.xlabel('it'), plt.ylabel('lr')
142
+ plt.legend()
143
+
144
+ plt.savefig('lr.jpg')
145
+
146
+
147
+ if __name__ == '__main__':
148
+ plot()
infinity/utils/misc.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import functools
3
+ import math
4
+ import os
5
+ import random
6
+ import subprocess
7
+ import sys
8
+ import threading
9
+ import time
10
+ from collections import defaultdict, deque
11
+ from typing import Iterator, List, Tuple
12
+
13
+ import numpy as np
14
+ import pytz
15
+ import torch
16
+ import torch.distributed as tdist
17
+ import torch.nn.functional as F
18
+
19
+ import infinity.utils.dist as dist
20
+
21
+ os_system = functools.partial(subprocess.call, shell=True)
22
+ def echo(info):
23
+ os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"')
24
+ def os_system_get_stdout(cmd):
25
+ return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
26
+ def os_system_get_stdout_stderr(cmd):
27
+ cnt = 0
28
+ while True:
29
+ try:
30
+ sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=30)
31
+ except subprocess.TimeoutExpired:
32
+ cnt += 1
33
+ print(f'[fetch free_port file] timeout cnt={cnt}')
34
+ else:
35
+ return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8')
36
+
37
+
38
+ def is_pow2n(x):
39
+ return x > 0 and (x & (x - 1) == 0)
40
+
41
+
42
+ def time_str(fmt='[%m-%d %H:%M:%S]'):
43
+ return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt)
44
+
45
+
46
+ class DistLogger(object):
47
+ def __init__(self, lg):
48
+ self._lg = lg
49
+
50
+ @staticmethod
51
+ def do_nothing(*args, **kwargs):
52
+ pass
53
+
54
+ def __getattr__(self, attr: str):
55
+ return getattr(self._lg, attr) if self._lg is not None else DistLogger.do_nothing
56
+
57
+ class TensorboardLogger(object):
58
+ def __init__(self, log_dir, filename_suffix):
59
+ try: import tensorflow_io as tfio
60
+ except: pass
61
+ from torch.utils.tensorboard import SummaryWriter
62
+ self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=filename_suffix)
63
+ self.step = 0
64
+
65
+ def set_step(self, step=None):
66
+ if step is not None:
67
+ self.step = step
68
+ else:
69
+ self.step += 1
70
+
71
+ def loggable(self):
72
+ return self.step == 0 or (self.step + 1) % 500 == 0
73
+
74
+ def update(self, head='scalar', step=None, **kwargs):
75
+ if step is None:
76
+ step = self.step
77
+ if not self.loggable(): return
78
+ for k, v in kwargs.items():
79
+ if v is None: continue
80
+ if hasattr(v, 'item'): v = v.item()
81
+ self.writer.add_scalar(f'{head}/{k}', v, step)
82
+
83
+ def log_tensor_as_distri(self, tag, tensor1d, step=None):
84
+ if step is None:
85
+ step = self.step
86
+ if not self.loggable(): return
87
+ try:
88
+ self.writer.add_histogram(tag=tag, values=tensor1d, global_step=step)
89
+ except Exception as e:
90
+ print(f'[log_tensor_as_distri writer.add_histogram failed]: {e}')
91
+
92
+ def log_image(self, tag, img_chw, step=None):
93
+ if step is None:
94
+ step = self.step
95
+ if not self.loggable(): return
96
+ self.writer.add_image(tag, img_chw, step, dataformats='CHW')
97
+
98
+ def flush(self):
99
+ self.writer.flush()
100
+
101
+ def close(self):
102
+ self.writer.close()
103
+
104
+
105
+ class Low_GPU_usage(object):
106
+ def __init__(self, files, sleep_secs, verbose):
107
+ pass
108
+
109
+ def early_stop(self):
110
+ pass
111
+
112
+ def __enter__(self):
113
+ return self
114
+
115
+ def __exit__(self, exc_type, exc_val, exc_tb):
116
+ pass
117
+
118
+ class TouchingDaemonDontForgetToStartMe(threading.Thread):
119
+ def __init__(self, files: List[str], sleep_secs: int, verbose=False):
120
+ super().__init__(daemon=True)
121
+ self.files = tuple(files)
122
+ self.sleep_secs = sleep_secs
123
+ self.is_finished = False
124
+ self.verbose = verbose
125
+
126
+ f_back = sys._getframe().f_back
127
+ file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
128
+ self.print_prefix = f' ({file_desc}, line{f_back.f_lineno:-4d}) @daemon@ '
129
+
130
+ def finishing(self):
131
+ self.is_finished = True
132
+
133
+ def run(self) -> None:
134
+ kw = {}
135
+ if tdist.is_initialized(): kw['clean'] = True
136
+
137
+ stt = time.time()
138
+ if self.verbose: print(f'{time_str()}{self.print_prefix}[TouchingDaemon tid={threading.get_native_id()}] start touching {self.files} per {self.sleep_secs}s ...', **kw)
139
+ while not self.is_finished:
140
+ for f in self.files:
141
+ if os.path.exists(f):
142
+ try:
143
+ os.utime(f)
144
+ fp = open(f, 'a')
145
+ fp.close()
146
+ except: pass
147
+ time.sleep(self.sleep_secs)
148
+
149
+ if self.verbose: print(f'{time_str()}{self.print_prefix}[TouchingDaemon tid={threading.get_native_id()}] finish touching after {time.time()-stt:.1f} secs {self.files} per {self.sleep_secs}s. ', **kw)
150
+
151
+
152
+ class SmoothedValue(object):
153
+ """Track a series of values and provide access to smoothed values over a
154
+ window or the global series average.
155
+ """
156
+
157
+ def __init__(self, window_size=30, fmt=None):
158
+ if fmt is None:
159
+ fmt = "{median:.4f} ({global_avg:.4f})"
160
+ self.deque = deque(maxlen=window_size)
161
+ self.total = 0.0
162
+ self.count = 0
163
+ self.fmt = fmt
164
+
165
+ def update(self, value, n=1):
166
+ self.deque.append(value)
167
+ self.count += n
168
+ self.total += value * n
169
+
170
+ def synchronize_between_processes(self):
171
+ """
172
+ Warning: does not synchronize the deque!
173
+ """
174
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
175
+ tdist.barrier()
176
+ tdist.all_reduce(t)
177
+ t = t.tolist()
178
+ self.count = int(t[0])
179
+ self.total = t[1]
180
+
181
+ @property
182
+ def median(self):
183
+ return np.median(self.deque) if len(self.deque) else 0
184
+
185
+ @property
186
+ def avg(self):
187
+ return sum(self.deque) / (len(self.deque) or 1)
188
+
189
+ @property
190
+ def global_avg(self):
191
+ return self.total / (self.count or 1)
192
+
193
+ @property
194
+ def max(self):
195
+ return max(self.deque) if len(self.deque) else 0
196
+
197
+ @property
198
+ def value(self):
199
+ return self.deque[-1] if len(self.deque) else 0
200
+
201
+ def time_preds(self, counts) -> Tuple[float, str, str]:
202
+ remain_secs = counts * self.median
203
+ return remain_secs, str(datetime.timedelta(seconds=round(remain_secs))), time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() + remain_secs))
204
+
205
+ def __str__(self):
206
+ return self.fmt.format(median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value)
207
+
208
+
209
+ class MetricLogger(object):
210
+ def __init__(self):
211
+ self.meters = defaultdict(SmoothedValue)
212
+ self.iter_end_t = time.time()
213
+ self.log_iters = set()
214
+ self.log_every_iter = False
215
+
216
+ def update(self, **kwargs):
217
+ # if it != 0 and it not in self.log_iters: return
218
+ for k, v in kwargs.items():
219
+ if v is None: continue
220
+ if hasattr(v, 'item'): v = v.item()
221
+ # assert isinstance(v, (float, int)), type(v)
222
+ self.meters[k].update(v)
223
+
224
+ def __getattr__(self, attr):
225
+ if attr in self.meters:
226
+ return self.meters[attr]
227
+ if attr in self.__dict__:
228
+ return self.__dict__[attr]
229
+ raise AttributeError("'{}' object has no attribute '{}'".format(
230
+ type(self).__name__, attr))
231
+
232
+ def __str__(self):
233
+ loss_str = []
234
+ for name, meter in self.meters.items():
235
+ if len(meter.deque):
236
+ loss_str.append(
237
+ "{}: {}".format(name, str(meter))
238
+ )
239
+ return ' '.join(loss_str)
240
+
241
+ def synchronize_between_processes(self):
242
+ for meter in self.meters.values():
243
+ meter.synchronize_between_processes()
244
+
245
+ def add_meter(self, name, meter):
246
+ self.meters[name] = meter
247
+
248
+ def log_every(self, start_it, max_iters, itrt, log_freq, log_every_iter=False, header=''): # also solve logging & skipping iterations before start_it
249
+ start_it = start_it % max_iters
250
+ self.log_iters = set(range(start_it, max_iters, log_freq))
251
+ self.log_iters.add(start_it)
252
+ self.log_iters.add(max_iters-1)
253
+ self.log_iters.add(max_iters)
254
+ self.log_every_iter = log_every_iter
255
+ self.iter_end_t = time.time()
256
+ self.iter_time = SmoothedValue(fmt='{value:.4f}')
257
+ self.data_time = SmoothedValue(fmt='{value:.3f}')
258
+ header_fmt = header + ': [{0:' + str(len(str(max_iters))) + 'd}/{1}]'
259
+
260
+ start_time = time.time()
261
+ if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'):
262
+ for it in range(start_it, max_iters):
263
+ obj = next(itrt)
264
+ if it < start_it: continue
265
+ self.data_time.update(time.time() - self.iter_end_t)
266
+ yield it, obj
267
+ self.iter_time.update(time.time() - self.iter_end_t)
268
+ if self.log_every_iter or it in self.log_iters:
269
+ eta_seconds = self.iter_time.avg * (max_iters - it)
270
+ print(f'{header_fmt.format(it, max_iters)} eta: {str(datetime.timedelta(seconds=int(eta_seconds)))} {str(self)} T: {self.iter_time.value:.3f}s dataT: {self.data_time.value*1e3:.1f}ms', flush=True)
271
+ self.iter_end_t = time.time()
272
+ else:
273
+ if isinstance(itrt, int): itrt = range(itrt)
274
+ for it, obj in enumerate(itrt):
275
+ if it < start_it:
276
+ self.iter_end_t = time.time()
277
+ continue
278
+ self.data_time.update(time.time() - self.iter_end_t)
279
+ yield it, obj
280
+ self.iter_time.update(time.time() - self.iter_end_t)
281
+ if self.log_every_iter or it in self.log_iters:
282
+ eta_seconds = self.iter_time.avg * (max_iters - it)
283
+ print(f'{header_fmt.format(it, max_iters)} eta: {str(datetime.timedelta(seconds=int(eta_seconds)))} {str(self)} T: {self.iter_time.value:.3f}s dataT: {self.data_time.value*1e3:.1f}ms', flush=True)
284
+ self.iter_end_t = time.time()
285
+ cost = time.time() - start_time
286
+ cost_str = str(datetime.timedelta(seconds=int(cost)))
287
+ print(f'{header} Cost of this ep: {cost_str} ({cost / (max_iters-start_it):.3f} s / it)', flush=True)
288
+
289
+
290
+ class NullDDP(torch.nn.Module):
291
+ def __init__(self, module, *args, **kwargs):
292
+ super(NullDDP, self).__init__()
293
+ self.module = module
294
+ self.require_backward_grad_sync = False
295
+
296
+ def forward(self, *args, **kwargs):
297
+ return self.module(*args, **kwargs)
298
+
299
+
300
+ def build_2d_sincos_position_embedding(h, w, embed_dim, temperature=10000., sc=0, verbose=True): # (1, hw**2, embed_dim)
301
+ # DiT: sc=0
302
+ # DETR: sc=2?
303
+ grid_w = torch.arange(w, dtype=torch.float32)
304
+ grid_h = torch.arange(h, dtype=torch.float32)
305
+ grid_w, grid_h = torch.meshgrid([grid_w, grid_h], indexing='ij')
306
+ if sc == 0:
307
+ scale = 1
308
+ elif sc == 1:
309
+ scale = math.pi * 2 / w
310
+ else:
311
+ scale = 1 / w
312
+ grid_w = scale * grid_w.reshape(h*w, 1) # scale * [0, 0, 0, 1, 1, 1, 2, 2, 2]
313
+ grid_h = scale * grid_h.reshape(h*w, 1) # scale * [0, 1, 2, 0, 1, 2, 0, 1, 2]
314
+
315
+ assert embed_dim % 4 == 0, f'Embed dimension ({embed_dim}) must be divisible by 4 for 2D sin-cos position embedding!'
316
+ pos_dim = embed_dim // 4
317
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
318
+ omega = (-math.log(temperature) * omega).exp()
319
+ # omega == (1/T) ** (arange(pos_dim) / pos_dim), a vector only dependent on C
320
+ out_w = grid_w * omega.view(1, pos_dim) # out_w: scale * [0*ome, 0*ome, 0*ome, 1*ome, 1*ome, 1*ome, 2*ome, 2*ome, 2*ome]
321
+ out_h = grid_h * omega.view(1, pos_dim) # out_h: scale * [0*ome, 1*ome, 2*ome, 0*ome, 1*ome, 2*ome, 0*ome, 1*ome, 2*ome]
322
+ pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
323
+ if verbose: print(f'[build_2d_sincos_position_embedding @ {hw} x {hw}] scale_type={sc}, temperature={temperature:g}, shape={pos_emb.shape}')
324
+ return pos_emb # (1, hw**2, embed_dim)
325
+
326
+
327
+ if __name__ == '__main__':
328
+ import seaborn as sns
329
+ import matplotlib.pyplot as plt
330
+ cmap_div = sns.color_palette('icefire', as_cmap=True)
331
+
332
+ scs = [0, 1, 2]
333
+ temps = [20, 50, 100, 1000]
334
+ reso = 3.0
335
+ RR, CC = len(scs), len(temps)
336
+ plt.figure(figsize=(CC * reso, RR * reso)) # figsize=(16, 16)
337
+ for row, sc in enumerate(scs):
338
+ for col, temp in enumerate(temps):
339
+ name = f'sc={sc}, T={temp}'
340
+ hw, C = 16, 512
341
+ N = hw*hw
342
+ pe = build_2d_sincos_position_embedding(hw, C, temperature=temp, sc=sc, verbose=False)[0] # N, C = 64, 16
343
+
344
+ hw2 = 16
345
+ N2 = hw2*hw2
346
+ pe2 = build_2d_sincos_position_embedding(hw2, C, temperature=temp, sc=sc, verbose=False)[0] # N, C = 64, 16
347
+ # pe2 = pe2.flip(dims=(0,))
348
+ bchw, bchw2 = F.normalize(pe.view(hw, hw, C).permute(2, 0, 1).unsqueeze(0), dim=1), F.normalize(pe2.view(hw2, hw2, C).permute(2, 0, 1).unsqueeze(0), dim=1)
349
+ dis = [
350
+ f'{F.mse_loss(bchw, F.interpolate(bchw2, size=bchw.shape[-2], mode=inter)).item():.3f}'
351
+ for inter in ('bilinear', 'bicubic', 'nearest')
352
+ ]
353
+ dis += [
354
+ f'{F.mse_loss(F.interpolate(bchw, size=bchw2.shape[-2], mode=inter), bchw2).item():.3f}'
355
+ for inter in ('area', 'nearest')
356
+ ]
357
+ print(f'[{name:^20s}] dis: {dis}')
358
+ """
359
+ [ sc=0, T=20 ] dis: ['0.010', '0.011', '0.011', '0.009', '0.010']
360
+ [ sc=0, T=100 ] dis: ['0.007', '0.007', '0.007', '0.006', '0.007']
361
+ [ sc=0, T=1000 ] dis: ['0.005', '0.005', '0.005', '0.004', '0.005']
362
+ [ sc=0, T=10000 ] dis: ['0.004', '0.004', '0.004', '0.003', '0.004']
363
+ [ sc=1, T=20 ] dis: ['0.007', '0.008', '0.008', '0.007', '0.008']
364
+ [ sc=1, T=100 ] dis: ['0.005', '0.005', '0.005', '0.005', '0.005']
365
+ [ sc=1, T=1000 ] dis: ['0.003', '0.003', '0.003', '0.003', '0.003']
366
+ [ sc=1, T=10000 ] dis: ['0.003', '0.003', '0.003', '0.003', '0.003']
367
+ [ sc=2, T=20 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000']
368
+ [ sc=2, T=100 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000']
369
+ [ sc=2, T=1000 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000']
370
+ [ sc=2, T=10000 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000']
371
+ Process finished with exit code 0
372
+ """
373
+
374
+ pe = torch.from_numpy(cmap_div(pe.T.numpy())[:, :, :3]) # C, N, 3
375
+ tar_h, tar_w = 1024, 1024
376
+ pe = pe.repeat_interleave(tar_w//pe.shape[0], dim=0).repeat_interleave(tar_h//pe.shape[1], dim=1)
377
+ plt.subplot(RR, CC, 1+row*CC+col)
378
+ plt.title(name)
379
+ plt.xlabel('hxw'), plt.ylabel('C')
380
+ plt.xticks([]), plt.yticks([])
381
+ plt.imshow(pe.mul(255).round().clamp(0, 255).byte().numpy())
382
+ plt.tight_layout(h_pad=0.02)
383
+ plt.show()
384
+
385
+
386
+ def check_randomness(args):
387
+ U = 16384
388
+ t = torch.zeros(dist.get_world_size(), 4, dtype=torch.float32, device=args.device)
389
+ t0 = torch.zeros(1, dtype=torch.float32, device=args.device).random_(U)
390
+ t[dist.get_rank(), 0] = float(random.randrange(U))
391
+ t[dist.get_rank(), 1] = float(np.random.randint(U))
392
+ t[dist.get_rank(), 2] = float(torch.randint(0, U, (1,))[0])
393
+ t[dist.get_rank(), 3] = float(t0[0])
394
+ dist.allreduce(t)
395
+ for rk in range(1, dist.get_world_size()):
396
+ assert torch.allclose(t[rk - 1], t[rk]), f't={t}'
397
+ del t0, t, U
infinity/utils/save_and_load.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import subprocess
4
+ import time
5
+ import re
6
+ from typing import List, Optional, Tuple
7
+
8
+ import torch
9
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
10
+
11
+ import glob
12
+ import shutil
13
+ from infinity.utils import arg_util
14
+ import infinity.utils.dist as dist
15
+
16
+
17
+ def glob_with_epoch_iter(pattern, recursive=False):
18
+ def extract_ep_iter(filename):
19
+ match = re.search(r'ep(\d+)-iter(\d+)', filename)
20
+ if match:
21
+ ep = int(match.group(1))
22
+ iter_idx = int(match.group(2))
23
+ return ep, iter_idx
24
+ return 0, 0
25
+ return sorted(glob.glob(pattern, recursive=recursive), key=lambda x: extract_ep_iter(os.path.basename(x)), reverse=True)
26
+
27
+
28
+ def glob_with_global_step(pattern, recursive=False):
29
+ def extract_ep_iter(filename):
30
+ match = re.search(r'global_step_(\d+)', filename)
31
+ if match:
32
+ iter_idx = int(match.group(1))
33
+ return iter_idx
34
+ return 0
35
+ return sorted(glob.glob(pattern, recursive=recursive), key=lambda x: extract_ep_iter(os.path.basename(x)), reverse=True)
36
+
37
+
38
+ class CKPTSaver(object):
39
+ def __init__(self, is_master: bool, eval_milestone: List[Tuple[float, float]]):
40
+ self.is_master = is_master
41
+ self.time_stamp = torch.tensor([time.time() - 1e5, time.time()], device=dist.get_device())
42
+ self.sp_also: subprocess.Popen = None
43
+ self.sp_best: subprocess.Popen = None
44
+ self.sp_backup: subprocess.Popen = None
45
+ self.acc_str, self.eval_milestone = '[no acc str]', eval_milestone
46
+
47
+ def sav(
48
+ self, args: arg_util.Args, g_it: int, next_ep: int, next_it: int, trainer,
49
+ acc_str: Optional[str] = None, eval_milestone: Optional[List[Tuple[float, float]]] = None,
50
+ also_save_to: str = None, best_save_to: str = None,
51
+ ):
52
+ self.time_stamp[1] = time.time()
53
+ dist.broadcast(self.time_stamp, src_rank=0)
54
+ last_save_time, cur_time = self.time_stamp.cpu().tolist()
55
+
56
+ auto_save = cur_time - last_save_time > 20 * 60
57
+ need_save = also_save_to is not None or best_save_to is not None or next_ep == args.ep or auto_save
58
+ if not need_save:
59
+ return
60
+
61
+ if acc_str is not None: self.acc_str = acc_str
62
+ if eval_milestone is not None: self.eval_milestone = eval_milestone
63
+
64
+ # fname = f'ar-ckpt-giter{g_it//1000:03d}K-ep{next_ep}-iter{next_it}-last.pth' if args.gpt_training else f'ckpt-last.pth'
65
+ pn = args.pn
66
+ fname = f'{pn}-{g_it//1000}K.pth'
67
+ local_out_ckpt = os.path.join(args.local_out_path, fname)
68
+
69
+ # NOTE: all rank should call this state_dict(), not master only!
70
+ trainer_state = trainer.state_dict()
71
+
72
+ if self.is_master:
73
+ stt = time.time()
74
+ # torch.save({
75
+ # 'args': args.state_dict(),
76
+ # 'gpt_training': args.gpt_training,
77
+ # 'arch': args.model if args.gpt_training else args.vv,
78
+ # 'epoch': next_ep,
79
+ # 'iter': next_it,
80
+ # 'trainer': trainer_state,
81
+ # 'acc_str': self.acc_str,
82
+ # 'milestones': self.eval_milestone,
83
+ # }, local_out_ckpt)
84
+ torch.save(trainer_state["gpt_fsdp"], local_out_ckpt)
85
+
86
+
87
+ # if g_it not in [1000, 5000]:
88
+ # cmd = f"aws s3 cp {local_out_ckpt} s3://hidream-user-maoqingyang/infinity/"
89
+ # os.system(cmd)
90
+ # cmd = f"rm -rf {local_out_ckpt}"
91
+ # os.system(cmd)
92
+
93
+ print(f'[CKPTSaver][rank00] start: {also_save_to=} {best_save_to=} {(next_ep == args.ep)=} {auto_save=} | see {local_out_ckpt}', flush=True)
94
+ print(f'[CKPTSaver][rank00] dbg: {args.bed=}', flush=True)
95
+ # if auto_save:
96
+ # if self.sp_backup is not None:
97
+ # self.sp_backup.wait(timeout=300); self.sp_backup.kill(); self.sp_backup.communicate()
98
+ # self.time_stamp[0] = time.time()
99
+
100
+ # def auto_sync(source_filename, target_filename):
101
+ # cmd = f'cp -r {source_filename} {target_filename}'
102
+ # self.sp_backup = subprocess.Popen(cmd, shell=True, bufsize=-1)
103
+ # print(f'[CKPTSaver] auto_save cmd: {cmd}', flush=True)
104
+
105
+ # local_files = glob.glob(f"{args.local_out_path}/*")
106
+ # for filename in local_files:
107
+ # basename = os.path.basename(filename)
108
+ # target_filename = f'{args.bed}/{basename}'
109
+ # if basename.endswith('.pth'):
110
+ # if not os.path.isfile(target_filename):
111
+ # auto_sync(filename, target_filename)
112
+ # else:
113
+ # auto_sync(filename, target_filename)
114
+ cost = time.time() - stt
115
+ print(f'[CKPTSaver][rank00] cost: {cost:.2f}s', flush=True)
116
+
117
+ del trainer_state
118
+ time.sleep(3), gc.collect(), torch.cuda.empty_cache(), time.sleep(3)
119
+ dist.barrier()
120
+
121
+
122
+ def auto_resume(args: arg_util.Args, pattern='ckpt*.pth') -> Tuple[List[str], int, int, str, List[Tuple[float, float]], dict, dict]:
123
+ info = []
124
+ resume = ''
125
+ if args.auto_resume:
126
+ for dd in (args.local_out_path, args.bed):
127
+ all_ckpt = glob_with_epoch_iter(os.path.join(dd, pattern))
128
+ if len(all_ckpt): break
129
+ if len(all_ckpt) == 0:
130
+ info.append(f'[auto_resume] no ckpt found @ {pattern}')
131
+ info.append(f'[auto_resume quit]')
132
+ else:
133
+ resume = all_ckpt[0]
134
+ info.append(f'[auto_resume] auto load from @ {resume} ...')
135
+ else:
136
+ info.append(f'[auto_resume] disabled')
137
+ info.append(f'[auto_resume quit]')
138
+
139
+ if len(resume) == 0:
140
+ return info, 0, 0, '[no acc str]', [], {}, {}
141
+
142
+ print(f'auto resume from {resume}')
143
+
144
+ try:
145
+ ckpt = torch.load(resume, map_location='cpu')
146
+ except Exception as e:
147
+ info.append(f'[auto_resume] failed, {e} @ {resume}')
148
+ if len(all_ckpt) < 2:
149
+ return info, 0, 0, '[no acc str]', [], {}, {}
150
+ try: # another chance to load from bytenas
151
+ ckpt = torch.load(all_ckpt[1], map_location='cpu')
152
+ except Exception as e:
153
+ info.append(f'[auto_resume] failed, {e} @ {all_ckpt[1]}')
154
+ return info, 0, 0, '[no acc str]', [], {}, {}
155
+
156
+ dist.barrier()
157
+ ep, it = ckpt['epoch'], ckpt['iter']
158
+ eval_milestone = ckpt.get('milestones', [])
159
+ info.append(f'[auto_resume success] resume from ep{ep}, it{it}, eval_milestone: {eval_milestone}')
160
+ return info, ep, it, ckpt.get('acc_str', '[no acc str]'), eval_milestone, ckpt['trainer'], ckpt['args']
infinity/utils/wandb_utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
+ import torch
3
+ from torchvision.utils import make_grid
4
+ import torch.distributed as dist
5
+ from PIL import Image
6
+ import os
7
+ import argparse
8
+ import hashlib
9
+ import math
10
+
11
+
12
+ def is_main_process():
13
+ return dist.get_rank() == 0
14
+
15
+ def namespace_to_dict(namespace):
16
+ return {
17
+ k: namespace_to_dict(v) if isinstance(v, argparse.Namespace) else v
18
+ for k, v in vars(namespace).items()
19
+ }
20
+
21
+
22
+ def generate_run_id(exp_name):
23
+ # https://stackoverflow.com/questions/16008670/how-to-hash-a-string-into-8-digits
24
+ return str(int(hashlib.sha256(exp_name.encode('utf-8')).hexdigest(), 16) % 10 ** 8)
25
+
26
+
27
+ def initialize(args, entity, exp_name, project_name):
28
+ config_dict = namespace_to_dict(args)
29
+ wandb.init(
30
+ entity=entity,
31
+ project=project_name,
32
+ name=exp_name,
33
+ config=config_dict,
34
+ id=generate_run_id(exp_name),
35
+ resume="allow",
36
+ )
37
+
38
+
39
+ def log(stats, step=None):
40
+ if is_main_process():
41
+ wandb.log({k: v for k, v in stats.items()}, step=step)
42
+
43
+
44
+ def log_image(name, sample, step=None):
45
+ if is_main_process():
46
+ sample = array2grid(sample)
47
+ wandb.log({f"{name}": wandb.Image(sample), "train_step": step})
48
+
49
+
50
+ def array2grid(x):
51
+ nrow = round(math.sqrt(x.size(0)))
52
+ x = make_grid(x, nrow=nrow, normalize=True, value_range=(-1,1))
53
+ x = x.mul(255).add_(0.5).clamp_(0,255).permute(1,2,0).to('cpu', torch.uint8).numpy()
54
+ return x
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ torch
3
+ torchvision
4
+ numpy
5
+ pillow
6
+ opencv-python
7
+ transformers
8
+ tokenizers
9
+ scipy
10
+ scikit-image
11
+ pyyaml
12
+ pandas
13
+ tqdm
14
+ webdataset
15
+ accelerate
16
+ xformers
17
+ bitsandbytes
18
+ jupyter
19
+ matplotlib
20
+ timm
21
+ flash_attn
tools/run_infinity.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
+ import os.path as osp
4
+ from typing import List
5
+ import math
6
+ import time
7
+ import hashlib
8
+ import yaml
9
+ import argparse
10
+ import shutil
11
+ import re
12
+
13
+ import sys
14
+ sys.path.append('./')
15
+
16
+ import cv2
17
+ import numpy as np
18
+ import torch
19
+ torch._dynamo.config.cache_size_limit=64
20
+ import pandas as pd
21
+ from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast
22
+ from PIL import Image, ImageEnhance
23
+ import torch.nn.functional as F
24
+ from torch.cuda.amp import autocast
25
+
26
+ from infinity.models.infinity import Infinity
27
+ from infinity.models.basic import *
28
+ import PIL.Image as PImage
29
+ from torchvision.transforms.functional import to_tensor
30
+ from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
31
+
32
+
33
+ def extract_key_val(text):
34
+ pattern = r'<(.+?):(.+?)>'
35
+ matches = re.findall(pattern, text)
36
+ key_val = {}
37
+ for match in matches:
38
+ key_val[match[0]] = match[1].lstrip()
39
+ return key_val
40
+
41
+ def encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt=False):
42
+ if enable_positive_prompt:
43
+ print(f'before positive_prompt aug: {prompt}')
44
+ prompt = aug_with_positive_prompt(prompt)
45
+ print(f'after positive_prompt aug: {prompt}')
46
+ print(f'prompt={prompt}')
47
+ captions = [prompt]
48
+ tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset
49
+ input_ids = tokens.input_ids.cuda(non_blocking=True)
50
+ mask = tokens.attention_mask.cuda(non_blocking=True)
51
+ text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float()
52
+ lens: List[int] = mask.sum(dim=-1).tolist()
53
+ cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
54
+ Ltext = max(lens)
55
+ kv_compact = []
56
+ for len_i, feat_i in zip(lens, text_features.unbind(0)):
57
+ kv_compact.append(feat_i[:len_i])
58
+ kv_compact = torch.cat(kv_compact, dim=0)
59
+ text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext)
60
+ return text_cond_tuple
61
+
62
+ def aug_with_positive_prompt(prompt):
63
+ for key in ['man', 'woman', 'men', 'women', 'boy', 'girl', 'child', 'person', 'human', 'adult', 'teenager', 'employee',
64
+ 'employer', 'worker', 'mother', 'father', 'sister', 'brother', 'grandmother', 'grandfather', 'son', 'daughter']:
65
+ if key in prompt:
66
+ prompt = prompt + '. very smooth faces, good looking faces, face to the camera, perfect facial features'
67
+ break
68
+ return prompt
69
+
70
+ def enhance_image(image):
71
+ for t in range(1):
72
+ contrast_image = image.copy()
73
+ contrast_enhancer = ImageEnhance.Contrast(contrast_image)
74
+ contrast_image = contrast_enhancer.enhance(1.05) # 增强对比度
75
+ color_image = contrast_image.copy()
76
+ color_enhancer = ImageEnhance.Color(color_image)
77
+ color_image = color_enhancer.enhance(1.05) # 增强饱和度
78
+ return color_image
79
+
80
+ def get_image_prefix(input_raw_features, vae, scale_schedule, apply_spatial_patchify=False):
81
+ with torch.amp.autocast('cuda', enabled = False):
82
+ if apply_spatial_patchify:
83
+ vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule]
84
+ else:
85
+ vae_scale_schedule = scale_schedule
86
+
87
+ B = input_raw_features.shape[0]
88
+ if input_raw_features.dim() == 4:
89
+ codes_out = input_raw_features.unsqueeze(2)
90
+ else:
91
+ codes_out = input_raw_features
92
+ cum_var_input = 0
93
+ gt_all_bit_indices = []
94
+
95
+ residual = F.interpolate(codes_out, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_down).contiguous()
96
+ if apply_spatial_patchify:
97
+ residual = torch.nn.functional.pixel_unshuffle(residual.squeeze(-3), 2)
98
+ x_BLC_wo_prefix = residual.reshape(*residual.shape[:2], -1).permute(0,2,1)
99
+
100
+ return x_BLC_wo_prefix
101
+
102
+ def gen_one_img(
103
+ infinity_test,
104
+ vae,
105
+ text_tokenizer,
106
+ text_encoder,
107
+ prompt,
108
+ src_img_3HW,
109
+ cfg_list=[],
110
+ tau_list=[],
111
+ negative_prompt='',
112
+ scale_schedule=None,
113
+ top_k=900,
114
+ top_p=0.97,
115
+ cfg_sc=3,
116
+ cfg_exp_k=0.0,
117
+ cfg_insertion_layer=-5,
118
+ vae_type=0,
119
+ gumbel=0,
120
+ softmax_merge_topk=-1,
121
+ gt_leak=-1,
122
+ gt_ls_Bl=None,
123
+ g_seed=None,
124
+ sampling_per_bits=1,
125
+ enable_positive_prompt=0,
126
+ apply_spatial_patchify=False,
127
+ ):
128
+ sstt = time.time()
129
+ if not isinstance(cfg_list, list):
130
+ cfg_list = [cfg_list] * len(scale_schedule)
131
+ if not isinstance(tau_list, list):
132
+ tau_list = [tau_list] * len(scale_schedule)
133
+ text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt)
134
+ if negative_prompt:
135
+ negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
136
+ else:
137
+ negative_label_B_or_BLT = None
138
+
139
+ src_img_3HW = src_img_3HW.unsqueeze(0).to('cuda', non_blocking=True)
140
+ src_img_features, _, _ = vae.encode_for_raw_features(src_img_3HW, scale_schedule=scale_schedule)
141
+ print(f'cfg: {cfg_list}, tau: {tau_list}')
142
+
143
+ src_img_prefix = get_image_prefix(src_img_features, vae, scale_schedule, apply_spatial_patchify)
144
+
145
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):
146
+ stt = time.time()
147
+ _, pred_gt, img_list = infinity_test.autoregressive_infer_cfg(
148
+ vae=vae,
149
+ scale_schedule=scale_schedule,
150
+ src_img_prefix=src_img_prefix,
151
+ label_B_or_BLT=text_cond_tuple, g_seed=g_seed,
152
+ B=1, negative_label_B_or_BLT=negative_label_B_or_BLT, force_gt_Bhw=None,
153
+ cfg_sc=cfg_sc, cfg_list=cfg_list, tau_list=tau_list, top_k=top_k, top_p=top_p,
154
+ returns_vemb=1, ratio_Bl1=None, gumbel=gumbel, norm_cfg=False,
155
+ cfg_exp_k=cfg_exp_k, cfg_insertion_layer=cfg_insertion_layer,
156
+ vae_type=vae_type, softmax_merge_topk=softmax_merge_topk,
157
+ ret_img=True, trunk_scale=1000,
158
+ gt_leak=gt_leak, gt_ls_Bl=gt_ls_Bl, inference_mode=True,
159
+ sampling_per_bits=sampling_per_bits,
160
+ )
161
+
162
+ print(f"cost: {time.time() - sstt}, infinity cost={time.time() - stt}")
163
+ img = img_list[0]
164
+ return img
165
+
166
+ def get_prompt_id(prompt):
167
+ md5 = hashlib.md5()
168
+ md5.update(prompt.encode('utf-8'))
169
+ prompt_id = md5.hexdigest()
170
+ return prompt_id
171
+
172
+ def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_fsdp'):
173
+ print('[Save slim model]')
174
+ full_ckpt = torch.load(infinity_model_path, map_location=device)
175
+ infinity_slim = full_ckpt['trainer'][key]
176
+ # ema_state_dict = cpu_d['trainer'].get('gpt_ema_fsdp', state_dict)
177
+ if not save_file:
178
+ save_file = osp.splitext(infinity_model_path)[0] + '-slim.pth'
179
+ print(f'Save to {save_file}')
180
+ torch.save(infinity_slim, save_file)
181
+ print('[Save slim model] done')
182
+ return save_file
183
+
184
+ def load_tokenizer(t5_path =''):
185
+ print(f'[Loading tokenizer and text encoder]')
186
+ text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
187
+ text_tokenizer.model_max_length = 512
188
+ text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
189
+ text_encoder.to('cuda')
190
+ text_encoder.eval()
191
+ text_encoder.requires_grad_(False)
192
+ return text_tokenizer, text_encoder
193
+
194
+ def load_infinity(
195
+ rope2d_each_sa_layer,
196
+ rope2d_normalized_by_hw,
197
+ use_scale_schedule_embedding,
198
+ pn,
199
+ use_bit_label,
200
+ add_lvl_embeding_only_first_block,
201
+ model_path='',
202
+ scale_schedule=None,
203
+ vae=None,
204
+ device='cuda',
205
+ model_kwargs=None,
206
+ text_channels=2048,
207
+ apply_spatial_patchify=0,
208
+ use_flex_attn=False,
209
+ bf16=False,
210
+ checkpoint_type='torch',
211
+ ):
212
+ print(f'[Loading Infinity]')
213
+ text_maxlen = 512
214
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
215
+ infinity_test: Infinity = Infinity(
216
+ vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
217
+ shared_aln=True, raw_scale_schedule=scale_schedule,
218
+ checkpointing='full-block',
219
+ customized_flash_attn=False,
220
+ fused_norm=True,
221
+ pad_to_multiplier=128,
222
+ use_flex_attn=use_flex_attn,
223
+ add_lvl_embeding_only_first_block=add_lvl_embeding_only_first_block,
224
+ use_bit_label=use_bit_label,
225
+ rope2d_each_sa_layer=rope2d_each_sa_layer,
226
+ rope2d_normalized_by_hw=rope2d_normalized_by_hw,
227
+ pn=pn,
228
+ apply_spatial_patchify=apply_spatial_patchify,
229
+ inference_mode=True,
230
+ train_h_div_w_list=[1.0],
231
+ **model_kwargs,
232
+ ).to(device=device)
233
+ print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')
234
+
235
+ if bf16:
236
+ for block in infinity_test.unregistered_blocks:
237
+ block.bfloat16()
238
+
239
+ infinity_test.eval()
240
+ infinity_test.requires_grad_(False)
241
+
242
+ infinity_test.cuda()
243
+ torch.cuda.empty_cache()
244
+
245
+ print(f'[Load Infinity weights]')
246
+ if checkpoint_type == 'torch':
247
+ state_dict = torch.load(model_path, map_location=device)
248
+ print(infinity_test.load_state_dict(state_dict))
249
+ elif checkpoint_type == 'torch_shard':
250
+ from transformers.modeling_utils import load_sharded_checkpoint
251
+ load_sharded_checkpoint(infinity_test, model_path, strict=False)
252
+ infinity_test.rng = torch.Generator()
253
+ return infinity_test
254
+
255
+ def transform(pil_img, tgt_h, tgt_w):
256
+ width, height = pil_img.size
257
+ if width / height <= tgt_w / tgt_h:
258
+ resized_width = tgt_w
259
+ resized_height = int(tgt_w / (width / height))
260
+ else:
261
+ resized_height = tgt_h
262
+ resized_width = int((width / height) * tgt_h)
263
+ pil_img = pil_img.resize((resized_width, resized_height), resample=PImage.LANCZOS)
264
+ # crop the center out
265
+ arr = np.array(pil_img)
266
+ crop_y = (arr.shape[0] - tgt_h) // 2
267
+ crop_x = (arr.shape[1] - tgt_w) // 2
268
+ im = to_tensor(arr[crop_y: crop_y + tgt_h, crop_x: crop_x + tgt_w])
269
+ return im.add(im).add_(-1)
270
+
271
+ def joint_vi_vae_encode_decode(vae, image_path, scale_schedule, device, tgt_h, tgt_w):
272
+ pil_image = Image.open(image_path).convert('RGB')
273
+ inp = transform(pil_image, tgt_h, tgt_w)
274
+ inp = inp.unsqueeze(0).to(device)
275
+ scale_schedule = [(item[0], item[1], item[2]) for item in scale_schedule]
276
+ t1 = time.time()
277
+ h, z, _, all_bit_indices, _, infinity_input = vae.encode(inp, scale_schedule=scale_schedule)
278
+ t2 = time.time()
279
+ recons_img = vae.decode(z)[0]
280
+ if len(recons_img.shape) == 4:
281
+ recons_img = recons_img.squeeze(1)
282
+ print(f'recons: z.shape: {z.shape}, recons_img shape: {recons_img.shape}')
283
+ t3 = time.time()
284
+ print(f'vae encode takes {t2-t1:.2f}s, decode takes {t3-t2:.2f}s')
285
+ recons_img = (recons_img + 1) / 2
286
+ recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)
287
+ gt_img = (inp[0] + 1) / 2
288
+ gt_img = gt_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)
289
+ print(recons_img.shape, gt_img.shape)
290
+ return gt_img, recons_img, all_bit_indices
291
+
292
+ def load_visual_tokenizer(args):
293
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
294
+ # load vae
295
+ if args.vae_type in [14,16,18,20,24,32,64]:
296
+ from infinity.models.bsq_vae.vae import vae_model
297
+ schedule_mode = "dynamic"
298
+ codebook_dim = args.vae_type
299
+ codebook_size = 2**codebook_dim
300
+ if args.apply_spatial_patchify:
301
+ patch_size = 8
302
+ encoder_ch_mult=[1, 2, 4, 4]
303
+ decoder_ch_mult=[1, 2, 4, 4]
304
+ else:
305
+ patch_size = 16
306
+ encoder_ch_mult=[1, 2, 4, 4, 4]
307
+ decoder_ch_mult=[1, 2, 4, 4, 4]
308
+ vae = vae_model(args.vae_path, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
309
+ encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(device)
310
+ else:
311
+ raise ValueError(f'vae_type={args.vae_type} not supported')
312
+ return vae
313
+
314
+ def load_transformer(vae, args):
315
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
316
+ model_path = args.model_path
317
+ if args.checkpoint_type == 'torch':
318
+ # copy large model to local; save slim to local; and copy slim to nas; load local slim model
319
+ if osp.exists(args.cache_dir):
320
+ local_model_path = osp.join(args.cache_dir, 'tmp', model_path.replace('/', '_'))
321
+ else:
322
+ local_model_path = model_path
323
+ if args.enable_model_cache:
324
+ slim_model_path = model_path.replace('ar-', 'slim-')
325
+ local_slim_model_path = local_model_path.replace('ar-', 'slim-')
326
+ os.makedirs(osp.dirname(local_slim_model_path), exist_ok=True)
327
+ print(f'model_path: {model_path}, slim_model_path: {slim_model_path}')
328
+ print(f'local_model_path: {local_model_path}, local_slim_model_path: {local_slim_model_path}')
329
+ if not osp.exists(local_slim_model_path):
330
+ if osp.exists(slim_model_path):
331
+ print(f'copy {slim_model_path} to {local_slim_model_path}')
332
+ shutil.copyfile(slim_model_path, local_slim_model_path)
333
+ else:
334
+ if not osp.exists(local_model_path):
335
+ print(f'copy {model_path} to {local_model_path}')
336
+ shutil.copyfile(model_path, local_model_path)
337
+ save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
338
+ print(f'copy {local_slim_model_path} to {slim_model_path}')
339
+ if not osp.exists(slim_model_path):
340
+ shutil.copyfile(local_slim_model_path, slim_model_path)
341
+ os.remove(local_model_path)
342
+ os.remove(model_path)
343
+ slim_model_path = local_slim_model_path
344
+ else:
345
+ slim_model_path = model_path
346
+ print(f'load checkpoint from {slim_model_path}')
347
+ elif args.checkpoint_type == 'torch_shard':
348
+ slim_model_path = model_path
349
+
350
+ if args.model_type == 'infinity_2b':
351
+ kwargs_model = dict(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, mlp_ratio=4, block_chunks=8) # 2b model
352
+ elif args.model_type == 'infinity_8b':
353
+ kwargs_model = dict(depth=40, embed_dim=3584, num_heads=28, drop_path_rate=0.1, mlp_ratio=4, block_chunks=8)
354
+ elif args.model_type == 'infinity_layer12':
355
+ kwargs_model = dict(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
356
+ elif args.model_type == 'infinity_layer16':
357
+ kwargs_model = dict(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
358
+ elif args.model_type == 'infinity_layer24':
359
+ kwargs_model = dict(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
360
+ elif args.model_type == 'infinity_layer32':
361
+ kwargs_model = dict(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
362
+ elif args.model_type == 'infinity_layer40':
363
+ kwargs_model = dict(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
364
+ elif args.model_type == 'infinity_layer48':
365
+ kwargs_model = dict(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
366
+ infinity = load_infinity(
367
+ rope2d_each_sa_layer=args.rope2d_each_sa_layer,
368
+ rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
369
+ use_scale_schedule_embedding=args.use_scale_schedule_embedding,
370
+ pn=args.pn,
371
+ use_bit_label=args.use_bit_label,
372
+ add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
373
+ model_path=slim_model_path,
374
+ scale_schedule=None,
375
+ vae=vae,
376
+ device=device,
377
+ model_kwargs=kwargs_model,
378
+ text_channels=args.text_channels,
379
+ apply_spatial_patchify=args.apply_spatial_patchify,
380
+ use_flex_attn=args.use_flex_attn,
381
+ bf16=args.bf16,
382
+ checkpoint_type=args.checkpoint_type,
383
+ )
384
+ return infinity
385
+
386
+ def add_common_arguments(parser):
387
+ parser.add_argument('--cfg', type=str, default='3')
388
+ parser.add_argument('--tau', type=float, default=1)
389
+ parser.add_argument('--pn', type=str, required=True, choices=['0.06M', '0.25M', '1M'])
390
+ parser.add_argument('--model_path', type=str, required=True)
391
+ parser.add_argument('--cfg_insertion_layer', type=int, default=0)
392
+ parser.add_argument('--vae_type', type=int, default=1)
393
+ parser.add_argument('--vae_path', type=str, default='')
394
+ parser.add_argument('--add_lvl_embeding_only_first_block', type=int, default=0, choices=[0,1])
395
+ parser.add_argument('--use_bit_label', type=int, default=1, choices=[0,1])
396
+ parser.add_argument('--model_type', type=str, default='infinity_2b')
397
+ parser.add_argument('--rope2d_each_sa_layer', type=int, default=1, choices=[0,1])
398
+ parser.add_argument('--rope2d_normalized_by_hw', type=int, default=2, choices=[0,1,2])
399
+ parser.add_argument('--use_scale_schedule_embedding', type=int, default=0, choices=[0,1])
400
+ parser.add_argument('--sampling_per_bits', type=int, default=1, choices=[1,2,4,8,16])
401
+ parser.add_argument('--text_encoder_ckpt', type=str, default='')
402
+ parser.add_argument('--text_channels', type=int, default=2048)
403
+ parser.add_argument('--apply_spatial_patchify', type=int, default=0, choices=[0,1])
404
+ parser.add_argument('--h_div_w_template', type=float, default=1.000)
405
+ parser.add_argument('--use_flex_attn', type=int, default=0, choices=[0,1])
406
+ parser.add_argument('--enable_positive_prompt', type=int, default=0, choices=[0,1])
407
+ parser.add_argument('--cache_dir', type=str, default='/dev/shm')
408
+ parser.add_argument('--enable_model_cache', type=int, default=0, choices=[0,1])
409
+ parser.add_argument('--checkpoint_type', type=str, default='torch')
410
+ parser.add_argument('--seed', type=int, default=0)
411
+ parser.add_argument('--bf16', type=int, default=1, choices=[0,1])
412
+
413
+
414
+
415
+ if __name__ == '__main__':
416
+ parser = argparse.ArgumentParser()
417
+ add_common_arguments(parser)
418
+ parser.add_argument('--prompt', type=str, default='a dog')
419
+ parser.add_argument('--src_image_path', type=str, default='./source.jpg')
420
+ parser.add_argument('--tgt_image_path', type=str, default='./target.jpg')
421
+ parser.add_argument('--save_file', type=str, default='./tmp.jpg')
422
+ args = parser.parse_args()
423
+
424
+ # parse cfg
425
+ args.cfg = list(map(float, args.cfg.split(',')))
426
+ if len(args.cfg) == 1:
427
+ args.cfg = args.cfg[0]
428
+
429
+ if args.pn == '0.06M':
430
+ h, w = 256, 256
431
+ elif args.pn == '0.25M':
432
+ h, w = 512, 512
433
+ elif args.pn == '1M':
434
+ h, w = 1024, 1024
435
+
436
+ from infinity.dataset.dataset_t2i_iterable import transform
437
+ with open(args.src_image_path, 'rb') as f:
438
+ src_img: PImage.Image = PImage.open(f)
439
+ src_img = src_img.convert('RGB')
440
+ src_img_3HW = transform(src_img, h, w)
441
+
442
+ # src_img = (src_img_3HW + 1) / 2
443
+ # src_img = src_img.permute(1, 2, 0).mul_(255).to(torch.uint8).flip(dims=(2,))
444
+ # cv2.imwrite("test.jpg", src_img.cpu().numpy())
445
+
446
+ # load text encoder
447
+ text_tokenizer, text_encoder = load_tokenizer(t5_path =args.text_encoder_ckpt)
448
+ # load vae
449
+ vae = load_visual_tokenizer(args)
450
+ # load infinity
451
+ infinity = load_transformer(vae, args)
452
+
453
+ scale_schedule = dynamic_resolution_h_w[args.h_div_w_template][args.pn]['scales']
454
+ scale_schedule = [ (1, h, w) for (_, h, w) in scale_schedule]
455
+
456
+ with autocast(dtype=torch.bfloat16):
457
+ with torch.no_grad():
458
+ generated_image = gen_one_img(
459
+ infinity,
460
+ vae,
461
+ text_tokenizer,
462
+ text_encoder,
463
+ args.prompt,
464
+ src_img_3HW,
465
+ g_seed=args.seed,
466
+ gt_leak=0,
467
+ gt_ls_Bl=None,
468
+ cfg_list=args.cfg,
469
+ tau_list=args.tau,
470
+ scale_schedule=scale_schedule,
471
+ cfg_insertion_layer=[args.cfg_insertion_layer],
472
+ vae_type=args.vae_type,
473
+ sampling_per_bits=args.sampling_per_bits,
474
+ enable_positive_prompt=args.enable_positive_prompt,
475
+ )
476
+ os.makedirs(osp.dirname(osp.abspath(args.save_file)), exist_ok=True)
477
+ cv2.imwrite(args.save_file, generated_image.cpu().numpy())
478
+ print(f'Save to {osp.abspath(args.save_file)}')