svjack commited on
Commit
9a81c97
·
verified ·
1 Parent(s): fe019d0

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +9 -0
  2. .gitignore +4 -0
  3. .ipynb_checkpoints/README-checkpoint.md +129 -0
  4. .ipynb_checkpoints/Untitled-checkpoint.ipynb +87 -0
  5. 20250205-041232_1234.mp4 +3 -0
  6. 20250205-043500_1234.mp4 +3 -0
  7. Mavuika_im_lora_dir/Mavuika_single_im_lora-000010.safetensors +3 -0
  8. Mavuika_im_lora_dir/Mavuika_single_im_lora-000011.safetensors +3 -0
  9. Mavuika_im_lora_dir/Mavuika_single_im_lora-000012.safetensors +3 -0
  10. Mavuika_im_lora_dir/Mavuika_single_im_lora-000013.safetensors +3 -0
  11. Mavuika_im_lora_dir/Mavuika_single_im_lora-000014.safetensors +3 -0
  12. Mavuika_im_lora_dir/Mavuika_single_im_lora-000015.safetensors +3 -0
  13. Mavuika_im_lora_dir/Mavuika_single_im_lora-000016.safetensors +3 -0
  14. Mavuika_im_lora_dir/Mavuika_single_im_lora-000017.safetensors +3 -0
  15. Mavuika_im_lora_dir/Mavuika_single_im_lora-000018.safetensors +3 -0
  16. Mavuika_im_lora_dir/Mavuika_single_im_lora-000019.safetensors +3 -0
  17. Mavuika_im_lora_dir/Mavuika_single_im_lora-000020.safetensors +3 -0
  18. Mavuika_im_lora_dir/Mavuika_single_im_lora-000021.safetensors +3 -0
  19. Mavuika_im_lora_dir/Mavuika_single_im_lora-000022.safetensors +3 -0
  20. Mavuika_im_lora_dir/Mavuika_single_im_lora-000023.safetensors +3 -0
  21. Mavuika_im_lora_dir/Mavuika_single_im_lora-000024.safetensors +3 -0
  22. Mavuika_im_lora_dir/Mavuika_single_im_lora-000025.safetensors +3 -0
  23. Mavuika_im_lora_dir/Mavuika_single_im_lora-000026.safetensors +3 -0
  24. Mavuika_im_lora_dir/Mavuika_single_im_lora-000027.safetensors +3 -0
  25. Mavuika_im_lora_dir/Mavuika_single_im_lora-000028.safetensors +3 -0
  26. Mavuika_im_lora_dir/Mavuika_single_im_lora-000029.safetensors +3 -0
  27. Mavuika_im_lora_dir/Mavuika_single_im_lora-000030.safetensors +3 -0
  28. Mavuika_im_lora_dir/Mavuika_single_im_lora-000031.safetensors +3 -0
  29. Mavuika_im_lora_dir/Mavuika_single_im_lora-000032.safetensors +3 -0
  30. Mavuika_im_lora_dir/Mavuika_single_im_lora-000033.safetensors +3 -0
  31. Mavuika_im_lora_dir/Mavuika_single_im_lora-000034.safetensors +3 -0
  32. Mavuika_im_lora_dir/Mavuika_single_im_lora-000035.safetensors +3 -0
  33. Mavuika_im_lora_dir/Mavuika_single_im_lora-000036.safetensors +3 -0
  34. Mavuika_im_lora_dir/Mavuika_single_im_lora-000037.safetensors +3 -0
  35. Mavuika_im_lora_dir/Mavuika_single_im_lora-000038.safetensors +3 -0
  36. Mavuika_im_lora_dir/Mavuika_single_im_lora-000039.safetensors +3 -0
  37. Mavuika_im_lora_dir/Mavuika_single_im_lora.safetensors +3 -0
  38. README.md +129 -0
  39. cache_latents.py +245 -0
  40. cache_text_encoder_outputs.py +135 -0
  41. convert_lora.py +129 -0
  42. dataset/__init__.py +0 -0
  43. dataset/config_utils.py +359 -0
  44. dataset/dataset_config.md +293 -0
  45. dataset/image_video_dataset.py +1255 -0
  46. hunyuan_model/__init__.py +0 -0
  47. hunyuan_model/activation_layers.py +23 -0
  48. hunyuan_model/attention.py +230 -0
  49. hunyuan_model/autoencoder_kl_causal_3d.py +609 -0
  50. hunyuan_model/embed_layers.py +132 -0
.gitattributes CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 20250131-122504_1234.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ 20250131-125418_1234.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ 20250131-130555_1234.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ 20250203-092003_1234.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ 20250203-112055_1234.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ 20250203-152222_1234.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ 20250203-153526_1234.mp4 filter=lfs diff=lfs merge=lfs -text
43
+ 20250205-041232_1234.mp4 filter=lfs diff=lfs merge=lfs -text
44
+ 20250205-043500_1234.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ .venv
3
+ venv/
4
+ logs/
.ipynb_checkpoints/README-checkpoint.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Genshin_Impact_Mavuika HunyuanVideo LoRA
2
+
3
+ This repository contains the necessary setup and scripts to generate videos using the HunyuanVideo model with a LoRA (Low-Rank Adaptation) fine-tuned for Mavuika. Below are the instructions to install dependencies, download models, and run the demo.
4
+
5
+ ---
6
+
7
+ ## Installation
8
+
9
+ ### Step 1: Install System Dependencies
10
+ Run the following command to install required system packages:
11
+ ```bash
12
+ sudo apt-get update && sudo apt-get install git-lfs ffmpeg cbm
13
+ ```
14
+
15
+ ### Step 2: Clone the Repository
16
+ Clone the repository and navigate to the project directory:
17
+ ```bash
18
+ git clone https://huggingface.co/svjack/Genshin_Impact_Mavuika_HunyuanVideo_lora
19
+ cd Genshin_Impact_Mavuika_HunyuanVideo_lora
20
+ ```
21
+
22
+ ### Step 3: Install Python Dependencies
23
+ Install the required Python packages:
24
+ ```bash
25
+ conda create -n py310 python=3.10
26
+ conda activate py310
27
+ pip install ipykernel
28
+ python -m ipykernel install --user --name py310 --display-name "py310"
29
+
30
+ pip install -r requirements.txt
31
+ pip install ascii-magic matplotlib tensorboard huggingface_hub
32
+ pip install moviepy==1.0.3
33
+ pip install sageattention==1.0.6
34
+
35
+ pip install torch==2.5.0 torchvision
36
+ ```
37
+
38
+ ---
39
+
40
+ ## Download Models
41
+
42
+ ### Step 1: Download HunyuanVideo Model
43
+ Download the HunyuanVideo model and place it in the `ckpts` directory:
44
+ ```bash
45
+ huggingface-cli download tencent/HunyuanVideo --local-dir ./ckpts
46
+ ```
47
+
48
+ ### Step 2: Download LLaVA Model
49
+ Download the LLaVA model and preprocess it:
50
+ ```bash
51
+ cd ckpts
52
+ huggingface-cli download xtuner/llava-llama-3-8b-v1_1-transformers --local-dir ./llava-llama-3-8b-v1_1-transformers
53
+ wget https://raw.githubusercontent.com/Tencent/HunyuanVideo/refs/heads/main/hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py
54
+ python preprocess_text_encoder_tokenizer_utils.py --input_dir llava-llama-3-8b-v1_1-transformers --output_dir text_encoder
55
+ ```
56
+
57
+ ### Step 3: Download CLIP Model
58
+ Download the CLIP model for the text encoder:
59
+ ```bash
60
+ huggingface-cli download openai/clip-vit-large-patch14 --local-dir ./text_encoder_2
61
+ ```
62
+
63
+ ---
64
+
65
+ ## Demo
66
+
67
+ ### Generate Video 1: Mavuika
68
+ Run the following command to generate a video of Mavuika:
69
+ ```bash
70
+ python hv_generate_video.py \
71
+ --fp8 \
72
+ --video_size 544 960 \
73
+ --video_length 60 \
74
+ --infer_steps 30 \
75
+ --prompt "Mavuika, featuring long, wavy red hair with golden highlights and large, star-shaped earrings. Mavuika wears dark sunglasses, a black choker, and a black leather glove on their left hand. Their attire includes a black and gold armor-like top with intricate designs. The background is a gradient of soft white to light blue, emphasizing Mavuika's confident expression and stylish appearance." \
76
+ --save_path . \
77
+ --output_type both \
78
+ --dit ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \
79
+ --attn_mode sdpa \
80
+ --vae ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt \
81
+ --vae_chunk_size 32 \
82
+ --vae_spatial_tile_sample_min_size 128 \
83
+ --text_encoder1 ckpts/text_encoder \
84
+ --text_encoder2 ckpts/text_encoder_2 \
85
+ --seed 1234 \
86
+ --lora_multiplier 1.0 \
87
+ --lora_weight Mavuika_im_lora_dir/Mavuika_single_im_lora-000035.safetensors
88
+
89
+ ```
90
+
91
+
92
+ <video controls autoplay src="https://huggingface.co/svjack/Genshin_Impact_Mavuika_HunyuanVideo_lora/resolve/main/20250205-041232_1234.mp4"></video>
93
+
94
+
95
+ ### Generate Video 2: Mavuika Sun
96
+ Run the following command to generate a video of KAEDEHARA_KAZUHA:
97
+ ```bash
98
+ python hv_generate_video.py \
99
+ --fp8 \
100
+ --video_size 544 960 \
101
+ --video_length 60 \
102
+ --infer_steps 30 \
103
+ --prompt "Fantastic artwork of Mavuika, featuring long, wavy red hair with golden highlights and large, star-shaped earrings. Mavuika wears dark sunglasses, a black choker, and a black leather glove on their left hand. Their attire includes a black and gold armor-like top with intricate designs, standing confidently in a warm sunset-lit rural village. The background transitions into the interior of a futuristic spaceship, blending the rustic and sci-fi elements seamlessly. The gradient of soft white to light blue in the sky enhances Mavuika's stylish and commanding presence." \
104
+ --save_path . \
105
+ --output_type both \
106
+ --dit ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \
107
+ --attn_mode sdpa \
108
+ --vae ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt \
109
+ --vae_chunk_size 32 \
110
+ --vae_spatial_tile_sample_min_size 128 \
111
+ --text_encoder1 ckpts/text_encoder \
112
+ --text_encoder2 ckpts/text_encoder_2 \
113
+ --seed 1234 \
114
+ --lora_multiplier 1.0 \
115
+ --lora_weight Mavuika_im_lora_dir/Mavuika_single_im_lora-000035.safetensors
116
+ ```
117
+
118
+
119
+ <video controls autoplay src="https://huggingface.co/svjack/Genshin_Impact_Mavuika_HunyuanVideo_lora/resolve/main/20250205-043500_1234.mp4"></video>
120
+
121
+
122
+ ---
123
+
124
+ ## Notes
125
+ - Ensure you have sufficient GPU resources for video generation.
126
+ - Adjust the `--video_size`, `--video_length`, and `--infer_steps` parameters as needed for different output qualities and lengths.
127
+ - The `--prompt` parameter can be modified to generate videos with different scenes or actions.
128
+
129
+ ---
.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "1ad678b1-90f1-4382-afe3-71e101c1f41a",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "python hv_generate_video.py \\\n",
11
+ " --fp8 \\\n",
12
+ " --video_size 544 960 \\\n",
13
+ " --video_length 60 \\\n",
14
+ " --infer_steps 30 \\\n",
15
+ " --prompt \"fantastic artwork of a handsome man img. warm sunset in a rural village. the interior of a futuristic spaceship in the background.\" \\\n",
16
+ " --save_path . \\\n",
17
+ " --output_type both \\\n",
18
+ " --dit ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \\\n",
19
+ " --attn_mode sdpa \\\n",
20
+ " --vae ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt \\\n",
21
+ " --vae_chunk_size 32 \\\n",
22
+ " --vae_spatial_tile_sample_min_size 128 \\\n",
23
+ " --text_encoder1 ckpts/text_encoder \\\n",
24
+ " --text_encoder2 ckpts/text_encoder_2 \\\n",
25
+ " --seed 1234 \\\n",
26
+ " --lora_multiplier 1.0 \\\n",
27
+ " --lora_weight Xiang_CID_im_lora_dir/Xiang_CID_im_lora_dir/Xiang_CID_single_im_lora-000004.safetensors\n"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "id": "a0387d95-f527-47c2-8713-6b74d3a0126e",
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "python hv_generate_video.py \\\n",
38
+ " --fp8 \\\n",
39
+ " --video_size 544 960 \\\n",
40
+ " --video_length 60 \\\n",
41
+ " --infer_steps 30 \\\n",
42
+ " --prompt \"surrealist painting of a handsome man img. underwater glow, deep sea. a peaceful zen garden with koi pond in the background.\" \\\n",
43
+ " --save_path . \\\n",
44
+ " --output_type both \\\n",
45
+ " --dit ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \\\n",
46
+ " --attn_mode sdpa \\\n",
47
+ " --vae ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt \\\n",
48
+ " --vae_chunk_size 32 \\\n",
49
+ " --vae_spatial_tile_sample_min_size 128 \\\n",
50
+ " --text_encoder1 ckpts/text_encoder \\\n",
51
+ " --text_encoder2 ckpts/text_encoder_2 \\\n",
52
+ " --seed 1234 \\\n",
53
+ " --lora_multiplier 1.0 \\\n",
54
+ " --lora_weight Xiang_CID_im_lora_dir/Xiang_CID_im_lora_dir/Xiang_CID_single_im_lora-000010.safetensors\n"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "id": "780799d2-d8d9-4dcd-9f71-f5ee00f52a31",
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": []
64
+ }
65
+ ],
66
+ "metadata": {
67
+ "kernelspec": {
68
+ "display_name": "py310",
69
+ "language": "python",
70
+ "name": "py310"
71
+ },
72
+ "language_info": {
73
+ "codemirror_mode": {
74
+ "name": "ipython",
75
+ "version": 3
76
+ },
77
+ "file_extension": ".py",
78
+ "mimetype": "text/x-python",
79
+ "name": "python",
80
+ "nbconvert_exporter": "python",
81
+ "pygments_lexer": "ipython3",
82
+ "version": "3.10.16"
83
+ }
84
+ },
85
+ "nbformat": 4,
86
+ "nbformat_minor": 5
87
+ }
20250205-041232_1234.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5a9d4a1cd062cb7ce96f990ad45d9c6b2f47e098eff4b924ee99a16b2e10d1e
3
+ size 1087467
20250205-043500_1234.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fa3c5fc487581d1780aace1397d83c98ef839a62ea41906c3885d501a8f4940
3
+ size 1169938
Mavuika_im_lora_dir/Mavuika_single_im_lora-000010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb91acd4b648c74e900fdecfad91cc0747527ff2526e0f1ebf33063a6a3fd7c0
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000011.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2c213e02f035b426554de4174597938779d6b8e6f875c912b8be3b44bdb0581
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000012.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a06bc717fef9fb488af32d90bf916248a53400b71893bad24fe2225d684f24b
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000013.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f450efe7d60a4947564099fc0ee91d2068bba88e96a5c7b6e17bd4ceb5e1fe2a
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000014.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b46777324943392c8f87849edf2f033982748c2ab219150c635dbb97ce97801b
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000015.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed651d914fe3acde3c3b8c02d1fd0df1c45cc484cb281809d57e8e6538a6797a
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45d33d93f77235c6d02bf7b3396285a4883996ee1dc47dde3cb17eb5d3723ea1
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000017.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2414c86d76f7078cac639cabdb3f50b51298efd0cb166d8d443b480af082e99b
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000018.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e490858d84600fdd755e54081562fbca66ad830e8655410ee9e3c13d8ffc2ebe
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000019.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e324c137a77b8a385ddc8e50feb4c440a676fa547fc3f22682fdff48a6618e9
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79a2e8c688e158ae58ac87fb047667f124bcda734ce30f6f67f838a88040deb0
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38e7e22bef98c62cbb2e210e96c82c79ebe922c28ee0077be32e4fe8e1eb1e82
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000022.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09d0eb593087bde6a25d395bbfc28b9c6e5ec537c9956e70f5bb4125c7486e0b
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f145149d88d80000a6edc4c33d3d5b4f1b2df6c61164b61ece4585245b986d18
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19736b4e36af6344858d5e1c9fc54265f887f3a366b2af1a9bb75b6a00e09662
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000025.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed416455dc7056bac87c7ec8e91b3ec4c1df7b8446a9da57ef799313c64c033f
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc74f28d2c03112bd8b0cab464f041e52345f1437cb3e12f3ee297840bd7b42a
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000027.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:998421df3704704d20a6cc84702ce5a365d34df0de2a91bf6454cac3155398f7
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000028.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ccf5e72f56c06a863b51c71d8070631d13a12827fbbb3ab4d94cb24e1e6de95
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000029.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05dab2db6bfe48499ca995a2efb0e2fe8a3a7aa11f79a4d39a7fdc79594d3243
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000030.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:231001ecfe296fead0d6852f0293850ce67d27ad2e6c8bb15fd2dde47bba52e4
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98db745553e99a7f493bf115242b841ed8ac2a0f9c8c402894832851fa4f59bd
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9112d516c343773f831bdcece1840c2c02be49fda254ea04a532d72d09b404e
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c72cecc9dcb80e5ff778e3035959a3b10ee200e967510c648ab98158964af2a5
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000034.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01a453cb964672d914979fd1dbb6ebaac4ef501077a87c2964695aaad09ef493
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000035.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:373d976b71331525b490dd8195273a015f82beb8d34ddb331a64bd40c2fdb1b5
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000036.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b50e263d65ccbad83a0493e19db1ad412d25902a1574da76cf1034ce61e2e348
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000037.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b610212816f718b05ce45804182bed930bf0974e17de6607816ab44468e7b047
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000038.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37ead1a03403354013a479bd9fc34bdff5eaf95643a9db213a2d64fe227178bb
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora-000039.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fffbf16bc23db5fec01421202c5284769f3910b4b1640af2e64484de3a1a69e1
3
+ size 322557568
Mavuika_im_lora_dir/Mavuika_single_im_lora.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:273327c68dcb2ef962cfef8b5a38e085a14cb42d0012cc72a7f1f8aaabbf1e06
3
+ size 322557568
README.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Genshin_Impact_Mavuika HunyuanVideo LoRA
2
+
3
+ This repository contains the necessary setup and scripts to generate videos using the HunyuanVideo model with a LoRA (Low-Rank Adaptation) fine-tuned for Mavuika. Below are the instructions to install dependencies, download models, and run the demo.
4
+
5
+ ---
6
+
7
+ ## Installation
8
+
9
+ ### Step 1: Install System Dependencies
10
+ Run the following command to install required system packages:
11
+ ```bash
12
+ sudo apt-get update && sudo apt-get install git-lfs ffmpeg cbm
13
+ ```
14
+
15
+ ### Step 2: Clone the Repository
16
+ Clone the repository and navigate to the project directory:
17
+ ```bash
18
+ git clone https://huggingface.co/svjack/Genshin_Impact_Mavuika_HunyuanVideo_lora
19
+ cd Genshin_Impact_Mavuika_HunyuanVideo_lora
20
+ ```
21
+
22
+ ### Step 3: Install Python Dependencies
23
+ Install the required Python packages:
24
+ ```bash
25
+ conda create -n py310 python=3.10
26
+ conda activate py310
27
+ pip install ipykernel
28
+ python -m ipykernel install --user --name py310 --display-name "py310"
29
+
30
+ pip install -r requirements.txt
31
+ pip install ascii-magic matplotlib tensorboard huggingface_hub
32
+ pip install moviepy==1.0.3
33
+ pip install sageattention==1.0.6
34
+
35
+ pip install torch==2.5.0 torchvision
36
+ ```
37
+
38
+ ---
39
+
40
+ ## Download Models
41
+
42
+ ### Step 1: Download HunyuanVideo Model
43
+ Download the HunyuanVideo model and place it in the `ckpts` directory:
44
+ ```bash
45
+ huggingface-cli download tencent/HunyuanVideo --local-dir ./ckpts
46
+ ```
47
+
48
+ ### Step 2: Download LLaVA Model
49
+ Download the LLaVA model and preprocess it:
50
+ ```bash
51
+ cd ckpts
52
+ huggingface-cli download xtuner/llava-llama-3-8b-v1_1-transformers --local-dir ./llava-llama-3-8b-v1_1-transformers
53
+ wget https://raw.githubusercontent.com/Tencent/HunyuanVideo/refs/heads/main/hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py
54
+ python preprocess_text_encoder_tokenizer_utils.py --input_dir llava-llama-3-8b-v1_1-transformers --output_dir text_encoder
55
+ ```
56
+
57
+ ### Step 3: Download CLIP Model
58
+ Download the CLIP model for the text encoder:
59
+ ```bash
60
+ huggingface-cli download openai/clip-vit-large-patch14 --local-dir ./text_encoder_2
61
+ ```
62
+
63
+ ---
64
+
65
+ ## Demo
66
+
67
+ ### Generate Video 1: Mavuika
68
+ Run the following command to generate a video of Mavuika:
69
+ ```bash
70
+ python hv_generate_video.py \
71
+ --fp8 \
72
+ --video_size 544 960 \
73
+ --video_length 60 \
74
+ --infer_steps 30 \
75
+ --prompt "Mavuika, featuring long, wavy red hair with golden highlights and large, star-shaped earrings. Mavuika wears dark sunglasses, a black choker, and a black leather glove on their left hand. Their attire includes a black and gold armor-like top with intricate designs. The background is a gradient of soft white to light blue, emphasizing Mavuika's confident expression and stylish appearance." \
76
+ --save_path . \
77
+ --output_type both \
78
+ --dit ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \
79
+ --attn_mode sdpa \
80
+ --vae ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt \
81
+ --vae_chunk_size 32 \
82
+ --vae_spatial_tile_sample_min_size 128 \
83
+ --text_encoder1 ckpts/text_encoder \
84
+ --text_encoder2 ckpts/text_encoder_2 \
85
+ --seed 1234 \
86
+ --lora_multiplier 1.0 \
87
+ --lora_weight Mavuika_im_lora_dir/Mavuika_single_im_lora-000035.safetensors
88
+
89
+ ```
90
+
91
+
92
+ <video controls autoplay src="https://huggingface.co/svjack/Genshin_Impact_Mavuika_HunyuanVideo_lora/resolve/main/20250205-041232_1234.mp4"></video>
93
+
94
+
95
+ ### Generate Video 2: Mavuika Sun
96
+ Run the following command to generate a video of KAEDEHARA_KAZUHA:
97
+ ```bash
98
+ python hv_generate_video.py \
99
+ --fp8 \
100
+ --video_size 544 960 \
101
+ --video_length 60 \
102
+ --infer_steps 30 \
103
+ --prompt "Fantastic artwork of Mavuika, featuring long, wavy red hair with golden highlights and large, star-shaped earrings. Mavuika wears dark sunglasses, a black choker, and a black leather glove on their left hand. Their attire includes a black and gold armor-like top with intricate designs, standing confidently in a warm sunset-lit rural village. The background transitions into the interior of a futuristic spaceship, blending the rustic and sci-fi elements seamlessly. The gradient of soft white to light blue in the sky enhances Mavuika's stylish and commanding presence." \
104
+ --save_path . \
105
+ --output_type both \
106
+ --dit ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \
107
+ --attn_mode sdpa \
108
+ --vae ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt \
109
+ --vae_chunk_size 32 \
110
+ --vae_spatial_tile_sample_min_size 128 \
111
+ --text_encoder1 ckpts/text_encoder \
112
+ --text_encoder2 ckpts/text_encoder_2 \
113
+ --seed 1234 \
114
+ --lora_multiplier 1.0 \
115
+ --lora_weight Mavuika_im_lora_dir/Mavuika_single_im_lora-000035.safetensors
116
+ ```
117
+
118
+
119
+ <video controls autoplay src="https://huggingface.co/svjack/Genshin_Impact_Mavuika_HunyuanVideo_lora/resolve/main/20250205-043500_1234.mp4"></video>
120
+
121
+
122
+ ---
123
+
124
+ ## Notes
125
+ - Ensure you have sufficient GPU resources for video generation.
126
+ - Adjust the `--video_size`, `--video_length`, and `--infer_steps` parameters as needed for different output qualities and lengths.
127
+ - The `--prompt` parameter can be modified to generate videos with different scenes or actions.
128
+
129
+ ---
cache_latents.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+ from dataset import config_utils
10
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
11
+ from PIL import Image
12
+
13
+ import logging
14
+
15
+ from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache
16
+ from hunyuan_model.vae import load_vae
17
+ from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
18
+ from utils.model_utils import str_to_dtype
19
+
20
+ logger = logging.getLogger(__name__)
21
+ logging.basicConfig(level=logging.INFO)
22
+
23
+
24
+ def show_image(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]]) -> int:
25
+ import cv2
26
+
27
+ imgs = (
28
+ [image]
29
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
30
+ else [image[0], image[-1]]
31
+ )
32
+ if len(imgs) > 1:
33
+ print(f"Number of images: {len(image)}")
34
+ for i, img in enumerate(imgs):
35
+ if len(imgs) > 1:
36
+ print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
37
+ else:
38
+ print(f"Image: {img.shape}")
39
+ cv2_img = np.array(img) if isinstance(img, Image.Image) else img
40
+ cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_RGB2BGR)
41
+ cv2.imshow("image", cv2_img)
42
+ k = cv2.waitKey(0)
43
+ cv2.destroyAllWindows()
44
+ if k == ord("q") or k == ord("d"):
45
+ return k
46
+ return k
47
+
48
+
49
+ def show_console(
50
+ image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]],
51
+ width: int,
52
+ back: str,
53
+ interactive: bool = False,
54
+ ) -> int:
55
+ from ascii_magic import from_pillow_image, Back
56
+
57
+ back = None
58
+ if back is not None:
59
+ back = getattr(Back, back.upper())
60
+
61
+ k = None
62
+ imgs = (
63
+ [image]
64
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
65
+ else [image[0], image[-1]]
66
+ )
67
+ if len(imgs) > 1:
68
+ print(f"Number of images: {len(image)}")
69
+ for i, img in enumerate(imgs):
70
+ if len(imgs) > 1:
71
+ print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
72
+ else:
73
+ print(f"Image: {img.shape}")
74
+ pil_img = img if isinstance(img, Image.Image) else Image.fromarray(img)
75
+ ascii_img = from_pillow_image(pil_img)
76
+ ascii_img.to_terminal(columns=width, back=back)
77
+
78
+ if interactive:
79
+ k = input("Press q to quit, d to next dataset, other key to next: ")
80
+ if k == "q" or k == "d":
81
+ return ord(k)
82
+
83
+ if not interactive:
84
+ return ord(" ")
85
+ return ord(k) if k else ord(" ")
86
+
87
+
88
+ def show_datasets(
89
+ datasets: list[BaseDataset], debug_mode: str, console_width: int, console_back: str, console_num_images: Optional[int]
90
+ ):
91
+ print(f"d: next dataset, q: quit")
92
+
93
+ num_workers = max(1, os.cpu_count() - 1)
94
+ for i, dataset in enumerate(datasets):
95
+ print(f"Dataset [{i}]")
96
+ batch_index = 0
97
+ num_images_to_show = console_num_images
98
+ k = None
99
+ for key, batch in dataset.retrieve_latent_cache_batches(num_workers):
100
+ print(f"bucket resolution: {key}, count: {len(batch)}")
101
+ for j, item_info in enumerate(batch):
102
+ item_info: ItemInfo
103
+ print(f"{batch_index}-{j}: {item_info}")
104
+ if debug_mode == "image":
105
+ k = show_image(item_info.content)
106
+ elif debug_mode == "console":
107
+ k = show_console(item_info.content, console_width, console_back, console_num_images is None)
108
+ if num_images_to_show is not None:
109
+ num_images_to_show -= 1
110
+ if num_images_to_show == 0:
111
+ k = ord("d") # next dataset
112
+
113
+ if k == ord("q"):
114
+ return
115
+ elif k == ord("d"):
116
+ break
117
+ if k == ord("d"):
118
+ break
119
+ batch_index += 1
120
+
121
+
122
+ def encode_and_save_batch(vae: AutoencoderKLCausal3D, batch: list[ItemInfo]):
123
+ contents = torch.stack([torch.from_numpy(item.content) for item in batch])
124
+ if len(contents.shape) == 4:
125
+ contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
126
+
127
+ contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
128
+ contents = contents.to(vae.device, dtype=vae.dtype)
129
+ contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
130
+
131
+ # print(f"encode batch: {contents.shape}")
132
+ with torch.no_grad():
133
+ latent = vae.encode(contents).latent_dist.sample()
134
+ latent = latent * vae.config.scaling_factor
135
+
136
+ # # debug: decode and save
137
+ # with torch.no_grad():
138
+ # latent_to_decode = latent / vae.config.scaling_factor
139
+ # images = vae.decode(latent_to_decode, return_dict=False)[0]
140
+ # images = (images / 2 + 0.5).clamp(0, 1)
141
+ # images = images.cpu().float().numpy()
142
+ # images = (images * 255).astype(np.uint8)
143
+ # images = images.transpose(0, 2, 3, 4, 1) # B, C, F, H, W -> B, F, H, W, C
144
+ # for b in range(images.shape[0]):
145
+ # for f in range(images.shape[1]):
146
+ # fln = os.path.splitext(os.path.basename(batch[b].item_key))[0]
147
+ # img = Image.fromarray(images[b, f])
148
+ # img.save(f"./logs/decode_{fln}_{b}_{f:03d}.jpg")
149
+
150
+ for item, l in zip(batch, latent):
151
+ # print(f"save latent cache: {item.latent_cache_path}, latent shape: {l.shape}")
152
+ save_latent_cache(item, l)
153
+
154
+
155
+ def main(args):
156
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
157
+ device = torch.device(device)
158
+
159
+ # Load dataset config
160
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
161
+ logger.info(f"Load dataset config from {args.dataset_config}")
162
+ user_config = config_utils.load_user_config(args.dataset_config)
163
+ blueprint = blueprint_generator.generate(user_config, args)
164
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
165
+
166
+ datasets = train_dataset_group.datasets
167
+
168
+ if args.debug_mode is not None:
169
+ show_datasets(datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images)
170
+ return
171
+
172
+ assert args.vae is not None, "vae checkpoint is required"
173
+
174
+ # Load VAE model: HunyuanVideo VAE model is float16
175
+ vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
176
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
177
+ vae.eval()
178
+ print(f"Loaded VAE: {vae.config}, dtype: {vae.dtype}")
179
+
180
+ if args.vae_chunk_size is not None:
181
+ vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
182
+ logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
183
+ if args.vae_spatial_tile_sample_min_size is not None:
184
+ vae.enable_spatial_tiling(True)
185
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
186
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
187
+ elif args.vae_tiling:
188
+ vae.enable_spatial_tiling(True)
189
+
190
+ # Encode images
191
+ num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
192
+ for i, dataset in enumerate(datasets):
193
+ print(f"Encoding dataset [{i}]")
194
+ for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
195
+ if args.skip_existing:
196
+ filtered_batch = [item for item in batch if not os.path.exists(item.latent_cache_path)]
197
+ if len(filtered_batch) == 0:
198
+ continue
199
+ batch = filtered_batch
200
+
201
+ bs = args.batch_size if args.batch_size is not None else len(batch)
202
+ for i in range(0, len(batch), bs):
203
+ encode_and_save_batch(vae, batch[i : i + bs])
204
+
205
+
206
+ def setup_parser():
207
+ parser = argparse.ArgumentParser()
208
+
209
+ parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
210
+ parser.add_argument("--vae", type=str, required=False, default=None, help="path to vae checkpoint")
211
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
212
+ parser.add_argument(
213
+ "--vae_tiling",
214
+ action="store_true",
215
+ help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled",
216
+ )
217
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
218
+ parser.add_argument(
219
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
220
+ )
221
+ parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
222
+ parser.add_argument(
223
+ "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
224
+ )
225
+ parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
226
+ parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
227
+ parser.add_argument("--debug_mode", type=str, default=None, choices=["image", "console"], help="debug mode")
228
+ parser.add_argument("--console_width", type=int, default=80, help="debug mode: console width")
229
+ parser.add_argument(
230
+ "--console_back", type=str, default=None, help="debug mode: console background color, one of ascii_magic.Back"
231
+ )
232
+ parser.add_argument(
233
+ "--console_num_images",
234
+ type=int,
235
+ default=None,
236
+ help="debug mode: not interactive, number of images to show for each dataset",
237
+ )
238
+ return parser
239
+
240
+
241
+ if __name__ == "__main__":
242
+ parser = setup_parser()
243
+
244
+ args = parser.parse_args()
245
+ main(args)
cache_text_encoder_outputs.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+ from dataset import config_utils
10
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
11
+ import accelerate
12
+
13
+ from dataset.image_video_dataset import ItemInfo, save_text_encoder_output_cache
14
+ from hunyuan_model import text_encoder as text_encoder_module
15
+ from hunyuan_model.text_encoder import TextEncoder
16
+
17
+ import logging
18
+
19
+ from utils.model_utils import str_to_dtype
20
+
21
+ logger = logging.getLogger(__name__)
22
+ logging.basicConfig(level=logging.INFO)
23
+
24
+
25
+ def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, list[str]]):
26
+ data_type = "video" # video only, image is not supported
27
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
28
+
29
+ with torch.no_grad():
30
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
31
+
32
+ return prompt_outputs.hidden_state, prompt_outputs.attention_mask
33
+
34
+
35
+ def encode_and_save_batch(
36
+ text_encoder: TextEncoder, batch: list[ItemInfo], is_llm: bool, accelerator: Optional[accelerate.Accelerator]
37
+ ):
38
+ prompts = [item.caption for item in batch]
39
+ # print(prompts)
40
+
41
+ # encode prompt
42
+ if accelerator is not None:
43
+ with accelerator.autocast():
44
+ prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
45
+ else:
46
+ prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
47
+
48
+ # # convert to fp16 if needed
49
+ # if prompt_embeds.dtype == torch.float32 and text_encoder.dtype != torch.float32:
50
+ # prompt_embeds = prompt_embeds.to(text_encoder.dtype)
51
+
52
+ # save prompt cache
53
+ for item, embed, mask in zip(batch, prompt_embeds, prompt_mask):
54
+ save_text_encoder_output_cache(item, embed, mask, is_llm)
55
+
56
+
57
+ def main(args):
58
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
59
+ device = torch.device(device)
60
+
61
+ # Load dataset config
62
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
63
+ logger.info(f"Load dataset config from {args.dataset_config}")
64
+ user_config = config_utils.load_user_config(args.dataset_config)
65
+ blueprint = blueprint_generator.generate(user_config, args)
66
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
67
+
68
+ datasets = train_dataset_group.datasets
69
+
70
+ # define accelerator for fp8 inference
71
+ accelerator = None
72
+ if args.fp8_llm:
73
+ accelerator = accelerate.Accelerator(mixed_precision="fp16")
74
+
75
+ # define encode function
76
+ num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
77
+
78
+ def encode_for_text_encoder(text_encoder: TextEncoder, is_llm: bool):
79
+ for i, dataset in enumerate(datasets):
80
+ print(f"Encoding dataset [{i}]")
81
+ for batch in tqdm(dataset.retrieve_text_encoder_output_cache_batches(num_workers)):
82
+ if args.skip_existing:
83
+ filtered_batch = [item for item in batch if not os.path.exists(item.text_encoder_output_cache_path)]
84
+ if len(filtered_batch) == 0:
85
+ continue
86
+ batch = filtered_batch
87
+
88
+ bs = args.batch_size if args.batch_size is not None else len(batch)
89
+ for i in range(0, len(batch), bs):
90
+ encode_and_save_batch(text_encoder, batch[i : i + bs], is_llm, accelerator)
91
+
92
+ # Load Text Encoder 1
93
+ text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else str_to_dtype(args.text_encoder_dtype)
94
+ logger.info(f"loading text encoder 1: {args.text_encoder1}")
95
+ text_encoder_1 = text_encoder_module.load_text_encoder_1(args.text_encoder1, device, args.fp8_llm, text_encoder_dtype)
96
+ text_encoder_1.to(device=device)
97
+
98
+ # Encode with Text Encoder 1
99
+ logger.info("Encoding with Text Encoder 1")
100
+ encode_for_text_encoder(text_encoder_1, is_llm=True)
101
+ del text_encoder_1
102
+
103
+ # Load Text Encoder 2
104
+ logger.info(f"loading text encoder 2: {args.text_encoder2}")
105
+ text_encoder_2 = text_encoder_module.load_text_encoder_2(args.text_encoder2, device, text_encoder_dtype)
106
+ text_encoder_2.to(device=device)
107
+
108
+ # Encode with Text Encoder 2
109
+ logger.info("Encoding with Text Encoder 2")
110
+ encode_for_text_encoder(text_encoder_2, is_llm=False)
111
+ del text_encoder_2
112
+
113
+
114
+ def setup_parser():
115
+ parser = argparse.ArgumentParser()
116
+
117
+ parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
118
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
119
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
120
+ parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
121
+ parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
122
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
123
+ parser.add_argument(
124
+ "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
125
+ )
126
+ parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
127
+ parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
128
+ return parser
129
+
130
+
131
+ if __name__ == "__main__":
132
+ parser = setup_parser()
133
+
134
+ args = parser.parse_args()
135
+ main(args)
convert_lora.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from safetensors.torch import load_file, save_file
5
+ from safetensors import safe_open
6
+ from utils import model_utils
7
+
8
+ import logging
9
+
10
+
11
+ logger = logging.getLogger(__name__)
12
+ logging.basicConfig(level=logging.INFO)
13
+
14
+
15
+ def convert_from_diffusers(prefix, weights_sd):
16
+ # convert from diffusers(?) to default LoRA
17
+ # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
18
+ # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
19
+ # note: Diffusers has no alpha, so alpha is set to rank
20
+ new_weights_sd = {}
21
+ lora_dims = {}
22
+ for key, weight in weights_sd.items():
23
+ diffusers_prefix, key_body = key.split(".", 1)
24
+ if diffusers_prefix != "diffusion_model":
25
+ logger.warning(f"unexpected key: {key} in diffusers format")
26
+ continue
27
+
28
+ new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
29
+ new_weights_sd[new_key] = weight
30
+
31
+ lora_name = new_key.split(".")[0] # before first dot
32
+ if lora_name not in lora_dims and "lora_down" in new_key:
33
+ lora_dims[lora_name] = weight.shape[0]
34
+
35
+ # add alpha with rank
36
+ for lora_name, dim in lora_dims.items():
37
+ new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
38
+
39
+ return new_weights_sd
40
+
41
+
42
+ def convert_to_diffusers(prefix, weights_sd):
43
+ # convert from default LoRA to diffusers
44
+
45
+ # get alphas
46
+ lora_alphas = {}
47
+ for key, weight in weights_sd.items():
48
+ if key.startswith(prefix):
49
+ lora_name = key.split(".", 1)[0] # before first dot
50
+ if lora_name not in lora_alphas and "alpha" in key:
51
+ lora_alphas[lora_name] = weight
52
+
53
+ new_weights_sd = {}
54
+ for key, weight in weights_sd.items():
55
+ if key.startswith(prefix):
56
+ if "alpha" in key:
57
+ continue
58
+
59
+ lora_name = key.split(".", 1)[0] # before first dot
60
+
61
+ # HunyuanVideo lora name to module name: ugly but works
62
+ module_name = lora_name[len(prefix) :] # remove "lora_unet_"
63
+ module_name = module_name.replace("_", ".") # replace "_" with "."
64
+ module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks
65
+ module_name = module_name.replace("single.blocks.", "single_blocks.") # fix single blocks
66
+ module_name = module_name.replace("img.", "img_") # fix img
67
+ module_name = module_name.replace("txt.", "txt_") # fix txt
68
+ module_name = module_name.replace("attn.", "attn_") # fix attn
69
+
70
+ diffusers_prefix = "diffusion_model"
71
+ if "lora_down" in key:
72
+ new_key = f"{diffusers_prefix}.{module_name}.lora_A.weight"
73
+ dim = weight.shape[0]
74
+ elif "lora_up" in key:
75
+ new_key = f"{diffusers_prefix}.{module_name}.lora_B.weight"
76
+ dim = weight.shape[1]
77
+ else:
78
+ logger.warning(f"unexpected key: {key} in default LoRA format")
79
+ continue
80
+
81
+ # scale weight by alpha
82
+ if lora_name in lora_alphas:
83
+ # we scale both down and up, so scale is sqrt
84
+ scale = lora_alphas[lora_name] / dim
85
+ scale = scale.sqrt()
86
+ weight = weight * scale
87
+ else:
88
+ logger.warning(f"missing alpha for {lora_name}")
89
+
90
+ new_weights_sd[new_key] = weight
91
+
92
+ return new_weights_sd
93
+
94
+
95
+ def convert(input_file, output_file, target_format):
96
+ logger.info(f"loading {input_file}")
97
+ weights_sd = load_file(input_file)
98
+ with safe_open(input_file, framework="pt") as f:
99
+ metadata = f.metadata()
100
+
101
+ logger.info(f"converting to {target_format}")
102
+ prefix = "lora_unet_"
103
+ if target_format == "default":
104
+ new_weights_sd = convert_from_diffusers(prefix, weights_sd)
105
+ metadata = metadata or {}
106
+ model_utils.precalculate_safetensors_hashes(new_weights_sd, metadata)
107
+ elif target_format == "other":
108
+ new_weights_sd = convert_to_diffusers(prefix, weights_sd)
109
+ else:
110
+ raise ValueError(f"unknown target format: {target_format}")
111
+
112
+ logger.info(f"saving to {output_file}")
113
+ save_file(new_weights_sd, output_file, metadata=metadata)
114
+
115
+ logger.info("done")
116
+
117
+
118
+ def parse_args():
119
+ parser = argparse.ArgumentParser(description="Convert LoRA weights between default and other formats")
120
+ parser.add_argument("--input", type=str, required=True, help="input model file")
121
+ parser.add_argument("--output", type=str, required=True, help="output model file")
122
+ parser.add_argument("--target", type=str, required=True, choices=["other", "default"], help="target format")
123
+ args = parser.parse_args()
124
+ return args
125
+
126
+
127
+ if __name__ == "__main__":
128
+ args = parse_args()
129
+ convert(args.input, args.output, args.target)
dataset/__init__.py ADDED
File without changes
dataset/config_utils.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import (
3
+ asdict,
4
+ dataclass,
5
+ )
6
+ import functools
7
+ import random
8
+ from textwrap import dedent, indent
9
+ import json
10
+ from pathlib import Path
11
+
12
+ # from toolz import curry
13
+ from typing import Dict, List, Optional, Sequence, Tuple, Union
14
+
15
+ import toml
16
+ import voluptuous
17
+ from voluptuous import Any, ExactSequence, MultipleInvalid, Object, Schema
18
+
19
+ from .image_video_dataset import DatasetGroup, ImageDataset, VideoDataset
20
+
21
+ import logging
22
+
23
+ logger = logging.getLogger(__name__)
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+
27
+ @dataclass
28
+ class BaseDatasetParams:
29
+ resolution: Tuple[int, int] = (960, 544)
30
+ enable_bucket: bool = False
31
+ bucket_no_upscale: bool = False
32
+ caption_extension: Optional[str] = None
33
+ batch_size: int = 1
34
+ cache_directory: Optional[str] = None
35
+ debug_dataset: bool = False
36
+
37
+
38
+ @dataclass
39
+ class ImageDatasetParams(BaseDatasetParams):
40
+ image_directory: Optional[str] = None
41
+ image_jsonl_file: Optional[str] = None
42
+
43
+
44
+ @dataclass
45
+ class VideoDatasetParams(BaseDatasetParams):
46
+ video_directory: Optional[str] = None
47
+ video_jsonl_file: Optional[str] = None
48
+ target_frames: Sequence[int] = (1,)
49
+ frame_extraction: Optional[str] = "head"
50
+ frame_stride: Optional[int] = 1
51
+ frame_sample: Optional[int] = 1
52
+
53
+
54
+ @dataclass
55
+ class DatasetBlueprint:
56
+ is_image_dataset: bool
57
+ params: Union[ImageDatasetParams, VideoDatasetParams]
58
+
59
+
60
+ @dataclass
61
+ class DatasetGroupBlueprint:
62
+ datasets: Sequence[DatasetBlueprint]
63
+
64
+
65
+ @dataclass
66
+ class Blueprint:
67
+ dataset_group: DatasetGroupBlueprint
68
+
69
+
70
+ class ConfigSanitizer:
71
+ # @curry
72
+ @staticmethod
73
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
74
+ Schema(ExactSequence([klass, klass]))(value)
75
+ return tuple(value)
76
+
77
+ # @curry
78
+ @staticmethod
79
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
80
+ Schema(Any(klass, ExactSequence([klass, klass])))(value)
81
+ try:
82
+ Schema(klass)(value)
83
+ return (value, value)
84
+ except:
85
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
86
+
87
+ # datasets schema
88
+ DATASET_ASCENDABLE_SCHEMA = {
89
+ "caption_extension": str,
90
+ "batch_size": int,
91
+ "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
92
+ "enable_bucket": bool,
93
+ "bucket_no_upscale": bool,
94
+ }
95
+ IMAGE_DATASET_DISTINCT_SCHEMA = {
96
+ "image_directory": str,
97
+ "image_jsonl_file": str,
98
+ "cache_directory": str,
99
+ }
100
+ VIDEO_DATASET_DISTINCT_SCHEMA = {
101
+ "video_directory": str,
102
+ "video_jsonl_file": str,
103
+ "target_frames": [int],
104
+ "frame_extraction": str,
105
+ "frame_stride": int,
106
+ "frame_sample": int,
107
+ "cache_directory": str,
108
+ }
109
+
110
+ # options handled by argparse but not handled by user config
111
+ ARGPARSE_SPECIFIC_SCHEMA = {
112
+ "debug_dataset": bool,
113
+ }
114
+
115
+ def __init__(self) -> None:
116
+ self.image_dataset_schema = self.__merge_dict(
117
+ self.DATASET_ASCENDABLE_SCHEMA,
118
+ self.IMAGE_DATASET_DISTINCT_SCHEMA,
119
+ )
120
+ self.video_dataset_schema = self.__merge_dict(
121
+ self.DATASET_ASCENDABLE_SCHEMA,
122
+ self.VIDEO_DATASET_DISTINCT_SCHEMA,
123
+ )
124
+
125
+ def validate_flex_dataset(dataset_config: dict):
126
+ if "target_frames" in dataset_config:
127
+ return Schema(self.video_dataset_schema)(dataset_config)
128
+ else:
129
+ return Schema(self.image_dataset_schema)(dataset_config)
130
+
131
+ self.dataset_schema = validate_flex_dataset
132
+
133
+ self.general_schema = self.__merge_dict(
134
+ self.DATASET_ASCENDABLE_SCHEMA,
135
+ )
136
+ self.user_config_validator = Schema(
137
+ {
138
+ "general": self.general_schema,
139
+ "datasets": [self.dataset_schema],
140
+ }
141
+ )
142
+ self.argparse_schema = self.__merge_dict(
143
+ self.ARGPARSE_SPECIFIC_SCHEMA,
144
+ )
145
+ self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
146
+
147
+ def sanitize_user_config(self, user_config: dict) -> dict:
148
+ try:
149
+ return self.user_config_validator(user_config)
150
+ except MultipleInvalid:
151
+ # TODO: clarify the error message
152
+ logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
153
+ raise
154
+
155
+ # NOTE: In nature, argument parser result is not needed to be sanitize
156
+ # However this will help us to detect program bug
157
+ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
158
+ try:
159
+ return self.argparse_config_validator(argparse_namespace)
160
+ except MultipleInvalid:
161
+ # XXX: this should be a bug
162
+ logger.error(
163
+ "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
164
+ )
165
+ raise
166
+
167
+ # NOTE: value would be overwritten by latter dict if there is already the same key
168
+ @staticmethod
169
+ def __merge_dict(*dict_list: dict) -> dict:
170
+ merged = {}
171
+ for schema in dict_list:
172
+ # merged |= schema
173
+ for k, v in schema.items():
174
+ merged[k] = v
175
+ return merged
176
+
177
+
178
+ class BlueprintGenerator:
179
+ BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
180
+
181
+ def __init__(self, sanitizer: ConfigSanitizer):
182
+ self.sanitizer = sanitizer
183
+
184
+ # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
185
+ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
186
+ sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
187
+ sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
188
+
189
+ argparse_config = {k: v for k, v in vars(sanitized_argparse_namespace).items() if v is not None}
190
+ general_config = sanitized_user_config.get("general", {})
191
+
192
+ dataset_blueprints = []
193
+ for dataset_config in sanitized_user_config.get("datasets", []):
194
+ is_image_dataset = "target_frames" not in dataset_config
195
+ if is_image_dataset:
196
+ dataset_params_klass = ImageDatasetParams
197
+ else:
198
+ dataset_params_klass = VideoDatasetParams
199
+
200
+ params = self.generate_params_by_fallbacks(
201
+ dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
202
+ )
203
+ dataset_blueprints.append(DatasetBlueprint(is_image_dataset, params))
204
+
205
+ dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
206
+
207
+ return Blueprint(dataset_group_blueprint)
208
+
209
+ @staticmethod
210
+ def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
211
+ name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
212
+ search_value = BlueprintGenerator.search_value
213
+ default_params = asdict(param_klass())
214
+ param_names = default_params.keys()
215
+
216
+ params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
217
+
218
+ return param_klass(**params)
219
+
220
+ @staticmethod
221
+ def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
222
+ for cand in fallbacks:
223
+ value = cand.get(key)
224
+ if value is not None:
225
+ return value
226
+
227
+ return default_value
228
+
229
+
230
+ # if training is True, it will return a dataset group for training, otherwise for caching
231
+ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint, training: bool = False) -> DatasetGroup:
232
+ datasets: List[Union[ImageDataset, VideoDataset]] = []
233
+
234
+ for dataset_blueprint in dataset_group_blueprint.datasets:
235
+ if dataset_blueprint.is_image_dataset:
236
+ dataset_klass = ImageDataset
237
+ else:
238
+ dataset_klass = VideoDataset
239
+
240
+ dataset = dataset_klass(**asdict(dataset_blueprint.params))
241
+ datasets.append(dataset)
242
+
243
+ # print info
244
+ info = ""
245
+ for i, dataset in enumerate(datasets):
246
+ is_image_dataset = isinstance(dataset, ImageDataset)
247
+ info += dedent(
248
+ f"""\
249
+ [Dataset {i}]
250
+ is_image_dataset: {is_image_dataset}
251
+ resolution: {dataset.resolution}
252
+ batch_size: {dataset.batch_size}
253
+ caption_extension: "{dataset.caption_extension}"
254
+ enable_bucket: {dataset.enable_bucket}
255
+ bucket_no_upscale: {dataset.bucket_no_upscale}
256
+ cache_directory: "{dataset.cache_directory}"
257
+ debug_dataset: {dataset.debug_dataset}
258
+ """
259
+ )
260
+
261
+ if is_image_dataset:
262
+ info += indent(
263
+ dedent(
264
+ f"""\
265
+ image_directory: "{dataset.image_directory}"
266
+ image_jsonl_file: "{dataset.image_jsonl_file}"
267
+ \n"""
268
+ ),
269
+ " ",
270
+ )
271
+ else:
272
+ info += indent(
273
+ dedent(
274
+ f"""\
275
+ video_directory: "{dataset.video_directory}"
276
+ video_jsonl_file: "{dataset.video_jsonl_file}"
277
+ target_frames: {dataset.target_frames}
278
+ frame_extraction: {dataset.frame_extraction}
279
+ frame_stride: {dataset.frame_stride}
280
+ frame_sample: {dataset.frame_sample}
281
+ \n"""
282
+ ),
283
+ " ",
284
+ )
285
+ logger.info(f"{info}")
286
+
287
+ # make buckets first because it determines the length of dataset
288
+ # and set the same seed for all datasets
289
+ seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
290
+ for i, dataset in enumerate(datasets):
291
+ # logger.info(f"[Dataset {i}]")
292
+ dataset.set_seed(seed)
293
+ if training:
294
+ dataset.prepare_for_training()
295
+
296
+ return DatasetGroup(datasets)
297
+
298
+
299
+ def load_user_config(file: str) -> dict:
300
+ file: Path = Path(file)
301
+ if not file.is_file():
302
+ raise ValueError(f"file not found / ファイルが見つかりません: {file}")
303
+
304
+ if file.name.lower().endswith(".json"):
305
+ try:
306
+ with open(file, "r") as f:
307
+ config = json.load(f)
308
+ except Exception:
309
+ logger.error(
310
+ f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
311
+ )
312
+ raise
313
+ elif file.name.lower().endswith(".toml"):
314
+ try:
315
+ config = toml.load(file)
316
+ except Exception:
317
+ logger.error(
318
+ f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
319
+ )
320
+ raise
321
+ else:
322
+ raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
323
+
324
+ return config
325
+
326
+
327
+ # for config test
328
+ if __name__ == "__main__":
329
+ parser = argparse.ArgumentParser()
330
+ parser.add_argument("dataset_config")
331
+ config_args, remain = parser.parse_known_args()
332
+
333
+ parser = argparse.ArgumentParser()
334
+ parser.add_argument("--debug_dataset", action="store_true")
335
+ argparse_namespace = parser.parse_args(remain)
336
+
337
+ logger.info("[argparse_namespace]")
338
+ logger.info(f"{vars(argparse_namespace)}")
339
+
340
+ user_config = load_user_config(config_args.dataset_config)
341
+
342
+ logger.info("")
343
+ logger.info("[user_config]")
344
+ logger.info(f"{user_config}")
345
+
346
+ sanitizer = ConfigSanitizer()
347
+ sanitized_user_config = sanitizer.sanitize_user_config(user_config)
348
+
349
+ logger.info("")
350
+ logger.info("[sanitized_user_config]")
351
+ logger.info(f"{sanitized_user_config}")
352
+
353
+ blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
354
+
355
+ logger.info("")
356
+ logger.info("[blueprint]")
357
+ logger.info(f"{blueprint}")
358
+
359
+ dataset_group = generate_dataset_group_by_blueprint(blueprint.dataset_group)
dataset/dataset_config.md ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Dataset Configuration
2
+
3
+ Please create a TOML file for dataset configuration.
4
+
5
+ Image and video datasets are supported. The configuration file can include multiple datasets, either image or video datasets, with caption text files or metadata JSONL files.
6
+
7
+ ### Sample for Image Dataset with Caption Text Files
8
+
9
+ ```toml
10
+ # resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale must be set in either general or datasets
11
+
12
+ # general configurations
13
+ [general]
14
+ resolution = [960, 544]
15
+ caption_extension = ".txt"
16
+ batch_size = 1
17
+ enable_bucket = true
18
+ bucket_no_upscale = false
19
+
20
+ [[datasets]]
21
+ image_directory = "/path/to/image_dir"
22
+
23
+ # other datasets can be added here. each dataset can have different configurations
24
+ ```
25
+
26
+ ### Sample for Image Dataset with Metadata JSONL File
27
+
28
+ ```toml
29
+ # resolution, batch_size, enable_bucket, bucket_no_upscale must be set in either general or datasets
30
+ # caption_extension is not required for metadata jsonl file
31
+ # cache_directory is required for each dataset with metadata jsonl file
32
+
33
+ # general configurations
34
+ [general]
35
+ resolution = [960, 544]
36
+ batch_size = 1
37
+ enable_bucket = true
38
+ bucket_no_upscale = false
39
+
40
+ [[datasets]]
41
+ image_jsonl_file = "/path/to/metadata.jsonl"
42
+ cache_directory = "/path/to/cache_directory"
43
+
44
+ # other datasets can be added here. each dataset can have different configurations
45
+ ```
46
+
47
+ JSONL file format for metadata:
48
+
49
+ ```json
50
+ {"image_path": "/path/to/image1.jpg", "caption": "A caption for image1"}
51
+ {"image_path": "/path/to/image2.jpg", "caption": "A caption for image2"}
52
+ ```
53
+
54
+ ### Sample for Video Dataset with Caption Text Files
55
+
56
+ ```toml
57
+ # resolution, caption_extension, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, enable_bucket, bucket_no_upscale must be set in either general or datasets
58
+
59
+ # general configurations
60
+ [general]
61
+ resolution = [960, 544]
62
+ caption_extension = ".txt"
63
+ batch_size = 1
64
+ enable_bucket = true
65
+ bucket_no_upscale = false
66
+
67
+ [[datasets]]
68
+ video_directory = "/path/to/video_dir"
69
+ target_frames = [1, 25, 45]
70
+ frame_extraction = "head"
71
+
72
+ # other datasets can be added here. each dataset can have different configurations
73
+ ```
74
+
75
+ ### Sample for Video Dataset with Metadata JSONL File
76
+
77
+ ```toml
78
+ # resolution, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, enable_bucket, bucket_no_upscale must be set in either general or datasets
79
+ # caption_extension is not required for metadata jsonl file
80
+ # cache_directory is required for each dataset with metadata jsonl file
81
+
82
+ # general configurations
83
+ [general]
84
+ resolution = [960, 544]
85
+ batch_size = 1
86
+ enable_bucket = true
87
+ bucket_no_upscale = false
88
+
89
+ [[datasets]]
90
+ video_jsonl_file = "/path/to/metadata.jsonl"
91
+ target_frames = [1, 25, 45]
92
+ frame_extraction = "head"
93
+ cache_directory = "/path/to/cache_directory"
94
+
95
+ # same metadata jsonl file can be used for multiple datasets
96
+ [[datasets]]
97
+ video_jsonl_file = "/path/to/metadata.jsonl"
98
+ target_frames = [1]
99
+ frame_stride = 10
100
+ cache_directory = "/path/to/cache_directory"
101
+
102
+ # other datasets can be added here. each dataset can have different configurations
103
+ ```
104
+
105
+ JSONL file format for metadata:
106
+
107
+ ```json
108
+ {"video_path": "/path/to/video1.mp4", "caption": "A caption for video1"}
109
+ {"video_path": "/path/to/video2.mp4", "caption": "A caption for video2"}
110
+ ```
111
+
112
+ ### fame_extraction Options
113
+
114
+ - `head`: Extract the first N frames from the video.
115
+ - `chunk`: Extract frames by splitting the video into chunks of N frames.
116
+ - `slide`: Extract frames from the video with a stride of `frame_stride`.
117
+ - `uniform`: Extract `frame_sample` samples uniformly from the video.
118
+
119
+ For example, consider a video with 40 frames. The following diagrams illustrate each extraction:
120
+
121
+ ```
122
+ Original Video, 40 frames: x = frame, o = no frame
123
+ oooooooooooooooooooooooooooooooooooooooo
124
+
125
+ head, target_frames = [1, 13, 25] -> extract head frames:
126
+ xooooooooooooooooooooooooooooooooooooooo
127
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
128
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
129
+
130
+ chunk, target_frames = [13, 25] -> extract frames by splitting into chunks, into 13 and 25 frames:
131
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
132
+ oooooooooooooxxxxxxxxxxxxxoooooooooooooo
133
+ ooooooooooooooooooooooooooxxxxxxxxxxxxxo
134
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
135
+
136
+ NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
137
+
138
+ slide, target_frames = [1, 13, 25], frame_stride = 10 -> extract N frames with a stride of 10:
139
+ xooooooooooooooooooooooooooooooooooooooo
140
+ ooooooooooxooooooooooooooooooooooooooooo
141
+ ooooooooooooooooooooxooooooooooooooooooo
142
+ ooooooooooooooooooooooooooooooxooooooooo
143
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
144
+ ooooooooooxxxxxxxxxxxxxooooooooooooooooo
145
+ ooooooooooooooooooooxxxxxxxxxxxxxooooooo
146
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
147
+ ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
148
+
149
+ uniform, target_frames =[1, 13, 25], frame_sample = 4 -> extract `frame_sample` samples uniformly, N frames each:
150
+ xooooooooooooooooooooooooooooooooooooooo
151
+ oooooooooooooxoooooooooooooooooooooooooo
152
+ oooooooooooooooooooooooooxoooooooooooooo
153
+ ooooooooooooooooooooooooooooooooooooooox
154
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
155
+ oooooooooxxxxxxxxxxxxxoooooooooooooooooo
156
+ ooooooooooooooooooxxxxxxxxxxxxxooooooooo
157
+ oooooooooooooooooooooooooooxxxxxxxxxxxxx
158
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
159
+ oooooxxxxxxxxxxxxxxxxxxxxxxxxxoooooooooo
160
+ ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
161
+ oooooooooooooooxxxxxxxxxxxxxxxxxxxxxxxxx
162
+ ```
163
+
164
+ ## Specifications
165
+
166
+ ```toml
167
+ # general configurations
168
+ [general]
169
+ resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
170
+ caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
171
+ batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
172
+ enable_bucket = true # optional, default is false. Enable bucketing for datasets
173
+ bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
174
+
175
+ ### Image Dataset
176
+
177
+ # sample image dataset with caption text files
178
+ [[datasets]]
179
+ image_directory = "/path/to/image_dir"
180
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
181
+ resolution = [960, 544] # required if general resolution is not set
182
+ batch_size = 4 # optional, overwrite the default batch size
183
+ enable_bucket = false # optional, overwrite the default bucketing setting
184
+ bucket_no_upscale = true # optional, overwrite the default bucketing setting
185
+ cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
186
+
187
+ # sample image dataset with metadata **jsonl** file
188
+ [[datasets]]
189
+ image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
190
+ resolution = [960, 544] # required if general resolution is not set
191
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
192
+ # caption_extension is not required for metadata jsonl file
193
+ # batch_size, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
194
+
195
+ ### Video Dataset
196
+
197
+ # sample video dataset with caption text files
198
+ [[datasets]]
199
+ video_directory = "/path/to/video_dir"
200
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
201
+ resolution = [960, 544] # required if general resolution is not set
202
+
203
+ target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
204
+
205
+ # NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
206
+
207
+ frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
208
+ frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
209
+ frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
210
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
211
+
212
+ # sample video dataset with metadata jsonl file
213
+ [[datasets]]
214
+ video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
215
+
216
+ target_frames = [1, 79]
217
+
218
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
219
+ # frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
220
+ ```
221
+
222
+ <!--
223
+ # sample image dataset with lance
224
+ [[datasets]]
225
+ image_lance_dataset = "/path/to/lance_dataset"
226
+ resolution = [960, 544] # required if general resolution is not set
227
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
228
+ -->
229
+
230
+ The metadata with .json file will be supported in the near future.
231
+
232
+
233
+
234
+ <!--
235
+
236
+ ```toml
237
+ # general configurations
238
+ [general]
239
+ resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
240
+ caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
241
+ batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
242
+ enable_bucket = true # optional, default is false. Enable bucketing for datasets
243
+ bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
244
+
245
+ # sample image dataset with caption text files
246
+ [[datasets]]
247
+ image_directory = "/path/to/image_dir"
248
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
249
+ resolution = [960, 544] # required if general resolution is not set
250
+ batch_size = 4 # optional, overwrite the default batch size
251
+ enable_bucket = false # optional, overwrite the default bucketing setting
252
+ bucket_no_upscale = true # optional, overwrite the default bucketing setting
253
+ cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
254
+
255
+ # sample image dataset with metadata **jsonl** file
256
+ [[datasets]]
257
+ image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
258
+ resolution = [960, 544] # required if general resolution is not set
259
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
260
+ # caption_extension is not required for metadata jsonl file
261
+ # batch_size, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
262
+
263
+ # sample video dataset with caption text files
264
+ [[datasets]]
265
+ video_directory = "/path/to/video_dir"
266
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
267
+ resolution = [960, 544] # required if general resolution is not set
268
+ target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
269
+ frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
270
+ frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
271
+ frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
272
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
273
+
274
+ # sample video dataset with metadata jsonl file
275
+ [[datasets]]
276
+ video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
277
+ target_frames = [1, 79]
278
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
279
+ # frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
280
+ ```
281
+
282
+ # sample image dataset with lance
283
+ [[datasets]]
284
+ image_lance_dataset = "/path/to/lance_dataset"
285
+ resolution = [960, 544] # required if general resolution is not set
286
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
287
+
288
+ The metadata with .json file will be supported in the near future.
289
+
290
+
291
+
292
+
293
+ -->
dataset/image_video_dataset.py ADDED
@@ -0,0 +1,1255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ import glob
3
+ import json
4
+ import math
5
+ import os
6
+ import random
7
+ import time
8
+ from typing import Optional, Sequence, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ from safetensors.torch import save_file, load_file
13
+ from safetensors import safe_open
14
+ from PIL import Image
15
+ import cv2
16
+ import av
17
+
18
+ from utils import safetensors_utils
19
+ from utils.model_utils import dtype_to_str
20
+
21
+ import logging
22
+
23
+ logger = logging.getLogger(__name__)
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+
27
+ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
28
+
29
+ try:
30
+ import pillow_avif
31
+
32
+ IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
33
+ except:
34
+ pass
35
+
36
+ # JPEG-XL on Linux
37
+ try:
38
+ from jxlpy import JXLImagePlugin
39
+
40
+ IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
41
+ except:
42
+ pass
43
+
44
+ # JPEG-XL on Windows
45
+ try:
46
+ import pillow_jxl
47
+
48
+ IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
49
+ except:
50
+ pass
51
+
52
+ VIDEO_EXTENSIONS = [".mp4", ".avi", ".mov", ".webm", ".MP4", ".AVI", ".MOV", ".WEBM"] # some of them are not tested
53
+
54
+ ARCHITECTURE_HUNYUAN_VIDEO = "hv"
55
+
56
+
57
+ def glob_images(directory, base="*"):
58
+ img_paths = []
59
+ for ext in IMAGE_EXTENSIONS:
60
+ if base == "*":
61
+ img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
62
+ else:
63
+ img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
64
+ img_paths = list(set(img_paths)) # remove duplicates
65
+ img_paths.sort()
66
+ return img_paths
67
+
68
+
69
+ def glob_videos(directory, base="*"):
70
+ video_paths = []
71
+ for ext in VIDEO_EXTENSIONS:
72
+ if base == "*":
73
+ video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
74
+ else:
75
+ video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
76
+ video_paths = list(set(video_paths)) # remove duplicates
77
+ video_paths.sort()
78
+ return video_paths
79
+
80
+
81
+ def divisible_by(num: int, divisor: int) -> int:
82
+ return num - num % divisor
83
+
84
+
85
+ def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
86
+ """
87
+ Resize the image to the bucket resolution.
88
+ """
89
+ is_pil_image = isinstance(image, Image.Image)
90
+ if is_pil_image:
91
+ image_width, image_height = image.size
92
+ else:
93
+ image_height, image_width = image.shape[:2]
94
+
95
+ if bucket_reso == (image_width, image_height):
96
+ return np.array(image) if is_pil_image else image
97
+
98
+ bucket_width, bucket_height = bucket_reso
99
+ if bucket_width == image_width or bucket_height == image_height:
100
+ image = np.array(image) if is_pil_image else image
101
+ else:
102
+ # resize the image to the bucket resolution to match the short side
103
+ scale_width = bucket_width / image_width
104
+ scale_height = bucket_height / image_height
105
+ scale = max(scale_width, scale_height)
106
+ image_width = int(image_width * scale + 0.5)
107
+ image_height = int(image_height * scale + 0.5)
108
+
109
+ if scale > 1:
110
+ image = Image.fromarray(image) if not is_pil_image else image
111
+ image = image.resize((image_width, image_height), Image.LANCZOS)
112
+ image = np.array(image)
113
+ else:
114
+ image = np.array(image) if is_pil_image else image
115
+ image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
116
+
117
+ # crop the image to the bucket resolution
118
+ crop_left = (image_width - bucket_width) // 2
119
+ crop_top = (image_height - bucket_height) // 2
120
+ image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
121
+ return image
122
+
123
+
124
+ class ItemInfo:
125
+ def __init__(
126
+ self,
127
+ item_key: str,
128
+ caption: str,
129
+ original_size: tuple[int, int],
130
+ bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None,
131
+ frame_count: Optional[int] = None,
132
+ content: Optional[np.ndarray] = None,
133
+ latent_cache_path: Optional[str] = None,
134
+ ) -> None:
135
+ self.item_key = item_key
136
+ self.caption = caption
137
+ self.original_size = original_size
138
+ self.bucket_size = bucket_size
139
+ self.frame_count = frame_count
140
+ self.content = content
141
+ self.latent_cache_path = latent_cache_path
142
+ self.text_encoder_output_cache_path: Optional[str] = None
143
+
144
+ def __str__(self) -> str:
145
+ return (
146
+ f"ItemInfo(item_key={self.item_key}, caption={self.caption}, "
147
+ + f"original_size={self.original_size}, bucket_size={self.bucket_size}, "
148
+ + f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path})"
149
+ )
150
+
151
+
152
+ def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor):
153
+ assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
154
+ metadata = {
155
+ "architecture": "hunyuan_video",
156
+ "width": f"{item_info.original_size[0]}",
157
+ "height": f"{item_info.original_size[1]}",
158
+ "format_version": "1.0.0",
159
+ }
160
+ if item_info.frame_count is not None:
161
+ metadata["frame_count"] = f"{item_info.frame_count}"
162
+
163
+ _, F, H, W = latent.shape
164
+ dtype_str = dtype_to_str(latent.dtype)
165
+ sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
166
+
167
+ latent_dir = os.path.dirname(item_info.latent_cache_path)
168
+ os.makedirs(latent_dir, exist_ok=True)
169
+
170
+ save_file(sd, item_info.latent_cache_path, metadata=metadata)
171
+
172
+
173
+ def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool):
174
+ assert (
175
+ embed.dim() == 1 or embed.dim() == 2
176
+ ), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}"
177
+ assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}"
178
+ metadata = {
179
+ "architecture": "hunyuan_video",
180
+ "caption1": item_info.caption,
181
+ "format_version": "1.0.0",
182
+ }
183
+
184
+ sd = {}
185
+ if os.path.exists(item_info.text_encoder_output_cache_path):
186
+ # load existing cache and update metadata
187
+ with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f:
188
+ existing_metadata = f.metadata()
189
+ for key in f.keys():
190
+ sd[key] = f.get_tensor(key)
191
+
192
+ assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch"
193
+ if existing_metadata["caption1"] != metadata["caption1"]:
194
+ logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite")
195
+ # TODO verify format_version
196
+
197
+ existing_metadata.pop("caption1", None)
198
+ existing_metadata.pop("format_version", None)
199
+ metadata.update(existing_metadata) # copy existing metadata
200
+ else:
201
+ text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path)
202
+ os.makedirs(text_encoder_output_dir, exist_ok=True)
203
+
204
+ dtype_str = dtype_to_str(embed.dtype)
205
+ text_encoder_type = "llm" if is_llm else "clipL"
206
+ sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
207
+ if mask is not None:
208
+ sd[f"{text_encoder_type}_mask"] = mask.detach().cpu()
209
+
210
+ safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata)
211
+
212
+
213
+ class BucketSelector:
214
+ RESOLUTION_STEPS_HUNYUAN = 16
215
+
216
+ def __init__(self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False):
217
+ self.resolution = resolution
218
+ self.bucket_area = resolution[0] * resolution[1]
219
+ self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN
220
+
221
+ if not enable_bucket:
222
+ # only define one bucket
223
+ self.bucket_resolutions = [resolution]
224
+ self.no_upscale = False
225
+ else:
226
+ # prepare bucket resolution
227
+ self.no_upscale = no_upscale
228
+ sqrt_size = int(math.sqrt(self.bucket_area))
229
+ min_size = divisible_by(sqrt_size // 2, self.reso_steps)
230
+ self.bucket_resolutions = []
231
+ for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps):
232
+ h = divisible_by(self.bucket_area // w, self.reso_steps)
233
+ self.bucket_resolutions.append((w, h))
234
+ self.bucket_resolutions.append((h, w))
235
+
236
+ self.bucket_resolutions = list(set(self.bucket_resolutions))
237
+ self.bucket_resolutions.sort()
238
+
239
+ # calculate aspect ratio to find the nearest resolution
240
+ self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions])
241
+
242
+ def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]:
243
+ """
244
+ return the bucket resolution for the given image size, (width, height)
245
+ """
246
+ area = image_size[0] * image_size[1]
247
+ if self.no_upscale and area <= self.bucket_area:
248
+ w, h = image_size
249
+ w = divisible_by(w, self.reso_steps)
250
+ h = divisible_by(h, self.reso_steps)
251
+ return w, h
252
+
253
+ aspect_ratio = image_size[0] / image_size[1]
254
+ ar_errors = self.aspect_ratios - aspect_ratio
255
+ bucket_id = np.abs(ar_errors).argmin()
256
+ return self.bucket_resolutions[bucket_id]
257
+
258
+
259
+ def load_video(
260
+ video_path: str,
261
+ start_frame: Optional[int] = None,
262
+ end_frame: Optional[int] = None,
263
+ bucket_selector: Optional[BucketSelector] = None,
264
+ ) -> list[np.ndarray]:
265
+ container = av.open(video_path)
266
+ video = []
267
+ bucket_reso = None
268
+ for i, frame in enumerate(container.decode(video=0)):
269
+ if start_frame is not None and i < start_frame:
270
+ continue
271
+ if end_frame is not None and i >= end_frame:
272
+ break
273
+ frame = frame.to_image()
274
+
275
+ if bucket_selector is not None and bucket_reso is None:
276
+ bucket_reso = bucket_selector.get_bucket_resolution(frame.size)
277
+
278
+ if bucket_reso is not None:
279
+ frame = resize_image_to_bucket(frame, bucket_reso)
280
+ else:
281
+ frame = np.array(frame)
282
+
283
+ video.append(frame)
284
+ container.close()
285
+ return video
286
+
287
+
288
+ class BucketBatchManager:
289
+
290
+ def __init__(self, bucketed_item_info: dict[tuple[int, int], list[ItemInfo]], batch_size: int):
291
+ self.batch_size = batch_size
292
+ self.buckets = bucketed_item_info
293
+ self.bucket_resos = list(self.buckets.keys())
294
+ self.bucket_resos.sort()
295
+
296
+ self.bucket_batch_indices = []
297
+ for bucket_reso in self.bucket_resos:
298
+ bucket = self.buckets[bucket_reso]
299
+ num_batches = math.ceil(len(bucket) / self.batch_size)
300
+ for i in range(num_batches):
301
+ self.bucket_batch_indices.append((bucket_reso, i))
302
+
303
+ self.shuffle()
304
+
305
+ def show_bucket_info(self):
306
+ for bucket_reso in self.bucket_resos:
307
+ bucket = self.buckets[bucket_reso]
308
+ logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}")
309
+
310
+ logger.info(f"total batches: {len(self)}")
311
+
312
+ def shuffle(self):
313
+ for bucket in self.buckets.values():
314
+ random.shuffle(bucket)
315
+ random.shuffle(self.bucket_batch_indices)
316
+
317
+ def __len__(self):
318
+ return len(self.bucket_batch_indices)
319
+
320
+ def __getitem__(self, idx):
321
+ bucket_reso, batch_idx = self.bucket_batch_indices[idx]
322
+ bucket = self.buckets[bucket_reso]
323
+ start = batch_idx * self.batch_size
324
+ end = min(start + self.batch_size, len(bucket))
325
+
326
+ latents = []
327
+ llm_embeds = []
328
+ llm_masks = []
329
+ clip_l_embeds = []
330
+ for item_info in bucket[start:end]:
331
+ sd = load_file(item_info.latent_cache_path)
332
+ latent = None
333
+ for key in sd.keys():
334
+ if key.startswith("latents_"):
335
+ latent = sd[key]
336
+ break
337
+ latents.append(latent)
338
+
339
+ sd = load_file(item_info.text_encoder_output_cache_path)
340
+ llm_embed = llm_mask = clip_l_embed = None
341
+ for key in sd.keys():
342
+ if key.startswith("llm_mask"):
343
+ llm_mask = sd[key]
344
+ elif key.startswith("llm_"):
345
+ llm_embed = sd[key]
346
+ elif key.startswith("clipL_mask"):
347
+ pass
348
+ elif key.startswith("clipL_"):
349
+ clip_l_embed = sd[key]
350
+ llm_embeds.append(llm_embed)
351
+ llm_masks.append(llm_mask)
352
+ clip_l_embeds.append(clip_l_embed)
353
+
354
+ latents = torch.stack(latents)
355
+ llm_embeds = torch.stack(llm_embeds)
356
+ llm_masks = torch.stack(llm_masks)
357
+ clip_l_embeds = torch.stack(clip_l_embeds)
358
+
359
+ return latents, llm_embeds, llm_masks, clip_l_embeds
360
+
361
+
362
+ class ContentDatasource:
363
+ def __init__(self):
364
+ self.caption_only = False
365
+
366
+ def set_caption_only(self, caption_only: bool):
367
+ self.caption_only = caption_only
368
+
369
+ def is_indexable(self):
370
+ return False
371
+
372
+ def get_caption(self, idx: int) -> tuple[str, str]:
373
+ """
374
+ Returns caption. May not be called if is_indexable() returns False.
375
+ """
376
+ raise NotImplementedError
377
+
378
+ def __len__(self):
379
+ raise NotImplementedError
380
+
381
+ def __iter__(self):
382
+ raise NotImplementedError
383
+
384
+ def __next__(self):
385
+ raise NotImplementedError
386
+
387
+
388
+ class ImageDatasource(ContentDatasource):
389
+ def __init__(self):
390
+ super().__init__()
391
+
392
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
393
+ """
394
+ Returns image data as a tuple of image path, image, and caption for the given index.
395
+ Key must be unique and valid as a file name.
396
+ May not be called if is_indexable() returns False.
397
+ """
398
+ raise NotImplementedError
399
+
400
+
401
+ class ImageDirectoryDatasource(ImageDatasource):
402
+ def __init__(self, image_directory: str, caption_extension: Optional[str] = None):
403
+ super().__init__()
404
+ self.image_directory = image_directory
405
+ self.caption_extension = caption_extension
406
+ self.current_idx = 0
407
+
408
+ # glob images
409
+ logger.info(f"glob images in {self.image_directory}")
410
+ self.image_paths = glob_images(self.image_directory)
411
+ logger.info(f"found {len(self.image_paths)} images")
412
+
413
+ def is_indexable(self):
414
+ return True
415
+
416
+ def __len__(self):
417
+ return len(self.image_paths)
418
+
419
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
420
+ image_path = self.image_paths[idx]
421
+ image = Image.open(image_path).convert("RGB")
422
+
423
+ _, caption = self.get_caption(idx)
424
+
425
+ return image_path, image, caption
426
+
427
+ def get_caption(self, idx: int) -> tuple[str, str]:
428
+ image_path = self.image_paths[idx]
429
+ caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else ""
430
+ with open(caption_path, "r", encoding="utf-8") as f:
431
+ caption = f.read().strip()
432
+ return image_path, caption
433
+
434
+ def __iter__(self):
435
+ self.current_idx = 0
436
+ return self
437
+
438
+ def __next__(self) -> callable:
439
+ """
440
+ Returns a fetcher function that returns image data.
441
+ """
442
+ if self.current_idx >= len(self.image_paths):
443
+ raise StopIteration
444
+
445
+ if self.caption_only:
446
+
447
+ def create_caption_fetcher(index):
448
+ return lambda: self.get_caption(index)
449
+
450
+ fetcher = create_caption_fetcher(self.current_idx)
451
+ else:
452
+
453
+ def create_image_fetcher(index):
454
+ return lambda: self.get_image_data(index)
455
+
456
+ fetcher = create_image_fetcher(self.current_idx)
457
+
458
+ self.current_idx += 1
459
+ return fetcher
460
+
461
+
462
+ class ImageJsonlDatasource(ImageDatasource):
463
+ def __init__(self, image_jsonl_file: str):
464
+ super().__init__()
465
+ self.image_jsonl_file = image_jsonl_file
466
+ self.current_idx = 0
467
+
468
+ # load jsonl
469
+ logger.info(f"load image jsonl from {self.image_jsonl_file}")
470
+ self.data = []
471
+ with open(self.image_jsonl_file, "r", encoding="utf-8") as f:
472
+ for line in f:
473
+ data = json.loads(line)
474
+ self.data.append(data)
475
+ logger.info(f"loaded {len(self.data)} images")
476
+
477
+ def is_indexable(self):
478
+ return True
479
+
480
+ def __len__(self):
481
+ return len(self.data)
482
+
483
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
484
+ data = self.data[idx]
485
+ image_path = data["image_path"]
486
+ image = Image.open(image_path).convert("RGB")
487
+
488
+ caption = data["caption"]
489
+
490
+ return image_path, image, caption
491
+
492
+ def get_caption(self, idx: int) -> tuple[str, str]:
493
+ data = self.data[idx]
494
+ image_path = data["image_path"]
495
+ caption = data["caption"]
496
+ return image_path, caption
497
+
498
+ def __iter__(self):
499
+ self.current_idx = 0
500
+ return self
501
+
502
+ def __next__(self) -> callable:
503
+ if self.current_idx >= len(self.data):
504
+ raise StopIteration
505
+
506
+ if self.caption_only:
507
+
508
+ def create_caption_fetcher(index):
509
+ return lambda: self.get_caption(index)
510
+
511
+ fetcher = create_caption_fetcher(self.current_idx)
512
+
513
+ else:
514
+
515
+ def create_fetcher(index):
516
+ return lambda: self.get_image_data(index)
517
+
518
+ fetcher = create_fetcher(self.current_idx)
519
+
520
+ self.current_idx += 1
521
+ return fetcher
522
+
523
+
524
+ class VideoDatasource(ContentDatasource):
525
+ def __init__(self):
526
+ super().__init__()
527
+
528
+ # None means all frames
529
+ self.start_frame = None
530
+ self.end_frame = None
531
+
532
+ self.bucket_selector = None
533
+
534
+ def __len__(self):
535
+ raise NotImplementedError
536
+
537
+ def get_video_data_from_path(
538
+ self,
539
+ video_path: str,
540
+ start_frame: Optional[int] = None,
541
+ end_frame: Optional[int] = None,
542
+ bucket_selector: Optional[BucketSelector] = None,
543
+ ) -> tuple[str, list[Image.Image], str]:
544
+ # this method can resize the video if bucket_selector is given to reduce the memory usage
545
+
546
+ start_frame = start_frame if start_frame is not None else self.start_frame
547
+ end_frame = end_frame if end_frame is not None else self.end_frame
548
+ bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
549
+
550
+ video = load_video(video_path, start_frame, end_frame, bucket_selector)
551
+ return video
552
+
553
+ def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]):
554
+ self.start_frame = start_frame
555
+ self.end_frame = end_frame
556
+
557
+ def set_bucket_selector(self, bucket_selector: BucketSelector):
558
+ self.bucket_selector = bucket_selector
559
+
560
+ def __iter__(self):
561
+ raise NotImplementedError
562
+
563
+ def __next__(self):
564
+ raise NotImplementedError
565
+
566
+
567
+ class VideoDirectoryDatasource(VideoDatasource):
568
+ def __init__(self, video_directory: str, caption_extension: Optional[str] = None):
569
+ super().__init__()
570
+ self.video_directory = video_directory
571
+ self.caption_extension = caption_extension
572
+ self.current_idx = 0
573
+
574
+ # glob images
575
+ logger.info(f"glob images in {self.video_directory}")
576
+ self.video_paths = glob_videos(self.video_directory)
577
+ logger.info(f"found {len(self.video_paths)} videos")
578
+
579
+ def is_indexable(self):
580
+ return True
581
+
582
+ def __len__(self):
583
+ return len(self.video_paths)
584
+
585
+ def get_video_data(
586
+ self,
587
+ idx: int,
588
+ start_frame: Optional[int] = None,
589
+ end_frame: Optional[int] = None,
590
+ bucket_selector: Optional[BucketSelector] = None,
591
+ ) -> tuple[str, list[Image.Image], str]:
592
+ video_path = self.video_paths[idx]
593
+ video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
594
+
595
+ _, caption = self.get_caption(idx)
596
+
597
+ return video_path, video, caption
598
+
599
+ def get_caption(self, idx: int) -> tuple[str, str]:
600
+ video_path = self.video_paths[idx]
601
+ caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else ""
602
+ with open(caption_path, "r", encoding="utf-8") as f:
603
+ caption = f.read().strip()
604
+ return video_path, caption
605
+
606
+ def __iter__(self):
607
+ self.current_idx = 0
608
+ return self
609
+
610
+ def __next__(self):
611
+ if self.current_idx >= len(self.video_paths):
612
+ raise StopIteration
613
+
614
+ if self.caption_only:
615
+
616
+ def create_caption_fetcher(index):
617
+ return lambda: self.get_caption(index)
618
+
619
+ fetcher = create_caption_fetcher(self.current_idx)
620
+
621
+ else:
622
+
623
+ def create_fetcher(index):
624
+ return lambda: self.get_video_data(index)
625
+
626
+ fetcher = create_fetcher(self.current_idx)
627
+
628
+ self.current_idx += 1
629
+ return fetcher
630
+
631
+
632
+ class VideoJsonlDatasource(VideoDatasource):
633
+ def __init__(self, video_jsonl_file: str):
634
+ super().__init__()
635
+ self.video_jsonl_file = video_jsonl_file
636
+ self.current_idx = 0
637
+
638
+ # load jsonl
639
+ logger.info(f"load video jsonl from {self.video_jsonl_file}")
640
+ self.data = []
641
+ with open(self.video_jsonl_file, "r", encoding="utf-8") as f:
642
+ for line in f:
643
+ data = json.loads(line)
644
+ self.data.append(data)
645
+ logger.info(f"loaded {len(self.data)} videos")
646
+
647
+ def is_indexable(self):
648
+ return True
649
+
650
+ def __len__(self):
651
+ return len(self.data)
652
+
653
+ def get_video_data(
654
+ self,
655
+ idx: int,
656
+ start_frame: Optional[int] = None,
657
+ end_frame: Optional[int] = None,
658
+ bucket_selector: Optional[BucketSelector] = None,
659
+ ) -> tuple[str, list[Image.Image], str]:
660
+ data = self.data[idx]
661
+ video_path = data["video_path"]
662
+ video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
663
+
664
+ caption = data["caption"]
665
+
666
+ return video_path, video, caption
667
+
668
+ def get_caption(self, idx: int) -> tuple[str, str]:
669
+ data = self.data[idx]
670
+ video_path = data["video_path"]
671
+ caption = data["caption"]
672
+ return video_path, caption
673
+
674
+ def __iter__(self):
675
+ self.current_idx = 0
676
+ return self
677
+
678
+ def __next__(self):
679
+ if self.current_idx >= len(self.data):
680
+ raise StopIteration
681
+
682
+ if self.caption_only:
683
+
684
+ def create_caption_fetcher(index):
685
+ return lambda: self.get_caption(index)
686
+
687
+ fetcher = create_caption_fetcher(self.current_idx)
688
+
689
+ else:
690
+
691
+ def create_fetcher(index):
692
+ return lambda: self.get_video_data(index)
693
+
694
+ fetcher = create_fetcher(self.current_idx)
695
+
696
+ self.current_idx += 1
697
+ return fetcher
698
+
699
+
700
+ class BaseDataset(torch.utils.data.Dataset):
701
+ def __init__(
702
+ self,
703
+ resolution: Tuple[int, int] = (960, 544),
704
+ caption_extension: Optional[str] = None,
705
+ batch_size: int = 1,
706
+ enable_bucket: bool = False,
707
+ bucket_no_upscale: bool = False,
708
+ cache_directory: Optional[str] = None,
709
+ debug_dataset: bool = False,
710
+ ):
711
+ self.resolution = resolution
712
+ self.caption_extension = caption_extension
713
+ self.batch_size = batch_size
714
+ self.enable_bucket = enable_bucket
715
+ self.bucket_no_upscale = bucket_no_upscale
716
+ self.cache_directory = cache_directory
717
+ self.debug_dataset = debug_dataset
718
+ self.seed = None
719
+ self.current_epoch = 0
720
+
721
+ if not self.enable_bucket:
722
+ self.bucket_no_upscale = False
723
+
724
+ def get_metadata(self) -> dict:
725
+ metadata = {
726
+ "resolution": self.resolution,
727
+ "caption_extension": self.caption_extension,
728
+ "batch_size_per_device": self.batch_size,
729
+ "enable_bucket": bool(self.enable_bucket),
730
+ "bucket_no_upscale": bool(self.bucket_no_upscale),
731
+ }
732
+ return metadata
733
+
734
+ def get_latent_cache_path(self, item_info: ItemInfo) -> str:
735
+ w, h = item_info.original_size
736
+ basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
737
+ assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
738
+ return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors")
739
+
740
+ def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str:
741
+ basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
742
+ assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
743
+ return os.path.join(self.cache_directory, f"{basename}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors")
744
+
745
+ def retrieve_latent_cache_batches(self, num_workers: int):
746
+ raise NotImplementedError
747
+
748
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
749
+ raise NotImplementedError
750
+
751
+ def prepare_for_training(self):
752
+ pass
753
+
754
+ def set_seed(self, seed: int):
755
+ self.seed = seed
756
+
757
+ def set_current_epoch(self, epoch):
758
+ if not self.current_epoch == epoch: # shuffle buckets when epoch is incremented
759
+ if epoch > self.current_epoch:
760
+ logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
761
+ num_epochs = epoch - self.current_epoch
762
+ for _ in range(num_epochs):
763
+ self.current_epoch += 1
764
+ self.shuffle_buckets()
765
+ # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
766
+ else:
767
+ logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
768
+ self.current_epoch = epoch
769
+
770
+ def set_current_step(self, step):
771
+ self.current_step = step
772
+
773
+ def set_max_train_steps(self, max_train_steps):
774
+ self.max_train_steps = max_train_steps
775
+
776
+ def shuffle_buckets(self):
777
+ raise NotImplementedError
778
+
779
+ def __len__(self):
780
+ return NotImplementedError
781
+
782
+ def __getitem__(self, idx):
783
+ raise NotImplementedError
784
+
785
+ def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int):
786
+ datasource.set_caption_only(True)
787
+ executor = ThreadPoolExecutor(max_workers=num_workers)
788
+
789
+ data: list[ItemInfo] = []
790
+ futures = []
791
+
792
+ def aggregate_future(consume_all: bool = False):
793
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
794
+ completed_futures = [future for future in futures if future.done()]
795
+ if len(completed_futures) == 0:
796
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
797
+ time.sleep(0.1)
798
+ continue
799
+ else:
800
+ break # submit batch if possible
801
+
802
+ for future in completed_futures:
803
+ item_key, caption = future.result()
804
+ item_info = ItemInfo(item_key, caption, (0, 0), (0, 0))
805
+ item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info)
806
+ data.append(item_info)
807
+
808
+ futures.remove(future)
809
+
810
+ def submit_batch(flush: bool = False):
811
+ nonlocal data
812
+ if len(data) >= batch_size or (len(data) > 0 and flush):
813
+ batch = data[0:batch_size]
814
+ if len(data) > batch_size:
815
+ data = data[batch_size:]
816
+ else:
817
+ data = []
818
+ return batch
819
+ return None
820
+
821
+ for fetch_op in datasource:
822
+ future = executor.submit(fetch_op)
823
+ futures.append(future)
824
+ aggregate_future()
825
+ while True:
826
+ batch = submit_batch()
827
+ if batch is None:
828
+ break
829
+ yield batch
830
+
831
+ aggregate_future(consume_all=True)
832
+ while True:
833
+ batch = submit_batch(flush=True)
834
+ if batch is None:
835
+ break
836
+ yield batch
837
+
838
+ executor.shutdown()
839
+
840
+
841
+ class ImageDataset(BaseDataset):
842
+ def __init__(
843
+ self,
844
+ resolution: Tuple[int, int],
845
+ caption_extension: Optional[str],
846
+ batch_size: int,
847
+ enable_bucket: bool,
848
+ bucket_no_upscale: bool,
849
+ image_directory: Optional[str] = None,
850
+ image_jsonl_file: Optional[str] = None,
851
+ cache_directory: Optional[str] = None,
852
+ debug_dataset: bool = False,
853
+ ):
854
+ super(ImageDataset, self).__init__(
855
+ resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset
856
+ )
857
+ self.image_directory = image_directory
858
+ self.image_jsonl_file = image_jsonl_file
859
+ if image_directory is not None:
860
+ self.datasource = ImageDirectoryDatasource(image_directory, caption_extension)
861
+ elif image_jsonl_file is not None:
862
+ self.datasource = ImageJsonlDatasource(image_jsonl_file)
863
+ else:
864
+ raise ValueError("image_directory or image_jsonl_file must be specified")
865
+
866
+ if self.cache_directory is None:
867
+ self.cache_directory = self.image_directory
868
+
869
+ self.batch_manager = None
870
+ self.num_train_items = 0
871
+
872
+ def get_metadata(self):
873
+ metadata = super().get_metadata()
874
+ if self.image_directory is not None:
875
+ metadata["image_directory"] = os.path.basename(self.image_directory)
876
+ if self.image_jsonl_file is not None:
877
+ metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file)
878
+ return metadata
879
+
880
+ def get_total_image_count(self):
881
+ return len(self.datasource) if self.datasource.is_indexable() else None
882
+
883
+ def retrieve_latent_cache_batches(self, num_workers: int):
884
+ buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale)
885
+ executor = ThreadPoolExecutor(max_workers=num_workers)
886
+
887
+ batches: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
888
+ futures = []
889
+
890
+ def aggregate_future(consume_all: bool = False):
891
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
892
+ completed_futures = [future for future in futures if future.done()]
893
+ if len(completed_futures) == 0:
894
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
895
+ time.sleep(0.1)
896
+ continue
897
+ else:
898
+ break # submit batch if possible
899
+
900
+ for future in completed_futures:
901
+ original_size, item_key, image, caption = future.result()
902
+ bucket_height, bucket_width = image.shape[:2]
903
+ bucket_reso = (bucket_width, bucket_height)
904
+
905
+ item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image)
906
+ item_info.latent_cache_path = self.get_latent_cache_path(item_info)
907
+
908
+ if bucket_reso not in batches:
909
+ batches[bucket_reso] = []
910
+ batches[bucket_reso].append(item_info)
911
+
912
+ futures.remove(future)
913
+
914
+ def submit_batch(flush: bool = False):
915
+ for key in batches:
916
+ if len(batches[key]) >= self.batch_size or flush:
917
+ batch = batches[key][0 : self.batch_size]
918
+ if len(batches[key]) > self.batch_size:
919
+ batches[key] = batches[key][self.batch_size :]
920
+ else:
921
+ del batches[key]
922
+ return key, batch
923
+ return None, None
924
+
925
+ for fetch_op in self.datasource:
926
+
927
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str]:
928
+ image_key, image, caption = op()
929
+ image: Image.Image
930
+ image_size = image.size
931
+
932
+ bucket_reso = buckset_selector.get_bucket_resolution(image_size)
933
+ image = resize_image_to_bucket(image, bucket_reso)
934
+ return image_size, image_key, image, caption
935
+
936
+ future = executor.submit(fetch_and_resize, fetch_op)
937
+ futures.append(future)
938
+ aggregate_future()
939
+ while True:
940
+ key, batch = submit_batch()
941
+ if key is None:
942
+ break
943
+ yield key, batch
944
+
945
+ aggregate_future(consume_all=True)
946
+ while True:
947
+ key, batch = submit_batch(flush=True)
948
+ if key is None:
949
+ break
950
+ yield key, batch
951
+
952
+ executor.shutdown()
953
+
954
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
955
+ return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
956
+
957
+ def prepare_for_training(self):
958
+ bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale)
959
+
960
+ # glob cache files
961
+ latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors"))
962
+
963
+ # assign cache files to item info
964
+ bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
965
+ for cache_file in latent_cache_files:
966
+ tokens = os.path.basename(cache_file).split("_")
967
+
968
+ image_size = tokens[-2] # 0000x0000
969
+ image_width, image_height = map(int, image_size.split("x"))
970
+ image_size = (image_width, image_height)
971
+
972
+ item_key = "_".join(tokens[:-2])
973
+ text_encoder_output_cache_file = os.path.join(
974
+ self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors"
975
+ )
976
+ if not os.path.exists(text_encoder_output_cache_file):
977
+ logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
978
+ continue
979
+
980
+ bucket_reso = bucket_selector.get_bucket_resolution(image_size)
981
+ item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file)
982
+ item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
983
+
984
+ bucket = bucketed_item_info.get(bucket_reso, [])
985
+ bucket.append(item_info)
986
+ bucketed_item_info[bucket_reso] = bucket
987
+
988
+ # prepare batch manager
989
+ self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
990
+ self.batch_manager.show_bucket_info()
991
+
992
+ self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
993
+
994
+ def shuffle_buckets(self):
995
+ # set random seed for this epoch
996
+ random.seed(self.seed + self.current_epoch)
997
+ self.batch_manager.shuffle()
998
+
999
+ def __len__(self):
1000
+ if self.batch_manager is None:
1001
+ return 100 # dummy value
1002
+ return len(self.batch_manager)
1003
+
1004
+ def __getitem__(self, idx):
1005
+ return self.batch_manager[idx]
1006
+
1007
+
1008
+ class VideoDataset(BaseDataset):
1009
+ def __init__(
1010
+ self,
1011
+ resolution: Tuple[int, int],
1012
+ caption_extension: Optional[str],
1013
+ batch_size: int,
1014
+ enable_bucket: bool,
1015
+ bucket_no_upscale: bool,
1016
+ frame_extraction: Optional[str] = "head",
1017
+ frame_stride: Optional[int] = 1,
1018
+ frame_sample: Optional[int] = 1,
1019
+ target_frames: Optional[list[int]] = None,
1020
+ video_directory: Optional[str] = None,
1021
+ video_jsonl_file: Optional[str] = None,
1022
+ cache_directory: Optional[str] = None,
1023
+ debug_dataset: bool = False,
1024
+ ):
1025
+ super(VideoDataset, self).__init__(
1026
+ resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset
1027
+ )
1028
+ self.video_directory = video_directory
1029
+ self.video_jsonl_file = video_jsonl_file
1030
+ self.target_frames = target_frames
1031
+ self.frame_extraction = frame_extraction
1032
+ self.frame_stride = frame_stride
1033
+ self.frame_sample = frame_sample
1034
+
1035
+ if video_directory is not None:
1036
+ self.datasource = VideoDirectoryDatasource(video_directory, caption_extension)
1037
+ elif video_jsonl_file is not None:
1038
+ self.datasource = VideoJsonlDatasource(video_jsonl_file)
1039
+
1040
+ if self.frame_extraction == "uniform" and self.frame_sample == 1:
1041
+ self.frame_extraction = "head"
1042
+ logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.")
1043
+ if self.frame_extraction == "head":
1044
+ # head extraction. we can limit the number of frames to be extracted
1045
+ self.datasource.set_start_and_end_frame(0, max(self.target_frames))
1046
+
1047
+ if self.cache_directory is None:
1048
+ self.cache_directory = self.video_directory
1049
+
1050
+ self.batch_manager = None
1051
+ self.num_train_items = 0
1052
+
1053
+ def get_metadata(self):
1054
+ metadata = super().get_metadata()
1055
+ if self.video_directory is not None:
1056
+ metadata["video_directory"] = os.path.basename(self.video_directory)
1057
+ if self.video_jsonl_file is not None:
1058
+ metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file)
1059
+ metadata["frame_extraction"] = self.frame_extraction
1060
+ metadata["frame_stride"] = self.frame_stride
1061
+ metadata["frame_sample"] = self.frame_sample
1062
+ metadata["target_frames"] = self.target_frames
1063
+ return metadata
1064
+
1065
+ def retrieve_latent_cache_batches(self, num_workers: int):
1066
+ buckset_selector = BucketSelector(self.resolution)
1067
+ self.datasource.set_bucket_selector(buckset_selector)
1068
+
1069
+ executor = ThreadPoolExecutor(max_workers=num_workers)
1070
+
1071
+ # key: (width, height, frame_count), value: [ItemInfo]
1072
+ batches: dict[tuple[int, int, int], list[ItemInfo]] = {}
1073
+ futures = []
1074
+
1075
+ def aggregate_future(consume_all: bool = False):
1076
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
1077
+ completed_futures = [future for future in futures if future.done()]
1078
+ if len(completed_futures) == 0:
1079
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
1080
+ time.sleep(0.1)
1081
+ continue
1082
+ else:
1083
+ break # submit batch if possible
1084
+
1085
+ for future in completed_futures:
1086
+ original_frame_size, video_key, video, caption = future.result()
1087
+
1088
+ frame_count = len(video)
1089
+ video = np.stack(video, axis=0)
1090
+ height, width = video.shape[1:3]
1091
+ bucket_reso = (width, height) # already resized
1092
+
1093
+ crop_pos_and_frames = []
1094
+ if self.frame_extraction == "head":
1095
+ for target_frame in self.target_frames:
1096
+ if frame_count >= target_frame:
1097
+ crop_pos_and_frames.append((0, target_frame))
1098
+ elif self.frame_extraction == "chunk":
1099
+ # split by target_frames
1100
+ for target_frame in self.target_frames:
1101
+ for i in range(0, frame_count, target_frame):
1102
+ if i + target_frame <= frame_count:
1103
+ crop_pos_and_frames.append((i, target_frame))
1104
+ elif self.frame_extraction == "slide":
1105
+ # slide window
1106
+ for target_frame in self.target_frames:
1107
+ if frame_count >= target_frame:
1108
+ for i in range(0, frame_count - target_frame + 1, self.frame_stride):
1109
+ crop_pos_and_frames.append((i, target_frame))
1110
+ elif self.frame_extraction == "uniform":
1111
+ # select N frames uniformly
1112
+ for target_frame in self.target_frames:
1113
+ if frame_count >= target_frame:
1114
+ frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int)
1115
+ for i in frame_indices:
1116
+ crop_pos_and_frames.append((i, target_frame))
1117
+ else:
1118
+ raise ValueError(f"frame_extraction {self.frame_extraction} is not supported")
1119
+
1120
+ for crop_pos, target_frame in crop_pos_and_frames:
1121
+ cropped_video = video[crop_pos : crop_pos + target_frame]
1122
+ body, ext = os.path.splitext(video_key)
1123
+ item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}"
1124
+ batch_key = (*bucket_reso, target_frame) # bucket_reso with frame_count
1125
+
1126
+ item_info = ItemInfo(
1127
+ item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video
1128
+ )
1129
+ item_info.latent_cache_path = self.get_latent_cache_path(item_info)
1130
+
1131
+ batch = batches.get(batch_key, [])
1132
+ batch.append(item_info)
1133
+ batches[batch_key] = batch
1134
+
1135
+ futures.remove(future)
1136
+
1137
+ def submit_batch(flush: bool = False):
1138
+ for key in batches:
1139
+ if len(batches[key]) >= self.batch_size or flush:
1140
+ batch = batches[key][0 : self.batch_size]
1141
+ if len(batches[key]) > self.batch_size:
1142
+ batches[key] = batches[key][self.batch_size :]
1143
+ else:
1144
+ del batches[key]
1145
+ return key, batch
1146
+ return None, None
1147
+
1148
+ for operator in self.datasource:
1149
+
1150
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str]:
1151
+ video_key, video, caption = op()
1152
+ video: list[np.ndarray]
1153
+ frame_size = (video[0].shape[1], video[0].shape[0])
1154
+
1155
+ # resize if necessary
1156
+ bucket_reso = buckset_selector.get_bucket_resolution(frame_size)
1157
+ video = [resize_image_to_bucket(frame, bucket_reso) for frame in video]
1158
+
1159
+ return frame_size, video_key, video, caption
1160
+
1161
+ future = executor.submit(fetch_and_resize, operator)
1162
+ futures.append(future)
1163
+ aggregate_future()
1164
+ while True:
1165
+ key, batch = submit_batch()
1166
+ if key is None:
1167
+ break
1168
+ yield key, batch
1169
+
1170
+ aggregate_future(consume_all=True)
1171
+ while True:
1172
+ key, batch = submit_batch(flush=True)
1173
+ if key is None:
1174
+ break
1175
+ yield key, batch
1176
+
1177
+ executor.shutdown()
1178
+
1179
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
1180
+ return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
1181
+
1182
+ def prepare_for_training(self):
1183
+ bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale)
1184
+
1185
+ # glob cache files
1186
+ latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors"))
1187
+
1188
+ # assign cache files to item info
1189
+ bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} # (width, height, frame_count) -> [ItemInfo]
1190
+ for cache_file in latent_cache_files:
1191
+ tokens = os.path.basename(cache_file).split("_")
1192
+
1193
+ image_size = tokens[-2] # 0000x0000
1194
+ image_width, image_height = map(int, image_size.split("x"))
1195
+ image_size = (image_width, image_height)
1196
+
1197
+ frame_pos, frame_count = tokens[-3].split("-")
1198
+ frame_pos, frame_count = int(frame_pos), int(frame_count)
1199
+
1200
+ item_key = "_".join(tokens[:-3])
1201
+ text_encoder_output_cache_file = os.path.join(
1202
+ self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors"
1203
+ )
1204
+ if not os.path.exists(text_encoder_output_cache_file):
1205
+ logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
1206
+ continue
1207
+
1208
+ bucket_reso = bucket_selector.get_bucket_resolution(image_size)
1209
+ bucket_reso = (*bucket_reso, frame_count)
1210
+ item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file)
1211
+ item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
1212
+
1213
+ bucket = bucketed_item_info.get(bucket_reso, [])
1214
+ bucket.append(item_info)
1215
+ bucketed_item_info[bucket_reso] = bucket
1216
+
1217
+ # prepare batch manager
1218
+ self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
1219
+ self.batch_manager.show_bucket_info()
1220
+
1221
+ self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
1222
+
1223
+ def shuffle_buckets(self):
1224
+ # set random seed for this epoch
1225
+ random.seed(self.seed + self.current_epoch)
1226
+ self.batch_manager.shuffle()
1227
+
1228
+ def __len__(self):
1229
+ if self.batch_manager is None:
1230
+ return 100 # dummy value
1231
+ return len(self.batch_manager)
1232
+
1233
+ def __getitem__(self, idx):
1234
+ return self.batch_manager[idx]
1235
+
1236
+
1237
+ class DatasetGroup(torch.utils.data.ConcatDataset):
1238
+ def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]):
1239
+ super().__init__(datasets)
1240
+ self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets
1241
+ self.num_train_items = 0
1242
+ for dataset in self.datasets:
1243
+ self.num_train_items += dataset.num_train_items
1244
+
1245
+ def set_current_epoch(self, epoch):
1246
+ for dataset in self.datasets:
1247
+ dataset.set_current_epoch(epoch)
1248
+
1249
+ def set_current_step(self, step):
1250
+ for dataset in self.datasets:
1251
+ dataset.set_current_step(step)
1252
+
1253
+ def set_max_train_steps(self, max_train_steps):
1254
+ for dataset in self.datasets:
1255
+ dataset.set_max_train_steps(max_train_steps)
hunyuan_model/__init__.py ADDED
File without changes
hunyuan_model/activation_layers.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def get_activation_layer(act_type):
5
+ """get activation layer
6
+
7
+ Args:
8
+ act_type (str): the activation type
9
+
10
+ Returns:
11
+ torch.nn.functional: the activation layer
12
+ """
13
+ if act_type == "gelu":
14
+ return lambda: nn.GELU()
15
+ elif act_type == "gelu_tanh":
16
+ # Approximate `tanh` requires torch >= 1.13
17
+ return lambda: nn.GELU(approximate="tanh")
18
+ elif act_type == "relu":
19
+ return nn.ReLU
20
+ elif act_type == "silu":
21
+ return nn.SiLU
22
+ else:
23
+ raise ValueError(f"Unknown activation type: {act_type}")
hunyuan_model/attention.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ try:
9
+ import flash_attn
10
+ from flash_attn.flash_attn_interface import _flash_attn_forward
11
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
12
+ except ImportError:
13
+ flash_attn = None
14
+ flash_attn_varlen_func = None
15
+ _flash_attn_forward = None
16
+
17
+ try:
18
+ print(f"Trying to import sageattention")
19
+ from sageattention import sageattn_varlen
20
+
21
+ print("Successfully imported sageattention")
22
+ except ImportError:
23
+ print(f"Failed to import flash_attn and sageattention")
24
+ sageattn_varlen = None
25
+
26
+ MEMORY_LAYOUT = {
27
+ "flash": (
28
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
29
+ lambda x: x,
30
+ ),
31
+ "sageattn": (
32
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
33
+ lambda x: x,
34
+ ),
35
+ "torch": (
36
+ lambda x: x.transpose(1, 2),
37
+ lambda x: x.transpose(1, 2),
38
+ ),
39
+ "vanilla": (
40
+ lambda x: x.transpose(1, 2),
41
+ lambda x: x.transpose(1, 2),
42
+ ),
43
+ }
44
+
45
+
46
+ def get_cu_seqlens(text_mask, img_len):
47
+ """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
48
+
49
+ Args:
50
+ text_mask (torch.Tensor): the mask of text
51
+ img_len (int): the length of image
52
+
53
+ Returns:
54
+ torch.Tensor: the calculated cu_seqlens for flash attention
55
+ """
56
+ batch_size = text_mask.shape[0]
57
+ text_len = text_mask.sum(dim=1)
58
+ max_len = text_mask.shape[1] + img_len
59
+
60
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
61
+
62
+ for i in range(batch_size):
63
+ s = text_len[i] + img_len
64
+ s1 = i * max_len + s
65
+ s2 = (i + 1) * max_len
66
+ cu_seqlens[2 * i + 1] = s1
67
+ cu_seqlens[2 * i + 2] = s2
68
+
69
+ return cu_seqlens
70
+
71
+
72
+ def attention(
73
+ q_or_qkv_list,
74
+ k=None,
75
+ v=None,
76
+ mode="flash",
77
+ drop_rate=0,
78
+ attn_mask=None,
79
+ causal=False,
80
+ cu_seqlens_q=None,
81
+ cu_seqlens_kv=None,
82
+ max_seqlen_q=None,
83
+ max_seqlen_kv=None,
84
+ batch_size=1,
85
+ ):
86
+ """
87
+ Perform QKV self attention.
88
+
89
+ Args:
90
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
91
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
92
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
93
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
94
+ drop_rate (float): Dropout rate in attention map. (default: 0)
95
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
96
+ (default: None)
97
+ causal (bool): Whether to use causal attention. (default: False)
98
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
99
+ used to index into q.
100
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
101
+ used to index into kv.
102
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
103
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
104
+
105
+ Returns:
106
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
107
+ """
108
+ q, k, v = q_or_qkv_list if type(q_or_qkv_list) == list else (q_or_qkv_list, k, v)
109
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
110
+ q = pre_attn_layout(q)
111
+ k = pre_attn_layout(k)
112
+ v = pre_attn_layout(v)
113
+
114
+ if mode == "torch":
115
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
116
+ attn_mask = attn_mask.to(q.dtype)
117
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
118
+ if type(q_or_qkv_list) == list:
119
+ q_or_qkv_list.clear()
120
+ del q, k, v
121
+ del attn_mask
122
+ elif mode == "flash":
123
+ x = flash_attn_varlen_func(
124
+ q,
125
+ k,
126
+ v,
127
+ cu_seqlens_q,
128
+ cu_seqlens_kv,
129
+ max_seqlen_q,
130
+ max_seqlen_kv,
131
+ )
132
+ if type(q_or_qkv_list) == list:
133
+ q_or_qkv_list.clear()
134
+ del q, k, v
135
+ # x with shape [(bxs), a, d]
136
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
137
+ elif mode == "sageattn":
138
+ x = sageattn_varlen(
139
+ q,
140
+ k,
141
+ v,
142
+ cu_seqlens_q,
143
+ cu_seqlens_kv,
144
+ max_seqlen_q,
145
+ max_seqlen_kv,
146
+ )
147
+ if type(q_or_qkv_list) == list:
148
+ q_or_qkv_list.clear()
149
+ del q, k, v
150
+ # x with shape [(bxs), a, d]
151
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
152
+ elif mode == "vanilla":
153
+ scale_factor = 1 / math.sqrt(q.size(-1))
154
+
155
+ b, a, s, _ = q.shape
156
+ s1 = k.size(2)
157
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
158
+ if causal:
159
+ # Only applied to self attention
160
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
161
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
162
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
163
+ attn_bias.to(q.dtype)
164
+
165
+ if attn_mask is not None:
166
+ if attn_mask.dtype == torch.bool:
167
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
168
+ else:
169
+ attn_bias += attn_mask
170
+
171
+ # TODO: Maybe force q and k to be float32 to avoid numerical overflow
172
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
173
+ attn += attn_bias
174
+ attn = attn.softmax(dim=-1)
175
+ attn = torch.dropout(attn, p=drop_rate, train=True)
176
+ x = attn @ v
177
+ else:
178
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
179
+
180
+ x = post_attn_layout(x)
181
+ b, s, a, d = x.shape
182
+ out = x.reshape(b, s, -1)
183
+ return out
184
+
185
+
186
+ def parallel_attention(hybrid_seq_parallel_attn, q, k, v, img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv):
187
+ attn1 = hybrid_seq_parallel_attn(
188
+ None,
189
+ q[:, :img_q_len, :, :],
190
+ k[:, :img_kv_len, :, :],
191
+ v[:, :img_kv_len, :, :],
192
+ dropout_p=0.0,
193
+ causal=False,
194
+ joint_tensor_query=q[:, img_q_len : cu_seqlens_q[1]],
195
+ joint_tensor_key=k[:, img_kv_len : cu_seqlens_kv[1]],
196
+ joint_tensor_value=v[:, img_kv_len : cu_seqlens_kv[1]],
197
+ joint_strategy="rear",
198
+ )
199
+ if flash_attn.__version__ >= "2.7.0":
200
+ attn2, *_ = _flash_attn_forward(
201
+ q[:, cu_seqlens_q[1] :],
202
+ k[:, cu_seqlens_kv[1] :],
203
+ v[:, cu_seqlens_kv[1] :],
204
+ dropout_p=0.0,
205
+ softmax_scale=q.shape[-1] ** (-0.5),
206
+ causal=False,
207
+ window_size_left=-1,
208
+ window_size_right=-1,
209
+ softcap=0.0,
210
+ alibi_slopes=None,
211
+ return_softmax=False,
212
+ )
213
+ else:
214
+ attn2, *_ = _flash_attn_forward(
215
+ q[:, cu_seqlens_q[1] :],
216
+ k[:, cu_seqlens_kv[1] :],
217
+ v[:, cu_seqlens_kv[1] :],
218
+ dropout_p=0.0,
219
+ softmax_scale=q.shape[-1] ** (-0.5),
220
+ causal=False,
221
+ window_size=(-1, -1),
222
+ softcap=0.0,
223
+ alibi_slopes=None,
224
+ return_softmax=False,
225
+ )
226
+ attn = torch.cat([attn1, attn2], dim=1)
227
+ b, s, a, d = attn.shape
228
+ attn = attn.reshape(b, s, -1)
229
+
230
+ return attn
hunyuan_model/autoencoder_kl_causal_3d.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+ from typing import Dict, Optional, Tuple, Union
20
+ from dataclasses import dataclass
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+
27
+ try:
28
+ # This diffusers is modified and packed in the mirror.
29
+ from diffusers.loaders import FromOriginalVAEMixin
30
+ except ImportError:
31
+ # Use this to be compatible with the original diffusers.
32
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
33
+ from diffusers.utils.accelerate_utils import apply_forward_hook
34
+ from diffusers.models.attention_processor import (
35
+ ADDED_KV_ATTENTION_PROCESSORS,
36
+ CROSS_ATTENTION_PROCESSORS,
37
+ Attention,
38
+ AttentionProcessor,
39
+ AttnAddedKVProcessor,
40
+ AttnProcessor,
41
+ )
42
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
43
+ from diffusers.models.modeling_utils import ModelMixin
44
+ from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
45
+
46
+
47
+ @dataclass
48
+ class DecoderOutput2(BaseOutput):
49
+ sample: torch.FloatTensor
50
+ posterior: Optional[DiagonalGaussianDistribution] = None
51
+
52
+
53
+ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
54
+ r"""
55
+ A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
56
+
57
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
58
+ for all models (such as downloading or saving).
59
+ """
60
+
61
+ _supports_gradient_checkpointing = True
62
+
63
+ @register_to_config
64
+ def __init__(
65
+ self,
66
+ in_channels: int = 3,
67
+ out_channels: int = 3,
68
+ down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
69
+ up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
70
+ block_out_channels: Tuple[int] = (64,),
71
+ layers_per_block: int = 1,
72
+ act_fn: str = "silu",
73
+ latent_channels: int = 4,
74
+ norm_num_groups: int = 32,
75
+ sample_size: int = 32,
76
+ sample_tsize: int = 64,
77
+ scaling_factor: float = 0.18215,
78
+ force_upcast: float = True,
79
+ spatial_compression_ratio: int = 8,
80
+ time_compression_ratio: int = 4,
81
+ mid_block_add_attention: bool = True,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.time_compression_ratio = time_compression_ratio
86
+
87
+ self.encoder = EncoderCausal3D(
88
+ in_channels=in_channels,
89
+ out_channels=latent_channels,
90
+ down_block_types=down_block_types,
91
+ block_out_channels=block_out_channels,
92
+ layers_per_block=layers_per_block,
93
+ act_fn=act_fn,
94
+ norm_num_groups=norm_num_groups,
95
+ double_z=True,
96
+ time_compression_ratio=time_compression_ratio,
97
+ spatial_compression_ratio=spatial_compression_ratio,
98
+ mid_block_add_attention=mid_block_add_attention,
99
+ )
100
+
101
+ self.decoder = DecoderCausal3D(
102
+ in_channels=latent_channels,
103
+ out_channels=out_channels,
104
+ up_block_types=up_block_types,
105
+ block_out_channels=block_out_channels,
106
+ layers_per_block=layers_per_block,
107
+ norm_num_groups=norm_num_groups,
108
+ act_fn=act_fn,
109
+ time_compression_ratio=time_compression_ratio,
110
+ spatial_compression_ratio=spatial_compression_ratio,
111
+ mid_block_add_attention=mid_block_add_attention,
112
+ )
113
+
114
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
115
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
116
+
117
+ self.use_slicing = False
118
+ self.use_spatial_tiling = False
119
+ self.use_temporal_tiling = False
120
+
121
+ # only relevant if vae tiling is enabled
122
+ self.tile_sample_min_tsize = sample_tsize
123
+ self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
124
+
125
+ self.tile_sample_min_size = self.config.sample_size
126
+ sample_size = self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size
127
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
128
+ self.tile_overlap_factor = 0.25
129
+
130
+ def _set_gradient_checkpointing(self, module, value=False):
131
+ if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
132
+ module.gradient_checkpointing = value
133
+
134
+ def enable_temporal_tiling(self, use_tiling: bool = True):
135
+ self.use_temporal_tiling = use_tiling
136
+
137
+ def disable_temporal_tiling(self):
138
+ self.enable_temporal_tiling(False)
139
+
140
+ def enable_spatial_tiling(self, use_tiling: bool = True):
141
+ self.use_spatial_tiling = use_tiling
142
+
143
+ def disable_spatial_tiling(self):
144
+ self.enable_spatial_tiling(False)
145
+
146
+ def enable_tiling(self, use_tiling: bool = True):
147
+ r"""
148
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
149
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
150
+ processing larger videos.
151
+ """
152
+ self.enable_spatial_tiling(use_tiling)
153
+ self.enable_temporal_tiling(use_tiling)
154
+
155
+ def disable_tiling(self):
156
+ r"""
157
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
158
+ decoding in one step.
159
+ """
160
+ self.disable_spatial_tiling()
161
+ self.disable_temporal_tiling()
162
+
163
+ def enable_slicing(self):
164
+ r"""
165
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
166
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
167
+ """
168
+ self.use_slicing = True
169
+
170
+ def disable_slicing(self):
171
+ r"""
172
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
173
+ decoding in one step.
174
+ """
175
+ self.use_slicing = False
176
+
177
+ def set_chunk_size_for_causal_conv_3d(self, chunk_size: int):
178
+ # set chunk_size to CausalConv3d recursively
179
+ def set_chunk_size(module):
180
+ if hasattr(module, "chunk_size"):
181
+ module.chunk_size = chunk_size
182
+
183
+ self.apply(set_chunk_size)
184
+
185
+ @property
186
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
187
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
188
+ r"""
189
+ Returns:
190
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
191
+ indexed by its weight name.
192
+ """
193
+ # set recursively
194
+ processors = {}
195
+
196
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
197
+ if hasattr(module, "get_processor"):
198
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
199
+
200
+ for sub_name, child in module.named_children():
201
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
202
+
203
+ return processors
204
+
205
+ for name, module in self.named_children():
206
+ fn_recursive_add_processors(name, module, processors)
207
+
208
+ return processors
209
+
210
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
211
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False):
212
+ r"""
213
+ Sets the attention processor to use to compute attention.
214
+
215
+ Parameters:
216
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
217
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
218
+ for **all** `Attention` layers.
219
+
220
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
221
+ processor. This is strongly recommended when setting trainable attention processors.
222
+
223
+ """
224
+ count = len(self.attn_processors.keys())
225
+
226
+ if isinstance(processor, dict) and len(processor) != count:
227
+ raise ValueError(
228
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
229
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
230
+ )
231
+
232
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
233
+ if hasattr(module, "set_processor"):
234
+ if not isinstance(processor, dict):
235
+ module.set_processor(processor, _remove_lora=_remove_lora)
236
+ else:
237
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
238
+
239
+ for sub_name, child in module.named_children():
240
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
241
+
242
+ for name, module in self.named_children():
243
+ fn_recursive_attn_processor(name, module, processor)
244
+
245
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
246
+ def set_default_attn_processor(self):
247
+ """
248
+ Disables custom attention processors and sets the default attention implementation.
249
+ """
250
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
251
+ processor = AttnAddedKVProcessor()
252
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
253
+ processor = AttnProcessor()
254
+ else:
255
+ raise ValueError(
256
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
257
+ )
258
+
259
+ self.set_attn_processor(processor, _remove_lora=True)
260
+
261
+ @apply_forward_hook
262
+ def encode(
263
+ self, x: torch.FloatTensor, return_dict: bool = True
264
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
265
+ """
266
+ Encode a batch of images/videos into latents.
267
+
268
+ Args:
269
+ x (`torch.FloatTensor`): Input batch of images/videos.
270
+ return_dict (`bool`, *optional*, defaults to `True`):
271
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
272
+
273
+ Returns:
274
+ The latent representations of the encoded images/videos. If `return_dict` is True, a
275
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
276
+ """
277
+ assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
278
+
279
+ if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
280
+ return self.temporal_tiled_encode(x, return_dict=return_dict)
281
+
282
+ if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
283
+ return self.spatial_tiled_encode(x, return_dict=return_dict)
284
+
285
+ if self.use_slicing and x.shape[0] > 1:
286
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
287
+ h = torch.cat(encoded_slices)
288
+ else:
289
+ h = self.encoder(x)
290
+
291
+ moments = self.quant_conv(h)
292
+ posterior = DiagonalGaussianDistribution(moments)
293
+
294
+ if not return_dict:
295
+ return (posterior,)
296
+
297
+ return AutoencoderKLOutput(latent_dist=posterior)
298
+
299
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
300
+ assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
301
+
302
+ if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
303
+ return self.temporal_tiled_decode(z, return_dict=return_dict)
304
+
305
+ if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
306
+ return self.spatial_tiled_decode(z, return_dict=return_dict)
307
+
308
+ z = self.post_quant_conv(z)
309
+ dec = self.decoder(z)
310
+
311
+ if not return_dict:
312
+ return (dec,)
313
+
314
+ return DecoderOutput(sample=dec)
315
+
316
+ @apply_forward_hook
317
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.FloatTensor]:
318
+ """
319
+ Decode a batch of images/videos.
320
+
321
+ Args:
322
+ z (`torch.FloatTensor`): Input batch of latent vectors.
323
+ return_dict (`bool`, *optional*, defaults to `True`):
324
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
325
+
326
+ Returns:
327
+ [`~models.vae.DecoderOutput`] or `tuple`:
328
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
329
+ returned.
330
+
331
+ """
332
+ if self.use_slicing and z.shape[0] > 1:
333
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
334
+ decoded = torch.cat(decoded_slices)
335
+ else:
336
+ decoded = self._decode(z).sample
337
+
338
+ if not return_dict:
339
+ return (decoded,)
340
+
341
+ return DecoderOutput(sample=decoded)
342
+
343
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
344
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
345
+ for y in range(blend_extent):
346
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
347
+ return b
348
+
349
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
350
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
351
+ for x in range(blend_extent):
352
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
353
+ return b
354
+
355
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
356
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
357
+ for x in range(blend_extent):
358
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
359
+ return b
360
+
361
+ def spatial_tiled_encode(
362
+ self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False
363
+ ) -> AutoencoderKLOutput:
364
+ r"""Encode a batch of images/videos using a tiled encoder.
365
+
366
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
367
+ steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
368
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
369
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
370
+ output, but they should be much less noticeable.
371
+
372
+ Args:
373
+ x (`torch.FloatTensor`): Input batch of images/videos.
374
+ return_dict (`bool`, *optional*, defaults to `True`):
375
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
376
+
377
+ Returns:
378
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
379
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
380
+ `tuple` is returned.
381
+ """
382
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
383
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
384
+ row_limit = self.tile_latent_min_size - blend_extent
385
+
386
+ # Split video into tiles and encode them separately.
387
+ rows = []
388
+ for i in range(0, x.shape[-2], overlap_size):
389
+ row = []
390
+ for j in range(0, x.shape[-1], overlap_size):
391
+ tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
392
+ tile = self.encoder(tile)
393
+ tile = self.quant_conv(tile)
394
+ row.append(tile)
395
+ rows.append(row)
396
+ result_rows = []
397
+ for i, row in enumerate(rows):
398
+ result_row = []
399
+ for j, tile in enumerate(row):
400
+ # blend the above tile and the left tile
401
+ # to the current tile and add the current tile to the result row
402
+ if i > 0:
403
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
404
+ if j > 0:
405
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
406
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
407
+ result_rows.append(torch.cat(result_row, dim=-1))
408
+
409
+ moments = torch.cat(result_rows, dim=-2)
410
+ if return_moments:
411
+ return moments
412
+
413
+ posterior = DiagonalGaussianDistribution(moments)
414
+ if not return_dict:
415
+ return (posterior,)
416
+
417
+ return AutoencoderKLOutput(latent_dist=posterior)
418
+
419
+ def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
420
+ r"""
421
+ Decode a batch of images/videos using a tiled decoder.
422
+
423
+ Args:
424
+ z (`torch.FloatTensor`): Input batch of latent vectors.
425
+ return_dict (`bool`, *optional*, defaults to `True`):
426
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
427
+
428
+ Returns:
429
+ [`~models.vae.DecoderOutput`] or `tuple`:
430
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
431
+ returned.
432
+ """
433
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
434
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
435
+ row_limit = self.tile_sample_min_size - blend_extent
436
+
437
+ # Split z into overlapping tiles and decode them separately.
438
+ # The tiles have an overlap to avoid seams between tiles.
439
+ rows = []
440
+ for i in range(0, z.shape[-2], overlap_size):
441
+ row = []
442
+ for j in range(0, z.shape[-1], overlap_size):
443
+ tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
444
+ tile = self.post_quant_conv(tile)
445
+ decoded = self.decoder(tile)
446
+ row.append(decoded)
447
+ rows.append(row)
448
+ result_rows = []
449
+ for i, row in enumerate(rows):
450
+ result_row = []
451
+ for j, tile in enumerate(row):
452
+ # blend the above tile and the left tile
453
+ # to the current tile and add the current tile to the result row
454
+ if i > 0:
455
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
456
+ if j > 0:
457
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
458
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
459
+ result_rows.append(torch.cat(result_row, dim=-1))
460
+
461
+ dec = torch.cat(result_rows, dim=-2)
462
+ if not return_dict:
463
+ return (dec,)
464
+
465
+ return DecoderOutput(sample=dec)
466
+
467
+ def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
468
+
469
+ B, C, T, H, W = x.shape
470
+ overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
471
+ blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
472
+ t_limit = self.tile_latent_min_tsize - blend_extent
473
+
474
+ # Split the video into tiles and encode them separately.
475
+ row = []
476
+ for i in range(0, T, overlap_size):
477
+ tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
478
+ if self.use_spatial_tiling and (
479
+ tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size
480
+ ):
481
+ tile = self.spatial_tiled_encode(tile, return_moments=True)
482
+ else:
483
+ tile = self.encoder(tile)
484
+ tile = self.quant_conv(tile)
485
+ if i > 0:
486
+ tile = tile[:, :, 1:, :, :]
487
+ row.append(tile)
488
+ result_row = []
489
+ for i, tile in enumerate(row):
490
+ if i > 0:
491
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
492
+ result_row.append(tile[:, :, :t_limit, :, :])
493
+ else:
494
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
495
+
496
+ moments = torch.cat(result_row, dim=2)
497
+ posterior = DiagonalGaussianDistribution(moments)
498
+
499
+ if not return_dict:
500
+ return (posterior,)
501
+
502
+ return AutoencoderKLOutput(latent_dist=posterior)
503
+
504
+ def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
505
+ # Split z into overlapping tiles and decode them separately.
506
+
507
+ B, C, T, H, W = z.shape
508
+ overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
509
+ blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
510
+ t_limit = self.tile_sample_min_tsize - blend_extent
511
+
512
+ row = []
513
+ for i in range(0, T, overlap_size):
514
+ tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
515
+ if self.use_spatial_tiling and (
516
+ tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size
517
+ ):
518
+ decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
519
+ else:
520
+ tile = self.post_quant_conv(tile)
521
+ decoded = self.decoder(tile)
522
+ if i > 0:
523
+ decoded = decoded[:, :, 1:, :, :]
524
+ row.append(decoded)
525
+ result_row = []
526
+ for i, tile in enumerate(row):
527
+ if i > 0:
528
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
529
+ result_row.append(tile[:, :, :t_limit, :, :])
530
+ else:
531
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
532
+
533
+ dec = torch.cat(result_row, dim=2)
534
+ if not return_dict:
535
+ return (dec,)
536
+
537
+ return DecoderOutput(sample=dec)
538
+
539
+ def forward(
540
+ self,
541
+ sample: torch.FloatTensor,
542
+ sample_posterior: bool = False,
543
+ return_dict: bool = True,
544
+ return_posterior: bool = False,
545
+ generator: Optional[torch.Generator] = None,
546
+ ) -> Union[DecoderOutput2, torch.FloatTensor]:
547
+ r"""
548
+ Args:
549
+ sample (`torch.FloatTensor`): Input sample.
550
+ sample_posterior (`bool`, *optional*, defaults to `False`):
551
+ Whether to sample from the posterior.
552
+ return_dict (`bool`, *optional*, defaults to `True`):
553
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
554
+ """
555
+ x = sample
556
+ posterior = self.encode(x).latent_dist
557
+ if sample_posterior:
558
+ z = posterior.sample(generator=generator)
559
+ else:
560
+ z = posterior.mode()
561
+ dec = self.decode(z).sample
562
+
563
+ if not return_dict:
564
+ if return_posterior:
565
+ return (dec, posterior)
566
+ else:
567
+ return (dec,)
568
+ if return_posterior:
569
+ return DecoderOutput2(sample=dec, posterior=posterior)
570
+ else:
571
+ return DecoderOutput2(sample=dec)
572
+
573
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
574
+ def fuse_qkv_projections(self):
575
+ """
576
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
577
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
578
+
579
+ <Tip warning={true}>
580
+
581
+ This API is 🧪 experimental.
582
+
583
+ </Tip>
584
+ """
585
+ self.original_attn_processors = None
586
+
587
+ for _, attn_processor in self.attn_processors.items():
588
+ if "Added" in str(attn_processor.__class__.__name__):
589
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
590
+
591
+ self.original_attn_processors = self.attn_processors
592
+
593
+ for module in self.modules():
594
+ if isinstance(module, Attention):
595
+ module.fuse_projections(fuse=True)
596
+
597
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
598
+ def unfuse_qkv_projections(self):
599
+ """Disables the fused QKV projection if enabled.
600
+
601
+ <Tip warning={true}>
602
+
603
+ This API is 🧪 experimental.
604
+
605
+ </Tip>
606
+
607
+ """
608
+ if self.original_attn_processors is not None:
609
+ self.set_attn_processor(self.original_attn_processors)
hunyuan_model/embed_layers.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange, repeat
6
+
7
+ from .helpers import to_2tuple
8
+
9
+ class PatchEmbed(nn.Module):
10
+ """2D Image to Patch Embedding
11
+
12
+ Image to Patch Embedding using Conv2d
13
+
14
+ A convolution based approach to patchifying a 2D image w/ embedding projection.
15
+
16
+ Based on the impl in https://github.com/google-research/vision_transformer
17
+
18
+ Hacked together by / Copyright 2020 Ross Wightman
19
+
20
+ Remove the _assert function in forward function to be compatible with multi-resolution images.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ patch_size=16,
26
+ in_chans=3,
27
+ embed_dim=768,
28
+ norm_layer=None,
29
+ flatten=True,
30
+ bias=True,
31
+ dtype=None,
32
+ device=None,
33
+ ):
34
+ factory_kwargs = {"dtype": dtype, "device": device}
35
+ super().__init__()
36
+ patch_size = to_2tuple(patch_size)
37
+ self.patch_size = patch_size
38
+ self.flatten = flatten
39
+
40
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs)
41
+ nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
42
+ if bias:
43
+ nn.init.zeros_(self.proj.bias)
44
+
45
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
46
+
47
+ def forward(self, x):
48
+ x = self.proj(x)
49
+ if self.flatten:
50
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
51
+ x = self.norm(x)
52
+ return x
53
+
54
+
55
+ class TextProjection(nn.Module):
56
+ """
57
+ Projects text embeddings. Also handles dropout for classifier-free guidance.
58
+
59
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
60
+ """
61
+
62
+ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
63
+ factory_kwargs = {"dtype": dtype, "device": device}
64
+ super().__init__()
65
+ self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
66
+ self.act_1 = act_layer()
67
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
68
+
69
+ def forward(self, caption):
70
+ hidden_states = self.linear_1(caption)
71
+ hidden_states = self.act_1(hidden_states)
72
+ hidden_states = self.linear_2(hidden_states)
73
+ return hidden_states
74
+
75
+
76
+ def timestep_embedding(t, dim, max_period=10000):
77
+ """
78
+ Create sinusoidal timestep embeddings.
79
+
80
+ Args:
81
+ t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
82
+ dim (int): the dimension of the output.
83
+ max_period (int): controls the minimum frequency of the embeddings.
84
+
85
+ Returns:
86
+ embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
87
+
88
+ .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
89
+ """
90
+ half = dim // 2
91
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
92
+ args = t[:, None].float() * freqs[None]
93
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
94
+ if dim % 2:
95
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
96
+ return embedding
97
+
98
+
99
+ class TimestepEmbedder(nn.Module):
100
+ """
101
+ Embeds scalar timesteps into vector representations.
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ hidden_size,
107
+ act_layer,
108
+ frequency_embedding_size=256,
109
+ max_period=10000,
110
+ out_size=None,
111
+ dtype=None,
112
+ device=None,
113
+ ):
114
+ factory_kwargs = {"dtype": dtype, "device": device}
115
+ super().__init__()
116
+ self.frequency_embedding_size = frequency_embedding_size
117
+ self.max_period = max_period
118
+ if out_size is None:
119
+ out_size = hidden_size
120
+
121
+ self.mlp = nn.Sequential(
122
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
123
+ act_layer(),
124
+ nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
125
+ )
126
+ nn.init.normal_(self.mlp[0].weight, std=0.02)
127
+ nn.init.normal_(self.mlp[2].weight, std=0.02)
128
+
129
+ def forward(self, t):
130
+ t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
131
+ t_emb = self.mlp(t_freq)
132
+ return t_emb