NeverMore0123 commited on
Commit
37471f2
·
1 Parent(s): 2375ad5
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ar_model.py +5 -5
  2. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/00514ee36fc535c00c979a7802492538d9886fae +0 -76
  3. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/00be0b18e2e656c0d69d8d74298a45195530e8c4 +0 -102
  4. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/05f6f3e548bab062b6f46c0f12377a502bde0dbf +0 -606
  5. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/072076fb853aec819a7298df83e26338e0cb4c3a +0 -187
  6. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/09440b34a95b1708d2154376f2a0202a533cb3b2 +0 -46
  7. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/0c2f9c6280ccfa60e1ba8a38e3062e0caf99e71e +0 -560
  8. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/148897d5cae9165673cb74e336548c71adb261b1 +0 -78
  9. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/184d60dec6f9b0326dc0aa1a3d9b89c06fa7566e +0 -283
  10. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/1e300540d3a022a74d708a0df0f04204a895b189 +0 -903
  11. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/1f41a4225dcea325c5ea283e51e09477ee1d0e6d +0 -149
  12. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/2464bc5e1892a3541ce439c0ea36347f43647224 +0 -305
  13. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/2984b57e08440bd3117de9e25e4f3cfabd619e80 +0 -195
  14. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/29be4d33e5dfb6255b5db0b99bcbc4311a3faa82 +0 -63
  15. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/2a19d3b8e2a1cf29c182f7b25a25d4c1e10089da +0 -491
  16. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/2c584c7c9a5e03bcb3b808d053f89e7c2aeaf9cf +0 -119
  17. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/39dca42a0a71383de919b750cedf2606faae206d +0 -65
  18. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/3c5a1dbe30558d9e7e97ad64304161c4e61a00f5 +0 -60
  19. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/4146fad65c365a8c4fd6903a0ea33860142f64f5 +0 -323
  20. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/45a2ac6c32e8df9e6836ed55973912b8730c0749 +0 -50
  21. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/46385211d438d1953e9ba21376680dc2c42db01c +0 -219
  22. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/4a13a8fde58e7852b683112be63eaed44e1f143f +0 -596
  23. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/4c860c42a1c3d8adc417e9593892491d0803fe51 +0 -113
  24. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/4de12fae686821ebf94aec3420719e6432856cf4 +0 -421
  25. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/53dea6ed871052e987bf5094f869778412202323 +0 -360
  26. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/54ff4d48b535d2a1f27bbcc75c20ef16821b11e1 +0 -341
  27. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/578cd9ecfca36e5376fef8da5106652c6ca85b68 +0 -262
  28. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/5877aa166d1d946b98ce604e2bd1a4284b884ae6 +0 -318
  29. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/5d1bc4c8a22a942736ae6b73a4ebb21da4980adc +0 -117
  30. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/5e5a5244c87516121f3e7686c924f8b1c66cd772 +0 -360
  31. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/602ea1cb383d8263be06829a466cfb3ba9f97856 +0 -52
  32. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/61f10fe07227a01d582e17f89a9b5089aa506006 +0 -88
  33. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/68e9cbb58aa1a39cd62c15a01b3e6526a49b66b0 +0 -728
  34. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/69f477ced9dfe59deda742bc507addf7d7268bdf +0 -223
  35. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/6bb055d8b2ddd78f626f08bb78f9434de5aef511 +0 -276
  36. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/73755631ed6b97ebf773b3941fc0f6d1621761f7 +0 -231
  37. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/77c3f88ca85134e689203e9ac157673c42edb0b3 +0 -131
  38. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/7b5c6e553583e8047a37aea5e4925df659426ea2 +0 -196
  39. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/7bebf08cef2869c85553980bf81851635dd74f7e +0 -108
  40. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/7c09eb428a97927d5f0407e2328a3f43afbf38fc +0 -72
  41. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/859eb6498143e5b063dbc888dca7748a07cfda9d +0 -45
  42. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/8929f3a211707ad09f7c25b6b6e305360a42d6be +0 -358
  43. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/9586934f8c1949d734b4ea3080135d2769ec481a +0 -333
  44. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/9861ef45253f4932a362923bdb6f07fd1b39666b +0 -322
  45. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/9918ab7cc8f55dc0c159b58c158d3556b6819acd +0 -317
  46. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/9bd252316a4bd6fb3a8f8a1c29a8e9ac44ac76fe +0 -60
  47. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/9d565d078fbe37e1d31cf8a445a460e2bae291f1 +0 -224
  48. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/a209db0eba28a8d8bcb527bfbaca6f5e361ace14 +0 -28
  49. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/a2496a4fa280586b62c846c54cfbbc9f8adc0331 +0 -211
  50. cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/a24d1a0cbbe184ab0a2bfb5cbee13bfd327810ae +0 -165
ar_model.py CHANGED
@@ -19,7 +19,7 @@ import time
19
  from pathlib import Path
20
  from typing import Any, Dict, List, Optional, Set
21
 
22
- from .misc import misc
23
  import torch
24
  from safetensors.torch import load_file
25
  from torch.nn.modules.module import _IncompatibleKeys
@@ -96,7 +96,7 @@ class AutoRegressiveModel(torch.nn.Module):
96
  """
97
  model_config = self.config
98
  ckpt_path = model_config.ckpt_path
99
- with misc.timer(f"loading checkpoint from {ckpt_path}"):
100
  if ckpt_path.endswith("safetensors"):
101
  # Load with safetensors API
102
  checkpoint = load_file(ckpt_path, device="cpu")
@@ -142,7 +142,7 @@ class AutoRegressiveModel(torch.nn.Module):
142
  )
143
  # Remove the "model." prefix in the state_dict
144
  llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
145
- with misc.timer("loading state_dict into model"):
146
  missing_keys, _ = model.load_state_dict(llm_checkpoint, strict=True)
147
  # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
148
  missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
@@ -217,7 +217,7 @@ class AutoRegressiveModel(torch.nn.Module):
217
  # Override the default model configuration with the parameters from the checkpoint
218
  setattr(model_config, key, value)
219
 
220
- with misc.timer(f"loading checkpoint from {ckpt_path}"):
221
  if ckpt_path.endswith("safetensors"):
222
  # Load with safetensors API
223
  checkpoint = load_file(ckpt_path, device="cpu")
@@ -293,7 +293,7 @@ class AutoRegressiveModel(torch.nn.Module):
293
 
294
  # Remove the "model." prefix in the state_dict
295
  llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
296
- with misc.timer("loading state_dict into model"):
297
  missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True)
298
  # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
299
  missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
 
19
  from pathlib import Path
20
  from typing import Any, Dict, List, Optional, Set
21
 
22
+ from .misc import misc, Color, timer
23
  import torch
24
  from safetensors.torch import load_file
25
  from torch.nn.modules.module import _IncompatibleKeys
 
96
  """
97
  model_config = self.config
98
  ckpt_path = model_config.ckpt_path
99
+ with timer(f"loading checkpoint from {ckpt_path}"):
100
  if ckpt_path.endswith("safetensors"):
101
  # Load with safetensors API
102
  checkpoint = load_file(ckpt_path, device="cpu")
 
142
  )
143
  # Remove the "model." prefix in the state_dict
144
  llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
145
+ with timer("loading state_dict into model"):
146
  missing_keys, _ = model.load_state_dict(llm_checkpoint, strict=True)
147
  # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
148
  missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
 
217
  # Override the default model configuration with the parameters from the checkpoint
218
  setattr(model_config, key, value)
219
 
220
+ with timer(f"loading checkpoint from {ckpt_path}"):
221
  if ckpt_path.endswith("safetensors"):
222
  # Load with safetensors API
223
  checkpoint = load_file(ckpt_path, device="cpu")
 
293
 
294
  # Remove the "model." prefix in the state_dict
295
  llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
296
+ with timer("loading state_dict into model"):
297
  missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True)
298
  # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
299
  missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/00514ee36fc535c00c979a7802492538d9886fae DELETED
@@ -1,76 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Dict, Optional
17
-
18
- import torch
19
-
20
- # Substrings to ignore when processing state dicts
21
- substrings_to_ignore = [
22
- "_extra_state", # Extra states (BytesIO type) added by TransformerEngine for FP8 handling
23
- ]
24
-
25
-
26
- def get_partial_state_dict(
27
- state_dict: Dict[str, torch.Tensor],
28
- prefix: str,
29
- ) -> Dict[str, torch.Tensor]:
30
- """
31
- Get a partial state dict with keys starting with the given prefix
32
- """
33
- return {k: v for k, v in state_dict.items() if k.startswith(prefix)}
34
-
35
-
36
- def process_state_dict(
37
- state_dict: Dict[str, torch.Tensor],
38
- device: str = None,
39
- dtype: torch.dtype = None,
40
- prefix_to_remove: Optional[str] = None,
41
- ) -> Dict[str, torch.Tensor]:
42
- """
43
- - Remove items with substring "_extra_state" in keys (TransformerEngine adds these for FP8)
44
- - Move tensors to specified device and dtype if provided
45
-
46
- Args:
47
- state_dict (Dict[str, torch.Tensor]): The state dict to process
48
- device (str, optional): The device to move tensors to. Defaults to None.
49
- dtype (torch.dtype, optional): The dtype to move tensors to. Defaults to None.
50
- prefix_to_remove (str, optional): The prefix to remove from the keys of the state dict. Defaults to None.
51
-
52
- Returns:
53
- Dict[str, torch.Tensor]: The processed state dict
54
- """
55
- new_state_dict = {}
56
- tensor_kwargs = {}
57
- if device is not None:
58
- tensor_kwargs["device"] = device
59
- if dtype is not None:
60
- tensor_kwargs["dtype"] = dtype
61
-
62
- for key, value in state_dict.items():
63
- # Check if any of the substrings to ignore are in the key
64
- skip = False
65
- for substr in substrings_to_ignore:
66
- if substr in key:
67
- skip = True
68
- break
69
- if skip:
70
- continue
71
- if len(tensor_kwargs) > 0:
72
- value = value.to(**tensor_kwargs)
73
- if prefix_to_remove is not None and key.startswith(prefix_to_remove):
74
- key = key[len(prefix_to_remove) :]
75
- new_state_dict[key] = value
76
- return new_state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/00be0b18e2e656c0d69d8d74298a45195530e8c4 DELETED
@@ -1,102 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Any, List, Union
17
-
18
- import attrs
19
-
20
- from .ar_config_base_model import ModelConfig, TokenizerConfig
21
-
22
-
23
- @attrs.define(slots=False)
24
- class DataShapeConfig:
25
- latent_shape: list = []
26
- num_video_frames: Union[None, int] = None
27
- height: Union[None, int] = None
28
- width: Union[None, int] = None
29
-
30
-
31
- @attrs.define(slots=False)
32
- class SamplingConfig:
33
- """
34
- Sampling config
35
- Args:
36
- temperature (float): Temperature value for controlling randomness in sampling. Defaults to 0.6.
37
- top_p (float): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
38
- logprobs (bool): Flag indicating whether to compute token log probabilities. Defaults to False.
39
- echo (bool): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
40
-
41
- """
42
-
43
- temperature: float = 0.6
44
- top_k: int = None
45
- top_p: float = 0.9
46
- compile_prefill: bool = False
47
- compile_sampling: bool = True
48
- logprobs: bool = False
49
- echo: bool = False
50
-
51
-
52
- @attrs.define(slots=False)
53
- class DiffusionDecoderSamplingConfig:
54
- """
55
- Diffusion decoder sampling config
56
- Args:
57
- guidance (float): Guidance scale for the diffusion process. Controls how much the model follows the conditioning. Defaults to 0.8.
58
- sigma_min (float): Minimum noise level for the diffusion process. Defaults to 0.02.
59
- sigma (float): Initial noise level for the diffusion process. Defaults to 8.
60
- num_steps (int): Number of denoising steps to perform. Defaults to 35.
61
- overlap (int): Number of overlapping frames between video chunks during processing. Defaults to 2.
62
- continuous_tokenizer_channel (int): Number of channels in the continuous tokenizer of diffusion decoder. Defaults to 16.
63
- continuous_tokenizer_spatial_compression_ratio (int): Spatial compression ratio for the continuous tokenizer of diffusion decoder. Defaults to 8.
64
- dd_train_num_video_frames (int): Number of video frames used during training for diffusion decoder. Defaults to 57.
65
- """
66
-
67
- guidance: float = 1.8
68
- sigma_min: float = 0.02
69
- sigma: float = 8
70
- num_steps: int = 15
71
- overlap: int = 2
72
- continuous_tokenizer_channel = 16
73
- continuous_tokenizer_spatial_compression_ratio = 8
74
- dd_train_num_video_frames: int = 57
75
- max_iter: int = 99
76
- fps: int = 24
77
-
78
-
79
- @attrs.define(slots=False)
80
- class InferenceConfig:
81
- """
82
- Inference config
83
- Args:
84
- model_config (ModelConfig): Model config
85
- tokenizer_config (TokenizerConfig): Tokenizer config
86
- ckpt_path (str): Path to the checkpoint
87
- latent_shape (list): Shape of the latent
88
- """
89
-
90
- model_config: ModelConfig = None
91
- tokenizer_config: TokenizerConfig = None
92
- ckpt_path: str = ""
93
- data_shape_config: DataShapeConfig = None
94
-
95
- defaults: List[Any] = attrs.field(
96
- factory=lambda: [
97
- "_self_",
98
- {"data_val": None},
99
- {"data_shape_config": "video_shape_as_model_config"},
100
- {"eval_job": None},
101
- ]
102
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/05f6f3e548bab062b6f46c0f12377a502bde0dbf DELETED
@@ -1,606 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import os
17
- from abc import ABC, abstractmethod
18
-
19
- import torch
20
- from einops import rearrange
21
- from torch.nn.modules import Module
22
-
23
-
24
- class BaseVAE(torch.nn.Module, ABC):
25
- """
26
- Abstract base class for a Variational Autoencoder (VAE).
27
-
28
- All subclasses should implement the methods to define the behavior for encoding
29
- and decoding, along with specifying the latent channel size.
30
- """
31
-
32
- def __init__(self, channel: int = 3, name: str = "vae"):
33
- super().__init__()
34
- self.channel = channel
35
- self.name = name
36
-
37
- @property
38
- def latent_ch(self) -> int:
39
- """
40
- Returns the number of latent channels in the VAE.
41
- """
42
- return self.channel
43
-
44
- @abstractmethod
45
- def encode(self, state: torch.Tensor) -> torch.Tensor:
46
- """
47
- Encodes the input tensor into a latent representation.
48
-
49
- Args:
50
- - state (torch.Tensor): The input tensor to encode.
51
-
52
- Returns:
53
- - torch.Tensor: The encoded latent tensor.
54
- """
55
- pass
56
-
57
- @abstractmethod
58
- def decode(self, latent: torch.Tensor) -> torch.Tensor:
59
- """
60
- Decodes the latent representation back to the original space.
61
-
62
- Args:
63
- - latent (torch.Tensor): The latent tensor to decode.
64
-
65
- Returns:
66
- - torch.Tensor: The decoded tensor.
67
- """
68
- pass
69
-
70
- @property
71
- def spatial_compression_factor(self) -> int:
72
- """
73
- Returns the spatial reduction factor for the VAE.
74
- """
75
- raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.")
76
-
77
-
78
- class BasePretrainedImageVAE(BaseVAE):
79
- """
80
- A base class for pretrained Variational Autoencoder (VAE) that loads mean and standard deviation values
81
- from a remote store, handles data type conversions, and normalization
82
- using provided mean and standard deviation values for latent space representation.
83
- Derived classes should load pre-trained encoder and decoder components from a remote store
84
-
85
- Attributes:
86
- latent_mean (Tensor): The mean used for normalizing the latent representation.
87
- latent_std (Tensor): The standard deviation used for normalizing the latent representation.
88
- dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled.
89
-
90
- Args:
91
- mean_std_fp (str): File path to the pickle file containing mean and std of the latent space.
92
- latent_ch (int, optional): Number of latent channels (default is 16).
93
- is_image (bool, optional): Flag to indicate whether the output is an image (default is True).
94
- is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True).
95
- """
96
-
97
- def __init__(
98
- self,
99
- name: str,
100
- latent_ch: int = 16,
101
- is_image: bool = True,
102
- is_bf16: bool = True,
103
- ) -> None:
104
- super().__init__(latent_ch, name)
105
- dtype = torch.bfloat16 if is_bf16 else torch.float32
106
- self.dtype = dtype
107
- self.is_image = is_image
108
- self.name = name
109
-
110
- def register_mean_std(self, vae_dir: str) -> None:
111
- latent_mean, latent_std = torch.load(os.path.join(vae_dir, "image_mean_std.pt"), weights_only=True)
112
-
113
- target_shape = [1, self.latent_ch, 1, 1] if self.is_image else [1, self.latent_ch, 1, 1, 1]
114
-
115
- self.register_buffer(
116
- "latent_mean",
117
- latent_mean.to(self.dtype).reshape(*target_shape),
118
- persistent=False,
119
- )
120
- self.register_buffer(
121
- "latent_std",
122
- latent_std.to(self.dtype).reshape(*target_shape),
123
- persistent=False,
124
- )
125
-
126
- @torch.no_grad()
127
- def encode(self, state: torch.Tensor) -> torch.Tensor:
128
- """
129
- Encode the input state to latent space; also handle the dtype conversion, mean and std scaling
130
- """
131
- in_dtype = state.dtype
132
- latent_mean = self.latent_mean.to(in_dtype)
133
- latent_std = self.latent_std.to(in_dtype)
134
- encoded_state = self.encoder(state.to(self.dtype))
135
- if isinstance(encoded_state, torch.Tensor):
136
- pass
137
- elif isinstance(encoded_state, tuple):
138
- assert isinstance(encoded_state[0], torch.Tensor)
139
- encoded_state = encoded_state[0]
140
- else:
141
- raise ValueError("Invalid type of encoded state")
142
- return (encoded_state.to(in_dtype) - latent_mean) / latent_std
143
-
144
- @torch.no_grad()
145
- def decode(self, latent: torch.Tensor) -> torch.Tensor:
146
- """
147
- Decode the input latent to state; also handle the dtype conversion, mean and std scaling
148
- """
149
- in_dtype = latent.dtype
150
- latent = latent * self.latent_std.to(in_dtype) + self.latent_mean.to(in_dtype)
151
- return self.decoder(latent.to(self.dtype)).to(in_dtype)
152
-
153
- def reset_dtype(self, *args, **kwargs):
154
- """
155
- Resets the data type of the encoder and decoder to the model's default data type.
156
-
157
- Args:
158
- *args, **kwargs: Unused, present to allow flexibility in method calls.
159
- """
160
- del args, kwargs
161
- self.decoder.to(self.dtype)
162
- self.encoder.to(self.dtype)
163
-
164
-
165
- class JITVAE(BasePretrainedImageVAE):
166
- """
167
- A JIT compiled Variational Autoencoder (VAE) that loads pre-trained encoder
168
- and decoder components from a remote store, handles data type conversions, and normalization
169
- using provided mean and standard deviation values for latent space representation.
170
-
171
- Attributes:
172
- encoder (Module): The JIT compiled encoder loaded from storage.
173
- decoder (Module): The JIT compiled decoder loaded from storage.
174
- latent_mean (Tensor): The mean used for normalizing the latent representation.
175
- latent_std (Tensor): The standard deviation used for normalizing the latent representation.
176
- dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled.
177
-
178
- Args:
179
- name (str): Name of the model, used for differentiating cache file paths.
180
- latent_ch (int, optional): Number of latent channels (default is 16).
181
- is_image (bool, optional): Flag to indicate whether the output is an image (default is True).
182
- is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True).
183
- """
184
-
185
- def __init__(
186
- self,
187
- name: str,
188
- latent_ch: int = 16,
189
- is_image: bool = True,
190
- is_bf16: bool = True,
191
- ):
192
- super().__init__(name, latent_ch, is_image, is_bf16)
193
-
194
- def load_encoder(self, vae_dir: str) -> None:
195
- """
196
- Load the encoder from the remote store.
197
- """
198
- self.encoder = torch.load(os.path.join(vae_dir, "encoder.jit"), weights_only=True)
199
-
200
- self.encoder.eval()
201
- for param in self.encoder.parameters():
202
- param.requires_grad = False
203
- self.encoder.to(self.dtype)
204
-
205
- def load_decoder(self, vae_dir: str) -> None:
206
- """
207
- Load the decoder from the remote store.
208
- """
209
- self.decoder = torch.load(os.path.join(vae_dir, "decoder.jit"), weights_only=True)
210
-
211
- self.decoder.eval()
212
- for param in self.decoder.parameters():
213
- param.requires_grad = False
214
- self.decoder.to(self.dtype)
215
-
216
-
217
- class BaseVAE(torch.nn.Module, ABC):
218
- """
219
- Abstract base class for a Variational Autoencoder (VAE).
220
-
221
- All subclasses should implement the methods to define the behavior for encoding
222
- and decoding, along with specifying the latent channel size.
223
- """
224
-
225
- def __init__(self, channel: int = 3, name: str = "vae"):
226
- super().__init__()
227
- self.channel = channel
228
- self.name = name
229
-
230
- @property
231
- def latent_ch(self) -> int:
232
- """
233
- Returns the number of latent channels in the VAE.
234
- """
235
- return self.channel
236
-
237
- @abstractmethod
238
- def encode(self, state: torch.Tensor) -> torch.Tensor:
239
- """
240
- Encodes the input tensor into a latent representation.
241
-
242
- Args:
243
- - state (torch.Tensor): The input tensor to encode.
244
-
245
- Returns:
246
- - torch.Tensor: The encoded latent tensor.
247
- """
248
- pass
249
-
250
- @abstractmethod
251
- def decode(self, latent: torch.Tensor) -> torch.Tensor:
252
- """
253
- Decodes the latent representation back to the original space.
254
-
255
- Args:
256
- - latent (torch.Tensor): The latent tensor to decode.
257
-
258
- Returns:
259
- - torch.Tensor: The decoded tensor.
260
- """
261
- pass
262
-
263
- @property
264
- def spatial_compression_factor(self) -> int:
265
- """
266
- Returns the spatial reduction factor for the VAE.
267
- """
268
- raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.")
269
-
270
-
271
- class VideoTokenizerInterface(ABC):
272
- @abstractmethod
273
- def encode(self, state: torch.Tensor) -> torch.Tensor:
274
- pass
275
-
276
- @abstractmethod
277
- def decode(self, latent: torch.Tensor) -> torch.Tensor:
278
- pass
279
-
280
- @abstractmethod
281
- def get_latent_num_frames(self, num_pixel_frames: int) -> int:
282
- pass
283
-
284
- @abstractmethod
285
- def get_pixel_num_frames(self, num_latent_frames: int) -> int:
286
- pass
287
-
288
- @property
289
- @abstractmethod
290
- def spatial_compression_factor(self):
291
- pass
292
-
293
- @property
294
- @abstractmethod
295
- def temporal_compression_factor(self):
296
- pass
297
-
298
- @property
299
- @abstractmethod
300
- def spatial_resolution(self):
301
- pass
302
-
303
- @property
304
- @abstractmethod
305
- def pixel_chunk_duration(self):
306
- pass
307
-
308
- @property
309
- @abstractmethod
310
- def latent_chunk_duration(self):
311
- pass
312
-
313
-
314
- class BasePretrainedVideoTokenizer(ABC):
315
- """
316
- Base class for a pretrained video tokenizer that handles chunking of video data for efficient processing.
317
-
318
- Args:
319
- pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level.
320
- temporal_compress_factor (int): The factor by which the video data is temporally compressed during processing.
321
- max_enc_batch_size (int): The maximum batch size to process in one go during encoding to avoid memory overflow.
322
- max_dec_batch_size (int): The maximum batch size to process in one go during decoding to avoid memory overflow.
323
-
324
- The class introduces parameters for managing temporal chunks (`pixel_chunk_duration` and `temporal_compress_factor`)
325
- which define how video data is subdivided and compressed during the encoding and decoding processes. The
326
- `max_enc_batch_size` and `max_dec_batch_size` parameters allow processing in smaller batches to handle memory
327
- constraints.
328
- """
329
-
330
- def __init__(
331
- self,
332
- pixel_chunk_duration: int = 17,
333
- temporal_compress_factor: int = 8,
334
- max_enc_batch_size: int = 8,
335
- max_dec_batch_size: int = 4,
336
- ):
337
- self._pixel_chunk_duration = pixel_chunk_duration
338
- self._temporal_compress_factor = temporal_compress_factor
339
- self.max_enc_batch_size = max_enc_batch_size
340
- self.max_dec_batch_size = max_dec_batch_size
341
-
342
- def register_mean_std(self, vae_dir: str) -> None:
343
- latent_mean, latent_std = torch.load(os.path.join(vae_dir, "mean_std.pt"), weights_only=True)
344
-
345
- latent_mean = latent_mean.view(self.latent_ch, -1)[:, : self.latent_chunk_duration]
346
- latent_std = latent_std.view(self.latent_ch, -1)[:, : self.latent_chunk_duration]
347
-
348
- target_shape = [1, self.latent_ch, self.latent_chunk_duration, 1, 1]
349
-
350
- self.register_buffer(
351
- "latent_mean",
352
- latent_mean.to(self.dtype).reshape(*target_shape),
353
- persistent=False,
354
- )
355
- self.register_buffer(
356
- "latent_std",
357
- latent_std.to(self.dtype).reshape(*target_shape),
358
- persistent=False,
359
- )
360
-
361
- def transform_encode_state_shape(self, state: torch.Tensor) -> torch.Tensor:
362
- """
363
- Rearranges the input state tensor to the required shape for encoding video data. Mainly for chunk based encoding
364
- """
365
- B, C, T, H, W = state.shape
366
- assert (
367
- T % self.pixel_chunk_duration == 0
368
- ), f"Temporal dimension {T} is not divisible by chunk_length {self.pixel_chunk_duration}"
369
- return rearrange(state, "b c (n t) h w -> (b n) c t h w", t=self.pixel_chunk_duration)
370
-
371
- def transform_decode_state_shape(self, latent: torch.Tensor) -> torch.Tensor:
372
- B, _, T, _, _ = latent.shape
373
- assert (
374
- T % self.latent_chunk_duration == 0
375
- ), f"Temporal dimension {T} is not divisible by chunk_length {self.latent_chunk_duration}"
376
- return rearrange(latent, "b c (n t) h w -> (b n) c t h w", t=self.latent_chunk_duration)
377
-
378
- @torch.no_grad()
379
- def encode(self, state: torch.Tensor) -> torch.Tensor:
380
- if self._temporal_compress_factor == 1:
381
- _, _, origin_T, _, _ = state.shape
382
- state = rearrange(state, "b c t h w -> (b t) c 1 h w")
383
- B, C, T, H, W = state.shape
384
- state = self.transform_encode_state_shape(state)
385
- # use max_enc_batch_size to avoid OOM
386
- if state.shape[0] > self.max_enc_batch_size:
387
- latent = []
388
- for i in range(0, state.shape[0], self.max_enc_batch_size):
389
- latent.append(super().encode(state[i : i + self.max_enc_batch_size]))
390
- latent = torch.cat(latent, dim=0)
391
- else:
392
- latent = super().encode(state)
393
-
394
- latent = rearrange(latent, "(b n) c t h w -> b c (n t) h w", b=B)
395
- if self._temporal_compress_factor == 1:
396
- latent = rearrange(latent, "(b t) c 1 h w -> b c t h w", t=origin_T)
397
- return latent
398
-
399
- @torch.no_grad()
400
- def decode(self, latent: torch.Tensor) -> torch.Tensor:
401
- """
402
- Decodes a batch of latent representations into video frames by applying temporal chunking. Similar to encode,
403
- it handles video data by processing smaller temporal chunks to reconstruct the original video dimensions.
404
-
405
- It can also decode single frame image data.
406
-
407
- Args:
408
- latent (torch.Tensor): The latent space tensor containing encoded video data.
409
-
410
- Returns:
411
- torch.Tensor: The decoded video tensor reconstructed from latent space.
412
- """
413
- if self._temporal_compress_factor == 1:
414
- _, _, origin_T, _, _ = latent.shape
415
- latent = rearrange(latent, "b c t h w -> (b t) c 1 h w")
416
- B, _, T, _, _ = latent.shape
417
- latent = self.transform_decode_state_shape(latent)
418
- # use max_enc_batch_size to avoid OOM
419
- if latent.shape[0] > self.max_dec_batch_size:
420
- state = []
421
- for i in range(0, latent.shape[0], self.max_dec_batch_size):
422
- state.append(super().decode(latent[i : i + self.max_dec_batch_size]))
423
- state = torch.cat(state, dim=0)
424
- else:
425
- state = super().decode(latent)
426
- assert state.shape[2] == self.pixel_chunk_duration
427
- state = rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B)
428
- if self._temporal_compress_factor == 1:
429
- return rearrange(state, "(b t) c 1 h w -> b c t h w", t=origin_T)
430
- return state
431
-
432
- @property
433
- def pixel_chunk_duration(self) -> int:
434
- return self._pixel_chunk_duration
435
-
436
- @property
437
- def latent_chunk_duration(self) -> int:
438
- # return self._latent_chunk_duration
439
- assert (self.pixel_chunk_duration - 1) % self.temporal_compression_factor == 0, (
440
- f"Pixel chunk duration {self.pixel_chunk_duration} is not divisible by latent chunk duration "
441
- f"{self.latent_chunk_duration}"
442
- )
443
- return (self.pixel_chunk_duration - 1) // self.temporal_compression_factor + 1
444
-
445
- @property
446
- def temporal_compression_factor(self):
447
- return self._temporal_compress_factor
448
-
449
- def get_latent_num_frames(self, num_pixel_frames: int) -> int:
450
- if num_pixel_frames == 1:
451
- return 1
452
- assert (
453
- num_pixel_frames % self.pixel_chunk_duration == 0
454
- ), f"Temporal dimension {num_pixel_frames} is not divisible by chunk_length {self.pixel_chunk_duration}"
455
- return num_pixel_frames // self.pixel_chunk_duration * self.latent_chunk_duration
456
-
457
- def get_pixel_num_frames(self, num_latent_frames: int) -> int:
458
- if num_latent_frames == 1:
459
- return 1
460
- assert (
461
- num_latent_frames % self.latent_chunk_duration == 0
462
- ), f"Temporal dimension {num_latent_frames} is not divisible by chunk_length {self.latent_chunk_duration}"
463
- return num_latent_frames // self.latent_chunk_duration * self.pixel_chunk_duration
464
-
465
-
466
- class VideoJITTokenizer(BasePretrainedVideoTokenizer, JITVAE, VideoTokenizerInterface):
467
- """
468
- Instance of BasePretrainedVideoVAE that loads encoder and decoder from JIT scripted module file
469
- """
470
-
471
- def __init__(
472
- self,
473
- name: str,
474
- latent_ch: int = 16,
475
- is_bf16: bool = True,
476
- spatial_compression_factor: int = 16,
477
- temporal_compression_factor: int = 8,
478
- pixel_chunk_duration: int = 17,
479
- max_enc_batch_size: int = 8,
480
- max_dec_batch_size: int = 4,
481
- spatial_resolution: str = "720",
482
- ):
483
- super().__init__(
484
- pixel_chunk_duration,
485
- temporal_compression_factor,
486
- max_enc_batch_size,
487
- max_dec_batch_size,
488
- )
489
- super(BasePretrainedVideoTokenizer, self).__init__(
490
- name,
491
- latent_ch,
492
- False,
493
- is_bf16,
494
- )
495
-
496
- self._spatial_compression_factor = spatial_compression_factor
497
- self._spatial_resolution = spatial_resolution
498
-
499
- @property
500
- def spatial_compression_factor(self):
501
- return self._spatial_compression_factor
502
-
503
- @property
504
- def spatial_resolution(self) -> str:
505
- return self._spatial_resolution
506
-
507
-
508
- class JointImageVideoTokenizer(BaseVAE, VideoTokenizerInterface):
509
- def __init__(
510
- self,
511
- image_vae: torch.nn.Module,
512
- video_vae: torch.nn.Module,
513
- name: str,
514
- latent_ch: int = 16,
515
- squeeze_for_image: bool = True,
516
- ):
517
- super().__init__(latent_ch, name)
518
- self.image_vae = image_vae
519
- self.video_vae = video_vae
520
- self.squeeze_for_image = squeeze_for_image
521
-
522
- def encode_image(self, state: torch.Tensor) -> torch.Tensor:
523
- if self.squeeze_for_image:
524
- return self.image_vae.encode(state.squeeze(2)).unsqueeze(2)
525
- return self.image_vae.encode(state)
526
-
527
- def decode_image(self, latent: torch.Tensor) -> torch.Tensor:
528
- if self.squeeze_for_image:
529
- return self.image_vae.decode(latent.squeeze(2)).unsqueeze(2)
530
- return self.image_vae.decode(latent)
531
-
532
- @torch.no_grad()
533
- def encode(self, state: torch.Tensor) -> torch.Tensor:
534
- B, C, T, H, W = state.shape
535
- if T == 1:
536
- return self.encode_image(state)
537
-
538
- return self.video_vae.encode(state)
539
-
540
- @torch.no_grad()
541
- def decode(self, latent: torch.Tensor) -> torch.Tensor:
542
- B, C, T, H, W = latent.shape
543
- if T == 1:
544
- return self.decode_image(latent)
545
- return self.video_vae.decode(latent)
546
-
547
- def reset_dtype(self, *args, **kwargs):
548
- """
549
- Resets the data type of the encoder and decoder to the model's default data type.
550
-
551
- Args:
552
- *args, **kwargs: Unused, present to allow flexibility in method calls.
553
- """
554
- del args, kwargs
555
- self.video_vae.reset_dtype()
556
-
557
- def get_latent_num_frames(self, num_pixel_frames: int) -> int:
558
- if num_pixel_frames == 1:
559
- return 1
560
- return self.video_vae.get_latent_num_frames(num_pixel_frames)
561
-
562
- def get_pixel_num_frames(self, num_latent_frames: int) -> int:
563
- if num_latent_frames == 1:
564
- return 1
565
- return self.video_vae.get_pixel_num_frames(num_latent_frames)
566
-
567
- @property
568
- def spatial_compression_factor(self):
569
- return self.video_vae.spatial_compression_factor
570
-
571
- @property
572
- def temporal_compression_factor(self):
573
- return self.video_vae.temporal_compression_factor
574
-
575
- @property
576
- def spatial_resolution(self) -> str:
577
- return self.video_vae.spatial_resolution
578
-
579
- @property
580
- def pixel_chunk_duration(self) -> int:
581
- return self.video_vae.pixel_chunk_duration
582
-
583
- @property
584
- def latent_chunk_duration(self) -> int:
585
- return self.video_vae.latent_chunk_duration
586
-
587
-
588
- class JointImageVideoSharedJITTokenizer(JointImageVideoTokenizer):
589
- """
590
- First version of the ImageVideoVAE trained with Fitsum.
591
- We have to use seperate mean and std for image and video due to non-causal nature of the model.
592
- """
593
-
594
- def __init__(self, image_vae: Module, video_vae: Module, name: str, latent_ch: int = 16):
595
- super().__init__(image_vae, video_vae, name, latent_ch, squeeze_for_image=False)
596
- assert isinstance(image_vae, JITVAE)
597
- assert isinstance(
598
- video_vae, VideoJITTokenizer
599
- ), f"video_vae should be an instance of VideoJITVAE, got {type(video_vae)}"
600
- # a hack to make the image_vae and video_vae share the same encoder and decoder
601
-
602
- def load_weights(self, vae_dir: str):
603
- self.video_vae.register_mean_std(vae_dir)
604
-
605
- self.video_vae.load_decoder(vae_dir)
606
- self.video_vae.load_encoder(vae_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/072076fb853aec819a7298df83e26338e0cb4c3a DELETED
@@ -1,187 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import argparse
17
- import json
18
- import os
19
- from typing import Iterable, Tuple, Union
20
-
21
- from .misc import misc
22
- import torch
23
- from PIL import Image
24
-
25
- from .guardrail_common_core import ContentSafetyGuardrail, GuardrailRunner
26
- from .guardrail_common_io_utils import get_video_filepaths, read_video
27
- from .guardrail_video_content_safety_filter_model import ModelConfig, VideoSafetyModel
28
- from .guardrail_video_content_safety_filter_vision_encoder import SigLIPEncoder
29
- from .log import log
30
-
31
- DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/video_content_safety_filter"
32
-
33
-
34
- # Define the class index to class name mapping for multi-class classification
35
- CLASS_IDX_TO_NAME = {
36
- 0: "Safe",
37
- 1: "Sexual_Content",
38
- 2: "Violence",
39
- 3: "Drugs",
40
- 4: "Child_Abuse",
41
- 5: "Hate_and_Harassment",
42
- 6: "Self-Harm",
43
- }
44
-
45
-
46
- class VideoContentSafetyFilter(ContentSafetyGuardrail):
47
- def __init__(
48
- self,
49
- checkpoint_dir: str = DEFAULT_CHECKPOINT_DIR,
50
- device="cuda" if torch.cuda.is_available() else "cpu",
51
- ) -> None:
52
- self.device = device
53
- self.dtype = torch.float32
54
-
55
- # Initialize the SigLIP encoder
56
- self.encoder = SigLIPEncoder(checkpoint_dir=checkpoint_dir, device=device, dtype=self.dtype)
57
-
58
- # Use ModelConfig directly for inference configuration
59
- model_config = ModelConfig(input_size=1152, num_classes=7)
60
-
61
- # Load the multi-class classifier
62
- self.model = VideoSafetyModel(model_config)
63
- safety_filter_local_path = os.path.join(checkpoint_dir, "safety_filter.pt")
64
- checkpoint = torch.load(safety_filter_local_path, map_location=torch.device("cpu"), weights_only=True)
65
- self.model.load_state_dict(checkpoint["model"])
66
- self.model.to(self.device, dtype=self.dtype).eval()
67
-
68
- @torch.inference_mode()
69
- def __infer(self, pil_image: Image.Image) -> int:
70
- """Infer the class of the image."""
71
- image_embs = self.encoder.encode_image(pil_image)
72
- logits = self.model.network(image_embs)
73
- probabilities = torch.nn.functional.softmax(logits, dim=-1)
74
- predicted_class = torch.argmax(probabilities, dim=-1).item()
75
- return predicted_class
76
-
77
- def is_safe_file(self, filepath: str) -> bool:
78
- """Check if the video file is safe."""
79
- video_data = read_video(filepath)
80
-
81
- # Sample frames at 2 FPS
82
- sample_rate = 2 # frames per second
83
- frame_interval = int(video_data.fps / sample_rate)
84
- frame_numbers = list(range(0, int(video_data.fps * video_data.duration), frame_interval))
85
-
86
- is_safe = True
87
- frame_scores = []
88
-
89
- for frame_number in frame_numbers:
90
- try:
91
- frame = video_data.frames[frame_number]
92
- pil_image = Image.fromarray(frame)
93
- predicted_class = self.__infer(pil_image)
94
- class_name = CLASS_IDX_TO_NAME.get(predicted_class, "Unknown")
95
- frame_scores.append({"frame_number": frame_number, "class": class_name})
96
-
97
- # If any frame is not "Safe", mark the video as unsafe
98
- if predicted_class != 0:
99
- is_safe = False
100
- break
101
-
102
- except Exception as e:
103
- log.warning(f"Warning: Failed to run safety classifier on frame_number {frame_number}. Exception: {e}")
104
- continue
105
-
106
- # Prepare data for JSON
107
- video_data = {
108
- "filepath": filepath,
109
- "is_safe": is_safe,
110
- "video_length": video_data.duration,
111
- "fps": video_data.fps,
112
- "frame_scores": frame_scores,
113
- }
114
-
115
- log.info(f"Video {filepath} is {'SAFE' if is_safe else 'UNSAFE'}.")
116
- log.debug(f"Video data: {json.dumps(video_data, indent=4)}")
117
- return is_safe
118
-
119
- def is_safe_frames(self, frames: Iterable) -> bool:
120
- """Check if the video frames are safe."""
121
- is_safe = True
122
- frame_scores = []
123
-
124
- for frame_number, frame in enumerate(frames):
125
- try:
126
- pil_image = Image.fromarray(frame)
127
- predicted_class = self.__infer(pil_image)
128
- class_name = CLASS_IDX_TO_NAME.get(predicted_class, "Unknown")
129
- frame_scores.append({"frame_number": frame_number, "class": class_name})
130
-
131
- # If any frame is not "Safe", mark as not safe
132
- if predicted_class != 0:
133
- is_safe = False
134
- break
135
-
136
- except Exception as e:
137
- log.warning(f"Warning: Failed to run safety classifier on frame_number {frame_number}. Exception: {e}")
138
- continue
139
-
140
- video_data = {
141
- "is_safe": is_safe,
142
- "frame_scores": frame_scores,
143
- }
144
-
145
- log.debug(f"Frames data: {json.dumps(video_data, indent=4)}")
146
- return is_safe
147
-
148
- def is_safe(self, input: Union[str, Iterable]) -> Tuple[bool, str]:
149
- if isinstance(input, str):
150
- is_safe = self.is_safe_file(input)
151
- return is_safe, "safe video detected" if is_safe else "unsafe video detected"
152
- elif isinstance(input, Iterable):
153
- is_safe = self.is_safe_frames(input)
154
- return is_safe, "safe frames detected" if is_safe else "unsafe frames detected"
155
- else:
156
- raise ValueError(f"Input type {type(input)} not supported.")
157
-
158
-
159
- def parse_args():
160
- parser = argparse.ArgumentParser()
161
- parser.add_argument("--input_dir", type=str, required=True, help="Path containing input videos")
162
- parser.add_argument(
163
- "--checkpoint_dir",
164
- type=str,
165
- help="Path to the Video Content Safety Filter checkpoint folder",
166
- default=DEFAULT_CHECKPOINT_DIR,
167
- )
168
- return parser.parse_args()
169
-
170
-
171
- def main(args):
172
- filepaths = get_video_filepaths(args.input_dir)
173
- if not filepaths:
174
- log.error(f"No video files found in directory: {args.input_dir}")
175
- return
176
-
177
- video_filter = VideoContentSafetyFilter(checkpoint_dir=args.checkpoint_dir)
178
- runner = GuardrailRunner(safety_models=[video_filter], generic_safe_msg="Video is safe")
179
-
180
- for filepath in filepaths:
181
- with misc.timer("video content safety filter"):
182
- _ = runner.run_safety_check(filepath)
183
-
184
-
185
- if __name__ == "__main__":
186
- args = parse_args()
187
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/09440b34a95b1708d2154376f2a0202a533cb3b2 DELETED
@@ -1,46 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- # Functions for performing operations with broadcasting to the right axis
17
- #
18
- # Example
19
- # input1: tensor of size (N1, N2)
20
- # input2: tensor of size (N1, N2, N3, N4)
21
- # batch_mul(input1, input2) = input1[:, :, None, None] * input2
22
- #
23
- # If the common dimensions don't match, we raise an assertion error.
24
-
25
- from torch import Tensor
26
-
27
-
28
- def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
29
- ndims1 = x.ndim
30
- ndims2 = y.ndim
31
-
32
- common_ndims = min(ndims1, ndims2)
33
- for axis in range(common_ndims):
34
- assert x.shape[axis] == y.shape[axis], "Dimensions not equal at axis {}".format(axis)
35
-
36
- if ndims1 < ndims2:
37
- x = x.reshape(x.shape + (1,) * (ndims2 - ndims1))
38
- elif ndims2 < ndims1:
39
- y = y.reshape(y.shape + (1,) * (ndims1 - ndims2))
40
-
41
- return x, y
42
-
43
-
44
- def batch_mul(x: Tensor, y: Tensor) -> Tensor:
45
- x, y = common_broadcast(x, y)
46
- return x * y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/0c2f9c6280ccfa60e1ba8a38e3062e0caf99e71e DELETED
@@ -1,560 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- """The model definition for 3D layers
17
-
18
- Adapted from: https://github.com/lucidrains/magvit2-pytorch/blob/9f49074179c912736e617d61b32be367eb5f993a/
19
- magvit2_pytorch/magvit2_pytorch.py#L889
20
-
21
- [MIT License Copyright (c) 2023 Phil Wang]
22
- https://github.com/lucidrains/magvit2-pytorch/blob/9f49074179c912736e617d61b32be367eb5f993a/LICENSE
23
- """
24
- import math
25
- from typing import Tuple, Union
26
-
27
- import numpy as np
28
- import torch
29
- import torch.nn as nn
30
- import torch.nn.functional as F
31
-
32
- from .ar_tokenizer_patching import Patcher3D, UnPatcher3D
33
- from .ar_tokenizer_utils import (
34
- CausalNormalize,
35
- batch2space,
36
- batch2time,
37
- cast_tuple,
38
- is_odd,
39
- nonlinearity,
40
- replication_pad,
41
- space2batch,
42
- time2batch,
43
- )
44
- from .log import log
45
-
46
-
47
- class CausalConv3d(nn.Module):
48
- def __init__(
49
- self,
50
- chan_in: int = 1,
51
- chan_out: int = 1,
52
- kernel_size: Union[int, Tuple[int, int, int]] = 3,
53
- pad_mode: str = "constant",
54
- **kwargs,
55
- ):
56
- super().__init__()
57
- kernel_size = cast_tuple(kernel_size, 3)
58
-
59
- time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
60
-
61
- assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
62
-
63
- dilation = kwargs.pop("dilation", 1)
64
- stride = kwargs.pop("stride", 1)
65
- time_stride = kwargs.pop("time_stride", 1)
66
- time_dilation = kwargs.pop("time_dilation", 1)
67
- padding = kwargs.pop("padding", 1)
68
-
69
- self.pad_mode = pad_mode
70
- time_pad = time_dilation * (time_kernel_size - 1) + (1 - time_stride)
71
- self.time_pad = time_pad
72
-
73
- self.spatial_pad = (padding, padding, padding, padding)
74
-
75
- stride = (time_stride, stride, stride)
76
- dilation = (time_dilation, dilation, dilation)
77
- self.conv3d = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
78
-
79
- def _replication_pad(self, x: torch.Tensor) -> torch.Tensor:
80
- x_prev = x[:, :, :1, ...].repeat(1, 1, self.time_pad, 1, 1)
81
- x = torch.cat([x_prev, x], dim=2)
82
- padding = self.spatial_pad + (0, 0)
83
- return F.pad(x, padding, mode=self.pad_mode, value=0.0)
84
-
85
- def forward(self, x: torch.Tensor) -> torch.Tensor:
86
- x = self._replication_pad(x)
87
- return self.conv3d(x)
88
-
89
-
90
- class CausalHybridUpsample3d(nn.Module):
91
- def __init__(self, in_channels: int, spatial_up: bool = True, temporal_up: bool = True, **ignore_kwargs) -> None:
92
- super().__init__()
93
- self.conv1 = (
94
- CausalConv3d(in_channels, in_channels, kernel_size=(3, 1, 1), stride=1, time_stride=1, padding=0)
95
- if temporal_up
96
- else nn.Identity()
97
- )
98
- self.conv2 = (
99
- CausalConv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=1, time_stride=1, padding=1)
100
- if spatial_up
101
- else nn.Identity()
102
- )
103
- self.conv3 = (
104
- CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, time_stride=1, padding=0)
105
- if spatial_up or temporal_up
106
- else nn.Identity()
107
- )
108
- self.spatial_up = spatial_up
109
- self.temporal_up = temporal_up
110
-
111
- def forward(self, x: torch.Tensor) -> torch.Tensor:
112
- if not self.spatial_up and not self.temporal_up:
113
- return x
114
-
115
- # hybrid upsample temporally.
116
- if self.temporal_up:
117
- time_factor = 1.0 + 1.0 * (x.shape[2] > 1)
118
- if isinstance(time_factor, torch.Tensor):
119
- time_factor = time_factor.item()
120
- x = x.repeat_interleave(int(time_factor), dim=2)
121
- x = x[..., int(time_factor - 1) :, :, :]
122
- x = self.conv1(x) + x
123
-
124
- # hybrid upsample spatially.
125
- if self.spatial_up:
126
- x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4)
127
- x = self.conv2(x) + x
128
-
129
- # final 1x1x1 conv.
130
- x = self.conv3(x)
131
- return x
132
-
133
-
134
- class CausalHybridDownsample3d(nn.Module):
135
- def __init__(
136
- self, in_channels: int, spatial_down: bool = True, temporal_down: bool = True, **ignore_kwargs
137
- ) -> None:
138
- super().__init__()
139
- self.conv1 = (
140
- CausalConv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=2, time_stride=1, padding=0)
141
- if spatial_down
142
- else nn.Identity()
143
- )
144
- self.conv2 = (
145
- CausalConv3d(in_channels, in_channels, kernel_size=(3, 1, 1), stride=1, time_stride=2, padding=0)
146
- if temporal_down
147
- else nn.Identity()
148
- )
149
- self.conv3 = (
150
- CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, time_stride=1, padding=0)
151
- if spatial_down or temporal_down
152
- else nn.Identity()
153
- )
154
- self.spatial_down = spatial_down
155
- self.temporal_down = temporal_down
156
-
157
- def forward(self, x: torch.Tensor) -> torch.Tensor:
158
- if not self.spatial_down and not self.temporal_down:
159
- return x
160
-
161
- # hybrid downsample spatially.
162
- if self.spatial_down:
163
- pad = (0, 1, 0, 1, 0, 0)
164
- x = F.pad(x, pad, mode="constant", value=0)
165
- x1 = self.conv1(x)
166
- x2 = F.avg_pool3d(x, kernel_size=(1, 2, 2), stride=(1, 2, 2))
167
- x = x1 + x2
168
-
169
- # hybrid downsample temporally.
170
- if self.temporal_down:
171
- x = replication_pad(x)
172
- x1 = self.conv2(x)
173
- x2 = F.avg_pool3d(x, kernel_size=(2, 1, 1), stride=(2, 1, 1))
174
- x = x1 + x2
175
-
176
- # final 1x1x1 conv.
177
- x = self.conv3(x)
178
- return x
179
-
180
-
181
- class CausalResnetBlockFactorized3d(nn.Module):
182
- def __init__(self, *, in_channels: int, out_channels: int = None, dropout: float, num_groups: int) -> None:
183
- super().__init__()
184
- self.in_channels = in_channels
185
- out_channels = in_channels if out_channels is None else out_channels
186
-
187
- self.norm1 = CausalNormalize(in_channels, num_groups=1)
188
- self.conv1 = nn.Sequential(
189
- CausalConv3d(in_channels, out_channels, kernel_size=(1, 3, 3), stride=1, padding=1),
190
- CausalConv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=0),
191
- )
192
- self.norm2 = CausalNormalize(out_channels, num_groups=num_groups)
193
- self.dropout = torch.nn.Dropout(dropout)
194
- self.conv2 = nn.Sequential(
195
- CausalConv3d(out_channels, out_channels, kernel_size=(1, 3, 3), stride=1, padding=1),
196
- CausalConv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=0),
197
- )
198
- self.nin_shortcut = (
199
- CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
200
- if in_channels != out_channels
201
- else nn.Identity()
202
- )
203
-
204
- def forward(self, x: torch.Tensor) -> torch.Tensor:
205
- h = x
206
- h = self.norm1(h)
207
- h = nonlinearity(h)
208
- h = self.conv1(h)
209
-
210
- h = self.norm2(h)
211
- h = nonlinearity(h)
212
- h = self.dropout(h)
213
- h = self.conv2(h)
214
- x = self.nin_shortcut(x)
215
-
216
- return x + h
217
-
218
-
219
- class CausalAttnBlock(nn.Module):
220
- def __init__(self, in_channels: int, num_groups: int) -> None:
221
- super().__init__()
222
-
223
- self.norm = CausalNormalize(in_channels, num_groups=num_groups)
224
- self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
225
- self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
226
- self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
227
- self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
228
-
229
- def forward(self, x: torch.Tensor) -> torch.Tensor:
230
- h_ = x
231
- h_ = self.norm(h_)
232
- q = self.q(h_)
233
- k = self.k(h_)
234
- v = self.v(h_)
235
-
236
- # compute attention
237
- q, batch_size = time2batch(q)
238
- k, batch_size = time2batch(k)
239
- v, batch_size = time2batch(v)
240
-
241
- b, c, h, w = q.shape
242
- q = q.reshape(b, c, h * w)
243
- q = q.permute(0, 2, 1)
244
- k = k.reshape(b, c, h * w)
245
- w_ = torch.bmm(q, k)
246
- w_ = w_ * (int(c) ** (-0.5))
247
- w_ = F.softmax(w_, dim=2)
248
-
249
- # attend to values
250
- v = v.reshape(b, c, h * w)
251
- w_ = w_.permute(0, 2, 1)
252
- h_ = torch.bmm(v, w_)
253
- h_ = h_.reshape(b, c, h, w)
254
-
255
- h_ = batch2time(h_, batch_size)
256
- h_ = self.proj_out(h_)
257
- return x + h_
258
-
259
-
260
- class CausalTemporalAttnBlock(nn.Module):
261
- def __init__(self, in_channels: int, num_groups: int) -> None:
262
- super().__init__()
263
-
264
- self.norm = CausalNormalize(in_channels, num_groups=num_groups)
265
- self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
266
- self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
267
- self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
268
- self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
269
-
270
- def forward(self, x: torch.Tensor) -> torch.Tensor:
271
- h_ = x
272
- h_ = self.norm(h_)
273
- q = self.q(h_)
274
- k = self.k(h_)
275
- v = self.v(h_)
276
-
277
- # compute attention
278
- q, batch_size, height = space2batch(q)
279
- k, _, _ = space2batch(k)
280
- v, _, _ = space2batch(v)
281
-
282
- bhw, c, t = q.shape
283
- q = q.permute(0, 2, 1) # (bhw, t, c)
284
- k = k.permute(0, 2, 1) # (bhw, t, c)
285
- v = v.permute(0, 2, 1) # (bhw, t, c)
286
-
287
- w_ = torch.bmm(q, k.permute(0, 2, 1)) # (bhw, t, t)
288
- w_ = w_ * (int(c) ** (-0.5))
289
-
290
- # Apply causal mask
291
- mask = torch.tril(torch.ones_like(w_))
292
- w_ = w_.masked_fill(mask == 0, float("-inf"))
293
- w_ = F.softmax(w_, dim=2)
294
-
295
- # attend to values
296
- h_ = torch.bmm(w_, v) # (bhw, t, c)
297
- h_ = h_.permute(0, 2, 1).reshape(bhw, c, t) # (bhw, c, t)
298
-
299
- h_ = batch2space(h_, batch_size, height)
300
- h_ = self.proj_out(h_)
301
- return x + h_
302
-
303
-
304
- class EncoderFactorized(nn.Module):
305
- def __init__(
306
- self,
307
- in_channels: int,
308
- channels: int,
309
- channels_mult: list[int],
310
- num_res_blocks: int,
311
- attn_resolutions: list[int],
312
- dropout: float,
313
- resolution: int,
314
- z_channels: int,
315
- spatial_compression: int,
316
- temporal_compression: int,
317
- **ignore_kwargs,
318
- ) -> None:
319
- super().__init__()
320
- self.num_resolutions = len(channels_mult)
321
- self.num_res_blocks = num_res_blocks
322
-
323
- # Patcher.
324
- patch_size = ignore_kwargs.get("patch_size", 1)
325
- self.patcher3d = Patcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
326
- in_channels = in_channels * patch_size * patch_size * patch_size
327
-
328
- # calculate the number of downsample operations
329
- self.num_spatial_downs = int(math.log2(spatial_compression)) - int(math.log2(patch_size))
330
- assert (
331
- self.num_spatial_downs <= self.num_resolutions
332
- ), f"Spatially downsample {self.num_resolutions} times at most"
333
-
334
- self.num_temporal_downs = int(math.log2(temporal_compression)) - int(math.log2(patch_size))
335
- assert (
336
- self.num_temporal_downs <= self.num_resolutions
337
- ), f"Temporally downsample {self.num_resolutions} times at most"
338
-
339
- # downsampling
340
- self.conv_in = nn.Sequential(
341
- CausalConv3d(in_channels, channels, kernel_size=(1, 3, 3), stride=1, padding=1),
342
- CausalConv3d(channels, channels, kernel_size=(3, 1, 1), stride=1, padding=0),
343
- )
344
-
345
- curr_res = resolution // patch_size
346
- in_ch_mult = (1,) + tuple(channels_mult)
347
- self.in_ch_mult = in_ch_mult
348
- self.down = nn.ModuleList()
349
- for i_level in range(self.num_resolutions):
350
- block = nn.ModuleList()
351
- attn = nn.ModuleList()
352
- block_in = channels * in_ch_mult[i_level]
353
- block_out = channels * channels_mult[i_level]
354
- for _ in range(self.num_res_blocks):
355
- block.append(
356
- CausalResnetBlockFactorized3d(
357
- in_channels=block_in, out_channels=block_out, dropout=dropout, num_groups=1
358
- )
359
- )
360
- block_in = block_out
361
- if curr_res in attn_resolutions:
362
- attn.append(
363
- nn.Sequential(
364
- CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1)
365
- )
366
- )
367
- down = nn.Module()
368
- down.block = block
369
- down.attn = attn
370
- if i_level != self.num_resolutions - 1:
371
- spatial_down = i_level < self.num_spatial_downs
372
- temporal_down = i_level < self.num_temporal_downs
373
- down.downsample = CausalHybridDownsample3d(
374
- block_in, spatial_down=spatial_down, temporal_down=temporal_down
375
- )
376
- curr_res = curr_res // 2
377
- self.down.append(down)
378
-
379
- # middle
380
- self.mid = nn.Module()
381
- self.mid.block_1 = CausalResnetBlockFactorized3d(
382
- in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1
383
- )
384
- self.mid.attn_1 = nn.Sequential(
385
- CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1)
386
- )
387
- self.mid.block_2 = CausalResnetBlockFactorized3d(
388
- in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1
389
- )
390
-
391
- # end
392
- self.norm_out = CausalNormalize(block_in, num_groups=1)
393
- self.conv_out = nn.Sequential(
394
- CausalConv3d(block_in, z_channels, kernel_size=(1, 3, 3), stride=1, padding=1),
395
- CausalConv3d(z_channels, z_channels, kernel_size=(3, 1, 1), stride=1, padding=0),
396
- )
397
-
398
- def forward(self, x: torch.Tensor) -> torch.Tensor:
399
- x = self.patcher3d(x)
400
-
401
- # downsampling
402
- h = self.conv_in(x)
403
- for i_level in range(self.num_resolutions):
404
- for i_block in range(self.num_res_blocks):
405
- h = self.down[i_level].block[i_block](h)
406
- if len(self.down[i_level].attn) > 0:
407
- h = self.down[i_level].attn[i_block](h)
408
- if i_level != self.num_resolutions - 1:
409
- h = self.down[i_level].downsample(h)
410
-
411
- # middle
412
- h = self.mid.block_1(h)
413
- h = self.mid.attn_1(h)
414
- h = self.mid.block_2(h)
415
-
416
- # end
417
- h = self.norm_out(h)
418
- h = nonlinearity(h)
419
- h = self.conv_out(h)
420
- return h
421
-
422
-
423
- class DecoderFactorized(nn.Module):
424
- def __init__(
425
- self,
426
- out_channels: int,
427
- channels: int,
428
- channels_mult: list[int],
429
- num_res_blocks: int,
430
- attn_resolutions: list[int],
431
- dropout: float,
432
- resolution: int,
433
- z_channels: int,
434
- spatial_compression: int,
435
- temporal_compression: int,
436
- **ignore_kwargs,
437
- ):
438
- super().__init__()
439
- self.num_resolutions = len(channels_mult)
440
- self.num_res_blocks = num_res_blocks
441
-
442
- # UnPatcher.
443
- patch_size = ignore_kwargs.get("patch_size", 1)
444
- self.unpatcher3d = UnPatcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
445
- out_ch = out_channels * patch_size * patch_size * patch_size
446
-
447
- # calculate the number of upsample operations
448
- self.num_spatial_ups = int(math.log2(spatial_compression)) - int(math.log2(patch_size))
449
- assert self.num_spatial_ups <= self.num_resolutions, f"Spatially upsample {self.num_resolutions} times at most"
450
- self.num_temporal_ups = int(math.log2(temporal_compression)) - int(math.log2(patch_size))
451
- assert (
452
- self.num_temporal_ups <= self.num_resolutions
453
- ), f"Temporally upsample {self.num_resolutions} times at most"
454
-
455
- block_in = channels * channels_mult[self.num_resolutions - 1]
456
- curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1)
457
- self.z_shape = (1, z_channels, curr_res, curr_res)
458
- log.debug("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
459
-
460
- # z to block_in
461
- self.conv_in = nn.Sequential(
462
- CausalConv3d(z_channels, block_in, kernel_size=(1, 3, 3), stride=1, padding=1),
463
- CausalConv3d(block_in, block_in, kernel_size=(3, 1, 1), stride=1, padding=0),
464
- )
465
-
466
- # middle
467
- self.mid = nn.Module()
468
- self.mid.block_1 = CausalResnetBlockFactorized3d(
469
- in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1
470
- )
471
- self.mid.attn_1 = nn.Sequential(
472
- CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1)
473
- )
474
- self.mid.block_2 = CausalResnetBlockFactorized3d(
475
- in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1
476
- )
477
-
478
- legacy_mode = ignore_kwargs.get("legacy_mode", False)
479
- # upsampling
480
- self.up = nn.ModuleList()
481
- for i_level in reversed(range(self.num_resolutions)):
482
- block = nn.ModuleList()
483
- attn = nn.ModuleList()
484
- block_out = channels * channels_mult[i_level]
485
- for _ in range(self.num_res_blocks + 1):
486
- block.append(
487
- CausalResnetBlockFactorized3d(
488
- in_channels=block_in, out_channels=block_out, dropout=dropout, num_groups=1
489
- )
490
- )
491
- block_in = block_out
492
- if curr_res in attn_resolutions:
493
- attn.append(
494
- nn.Sequential(
495
- CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1)
496
- )
497
- )
498
- up = nn.Module()
499
- up.block = block
500
- up.attn = attn
501
- if i_level != 0:
502
- # The layer index for temporal/spatial downsampling performed in the encoder should correspond
503
- # to the layer index, inreverse order, where upsampling is performed in the decoder.
504
- # If you've a pre-trained model, you can simply finetune.
505
- # For example:
506
- # Input tensor = (1, 3, 17, 32, 32)
507
- # Patch size = 4 for 3D wavelet transform
508
- # Compression rate = (8x16x16)
509
- #
510
- # We expect successive downsampling in the encoder and upsampling in the decoder to be mirrored.
511
- # ENCODER: `(...,5,8,8) -> (...,3,4,4) -> (...,3,2,2)`
512
- # DECODER: `(...,3,2,2) -> (...,3,4,4) -> (...,5,8,8)`
513
- #
514
- # if legacy_mode is True, the temporal upsampling is not perfectly mirrored.
515
- # ENCODER: `(...,5,8,8) -> (...,3,4,4) -> (...,3,2,2)`
516
- # DECODER: `(...,3,2,2) -> (...,5,4,4) -> (...,5,8,8)`
517
- #
518
- # Most of the CV and DV tokenizers were trained before 09/01/2024 with upsampling that's not mirrored.
519
- # Going forward, new CV/DV tokenizers will adopt `legacy_mode=False`, i.e. use mirrored upsampling.
520
- i_level_reverse = self.num_resolutions - i_level - 1
521
- if legacy_mode:
522
- temporal_up = i_level_reverse < self.num_temporal_ups
523
- else:
524
- temporal_up = 0 < i_level_reverse < self.num_temporal_ups + 1
525
- spatial_up = temporal_up or (
526
- i_level_reverse < self.num_spatial_ups and self.num_spatial_ups > self.num_temporal_ups
527
- )
528
- up.upsample = CausalHybridUpsample3d(block_in, spatial_up=spatial_up, temporal_up=temporal_up)
529
- curr_res = curr_res * 2
530
- self.up.insert(0, up) # prepend to get consistent order
531
-
532
- # end
533
- self.norm_out = CausalNormalize(block_in, num_groups=1)
534
- self.conv_out = nn.Sequential(
535
- CausalConv3d(block_in, out_ch, kernel_size=(1, 3, 3), stride=1, padding=1),
536
- CausalConv3d(out_ch, out_ch, kernel_size=(3, 1, 1), stride=1, padding=0),
537
- )
538
-
539
- def forward(self, z):
540
- h = self.conv_in(z)
541
-
542
- # middle block.
543
- h = self.mid.block_1(h)
544
- h = self.mid.attn_1(h)
545
- h = self.mid.block_2(h)
546
-
547
- # decoder blocks.
548
- for i_level in reversed(range(self.num_resolutions)):
549
- for i_block in range(self.num_res_blocks + 1):
550
- h = self.up[i_level].block[i_block](h)
551
- if len(self.up[i_level].attn) > 0:
552
- h = self.up[i_level].attn[i_block](h)
553
- if i_level != 0:
554
- h = self.up[i_level].upsample(h)
555
-
556
- h = self.norm_out(h)
557
- h = nonlinearity(h)
558
- h = self.conv_out(h)
559
- h = self.unpatcher3d(h)
560
- return h
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/148897d5cae9165673cb74e336548c71adb261b1 DELETED
@@ -1,78 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import glob
17
- from dataclasses import dataclass
18
-
19
- import imageio
20
- import numpy as np
21
-
22
- from .log import log
23
-
24
-
25
- @dataclass
26
- class VideoData:
27
- frames: np.ndarray # Shape: [B, H, W, C]
28
- fps: int
29
- duration: int # in seconds
30
-
31
-
32
- def get_video_filepaths(input_dir: str) -> list[str]:
33
- """Get a list of filepaths for all videos in the input directory."""
34
- paths = glob.glob(f"{input_dir}/**/*.mp4", recursive=True)
35
- paths += glob.glob(f"{input_dir}/**/*.avi", recursive=True)
36
- paths += glob.glob(f"{input_dir}/**/*.mov", recursive=True)
37
- paths = sorted(paths)
38
- log.debug(f"Found {len(paths)} videos")
39
- return paths
40
-
41
-
42
- def read_video(filepath: str) -> VideoData:
43
- """Read a video file and extract its frames and metadata."""
44
- try:
45
- reader = imageio.get_reader(filepath, "ffmpeg")
46
- except Exception as e:
47
- raise ValueError(f"Failed to read video file: {filepath}") from e
48
-
49
- # Extract metadata from the video file
50
- try:
51
- metadata = reader.get_meta_data()
52
- fps = metadata.get("fps")
53
- duration = metadata.get("duration")
54
- except Exception as e:
55
- reader.close()
56
- raise ValueError(f"Failed to extract metadata from video file: {filepath}") from e
57
-
58
- # Extract frames from the video file
59
- try:
60
- frames = np.array([frame for frame in reader])
61
- except Exception as e:
62
- raise ValueError(f"Failed to extract frames from video file: {filepath}") from e
63
- finally:
64
- reader.close()
65
-
66
- return VideoData(frames=frames, fps=fps, duration=duration)
67
-
68
-
69
- def save_video(filepath: str, frames: np.ndarray, fps: int) -> None:
70
- """Save a video file from a sequence of frames."""
71
- try:
72
- writer = imageio.get_writer(filepath, fps=fps, macro_block_size=1)
73
- for frame in frames:
74
- writer.append_data(frame)
75
- except Exception as e:
76
- raise ValueError(f"Failed to save video file to {filepath}") from e
77
- finally:
78
- writer.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/184d60dec6f9b0326dc0aa1a3d9b89c06fa7566e DELETED
@@ -1,283 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- """
17
- A general framework for various sampling algorithm from a diffusion model.
18
- Impl based on
19
- * Refined Exponential Solver (RES) in https://arxiv.org/pdf/2308.02157
20
- * also clude other impl, DDIM, DEIS, DPM-Solver, EDM sampler.
21
- Most of sampling algorihtm, Runge-Kutta, Multi-step, etc, can be impl in this framework by \
22
- adding new step function in get_runge_kutta_fn or get_multi_step_fn.
23
- """
24
-
25
- import math
26
- from typing import Any, Callable, List, Literal, Optional, Tuple, Union
27
-
28
- import attrs
29
- import torch
30
-
31
- from .df_df_functional_multi_step import get_multi_step_fn, is_multi_step_fn_supported
32
- from .df_df_functional_runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported
33
- from .config import make_freezable
34
-
35
- COMMON_SOLVER_OPTIONS = Literal["2ab", "2mid", "1euler"]
36
-
37
-
38
- @make_freezable
39
- @attrs.define(slots=False)
40
- class SolverConfig:
41
- is_multi: bool = False
42
- rk: str = "2mid"
43
- multistep: str = "2ab"
44
- # following parameters control stochasticity, see EDM paper
45
- # BY default, we use deterministic with no stochasticity
46
- s_churn: float = 0.0
47
- s_t_max: float = float("inf")
48
- s_t_min: float = 0.05
49
- s_noise: float = 1.0
50
-
51
-
52
- @make_freezable
53
- @attrs.define(slots=False)
54
- class SolverTimestampConfig:
55
- nfe: int = 50
56
- t_min: float = 0.002
57
- t_max: float = 80.0
58
- order: float = 7.0
59
- is_forward: bool = False # whether generate forward or backward timestamps
60
-
61
-
62
- @make_freezable
63
- @attrs.define(slots=False)
64
- class SamplerConfig:
65
- solver: SolverConfig = attrs.field(factory=SolverConfig)
66
- timestamps: SolverTimestampConfig = attrs.field(factory=SolverTimestampConfig)
67
- sample_clean: bool = True # whether run one last step to generate clean image
68
-
69
-
70
- def get_rev_ts(
71
- t_min: float, t_max: float, num_steps: int, ts_order: Union[int, float], is_forward: bool = False
72
- ) -> torch.Tensor:
73
- """
74
- Generate a sequence of reverse time steps.
75
-
76
- Args:
77
- t_min (float): The minimum time value.
78
- t_max (float): The maximum time value.
79
- num_steps (int): The number of time steps to generate.
80
- ts_order (Union[int, float]): The order of the time step progression.
81
- is_forward (bool, optional): If True, returns the sequence in forward order. Defaults to False.
82
-
83
- Returns:
84
- torch.Tensor: A tensor containing the generated time steps in reverse or forward order.
85
-
86
- Raises:
87
- ValueError: If `t_min` is not less than `t_max`.
88
- TypeError: If `ts_order` is not an integer or float.
89
- """
90
- if t_min >= t_max:
91
- raise ValueError("t_min must be less than t_max")
92
-
93
- if not isinstance(ts_order, (int, float)):
94
- raise TypeError("ts_order must be an integer or float")
95
-
96
- step_indices = torch.arange(num_steps + 1, dtype=torch.float64)
97
- time_steps = (
98
- t_max ** (1 / ts_order) + step_indices / num_steps * (t_min ** (1 / ts_order) - t_max ** (1 / ts_order))
99
- ) ** ts_order
100
-
101
- if is_forward:
102
- return time_steps.flip(dims=(0,))
103
-
104
- return time_steps
105
-
106
-
107
- class Sampler(torch.nn.Module):
108
- def __init__(self, cfg: Optional[SamplerConfig] = None):
109
- super().__init__()
110
- if cfg is None:
111
- cfg = SamplerConfig()
112
- self.cfg = cfg
113
-
114
- @torch.no_grad()
115
- def forward(
116
- self,
117
- x0_fn: Callable,
118
- x_sigma_max: torch.Tensor,
119
- num_steps: int = 35,
120
- sigma_min: float = 0.002,
121
- sigma_max: float = 80,
122
- rho: float = 7,
123
- S_churn: float = 0,
124
- S_min: float = 0,
125
- S_max: float = float("inf"),
126
- S_noise: float = 1,
127
- solver_option: str = "2ab",
128
- ) -> torch.Tensor:
129
- in_dtype = x_sigma_max.dtype
130
-
131
- def float64_x0_fn(x_B_StateShape: torch.Tensor, t_B: torch.Tensor) -> torch.Tensor:
132
- return x0_fn(x_B_StateShape.to(in_dtype), t_B.to(in_dtype)).to(torch.float64)
133
-
134
- is_multistep = is_multi_step_fn_supported(solver_option)
135
- is_rk = is_runge_kutta_fn_supported(solver_option)
136
- assert is_multistep or is_rk, f"Only support multistep or Runge-Kutta method, got {solver_option}"
137
-
138
- solver_cfg = SolverConfig(
139
- s_churn=S_churn,
140
- s_t_max=S_max,
141
- s_t_min=S_min,
142
- s_noise=S_noise,
143
- is_multi=is_multistep,
144
- rk=solver_option,
145
- multistep=solver_option,
146
- )
147
- timestamps_cfg = SolverTimestampConfig(nfe=num_steps, t_min=sigma_min, t_max=sigma_max, order=rho)
148
- sampler_cfg = SamplerConfig(solver=solver_cfg, timestamps=timestamps_cfg, sample_clean=True)
149
-
150
- return self._forward_impl(float64_x0_fn, x_sigma_max, sampler_cfg).to(in_dtype)
151
-
152
- @torch.no_grad()
153
- def _forward_impl(
154
- self,
155
- denoiser_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
156
- noisy_input_B_StateShape: torch.Tensor,
157
- sampler_cfg: Optional[SamplerConfig] = None,
158
- callback_fns: Optional[List[Callable]] = None,
159
- ) -> torch.Tensor:
160
- """
161
- Internal implementation of the forward pass.
162
-
163
- Args:
164
- denoiser_fn: Function to denoise the input.
165
- noisy_input_B_StateShape: Input tensor with noise.
166
- sampler_cfg: Configuration for the sampler.
167
- callback_fns: List of callback functions to be called during sampling.
168
-
169
- Returns:
170
- torch.Tensor: Denoised output tensor.
171
- """
172
- sampler_cfg = self.cfg if sampler_cfg is None else sampler_cfg
173
- solver_order = 1 if sampler_cfg.solver.is_multi else int(sampler_cfg.solver.rk[0])
174
- num_timestamps = sampler_cfg.timestamps.nfe // solver_order
175
-
176
- sigmas_L = get_rev_ts(
177
- sampler_cfg.timestamps.t_min, sampler_cfg.timestamps.t_max, num_timestamps, sampler_cfg.timestamps.order
178
- ).to(noisy_input_B_StateShape.device)
179
-
180
- denoised_output = differential_equation_solver(
181
- denoiser_fn, sigmas_L, sampler_cfg.solver, callback_fns=callback_fns
182
- )(noisy_input_B_StateShape)
183
-
184
- if sampler_cfg.sample_clean:
185
- # Override denoised_output with fully denoised version
186
- ones = torch.ones(denoised_output.size(0), device=denoised_output.device, dtype=denoised_output.dtype)
187
- denoised_output = denoiser_fn(denoised_output, sigmas_L[-1] * ones)
188
-
189
- return denoised_output
190
-
191
-
192
- def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_val: Any) -> Any:
193
- """
194
- Implements a for loop with a function.
195
-
196
- Args:
197
- lower: Lower bound of the loop (inclusive).
198
- upper: Upper bound of the loop (exclusive).
199
- body_fun: Function to be applied in each iteration.
200
- init_val: Initial value for the loop.
201
-
202
- Returns:
203
- The final result after all iterations.
204
- """
205
- val = init_val
206
- for i in range(lower, upper):
207
- val = body_fun(i, val)
208
- return val
209
-
210
-
211
- def differential_equation_solver(
212
- x0_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
213
- sigmas_L: torch.Tensor,
214
- solver_cfg: SolverConfig,
215
- callback_fns: Optional[List[Callable]] = None,
216
- ) -> Callable[[torch.Tensor], torch.Tensor]:
217
- """
218
- Creates a differential equation solver function.
219
-
220
- Args:
221
- x0_fn: Function to compute x0 prediction.
222
- sigmas_L: Tensor of sigma values with shape [L,].
223
- solver_cfg: Configuration for the solver.
224
- callback_fns: Optional list of callback functions.
225
-
226
- Returns:
227
- A function that solves the differential equation.
228
- """
229
- num_step = len(sigmas_L) - 1
230
-
231
- if solver_cfg.is_multi:
232
- update_step_fn = get_multi_step_fn(solver_cfg.multistep)
233
- else:
234
- update_step_fn = get_runge_kutta_fn(solver_cfg.rk)
235
-
236
- eta = min(solver_cfg.s_churn / (num_step + 1), math.sqrt(1.2) - 1)
237
-
238
- def sample_fn(input_xT_B_StateShape: torch.Tensor) -> torch.Tensor:
239
- """
240
- Samples from the differential equation.
241
-
242
- Args:
243
- input_xT_B_StateShape: Input tensor with shape [B, StateShape].
244
-
245
- Returns:
246
- Output tensor with shape [B, StateShape].
247
- """
248
- ones_B = torch.ones(input_xT_B_StateShape.size(0), device=input_xT_B_StateShape.device, dtype=torch.float64)
249
-
250
- def step_fn(
251
- i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
252
- ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
253
- input_x_B_StateShape, x0_preds = state
254
- sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1]
255
-
256
- # algorithm 2: line 4-6
257
- if solver_cfg.s_t_min < sigma_cur_0 < solver_cfg.s_t_max:
258
- hat_sigma_cur_0 = sigma_cur_0 + eta * sigma_cur_0
259
- input_x_B_StateShape = input_x_B_StateShape + (
260
- hat_sigma_cur_0**2 - sigma_cur_0**2
261
- ).sqrt() * solver_cfg.s_noise * torch.randn_like(input_x_B_StateShape)
262
- sigma_cur_0 = hat_sigma_cur_0
263
-
264
- if solver_cfg.is_multi:
265
- x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B)
266
- output_x_B_StateShape, x0_preds = update_step_fn(
267
- input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_pred_B_StateShape, x0_preds
268
- )
269
- else:
270
- output_x_B_StateShape, x0_preds = update_step_fn(
271
- input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_fn
272
- )
273
-
274
- if callback_fns:
275
- for callback_fn in callback_fns:
276
- callback_fn(**locals())
277
-
278
- return output_x_B_StateShape, x0_preds
279
-
280
- x_at_eps, _ = fori_loop(0, num_step, step_fn, [input_xT_B_StateShape, None])
281
- return x_at_eps
282
-
283
- return sample_fn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/1e300540d3a022a74d708a0df0f04204a895b189 DELETED
@@ -1,903 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import gc
17
- import os
18
- from typing import List, Optional, Tuple
19
-
20
- from .misc import misc
21
- import numpy as np
22
- import torch
23
- from einops import rearrange
24
-
25
- from .ar_config_base_model_config import create_video2world_model_config
26
- from .ar_config_base_tokenizer import TokenizerConfig
27
- from .ar_config_inference_inference_config import (
28
- DataShapeConfig,
29
- DiffusionDecoderSamplingConfig,
30
- InferenceConfig,
31
- SamplingConfig,
32
- )
33
- from .ar_diffusion_decoder_inference import diffusion_decoder_process_tokens
34
- from .ar_diffusion_decoder_model import LatentDiffusionDecoderModel
35
- from .ar_model import AutoRegressiveModel
36
- from .ar_utils_inference import _SUPPORTED_CONTEXT_LEN, prepare_video_batch_for_saving
37
- from .base_world_generation_pipeline import BaseWorldGenerationPipeline
38
- from .df_inference_inference_utils import (
39
- load_model_by_config,
40
- load_network_model,
41
- load_tokenizer_model,
42
- )
43
- from .log import log
44
-
45
-
46
- def detect_model_size_from_ckpt_path(ckpt_path: str) -> str:
47
- """Detect model size from checkpoint path.
48
-
49
- Args:
50
- ckpt_path: Path to model checkpoint file
51
-
52
- Returns:
53
- str: Model size ('4b', '5b', '12b', or '13b')
54
-
55
- Examples:
56
- >>> detect_model_size_from_ckpt_path("model_4B.pt")
57
- '4b'
58
- """
59
- model_size = "4b"
60
- if "4B" in ckpt_path:
61
- model_size = "4b"
62
- elif "5B" in ckpt_path:
63
- model_size = "5b"
64
- elif "12B" in ckpt_path:
65
- model_size = "12b"
66
- elif "13B" in ckpt_path:
67
- model_size = "13b"
68
- else:
69
- log.warning(f"Could not detect model size from checkpoint path: {ckpt_path}")
70
- return model_size
71
-
72
-
73
- def create_inference_config(
74
- model_ckpt_path: str,
75
- tokenizer_ckpt_path: str,
76
- model_size: str = "4b",
77
- batch_size: int = 1,
78
- inference_type: str = "base",
79
- ) -> InferenceConfig:
80
- """Create inference configuration for model.
81
-
82
- Args:
83
- model_ckpt_path: Path to model checkpoint
84
- tokenizer_ckpt_path: Path to tokenizer checkpoint
85
- model_size: Size of model ('4b', '5b', '12b', '13b')
86
- batch_size: Batch size for inference
87
- inference_type: Type of inference ('base' or 'video2world')
88
-
89
- Returns:
90
- InferenceConfig: Configuration object for inference
91
- """
92
- model_size = model_size.lower()
93
- # For inference config
94
- kwargs = {}
95
- if inference_type == "video2world":
96
- kwargs.update(
97
- dict(
98
- insert_cross_attn=True,
99
- insert_cross_attn_every_k_layers=1,
100
- context_dim=1024,
101
- training_type="text_to_video",
102
- apply_abs_pos_emb=True,
103
- )
104
- )
105
- if model_size == "5b":
106
- model_size = "4b" # The base model (excluding the cross attention layers) is the 4B model
107
- elif model_size == "13b":
108
- model_size = "12b" # The base model (excluding the cross attention layers) is the 12B model
109
- else:
110
- raise ValueError(f"Unsupported model size for video2world inference_type: {model_size}")
111
- else:
112
- assert inference_type == "base", f"Unsupported inference_type: {inference_type}"
113
-
114
- model_config, tokenizer_config = create_video2world_model_config(
115
- model_ckpt_path=model_ckpt_path,
116
- tokenizer_ckpt_path=tokenizer_ckpt_path,
117
- model_size=model_size,
118
- rope_dim="3D",
119
- add_special_tokens=False,
120
- pixel_chunk_duration=33,
121
- num_video_frames=33,
122
- num_condition_latents_t=1,
123
- batch_size=batch_size,
124
- video_height=640,
125
- video_width=1024,
126
- **kwargs,
127
- )
128
-
129
- inference_config = InferenceConfig()
130
-
131
- inference_config.model_config = model_config
132
- inference_config.tokenizer_config = tokenizer_config
133
-
134
- inference_config.data_shape_config = DataShapeConfig(
135
- num_video_frames=model_config.num_video_frames,
136
- height=model_config.video_height,
137
- width=model_config.video_width,
138
- latent_shape=model_config.video_latent_shape,
139
- )
140
- inference_config.model_config.fuse_qkv = False
141
- return inference_config
142
-
143
-
144
- class ARBaseGenerationPipeline(BaseWorldGenerationPipeline):
145
- """Base class for autoregressive world generation models.
146
-
147
- Handles the core functionality for generating videos using autoregressive models.
148
- Provides configurable GPU memory management through model offloading and supports
149
- different inference types for video generation.
150
-
151
- Attributes:
152
- inference_config (InferenceConfig): Configuration for model inference
153
- tokenizer_config (TokenizerConfig): Configuration for tokenizer
154
- disable_diffusion_decoder (bool): Whether diffusion decoder is disabled
155
- latent_shape (List[int]): Shape of video latents [T, H, W]
156
- _supported_context_len (int): Supported context window length
157
- latent_chunk_duration (int): Duration of latent chunks
158
- pixel_chunk_duration (int): Duration of pixel chunks
159
- diffusion_decoder_model (Optional[nn.Module]): The diffusion decoder model
160
- """
161
-
162
- def __init__(
163
- self,
164
- inference_type: str,
165
- checkpoint_dir: str,
166
- checkpoint_name: str,
167
- has_text_input: bool = False,
168
- offload_network: bool = False,
169
- offload_tokenizer: bool = False,
170
- disable_diffusion_decoder: bool = False,
171
- offload_guardrail_models: bool = False,
172
- offload_diffusion_decoder: bool = False,
173
- ):
174
- """Initialize the autoregressive world generation pipeline.
175
-
176
- Args:
177
- inference_type: Type of world generation ('base' or 'video2world')
178
- checkpoint_dir: Base directory containing model checkpoints
179
- checkpoint_name: Name of the AR checkpoint to load
180
- has_text_input: Whether the pipeline takes text input for world generation
181
- disable_diffusion_decoder: Whether to disable the diffusion decoder stage
182
- offload_network: Whether to offload AR model from GPU after use
183
- offload_guardrail_models: Whether to offload content filtering models
184
- offload_diffusion_decoder: Whether to offload diffusion decoder
185
-
186
- Raises:
187
- AssertionError: If inference_type is not 'base' or 'video2world'
188
- """
189
- assert inference_type in [
190
- "base",
191
- "video2world",
192
- ], "Invalid inference_type, must be 'base' or 'video2world'"
193
-
194
- # Create inference config
195
- model_size = detect_model_size_from_ckpt_path(checkpoint_name)
196
- model_ckpt_path = os.path.join(checkpoint_dir, checkpoint_name, "model.pt")
197
- tokenizer_ckpt_path = os.path.join(checkpoint_dir, "Cosmos-1.0-Tokenizer-DV8x16x16/ema.jit")
198
-
199
- inference_config: InferenceConfig = create_inference_config(
200
- model_ckpt_path=model_ckpt_path,
201
- tokenizer_ckpt_path=tokenizer_ckpt_path,
202
- model_size=model_size,
203
- inference_type=inference_type,
204
- )
205
-
206
- self.inference_config = inference_config
207
- self.disable_diffusion_decoder = disable_diffusion_decoder
208
-
209
- if not disable_diffusion_decoder:
210
- self.diffusion_decoder_ckpt_path = os.path.join(
211
- checkpoint_dir, "Cosmos-1.0-Diffusion-7B-Decoder-DV8x16x16ToCV8x8x8/model.pt"
212
- )
213
- self.diffusion_decoder_config = "DD_FT_7Bv1_003_002_tokenizer888_spatch2_discrete_cond_on_token"
214
- self.diffusion_decoder_tokenizer_path = os.path.join(checkpoint_dir, "Cosmos-1.0-Tokenizer-CV8x8x8")
215
- self.dd_sampling_config = DiffusionDecoderSamplingConfig()
216
- aux_vars_path = os.path.join(os.path.dirname(self.diffusion_decoder_ckpt_path), "aux_vars.pt")
217
- # We use a generic prompt when no text prompts are available for diffusion decoder.
218
- # Generic prompt used - "high quality, 4k, high definition, smooth video"
219
- aux_vars = torch.load(aux_vars_path, weights_only=True)
220
- self.generic_prompt = dict()
221
- self.generic_prompt["context"] = aux_vars["context"].cuda()
222
- self.generic_prompt["context_mask"] = aux_vars["context_mask"].cuda()
223
-
224
- self.latent_shape = inference_config.data_shape_config.latent_shape # [L, 40, 64]
225
- self._supported_context_len = _SUPPORTED_CONTEXT_LEN
226
- self.tokenizer_config = inference_config.tokenizer_config
227
-
228
- self.offload_diffusion_decoder = offload_diffusion_decoder
229
- self.diffusion_decoder_model = None
230
- if not self.offload_diffusion_decoder and not disable_diffusion_decoder:
231
- self._load_diffusion_decoder()
232
-
233
- super().__init__(
234
- inference_type=inference_type,
235
- checkpoint_dir=checkpoint_dir,
236
- checkpoint_name=checkpoint_name,
237
- has_text_input=has_text_input,
238
- offload_guardrail_models=offload_guardrail_models,
239
- offload_network=offload_network,
240
- offload_tokenizer=offload_tokenizer,
241
- offload_text_encoder_model=True,
242
- )
243
-
244
- def _load_model(self):
245
- """Load and initialize the autoregressive model.
246
-
247
- Creates and configures the autoregressive model with appropriate settings.
248
- """
249
- self.model = AutoRegressiveModel(
250
- config=self.inference_config.model_config,
251
- )
252
-
253
- def _load_network(self):
254
- """Load network weights for the autoregressive model."""
255
- self.model.load_ar_model(tokenizer_config=self.inference_config.tokenizer_config)
256
-
257
- def _load_tokenizer(self):
258
- """Load and initialize the tokenizer model.
259
-
260
- Configures the tokenizer using settings from inference_config and
261
- attaches it to the autoregressive model.
262
- """
263
- self.model.load_tokenizer(tokenizer_config=self.inference_config.tokenizer_config)
264
-
265
- def _load_diffusion_decoder(self):
266
- """Load and initialize the diffusion decoder model."""
267
- self.diffusion_decoder_model = load_model_by_config(
268
- config_job_name=self.diffusion_decoder_config,
269
- config_file="cosmos1/models/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py",
270
- model_class=LatentDiffusionDecoderModel,
271
- )
272
- load_network_model(self.diffusion_decoder_model, self.diffusion_decoder_ckpt_path)
273
- load_tokenizer_model(self.diffusion_decoder_model, self.diffusion_decoder_tokenizer_path)
274
-
275
- def _offload_diffusion_decoder(self):
276
- """Offload diffusion decoder model from GPU memory."""
277
- if self.diffusion_decoder_model is not None:
278
- del self.diffusion_decoder_model
279
- self.diffusion_decoder_model = None
280
- gc.collect()
281
- torch.cuda.empty_cache()
282
-
283
- def _run_model_with_offload(
284
- self, inp_vid: torch.Tensor, num_input_frames: int, seed: int, sampling_config: SamplingConfig
285
- ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
286
- """Run the autoregressive model to generate video tokens.
287
-
288
- Takes input video frames and generates new video tokens using the autoregressive model.
289
- Handles context frame selection and token generation.
290
-
291
- Args:
292
- inp_vid (torch.Tensor): Input video tensor of shape
293
- num_input_frames (int): Number of context frames to use from input. The tensor shape should be (B x T x 3 x H x W).
294
- seed (int): Random seed for generation
295
- sampling_config (SamplingConfig): Configuration for sampling parameters
296
-
297
- Returns:
298
- tuple: (
299
- List of generated video tensors,
300
- List of token index tensors,
301
- List of prompt embedding tensors
302
- )
303
- """
304
- # Choosing the context length from list of available contexts
305
- latent_context_t_size = 0
306
- context_used = 0
307
- for _clen in self._supported_context_len:
308
- if num_input_frames >= _clen:
309
- context_used = _clen
310
- latent_context_t_size += 1
311
- log.info(f"Using input size of {context_used} frames")
312
-
313
- data_batch = {"video": inp_vid}
314
- data_batch = misc.to(data_batch, "cuda")
315
-
316
- T, H, W = self.latent_shape
317
- num_gen_tokens = int(np.prod([T - latent_context_t_size, H, W]))
318
-
319
- out_videos_cur_batch, indices_tensor_cur_batch = self.generate_partial_tokens_from_data_batch(
320
- data_batch=data_batch,
321
- num_tokens_to_generate=num_gen_tokens,
322
- sampling_config=sampling_config,
323
- tokenizer_config=self.tokenizer_config,
324
- latent_shape=self.latent_shape,
325
- task_condition="video",
326
- num_chunks_to_generate=1,
327
- seed=seed,
328
- )
329
- if self.offload_network:
330
- self._offload_network()
331
- if self.offload_tokenizer:
332
- self._offload_tokenizer()
333
- return out_videos_cur_batch, indices_tensor_cur_batch
334
-
335
- def _run_diffusion_decoder(
336
- self,
337
- out_videos_cur_batch: List[torch.Tensor],
338
- indices_tensor_cur_batch: List[torch.Tensor],
339
- t5_emb_batch: List[torch.Tensor],
340
- ) -> List[torch.Tensor]:
341
- """Process generated tokens through the diffusion decoder.
342
-
343
- Enhances video quality through diffusion-based decoding.
344
-
345
- Args:
346
- out_videos_cur_batch: List of generated video tensors
347
- indices_tensor_cur_batch: List of token indices tensors
348
- t5_emb_batch: List of text embeddings for conditioning
349
-
350
- Returns:
351
- list: Enhanced video tensors after diffusion processing
352
- """
353
- out_videos_cur_batch_dd = diffusion_decoder_process_tokens(
354
- model=self.diffusion_decoder_model,
355
- indices_tensor=indices_tensor_cur_batch,
356
- dd_sampling_config=self.dd_sampling_config,
357
- original_video_example=out_videos_cur_batch[0],
358
- t5_emb_batch=t5_emb_batch,
359
- )
360
- return out_videos_cur_batch_dd
361
-
362
- def _run_diffusion_decoder_with_offload(
363
- self,
364
- out_videos_cur_batch: List[torch.Tensor],
365
- indices_tensor_cur_batch: List[torch.Tensor],
366
- t5_emb_batch: List[torch.Tensor],
367
- ) -> List[torch.Tensor]:
368
- """Run diffusion decoder with memory management.
369
-
370
- Loads decoder if needed, processes videos, and offloads decoder afterward
371
- if configured in offload_diffusion_decoder.
372
-
373
- Args:
374
- out_videos_cur_batch: List of generated video tensors
375
- indices_tensor_cur_batch: List of token indices tensors
376
- t5_emb_batch: List of text embeddings for conditioning
377
-
378
- Returns:
379
- list: Enhanced video tensors after diffusion processing
380
- """
381
- if self.offload_diffusion_decoder:
382
- self._load_diffusion_decoder()
383
- out_videos_cur_batch = self._run_diffusion_decoder(out_videos_cur_batch, indices_tensor_cur_batch, t5_emb_batch)
384
- if self.offload_diffusion_decoder:
385
- self._offload_diffusion_decoder()
386
- return out_videos_cur_batch
387
-
388
- def generate(
389
- self,
390
- inp_vid: torch.Tensor,
391
- sampling_config: SamplingConfig,
392
- num_input_frames: int = 9,
393
- seed: int = 0,
394
- ) -> np.ndarray | None:
395
- """Generate a video continuation from input frames.
396
-
397
- Pipeline steps:
398
- 1. Generates video tokens using autoregressive model
399
- 2. Optionally enhances quality via diffusion decoder
400
- 3. Applies safety checks if enabled
401
-
402
- Args:
403
- inp_vid: Input video tensor of shape (batch_size, time, channels=3, height, width)
404
- sampling_config: Parameters controlling the generation process
405
- num_input_frames: Number of input frames to use as context (default: 9)
406
- seed: Random seed for reproducibility (default: 0)
407
-
408
- Returns:
409
- np.ndarray | None: Generated video as numpy array (time, height, width, channels)
410
- if generation successful, None if safety checks fail
411
- """
412
- log.info("Run generation")
413
- out_videos_cur_batch, indices_tensor_cur_batch = self._run_model_with_offload(
414
- inp_vid, num_input_frames, seed, sampling_config
415
- )
416
- log.info("Finish AR model generation")
417
-
418
- if not self.disable_diffusion_decoder:
419
- log.info("Run diffusion decoder on generated tokens")
420
- out_videos_cur_batch = self._run_diffusion_decoder_with_offload(
421
- out_videos_cur_batch, indices_tensor_cur_batch, t5_emb_batch=[self.generic_prompt["context"]]
422
- )
423
- log.info("Finish diffusion decoder on generated tokens")
424
- out_videos_cur_batch = prepare_video_batch_for_saving(out_videos_cur_batch)
425
- output_video = out_videos_cur_batch[0]
426
-
427
- log.info("Run guardrail on generated video")
428
- output_video = self._run_guardrail_on_video_with_offload(output_video)
429
- if output_video is None:
430
- log.critical("Generated video is not safe")
431
- return None
432
- log.info("Finish guardrail on generated video")
433
-
434
- return output_video
435
-
436
- @torch.inference_mode()
437
- def generate_partial_tokens_from_data_batch(
438
- self,
439
- data_batch: dict,
440
- num_tokens_to_generate: int,
441
- sampling_config: SamplingConfig,
442
- tokenizer_config: TokenizerConfig,
443
- latent_shape: list[int],
444
- task_condition: str,
445
- num_chunks_to_generate: int = 1,
446
- seed: int = 0,
447
- ) -> tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
448
- """Generate video tokens from partial input tokens with conditioning.
449
-
450
- Handles token generation and decoding process:
451
- 1. Processes input batch and applies conditioning
452
- 2. Generates specified number of new tokens
453
- 3. Decodes tokens to video frames
454
-
455
- Args:
456
- data_batch: Dictionary containing input data including video and optional context
457
- num_tokens_to_generate: Number of tokens to generate
458
- sampling_config: Configuration for sampling parameters
459
- tokenizer_config: Configuration for tokenizer, including video tokenizer settings
460
- latent_shape: Shape of video latents [T, H, W]
461
- task_condition: Type of generation task ('video' or 'text_and_video')
462
- num_chunks_to_generate: Number of chunks to generate (default: 1)
463
- seed: Random seed for generation (default: 0)
464
-
465
- Returns:
466
- tuple containing:
467
- - List[torch.Tensor]: Generated videos
468
- - List[torch.Tensor]: Input videos
469
- - List[torch.Tensor]: Generated tokens
470
- - List[torch.Tensor]: Token index tensors
471
- """
472
- log.debug(f"Starting generate_partial_tokens_from_data_batch with seed {seed}")
473
- log.debug(f"Number of tokens to generate: {num_tokens_to_generate}")
474
- log.debug(f"Latent shape: {latent_shape}")
475
-
476
- video_token_start = tokenizer_config.video_tokenizer.tokenizer_offset
477
- video_vocab_size = tokenizer_config.video_tokenizer.vocab_size
478
- video_token_end = video_token_start + video_vocab_size
479
-
480
- logit_clipping_range = [video_token_start, video_token_end]
481
-
482
- if self.offload_network:
483
- self._offload_network()
484
- if self.offload_tokenizer:
485
- self._load_tokenizer()
486
-
487
- assert logit_clipping_range == [
488
- 0,
489
- self.model.tokenizer.video_vocab_size,
490
- ], f"logit_clipping_range {logit_clipping_range} is not supported for fast generate. Expected [0, {self.model.tokenizer.video_vocab_size}]"
491
-
492
- out_videos = {}
493
- out_indices_tensors = {}
494
-
495
- # for text2world, we only add a <bov> token at the beginning of the video tokens, this applies to 5B and 13B models
496
- if self.model.tokenizer.tokenizer_config.training_type == "text_to_video":
497
- num_bov_tokens = 1
498
- num_eov_tokens = 0
499
- else:
500
- num_eov_tokens = 1 if self.model.tokenizer.tokenizer_config.add_special_tokens else 0
501
- num_bov_tokens = 1 if self.model.tokenizer.tokenizer_config.add_special_tokens else 0
502
-
503
- chunk_idx = 0
504
- out_videos[chunk_idx] = []
505
- out_indices_tensors[chunk_idx] = []
506
-
507
- # get the context embedding and mask
508
- context = data_batch.get("context", None) if task_condition != "video" else None
509
- context_mask = data_batch.get("context_mask", None) if task_condition != "video" else None
510
- if context is not None:
511
- context = misc.to(context, "cuda").detach().clone()
512
- if context_mask is not None:
513
- context_mask = misc.to(context_mask, "cuda").detach().clone()
514
-
515
- # get the video tokens
516
- data_tokens, token_boundaries = self.model.tokenizer.tokenize(data_batch=data_batch)
517
- data_tokens = misc.to(data_tokens, "cuda").detach().clone()
518
- batch_size = data_tokens.shape[0]
519
-
520
- for sample_num in range(batch_size):
521
- input_tokens = data_tokens[sample_num][0 : token_boundaries["video"][sample_num][1]] # [B, L]
522
- input_tokens = [
523
- input_tokens[0 : -num_tokens_to_generate - num_eov_tokens].tolist()
524
- ] # -1 is to exclude eov token
525
- log.debug(
526
- f"Run sampling. # input condition tokens: {len(input_tokens[0])}; # generate tokens: {num_tokens_to_generate + num_eov_tokens}; "
527
- f"full length of the data tokens: {len(data_tokens[sample_num])}: {data_tokens[sample_num]}"
528
- )
529
- video_start_boundary = token_boundaries["video"][sample_num][0] + num_bov_tokens
530
-
531
- video_decoded, indices_tensor = self.generate_video_from_tokens(
532
- prompt_tokens=input_tokens,
533
- latent_shape=latent_shape,
534
- video_start_boundary=video_start_boundary,
535
- max_gen_len=num_tokens_to_generate,
536
- sampling_config=sampling_config,
537
- logit_clipping_range=logit_clipping_range,
538
- seed=seed,
539
- context=context,
540
- context_mask=context_mask,
541
- ) # BCLHW, range [0, 1]
542
-
543
- # For the first chunk, we store the entire generated video
544
- out_videos[chunk_idx].append(video_decoded[sample_num].detach().clone())
545
- out_indices_tensors[chunk_idx].append(indices_tensor[sample_num].detach().clone())
546
-
547
- output_videos = []
548
- output_indice_tensors = []
549
- for sample_num in range(len(out_videos[0])):
550
- tensors_to_concat = [out_videos[chunk_idx][sample_num] for chunk_idx in range(num_chunks_to_generate)]
551
- concatenated = torch.cat(tensors_to_concat, dim=1)
552
- output_videos.append(concatenated)
553
-
554
- indices_tensor_to_concat = [
555
- out_indices_tensors[chunk_idx][sample_num] for chunk_idx in range(num_chunks_to_generate)
556
- ]
557
- concatenated_indices_tensor = torch.cat(indices_tensor_to_concat, dim=1) # BLHW
558
- output_indice_tensors.append(concatenated_indices_tensor)
559
-
560
- return output_videos, output_indice_tensors
561
-
562
- def generate_video_from_tokens(
563
- self,
564
- prompt_tokens: list[torch.Tensor],
565
- latent_shape: list[int],
566
- video_start_boundary: int,
567
- max_gen_len: int,
568
- sampling_config: SamplingConfig,
569
- logit_clipping_range: list[int],
570
- seed: int = 0,
571
- context: Optional[torch.Tensor] = None,
572
- context_mask: Optional[torch.Tensor] = None,
573
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
574
- r"""
575
- Function to generate video from input tokens. These input tokens can be initial text tokens (in case of text to video),
576
- or partial ground truth tokens.
577
-
578
- Handles the core token-to-video generation process:
579
- 1. Generates new tokens using the autoregressive model
580
- 2. Handles padding and token sequence completion
581
- 3. Reshapes and processes generated tokens
582
- 4. Decodes final tokens into video frames
583
-
584
- Args:
585
- model (AutoRegressiveModel): LLama model instance
586
- prompt_tokens (list): Prompt tokens used by the model
587
- latent_shape (list): Shape of the video latents
588
- video_start_boundary (int): Index where the video tokens start
589
- max_gen_len (int): Maximum length of the tokens that needs to be generated
590
- sampling_config (SamplingConfig): Config used by sampler during inference
591
- logit_clipping_range (list): Range of indices in the logits to be clipped, e.g. [video_token_start, video_token_end]
592
- context (Optional[torch.Tensor]): The context tensor added via cross-attn.
593
- context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor.
594
- Returns:
595
- tuple containing:
596
- - List[torch.Tensor]: Generated videos
597
- - List[torch.Tensor]: Generated tokens
598
- - List[torch.Tensor]: Token index tensors
599
- """
600
- # Combine the tokens and do padding, sometimes the generated tokens end before the max_gen_len
601
- total_seq_len = np.prod(latent_shape)
602
-
603
- assert not sampling_config.logprobs
604
-
605
- stop_tokens = self.model.tokenizer.stop_tokens
606
- if self.offload_tokenizer:
607
- self._offload_tokenizer()
608
- if self.offload_network:
609
- self._load_network()
610
-
611
- generation_tokens, _ = self.model.generate(
612
- prompt_tokens=prompt_tokens,
613
- temperature=sampling_config.temperature,
614
- top_p=sampling_config.top_p,
615
- echo=sampling_config.echo,
616
- seed=seed,
617
- context=context,
618
- context_mask=context_mask,
619
- max_gen_len=max_gen_len,
620
- compile_sampling=sampling_config.compile_sampling,
621
- compile_prefill=sampling_config.compile_prefill,
622
- stop_tokens=stop_tokens,
623
- verbose=True,
624
- )
625
- generation_tokens = generation_tokens[:, video_start_boundary:]
626
- # Combine the tokens and do padding, sometimes the generated tokens end before the max_gen_len
627
- if generation_tokens.shape[1] < total_seq_len:
628
- log.warning(
629
- f"Generated video tokens (shape:{generation_tokens.shape}) shorted than expected {total_seq_len}. Could be the model produce end token early. Repeat the last token to fill the sequence in order for decoding."
630
- )
631
- padding_len = total_seq_len - generation_tokens.shape[1]
632
- padding_tokens = generation_tokens[:, [-1]].repeat(1, padding_len)
633
- generation_tokens = torch.cat([generation_tokens, padding_tokens], dim=1)
634
- # Cast to LongTensor
635
- indices_tensor = generation_tokens.long()
636
- # First, we reshape the generated tokens into batch x time x height x width
637
- indices_tensor = rearrange(
638
- indices_tensor,
639
- "B (T H W) -> B T H W",
640
- T=latent_shape[0],
641
- H=latent_shape[1],
642
- W=latent_shape[2],
643
- )
644
- log.debug(f"generated video tokens {len(generation_tokens[0])} -> reshape: {indices_tensor.shape}")
645
- # If logit clipping range is specified, offset the generated indices by the logit_clipping_range[0]
646
- # Video decoder always takes tokens in the range (0, N-1). So, this offset is needed.
647
- if len(logit_clipping_range) > 0:
648
- indices_tensor = indices_tensor - logit_clipping_range[0]
649
-
650
- if self.offload_network:
651
- self._offload_network()
652
- if self.offload_tokenizer:
653
- self._load_tokenizer()
654
-
655
- # Now decode the video using tokenizer.
656
- video_decoded = self.model.tokenizer.video_tokenizer.decode(indices_tensor.cuda())
657
- # Normalize decoded video from [-1, 1] to [0, 1], and clip value
658
- video_decoded = (video_decoded * 0.5 + 0.5).clamp_(0, 1)
659
- return video_decoded, indices_tensor
660
-
661
-
662
- class ARVideo2WorldGenerationPipeline(ARBaseGenerationPipeline):
663
- """Video-to-world generation pipeline with text conditioning capabilities.
664
-
665
- Extends the base autoregressive generation pipeline by adding:
666
- - Text prompt processing and embedding
667
- - Text-conditioned video generation
668
- - Additional safety checks for text input
669
- - Memory management for text encoder model
670
-
671
- Enables generating video continuations that are guided by both
672
- input video frames and text descriptions.
673
-
674
- Additional attributes compared to ARBaseGenerationPipeline:
675
- offload_text_encoder_model (bool): Whether to offload text encoder from GPU after use
676
- """
677
-
678
- def __init__(
679
- self,
680
- checkpoint_dir: str,
681
- checkpoint_name: str,
682
- inference_type: str = None,
683
- has_text_input: bool = True,
684
- disable_diffusion_decoder: bool = False,
685
- offload_guardrail_models: bool = False,
686
- offload_diffusion_decoder: bool = False,
687
- offload_network: bool = False,
688
- offload_tokenizer: bool = False,
689
- offload_text_encoder_model: bool = False,
690
- ):
691
- """Initialize text-conditioned video generation pipeline.
692
-
693
- Args:
694
- checkpoint_dir: Base directory containing model checkpoints
695
- checkpoint_name: Name of the checkpoint to load
696
- inference_type: Type of world generation workflow
697
- has_text_input: Whether the pipeline takes text input for world generation
698
- disable_diffusion_decoder: Whether to disable diffusion decoder stage
699
- offload_guardrail_models: Whether to offload content filtering models
700
- offload_diffusion_decoder: Whether to offload diffusion decoder
701
- offload_network: Whether to offload AR model from GPU
702
- offload_tokenizer: Whether to offload tokenizer from GPU
703
- offload_text_encoder_model: Whether to offload text encoder
704
- """
705
- super().__init__(
706
- checkpoint_dir=checkpoint_dir,
707
- checkpoint_name=checkpoint_name,
708
- inference_type=inference_type,
709
- has_text_input=has_text_input,
710
- disable_diffusion_decoder=disable_diffusion_decoder,
711
- offload_guardrail_models=offload_guardrail_models,
712
- offload_diffusion_decoder=offload_diffusion_decoder,
713
- offload_network=offload_network,
714
- offload_tokenizer=offload_tokenizer,
715
- )
716
- self.offload_text_encoder_model = offload_text_encoder_model
717
- if not self.offload_text_encoder_model:
718
- self._load_text_encoder_model()
719
-
720
- def _run_model_with_offload(
721
- self,
722
- prompt_embedding: torch.Tensor,
723
- prompt_mask: torch.Tensor,
724
- inp_vid: torch.Tensor,
725
- num_input_frames: int,
726
- seed: int,
727
- sampling_config: SamplingConfig,
728
- ) -> tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
729
- """Run model generation with memory management.
730
-
731
- Executes generation process and handles model offloading to manage GPU memory.
732
-
733
- Args:
734
- prompt_embedding: Text prompt embeddings tensor
735
- prompt_mask: Attention mask for prompt embeddings
736
- inp_vid: Input video tensor
737
- num_input_frames: Number of input frames to use
738
- seed: Random seed for reproducibility
739
- sampling_config: Configuration for sampling parameters
740
-
741
- Returns:
742
- tuple: (
743
- List of generated video tensors
744
- List of token index tensors
745
- List of prompt embedding tensors
746
- )
747
- """
748
- out_videos, indices_tensor, prompt_embedding = self._run_model(
749
- prompt_embedding, prompt_mask, inp_vid, num_input_frames, seed, sampling_config
750
- )
751
- if self.offload_network:
752
- self._offload_network()
753
- if self.offload_tokenizer:
754
- self._offload_tokenizer()
755
- return out_videos, indices_tensor, prompt_embedding
756
-
757
- def _run_model(
758
- self,
759
- prompt_embedding: torch.Tensor,
760
- prompt_mask: torch.Tensor,
761
- inp_vid: torch.Tensor,
762
- num_input_frames: int,
763
- seed: int,
764
- sampling_config: SamplingConfig,
765
- ) -> tuple[List[torch.Tensor], List[torch.Tensor], torch.Tensor]:
766
- """Run core model generation process.
767
-
768
- Handles text-conditioned video generation:
769
- 1. Prepares data batch with text embeddings and video
770
- 2. Determines appropriate context length
771
- 3. Generates video tokens with text conditioning
772
- 4. Processes output tensors
773
-
774
- Args:
775
- prompt_embedding: Text prompt embeddings tensor
776
- prompt_mask: Attention mask for prompt embeddings
777
- inp_vid: Input video tensor
778
- num_input_frames: Number of input frames to use
779
- seed: Random seed for reproducibility
780
- sampling_config: Configuration for sampling parameters,
781
- uses default config if None
782
-
783
- Returns:
784
- tuple: (
785
- List of generated video tensors
786
- List of token index tensors
787
- Text context tensor
788
- )
789
- """
790
- data_batch = {}
791
- data_batch["context"], data_batch["context_mask"] = prompt_embedding, prompt_mask
792
- T, H, W = self.latent_shape
793
-
794
- if sampling_config is None:
795
- sampling_config = self.sampling_config
796
- if type(inp_vid) is list:
797
- batch_size = len(inp_vid)
798
- elif type(inp_vid) is torch.Tensor:
799
- batch_size = 1
800
- data_batch["context"] = data_batch["context"].repeat(batch_size, 1, 1)
801
- data_batch["context_mask"] = data_batch["context_mask"].repeat(batch_size, 1)
802
- data_batch["context_mask"] = torch.ones_like(data_batch["context_mask"]).bool()
803
-
804
- latent_context_t_size = 0
805
-
806
- # Choosing the context length from list of available contexts
807
- context_used = 0
808
- for _clen in self._supported_context_len:
809
- if num_input_frames >= _clen:
810
- context_used = _clen
811
- latent_context_t_size += 1
812
- log.info(f"Using context of {context_used} frames")
813
-
814
- num_gen_tokens = int(np.prod([T - latent_context_t_size, H, W]))
815
-
816
- data_batch["video"] = inp_vid
817
- data_batch["video"] = data_batch["video"].repeat(batch_size, 1, 1, 1, 1)
818
-
819
- data_batch = misc.to(data_batch, "cuda")
820
-
821
- log.debug(f" num_tokens_to_generate: {num_gen_tokens}")
822
- log.debug(f" sampling_config: {sampling_config}")
823
- log.debug(f" tokenizer_config: {self.tokenizer_config}")
824
- log.debug(f" latent_shape: {self.latent_shape}")
825
- log.debug(f" latent_context_t_size: {latent_context_t_size}")
826
- log.debug(f" seed: {seed}")
827
-
828
- out_videos_cur_batch, indices_tensor_cur_batch = self.generate_partial_tokens_from_data_batch(
829
- data_batch=data_batch,
830
- num_tokens_to_generate=num_gen_tokens,
831
- sampling_config=sampling_config,
832
- tokenizer_config=self.tokenizer_config,
833
- latent_shape=self.latent_shape,
834
- task_condition="text_and_video",
835
- seed=seed,
836
- )
837
- return out_videos_cur_batch, indices_tensor_cur_batch, data_batch["context"]
838
-
839
- def generate(
840
- self,
841
- inp_prompt: str,
842
- inp_vid: torch.Tensor,
843
- num_input_frames: int = 9,
844
- seed: int = 0,
845
- sampling_config: SamplingConfig = None,
846
- ) -> np.ndarray | None:
847
- """Generate a video guided by text prompt and input frames.
848
-
849
- Pipeline steps:
850
- 1. Validates text prompt safety if enabled
851
- 2. Converts text to embeddings
852
- 3. Generates video with text conditioning
853
- 4. Enhances quality via diffusion decoder
854
- 5. Applies video safety checks if enabled
855
-
856
- Args:
857
- inp_prompt: Text prompt to guide the generation
858
- inp_vid: Input video tensor with shape (batch_size, time, channels=3, height, width)
859
- num_input_frames: Number of frames to use as context (default: 9)
860
- seed: Random seed for reproducibility (default: 0)
861
- sampling_config: Configuration for sampling parameters,
862
- uses default config if None
863
-
864
- Returns:
865
- np.ndarray | None: Generated video as numpy array (time, height, width, channels)
866
- if generation successful, None if safety checks fail
867
- """
868
- log.info("Run guardrail on prompt")
869
- is_safe = self._run_guardrail_on_prompt_with_offload(inp_prompt)
870
- if not is_safe:
871
- log.critical("Input text prompt is not safe")
872
- return None
873
- log.info("Pass guardrail on prompt")
874
-
875
- log.info("Run text embedding on prompt")
876
- prompt_embeddings, prompt_masks = self._run_text_embedding_on_prompt_with_offload([inp_prompt])
877
- prompt_embedding = prompt_embeddings[0]
878
- prompt_mask = prompt_masks[0]
879
- log.info("Finish text embedding on prompt")
880
-
881
- log.info("Run generation")
882
- out_videos_cur_batch, indices_tensor_cur_batch, prompt_embedding = self._run_model_with_offload(
883
- prompt_embedding, prompt_mask, inp_vid, num_input_frames, seed, sampling_config
884
- )
885
- log.info("Finish AR model generation")
886
-
887
- if not self.disable_diffusion_decoder:
888
- log.info("Run diffusion decoder on generated tokens")
889
- out_videos_cur_batch = self._run_diffusion_decoder_with_offload(
890
- out_videos_cur_batch, indices_tensor_cur_batch, [prompt_embedding]
891
- )
892
- log.info("Finish diffusion decoder on generated tokens")
893
- out_videos_cur_batch = prepare_video_batch_for_saving(out_videos_cur_batch)
894
- output_video = out_videos_cur_batch[0]
895
-
896
- log.info("Run guardrail on generated video")
897
- output_video = self._run_guardrail_on_video_with_offload(output_video)
898
- if output_video is None:
899
- log.critical("Generated video is not safe")
900
- return None
901
- log.info("Finish guardrail on generated video")
902
-
903
- return output_video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/1f41a4225dcea325c5ea283e51e09477ee1d0e6d DELETED
@@ -1,149 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import argparse
17
- import os
18
-
19
- import imageio
20
- import torch
21
-
22
- from .world_generation_pipeline import ARVideo2WorldGenerationPipeline
23
- from .ar_utils_inference import load_vision_input, validate_args
24
- from .log import log
25
- from .io import read_prompts_from_file
26
-
27
- # from download_autoregressive import main as download_autoregressive
28
- from transformers import PreTrainedModel, PretrainedConfig
29
-
30
-
31
- class ARVideo2WorldConfig(PretrainedConfig):
32
- model_type = "ARVideo2World"
33
- def __init__(self, **kwargs):
34
- super().__init__(**kwargs)
35
-
36
- self.checkpoint_dir = kwargs.get("checkpoint_dir", "checkpoints")
37
- self.ar_model_dir = kwargs.get("ar_model_dir", "Cosmos-1.0-Autoregressive-5B-Video2World")
38
- self.video_save_name = kwargs.get("video_save_name", "output")
39
- self.video_save_folder = kwargs.get("video_save_folder", "outputs/")
40
- self.prompt = kwargs.get("prompt", None)
41
-
42
- self.input_type = kwargs.get("input_type", "text_and_video")
43
- self.input_image_or_video_path = kwargs.get("input_image_or_video_path", None)
44
- self.batch_input_path = kwargs.get("batch_input_path", None)
45
- self.num_input_frames = kwargs.get("num_input_frames", 9)
46
- self.temperature = kwargs.get("temperature", 1.0)
47
- self.top_p = kwargs.get("top_p", 0.8)
48
- self.seed = kwargs.get("seed", 0)
49
-
50
- self.disable_diffusion_decoder = kwargs.get("disable_diffusion_decoder", False)
51
- self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False)
52
- self.offload_diffusion_decoder = kwargs.get("offload_diffusion_decoder", False)
53
- self.offload_ar_model = kwargs.get("offload_ar_model", False)
54
- self.offload_tokenizer = kwargs.get("offload_tokenizer", False)
55
- self.offload_text_encoder_model = kwargs.get("offload_text_encoder_model", False)
56
-
57
-
58
- class ARVideo2World(PreTrainedModel):
59
- config_class = ARVideo2WorldConfig
60
-
61
- def __init__(self, args=ARVideo2WorldConfig()):
62
- super().__init__(args)
63
- torch.enable_grad(False)
64
- self.args = args
65
-
66
- inference_type = "video2world" # When the inference_type is "video2world", AR model takes both text and video as input, the world generation is based on the input text prompt and video
67
- self.sampling_config = validate_args(args, inference_type)
68
-
69
- # Initialize prompted base generation model pipeline
70
- self.pipeline = ARVideo2WorldGenerationPipeline(
71
- inference_type=inference_type,
72
- checkpoint_dir=args.checkpoint_dir,
73
- checkpoint_name=args.ar_model_dir,
74
- disable_diffusion_decoder=args.disable_diffusion_decoder,
75
- offload_guardrail_models=args.offload_guardrail_models,
76
- offload_diffusion_decoder=args.offload_diffusion_decoder,
77
- offload_network=args.offload_ar_model,
78
- offload_tokenizer=args.offload_tokenizer,
79
- offload_text_encoder_model=args.offload_text_encoder_model,
80
- )
81
-
82
- def forward(self):
83
- args = self.args
84
-
85
- # Load input image(s) or video(s)
86
- input_videos = load_vision_input(
87
- input_type=args.input_type,
88
- batch_input_path=args.batch_input_path,
89
- input_image_or_video_path=args.input_image_or_video_path,
90
- data_resolution=args.data_resolution,
91
- num_input_frames=args.num_input_frames,
92
- )
93
-
94
- # Load input prompt(s)
95
- if args.batch_input_path:
96
- prompts_list = read_prompts_from_file(args.batch_input_path)
97
- else:
98
- prompts_list = [{"visual_input": args.input_image_or_video_path, "prompt": args.prompt}]
99
-
100
- # Iterate through prompts
101
- for idx, prompt_entry in enumerate(prompts_list):
102
- video_path = prompt_entry["visual_input"]
103
- input_filename = os.path.basename(video_path)
104
-
105
- # Check if video exists in loaded videos
106
- if input_filename not in input_videos:
107
- log.critical(f"Input file {input_filename} not found, skipping prompt.")
108
- continue
109
-
110
- inp_vid = input_videos[input_filename]
111
- inp_prompt = prompt_entry["prompt"]
112
-
113
- # Generate video
114
- log.info(f"Run with input: {prompt_entry}")
115
- out_vid = self.pipeline.generate(
116
- inp_prompt=inp_prompt,
117
- inp_vid=inp_vid,
118
- num_input_frames=args.num_input_frames,
119
- seed=args.seed,
120
- sampling_config=self.sampling_config,
121
- )
122
- if out_vid is None:
123
- log.critical("Guardrail blocked video2world generation.")
124
- continue
125
-
126
- # Save video
127
- if args.input_image_or_video_path:
128
- out_vid_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4")
129
- else:
130
- out_vid_path = os.path.join(args.video_save_folder, f"{idx}.mp4")
131
- imageio.mimsave(out_vid_path, out_vid, fps=25)
132
-
133
- log.info(f"Saved video to {out_vid_path}")
134
-
135
- def save_pretrained(self, save_directory, **kwargs):
136
- # We don't save anything, but need this function to override
137
- pass
138
-
139
- @classmethod
140
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
141
- config = kwargs["config"]
142
- other_args = kwargs.copy()
143
- other_args.pop("config")
144
- config.update(other_args)
145
- # model_sizes = ["5B",] if "5B" in config.ar_model_dir else ["13B",]
146
- # model_types = ["Video2World",]
147
- # download_autoregressive(model_types, model_sizes, config.checkpoint_dir)
148
- model = cls(config)
149
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/2464bc5e1892a3541ce439c0ea36347f43647224 DELETED
@@ -1,305 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import List, Optional
17
-
18
- import numpy as np
19
- import torch
20
- import transformer_engine as te
21
- from einops import rearrange
22
- from torch import nn
23
- from torch.utils.checkpoint import checkpoint
24
- from transformer_engine.pytorch.attention import DotProductAttention, apply_rotary_pos_emb
25
-
26
- # ---------------------- Feed Forward Network -----------------------
27
-
28
-
29
- class FeedForward(nn.Module):
30
- """
31
- Transformer FFN with optional gating
32
-
33
- Parameters:
34
- d_model (int): Dimensionality of input features.
35
- d_ff (int): Dimensionality of the hidden layer.
36
- dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1.
37
- activation (callable, optional): The activation function applied after the first linear layer.
38
- Defaults to nn.ReLU().
39
- is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer.
40
- Defaults to False.
41
- bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True.
42
-
43
- Example:
44
- >>> ff = FeedForward(d_model=512, d_ff=2048)
45
- >>> x = torch.randn(64, 10, 512) # Example input tensor
46
- >>> output = ff(x)
47
- >>> print(output.shape) # Expected shape: (64, 10, 512)
48
- """
49
-
50
- def __init__(
51
- self,
52
- d_model: int,
53
- d_ff: int,
54
- dropout: float = 0.1,
55
- activation=nn.ReLU(),
56
- is_gated: bool = False,
57
- bias: bool = False,
58
- ) -> None:
59
- super().__init__()
60
-
61
- self.layer1 = nn.Linear(d_model, d_ff, bias=bias)
62
- self.layer2 = nn.Linear(d_ff, d_model, bias=bias)
63
-
64
- self.dropout = nn.Dropout(dropout)
65
- self.activation = activation
66
- self.is_gated = is_gated
67
- if is_gated:
68
- self.linear_gate = nn.Linear(d_model, d_ff, bias=False)
69
-
70
- def forward(self, x: torch.Tensor):
71
- g = self.activation(self.layer1(x))
72
- if self.is_gated:
73
- x = g * self.linear_gate(x)
74
- else:
75
- x = g
76
- assert self.dropout.p == 0.0, "we skip dropout"
77
- return self.layer2(x)
78
-
79
-
80
- class GPT2FeedForward(FeedForward):
81
- def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False):
82
- super().__init__(
83
- d_model=d_model,
84
- d_ff=d_ff,
85
- dropout=dropout,
86
- activation=nn.GELU(),
87
- is_gated=False,
88
- bias=bias,
89
- )
90
-
91
- def forward(self, x: torch.Tensor):
92
- assert self.dropout.p == 0.0, "we skip dropout"
93
-
94
- x = self.layer1(x)
95
-
96
- def activation_layer2_forward(x):
97
- x = self.activation(x)
98
- x = self.layer2(x)
99
- return x
100
-
101
- x = checkpoint(activation_layer2_forward, x, use_reentrant=False)
102
- return x
103
-
104
-
105
- # ---------------------- Normalization Layer -----------------------
106
-
107
-
108
- def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
109
- """
110
- Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.
111
-
112
- Args:
113
- x (torch.Tensor): The input tensor to normalize.
114
- dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.
115
- eps (float, optional): A small constant to ensure numerical stability during division.
116
-
117
- Returns:
118
- torch.Tensor: The normalized tensor.
119
- """
120
- if dim is None:
121
- dim = list(range(1, x.ndim))
122
- norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
123
- norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
124
- return x / norm.to(x.dtype)
125
-
126
-
127
- def get_normalization(name: str, channels: int):
128
- if name == "I":
129
- return nn.Identity()
130
- elif name == "R":
131
- return te.pytorch.RMSNorm(channels, eps=1e-6)
132
- else:
133
- raise ValueError(f"Normalization {name} not found")
134
-
135
-
136
- class BaseAttentionOp(nn.Module):
137
- def __init__(self):
138
- super().__init__()
139
-
140
-
141
- class Attention(nn.Module):
142
- """
143
- Generalized attention impl.
144
-
145
- Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided.
146
- If `context_dim` is None, self-attention is assumed.
147
-
148
- Parameters:
149
- query_dim (int): Dimension of each query vector.
150
- context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed.
151
- heads (int, optional): Number of attention heads. Defaults to 8.
152
- dim_head (int, optional): Dimension of each head. Defaults to 64.
153
- dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0.
154
- attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default.
155
- qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False.
156
- out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False.
157
- qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections.
158
- Defaults to "SSI".
159
- qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections.
160
- Defaults to 'per_head'. Only support 'per_head'.
161
-
162
- Examples:
163
- >>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1)
164
- >>> query = torch.randn(10, 128) # Batch size of 10
165
- >>> context = torch.randn(10, 256) # Batch size of 10
166
- >>> output = attn(query, context) # Perform the attention operation
167
-
168
- Note:
169
- https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
170
- """
171
-
172
- def __init__(
173
- self,
174
- query_dim: int,
175
- context_dim=None,
176
- heads=8,
177
- dim_head=64,
178
- dropout=0.0,
179
- attn_op: Optional[BaseAttentionOp] = None,
180
- qkv_bias: bool = False,
181
- out_bias: bool = False,
182
- qkv_norm: str = "SSI",
183
- qkv_norm_mode: str = "per_head",
184
- backend: str = "transformer_engine",
185
- qkv_format: str = "bshd",
186
- ) -> None:
187
- super().__init__()
188
-
189
- self.is_selfattn = context_dim is None # self attention
190
-
191
- inner_dim = dim_head * heads
192
- context_dim = query_dim if context_dim is None else context_dim
193
-
194
- self.heads = heads
195
- self.dim_head = dim_head
196
- self.qkv_norm_mode = qkv_norm_mode
197
- self.qkv_format = qkv_format
198
-
199
- if self.qkv_norm_mode == "per_head":
200
- norm_dim = dim_head
201
- else:
202
- raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
203
-
204
- self.backend = backend
205
-
206
- self.to_q = nn.Sequential(
207
- nn.Linear(query_dim, inner_dim, bias=qkv_bias),
208
- get_normalization(qkv_norm[0], norm_dim),
209
- )
210
- self.to_k = nn.Sequential(
211
- nn.Linear(context_dim, inner_dim, bias=qkv_bias),
212
- get_normalization(qkv_norm[1], norm_dim),
213
- )
214
- self.to_v = nn.Sequential(
215
- nn.Linear(context_dim, inner_dim, bias=qkv_bias),
216
- get_normalization(qkv_norm[2], norm_dim),
217
- )
218
-
219
- self.to_out = nn.Sequential(
220
- nn.Linear(inner_dim, query_dim, bias=out_bias),
221
- nn.Dropout(dropout),
222
- )
223
-
224
- if attn_op: # use what is given
225
- self.attn_op = attn_op
226
- elif self.backend == "transformer_engine":
227
- sequence_parallel = False
228
- self.attn_op: BaseAttentionOp = DotProductAttention(
229
- self.heads,
230
- self.dim_head,
231
- num_gqa_groups=self.heads,
232
- attention_dropout=0,
233
- qkv_format=qkv_format,
234
- attn_mask_type="no_mask",
235
- tp_size=1,
236
- tp_group=None,
237
- sequence_parallel=sequence_parallel,
238
- )
239
- else:
240
- raise ValueError(f"Backend {backend} not found")
241
-
242
- def cal_qkv(
243
- self, x, context=None, mask=None, rope_emb=None, **kwargs
244
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
245
- del kwargs
246
-
247
- """
248
- self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers.
249
- Before 07/24/2024, these modules normalize across all heads.
250
- After 07/24/2024, to support tensor parallelism and follow the common practice in the community,
251
- we support to normalize per head.
252
- To keep the checkpoint copatibility with the previous code,
253
- we keep the nn.Sequential but call the projection and the normalization layers separately.
254
- We use a flag `self.qkv_norm_mode` to control the normalization behavior.
255
- The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head.
256
- """
257
- if self.qkv_norm_mode == "per_head":
258
- q = self.to_q[0](x)
259
- context = x if context is None else context
260
- k = self.to_k[0](context)
261
- v = self.to_v[0](context)
262
- q, k, v = map(
263
- lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads, c=self.dim_head),
264
- (q, k, v),
265
- )
266
- else:
267
- raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
268
-
269
- q = self.to_q[1](q)
270
- k = self.to_k[1](k)
271
- v = self.to_v[1](v)
272
- if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
273
- q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True)
274
- k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True)
275
- return q, k, v
276
-
277
- def cal_attn(self, q, k, v, mask=None):
278
- if self.backend == "transformer_engine":
279
- seq_dim = self.qkv_format.index("s")
280
- assert (
281
- q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1
282
- ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version."
283
- out = self.attn_op(q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None) # [B, Mq, H, V]
284
- return self.to_out(out)
285
- elif self.backend == "torch":
286
- out = self.attn_op(q, k, v, mask=mask) # [B, Mq, H, V]
287
- return self.to_out(rearrange(out, " b ... n c -> b ... (n c)"))
288
- else:
289
- raise ValueError(f"Backend {self.backend} not found")
290
-
291
- def forward(
292
- self,
293
- x,
294
- context=None,
295
- mask=None,
296
- rope_emb=None,
297
- **kwargs,
298
- ):
299
- """
300
- Args:
301
- x (Tensor): The query tensor of shape [B, Mq, K]
302
- context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
303
- """
304
- q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
305
- return self.cal_attn(q, k, v, mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/2984b57e08440bd3117de9e25e4f3cfabd619e80 DELETED
@@ -1,195 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Optional, Tuple
17
-
18
- import torch
19
-
20
- from .ar_network_transformer import Transformer
21
-
22
-
23
- def sample_top_p(logits, temperature, top_p, return_probs: bool = False):
24
- """
25
- Perform top-p (nucleus) sampling on a probability distribution.
26
-
27
- Args:
28
- logits (torch.Tensor): Logits of the probability distribution.
29
- temperature (float): Temperature for sampling.
30
- top_p (float): Probability threshold for top-p sampling.
31
-
32
- Returns:
33
- torch.Tensor: Sampled token indices.
34
-
35
- Note:
36
- Top-p sampling selects the smallest set of tokens whose cumulative probability mass
37
- exceeds the threshold p. The distribution is renormalized based on the selected tokens.
38
- """
39
- probs = torch.softmax(logits[:, -1, :] / temperature, dim=-1)
40
- # Sort the probabilities in descending order and get their indices.
41
- probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
42
- # Compute the cumulative sum of the sorted probabilities.
43
- probs_sum = torch.cumsum(probs_sort, dim=-1)
44
- # Create a mask where the cumulative probability exceeds the threshold p.
45
- mask = probs_sum - probs_sort > top_p
46
- # Set the probabilities that exceed the threshold to 0.
47
- probs_sort[mask] = 0.0
48
- # Renormalize the remaining probabilities so they sum to 1.
49
- probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
50
- # Sample from the renormalized probability distribution.
51
- # next_token = torch.multinomial(probs_sort, num_samples=1)
52
- next_token = multinomial_sample_one_no_sync(probs_sort, dtype=torch.int64)
53
- # Gather the indices of the sampled tokens.
54
- next_token = torch.gather(probs_idx, -1, next_token)
55
- if return_probs:
56
- # Initialize a tensor for unsorted probabilities
57
- probs_unsorted = torch.zeros_like(probs_sort)
58
- # Scatter the sorted probabilities back to their original order
59
- probs_unsorted.scatter_(-1, probs_idx, probs_sort)
60
- else:
61
- probs_unsorted = None
62
- return next_token, probs_unsorted
63
-
64
-
65
- def multinomial_sample_one_no_sync(probs_sort, dtype=torch.int):
66
- """
67
- Multinomial sampling without a cuda synchronization.
68
- Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
69
- """
70
- q = torch.empty_like(probs_sort).exponential_(1)
71
- return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=dtype)
72
-
73
-
74
- def logits_to_probs(
75
- logits,
76
- temperature: float = 1.0,
77
- top_k: Optional[int] = None,
78
- ):
79
- logits = logits / max(temperature, 1e-5)
80
-
81
- if top_k is not None:
82
- v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
83
- pivot = v.select(-1, -1).unsqueeze(-1)
84
- logits = torch.where(logits < pivot, -float("Inf"), logits)
85
- probs = torch.nn.functional.softmax(logits, dim=-1)
86
- return probs
87
-
88
-
89
- def sample_top_k(logits, temperature: float = 1.0, top_k: Optional[int] = None):
90
- """
91
- Sample from the logits using top-k sampling.
92
- Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
93
- """
94
- # logits: [batch_size, seq_len, vocab_size]
95
- if temperature == 0.0:
96
- idx_next = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
97
- probs = None
98
- else:
99
- probs = logits_to_probs(logits[:, -1, :], temperature, top_k)
100
- idx_next = multinomial_sample_one_no_sync(probs)
101
- return idx_next, probs
102
-
103
-
104
- def prefill(
105
- model: Transformer,
106
- input_pos: torch.Tensor,
107
- tokens: torch.Tensor = None,
108
- token_embeddings: torch.Tensor = None,
109
- temperature: float = 1.0,
110
- top_k: Optional[int] = None,
111
- top_p: Optional[float] = None,
112
- **kwargs,
113
- ) -> torch.Tensor:
114
- logits = model(tokens=tokens, token_embeddings=token_embeddings, input_pos=input_pos, **kwargs)
115
- # Only top-p or top-k can be provided
116
- assert (
117
- top_p is None or top_k is None
118
- ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
119
- if top_p is not None:
120
- return sample_top_p(logits, temperature=temperature, top_p=top_p)[0]
121
- else:
122
- return sample_top_k(logits, temperature=temperature, top_k=top_k)[0]
123
-
124
-
125
- def decode_one_token(
126
- model: Transformer,
127
- tokens: torch.Tensor,
128
- input_pos: torch.Tensor,
129
- temperature: float = 1.0,
130
- top_k: Optional[int] = None,
131
- top_p: Optional[float] = None,
132
- **kwargs,
133
- ) -> Tuple[torch.Tensor, torch.Tensor]:
134
- """
135
- Decode a single token from the autoregressive model.
136
- """
137
- logits = model(tokens=tokens, input_pos=input_pos, **kwargs)
138
- if top_p is not None:
139
- return sample_top_p(logits, temperature=temperature, top_p=top_p)
140
- else:
141
- return sample_top_k(logits, temperature=temperature, top_k=top_k)
142
-
143
-
144
- def decode_n_tokens(
145
- model: Transformer,
146
- cur_token: torch.Tensor,
147
- input_pos: torch.Tensor,
148
- num_new_tokens: int,
149
- stop_tokens: torch.Tensor = None,
150
- temperature: float = 1.0,
151
- top_p: Optional[float] = None,
152
- top_k: Optional[int] = None,
153
- return_probs: bool = False,
154
- decode_one_token_function=decode_one_token,
155
- **kwargs,
156
- ):
157
- """
158
- Decode n tokens from the autoregressive model.
159
- Adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
160
- """
161
- new_tokens, new_probs = [], []
162
- batch_size = cur_token.shape[0]
163
- assert (
164
- top_p is None or top_k is None
165
- ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
166
- if stop_tokens is not None:
167
- # Indicator for whether the EOS token (stop token) has been reached for each sample in the batch
168
- eos_reached = torch.tensor([False] * batch_size, device="cuda")
169
- for t in range(num_new_tokens):
170
- with torch.backends.cuda.sdp_kernel(
171
- enable_flash=False, enable_mem_efficient=False, enable_math=True
172
- ): # Actually better for Inductor to codegen attention here
173
- next_token, next_prob = decode_one_token_function(
174
- model,
175
- tokens=cur_token,
176
- input_pos=input_pos,
177
- temperature=temperature,
178
- top_k=top_k,
179
- top_p=top_p,
180
- **kwargs,
181
- )
182
- input_pos += 1
183
- if stop_tokens is not None and len(stop_tokens) > 0:
184
- eos_reached = eos_reached | (torch.isin(next_token, stop_tokens))
185
- if eos_reached.all():
186
- break
187
- new_tokens.append(next_token.clone())
188
- if return_probs:
189
- new_probs.append(next_prob.clone())
190
- cur_token = next_token.clone()
191
-
192
- if return_probs:
193
- return new_tokens, new_probs
194
- else:
195
- return new_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/29be4d33e5dfb6255b5db0b99bcbc4311a3faa82 DELETED
@@ -1,63 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from collections import namedtuple
17
-
18
- import torch
19
- from torch import nn
20
-
21
- from .ar_tokenizer_modules import CausalConv3d, DecoderFactorized, EncoderFactorized
22
- from .ar_tokenizer_quantizers import FSQuantizer
23
- from .log import log
24
-
25
- NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"])
26
-
27
-
28
- class CausalDiscreteVideoTokenizer(nn.Module):
29
- def __init__(self, z_channels: int, z_factor: int, embedding_dim: int, **kwargs) -> None:
30
- super().__init__()
31
- self.name = kwargs.get("name", "CausalDiscreteVideoTokenizer")
32
- self.embedding_dim = embedding_dim
33
- self.encoder = EncoderFactorized(z_channels=z_factor * z_channels, **kwargs)
34
- self.decoder = DecoderFactorized(z_channels=z_channels, **kwargs)
35
-
36
- self.quant_conv = CausalConv3d(z_factor * z_channels, embedding_dim, kernel_size=1, padding=0)
37
- self.post_quant_conv = CausalConv3d(embedding_dim, z_channels, kernel_size=1, padding=0)
38
-
39
- self.quantizer = FSQuantizer(**kwargs)
40
-
41
- num_parameters = sum(param.numel() for param in self.parameters())
42
- log.debug(f"model={self.name}, num_parameters={num_parameters:,}")
43
- log.debug(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.")
44
-
45
- def to(self, *args, **kwargs):
46
- setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16))
47
- return super(CausalDiscreteVideoTokenizer, self).to(*args, **kwargs)
48
-
49
- def encode(self, x):
50
- h = self.encoder(x)
51
- h = self.quant_conv(h)
52
- return self.quantizer(h)
53
-
54
- def decode(self, quant):
55
- quant = self.post_quant_conv(quant)
56
- return self.decoder(quant)
57
-
58
- def forward(self, input):
59
- quant_info, quant_codes, quant_loss = self.encode(input)
60
- reconstructions = self.decode(quant_codes)
61
- if self.training:
62
- return dict(reconstructions=reconstructions, quant_loss=quant_loss, quant_info=quant_info)
63
- return NetworkEval(reconstructions=reconstructions, quant_loss=quant_loss, quant_info=quant_info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/2a19d3b8e2a1cf29c182f7b25a25d4c1e10089da DELETED
@@ -1,491 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import math
17
- from typing import List, Optional, Tuple
18
-
19
- import numpy as np
20
- import torch
21
- from einops import rearrange, repeat
22
-
23
-
24
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
25
- """
26
- embed_dim: output dimension for each position
27
- pos: a list of positions to be encoded: size (M,)
28
- out: (M, D)
29
- """
30
- assert embed_dim % 2 == 0
31
- omega = np.arange(embed_dim // 2, dtype=np.float64)
32
- omega /= embed_dim / 2.0
33
- omega = 1.0 / 10000**omega # (D/2,)
34
-
35
- pos = pos.reshape(-1) # (M,)
36
- out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
37
-
38
- emb_sin = np.sin(out) # (M, D/2)
39
- emb_cos = np.cos(out) # (M, D/2)
40
-
41
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
42
- return emb
43
-
44
-
45
- def _rotate_half_te(x: torch.Tensor) -> torch.Tensor:
46
- """
47
- change sign so the last dimension becomes [-odd, +even].
48
- Adopted from TransformerEngine.
49
- Source: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py
50
- """
51
- x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
52
- x1, x2 = x.unbind(dim=-2)
53
- return torch.cat((-x2, x1), dim=-1)
54
-
55
-
56
- def _apply_rotary_pos_emb_te(
57
- t: torch.Tensor,
58
- cos_freqs: torch.Tensor,
59
- sin_freqs: torch.Tensor,
60
- ) -> torch.Tensor:
61
- """
62
- Apply rotary positional embedding tensor to the input tensor.
63
- Adopted from TransformerEngine.
64
- Source: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py
65
-
66
- Parameters
67
- ----------
68
- t: torch.Tensor
69
- Input tensor of shape `[b, s, h, d]`, on which
70
- rotary positional embedding will be applied.
71
- cos_freqs: torch.Tensor
72
- Cosine component of rotary positional embedding tensor of shape `[s, 1, 1, d]` and dtype 'float',
73
- sin_freqs: torch.Tensor
74
- Sine component of rotary positional embedding tensor of shape `[s, 1, 1, d]` and dtype 'float',
75
- """
76
- rot_dim = cos_freqs.shape[-1]
77
- # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
78
- t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
79
- # first part is cosine component
80
- # second part is sine component, need to change signs with _rotate_half method
81
- t = (t * cos_freqs) + (_rotate_half_te(t) * sin_freqs)
82
- output = torch.cat((t, t_pass), dim=-1)
83
- return output
84
-
85
-
86
- class RotaryPositionEmbedding(torch.nn.Module):
87
- """
88
- Rotary Position Embedding module as described in the paper:
89
- https://arxiv.org/abs/2104.09864
90
-
91
- This module implements rotary positional embeddings, which are used to
92
- enhance the performance of transformer models.
93
-
94
- Args:
95
- dim (int): Dimensionality of the input tensor.
96
- max_position_embeddings (Optional[int]): Maximum position embeddings.
97
- original_max_position_embeddings (Optional[int]): Original maximum position embeddings.
98
- rope_theta (Optional[float]): Base for the frequency calculation.
99
- apply_yarn (Optional[bool]): Whether to apply YaRN (Yet another Rotary).
100
- scale (Optional[int]): Scaling factor for the frequency calculation.
101
- extrapolation_factor (Optional[int]): Extrapolation factor for the frequency extension.
102
- attn_factor (Optional[int]): Attention factor for the frequency calculation.
103
- beta_fast (Optional[int]): Fast beta value for the YaRN frequency calculation.
104
- beta_slow (Optional[int]): Slow beta value for the YaRN frequency calculation.
105
- rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "2D", "3D".
106
- latent_shape (Optional[List[int]]): Shape of the latent tensor for video or image inputs.
107
- original_latent_shape (Optional[List[int]]): Original shape of the latent tensor for video or image inputs.
108
- pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
109
- """
110
-
111
- def __init__(
112
- self,
113
- dim: int,
114
- max_position_embeddings: Optional[int] = None,
115
- original_max_position_embeddings: Optional[int] = None,
116
- rope_theta: Optional[float] = 10000.0,
117
- apply_yarn: Optional[bool] = False,
118
- scale: Optional[int] = None,
119
- extrapolation_factor: Optional[int] = 1,
120
- attn_factor: Optional[int] = 1,
121
- beta_fast: Optional[int] = 32,
122
- beta_slow: Optional[int] = 1,
123
- rope_dim: Optional[str] = "1D",
124
- latent_shape: Optional[List[int]] = None,
125
- original_latent_shape: Optional[List[int]] = None,
126
- pad_to_multiple_of: Optional[int] = None,
127
- ):
128
- super().__init__()
129
-
130
- self.dim = dim
131
- self.max_position_embeddings = max_position_embeddings
132
- self.original_max_position_embeddings = original_max_position_embeddings
133
- self.rope_theta = rope_theta
134
- self.apply_yarn = apply_yarn
135
- self.scale = scale
136
- self.extrapolation_factor = extrapolation_factor
137
- self.attn_factor = attn_factor
138
- self.beta_fast = beta_fast
139
- self.beta_slow = beta_slow
140
- self.mscale = 1.0
141
- self.rope_dim = rope_dim
142
- self.latent_shape = latent_shape
143
- self.original_latent_shape = original_latent_shape
144
- self.pad_to_multiple_of = pad_to_multiple_of
145
- self.get_inv_freq(torch.cuda.current_device())
146
-
147
- def get_mscale(self, scale: float = 1.0) -> float:
148
- """Get the magnitude scaling factor for YaRN."""
149
- if scale <= 1:
150
- return 1.0
151
- return 0.1 * math.log(scale) + 1.0
152
-
153
- def forward(self, seq_len: Optional[int] = None) -> torch.Tensor:
154
- """
155
- Forward pass for the rotary position embedding.
156
-
157
- Args:
158
- seq_len (Optional[int]): Length of the sequence.
159
-
160
- Returns:
161
- torch.Tensor: The computed frequencies for positional embedding.
162
- """
163
-
164
- if self.apply_yarn and seq_len > self.max_seq_len_cached:
165
- self.max_seq_len_cached = seq_len
166
- self.freqs = self.compute_freqs()
167
-
168
- return self.freqs
169
-
170
- def compute_freqs(
171
- self,
172
- ) -> Tuple[torch.Tensor, torch.Tensor]:
173
- """Compute the spatial frequencies for the latent tensor."""
174
- self.seq = torch.arange(self.max_seq_len_cached, dtype=torch.float).cuda()
175
- if self.rope_dim == "1D":
176
- emb = torch.einsum("i,j->ij", self.seq, self.inv_freq)
177
-
178
- elif self.rope_dim == "2D":
179
- H, W = self.latent_shape
180
- half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq)
181
- half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq)
182
- emb = torch.cat(
183
- [
184
- repeat(half_emb_h, "h d -> h w d", w=W),
185
- repeat(half_emb_w, "w d -> h w d", h=H),
186
- ]
187
- * 2,
188
- dim=-1,
189
- )
190
- emb = rearrange(emb, "h w d -> (h w) 1 1 d").float()
191
-
192
- elif self.rope_dim == "3D":
193
- T, H, W = self.latent_shape
194
- half_emb_t = torch.outer(self.seq[:T], self.temporal_inv_freq)
195
- half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq)
196
- half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq)
197
- emb = torch.cat(
198
- [
199
- repeat(half_emb_t, "t d -> t h w d", h=H, w=W),
200
- repeat(half_emb_h, "h d -> t h w d", t=T, w=W),
201
- repeat(half_emb_w, "w d -> t h w d", t=T, h=H),
202
- ]
203
- * 2,
204
- dim=-1,
205
- )
206
- emb = rearrange(emb, "t h w d -> (t h w) 1 1 d").float()
207
- else:
208
- raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}")
209
- return emb
210
-
211
- def get_scale_factors(self, inv_freq: torch.Tensor, original_seq_len: int) -> torch.Tensor:
212
- """Get the scale factors for YaRN."""
213
- # Calculate the high and low frequency cutoffs for YaRN. Note: `beta_fast` and `beta_slow` are called
214
- # `high_freq_factor` and `low_freq_factor` in the Llama 3.1 RoPE scaling code.
215
- high_freq_cutoff = 2 * math.pi * self.beta_fast / original_seq_len
216
- low_freq_cutoff = 2 * math.pi * self.beta_slow / original_seq_len
217
- # Obtain a smooth mask that has a value of 0 for low frequencies and 1 for high frequencies, with linear
218
- # interpolation in between.
219
- smooth_mask = torch.clamp((inv_freq - low_freq_cutoff) / (high_freq_cutoff - low_freq_cutoff), min=0, max=1)
220
- # For low frequencies, we scale the frequency by 1/self.scale. For high frequencies, we keep the frequency.
221
- scale_factors = (1 - smooth_mask) / self.scale + smooth_mask
222
- return scale_factors
223
-
224
- def get_inv_freq(self, device: torch.device) -> None:
225
- """Get the inverse frequency."""
226
- if self.rope_dim == "1D":
227
- assert self.max_position_embeddings is not None, "Max position embeddings required."
228
- inv_freq = 1.0 / (
229
- self.rope_theta ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)
230
- )
231
- if self.apply_yarn:
232
- assert self.original_max_position_embeddings is not None, "Original max position embeddings required."
233
- assert self.beta_slow is not None, "Beta slow value required."
234
- assert self.beta_fast is not None, "Beta fast value required."
235
-
236
- scale_factors = self.get_scale_factors(inv_freq, self.original_max_position_embeddings)
237
- # Apply the scaling factors to inv_freq.
238
- inv_freq = inv_freq * scale_factors
239
- # Set the magnitude scaling factor.
240
- self.mscale = float(self.get_mscale(self.scale) * self.attn_factor)
241
- self.max_seq_len_cached = self.max_position_embeddings
242
- self.inv_freq = inv_freq
243
-
244
- elif self.rope_dim == "2D":
245
- assert self.latent_shape is not None, "Latent shape required."
246
- dim_h = self.dim // 2
247
- spatial_inv_freq = 1.0 / (
248
- self.rope_theta ** torch.arange(0, dim_h, 2, dtype=torch.float32, device=device) / dim_h
249
- )
250
- if self.apply_yarn:
251
- assert self.original_latent_shape is not None, "Original latent shape required."
252
- assert self.beta_slow is not None, "Beta slow value required."
253
- assert self.beta_fast is not None, "Beta fast value required."
254
-
255
- scale_factors = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[0])
256
- spatial_inv_freq = spatial_inv_freq * scale_factors
257
- self.mscale = float(self.get_mscale(self.scale) * self.attn_factor)
258
- self.spatial_inv_freq = spatial_inv_freq
259
- self.max_seq_len_cached = max(self.latent_shape)
260
-
261
- elif self.rope_dim == "3D":
262
- assert self.latent_shape is not None, "Latent shape required."
263
- dim_h = self.dim // 6 * 2
264
- dim_t = self.dim - 2 * dim_h
265
- self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(device) / dim_h
266
- spatial_inv_freq = 1.0 / (self.rope_theta**self.dim_spatial_range)
267
- self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(device) / dim_t
268
- temporal_inv_freq = 1.0 / (self.rope_theta**self.dim_temporal_range)
269
- if self.apply_yarn:
270
- assert self.original_latent_shape is not None, "Original latent shape required."
271
- assert self.beta_slow is not None, "Beta slow value required."
272
- assert self.beta_fast is not None, "Beta fast value required."
273
- scale_factors_spatial = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[1])
274
- spatial_inv_freq = spatial_inv_freq * scale_factors_spatial
275
- scale_factors_temporal = self.get_scale_factors(temporal_inv_freq, self.original_latent_shape[0])
276
- temporal_inv_freq = temporal_inv_freq * scale_factors_temporal
277
- self.mscale = float(self.get_mscale(self.scale) * self.attn_factor)
278
- self.spatial_inv_freq = spatial_inv_freq
279
- self.temporal_inv_freq = temporal_inv_freq
280
- self.max_seq_len_cached = max(self.latent_shape)
281
- else:
282
- raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}")
283
-
284
- self.freqs = self.compute_freqs()
285
-
286
-
287
- class RotaryPositionEmbeddingPytorchV2(RotaryPositionEmbedding):
288
- """
289
- Rotary Position Embedding that works in the same way as the TransformerEngine RoPE
290
- (https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py)
291
-
292
- """
293
-
294
- def __init__(
295
- self,
296
- seq_len: int,
297
- training_type: str = None,
298
- **kwargs,
299
- ):
300
- super().__init__(
301
- **kwargs,
302
- )
303
- emb = self.create_rope_freqs(seq_len=seq_len, training_type=training_type)
304
- emb = emb.transpose(0, 1).contiguous() # [seq, 1, 1, dim] -> [1, seq, 1, dim]
305
- assert emb.shape[0] == 1 and emb.shape[2] == 1, f"emb shape: {emb.shape}"
306
- # cos/sin first then dtype conversion for better precision
307
- self.register_buffer("cos_cached", torch.cos(emb), persistent=False)
308
- self.register_buffer("sin_cached", torch.sin(emb), persistent=False)
309
-
310
- def create_rope_freqs(self, seq_len: int, training_type: str = None) -> torch.Tensor:
311
- """
312
- Create rotary position embedding frequencies.
313
-
314
- Args:
315
- seq_len (int): Sequence length of a sample.
316
-
317
- Returns:
318
- torch.Tensor: The computed positional embeddings.
319
- """
320
- if self.rope_dim == "1D":
321
- freqs = super().forward(seq_len=seq_len)
322
- emb = torch.cat((freqs, freqs), dim=-1)
323
- emb = emb.reshape(emb.size(0), 1, 1, emb.size(1))
324
-
325
- elif self.rope_dim in ["2D", "3D"]:
326
- emb = super().forward(seq_len=seq_len)
327
- if training_type == "text_to_video":
328
- # since we added <bov> token at the beginning of the video for text2world, we also extend the position embedding by one token in the beginning
329
- bov_pe = torch.zeros((1, *emb.shape[1:]), device=emb.device)
330
- emb = torch.cat((bov_pe, emb), dim=0)
331
- else:
332
- raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}")
333
- if self.pad_to_multiple_of is not None and emb.shape[0] % self.pad_to_multiple_of != 0:
334
- # Round up to the nearest multiple of pad_to_multiple_of
335
- pad_len = self.pad_to_multiple_of - emb.shape[0] % self.pad_to_multiple_of
336
- emb = torch.cat((emb, torch.zeros((pad_len, *emb.shape[1:]), device=emb.device)), dim=0)
337
-
338
- return emb
339
-
340
- def forward(
341
- self, q: torch.Tensor, k: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seq_len: Optional[int] = None
342
- ) -> Tuple[torch.Tensor, torch.Tensor]:
343
- if q.dtype != self.cos_cached.dtype:
344
- self.cos_cached = self.cos_cached.to(q.dtype)
345
- self.sin_cached = self.sin_cached.to(q.dtype)
346
-
347
- cos_emb = self.cos_cached
348
- sin_emb = self.sin_cached
349
- if input_pos is not None:
350
- cos_emb = cos_emb[:, input_pos, :, :]
351
- sin_emb = sin_emb[:, input_pos, :, :]
352
- elif seq_len is not None:
353
- cos_emb = cos_emb[:, :seq_len, :, :]
354
- sin_emb = sin_emb[:, :seq_len, :, :]
355
- q = _apply_rotary_pos_emb_te(q, cos_emb, sin_emb)
356
- k = _apply_rotary_pos_emb_te(k, cos_emb, sin_emb)
357
- return q, k
358
-
359
-
360
- class RotaryPositionEmbeddingPytorchV1(RotaryPositionEmbedding):
361
- """
362
- Rotary Position Embedding that works in the same way as
363
- mistral_inference (https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/rope.py)
364
- or llama3 (https://github.com/meta-llama/llama3/blob/main/llama/model.py)
365
-
366
- """
367
-
368
- def __init__(
369
- self,
370
- **kwargs,
371
- ):
372
- super().__init__(
373
- **kwargs,
374
- )
375
- if self.rope_dim == "1D":
376
- emb = torch.stack((self.freqs, self.freqs), dim=-1).reshape(*self.freqs.shape[:-1], -1)
377
- elif self.rope_dim in ["2D", "3D"]:
378
- emb = rearrange(self.freqs, "s 1 1 d -> s d").float()
379
- self.register_buffer("cos_cached", (emb.cos() * self.mscale)[None, :, None, :], persistent=False)
380
- self.register_buffer("sin_cached", (emb.sin() * self.mscale)[None, :, None, :], persistent=False)
381
-
382
- def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
383
- """Rotate half the hidden dimensions of the input tensor."""
384
- x_reshaped = x.reshape(*x.shape[:-1], -1, 2)
385
- x1 = x_reshaped[..., 0]
386
- x2 = x_reshaped[..., 1]
387
- output = torch.stack((-x2, x1), dim=-1).reshape(*x.shape)
388
- return output
389
-
390
- def forward(
391
- self, q: torch.Tensor, k: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seq_len: Optional[int] = None
392
- ) -> Tuple[torch.Tensor, torch.Tensor]:
393
- """
394
- Forward pass for the rotary position embedding.
395
-
396
- Args:
397
- q (torch.Tensor): Query tensor.
398
- k (torch.Tensor): Key tensor.
399
- input_pos (Optional[torch.Tensor]): Starting position for the sequence.
400
- seq_len (Optional[int]): Length of the sequence.
401
-
402
- Returns:
403
- Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors.
404
- """
405
- if self.apply_yarn and seq_len > self.max_seq_len_cached:
406
- freqs = super().forward(seq_len)
407
- if self.rope_dim == "1D":
408
- emb = torch.stack((freqs, freqs), dim=-1).reshape(*freqs.shape[:-1], -1)
409
- elif self.rope_dim in ["2D", "3D"]:
410
- emb = rearrange(freqs, "s 1 1 d -> s d").float()
411
- else:
412
- raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}")
413
- self.register_buffer(
414
- "cos_cached", (emb.cos() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False
415
- )
416
- self.register_buffer(
417
- "sin_cached", (emb.sin() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False
418
- )
419
-
420
- if input_pos is not None:
421
- cos_cached = self.cos_cached[:, input_pos]
422
- sin_cached = self.sin_cached[:, input_pos]
423
- else:
424
- assert (
425
- self.cos_cached.shape[1] >= seq_len
426
- ), f"Invalid sequence length; cos_cached.shape {self.cos_cached.shape}, seq_len {seq_len}."
427
- cos_cached = self.cos_cached[:, :seq_len, ...]
428
- sin_cached = self.sin_cached[:, :seq_len, ...]
429
- xq = q * cos_cached + self.rotate_half(q) * sin_cached
430
- xk = k * cos_cached + self.rotate_half(k) * sin_cached
431
-
432
- return xq.type_as(q), xk.type_as(k)
433
-
434
-
435
- class SinCosPosEmbAxisTE(torch.nn.Module):
436
- def __init__(
437
- self,
438
- dim: int,
439
- latent_shape: Optional[List[int]] = None,
440
- pad_to_multiple_of: Optional[int] = None,
441
- dtype: torch.dtype = torch.bfloat16,
442
- **kwargs,
443
- ):
444
- """
445
- Args:
446
- dim (int): Dimensionality of the input tensor.
447
- latent_shape (Optional[List[int]]): Shape of the latent tensor for video or image inputs.
448
- pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
449
- dtype (torch.dtype): Data type of the position embedding tensor.
450
- """
451
- super().__init__()
452
- dim_h = dim // 6 * 2
453
- dim_w = dim_h
454
- dim_t = dim - 2 * dim_h
455
- assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
456
- self.latent_shape = latent_shape
457
- T, H, W = latent_shape
458
- emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, pos=np.arange(H))
459
- emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, pos=np.arange(W))
460
- emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, pos=np.arange(T))
461
-
462
- self.register_buffer("pos_emb_h", torch.from_numpy(emb_h).to(dtype=dtype, device="cuda"), persistent=False)
463
- self.register_buffer("pos_emb_w", torch.from_numpy(emb_w).to(dtype=dtype, device="cuda"), persistent=False)
464
- self.register_buffer("pos_emb_t", torch.from_numpy(emb_t).to(dtype=dtype, device="cuda"), persistent=False)
465
- self.pad_to_multiple_of = pad_to_multiple_of
466
-
467
- def forward(
468
- self,
469
- training_type: str = None,
470
- ) -> torch.Tensor:
471
- T, H, W = self.latent_shape
472
- emb = torch.cat(
473
- [
474
- repeat(self.pos_emb_t, "t d-> t h w d", h=H, w=W),
475
- repeat(self.pos_emb_h, "h d-> t h w d", t=T, w=W),
476
- repeat(self.pos_emb_w, "w d-> t h w d", t=T, h=H),
477
- ],
478
- dim=-1,
479
- )
480
- # Flatten the T,H,W dimensions
481
- emb = rearrange(emb, "t h w d -> (t h w) d")
482
-
483
- if training_type == "text_to_video":
484
- bov_pe = torch.zeros((1, *emb.shape[1:]), device=emb.device, dtype=emb.dtype)
485
- emb = torch.cat((bov_pe, emb), dim=0)
486
- if self.pad_to_multiple_of is not None and emb.shape[0] % self.pad_to_multiple_of != 0:
487
- pad_len = self.pad_to_multiple_of - emb.shape[0] % self.pad_to_multiple_of
488
- emb = torch.cat((emb, torch.zeros((pad_len, *emb.shape[1:]), device=emb.device, dtype=emb.dtype)), dim=0)
489
- seq_len, dim = emb.shape
490
- emb = emb.reshape(1, seq_len, dim)
491
- return emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/2c584c7c9a5e03bcb3b808d053f89e7c2aeaf9cf DELETED
@@ -1,119 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import torch
17
- import torch.nn.functional as F
18
-
19
-
20
- def split_with_overlap(video_BCTHW, num_video_frames, overlap=2, tobf16=True):
21
- """
22
- Splits the video tensor into chunks of num_video_frames with a specified overlap.
23
-
24
- Args:
25
- - video_BCTHW (torch.Tensor): Input tensor with shape [Batch, Channels, Time, Height, Width].
26
- - num_video_frames (int): Number of frames per chunk.
27
- - overlap (int): Number of overlapping frames between chunks.
28
-
29
- Returns:
30
- - List of torch.Tensors: List of video chunks with overlap.
31
- """
32
- # Get the dimensions of the input tensor
33
- B, C, T, H, W = video_BCTHW.shape
34
-
35
- # Ensure overlap is less than num_video_frames
36
- assert overlap < num_video_frames, "Overlap should be less than num_video_frames."
37
-
38
- # List to store the chunks
39
- chunks = []
40
-
41
- # Step size for the sliding window
42
- step = num_video_frames - overlap
43
-
44
- # Loop through the time dimension (T) with the sliding window
45
- for start in range(0, T - overlap, step):
46
- end = start + num_video_frames
47
- # Handle the case when the last chunk might go out of bounds
48
- if end > T:
49
- # Get the last available frame
50
- num_padding_frames = end - T
51
- chunk = F.pad(video_BCTHW[:, :, start:T, :, :], (0, 0, 0, 0, 0, num_padding_frames), mode="reflect")
52
- else:
53
- # Regular case: no padding needed
54
- chunk = video_BCTHW[:, :, start:end, :, :]
55
- if tobf16:
56
- chunks.append(chunk.to(torch.bfloat16))
57
- else:
58
- chunks.append(chunk)
59
- return chunks
60
-
61
-
62
- def linear_blend_video_list(videos, D):
63
- """
64
- Linearly blends a list of videos along the time dimension with overlap length D.
65
-
66
- Parameters:
67
- - videos: list of video tensors, each of shape [b, c, t, h, w]
68
- - D: int, overlap length
69
-
70
- Returns:
71
- - output_video: blended video tensor of shape [b, c, L, h, w]
72
- """
73
- assert len(videos) >= 2, "At least two videos are required."
74
- b, c, t, h, w = videos[0].shape
75
- N = len(videos)
76
-
77
- # Ensure all videos have the same shape
78
- for video in videos:
79
- assert video.shape == (b, c, t, h, w), "All videos must have the same shape."
80
-
81
- # Calculate total output length
82
- L = N * t - D * (N - 1)
83
- output_video = torch.zeros((b, c, L, h, w), device=videos[0].device)
84
-
85
- output_index = 0 # Current index in the output video
86
-
87
- for i in range(N):
88
- if i == 0:
89
- # Copy frames from the first video up to t - D
90
- output_video[:, :, output_index : output_index + t - D, :, :] = videos[i][:, :, : t - D, :, :]
91
- output_index += t - D
92
- else:
93
- # Blend overlapping frames between videos[i-1] and videos[i]
94
- blend_weights = torch.linspace(0, 1, steps=D, device=videos[0].device)
95
-
96
- for j in range(D):
97
- w1 = 1 - blend_weights[j]
98
- w2 = blend_weights[j]
99
- frame_from_prev = videos[i - 1][:, :, t - D + j, :, :]
100
- frame_from_curr = videos[i][:, :, j, :, :]
101
- output_frame = w1 * frame_from_prev + w2 * frame_from_curr
102
- output_video[:, :, output_index, :, :] = output_frame
103
- output_index += 1
104
-
105
- if i < N - 1:
106
- # Copy non-overlapping frames from current video up to t - D
107
- frames_to_copy = t - 2 * D
108
- if frames_to_copy > 0:
109
- output_video[:, :, output_index : output_index + frames_to_copy, :, :] = videos[i][
110
- :, :, D : t - D, :, :
111
- ]
112
- output_index += frames_to_copy
113
- else:
114
- # For the last video, copy frames from D to t
115
- frames_to_copy = t - D
116
- output_video[:, :, output_index : output_index + frames_to_copy, :, :] = videos[i][:, :, D:, :, :]
117
- output_index += frames_to_copy
118
-
119
- return output_video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/39dca42a0a71383de919b750cedf2606faae206d DELETED
@@ -1,65 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Any, Dict, List, Union
17
-
18
- from omegaconf import OmegaConf
19
- from omegaconf.base import DictKeyType, SCMode
20
- from omegaconf.dictconfig import DictConfig # pragma: no cover
21
-
22
-
23
- def to_object(cfg: Any) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]:
24
- """
25
- Converts an OmegaConf configuration object to a native Python container (dict or list), unless
26
- the configuration is specifically created by LazyCall, in which case the original configuration
27
- is returned directly.
28
-
29
- This function serves as a modification of the original `to_object` method from OmegaConf,
30
- preventing DictConfig objects created by LazyCall from being automatically converted to Python
31
- dictionaries. This ensures that configurations meant to be lazily evaluated retain their intended
32
- structure and behavior.
33
-
34
- Differences from OmegaConf's original `to_object`:
35
- - Adds a check at the beginning to return the configuration unchanged if it is created by LazyCall.
36
-
37
- Reference:
38
- - Original OmegaConf `to_object` method: https://github.com/omry/omegaconf/blob/master/omegaconf/omegaconf.py#L595
39
-
40
- Args:
41
- cfg (Any): The OmegaConf configuration object to convert.
42
-
43
- Returns:
44
- Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: The converted Python container if
45
- `cfg` is not a LazyCall created configuration, otherwise the unchanged `cfg`.
46
-
47
- Examples:
48
- >>> cfg = DictConfig({"key": "value", "_target_": "Model"})
49
- >>> to_object(cfg)
50
- DictConfig({"key": "value", "_target_": "Model"})
51
-
52
- >>> cfg = DictConfig({"list": [1, 2, 3]})
53
- >>> to_object(cfg)
54
- {'list': [1, 2, 3]}
55
- """
56
- if isinstance(cfg, DictConfig) and "_target_" in cfg.keys():
57
- return cfg
58
-
59
- return OmegaConf.to_container(
60
- cfg=cfg,
61
- resolve=True,
62
- throw_on_missing=True,
63
- enum_to_str=False,
64
- structured_config_mode=SCMode.INSTANTIATE,
65
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/3c5a1dbe30558d9e7e97ad64304161c4e61a00f5 DELETED
@@ -1,60 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- """
17
- Impl of multistep methods to solve the ODE in the diffusion model.
18
- """
19
-
20
- from typing import Callable, List, Tuple
21
-
22
- import torch
23
-
24
- from .df_df_functional_runge_kutta import reg_x0_euler_step, res_x0_rk2_step
25
-
26
-
27
- def order2_fn(
28
- x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_s: torch.Tensor, x0_preds: torch.Tensor
29
- ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
30
- """
31
- impl the second order multistep method in https://arxiv.org/pdf/2308.02157
32
- Adams Bashforth approach!
33
- """
34
- if x0_preds:
35
- x0_s1, s1 = x0_preds[0]
36
- x_t = res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1)
37
- else:
38
- x_t = reg_x0_euler_step(x_s, s, t, x0_s)[0]
39
- return x_t, [(x0_s, s)]
40
-
41
-
42
- # key: method name, value: method function
43
- # key: order + algorithm name
44
- MULTISTEP_FNs = {
45
- "2ab": order2_fn,
46
- }
47
-
48
-
49
- def get_multi_step_fn(name: str) -> Callable:
50
- if name in MULTISTEP_FNs:
51
- return MULTISTEP_FNs[name]
52
- methods = "\n\t".join(MULTISTEP_FNs.keys())
53
- raise RuntimeError("Only support multistep method\n" + methods)
54
-
55
-
56
- def is_multi_step_fn_supported(name: str) -> bool:
57
- """
58
- Check if the multistep method is supported.
59
- """
60
- return name in MULTISTEP_FNs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/4146fad65c365a8c4fd6903a0ea33860142f64f5 DELETED
@@ -1,323 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import copy
17
- from abc import ABC, abstractmethod
18
- from collections import defaultdict
19
- from dataclasses import dataclass, fields
20
- from enum import Enum
21
- from typing import Any, Dict, List, Optional, Tuple, Union
22
-
23
- import torch
24
- import torch.nn as nn
25
-
26
- from .df_df_functional_batch_ops import batch_mul
27
- from .log import log
28
- from .lazy_config_init import instantiate
29
-
30
-
31
- class BaseConditionEntry(nn.Module):
32
- def __init__(self):
33
- super().__init__()
34
-
35
- self._dropout_rate = None
36
- self._input_key = None
37
- self._return_dict = False
38
-
39
- @property
40
- def dropout_rate(self) -> Union[float, torch.Tensor]:
41
- return self._dropout_rate
42
-
43
- @property
44
- def input_key(self) -> str:
45
- return self._input_key
46
-
47
- @property
48
- def is_return_dict(self) -> bool:
49
- return self._return_dict
50
-
51
- @dropout_rate.setter
52
- def dropout_rate(self, value: Union[float, torch.Tensor]):
53
- self._dropout_rate = value
54
-
55
- @input_key.setter
56
- def input_key(self, value: str):
57
- self._input_key = value
58
-
59
- @is_return_dict.setter
60
- def is_return_dict(self, value: bool):
61
- self._return_dict = value
62
-
63
- @dropout_rate.deleter
64
- def dropout_rate(self):
65
- del self._dropout_rate
66
-
67
- @input_key.deleter
68
- def input_key(self):
69
- del self._input_key
70
-
71
- @is_return_dict.deleter
72
- def is_return_dict(self):
73
- del self._return_dict
74
-
75
- def random_dropout_input(
76
- self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None
77
- ) -> torch.Tensor:
78
- del key
79
- dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate
80
- return batch_mul(
81
- torch.bernoulli((1.0 - dropout_rate) * torch.ones(in_tensor.shape[0])).type_as(in_tensor),
82
- in_tensor,
83
- )
84
-
85
- def summary(self) -> str:
86
- pass
87
-
88
-
89
- class DataType(Enum):
90
- IMAGE = "image"
91
- VIDEO = "video"
92
-
93
-
94
- class TextAttr(BaseConditionEntry):
95
- def __init__(self):
96
- super().__init__()
97
-
98
- def forward(self, token: torch.Tensor, mask: torch.Tensor):
99
- return {"crossattn_emb": token, "crossattn_mask": mask}
100
-
101
- def random_dropout_input(
102
- self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None
103
- ) -> torch.Tensor:
104
- if key is not None and "mask" in key:
105
- return in_tensor
106
- return super().random_dropout_input(in_tensor, dropout_rate, key)
107
-
108
-
109
- @dataclass
110
- class BaseVideoCondition:
111
- crossattn_emb: torch.Tensor
112
- crossattn_mask: torch.Tensor
113
- data_type: DataType = DataType.VIDEO
114
- padding_mask: Optional[torch.Tensor] = None
115
- fps: Optional[torch.Tensor] = None
116
- num_frames: Optional[torch.Tensor] = None
117
- image_size: Optional[torch.Tensor] = None
118
- scalar_feature: Optional[torch.Tensor] = None
119
-
120
- def to_dict(self) -> Dict[str, Optional[torch.Tensor]]:
121
- return {f.name: getattr(self, f.name) for f in fields(self)}
122
-
123
-
124
- @dataclass
125
- class VideoExtendCondition(BaseVideoCondition):
126
- video_cond_bool: Optional[torch.Tensor] = None # whether or not it conditioned on video
127
- gt_latent: Optional[torch.Tensor] = None
128
- condition_video_indicator: Optional[torch.Tensor] = None # 1 for condition region
129
-
130
- # condition_video_input_mask will concat to the input of network, along channel dim;
131
- # Will be concat with the input tensor
132
- condition_video_input_mask: Optional[torch.Tensor] = None
133
- # condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation, only valid when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed"
134
- condition_video_augment_sigma: Optional[torch.Tensor] = None
135
-
136
-
137
- class GeneralConditioner(nn.Module, ABC):
138
- """
139
- An abstract module designed to handle various embedding models with conditional and
140
- unconditional configurations. This abstract base class initializes and manages a collection
141
- of embedders that can dynamically adjust their dropout rates based on conditioning.
142
-
143
- Attributes:
144
- KEY2DIM (dict): A mapping from output keys to dimensions used for concatenation.
145
- embedders (nn.ModuleDict): A dictionary containing all embedded models initialized and
146
- configured based on the provided configurations.
147
-
148
- Parameters:
149
- emb_models (Union[List, Any]): A dictionary where keys are embedder names and values
150
- are configurations for initializing the embedders.
151
-
152
- """
153
-
154
- KEY2DIM = {"crossattn_emb": 1, "crossattn_mask": 1}
155
-
156
- def __init__(self, **emb_models: Union[List, Any]):
157
- super().__init__()
158
- self.embedders = nn.ModuleDict()
159
- for n, (emb_name, embconfig) in enumerate(emb_models.items()):
160
- embedder = instantiate(embconfig.obj)
161
- assert isinstance(
162
- embedder, BaseConditionEntry
163
- ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
164
- embedder.dropout_rate = getattr(embconfig, "dropout_rate", 0.0)
165
-
166
- if hasattr(embconfig, "input_key"):
167
- embedder.input_key = embconfig.input_key
168
- elif hasattr(embconfig, "input_keys"):
169
- embedder.input_keys = embconfig.input_keys
170
- else:
171
- raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}")
172
-
173
- log.debug(f"Initialized embedder #{n}-{emb_name}: \n {embedder.summary()}")
174
- self.embedders[emb_name] = embedder
175
-
176
- @abstractmethod
177
- def forward(
178
- self,
179
- batch: Dict,
180
- override_dropout_rate: Optional[Dict[str, float]] = None,
181
- ) -> Any:
182
- """Should be implemented in subclasses to handle conditon datatype"""
183
- raise NotImplementedError
184
-
185
- def _forward(
186
- self,
187
- batch: Dict,
188
- override_dropout_rate: Optional[Dict[str, float]] = None,
189
- ) -> Dict:
190
- """
191
- Processes the input batch through all configured embedders, applying conditional dropout rates if specified.
192
- Output tensors for each key are concatenated along the dimensions specified in KEY2DIM.
193
-
194
- Parameters:
195
- batch (Dict): The input data batch to process.
196
- override_dropout_rate (Optional[Dict[str, float]]): Optional dictionary to override default dropout rates
197
- per embedder key.
198
-
199
- Returns:
200
- Dict: A dictionary of output tensors concatenated by specified dimensions.
201
-
202
- Note:
203
- In case the network code is sensitive to the order of concatenation, you can either control the order via \
204
- config file or make sure the embedders return a unique key for each output.
205
- """
206
- output = defaultdict(list)
207
- if override_dropout_rate is None:
208
- override_dropout_rate = {}
209
-
210
- # make sure emb_name in override_dropout_rate is valid
211
- for emb_name in override_dropout_rate.keys():
212
- assert emb_name in self.embedders, f"invalid name found {emb_name}"
213
-
214
- for emb_name, embedder in self.embedders.items():
215
- with torch.no_grad():
216
- if hasattr(embedder, "input_key") and (embedder.input_key is not None):
217
- emb_out = embedder(
218
- embedder.random_dropout_input(
219
- batch[embedder.input_key], override_dropout_rate.get(emb_name, None)
220
- )
221
- )
222
- elif hasattr(embedder, "input_keys"):
223
- emb_out = embedder(
224
- *[
225
- embedder.random_dropout_input(batch[k], override_dropout_rate.get(emb_name, None), k)
226
- for k in embedder.input_keys
227
- ]
228
- )
229
- for k, v in emb_out.items():
230
- output[k].append(v)
231
- # Concatenate the outputs
232
- return {k: torch.cat(v, dim=self.KEY2DIM.get(k, -1)) for k, v in output.items()}
233
-
234
- def get_condition_uncondition(
235
- self,
236
- data_batch: Dict,
237
- ) -> Tuple[Any, Any]:
238
- """
239
- Processes the provided data batch to generate conditioned and unconditioned outputs.
240
-
241
- This method manipulates dropout rates to simulate two scenarios:
242
- 1. All conditions applied (conditioned)
243
- 2. Conditions removed/reduced to minimum (unconditioned)
244
-
245
- This method sets dropout rates to zero for the conditioned scenario to fully apply
246
- embedders' effects. For unconditioned, it sets rates to 1 (or 0 if initial rate is
247
- insignificant) to minimize embedder influences.
248
-
249
- Parameters:
250
- data_batch (Dict): Input data batch containing all necessary information for
251
- embedding processing.
252
-
253
- Returns:
254
- Tuple[Any, Any]: A tuple containing:
255
- - Outputs with all embedders fully applied (conditioned)
256
- - Outputs with embedders minimized/not applied (unconditioned)
257
- """
258
- cond_dropout_rates, dropout_rates = {}, {}
259
- for emb_name, embedder in self.embedders.items():
260
- cond_dropout_rates[emb_name] = 0.0
261
- dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0
262
-
263
- condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates)
264
- un_condition: Any = self(data_batch, override_dropout_rate=dropout_rates)
265
- return condition, un_condition
266
-
267
- def get_condition_with_negative_prompt(
268
- self,
269
- data_batch: Dict,
270
- ) -> Tuple[Any, Any]:
271
- """
272
- Similar functionality as get_condition_uncondition
273
- But use negative prompts for unconditon
274
- """
275
- cond_dropout_rates, uncond_dropout_rates = {}, {}
276
- for emb_name, embedder in self.embedders.items():
277
- cond_dropout_rates[emb_name] = 0.0
278
- if isinstance(embedder, TextAttr):
279
- uncond_dropout_rates[emb_name] = 0.0
280
- else:
281
- uncond_dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0
282
-
283
- data_batch_neg_prompt = copy.deepcopy(data_batch)
284
- if "neg_t5_text_embeddings" in data_batch_neg_prompt:
285
- if isinstance(data_batch_neg_prompt["neg_t5_text_embeddings"], torch.Tensor):
286
- data_batch_neg_prompt["t5_text_embeddings"] = data_batch_neg_prompt["neg_t5_text_embeddings"]
287
- data_batch_neg_prompt["t5_text_mask"] = data_batch_neg_prompt["neg_t5_text_mask"]
288
-
289
- condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates)
290
- un_condition: Any = self(data_batch_neg_prompt, override_dropout_rate=uncond_dropout_rates)
291
-
292
- return condition, un_condition
293
-
294
-
295
- @dataclass
296
- class CosmosCondition:
297
- crossattn_emb: torch.Tensor
298
- crossattn_mask: torch.Tensor
299
- padding_mask: Optional[torch.Tensor] = None
300
- scalar_feature: Optional[torch.Tensor] = None
301
-
302
- def to_dict(self) -> Dict[str, Optional[torch.Tensor]]:
303
- return {f.name: getattr(self, f.name) for f in fields(self)}
304
-
305
-
306
- class VideoConditioner(GeneralConditioner):
307
- def forward(
308
- self,
309
- batch: Dict,
310
- override_dropout_rate: Optional[Dict[str, float]] = None,
311
- ) -> BaseVideoCondition:
312
- output = super()._forward(batch, override_dropout_rate)
313
- return BaseVideoCondition(**output)
314
-
315
-
316
- class VideoExtendConditioner(GeneralConditioner):
317
- def forward(
318
- self,
319
- batch: Dict,
320
- override_dropout_rate: Optional[Dict[str, float]] = None,
321
- ) -> VideoExtendCondition:
322
- output = super()._forward(batch, override_dropout_rate)
323
- return VideoExtendCondition(**output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/45a2ac6c32e8df9e6836ed55973912b8730c0749 DELETED
@@ -1,50 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import torch
17
- import torch.nn as nn
18
- import torch.nn.functional as F
19
-
20
-
21
- class MLP(nn.Module):
22
- def __init__(
23
- self,
24
- dim: int,
25
- hidden_dim: int,
26
- ):
27
- """
28
- Initializes the multilayer perceptron (MLP) module.
29
-
30
- Args:
31
- dim: The input and output dimensionality.
32
- hidden_dim: The dimensionality of the hidden layer.
33
- """
34
- super().__init__()
35
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
36
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
37
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
38
-
39
- def forward(self, x: torch.Tensor) -> torch.Tensor:
40
- """
41
- Performs the forward pass of the MLP module.
42
-
43
- Args:
44
- x: The input tensor of shape (batch_size, dim).
45
-
46
- Returns:
47
- The output tensor of shape (batch_size, dim).
48
- """
49
- output = self.w2(F.silu(self.w1(x)) * self.w3(x))
50
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/46385211d438d1953e9ba21376680dc2c42db01c DELETED
@@ -1,219 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import argparse
17
- import os
18
- import re
19
- import string
20
- from difflib import SequenceMatcher
21
-
22
- from .misc import misc
23
- import nltk
24
- from better_profanity import profanity
25
-
26
- from .guardrail_blocklist_utils import read_keyword_list_from_dir, to_ascii
27
- from .guardrail_common_core import ContentSafetyGuardrail, GuardrailRunner
28
- from .log import log
29
-
30
- DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/blocklist"
31
- CENSOR = misc.Color.red("*")
32
-
33
-
34
- class Blocklist(ContentSafetyGuardrail):
35
- def __init__(
36
- self,
37
- checkpoint_dir: str = DEFAULT_CHECKPOINT_DIR,
38
- guardrail_partial_match_min_chars: int = 4,
39
- guardrail_partial_match_letter_count: float = 0.5,
40
- ) -> None:
41
- nltk.data.path.append(os.path.join(checkpoint_dir, "nltk_data"))
42
- self.lemmatizer = nltk.WordNetLemmatizer()
43
- self.profanity = profanity
44
- self.checkpoint_dir = checkpoint_dir
45
- self.guardrail_partial_match_min_chars = guardrail_partial_match_min_chars
46
- self.guardrail_partial_match_letter_count = guardrail_partial_match_letter_count
47
-
48
- # Load blocklist and whitelist keywords
49
- self.blocklist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "custom"))
50
- self.whitelist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "whitelist"))
51
- self.exact_match_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "exact_match"))
52
-
53
- self.profanity.load_censor_words(custom_words=self.blocklist_words, whitelist_words=self.whitelist_words)
54
- log.debug(f"Loaded {len(self.blocklist_words)} words/phrases from blocklist")
55
- log.debug(f"Whitelisted {len(self.whitelist_words)} words/phrases from whitelist")
56
- log.debug(f"Loaded {len(self.exact_match_words)} exact match words/phrases from blocklist")
57
-
58
- def uncensor_whitelist(self, input_prompt: str, censored_prompt: str) -> str:
59
- """Explicitly uncensor words that are in the whitelist."""
60
- input_words = input_prompt.split()
61
- censored_words = censored_prompt.split()
62
- whitelist_words = set(self.whitelist_words)
63
- for i, token in enumerate(input_words):
64
- if token.strip(string.punctuation).lower() in whitelist_words:
65
- censored_words[i] = token
66
- censored_prompt = " ".join(censored_words)
67
- return censored_prompt
68
-
69
- def censor_prompt(self, input_prompt: str) -> tuple[bool, str]:
70
- """Censor the prompt using the blocklist with better-profanity fuzzy matching.
71
-
72
- Args:
73
- input_prompt: input prompt to censor
74
-
75
- Returns:
76
- bool: True if the prompt is blocked, False otherwise
77
- str: A message indicating why the prompt was blocked
78
- """
79
- censored_prompt = self.profanity.censor(input_prompt, censor_char=CENSOR)
80
- # Uncensor whitelisted words that were censored from blocklist fuzzy matching
81
- censored_prompt = self.uncensor_whitelist(input_prompt, censored_prompt)
82
- if CENSOR in censored_prompt:
83
- return True, f"Prompt blocked by censorship: Censored Prompt: {censored_prompt}"
84
- return False, ""
85
-
86
- @staticmethod
87
- def check_partial_match(
88
- normalized_prompt: str, normalized_word: str, guardrail_partial_match_letter_count: float
89
- ) -> tuple[bool, str]:
90
- """
91
- Check robustly if normalized word and the matching target have a difference of up to guardrail_partial_match_letter_count characters.
92
-
93
- Args:
94
- normalized_prompt: a string with many words
95
- normalized_word: a string with one or multiple words, its length is smaller than normalized_prompt
96
- guardrail_partial_match_letter_count: maximum allowed difference in characters (float to allow partial characters)
97
-
98
- Returns:
99
- bool: True if a match is found, False otherwise
100
- str: A message indicating why the prompt was blocked
101
- """
102
- prompt_words = normalized_prompt.split()
103
- word_length = len(normalized_word.split())
104
- max_similarity_ratio = (len(normalized_word) - float(guardrail_partial_match_letter_count)) / float(
105
- len(normalized_word)
106
- )
107
-
108
- for i in range(len(prompt_words) - word_length + 1):
109
- # Extract a substring from the prompt with the same number of words as the normalized_word
110
- substring = " ".join(prompt_words[i : i + word_length])
111
- similarity_ratio = SequenceMatcher(None, substring, normalized_word).ratio()
112
- if similarity_ratio >= max_similarity_ratio:
113
- return (
114
- True,
115
- f"Prompt blocked by partial match blocklist: Prompt: {normalized_prompt}, Partial Match Word: {normalized_word}",
116
- )
117
-
118
- return False, ""
119
-
120
- @staticmethod
121
- def check_against_whole_word_blocklist(
122
- prompt: str,
123
- blocklist: list[str],
124
- guardrail_partial_match_min_chars: int = 4,
125
- guardrail_partial_match_letter_count: float = 0.5,
126
- ) -> bool:
127
- """
128
- Check if the prompt contains any whole words from the blocklist.
129
- The match is case insensitive and robust to multiple spaces between words.
130
-
131
- Args:
132
- prompt: input prompt to check
133
- blocklist: list of words to check against
134
- guardrail_partial_match_min_chars: minimum number of characters in a word to check for partial match
135
- guardrail_partial_match_letter_count: maximum allowed difference in characters for partial match
136
-
137
- Returns:
138
- bool: True if a match is found, False otherwise
139
- str: A message indicating why the prompt was blocked
140
- """
141
- # Normalize spaces and convert to lowercase
142
- normalized_prompt = re.sub(r"\s+", " ", prompt).strip().lower()
143
-
144
- for word in blocklist:
145
- # Normalize spaces and convert to lowercase for each blocklist word
146
- normalized_word = re.sub(r"\s+", " ", word).strip().lower()
147
-
148
- # Use word boundaries to ensure whole word match
149
- if re.search(r"\b" + re.escape(normalized_word) + r"\b", normalized_prompt):
150
- return True, f"Prompt blocked by exact match blocklist: Prompt: {prompt}, Exact Match Word: {word}"
151
-
152
- # Check for partial match if the word is long enough
153
- if len(normalized_word) >= guardrail_partial_match_min_chars:
154
- match, message = Blocklist.check_partial_match(
155
- normalized_prompt, normalized_word, guardrail_partial_match_letter_count
156
- )
157
- if match:
158
- return True, message
159
-
160
- return False, ""
161
-
162
- def is_safe(self, input_prompt: str = "") -> tuple[bool, str]:
163
- """Check if the input prompt is safe using the blocklist."""
164
- # Check if the input is empty
165
- if not input_prompt:
166
- return False, "Input is empty"
167
- input_prompt = to_ascii(input_prompt)
168
-
169
- # Check full sentence for censored words
170
- censored, message = self.censor_prompt(input_prompt)
171
- if censored:
172
- return False, message
173
-
174
- # Check lemmatized words for censored words
175
- tokens = nltk.word_tokenize(input_prompt)
176
- lemmas = [self.lemmatizer.lemmatize(token) for token in tokens]
177
- lemmatized_prompt = " ".join(lemmas)
178
- censored, message = self.censor_prompt(lemmatized_prompt)
179
- if censored:
180
- return False, message
181
-
182
- # Check for exact match blocklist words
183
- censored, message = self.check_against_whole_word_blocklist(
184
- input_prompt,
185
- self.exact_match_words,
186
- self.guardrail_partial_match_min_chars,
187
- self.guardrail_partial_match_letter_count,
188
- )
189
- if censored:
190
- return False, message
191
-
192
- # If all these checks pass, the input is safe
193
- return True, "Input is safe"
194
-
195
-
196
- def parse_args():
197
- parser = argparse.ArgumentParser()
198
- parser.add_argument("--prompt", type=str, required=True, help="Input prompt")
199
- parser.add_argument(
200
- "--checkpoint_dir",
201
- type=str,
202
- help="Path to the Blocklist checkpoint folder",
203
- default=DEFAULT_CHECKPOINT_DIR,
204
- )
205
- return parser.parse_args()
206
-
207
-
208
- def main(args):
209
- blocklist = Blocklist(checkpoint_dir=args.checkpoint_dir)
210
- runner = GuardrailRunner(safety_models=[blocklist])
211
- with misc.timer("blocklist safety check"):
212
- safety, message = runner.run_safety_check(args.prompt)
213
- log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
214
- log.info(f"Message: {message}") if not safety else None
215
-
216
-
217
- if __name__ == "__main__":
218
- args = parse_args()
219
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/4a13a8fde58e7852b683112be63eaed44e1f143f DELETED
@@ -1,596 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import json
17
- import os
18
- import time
19
- from pathlib import Path
20
- from typing import Any, Dict, List, Optional, Set
21
-
22
- from .misc import misc
23
- import torch
24
- from safetensors.torch import load_file
25
- from torch.nn.modules.module import _IncompatibleKeys
26
-
27
- from .ar_config_base_model import ModelConfig
28
- from .ar_config_base_tokenizer import TokenizerConfig
29
- from .ar_module_mm_projector import MultimodalProjector
30
- from .ar_network_transformer import Transformer
31
- from .ar_network_vit import VisionTransformer, get_vit_config
32
- from .ar_tokenizer_tokenizer import DiscreteMultimodalTokenizer, update_vocab_size
33
- from .ar_utils_checkpoint import (
34
- get_partial_state_dict,
35
- process_state_dict,
36
- substrings_to_ignore,
37
- )
38
- from .ar_utils_sampling import decode_n_tokens, decode_one_token, prefill
39
- from .log import log
40
-
41
-
42
- class AutoRegressiveModel(torch.nn.Module):
43
- """
44
- A class to build and use a AutoRegressiveModel model for text generation.
45
-
46
- Methods:
47
- build: Build a AutoRegressiveModel instance by initializing and loading a model checkpoint.
48
- generate: Generate text sequences based on provided prompts using the language generation model.
49
- """
50
-
51
- def __init__(
52
- self,
53
- model: Transformer = None,
54
- tokenizer: DiscreteMultimodalTokenizer = None,
55
- config: ModelConfig = None,
56
- vision_encoder: VisionTransformer = None,
57
- mm_projector: MultimodalProjector = None,
58
- ):
59
- """
60
- Initialize the AutoRegressiveModel instance with a model and tokenizer.
61
-
62
- Args:
63
- model (Transformer): The Transformer model for text generation.
64
- tokenizer (Tokenizer): The tokenizer for encoding and decoding text.
65
- config (Config): The configuration for the AutoRegressiveModel model.
66
- vision_encoder (VisionTransformer): The vision encoder for the AutoRegressiveModel model.
67
- mm_projector (MultimodalProjector): The multi-modal projector for the AutoRegressiveModel model.
68
- """
69
- super().__init__()
70
- self.model = model
71
- self.tokenizer = tokenizer
72
- self.config = config
73
-
74
- self.vision_encoder = vision_encoder
75
- self.mm_projector = mm_projector
76
-
77
- @property
78
- def precision(self):
79
- return self.model.precision
80
-
81
- def get_num_params(
82
- self,
83
- ) -> int:
84
- """
85
- Return the number of parameters in the model.
86
- """
87
- n_params = sum(p.numel() for p in self.parameters())
88
- return n_params
89
-
90
- def load_ar_model(
91
- self,
92
- tokenizer_config,
93
- ):
94
- """
95
- Load the AR model.
96
- """
97
- model_config = self.config
98
- ckpt_path = model_config.ckpt_path
99
- with misc.timer(f"loading checkpoint from {ckpt_path}"):
100
- if ckpt_path.endswith("safetensors"):
101
- # Load with safetensors API
102
- checkpoint = load_file(ckpt_path, device="cpu")
103
- else:
104
- # The pytorch version
105
- checkpoint = torch.load(
106
- ckpt_path,
107
- map_location="cpu",
108
- mmap=True, # load the checkpoint in memory-mapped mode
109
- weights_only=True,
110
- )
111
- llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint
112
- orig_precision = torch.get_default_dtype()
113
- precision = getattr(torch, model_config.precision)
114
- torch.set_default_dtype(precision)
115
- log.debug(f"Setting torch default dtype to {precision}")
116
-
117
- model = Transformer(
118
- params=model_config,
119
- tokenizer_config=tokenizer_config,
120
- )
121
- log.debug(
122
- f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size}"
123
- )
124
- vocab_size = update_vocab_size(
125
- existing_vocab_size=0,
126
- to_be_added_vocab_size=tokenizer_config.video_tokenizer.vocab_size,
127
- training_type=tokenizer_config.training_type,
128
- add_special_tokens=False,
129
- )
130
- log.debug(
131
- f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size} vocab_size {vocab_size}"
132
- )
133
- # Perform vocab expansion
134
- if vocab_size > model.vocab_size:
135
- log.debug(f"Expanding vocab size to {vocab_size}")
136
- # For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer,
137
- expand_output_layer = not (tokenizer_config.training_type == "text_to_video")
138
- model.expand_vocab(
139
- vocab_size,
140
- init_method="gaussian",
141
- expand_output_layer=expand_output_layer,
142
- )
143
- # Remove the "model." prefix in the state_dict
144
- llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
145
- with misc.timer("loading state_dict into model"):
146
- missing_keys, _ = model.load_state_dict(llm_checkpoint, strict=True)
147
- # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
148
- missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
149
- assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
150
-
151
- self.model = model.to(precision).to("cuda")
152
- torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value
153
-
154
- def load_tokenizer(self, tokenizer_config):
155
- """
156
- Load the tokenizer.
157
- """
158
- self.tokenizer = DiscreteMultimodalTokenizer(tokenizer_config)
159
-
160
- @staticmethod
161
- def build(
162
- model_config: ModelConfig = ModelConfig(),
163
- tokenizer_config: TokenizerConfig = None,
164
- ) -> "AutoRegressiveModel":
165
- """
166
- Build a AutoRegressiveModel instance by initializing and loading a model checkpoint.
167
-
168
- Args:
169
- model_config (ModelConfig, optional): The model configuration for the AutoRegressiveModel instance. Defaults to ModelConfig().
170
- tokenizer_config (TokenizerConfig, optional): The tokenizer configuration for the AutoRegressiveModel instance. Defaults to None.
171
- download_rank_sync (bool, optional): Whether to download the checkpoint in a rank-synchronized manner. Defaults to True.
172
- Returns:
173
- AutoRegressiveModel: An instance of the AutoRegressiveModel class with the loaded model and tokenizer.
174
-
175
- Raises:
176
- AssertionError: If there are no checkpoint files in the specified directory.
177
-
178
- Note:
179
- This method sets the device to CUDA and loads the pre-trained model and tokenizer.
180
- """
181
- # Initialize model configuration parameters
182
- config_params = {}
183
-
184
- # Load checkpoint and model parameters
185
-
186
- if model_config.ckpt_path is None:
187
- # If ckpt_path is not provided, we assume the model checkpoint is saved in the ckpt_dir
188
- ckpt_dir = model_config.ckpt_dir
189
-
190
- # We prioritize safetensors version over the pytorch version, since the former is
191
- # much faster for checkpoint loading.
192
- checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors"))
193
- if len(checkpoints) == 0:
194
- checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
195
-
196
- assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
197
- assert (
198
- len(checkpoints) == 1
199
- ), f"multiple checkpoint files found in {ckpt_dir} (currently only one is supported)"
200
- ckpt_path = str(checkpoints[0]) # Assuming single checkpoint for non-parallel case
201
-
202
- if os.path.exists(Path(ckpt_dir) / "config.json"):
203
- with open(Path(ckpt_dir) / "config.json", "r") as f:
204
- config_params = json.loads(f.read())
205
- else:
206
- log.info(
207
- f"No params.json found in the checkpoint directory ({ckpt_dir}). " f"Using default model config."
208
- )
209
-
210
- else:
211
- # If ckpt_path is provided, we load the model from the specified path,
212
- # and use the default model configuration
213
- ckpt_path = model_config.ckpt_path
214
-
215
- for key, value in config_params.items():
216
- if hasattr(model_config, key):
217
- # Override the default model configuration with the parameters from the checkpoint
218
- setattr(model_config, key, value)
219
-
220
- with misc.timer(f"loading checkpoint from {ckpt_path}"):
221
- if ckpt_path.endswith("safetensors"):
222
- # Load with safetensors API
223
- checkpoint = load_file(ckpt_path, device="cpu")
224
- else:
225
- # The pytorch version
226
- checkpoint = torch.load(
227
- ckpt_path,
228
- map_location="cpu",
229
- mmap=True, # load the checkpoint in memory-mapped mode
230
- weights_only=True,
231
- )
232
- llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint
233
-
234
- if model_config.vision_encoder is not None:
235
- # Take the LLM weights (starting with "model.") from the VLM checkpoint
236
- llm_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="model.")
237
- if model_config.vision_encoder is not None:
238
- # For vanilla VLM ckpt before fine-tuning, `checkpoint['model']` only contains LLM weights, and `checkpoint['vision_encoder']`
239
- # and `checkpoint['mm_projector']` are both for those weights
240
- # For fine-tuned VLM ckpt, `checkpoint['model']` contains all LLM, mm_projector and vision_encoder weights
241
- if "vision_encoder" in checkpoint:
242
- log.debug("Using pretrained vision_encoder")
243
- vit_checkpoint = checkpoint["vision_encoder"]
244
- else:
245
- log.debug("Using fine-tuned vision_encoder")
246
- vit_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="vision_encoder.")
247
- vit_checkpoint = process_state_dict(vit_checkpoint, prefix_to_remove="vision_encoder.")
248
- if "mm_projector" in checkpoint:
249
- log.debug("Using pretrained mm_projector")
250
- projector_checkpoint = checkpoint["mm_projector"]
251
- else:
252
- log.debug("Using fine-tuned mm_projector")
253
- projector_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="mm_projector.")
254
- projector_checkpoint = process_state_dict(projector_checkpoint, prefix_to_remove="mm_projector.")
255
- assert (
256
- len(vit_checkpoint) > 0 and len(projector_checkpoint) > 0
257
- ), "vit_checkpoint and projector_checkpoint cannot be empty. We do not support random initialization for vision_encoder and mm_projector."
258
-
259
- tokenizer = DiscreteMultimodalTokenizer(tokenizer_config)
260
- orig_precision = torch.get_default_dtype()
261
- precision = getattr(torch, model_config.precision)
262
- torch.set_default_dtype(precision)
263
- log.debug(f"Setting torch default dtype to {precision}")
264
-
265
- model = Transformer(
266
- params=model_config,
267
- tokenizer_config=tokenizer_config,
268
- )
269
- model_kwargs = {}
270
-
271
- if model_config.vision_encoder is not None:
272
- assert model_config.mm_projector is not None, "mm_projector must be provided if vision_encoder is provided."
273
- vit_config = get_vit_config(model_config.vision_encoder)
274
- vision_encoder = VisionTransformer.build(
275
- vit_config,
276
- )
277
-
278
- mm_projector = MultimodalProjector(
279
- mm_projector_type=model_config.mm_projector, in_dim=vit_config["dim"], out_dim=model_config["dim"]
280
- )
281
- model_kwargs.update({"vision_encoder": vision_encoder, "mm_projector": mm_projector})
282
-
283
- # Perform vocab expansion
284
- if tokenizer.vocab_size > model.vocab_size:
285
- log.debug(f"Expanding vocab size to {tokenizer.vocab_size}")
286
- # For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer,
287
- expand_output_layer = not (tokenizer.training_type == "text_to_video")
288
- model.expand_vocab(
289
- tokenizer.vocab_size,
290
- init_method="gaussian",
291
- expand_output_layer=expand_output_layer,
292
- )
293
-
294
- # Remove the "model." prefix in the state_dict
295
- llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
296
- with misc.timer("loading state_dict into model"):
297
- missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True)
298
- # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
299
- missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
300
- assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
301
-
302
- if model_config.vision_encoder is not None:
303
- vision_encoder.load_state_dict(vit_checkpoint)
304
- mm_projector.load_state_dict(projector_checkpoint)
305
- if model_config.vision_encoder_in_channels != 3:
306
- vision_encoder.expand_in_channels(model_config.vision_encoder_in_channels)
307
-
308
- model = model.to(precision) # ensure model parameters are in the correct precision
309
- log.debug(f"Model config: {model_config}")
310
-
311
- model_class = AutoRegressiveModel
312
-
313
- torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value
314
-
315
- return model_class(model, tokenizer, model_config, **model_kwargs)
316
-
317
- @torch.no_grad()
318
- def generate(
319
- self,
320
- prompt_tokens: List[List[int]] | torch.Tensor,
321
- max_gen_len: int,
322
- temperature: float = 1.0,
323
- top_k: Optional[int] = None,
324
- top_p: Optional[float] = None,
325
- num_gen_seq: int = 1,
326
- logprobs: bool = False,
327
- echo: bool = False,
328
- seed: int = None,
329
- context: Optional[torch.Tensor] = None,
330
- context_mask: Optional[torch.Tensor] = None,
331
- compile_sampling: bool = True,
332
- compile_prefill: bool = False,
333
- verbose: bool = True,
334
- stop_tokens: Optional[Set[int]] = None,
335
- images: Optional[torch.Tensor] = None,
336
- ):
337
- """
338
- Autoregressive generation built upon the gpt-fast implementation (https://github.com/pytorch-labs/gpt-fast).
339
-
340
- Args:
341
- prompt_tokens (List[List[int]] | torch.Tensor): A single prompt of shape (1, seq_len).
342
- max_gen_len (int): Maximum length of the generated text sequence.
343
- temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
344
- top_k (int, optional): Top-k value for top-k sampling. Defaults to None.
345
- top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to None.
346
- num_gen_seq (int, optional): Number of outputs to generate given the same prompt. Defaults to 1. When temperature == 0, num_gen_seq must be 1 because the generation is deterministic.
347
- echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
348
- logit_clipping_range (list, optional): Range of logits to clip. Defaults to [].
349
- seed (int, optional): Random seed for reproducibility. Defaults to None.
350
- compile_sampling (bool, optional): Flag indicating whether to compile the decoding function. Defaults to True.
351
- compile_prefill (bool, optional): Flag indicating whether to compile the prefill function. Defaults to False.
352
- verbose (bool, optional): Flag indicating whether to print the the time. Defaults to False.
353
- """
354
- assert top_k is None or top_p is None, f"Only one of top_k ({top_k} or top_p ({top_p} should be specified."
355
- if temperature == 0:
356
- top_p, top_k = None, None
357
- log.debug("Setting top_p and top_k to None because temperature is 0")
358
- if top_p is not None:
359
- log.debug(f"Using top-p sampling with p={top_p} and temperature={temperature}")
360
- elif top_k is not None:
361
- log.debug(f"Using top-k sampling with k={top_k} and temperature={temperature}")
362
- else:
363
- log.debug("Not applying top-k or top-p sampling. Will use top-k sampling with k=None")
364
-
365
- orig_precision = torch.get_default_dtype()
366
- torch.set_default_dtype(self.precision)
367
-
368
- torch._inductor.config.coordinate_descent_tuning = True
369
- torch._inductor.config.triton.unique_kernel_names = True
370
- # Experimental features to reduce compilation times, will be on by default in future
371
- torch._inductor.config.fx_graph_cache = True
372
-
373
- if seed is not None:
374
- misc.set_random_seed(seed)
375
-
376
- assert not logprobs, "logprobs are not supported for fast_generate yet"
377
- # Examine if the function prefil and decode_one_token functions are compiled yet. If not, compile them based on the flags
378
- if compile_sampling and not getattr(self, "inference_decode_compiled", False):
379
- self.decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
380
- self.inference_decode_compiled = True
381
- log.info("Compiled AR sampling function. Note: the first run will be slower due to compilation")
382
- if compile_prefill and not getattr(self, "inference_prefill_compiled", False):
383
- self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
384
- self.inference_prefill_compiled = True
385
- log.info("Compiled prefill function. Note: the first run will be slower due to compilation")
386
-
387
- if not hasattr(self, "decode_one_token"):
388
- self.decode_one_token = decode_one_token
389
- if not hasattr(self, "prefill"):
390
- self.prefill = prefill
391
-
392
- # Initialization and Assertions
393
- if isinstance(self.model.params, list):
394
- # During training, model.params is a list
395
- log.debug(
396
- f"Find self.model.params is a list, use self.config instead. Get max_batch_size={self.config.max_batch_size}, max_seq_len={self.config.max_seq_len}"
397
- )
398
- params = self.config
399
- else:
400
- params = self.model.params
401
- if isinstance(prompt_tokens, list):
402
- prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cuda")
403
- if prompt_tokens.ndim == 1:
404
- prompt_tokens = prompt_tokens.view(1, -1)
405
- else:
406
- assert prompt_tokens.ndim == 2, f"prompt_tokens has shape {prompt_tokens.shape}"
407
- batch_size, prompt_len = prompt_tokens.shape
408
- total_len = min(params.max_seq_len, max_gen_len + prompt_len)
409
- if max_gen_len + prompt_len > params.max_seq_len:
410
- log.warning(
411
- f"max_gen_len + prompt_len={max_gen_len + prompt_len} exceeds max_seq_len={params.max_seq_len}, truncate max_gen_len to {params.max_seq_len - prompt_len}"
412
- )
413
- max_gen_len = params.max_seq_len - prompt_len
414
-
415
- if context_mask is not None:
416
- context_mask = context_mask.to(dtype=torch.bool)
417
- if context_mask.ndim == 2:
418
- assert (
419
- context_mask.shape[0] == batch_size
420
- ), f"batch_size mismatch: {context_mask.shape[0]} != {batch_size}"
421
- # Unsqueeze it to make it of shape [batch_size, 1, 1, context_seq_len]
422
- context_mask = context_mask.view(batch_size, 1, 1, -1)
423
-
424
- if num_gen_seq > 1:
425
- assert (
426
- batch_size == 1
427
- ), f"num_gen_seq > 1 is only supported for a single prompt, got {len(prompt_tokens)} prompts"
428
- log.debug(f"Generating {num_gen_seq} sequences with the same prompt")
429
- assert (
430
- num_gen_seq <= params.max_batch_size
431
- ), f"num_gen_seq={num_gen_seq} exceeds max_batch_size={params.max_batch_size}"
432
- # repeat the prompt tokens for num_gen_seq times
433
- prompt_tokens = prompt_tokens.repeat(num_gen_seq, 1)
434
- assert prompt_tokens.shape == (
435
- num_gen_seq,
436
- prompt_len,
437
- ), f"prompt_tokens must be of shape (num_gen_seq, seq_len), got {prompt_tokens.shape}"
438
- batch_size = len(prompt_tokens)
439
-
440
- # create an empty tensor of the expected final shape and fill in the current tokens
441
- empty = torch.empty(batch_size, total_len, dtype=prompt_tokens.dtype, device=prompt_tokens.device)
442
- empty[:, :prompt_len] = prompt_tokens
443
- seq = empty
444
- input_pos = torch.arange(0, prompt_len, device="cuda")
445
-
446
- if verbose:
447
- prefill_start = time.time()
448
-
449
- if images is not None:
450
- images = images.to(device=prompt_tokens.device, dtype=torch.bfloat16)
451
- prompt_token_embeddings = self.embed_vision_language_features(prompt_tokens, images)
452
- else:
453
- prompt_token_embeddings = None
454
-
455
- if context is not None:
456
- context = context.to(device=prompt_tokens.device, dtype=self.precision)
457
-
458
- # Prefill stage
459
- next_token = self.prefill(
460
- self.model,
461
- input_pos=input_pos,
462
- tokens=prompt_tokens if prompt_token_embeddings is None else None,
463
- token_embeddings=prompt_token_embeddings,
464
- temperature=temperature,
465
- top_k=top_k,
466
- top_p=top_p,
467
- context=context,
468
- context_mask=context_mask,
469
- )
470
- if verbose:
471
- prefill_time = time.time() - prefill_start
472
-
473
- seq[:, [prompt_len]] = next_token.to(dtype=seq.dtype)
474
- input_pos = torch.tensor([prompt_len], dtype=torch.long, device="cuda")
475
- stop_tokens = self.tokenizer.stop_tokens if stop_tokens is None else stop_tokens
476
- stop_tokens = torch.tensor(list(stop_tokens), dtype=torch.long, device="cuda")
477
-
478
- if verbose:
479
- decode_start = time.time()
480
- # Decode stage
481
- generated_tokens = decode_n_tokens(
482
- self.model,
483
- next_token.view(batch_size, -1),
484
- input_pos,
485
- max_gen_len - 1,
486
- temperature=temperature,
487
- top_k=top_k,
488
- top_p=top_p,
489
- stop_tokens=stop_tokens,
490
- decode_one_token_function=self.decode_one_token,
491
- context=context,
492
- context_mask=context_mask,
493
- )
494
- gen_len = len(generated_tokens)
495
- if verbose:
496
- decode_time = time.time() - decode_start
497
- prefill_throughput = prompt_len / prefill_time
498
- decode_throughput = gen_len / decode_time
499
- log.debug(f"[Prefill] Time: {prefill_time:.2f}s; Throughput: {prefill_throughput:.2f} tokens/s")
500
- log.debug(f"[Decode] Time: {decode_time:.2f}s; Throughput: {decode_throughput:.2f} tokens/s")
501
-
502
- generated_tokens = torch.cat(generated_tokens, dim=1)
503
-
504
- log.debug(f"generated_tokens: {generated_tokens.shape}")
505
- seq = seq[:, : prompt_len + 1 + gen_len]
506
- seq[:, prompt_len + 1 :] = generated_tokens
507
- if not echo:
508
- seq = seq[:, prompt_len:]
509
-
510
- torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value
511
-
512
- return seq, None
513
-
514
- def embed_vision_language_features(self, input_ids: torch.Tensor, images: torch.tensor) -> torch.Tensor:
515
- """
516
- Embed vision and language features into a combined representation.
517
-
518
- Args:
519
- input_ids (torch.Tensor): Input token IDs.
520
- images (torch.tensor): Input images.
521
-
522
- Returns:
523
- torch.Tensor: Combined vision-language features.
524
-
525
- Raises:
526
- AssertionError: If vision encoder or mm projector is not initialized,
527
- or if dimensions mismatch.
528
- """
529
- # Ensure vision encoder and mm projector are initialized
530
- assert self.vision_encoder is not None
531
- assert self.mm_projector is not None
532
-
533
- # Get image token ID and validate it
534
- image_token_id = self.vision_encoder.image_token_id
535
- assert isinstance(image_token_id, int) and image_token_id >= 0, f"Invalid image_token_id: {image_token_id}"
536
-
537
- # Identify text and image locations in the input
538
- text_locations = input_ids != image_token_id
539
- image_locations = input_ids == image_token_id
540
-
541
- # Process text features
542
- text_features = self.model.tok_embeddings(input_ids[text_locations])
543
-
544
- # Process image features
545
- images = images.to(device=text_features.device, dtype=text_features.dtype)
546
- vit_outputs = self.vision_encoder(images)
547
- image_features = self.mm_projector(vit_outputs)
548
-
549
- # Get dimensions
550
- B, seq_len = input_ids.shape
551
- N_total = B * seq_len
552
- N_txt, D_txt = text_features.shape
553
- N_img, N_patch, D_img = image_features.shape
554
-
555
- # Reshape image features
556
- image_features = image_features.reshape(N_img * N_patch, D_img)
557
-
558
- # Validate dimensions
559
- assert D_txt == D_img, f"Text features dim {D_txt} should be equal to image features dim {D_img}"
560
- assert (
561
- N_total == N_txt + N_img * N_patch
562
- ), f"seq_len {seq_len} should be equal to N_txt + N_img*N_Patch {(N_txt, N_img * N_patch, image_locations.sum().item())}"
563
-
564
- # Combine text and image features
565
- combined_features = torch.empty(
566
- (B, seq_len, D_txt),
567
- dtype=text_features.dtype,
568
- device=text_features.device,
569
- )
570
- combined_features[text_locations, :] = text_features
571
- combined_features[image_locations, :] = image_features
572
-
573
- return combined_features
574
-
575
- def state_dict(self, *args, **kwargs):
576
- """
577
- Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8).
578
- """
579
- state_dict = super().state_dict(*args, **kwargs)
580
- return process_state_dict(state_dict)
581
-
582
- def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False):
583
- """
584
- Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by
585
- TransformerEngine for FP8).
586
- """
587
- state_dict = process_state_dict(state_dict)
588
- missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign)
589
- actual_missing_keys = []
590
- for key in missing_keys:
591
- if not any(substring in key for substring in substrings_to_ignore):
592
- actual_missing_keys.append(key)
593
- if strict:
594
- if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0:
595
- raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}")
596
- return _IncompatibleKeys(actual_missing_keys, unexpected_keys)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/4c860c42a1c3d8adc417e9593892491d0803fe51 DELETED
@@ -1,113 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import collections.abc as abc
17
- import dataclasses
18
- import logging
19
- from typing import Any
20
-
21
- import attrs
22
-
23
- from .lazy_registry import _convert_target_to_string, locate
24
-
25
- __all__ = ["dump_dataclass", "instantiate"]
26
-
27
-
28
- def is_dataclass_or_attrs(target):
29
- return dataclasses.is_dataclass(target) or attrs.has(target)
30
-
31
-
32
- def dump_dataclass(obj: Any):
33
- """
34
- Dump a dataclass recursively into a dict that can be later instantiated.
35
-
36
- Args:
37
- obj: a dataclass object
38
-
39
- Returns:
40
- dict
41
- """
42
- assert dataclasses.is_dataclass(obj) and not isinstance(
43
- obj, type
44
- ), "dump_dataclass() requires an instance of a dataclass."
45
- ret = {"_target_": _convert_target_to_string(type(obj))}
46
- for f in dataclasses.fields(obj):
47
- v = getattr(obj, f.name)
48
- if dataclasses.is_dataclass(v):
49
- v = dump_dataclass(v)
50
- if isinstance(v, (list, tuple)):
51
- v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v]
52
- ret[f.name] = v
53
- return ret
54
-
55
-
56
- def instantiate(cfg, *args, **kwargs):
57
- """
58
- Recursively instantiate objects defined in dictionaries by
59
- "_target_" and arguments.
60
-
61
- Args:
62
- cfg: a dict-like object with "_target_" that defines the caller, and
63
- other keys that define the arguments
64
- args: Optional positional parameters pass-through.
65
- kwargs: Optional named parameters pass-through.
66
-
67
- Returns:
68
- object instantiated by cfg
69
- """
70
- from omegaconf import DictConfig, ListConfig, OmegaConf
71
-
72
- if isinstance(cfg, ListConfig):
73
- lst = [instantiate(x) for x in cfg]
74
- return ListConfig(lst, flags={"allow_objects": True})
75
- if isinstance(cfg, list):
76
- # Specialize for list, because many classes take
77
- # list[objects] as arguments, such as ResNet, DatasetMapper
78
- return [instantiate(x) for x in cfg]
79
-
80
- # If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config),
81
- # instantiate it to the actual dataclass.
82
- if isinstance(cfg, DictConfig) and is_dataclass_or_attrs(cfg._metadata.object_type):
83
- return OmegaConf.to_object(cfg)
84
-
85
- if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
86
- # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
87
- # but faster: https://github.com/facebookresearch/hydra/issues/1200
88
- cfg = {k: instantiate(v) for k, v in cfg.items()}
89
- cls = cfg.pop("_target_")
90
- cls = instantiate(cls)
91
-
92
- if isinstance(cls, str):
93
- cls_name = cls
94
- cls = locate(cls_name)
95
- assert cls is not None, cls_name
96
- else:
97
- try:
98
- cls_name = cls.__module__ + "." + cls.__qualname__
99
- except Exception:
100
- # target could be anything, so the above could fail
101
- cls_name = str(cls)
102
- assert callable(cls), f"_target_ {cls} does not define a callable object"
103
- try:
104
- # override config with kwargs
105
- instantiate_kwargs = {}
106
- instantiate_kwargs.update(cfg)
107
- instantiate_kwargs.update(kwargs)
108
- return cls(*args, **instantiate_kwargs)
109
- except TypeError:
110
- logger = logging.getLogger(__name__)
111
- logger.error(f"Error when instantiating {cls_name}!")
112
- raise
113
- return cfg # return as-is if don't know what to do
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/4de12fae686821ebf94aec3420719e6432856cf4 DELETED
@@ -1,421 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import copy
17
- from typing import Callable, List, Optional
18
-
19
- from .ar_config_base_model import ModelConfig
20
- from .ar_config_base_tokenizer import (
21
- TextTokenizerConfig,
22
- TokenizerConfig,
23
- VideoTokenizerConfig,
24
- create_discrete_video_fsq_tokenizer_state_dict_config,
25
- )
26
- from .ar_tokenizer_image_text_tokenizer import ImageTextTokenizer
27
- from .ar_tokenizer_text_tokenizer import TextTokenizer
28
- from .log import log
29
- from .lazy_config_init import LazyCall as L
30
-
31
- # Common architecture specifications
32
- BASE_CONFIG = {"n_kv_heads": 8, "norm_type": "rmsnorm", "norm_eps": 1e-5, "ffn_hidden_size": 14336}
33
- COSMOS_ARCHITECTURES = {
34
- "4b": {
35
- "n_layers": 16,
36
- "dim": 4096,
37
- "n_heads": 32,
38
- },
39
- "12b": {
40
- "n_layers": 40,
41
- "dim": 5120,
42
- "n_heads": 32,
43
- "head_dim": 128,
44
- },
45
- }
46
-
47
- COSMOS_YARN_CONFIG = {
48
- "original_latent_shape": [3, 40, 64],
49
- "apply_yarn": True,
50
- "yarn_beta_fast": 4,
51
- "yarn_beta_slow": 1,
52
- "yarn_scale": 2,
53
- }
54
-
55
- # Llama3 architecture specifications for different model sizes
56
- LLAMA3_ARCHITECTURES = {
57
- "8b": {
58
- "n_layers": 32,
59
- "dim": 4096,
60
- "n_heads": 32,
61
- "ffn_hidden_size": 14336,
62
- },
63
- }
64
- # Llama3.1 uses YaRN for long context support (context of 128k tokens)
65
- LLAMA_YARN_CONFIG = {
66
- "apply_yarn": True,
67
- "yarn_scale": 8,
68
- "yarn_beta_fast": 4,
69
- "yarn_beta_slow": 1,
70
- }
71
-
72
- # Mistral architecture specifications for different model sizes
73
- MISTRAL_ARCHITECTURES = {
74
- "12b": {
75
- "n_layers": 40,
76
- "dim": 5120,
77
- "n_heads": 32,
78
- "ffn_hidden_size": 14336,
79
- "head_dim": 128,
80
- },
81
- }
82
-
83
- PIXTRAL_VISION_ARCHITECTURES = {
84
- "12b": {"vision_encoder": "pixtral-12b-vit", "mm_projector": "mlp"},
85
- }
86
-
87
-
88
- def get_model_arch_specs(model_size: str, model_family: str = "mistral", pretrained: bool = False) -> dict:
89
- """
90
- Get the model architecture specifications for the given model size, model family and pretrained status.
91
-
92
- Args:
93
- model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", etc.
94
- model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral"
95
- pretrained (bool): Whether to load pretrained weights.
96
-
97
- Returns:
98
- dict: A dictionary containing the model architecture specifications.
99
- """
100
- arch_specs = copy.deepcopy(BASE_CONFIG)
101
- model_size = model_size.lower()
102
- if model_family.startswith("cosmos"):
103
- arch_specs.update(COSMOS_ARCHITECTURES[model_size])
104
- elif model_family.startswith("llama"):
105
- arch_specs.update(LLAMA3_ARCHITECTURES[model_size])
106
- elif model_family in ["mistral", "pixtral"]:
107
- arch_specs.update(MISTRAL_ARCHITECTURES[model_size])
108
- if model_family == "pixtral":
109
- arch_specs.update(PIXTRAL_VISION_ARCHITECTURES[model_size])
110
- else:
111
- raise ValueError(f"Model family {model_family} is not supported.")
112
-
113
- if pretrained:
114
- if model_family == "cosmos":
115
- if model_size == "12b":
116
- arch_specs.update(COSMOS_YARN_CONFIG)
117
- log.debug(f"Using YaRN for RoPE extension with config: {COSMOS_YARN_CONFIG}")
118
- else:
119
- pass
120
- elif model_family in ["llama", "llama3"]:
121
- pretrained_specs = {
122
- "rope_theta": 500000,
123
- "max_seq_len": 8192,
124
- "vocab_size": 128256,
125
- }
126
- arch_specs.update(pretrained_specs)
127
- elif model_family == "llama3.1":
128
- pretrained_specs = {
129
- "rope_theta": 500000,
130
- "max_seq_len": 131072,
131
- "original_seq_len": 8192,
132
- "vocab_size": 128256,
133
- **LLAMA_YARN_CONFIG,
134
- }
135
- arch_specs.update(pretrained_specs)
136
- elif model_family == "mistral":
137
- assert model_size == "12b", "We only support Mistral-Nemo-12B model."
138
- pretrained_specs = {
139
- "rope_theta": 1000000,
140
- "max_seq_len": 128000,
141
- "vocab_size": 131072,
142
- }
143
- arch_specs.update(pretrained_specs)
144
- elif model_family == "pixtral":
145
- assert model_size == "12b", "We only support Pixtral 12B model."
146
- pretrained_specs = {"rope_theta": 1000000000, "max_seq_len": 128000, "vocab_size": 131072}
147
- arch_specs.update(pretrained_specs)
148
- else:
149
- raise ValueError(f"Model family {model_family} doesn't have a pretrained config.")
150
-
151
- return arch_specs
152
-
153
-
154
- def create_text_model_config(
155
- model_ckpt_path: str,
156
- tokenizer_path: str,
157
- model_family: str = "mistral",
158
- model_size: str = "12b",
159
- is_instruct_model: bool = True,
160
- max_seq_len: int = None,
161
- max_batch_size: int = 1,
162
- rope_dim: str = "1D",
163
- add_special_tokens: bool = True,
164
- pytorch_rope_version: str = None,
165
- ) -> dict:
166
- """Create a text model for training or inference.
167
- Args:
168
- model_ckpt_path (str): Path to the model checkpoint.
169
- tokenizer_path (str): Path to the tokenizer folder.
170
- model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
171
- model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", "8b", "72b", etc.
172
- is_instruct_model (bool): Whether the model is an instruct model.
173
- inference (bool): Whether to create the model for inference.
174
- max_seq_len (int): Maximum sequence length.
175
- max_batch_size (int): Maximum batch size.
176
- rope_dim (str): RoPE dimension. Choices: "1D", "3D".
177
- add_special_tokens (bool): Whether to add special tokens.
178
- Returns:
179
- dict: A dictionary containing the model configuration, which can be used to instantiate the model object.
180
- """
181
- # Model size specific parameters
182
- model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
183
- if max_seq_len is not None:
184
- # Override the max_seq_len if provided
185
- model_arch_specs["max_seq_len"] = max_seq_len
186
- if pytorch_rope_version is not None:
187
- model_arch_specs["pytorch_rope_version"] = pytorch_rope_version
188
- model_config = ModelConfig(
189
- max_batch_size=max_batch_size,
190
- precision="bfloat16",
191
- ckpt_path=model_ckpt_path,
192
- use_qk_normalization=False,
193
- rope_dim=rope_dim,
194
- **model_arch_specs,
195
- )
196
-
197
- tokenizer_config = TokenizerConfig(
198
- text_tokenizer=TextTokenizerConfig(
199
- config=L(TextTokenizer)(
200
- model_family=model_family,
201
- is_instruct_model=is_instruct_model,
202
- local_path=tokenizer_path,
203
- ),
204
- data_key="text",
205
- tokenizer_offset=model_config.vocab_size,
206
- tokenize_here=False,
207
- vocab_size=model_config.vocab_size,
208
- ),
209
- seq_len=model_config.max_seq_len,
210
- training_type="text_only",
211
- add_special_tokens=add_special_tokens,
212
- )
213
- return model_config, tokenizer_config
214
-
215
-
216
- def create_vision_language_model_config(
217
- model_ckpt_path: str,
218
- tokenizer_ckpt_path: str,
219
- model_family: str = "pixtral",
220
- model_size: str = "12b",
221
- is_instruct_model: bool = True,
222
- max_batch_size: int = 1,
223
- rope_dim: str = "1D",
224
- add_special_tokens: bool = True,
225
- max_seq_len: int = None,
226
- vision_encoder_in_channels: int = 3,
227
- fuse_qkv: bool = False,
228
- pytorch_rope_version: str = None,
229
- ) -> dict:
230
- """Create a vision-language model for training or inference.
231
- Args:
232
- model_ckpt_path (str): Path to the model checkpoint.
233
- tokenizer_ckpt_path (str): Path to the tokenizer checkpoint.
234
- model_family (str): Model family. Choices: "pixtral".
235
- model_size (str): Model size. Choices: "12b".
236
- is_instruct_model (bool): Whether the model is an instruct model.
237
- rope_dim (str): RoPE dimension. Choices: "1D".
238
- add_special_tokens (bool): Whether to add special tokens.
239
- max_seq_len (int): Maximum sequence length.
240
- vision_encoder_in_channels (int): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4 channel images where last channel is binary mask, set this to 4.
241
- fuse_qkv (bool): Whether to fuse the QKV linear layers.
242
- Returns:
243
- dict: A dictionary containing the model configuration, which can be used to instantiate the model object.
244
- """
245
- # Model size specific parameters
246
- model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
247
- if max_seq_len is not None:
248
- # Override the max_seq_len if provided
249
- model_arch_specs["max_seq_len"] = max_seq_len
250
- if pytorch_rope_version is not None:
251
- model_arch_specs["pytorch_rope_version"] = pytorch_rope_version
252
-
253
- model_config = ModelConfig(
254
- max_batch_size=max_batch_size,
255
- precision="bfloat16",
256
- ckpt_path=model_ckpt_path,
257
- use_qk_normalization=False,
258
- rope_dim=rope_dim,
259
- vision_encoder_in_channels=vision_encoder_in_channels,
260
- fuse_qkv=fuse_qkv,
261
- **model_arch_specs,
262
- )
263
- # Vision-language tokenizer
264
- tokenizer_config = TokenizerConfig(
265
- text_tokenizer=TextTokenizerConfig(
266
- config=L(ImageTextTokenizer)(
267
- model_family=model_family,
268
- is_instruct_model=is_instruct_model,
269
- image_processor_path=tokenizer_ckpt_path,
270
- tokenizer_path=tokenizer_ckpt_path,
271
- ),
272
- data_key="image_text_interleaved",
273
- tokenizer_offset=model_config.vocab_size,
274
- tokenize_here=False,
275
- vocab_size=model_config.vocab_size,
276
- ),
277
- seq_len=model_config.max_seq_len,
278
- training_type="image_text_interleaved",
279
- add_special_tokens=add_special_tokens,
280
- )
281
- return model_config, tokenizer_config
282
-
283
-
284
- def create_video2world_model_config(
285
- model_ckpt_path: str,
286
- tokenizer_ckpt_path: str,
287
- model_family: str = "cosmos",
288
- model_size: str = "4b",
289
- pixel_chunk_duration: int = 9,
290
- num_video_frames: int = 36,
291
- compression_ratio: List[int] = [8, 16, 16],
292
- original_seq_len: int = 8192,
293
- num_condition_latents_t: int = 1,
294
- num_tokens_to_ignore: int = -1,
295
- batch_size: int = 2,
296
- video_tokenizer_config_creator: Callable = create_discrete_video_fsq_tokenizer_state_dict_config,
297
- rope_dim: str = "3D",
298
- add_special_tokens: bool = True,
299
- video_height: int = 384,
300
- video_width: int = 640,
301
- use_qk_normalization: bool = True,
302
- insert_cross_attn: bool = False,
303
- insert_cross_attn_every_k_layers: int = 1,
304
- context_dim: int = 1024,
305
- training_type: str = "video_to_video",
306
- pad_to_multiple_of: Optional[int] = 64,
307
- vocab_size: int = 64000,
308
- apply_abs_pos_emb: bool = False,
309
- ) -> dict:
310
- """Create a video-to-world model config.
311
- Args:
312
- model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
313
- model_size (str): Model size. Choices: "1b", "8b", "3b".
314
- pixel_chunk_duration (int): Number of frames in each chunk.
315
- num_video_frames (int): Number of video frames.
316
- compression_ratio (List[int]): Compression ratio for the video frames. Choices: [8, 16, 16] or [4, 8, 8].
317
- original_seq_len (int): Original sequence length.
318
- apply_yarn (bool): Whether to apply YaRN for long context scaling.
319
- yarn_beta_fast (Optional[int]): Fast beta for YaRN.
320
- yarn_beta_slow (Optional[int]): Slow beta for YaRN.
321
- yarn_scale (Optional[int]): Scale factor for ctx extension.
322
- use_qk_normalization (bool): Whether to use Query-Key normalization.
323
- training_type (str): Type of training task.
324
- batch_size (int): Batch size.
325
- video_tokenizer_config_creator (Callable): Method that takes "pixel_chunk_duration: int" and "version: str" as arguments and returns video tokenizer config
326
- video_tokenizer_version (str): Version of the video tokenizer.
327
- num_condition_latents_t (int): Number of conditioning latent channels
328
- num_tokens_to_ignore (int) = Number of tokens to ignore. This takes the precedence
329
- video_height (int): Height of the video frame. Defaults to 384.
330
- video_width (int): Width of the video frame. Defaults to 640.
331
- rope_dim (str): RoPE dimension. Choices: "1D", "3D".
332
- add_special_tokens (bool): Whether to add special tokens, use False for 2D/3D RoPE.
333
- pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
334
- vocab_size (int): Vocabulary size.
335
- apply_abs_pos_emb (bool): Whether to apply absolute positional embeddings.
336
- Returns:
337
- dict: A dictionary containing the model configuration representing the model object, can be instantiated.
338
- """
339
- assert (
340
- pixel_chunk_duration % compression_ratio[0] == 1
341
- ), f"pixel_chunk_duration({pixel_chunk_duration}) should be k*n + 1 (k={compression_ratio[0]})"
342
- latent_chunk_duration = (pixel_chunk_duration - 1) // compression_ratio[0] + 1
343
- latent_height = video_height // compression_ratio[1]
344
- latent_width = video_width // compression_ratio[2]
345
- # Do some math to compute the video latent shape and sequence length
346
- assert (
347
- num_video_frames % pixel_chunk_duration == 0
348
- ), f"num_video_frames {num_video_frames} should be divisible by pixel_chunk_duration {pixel_chunk_duration}"
349
- video_latent_shape = [
350
- num_video_frames // pixel_chunk_duration * latent_chunk_duration,
351
- latent_height,
352
- latent_width,
353
- ]
354
- # product of video_latent_shape
355
- num_token_video_latent = video_latent_shape[0] * video_latent_shape[1] * video_latent_shape[2]
356
- if add_special_tokens:
357
- seq_len = num_token_video_latent + 3 # Sequence length per batch, max_seq_len + 3
358
- seq_len = (seq_len + 63) // 64 * 64 # Round up to multiple of 64
359
- # for text to video, we need to add <bov> token to indicate the start of the video
360
- elif training_type == "text_to_video":
361
- seq_len = num_token_video_latent + 1
362
- else:
363
- seq_len = num_token_video_latent
364
-
365
- if seq_len % pad_to_multiple_of != 0:
366
- # Round up to the nearest multiple of pad_to_multiple_of
367
- seq_len = ((seq_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
368
-
369
- # Model size specific parameters
370
- model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
371
-
372
- # Whether skip the loss for first chunk or not, note the first token is already skipped when computing the loss
373
- # If num_tokens_to_ignore is specified, use it.
374
- # Else compute it from num_condition_latents_t
375
- if num_tokens_to_ignore < 0:
376
- num_tokens_to_ignore = latent_height * latent_width * num_condition_latents_t
377
- if not add_special_tokens and num_condition_latents_t > 0:
378
- # If there are no special tokens (bov), do a -1 so that you can compute the loss
379
- # from the first token of the next chunk
380
- num_tokens_to_ignore -= 1
381
-
382
- model_config = ModelConfig(
383
- video_height=video_height,
384
- video_width=video_width,
385
- max_seq_len=seq_len,
386
- max_batch_size=batch_size,
387
- precision="bfloat16",
388
- ckpt_path=model_ckpt_path,
389
- use_qk_normalization=use_qk_normalization,
390
- vocab_size=64000,
391
- original_seq_len=original_seq_len,
392
- video_latent_shape=video_latent_shape,
393
- num_video_frames=num_video_frames,
394
- rope_dim=rope_dim,
395
- pad_to_multiple_of=pad_to_multiple_of,
396
- insert_cross_attn=insert_cross_attn,
397
- insert_cross_attn_every_k_layers=insert_cross_attn_every_k_layers,
398
- context_dim=context_dim,
399
- apply_abs_pos_emb=apply_abs_pos_emb,
400
- **model_arch_specs,
401
- )
402
-
403
- video_tokenizer_config = video_tokenizer_config_creator(
404
- tokenizer_ckpt_path, pixel_chunk_duration, compression_ratio
405
- )
406
- tokenizer_config = TokenizerConfig(
407
- text_tokenizer=None,
408
- video_tokenizer=VideoTokenizerConfig(
409
- config=video_tokenizer_config,
410
- data_key="video",
411
- tokenizer_offset=0, # Since there is no text embeddings in the model. Note this only apply when the model is trained from scratch. If we use text pretrained model, the offset will be vocab_size of text token.
412
- tokenize_here=True,
413
- max_seq_len=num_token_video_latent,
414
- vocab_size=vocab_size,
415
- ),
416
- seq_len=seq_len,
417
- training_type=training_type,
418
- add_special_tokens=add_special_tokens,
419
- pad_to_multiple_of=pad_to_multiple_of,
420
- )
421
- return model_config, tokenizer_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/53dea6ed871052e987bf5094f869778412202323 DELETED
@@ -1,360 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import argparse
17
- import json
18
- import math
19
- import os
20
- from pathlib import Path
21
- from typing import List
22
-
23
- import numpy as np
24
- import torch
25
- import torchvision
26
- from PIL import Image
27
-
28
- from .ar_config_inference_inference_config import SamplingConfig
29
- from .log import log
30
-
31
- _IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", "webp"]
32
- _VIDEO_EXTENSIONS = [".mp4"]
33
- _SUPPORTED_CONTEXT_LEN = [1, 9] # Input frames
34
- NUM_TOTAL_FRAMES = 33
35
-
36
-
37
- def add_common_arguments(parser):
38
- """Add common command line arguments.
39
-
40
- Args:
41
- parser (ArgumentParser): Argument parser to add arguments to
42
- """
43
- parser.add_argument(
44
- "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints"
45
- )
46
- parser.add_argument(
47
- "--video_save_name",
48
- type=str,
49
- default="output",
50
- help="Output filename for generating a single video",
51
- )
52
- parser.add_argument("--video_save_folder", type=str, default="outputs/", help="Output folder for saving videos")
53
- parser.add_argument(
54
- "--input_image_or_video_path",
55
- type=str,
56
- help="Input path for input image or video",
57
- )
58
- parser.add_argument(
59
- "--batch_input_path",
60
- type=str,
61
- help="Input folder containing all input images or videos",
62
- )
63
- parser.add_argument(
64
- "--num_input_frames",
65
- type=int,
66
- default=9,
67
- help="Number of input frames for world generation",
68
- choices=_SUPPORTED_CONTEXT_LEN,
69
- )
70
- parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling")
71
- parser.add_argument("--top_p", type=float, default=0.8, help="Top-p value for sampling")
72
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
73
- parser.add_argument("--disable_diffusion_decoder", action="store_true", help="Disable diffusion decoder")
74
- parser.add_argument(
75
- "--offload_guardrail_models",
76
- action="store_true",
77
- help="Offload guardrail models after inference",
78
- )
79
- parser.add_argument(
80
- "--offload_diffusion_decoder",
81
- action="store_true",
82
- help="Offload diffusion decoder after inference",
83
- )
84
- parser.add_argument(
85
- "--offload_ar_model",
86
- action="store_true",
87
- help="Offload AR model after inference",
88
- )
89
- parser.add_argument(
90
- "--offload_tokenizer",
91
- action="store_true",
92
- help="Offload discrete tokenizer model after inference",
93
- )
94
-
95
-
96
- def validate_args(args: argparse.Namespace, inference_type: str):
97
- """Validate command line arguments for base and video2world generation."""
98
- assert inference_type in [
99
- "base",
100
- "video2world",
101
- ], "Invalid inference_type, must be 'base' or 'video2world'"
102
- if args.input_type in ["image", "text_and_image"] and args.num_input_frames != 1:
103
- args.num_input_frames = 1
104
- log.info(f"Set num_input_frames to 1 for {args.input_type} input")
105
-
106
- if args.num_input_frames == 1:
107
- if "4B" in args.ar_model_dir:
108
- log.warning(
109
- "The failure rate for 4B model with image input is ~15%. 12B / 13B model have a smaller failure rate. Please be cautious and refer to README.md for more details."
110
- )
111
- elif "5B" in args.ar_model_dir:
112
- log.warning(
113
- "The failure rate for 5B model with image input is ~7%. 12B / 13B model have a smaller failure rate. Please be cautious and refer to README.md for more details."
114
- )
115
-
116
- # Validate prompt/image/video args for single or batch generation
117
- assert (
118
- args.input_image_or_video_path or args.batch_input_path
119
- ), "--input_image_or_video_path or --batch_input_path must be provided."
120
- if inference_type == "video2world" and (not args.batch_input_path):
121
- assert args.prompt, "--prompt is required for single video generation."
122
- args.data_resolution = [640, 1024]
123
-
124
- # Validate number of GPUs
125
- num_gpus = int(os.getenv("WORLD_SIZE", 1))
126
- assert num_gpus <= 1, "We support only single GPU inference for now"
127
-
128
- # Create output folder
129
- Path(args.video_save_folder).mkdir(parents=True, exist_ok=True)
130
-
131
- sampling_config = SamplingConfig(
132
- echo=True,
133
- temperature=args.temperature,
134
- top_p=args.top_p,
135
- compile_sampling=True,
136
- )
137
- return sampling_config
138
-
139
-
140
- def resize_input(video: torch.Tensor, resolution: list[int]):
141
- r"""
142
- Function to perform aspect ratio preserving resizing and center cropping.
143
- This is needed to make the video into target resolution.
144
- Args:
145
- video (torch.Tensor): Input video tensor
146
- resolution (list[int]): Data resolution
147
- Returns:
148
- Cropped video
149
- """
150
-
151
- orig_h, orig_w = video.shape[2], video.shape[3]
152
- target_h, target_w = resolution
153
-
154
- scaling_ratio = max((target_w / orig_w), (target_h / orig_h))
155
- resizing_shape = (int(math.ceil(scaling_ratio * orig_h)), int(math.ceil(scaling_ratio * orig_w)))
156
- video_resized = torchvision.transforms.functional.resize(video, resizing_shape)
157
- video_cropped = torchvision.transforms.functional.center_crop(video_resized, resolution)
158
- return video_cropped
159
-
160
-
161
- def load_image_from_list(flist, data_resolution: List[int]) -> dict:
162
- """
163
- Function to load images from a list of image paths.
164
- Args:
165
- flist (List[str]): List of image paths
166
- data_resolution (List[int]): Data resolution
167
- Returns:
168
- Dict containing input images
169
- """
170
- all_videos = dict()
171
- for img_path in flist:
172
- ext = os.path.splitext(img_path)[1]
173
- if ext in _IMAGE_EXTENSIONS:
174
- # Read the image
175
- img = Image.open(img_path)
176
-
177
- # Convert to tensor
178
- img = torchvision.transforms.functional.to_tensor(img)
179
- static_vid = img.unsqueeze(0).repeat(NUM_TOTAL_FRAMES, 1, 1, 1)
180
- static_vid = static_vid * 2 - 1
181
-
182
- log.debug(
183
- f"Resizing input image of shape ({static_vid.shape[2]}, {static_vid.shape[3]}) -> ({data_resolution[0]}, {data_resolution[1]})"
184
- )
185
- static_vid = resize_input(static_vid, data_resolution)
186
- fname = os.path.basename(img_path)
187
- all_videos[fname] = static_vid.transpose(0, 1).unsqueeze(0)
188
-
189
- return all_videos
190
-
191
-
192
- def read_input_images(batch_input_path: str, data_resolution: List[int]) -> dict:
193
- """
194
- Function to read input images from a JSONL file.
195
-
196
- Args:
197
- batch_input_path (str): Path to JSONL file containing visual input paths
198
- data_resolution (list[int]): Data resolution
199
-
200
- Returns:
201
- Dict containing input images
202
- """
203
- # Read visual inputs from JSONL
204
- flist = []
205
- with open(batch_input_path, "r") as f:
206
- for line in f:
207
- data = json.loads(line.strip())
208
- flist.append(data["visual_input"])
209
-
210
- return load_image_from_list(flist, data_resolution=data_resolution)
211
-
212
-
213
- def read_input_image(input_path: str, data_resolution: List[int]) -> dict:
214
- """
215
- Function to read input image.
216
- Args:
217
- input_path (str): Path to input image
218
- data_resolution (List[int]): Data resolution
219
- Returns:
220
- Dict containing input image
221
- """
222
- flist = [input_path]
223
- return load_image_from_list(flist, data_resolution=data_resolution)
224
-
225
-
226
- def read_input_videos(batch_input_path: str, data_resolution: List[int], num_input_frames: int) -> dict:
227
- r"""
228
- Function to read input videos.
229
- Args:
230
- batch_input_path (str): Path to JSONL file containing visual input paths
231
- data_resolution (list[int]): Data resolution
232
- Returns:
233
- Dict containing input videos
234
- """
235
- # Read visual inputs from JSONL
236
- flist = []
237
- with open(batch_input_path, "r") as f:
238
- for line in f:
239
- data = json.loads(line.strip())
240
- flist.append(data["visual_input"])
241
- return load_videos_from_list(flist, data_resolution=data_resolution, num_input_frames=num_input_frames)
242
-
243
-
244
- def read_input_video(input_path: str, data_resolution: List[int], num_input_frames: int) -> dict:
245
- """
246
- Function to read input video.
247
- Args:
248
- input_path (str): Path to input video
249
- data_resolution (List[int]): Data resolution
250
- num_input_frames (int): Number of frames in context
251
- Returns:
252
- Dict containing input video
253
- """
254
- flist = [input_path]
255
- return load_videos_from_list(flist, data_resolution=data_resolution, num_input_frames=num_input_frames)
256
-
257
-
258
- def load_videos_from_list(flist: List[str], data_resolution: List[int], num_input_frames: int) -> dict:
259
- """
260
- Function to load videos from a list of video paths.
261
- Args:
262
- flist (List[str]): List of video paths
263
- data_resolution (List[int]): Data resolution
264
- num_input_frames (int): Number of frames in context
265
- Returns:
266
- Dict containing input videos
267
- """
268
- all_videos = dict()
269
-
270
- for video_path in flist:
271
- ext = os.path.splitext(video_path)[-1]
272
- if ext in _VIDEO_EXTENSIONS:
273
- video, _, _ = torchvision.io.read_video(video_path, pts_unit="sec")
274
- video = video.float() / 255.0
275
- video = video * 2 - 1
276
-
277
- # Resize the videos to the required dimension
278
- nframes_in_video = video.shape[0]
279
- if nframes_in_video < num_input_frames:
280
- fname = os.path.basename(video_path)
281
- log.warning(
282
- f"Video {fname} has {nframes_in_video} frames, less than the requried {num_input_frames} frames. Skipping."
283
- )
284
- continue
285
-
286
- video = video[-num_input_frames:, :, :, :]
287
-
288
- # Pad the video to NUM_TOTAL_FRAMES (because the tokenizer expects inputs of NUM_TOTAL_FRAMES)
289
- video = torch.cat(
290
- (video, video[-1, :, :, :].unsqueeze(0).repeat(NUM_TOTAL_FRAMES - num_input_frames, 1, 1, 1)),
291
- dim=0,
292
- )
293
-
294
- video = video.permute(0, 3, 1, 2)
295
-
296
- log.debug(
297
- f"Resizing input video of shape ({video.shape[2]}, {video.shape[3]}) -> ({data_resolution[0]}, {data_resolution[1]})"
298
- )
299
- video = resize_input(video, data_resolution)
300
-
301
- fname = os.path.basename(video_path)
302
- all_videos[fname] = video.transpose(0, 1).unsqueeze(0)
303
-
304
- return all_videos
305
-
306
-
307
- def load_vision_input(
308
- input_type: str,
309
- batch_input_path: str,
310
- input_image_or_video_path: str,
311
- data_resolution: List[int],
312
- num_input_frames: int,
313
- ):
314
- """
315
- Function to load vision input.
316
- Note: We pad the frames of the input image/video to NUM_TOTAL_FRAMES here, and feed the padded video tensors to the video tokenizer to obtain tokens. The tokens will be truncated based on num_input_frames when feeding to the autoregressive model.
317
- Args:
318
- input_type (str): Type of input
319
- batch_input_path (str): Folder containing input images or videos
320
- input_image_or_video_path (str): Path to input image or video
321
- data_resolution (List[int]): Data resolution
322
- num_input_frames (int): Number of frames in context
323
- Returns:
324
- Dict containing input videos
325
- """
326
- if batch_input_path:
327
- log.info(f"Reading batch inputs from path: {batch_input_path}")
328
- if input_type == "image" or input_type == "text_and_image":
329
- input_videos = read_input_images(batch_input_path, data_resolution=data_resolution)
330
- elif input_type == "video" or input_type == "text_and_video":
331
- input_videos = read_input_videos(
332
- batch_input_path,
333
- data_resolution=data_resolution,
334
- num_input_frames=num_input_frames,
335
- )
336
- else:
337
- raise ValueError(f"Invalid input type {input_type}")
338
- else:
339
- if input_type == "image" or input_type == "text_and_image":
340
- input_videos = read_input_image(input_image_or_video_path, data_resolution=data_resolution)
341
- elif input_type == "video" or input_type == "text_and_video":
342
- input_videos = read_input_video(
343
- input_image_or_video_path,
344
- data_resolution=data_resolution,
345
- num_input_frames=num_input_frames,
346
- )
347
- else:
348
- raise ValueError(f"Invalid input type {input_type}")
349
- return input_videos
350
-
351
-
352
- def prepare_video_batch_for_saving(video_batch: List[torch.Tensor]) -> List[np.ndarray]:
353
- """
354
- Function to convert output tensors to numpy format for saving.
355
- Args:
356
- video_batch (List[torch.Tensor]): List of output tensors
357
- Returns:
358
- List of numpy arrays
359
- """
360
- return [(video * 255).to(torch.uint8).permute(1, 2, 3, 0).cpu().numpy() for video in video_batch]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/54ff4d48b535d2a1f27bbcc75c20ef16821b11e1 DELETED
@@ -1,341 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from dataclasses import dataclass
17
- from typing import Callable, Dict, Optional, Tuple, Union
18
-
19
- from .misc import misc
20
- import torch
21
- from torch import Tensor
22
-
23
- from .df_conditioner import VideoExtendCondition
24
- from .df_config_base_conditioner import VideoCondBoolConfig
25
- from .df_df_functional_batch_ops import batch_mul
26
- from .df_model_model_t2w import DiffusionT2WModel
27
- from .log import log
28
-
29
-
30
- @dataclass
31
- class VideoDenoisePrediction:
32
- x0: torch.Tensor # clean data prediction
33
- eps: Optional[torch.Tensor] = None # noise prediction
34
- logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty
35
- xt: Optional[torch.Tensor] = None # input to the network, before muliply with c_in
36
- x0_pred_replaced: Optional[torch.Tensor] = None # x0 prediction with condition region replaced by gt_latent
37
-
38
-
39
- class DiffusionV2WModel(DiffusionT2WModel):
40
- def __init__(self, config):
41
- super().__init__(config)
42
-
43
- def augment_conditional_latent_frames(
44
- self,
45
- condition: VideoExtendCondition,
46
- cfg_video_cond_bool: VideoCondBoolConfig,
47
- gt_latent: Tensor,
48
- condition_video_augment_sigma_in_inference: float = 0.001,
49
- sigma: Tensor = None,
50
- seed: int = 1,
51
- ) -> Union[VideoExtendCondition, Tensor]:
52
- """Augments the conditional frames with noise during inference.
53
-
54
- Args:
55
- condition (VideoExtendCondition): condition object
56
- condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor.
57
- condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network.
58
- cfg_video_cond_bool (VideoCondBoolConfig): video condition bool config
59
- gt_latent (Tensor): ground truth latent tensor in shape B,C,T,H,W
60
- condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference
61
- sigma (Tensor): noise level for the generation region
62
- seed (int): random seed for reproducibility
63
- Returns:
64
- VideoExtendCondition: updated condition object
65
- condition_video_augment_sigma: sigma for the condition region, feed to the network
66
- augment_latent (Tensor): augmented latent tensor in shape B,C,T,H,W
67
-
68
- """
69
-
70
- # Inference only, use fixed sigma for the condition region
71
- assert (
72
- condition_video_augment_sigma_in_inference is not None
73
- ), "condition_video_augment_sigma_in_inference should be provided"
74
- augment_sigma = condition_video_augment_sigma_in_inference
75
-
76
- if augment_sigma >= sigma.flatten()[0]:
77
- # This is a inference trick! If the sampling sigma is smaller than the augment sigma, we will start denoising the condition region together.
78
- # This is achieved by setting all region as `generation`, i.e. value=0
79
- log.debug("augment_sigma larger than sigma or other frame, remove condition")
80
- condition.condition_video_indicator = condition.condition_video_indicator * 0
81
-
82
- augment_sigma = torch.tensor([augment_sigma], **self.tensor_kwargs)
83
-
84
- # Now apply the augment_sigma to the gt_latent
85
-
86
- noise = misc.arch_invariant_rand(
87
- gt_latent.shape,
88
- torch.float32,
89
- self.tensor_kwargs["device"],
90
- seed,
91
- )
92
-
93
- augment_latent = gt_latent + noise * augment_sigma[:, None, None, None, None]
94
-
95
- _, _, c_in_augment, _ = self.scaling(sigma=augment_sigma)
96
-
97
- # Multiply the whole latent with c_in_augment
98
- augment_latent_cin = batch_mul(augment_latent, c_in_augment)
99
-
100
- # Since the whole latent will multiply with c_in later, we devide the value to cancel the effect
101
- _, _, c_in, _ = self.scaling(sigma=sigma)
102
- augment_latent_cin = batch_mul(augment_latent_cin, 1 / c_in)
103
-
104
- return condition, augment_latent_cin
105
-
106
- def denoise(
107
- self,
108
- noise_x: Tensor,
109
- sigma: Tensor,
110
- condition: VideoExtendCondition,
111
- condition_video_augment_sigma_in_inference: float = 0.001,
112
- seed: int = 1,
113
- ) -> VideoDenoisePrediction:
114
- """Denoises input tensor using conditional video generation.
115
-
116
- Args:
117
- noise_x (Tensor): Noisy input tensor.
118
- sigma (Tensor): Noise level.
119
- condition (VideoExtendCondition): Condition for denoising.
120
- condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference
121
- seed (int): Random seed for reproducibility
122
- Returns:
123
- VideoDenoisePrediction containing:
124
- - x0: Denoised prediction
125
- - eps: Noise prediction
126
- - logvar: Log variance of noise prediction
127
- - xt: Input before c_in multiplication
128
- - x0_pred_replaced: x0 prediction with condition regions replaced by ground truth
129
- """
130
-
131
- assert (
132
- condition.gt_latent is not None
133
- ), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}"
134
- gt_latent = condition.gt_latent
135
- cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool
136
-
137
- condition_latent = gt_latent
138
-
139
- # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed
140
- condition, augment_latent = self.augment_conditional_latent_frames(
141
- condition, cfg_video_cond_bool, condition_latent, condition_video_augment_sigma_in_inference, sigma, seed
142
- )
143
- condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1]
144
-
145
- # Compose the model input with condition region (augment_latent) and generation region (noise_x)
146
- new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x
147
- # Call the abse model
148
- denoise_pred = super().denoise(new_noise_xt, sigma, condition)
149
-
150
- x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0
151
-
152
- x0_pred = x0_pred_replaced
153
-
154
- return VideoDenoisePrediction(
155
- x0=x0_pred,
156
- eps=batch_mul(noise_x - x0_pred, 1.0 / sigma),
157
- logvar=denoise_pred.logvar,
158
- xt=new_noise_xt,
159
- x0_pred_replaced=x0_pred_replaced,
160
- )
161
-
162
- def generate_samples_from_batch(
163
- self,
164
- data_batch: Dict,
165
- guidance: float = 1.5,
166
- seed: int = 1,
167
- state_shape: Tuple | None = None,
168
- n_sample: int | None = None,
169
- is_negative_prompt: bool = False,
170
- num_steps: int = 35,
171
- condition_latent: Union[torch.Tensor, None] = None,
172
- num_condition_t: Union[int, None] = None,
173
- condition_video_augment_sigma_in_inference: float = None,
174
- add_input_frames_guidance: bool = False,
175
- x_sigma_max: Optional[torch.Tensor] = None,
176
- ) -> Tensor:
177
- """Generates video samples conditioned on input frames.
178
-
179
- Args:
180
- data_batch: Input data dictionary
181
- guidance: Classifier-free guidance scale
182
- seed: Random seed for reproducibility
183
- state_shape: Shape of output tensor (defaults to model's state shape)
184
- n_sample: Number of samples to generate (defaults to batch size)
185
- is_negative_prompt: Whether to use negative prompting
186
- num_steps: Number of denoising steps
187
- condition_latent: Conditioning frames tensor (B,C,T,H,W)
188
- num_condition_t: Number of frames to condition on
189
- condition_video_augment_sigma_in_inference: Noise level for condition augmentation
190
- add_input_frames_guidance: Whether to apply guidance to input frames
191
- x_sigma_max: Maximum noise level tensor
192
-
193
- Returns:
194
- Generated video samples tensor
195
- """
196
-
197
- if n_sample is None:
198
- input_key = self.input_data_key
199
- n_sample = data_batch[input_key].shape[0]
200
- if state_shape is None:
201
- log.debug(f"Default Video state shape is used. {self.state_shape}")
202
- state_shape = self.state_shape
203
-
204
- assert condition_latent is not None, "condition_latent should be provided"
205
-
206
- x0_fn = self.get_x0_fn_from_batch_with_condition_latent(
207
- data_batch,
208
- guidance,
209
- is_negative_prompt=is_negative_prompt,
210
- condition_latent=condition_latent,
211
- num_condition_t=num_condition_t,
212
- condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
213
- add_input_frames_guidance=add_input_frames_guidance,
214
- seed=seed,
215
- )
216
- if x_sigma_max is None:
217
- x_sigma_max = (
218
- misc.arch_invariant_rand(
219
- (n_sample,) + tuple(state_shape),
220
- torch.float32,
221
- self.tensor_kwargs["device"],
222
- seed,
223
- )
224
- * self.sde.sigma_max
225
- )
226
-
227
- samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max)
228
- return samples
229
-
230
- def get_x0_fn_from_batch_with_condition_latent(
231
- self,
232
- data_batch: Dict,
233
- guidance: float = 1.5,
234
- is_negative_prompt: bool = False,
235
- condition_latent: torch.Tensor = None,
236
- num_condition_t: Union[int, None] = None,
237
- condition_video_augment_sigma_in_inference: float = None,
238
- add_input_frames_guidance: bool = False,
239
- seed: int = 1,
240
- ) -> Callable:
241
- """Creates denoising function for conditional video generation.
242
-
243
- Args:
244
- data_batch: Input data dictionary
245
- guidance: Classifier-free guidance scale
246
- is_negative_prompt: Whether to use negative prompting
247
- condition_latent: Conditioning frames tensor (B,C,T,H,W)
248
- num_condition_t: Number of frames to condition on
249
- condition_video_augment_sigma_in_inference: Noise level for condition augmentation
250
- add_input_frames_guidance: Whether to apply guidance to input frames
251
- seed: Random seed for reproducibility
252
-
253
- Returns:
254
- Function that takes noisy input and noise level and returns denoised prediction
255
- """
256
- if is_negative_prompt:
257
- condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
258
- else:
259
- condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
260
-
261
- condition.video_cond_bool = True
262
- condition = self.add_condition_video_indicator_and_video_input_mask(
263
- condition_latent, condition, num_condition_t
264
- )
265
-
266
- uncondition.video_cond_bool = False if add_input_frames_guidance else True
267
- uncondition = self.add_condition_video_indicator_and_video_input_mask(
268
- condition_latent, uncondition, num_condition_t
269
- )
270
-
271
- def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
272
- cond_x0 = self.denoise(
273
- noise_x,
274
- sigma,
275
- condition,
276
- condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
277
- seed=seed,
278
- ).x0_pred_replaced
279
- uncond_x0 = self.denoise(
280
- noise_x,
281
- sigma,
282
- uncondition,
283
- condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
284
- seed=seed,
285
- ).x0_pred_replaced
286
-
287
- return cond_x0 + guidance * (cond_x0 - uncond_x0)
288
-
289
- return x0_fn
290
-
291
- def add_condition_video_indicator_and_video_input_mask(
292
- self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None
293
- ) -> VideoExtendCondition:
294
- """Adds conditioning masks to VideoExtendCondition object.
295
-
296
- Creates binary indicators and input masks for conditional video generation.
297
-
298
- Args:
299
- latent_state: Input latent tensor (B,C,T,H,W)
300
- condition: VideoExtendCondition object to update
301
- num_condition_t: Number of frames to condition on
302
-
303
- Returns:
304
- Updated VideoExtendCondition with added masks:
305
- - condition_video_indicator: Binary tensor marking condition regions
306
- - condition_video_input_mask: Input mask for network
307
- - gt_latent: Ground truth latent tensor
308
- """
309
- T = latent_state.shape[2]
310
- latent_dtype = latent_state.dtype
311
- condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type(
312
- latent_dtype
313
- ) # 1 for condition region
314
-
315
- # Only in inference to decide the condition region
316
- assert num_condition_t is not None, "num_condition_t should be provided"
317
- assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}"
318
- log.debug(
319
- f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}"
320
- )
321
- condition_video_indicator[:, :, :num_condition_t] += 1.0
322
-
323
- condition.gt_latent = latent_state
324
- condition.condition_video_indicator = condition_video_indicator
325
-
326
- B, C, T, H, W = latent_state.shape
327
- # Create additional input_mask channel, this will be concatenated to the input of the network
328
- # See design doc section (Implementation detail A.1 and A.2) for visualization
329
- ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device)
330
- zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device)
331
- assert condition.video_cond_bool is not None, "video_cond_bool should be set"
332
-
333
- # The input mask indicate whether the input is conditional region or not
334
- if condition.video_cond_bool: # Condition one given video frames
335
- condition.condition_video_input_mask = (
336
- condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding
337
- )
338
- else: # Unconditional case, use for cfg
339
- condition.condition_video_input_mask = zeros_padding
340
-
341
- return condition
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/578cd9ecfca36e5376fef8da5106652c6ca85b68 DELETED
@@ -1,262 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import math
17
- from typing import Optional, Union
18
-
19
- import torch
20
- from torch import nn
21
-
22
- from .ar_module_embedding import RotaryPositionEmbedding
23
- from .ar_module_normalization import create_norm
24
-
25
-
26
- class Attention(nn.Module):
27
- """
28
- Attenion layer with KV cache.
29
- """
30
-
31
- def __init__(
32
- self,
33
- n_heads: int,
34
- n_kv_heads: Union[int, None],
35
- dim: int,
36
- max_batch_size: int,
37
- max_seq_len: int,
38
- context_dim: Optional[int] = None,
39
- use_qk_normalization: bool = False,
40
- norm_type: str = "rmsnorm",
41
- norm_eps: float = 1e-5,
42
- causal_mask: Optional[bool] = True,
43
- head_dim: Optional[int] = None,
44
- fuse_qkv: bool = False,
45
- precision: str = "bfloat16",
46
- attn_type: str = "self",
47
- ):
48
- """
49
- Initializes the GQA module.
50
-
51
- Args:
52
- n_heads (int): The number of attention heads.
53
- n_kv_heads (int, optional): The number of key-value attention heads. None defaults to n_heads.
54
- dim (int): The dimensionality of the input and output.
55
- max_batch_size (int): The maximum batch size.
56
- max_seq_len (int): The maximum sequence length.
57
- context_dim (int, optional): The dimensionality of the context for cross-attn. Defaults to None.
58
- use_qk_normalization (bool, optional): Whether to apply QK normalization. Defaults to False.
59
- norm_type (str, optional): The type of normalization layer. Defaults to "rmsnorm".
60
- norm_eps (float, optional): The epsilon value for normalization. Defaults to 1e-5.
61
- causal_mask (bool, optional): Whether to use causal mask. Defaults to True.
62
- head_dim (int, optional): The dimensionality of each attention head. If None, defaults to dim // n_heads.
63
- fuse_qkv (bool, optional): Whether to fuse QKV. Defaults to False.
64
- precision (str, optional): The precision of the module. Defaults to "bfloat16".
65
- attn_type (str, optional): The type of attention. Defaults to "self".
66
- """
67
- super().__init__()
68
- assert attn_type in ["self", "cross", "full"], f"Invalid attention type: {attn_type}"
69
- self.attn_type = attn_type
70
- context_dim = dim if context_dim is None else context_dim
71
-
72
- self.dim = dim
73
- self.context_dim = context_dim
74
- self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
75
- self.n_local_kv_heads = self.n_kv_heads
76
- self.n_local_heads = n_heads
77
- self.n_rep = self.n_local_heads // self.n_local_kv_heads
78
- self.head_dim = dim // n_heads if head_dim is None else head_dim
79
- self.causal_mask = causal_mask
80
- self.fuse_qkv = fuse_qkv
81
- self.precision = precision
82
-
83
- if fuse_qkv:
84
- assert context_dim == dim, f"Fuse QKV requires context_dim ({context_dim}) to be equal to dim ({dim})"
85
- self.total_local_head_dim = (self.n_local_heads + 2 * self.n_local_kv_heads) * self.head_dim
86
- self.wqkv = nn.Linear(dim, self.total_local_head_dim, bias=False)
87
- # Register hook to load fused QKV weights
88
- self._register_load_state_dict_pre_hook(self.load_hook)
89
- else:
90
- self.wq = nn.Linear(dim, self.n_local_heads * self.head_dim, bias=False)
91
- self.wk = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False)
92
- self.wv = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False)
93
- self.wo = nn.Linear(self.n_local_heads * self.head_dim, dim, bias=False)
94
-
95
- self.max_batch_size = max_batch_size
96
- self.max_seq_len = max_seq_len
97
-
98
- if self.attn_type == "self":
99
- # Cache for key and value tensors
100
- self.init_kv_cache()
101
-
102
- # QK normalization layers
103
- if use_qk_normalization:
104
- self.q_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps)
105
- self.k_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps)
106
-
107
- self.use_qk_normalization = use_qk_normalization
108
-
109
- self.to(dtype=getattr(torch, self.precision))
110
-
111
- def load_hook(self, state_dict, prefix, *args):
112
- if prefix + "wq.weight" in state_dict:
113
- wq = state_dict.pop(prefix + "wq.weight")
114
- wk = state_dict.pop(prefix + "wk.weight")
115
- wv = state_dict.pop(prefix + "wv.weight")
116
- state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
117
-
118
- def init_kv_cache(self, dtype=None):
119
- cache_shape = (self.max_batch_size, self.n_local_kv_heads, self.max_seq_len, self.head_dim)
120
- if dtype is None:
121
- dtype = getattr(torch, self.precision)
122
- if self.attn_type == "self":
123
- self.cache_k = torch.zeros(cache_shape, dtype=dtype).cuda()
124
- self.cache_v = torch.zeros(cache_shape, dtype=dtype).cuda()
125
-
126
- def forward(
127
- self,
128
- x: torch.Tensor,
129
- rope: RotaryPositionEmbedding,
130
- input_pos: torch.Tensor,
131
- mask: Optional[torch.Tensor] = None,
132
- context: Optional[torch.Tensor] = None,
133
- ):
134
- """
135
- Forward pass of GQA.
136
-
137
- Args:
138
- x: The input tensor of shape (batch_size, seq_len, dim).
139
- rope: The rotary positional embedding module.
140
- input_pos: The starting position of the current sequence.
141
- mask: The attention mask tensor.
142
- context: The context tensor of shape (batch_size, context_len, dim).
143
-
144
- Returns:
145
- The output tensor after applying GQA.
146
- """
147
- bsz, seqlen, _ = x.shape
148
-
149
- # Use one single module to handle both self-attn and cross-attn
150
- context = x if context is None else context
151
- context_len = seqlen if context is None else context.shape[1]
152
-
153
- if self.fuse_qkv:
154
- q_size = self.n_local_heads * self.head_dim
155
- kv_size = self.n_local_kv_heads * self.head_dim
156
- xq, xk, xv = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
157
- else:
158
- # Compute query, key, and value projections
159
- xq, xk, xv = self.wq(x), self.wk(context), self.wv(context)
160
-
161
- # Reshape projections
162
- xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
163
- xk = xk.view(bsz, context_len, self.n_local_kv_heads, self.head_dim)
164
- xv = xv.view(bsz, context_len, self.n_local_kv_heads, self.head_dim)
165
-
166
- # QK normalization
167
- if self.use_qk_normalization:
168
- xq = self.q_norm(xq)
169
- xk = self.k_norm(xk)
170
-
171
- # Apply rotary positional embeddings to queries and keys
172
- # Only apply RoPE to self-attention!
173
- if self.attn_type in ["self", "full"]:
174
- xq, xk = rope(xq, xk, input_pos, seqlen)
175
-
176
- xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
177
- # xq: (bs, n_local_heads, seqlen, head_dim)
178
- # xk: (bs, n_kv_heads, cache_len + context_len, head_dim)
179
- # xv: (bs, n_kv_heads, cache_len + context_len, head_dim)
180
- if self.attn_type == "self":
181
- # Update cache with current key and value tensors
182
- assert input_pos is not None
183
- self.cache_k[:bsz, :, input_pos] = xk
184
- self.cache_v[:bsz, :, input_pos] = xv
185
- keys, values = (
186
- self.cache_k[:bsz, :, :],
187
- self.cache_v[:bsz, :, :],
188
- )
189
- else:
190
- keys, values = xk, xv
191
-
192
- # Repeat keys and values if necessary
193
- keys = keys.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim)
194
- values = values.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim)
195
-
196
- # For self-attention, `is_causal` should be set to False when KV cache is pre-computed and used,
197
- # since the masking is handled outside this attention module.
198
- # For cross-attention, it's always full-attn without causal mask
199
- is_causal = False
200
- output = scaled_dot_product_attention(
201
- xq,
202
- keys,
203
- values,
204
- head_dim=self.head_dim,
205
- mask=mask,
206
- is_causal=is_causal,
207
- dropout_p=0.0,
208
- )
209
- output = output.view(bsz, seqlen, -1)
210
- output = self.wo(output)
211
- return output
212
-
213
-
214
- def scaled_dot_product_attention(
215
- q: torch.Tensor,
216
- k: torch.Tensor,
217
- v: torch.Tensor,
218
- head_dim: int,
219
- mask: Optional[torch.Tensor] = None,
220
- is_causal: Optional[bool] = None,
221
- dropout_p: float = 0.0,
222
- ) -> torch.Tensor:
223
- """
224
- PyTorch's native implementation of Flash Attention 2.
225
-
226
- If `is_causal` is given, then the causal attention mask is applied accordingly:
227
- - If `is_causal` is True, the standard upper-left causal attention masking is applied.
228
- - If `is_causal` is False, no attention mask is applied, unless an explicit mask tensor is
229
- provided (i.e., `mask is not None`).
230
-
231
- If `is_causal` is not given (i.e., `is_causal is None`), then the attention mask is applied
232
- based on the provided mask tensor:
233
- - If no explicit attention mask is given (i.e., `mask is None`), `is_causal` is set to True,
234
- leading to the standard upper-left causal attention masking.
235
- - If an attention mask is given (i.e., `mask is not None`), the provided mask is used,
236
- and `is_causal` is set to False.
237
-
238
- Args:
239
- q (torch.Tensor): Query tensor
240
- k (torch.Tensor): Key tensor
241
- v (torch.Tensor): Value tensor
242
- head_dim (int): Dimension of each attention head
243
- mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None.
244
- is_causal (Optional[bool], optional): Whether to apply causal attention mask. Defaults to None.
245
- dropout_p (float, optional): Dropout rate. Defaults to 0.0.
246
-
247
- Returns:
248
- torch.Tensor: Output tensor after applying scaled dot-product attention
249
- """
250
- scale = 1.0 / math.sqrt(head_dim)
251
- if is_causal is None:
252
- is_causal = mask is None
253
- y = torch.nn.functional.scaled_dot_product_attention(
254
- q,
255
- k,
256
- v,
257
- attn_mask=mask,
258
- dropout_p=dropout_p,
259
- scale=scale,
260
- is_causal=is_causal,
261
- )
262
- return y.transpose(1, 2).contiguous()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/5877aa166d1d946b98ce604e2bd1a4284b884ae6 DELETED
@@ -1,318 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Any, Dict, List, Optional, Union
17
-
18
- import numpy as np
19
- import torch
20
- import transformers
21
- from transformers import AutoImageProcessor
22
- from transformers.image_utils import ImageInput, is_valid_image, load_image
23
-
24
- from .ar_tokenizer_text_tokenizer import TextTokenizer
25
- from .log import log
26
-
27
- # Configuration for different vision-language models
28
- IMAGE_CONFIGS = {
29
- "pixtral": {
30
- "patch_size": 16,
31
- "image_token": "[IMG]",
32
- "image_break_token": "[IMG_BREAK]",
33
- "image_end_token": "[IMG_END]",
34
- }
35
- }
36
-
37
- # Chat template for Pixtral-12B-Instruct
38
- PIXTRAL_CHAT_TEMPLATE = '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}\n {%- endif %}\n {%- if message["role"] == "user" %}\n {%- if loop.last and system_message is defined %}\n {{- "[INST]" + system_message + "\n\n" }}\n {%- else %}\n {{- "[INST]" }}\n {%- endif %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["content"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message["content"] }}\n {%- endif %}\n {{- "[/INST]" }}\n {%- elif message["role"] == "assistant" %}\n {{- message["content"] + eos_token}}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}'
39
-
40
-
41
- # Copied from transformers.models.pixtral.processing_pixtral.is_url
42
- def is_url(val) -> bool:
43
- """Check if the given value is a URL."""
44
- return isinstance(val, str) and val.startswith("http")
45
-
46
-
47
- # Copied from transformers.models.pixtral.processing_pixtral.is_image_or_image_url
48
- def is_image_or_image_url(elem):
49
- """Check if the given element is an image or an image URL."""
50
- return is_url(elem) or is_valid_image(elem)
51
-
52
-
53
- def load_image_list(
54
- image_list: List[Union[str, "PIL.Image.Image"]], timeout: Optional[float] = None
55
- ) -> List["PIL.Image.Image"]:
56
- """
57
- Load a list of images.
58
-
59
- Args:
60
- image_list (List[Union[str, PIL.Image.Image]]): The list of images to load.
61
- timeout (Optional[float]): The timeout for loading the image.
62
-
63
- Returns:
64
- List[PIL.Image.Image]: The list of loaded images.
65
- """
66
- return [load_image(image, timeout=timeout) for image in image_list]
67
-
68
-
69
- class ImageTextTokenizer(TextTokenizer):
70
- """
71
- Image-text tokenizer class that extends the text tokenizer to support vision tokens as well.
72
- """
73
-
74
- def __init__(
75
- self,
76
- model_family: str,
77
- is_instruct_model: bool,
78
- tokenizer_path: str,
79
- image_processor_path: str,
80
- ):
81
- """
82
- Initialize the ImageTextTokenizer.
83
-
84
- Args:
85
- model_family (str): The model family.
86
- is_instruct_model (bool): Whether the model is an instruct model.
87
- s3_credential_path (str): The path to the s3 credential file. Defaults to "credentials/pbss_dir.secret".
88
-
89
- Raises:
90
- AssertionError: If the model family is not supported or if the transformers version is incompatible.
91
- """
92
- super().__init__(
93
- model_family=model_family,
94
- is_instruct_model=is_instruct_model,
95
- local_path=tokenizer_path,
96
- )
97
- assert model_family in ["pixtral"], f"Unsupported model family: {model_family}"
98
- if model_family == "pixtral":
99
- # Need transformers>=4.45.0
100
- assert transformers.__version__ >= "4.45.0", "Pixtral requires transformers>=4.45.0"
101
- assert is_instruct_model, "Pixtral requires is_instruct_model=True"
102
- if not hasattr(self.tokenizer, "chat_template") or self.tokenizer.chat_template is None:
103
- setattr(self.tokenizer, "chat_template", PIXTRAL_CHAT_TEMPLATE)
104
- log.debug(f"Pixtral tokenizer chat template set to: {PIXTRAL_CHAT_TEMPLATE}")
105
-
106
- # Set up image-specific configurations
107
- image_config = IMAGE_CONFIGS[model_family]
108
- self.patch_size = image_config["patch_size"]
109
- self.image_token = image_config["image_token"]
110
- self.image_break_token = image_config["image_break_token"]
111
- self.image_end_token = image_config["image_end_token"]
112
-
113
- # Initialize the image processor
114
- self.image_processor = AutoImageProcessor.from_pretrained(image_processor_path)
115
-
116
- def encode(
117
- self,
118
- text: Union[str, List[str], List[int]],
119
- *, # Enforce keyword-only arguments
120
- images: Optional[ImageInput] = None,
121
- image_kwargs: Optional[Dict[str, Any]] = None,
122
- **text_kwargs,
123
- ) -> List[int]:
124
- """
125
- Process the images and return the tokenized images and text.
126
-
127
- Args:
128
- text (`str`, `List[str]`, `List[List[str]]`):
129
- The sequence or batch of sequences to be encoded.
130
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
131
- The image or batch of images to be prepared.
132
- image_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for image processing.
133
- **text_kwargs: Additional keyword arguments for text processing.
134
-
135
- Returns:
136
- A dictionary with the following fields:
137
- - **input_ids** -- List of token ids to be fed to a model.
138
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
139
- - **pixel_values** -- Pixel values to be fed to a model.
140
-
141
- Raises:
142
- ValueError: If the input images are in an invalid format.
143
- """
144
-
145
- output_dict, image_inputs = {}, {}
146
- if images is not None:
147
- # Preprocess images
148
- if is_image_or_image_url(images):
149
- images = [[images]]
150
- elif isinstance(images, list) and is_image_or_image_url(images[0]):
151
- images = [images]
152
- elif (
153
- not isinstance(images, list)
154
- and not isinstance(images[0], list)
155
- and not is_image_or_image_url(images[0][0])
156
- ):
157
- raise ValueError(
158
- "Invalid input images. Please provide a single image or a list of images or a list of list of images."
159
- )
160
-
161
- # Load and process images
162
- images = [load_image_list(sample) for sample in images]
163
- image_kwargs = image_kwargs or {}
164
- image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors="np", **image_kwargs)
165
-
166
- # Validate image inputs
167
- assert "pixel_values" in image_inputs, "pixel_values not found in image_inputs"
168
- assert "image_sizes" in image_inputs, "image_sizes not found in image_inputs"
169
- assert len(image_inputs.keys()) == 2, "Only one key is allowed in image_inputs, got {}".format(
170
- image_inputs.keys()
171
- )
172
-
173
- # Extract pixel values and image sizes
174
- pixel_values = image_inputs["pixel_values"][0]
175
- image_sizes = image_inputs["image_sizes"][0]
176
- unique_sizes = np.unique(image_sizes, axis=0)
177
-
178
- assert len(unique_sizes) == 1, "All images must have the same size, got {}".format(unique_sizes)
179
-
180
- # Convert pixel values to PyTorch tensor
181
- pixel_values = np.asarray(pixel_values)
182
- pixel_values = torch.from_numpy(pixel_values)
183
- output_dict["pixel_values"] = pixel_values
184
- output_dict["image_sizes"] = image_sizes
185
-
186
- # Expand image tokens in text
187
- if image_inputs.get("pixel_values") is not None:
188
- replace_strings = []
189
- # Calculate the number of tokens needed for each image and create a placeholder
190
- for image_size in image_sizes:
191
- height, width = image_size
192
- num_height_tokens = height // self.patch_size
193
- num_width_tokens = width // self.patch_size
194
- replace_tokens = [[self.image_token] * num_width_tokens + [self.image_break_token]] * num_height_tokens
195
- # Flatten list
196
- replace_tokens = [item for sublist in replace_tokens for item in sublist]
197
- replace_tokens[-1] = self.image_end_token
198
- replace_str = "".join(replace_tokens)
199
- replace_strings.append(replace_str)
200
- text = text.replace(self.image_token, "<placeholder>", 1)
201
-
202
- # Replace placeholders with actual image token sequences
203
- while "<placeholder>" in text:
204
- replace_str = replace_strings.pop(0)
205
- text = text.replace("<placeholder>", replace_str, 1)
206
-
207
- # Encode the text
208
- text_inputs = super(ImageTextTokenizer, self).encode(text, **text_kwargs)
209
-
210
- output_dict["input_ids"] = text_inputs
211
- return output_dict
212
-
213
- def apply_chat_template(
214
- self,
215
- conversation: List[Dict[str, Any]] | List[List[Dict[str, Any]]],
216
- *,
217
- images: Optional[ImageInput] = None,
218
- image_kwargs: Optional[Dict[str, Any]] = None,
219
- add_generation_prompt: bool = False,
220
- tokenize: bool = True,
221
- padding: bool = False,
222
- truncation: bool = False,
223
- max_length: Optional[int] = None,
224
- return_tensors: Optional[str] = None,
225
- return_dict: bool = True,
226
- return_assistant_tokens_mask: bool = False,
227
- generation_prefix: str = "",
228
- tokenizer_kwargs: Optional[Dict[str, Any]] = None,
229
- **kwargs,
230
- ):
231
- """
232
- Apply the chat template to the conversation.
233
-
234
- Args:
235
- conversation (List[Dict[str, Any]] | List[List[Dict[str, Any]]]): The conversation to process.
236
- images (Optional[ImageInput]): Images to include in the conversation.
237
- image_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for image processing.
238
- add_generation_prompt (bool): Whether to add a generation prompt.
239
- tokenize (bool): Whether to tokenize the output.
240
- padding (bool): Whether to pad the output.
241
- truncation (bool): Whether to truncate the output.
242
- max_length (Optional[int]): Maximum length of the output.
243
- return_tensors (Optional[str]): The type of tensors to return.
244
- return_dict (bool): Whether to return a dictionary.
245
- return_assistant_tokens_mask (bool): Whether to return the assistant tokens mask.
246
- generation_prefix (str): Prefix to add before asking model to generate. Helpful to guide the generation. Defaults to "".
247
- tokenizer_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer.
248
- **kwargs: Additional keyword arguments.
249
-
250
- Returns:
251
- The processed conversation with applied chat template.
252
-
253
- Raises:
254
- AssertionError: If return_dict is False or if the conversation format is invalid.
255
- """
256
- assert return_dict, "return_dict must be True for ImageTextTokenizer"
257
- assert isinstance(conversation, list), "conversation must be a list"
258
- if isinstance(conversation[0], list):
259
- assert len(conversation) == 1, "Only support single-conversation input, got {}".format(conversation)
260
- conversation = conversation[0]
261
-
262
- # Extract images from the conversation if not provided
263
- if images is None:
264
- images = []
265
- for msg in conversation:
266
- if msg.get("images", None) is not None:
267
- images = images + (msg["images"])
268
- images = load_image_list(images)
269
- # In case the input does not have images, will ignore
270
- # Useful in feeding VLM inputs with and without images
271
- if isinstance(images, list) and len(images) == 0:
272
- images = None
273
-
274
- # Apply the chat template to the text
275
- text = super().apply_chat_template(
276
- conversation,
277
- tokenize=False,
278
- add_generation_prompt=add_generation_prompt,
279
- padding=padding,
280
- truncation=truncation,
281
- max_length=max_length,
282
- return_tensors=return_tensors,
283
- return_dict=False,
284
- return_assistant_tokens_mask=return_assistant_tokens_mask,
285
- generation_prefix=generation_prefix,
286
- tokenizer_kwargs=tokenizer_kwargs,
287
- **kwargs,
288
- )
289
-
290
- if tokenizer_kwargs is None:
291
- tokenizer_kwargs = {}
292
-
293
- # Encode the text and images
294
- output = self.encode(
295
- text,
296
- images=images,
297
- image_kwargs=image_kwargs,
298
- tokenize=tokenize,
299
- padding=padding,
300
- truncation=truncation,
301
- max_length=max_length,
302
- add_special_tokens=False,
303
- return_tensors=return_tensors,
304
- **tokenizer_kwargs,
305
- )
306
- return output
307
-
308
- @property
309
- def model_input_names(self):
310
- """
311
- Get the combined model input names from both the text tokenizer and image processor.
312
-
313
- Returns:
314
- List[str]: A list of unique input names.
315
- """
316
- tokenizer_input_names = self.tokenizer.model_input_names
317
- image_processor_input_names = self.image_processor.model_input_names
318
- return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/5d1bc4c8a22a942736ae6b73a4ebb21da4980adc DELETED
@@ -1,117 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import numpy as np
17
- import torch
18
- from pytorch_retinaface.utils.nms.py_cpu_nms import py_cpu_nms
19
-
20
- from .log import log
21
-
22
-
23
- # Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py
24
- def filter_detected_boxes(boxes, scores, confidence_threshold, nms_threshold, top_k, keep_top_k):
25
- """Filter boxes based on confidence score and remove overlapping boxes using NMS."""
26
- # Keep detections with confidence above threshold
27
- inds = np.where(scores > confidence_threshold)[0]
28
- boxes = boxes[inds]
29
- scores = scores[inds]
30
-
31
- # Sort by confidence and keep top K detections
32
- order = scores.argsort()[::-1][:top_k]
33
- boxes = boxes[order]
34
- scores = scores[order]
35
-
36
- # Run non-maximum-suppression (NMS) to remove overlapping boxes
37
- dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
38
- keep = py_cpu_nms(dets, nms_threshold)
39
- dets = dets[keep, :]
40
- dets = dets[:keep_top_k, :]
41
- boxes = dets[:, :-1]
42
- return boxes
43
-
44
-
45
- # Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/utils/box_utils.py to handle batched inputs
46
- def decode_batch(loc, priors, variances):
47
- """Decode batched locations from predictions using priors and variances.
48
-
49
- Args:
50
- loc (tensor): Batched location predictions for loc layers.
51
- Shape: [batch_size, num_priors, 4]
52
- priors (tensor): Prior boxes in center-offset form.
53
- Shape: [num_priors, 4]
54
- variances: (list[float]): Variances of prior boxes.
55
-
56
- Return:
57
- Decoded batched bounding box predictions
58
- Shape: [batch_size, num_priors, 4]
59
- """
60
- batch_size = loc.size(0)
61
- priors = priors.unsqueeze(0).expand(batch_size, -1, -1)
62
-
63
- boxes = torch.cat(
64
- (
65
- priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
66
- priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1]),
67
- ),
68
- dim=2,
69
- )
70
-
71
- boxes[:, :, :2] -= boxes[:, :, 2:] / 2
72
- boxes[:, :, 2:] += boxes[:, :, :2]
73
- return boxes
74
-
75
-
76
- # Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py
77
- def _check_keys(model, pretrained_state_dict):
78
- ckpt_keys = set(pretrained_state_dict.keys())
79
- model_keys = set(model.state_dict().keys())
80
- used_pretrained_keys = model_keys & ckpt_keys
81
- unused_pretrained_keys = ckpt_keys - model_keys
82
- missing_keys = model_keys - ckpt_keys
83
- log.debug("Missing keys:{}".format(len(missing_keys)))
84
- log.debug("Unused checkpoint keys:{}".format(len(unused_pretrained_keys)))
85
- log.debug("Used keys:{}".format(len(used_pretrained_keys)))
86
- assert len(used_pretrained_keys) > 0, "load NONE from pretrained checkpoint"
87
- return True
88
-
89
-
90
- # Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py
91
- def _remove_prefix(state_dict, prefix):
92
- """Old version of the model is stored with all names of parameters sharing common prefix 'module.'"""
93
- log.debug("Removing prefix '{}'".format(prefix))
94
-
95
- def f(x):
96
- return x.split(prefix, 1)[-1] if x.startswith(prefix) else x
97
-
98
- return {f(key): value for key, value in state_dict.items()}
99
-
100
-
101
- # Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py
102
- def load_model(model, pretrained_path, load_to_cpu):
103
- log.debug("Loading pretrained model from {}".format(pretrained_path))
104
- if load_to_cpu:
105
- pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage, weights_only=True)
106
- else:
107
- device = torch.cuda.current_device()
108
- pretrained_dict = torch.load(
109
- pretrained_path, map_location=lambda storage, loc: storage.cuda(device), weights_only=True
110
- )
111
- if "state_dict" in pretrained_dict.keys():
112
- pretrained_dict = _remove_prefix(pretrained_dict["state_dict"], "module.")
113
- else:
114
- pretrained_dict = _remove_prefix(pretrained_dict, "module.")
115
- _check_keys(model, pretrained_dict)
116
- model.load_state_dict(pretrained_dict, strict=False)
117
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/5e5a5244c87516121f3e7686c924f8b1c66cd772 DELETED
@@ -1,360 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Optional
17
-
18
- import torch
19
- from einops import rearrange
20
-
21
- from .ar_tokenizer_quantizers import FSQuantizer
22
-
23
- # Make sure jit model output consistenly during consecutive calls
24
- # Check here: https://github.com/pytorch/pytorch/issues/74534
25
- torch._C._jit_set_texpr_fuser_enabled(False)
26
-
27
-
28
- def load_jit_model(jit_filepath: str = None, device: str = "cuda") -> torch.jit.ScriptModule:
29
- """Loads a torch.jit.ScriptModule from a filepath.
30
-
31
- Args:
32
- jit_filepath: The filepath to the JIT-compiled model.
33
- device: The device to load the model onto, default=cuda.
34
- Returns:
35
- The JIT compiled model loaded to device and on eval mode.
36
- """
37
- # Make sure jit model output consistenly during consecutive calls
38
- # Check here: https://github.com/pytorch/pytorch/issues/74534
39
- torch._C._jit_set_texpr_fuser_enabled(False)
40
-
41
- model = torch.jit.load(jit_filepath)
42
- return model.eval().to(device)
43
-
44
-
45
- class BaseDiscreteVideoFSQTokenizer(torch.nn.Module):
46
- """
47
- A base class for Discrete Video FSQ Tokenizer that handles data type conversions, and normalization
48
- using provided mean and standard deviation values for latent space representation.
49
- Derived classes should load pre-trained encoder and decoder components into a encoder and decoder attributes.
50
-
51
- Attributes:
52
- encoder (Module | Callable): Encoder loaded from storage.
53
- decoder (Module | Callable): Decoder loaded from storage.
54
- dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled.
55
-
56
- Args:
57
- name (str): Name of the model, used for differentiating cache file paths.
58
- latent_ch (int, optional): Number of latent channels (default is 6).
59
- is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True).
60
- pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level.
61
- latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level.
62
- max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow.
63
- level (list[int]): The level defined in FSQ quantizer.
64
- compression_ratio (list[int]): The compression factor for (T, H, W).
65
- """
66
-
67
- def __init__(
68
- self,
69
- name: str,
70
- latent_ch: int = 6,
71
- is_bf16: bool = True,
72
- pixel_chunk_duration: int = 25,
73
- latent_chunk_duration: int = 4,
74
- max_enc_batch_size: int = 8,
75
- max_dec_batch_size: int = 4,
76
- levels: list[int] = [8, 8, 8, 5, 5, 5],
77
- compression_ratio: list[int] = [8, 16, 16],
78
- ):
79
- super().__init__()
80
- self.channel = latent_ch
81
- self.name = name
82
- dtype = torch.bfloat16 if is_bf16 else torch.float32
83
- self.dtype = dtype
84
- self.pixel_chunk_duration = pixel_chunk_duration
85
- self.latent_chunk_duration = latent_chunk_duration
86
- self.max_enc_batch_size = max_enc_batch_size
87
- self.max_dec_batch_size = max_dec_batch_size
88
- self.levels = levels
89
- self.compress_ratio = compression_ratio
90
- self.fsq_quantizer = FSQuantizer(levels)
91
-
92
- @property
93
- def latent_ch(self) -> int:
94
- """
95
- Returns the number of latent channels in the tokenizer.
96
- """
97
- return self.channel
98
-
99
- @torch.no_grad()
100
- def encode(self, state: torch.Tensor, pixel_chunk_duration: Optional[int] = None) -> torch.Tensor:
101
- B, C, T, H, W = state.shape
102
- if pixel_chunk_duration is None:
103
- # Use the default pixel chunk duration and latent chunk duration
104
- pixel_chunk_duration = self.pixel_chunk_duration
105
- latent_chunk_duration = self.latent_chunk_duration
106
- else:
107
- # Update the latent chunk duration based on the given pixel chunk duration
108
- latent_chunk_duration = 1 + (pixel_chunk_duration - 1) // self.compress_ratio[0]
109
-
110
- assert (
111
- T % pixel_chunk_duration == 0
112
- ), f"Temporal dimension {T} is not divisible by chunk_length {pixel_chunk_duration}"
113
- state = rearrange(state, "b c (n t) h w -> (b n) c t h w", t=pixel_chunk_duration)
114
-
115
- # use max_enc_batch_size to avoid OOM
116
- if state.shape[0] > self.max_enc_batch_size:
117
- quantized_out_list = []
118
- indices_list = []
119
- for i in range(0, state.shape[0], self.max_enc_batch_size):
120
- indices, quantized_out, _ = self.encoder(state[i : i + self.max_enc_batch_size].to(self.dtype))
121
- quantized_out_list.append(quantized_out)
122
- indices_list.append(indices)
123
- quantized_out = torch.cat(quantized_out_list, dim=0)
124
- indices = torch.cat(indices_list, dim=0)
125
- else:
126
- indices, quantized_out, _ = self.encoder(state.to(self.dtype))
127
- assert quantized_out.shape[2] == latent_chunk_duration
128
- return rearrange(quantized_out, "(b n) c t h w -> b c (n t) h w", b=B), rearrange(
129
- indices, "(b n) t h w -> b (n t) h w", b=B
130
- )
131
-
132
- @torch.no_grad()
133
- def decode(self, indices: torch.Tensor, pixel_chunk_duration: Optional[int] = None) -> torch.Tensor:
134
- B, T, _, _ = indices.shape
135
- if pixel_chunk_duration is None:
136
- pixel_chunk_duration = self.pixel_chunk_duration
137
- latent_chunk_duration = self.latent_chunk_duration
138
- else:
139
- latent_chunk_duration = 1 + (pixel_chunk_duration - 1) // self.compress_ratio[0]
140
- assert (
141
- T % latent_chunk_duration == 0
142
- ), f"Temporal dimension {T} is not divisible by chunk_length {latent_chunk_duration}"
143
- indices = rearrange(indices, "b (n t) h w -> (b n) t h w", t=latent_chunk_duration)
144
-
145
- # use max_dec_batch_size to avoid OOM
146
- if indices.shape[0] > self.max_dec_batch_size:
147
- state = []
148
- for i in range(0, indices.shape[0], self.max_dec_batch_size):
149
- state.append(self.decoder(indices[i : i + self.max_dec_batch_size]))
150
- state = torch.cat(state, dim=0)
151
- else:
152
- state = self.decoder(indices)
153
-
154
- assert state.shape[2] == pixel_chunk_duration
155
- return rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B)
156
-
157
- def reset_dtype(self, *args, **kwargs):
158
- """
159
- Resets the data type of the encoder and decoder to the model's default data type.
160
-
161
- Args:
162
- *args, **kwargs: Unused, present to allow flexibility in method calls.
163
- """
164
- del args, kwargs
165
- self.decoder.to(self.dtype)
166
- self.encoder.to(self.dtype)
167
-
168
-
169
- class DiscreteVideoFSQJITTokenizer(BaseDiscreteVideoFSQTokenizer):
170
- """
171
- A JIT compiled Discrete Video FSQ Tokenizer that loads pre-trained encoder
172
- and decoder components from a remote store, handles data type conversions, and normalization
173
- using provided mean and standard deviation values for latent space representation.
174
-
175
- Attributes:
176
- encoder (Module): The JIT compiled encoder loaded from storage.
177
- decoder (Module): The JIT compiled decoder loaded from storage.
178
- dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled.
179
-
180
- Args:
181
- enc_fp (str): File path to the encoder's JIT file on the remote store.
182
- dec_fp (str): File path to the decoder's JIT file on the remote store.
183
- name (str): Name of the model, used for differentiating cache file paths.
184
- latent_ch (int, optional): Number of latent channels (default is 6).
185
- is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True).
186
- pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level.
187
- latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level.
188
- max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow.
189
- level (list[int]): The level defined in FSQ quantizer.
190
- compression_ratio (list[int]): The compression factor for (T, H, W).
191
- """
192
-
193
- def __init__(
194
- self,
195
- enc_fp: str,
196
- dec_fp: str,
197
- name: str,
198
- latent_ch: int = 6,
199
- is_bf16: bool = True,
200
- pixel_chunk_duration: int = 25,
201
- latent_chunk_duration: int = 4,
202
- max_enc_batch_size: int = 8,
203
- max_dec_batch_size: int = 4,
204
- levels: list[int] = [8, 8, 8, 5, 5, 5],
205
- compression_ratio: list[int] = [8, 16, 16],
206
- ):
207
- super().__init__(
208
- name,
209
- latent_ch,
210
- is_bf16,
211
- pixel_chunk_duration,
212
- latent_chunk_duration,
213
- max_enc_batch_size,
214
- max_dec_batch_size,
215
- levels,
216
- compression_ratio,
217
- )
218
-
219
- self.load_encoder(enc_fp)
220
- self.load_decoder(dec_fp)
221
-
222
- def load_encoder(self, enc_fp: str) -> None:
223
- """
224
- Load the encoder from the remote store.
225
-
226
- Args:
227
- - enc_fp (str): File path to the encoder's JIT file on the remote store.
228
- """
229
- self.encoder = load_jit_model(enc_fp, device="cuda")
230
- self.encoder.eval()
231
- for param in self.encoder.parameters():
232
- param.requires_grad = False
233
- self.encoder.to(self.dtype)
234
-
235
- def load_decoder(self, dec_fp: str) -> None:
236
- """
237
- Load the decoder from the remote store.
238
-
239
- Args:
240
- - dec_fp (str): File path to the decoder's JIT file on the remote store.
241
- """
242
- self.decoder = load_jit_model(dec_fp, device="cuda")
243
- self.decoder.eval()
244
- for param in self.decoder.parameters():
245
- param.requires_grad = False
246
- self.decoder.to(self.dtype)
247
-
248
-
249
- class DiscreteVideoFSQStateDictTokenizer(BaseDiscreteVideoFSQTokenizer):
250
- """
251
- A Discrete Video FSQ Tokenizer that loads weights from pre-trained JITed encoder
252
- into as nn.Module so that encoder can be "torch.compile()" and JITed decoder, so it can be torch.compiled,
253
- handles data type conversions, and normalization using provided mean and standard deviation values for latent
254
- space representation.
255
-
256
- Attributes:
257
- tokenizer_module (Module): Tokenizer module with weights loaded from JIT checkpoints
258
- encoder (Callable): tokenizer_module's encode method
259
- decoder (Callable): tokenizer_module's decode method
260
- dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled.
261
-
262
- Args:
263
- enc_fp (str): File path to the encoder's JIT file on the remote store.
264
- dec_fp (str): File path to the decoder's JIT file on the remote store.
265
- tokenizer_module (Module): Tokenizer module that will have it's weights loaded
266
- name (str): Name of the model, used for differentiating cache file paths.
267
- latent_ch (int, optional): Number of latent channels (default is 6).
268
- is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True).
269
- pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level.
270
- latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level.
271
- max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow.
272
- level (list[int]): The level defined in FSQ quantizer.
273
- compression_ratio (list[int]): The compression factor for (T, H, W).
274
- """
275
-
276
- def __init__(
277
- self,
278
- enc_fp: str,
279
- dec_fp: str,
280
- tokenizer_module: torch.nn.Module,
281
- name: str,
282
- latent_ch: int = 6,
283
- is_bf16: bool = True,
284
- pixel_chunk_duration: int = 25,
285
- latent_chunk_duration: int = 4,
286
- max_enc_batch_size: int = 8,
287
- max_dec_batch_size: int = 4,
288
- levels: list[int] = [8, 8, 8, 5, 5, 5],
289
- compression_ratio: list[int] = [8, 16, 16],
290
- ):
291
- super().__init__(
292
- name,
293
- latent_ch,
294
- is_bf16,
295
- pixel_chunk_duration,
296
- latent_chunk_duration,
297
- max_enc_batch_size,
298
- max_dec_batch_size,
299
- levels,
300
- compression_ratio,
301
- )
302
-
303
- self.load_encoder_and_decoder(enc_fp, dec_fp, tokenizer_module)
304
-
305
- def load_encoder_and_decoder(self, enc_fp: str, dec_fp: str, tokenizer_module: torch.nn.Module) -> None:
306
- """
307
- Load the encoder from the remote store.
308
-
309
- Args:
310
- - enc_fp (str): File path to the encoder's JIT file on the remote store.
311
- - def_fp (str): File path to the decoder's JIT file on the remote store.
312
- - tokenizer_module (Module): Tokenizer module that was used to create JIT checkpoints
313
- """
314
- self.decoder = load_jit_model(dec_fp)
315
-
316
- self.decoder.eval()
317
- for param in self.decoder.parameters():
318
- param.requires_grad = False
319
- self.decoder.to(self.dtype)
320
-
321
- encoder_sd = load_jit_model(enc_fp).state_dict()
322
-
323
- del tokenizer_module.post_quant_conv
324
- del tokenizer_module.decoder
325
-
326
- state_dict = {
327
- k: v
328
- for k, v in (encoder_sd).items()
329
- # Variables captured by JIT
330
- if k
331
- not in (
332
- "encoder.patcher3d.wavelets",
333
- "encoder.patcher3d._arange",
334
- "encoder.patcher3d.patch_size_buffer",
335
- "quantizer._levels",
336
- "quantizer._basis",
337
- "quantizer.implicit_codebook",
338
- )
339
- }
340
-
341
- tokenizer_module.load_state_dict(state_dict)
342
-
343
- tokenizer_module.eval()
344
- for param in tokenizer_module.parameters():
345
- param.requires_grad = False
346
- tokenizer_module.to(self.dtype)
347
-
348
- self.tokenizer_module = tokenizer_module
349
- self.encoder = self.tokenizer_module.encode
350
-
351
- def reset_dtype(self, *args, **kwargs):
352
- """
353
- Resets the data type of the encoder and decoder to the model's default data type.
354
-
355
- Args:
356
- *args, **kwargs: Unused, present to allow flexibility in method calls.
357
- """
358
- del args, kwargs
359
- self.decoder.to(self.dtype)
360
- self.tokenizer_module.to(self.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/602ea1cb383d8263be06829a466cfb3ba9f97856 DELETED
@@ -1,52 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from omegaconf import DictConfig, OmegaConf
17
-
18
-
19
- class CustomSimpleNamespace:
20
- """
21
- A simple namespace class that supports both attribute-style and dictionary-style access.
22
- """
23
-
24
- def __init__(self, d):
25
- self._d = d
26
-
27
- def __getattr__(self, attr):
28
- # Attribute-style access: config.key
29
- try:
30
- return self._d[attr]
31
- except KeyError:
32
- raise AttributeError(f"'CustomSimpleNamespace' object has no attribute '{attr}'")
33
-
34
- def __getitem__(self, key):
35
- # Dictionary-style access: config['key']
36
- return self._d[key]
37
-
38
-
39
- def maybe_convert_to_namespace(config):
40
- """
41
- This function cast a OmegaConf's DictConfig or a standard dict to CustomSimpleNamespace, which supports both
42
- attribute-style and dictionary-style access.
43
- Note: We need to convert OmegaConf's DictConfig since it is not compatible with torch.compile.
44
- """
45
- # If input is OmegaConf's DictConfig, convert to a standard dict
46
- if isinstance(config, DictConfig):
47
- config = OmegaConf.to_container(config, resolve=True)
48
-
49
- if isinstance(config, dict):
50
- return CustomSimpleNamespace(config)
51
- else:
52
- return config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/61f10fe07227a01d582e17f89a9b5089aa506006 DELETED
@@ -1,88 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import torch
17
- import torch.nn as nn
18
-
19
-
20
- def create_norm(norm_type: str, dim: int, eps: float = 1e-6):
21
- """
22
- Creates the specified normalization layer based on the norm_type.
23
- Adopted from TorchTriton: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py
24
-
25
- Args:
26
- norm_type (str): The type of normalization layer to create.
27
- Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm
28
- dim (int): The dimension of the normalization layer.
29
- eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
30
-
31
- Returns:
32
- The created normalization layer.
33
-
34
- Raises:
35
- NotImplementedError: If an unknown norm_type is provided.
36
- """
37
- norm_type = norm_type.lower() # Normalize to lowercase
38
-
39
- if norm_type == "layernorm":
40
- return nn.LayerNorm(dim, eps=eps, bias=False)
41
- elif norm_type == "np_layernorm":
42
- return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
43
- elif norm_type == "rmsnorm":
44
- return RMSNorm(dim, eps=eps, compile=False)
45
- elif norm_type == "compiled_rmsnorm":
46
- return RMSNorm(dim, eps=eps, compile=True)
47
- elif norm_type == "fused_rmsnorm":
48
- raise NotImplementedError("Fused RMSNorm is not supported yet.")
49
- else:
50
- raise NotImplementedError(f"Unknown norm_type: '{norm_type}'")
51
-
52
-
53
- class RMSNorm(nn.Module):
54
- """
55
- Initialize the RMSNorm normalization layer.
56
- Reference implementation: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py
57
-
58
- Args:
59
- dim (int): The dimension of the input tensor.
60
- eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
61
- compile (bool, optional): Whether to compile the forward function. Default is False.
62
-
63
- Attributes:
64
- eps (float): A small value added to the denominator for numerical stability.
65
- weight (nn.Parameter): Learnable scaling parameter.
66
-
67
- """
68
-
69
- def __init__(self, dim: int, eps: float = 1e-6, compile: bool = False):
70
- super().__init__()
71
- self.eps = eps
72
- self.weight = nn.Parameter(torch.ones(dim))
73
- self.rmsnorm_fn = torch.compile(self.compute_rmsnorm, fullgraph=True) if compile else self.compute_rmsnorm
74
-
75
- @staticmethod
76
- def compute_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float):
77
- def _norm(x, eps):
78
- # Computes the root-mean-square norm of the input tensor.
79
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
80
-
81
- output = _norm(x.float(), eps).type_as(x)
82
- return output * weight
83
-
84
- def forward(self, x: torch.Tensor):
85
- return self.rmsnorm_fn(x, self.weight, self.eps)
86
-
87
- def reset_parameters(self):
88
- torch.nn.init.ones_(self.weight)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/68e9cbb58aa1a39cd62c15a01b3e6526a49b66b0 DELETED
@@ -1,728 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import argparse
17
- import importlib
18
- from contextlib import contextmanager
19
- from typing import List, NamedTuple, Optional, Tuple
20
-
21
- from .misc import misc
22
- import einops
23
- import imageio
24
- import numpy as np
25
- import torch
26
- import torchvision.transforms.functional as transforms_F
27
-
28
- from .df_model_model_t2w import DiffusionT2WModel
29
- from .df_model_model_v2w import DiffusionV2WModel
30
- from .log import log
31
- from .config_helper import get_config_module, override
32
- from .io import load_from_fileobj
33
-
34
- TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
35
- if TORCH_VERSION >= (1, 11):
36
- from torch.ao import quantization
37
- from torch.ao.quantization import FakeQuantizeBase, ObserverBase
38
- elif (
39
- TORCH_VERSION >= (1, 8)
40
- and hasattr(torch.quantization, "FakeQuantizeBase")
41
- and hasattr(torch.quantization, "ObserverBase")
42
- ):
43
- from torch import quantization
44
- from torch.quantization import FakeQuantizeBase, ObserverBase
45
-
46
- DEFAULT_AUGMENT_SIGMA = 0.001
47
-
48
-
49
- def add_common_arguments(parser):
50
- """Add common command line arguments for text2world and video2world generation.
51
-
52
- Args:
53
- parser (ArgumentParser): Argument parser to add arguments to
54
-
55
- The arguments include:
56
- - checkpoint_dir: Base directory containing model weights
57
- - tokenizer_dir: Directory containing tokenizer weights
58
- - video_save_name: Output video filename for single video generation
59
- - video_save_folder: Output directory for batch video generation
60
- - prompt: Text prompt for single video generation
61
- - batch_input_path: Path to JSONL file with input prompts for batch video generation
62
- - negative_prompt: Text prompt describing undesired attributes
63
- - num_steps: Number of diffusion sampling steps
64
- - guidance: Classifier-free guidance scale
65
- - num_video_frames: Number of frames to generate
66
- - height/width: Output video dimensions
67
- - fps: Output video frame rate
68
- - seed: Random seed for reproducibility
69
- - Various model offloading flags
70
- """
71
- parser.add_argument(
72
- "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints"
73
- )
74
- parser.add_argument(
75
- "--tokenizer_dir",
76
- type=str,
77
- default="Cosmos-1.0-Tokenizer-CV8x8x8",
78
- help="Tokenizer weights directory relative to checkpoint_dir",
79
- )
80
- parser.add_argument(
81
- "--video_save_name",
82
- type=str,
83
- default="output",
84
- help="Output filename for generating a single video",
85
- )
86
- parser.add_argument(
87
- "--video_save_folder",
88
- type=str,
89
- default="outputs/",
90
- help="Output folder for generating a batch of videos",
91
- )
92
- parser.add_argument(
93
- "--prompt",
94
- type=str,
95
- help="Text prompt for generating a single video",
96
- )
97
- parser.add_argument(
98
- "--batch_input_path",
99
- type=str,
100
- help="Path to a JSONL file of input prompts for generating a batch of videos",
101
- )
102
- parser.add_argument(
103
- "--negative_prompt",
104
- type=str,
105
- default="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
106
- "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
107
- "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
108
- "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special "
109
- "effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and "
110
- "flickering. Overall, the video is of poor quality.",
111
- help="Negative prompt for the video",
112
- )
113
- parser.add_argument("--num_steps", type=int, default=35, help="Number of diffusion sampling steps")
114
- parser.add_argument("--guidance", type=float, default=7, help="Guidance scale value")
115
- parser.add_argument(
116
- "--num_video_frames", type=int, default=121, choices=[121], help="Number of video frames to sample"
117
- )
118
- parser.add_argument("--height", type=int, default=704, help="Height of video to sample")
119
- parser.add_argument("--width", type=int, default=1280, help="Width of video to sample")
120
- parser.add_argument("--fps", type=int, default=24, help="FPS of the sampled video")
121
- parser.add_argument("--seed", type=int, default=1, help="Random seed")
122
- parser.add_argument(
123
- "--disable_prompt_upsampler",
124
- action="store_true",
125
- help="Disable prompt upsampling",
126
- )
127
- parser.add_argument(
128
- "--offload_diffusion_transformer",
129
- action="store_true",
130
- help="Offload DiT after inference",
131
- )
132
- parser.add_argument(
133
- "--offload_tokenizer",
134
- action="store_true",
135
- help="Offload tokenizer after inference",
136
- )
137
- parser.add_argument(
138
- "--offload_text_encoder_model",
139
- action="store_true",
140
- help="Offload text encoder model after inference",
141
- )
142
- parser.add_argument(
143
- "--offload_prompt_upsampler",
144
- action="store_true",
145
- help="Offload prompt upsampler after inference",
146
- )
147
- parser.add_argument(
148
- "--offload_guardrail_models",
149
- action="store_true",
150
- help="Offload guardrail models after inference",
151
- )
152
-
153
-
154
- def validate_args(args: argparse.Namespace, inference_type: str) -> None:
155
- """Validate command line arguments for text2world and video2world generation."""
156
- assert inference_type in [
157
- "text2world",
158
- "video2world",
159
- ], "Invalid inference_type, must be 'text2world' or 'video2world'"
160
-
161
- # Validate prompt/image/video args for single or batch generation
162
- if inference_type == "text2world" or (inference_type == "video2world" and args.disable_prompt_upsampler):
163
- assert args.prompt or args.batch_input_path, "--prompt or --batch_input_path must be provided."
164
- if inference_type == "video2world" and not args.batch_input_path:
165
- assert (
166
- args.input_image_or_video_path
167
- ), "--input_image_or_video_path must be provided for single video generation."
168
-
169
-
170
- class _IncompatibleKeys(
171
- NamedTuple(
172
- "IncompatibleKeys",
173
- [
174
- ("missing_keys", List[str]),
175
- ("unexpected_keys", List[str]),
176
- ("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]),
177
- ],
178
- )
179
- ):
180
- pass
181
-
182
-
183
- def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys:
184
- """Load a model checkpoint with non-strict matching, handling shape mismatches.
185
-
186
- Args:
187
- model (torch.nn.Module): Model to load weights into
188
- checkpoint_state_dict (dict): State dict from checkpoint
189
-
190
- Returns:
191
- _IncompatibleKeys: Named tuple containing:
192
- - missing_keys: Keys present in model but missing from checkpoint
193
- - unexpected_keys: Keys present in checkpoint but not in model
194
- - incorrect_shapes: Keys with mismatched tensor shapes
195
-
196
- The function handles special cases like:
197
- - Uninitialized parameters
198
- - Quantization observers
199
- - TransformerEngine FP8 states
200
- """
201
- # workaround https://github.com/pytorch/pytorch/issues/24139
202
- model_state_dict = model.state_dict()
203
- incorrect_shapes = []
204
- for k in list(checkpoint_state_dict.keys()):
205
- if k in model_state_dict:
206
- if "_extra_state" in k: # Key introduced by TransformerEngine for FP8
207
- log.debug(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.")
208
- continue
209
- model_param = model_state_dict[k]
210
- # Allow mismatch for uninitialized parameters
211
- if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter):
212
- continue
213
- if not isinstance(model_param, torch.Tensor):
214
- raise ValueError(
215
- f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not."
216
- )
217
-
218
- shape_model = tuple(model_param.shape)
219
- shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
220
- if shape_model != shape_checkpoint:
221
- has_observer_base_classes = (
222
- TORCH_VERSION >= (1, 8)
223
- and hasattr(quantization, "ObserverBase")
224
- and hasattr(quantization, "FakeQuantizeBase")
225
- )
226
- if has_observer_base_classes:
227
- # Handle the special case of quantization per channel observers,
228
- # where buffer shape mismatches are expected.
229
- def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module:
230
- # foo.bar.param_or_buffer_name -> [foo, bar]
231
- key_parts = key.split(".")[:-1]
232
- cur_module = model
233
- for key_part in key_parts:
234
- cur_module = getattr(cur_module, key_part)
235
- return cur_module
236
-
237
- cls_to_skip = (
238
- ObserverBase,
239
- FakeQuantizeBase,
240
- )
241
- target_module = _get_module_for_key(model, k)
242
- if isinstance(target_module, cls_to_skip):
243
- # Do not remove modules with expected shape mismatches
244
- # them from the state_dict loading. They have special logic
245
- # in _load_from_state_dict to handle the mismatches.
246
- continue
247
-
248
- incorrect_shapes.append((k, shape_checkpoint, shape_model))
249
- checkpoint_state_dict.pop(k)
250
- incompatible = model.load_state_dict(checkpoint_state_dict, strict=False)
251
- # Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling
252
- missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k]
253
- unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k]
254
- return _IncompatibleKeys(
255
- missing_keys=missing_keys,
256
- unexpected_keys=unexpected_keys,
257
- incorrect_shapes=incorrect_shapes,
258
- )
259
-
260
-
261
- @contextmanager
262
- def skip_init_linear():
263
- # skip init of nn.Linear
264
- orig_reset_parameters = torch.nn.Linear.reset_parameters
265
- torch.nn.Linear.reset_parameters = lambda x: x
266
- xavier_uniform_ = torch.nn.init.xavier_uniform_
267
- torch.nn.init.xavier_uniform_ = lambda x: x
268
- yield
269
- torch.nn.Linear.reset_parameters = orig_reset_parameters
270
- torch.nn.init.xavier_uniform_ = xavier_uniform_
271
-
272
-
273
- def load_model_by_config(
274
- config_job_name,
275
- config_file="projects/cosmos_video/config/config.py",
276
- model_class=DiffusionT2WModel,
277
- ):
278
- config_module = get_config_module(config_file)
279
- config = importlib.import_module(config_module).make_config()
280
-
281
- config = override(config, ["--", f"experiment={config_job_name}"])
282
-
283
- # Check that the config is valid
284
- config.validate()
285
- # Freeze the config so developers don't change it during training.
286
- config.freeze() # type: ignore
287
-
288
- # Initialize model
289
- with skip_init_linear():
290
- model = model_class(config.model)
291
- return model
292
-
293
-
294
- def load_network_model(model: DiffusionT2WModel, ckpt_path: str):
295
- with skip_init_linear():
296
- model.set_up_model()
297
- net_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
298
- log.debug(non_strict_load_model(model.model, net_state_dict))
299
- model.cuda()
300
-
301
-
302
- def load_tokenizer_model(model: DiffusionT2WModel, tokenizer_dir: str):
303
- with skip_init_linear():
304
- model.set_up_tokenizer(tokenizer_dir)
305
- model.cuda()
306
-
307
-
308
- def prepare_data_batch(
309
- height: int,
310
- width: int,
311
- num_frames: int,
312
- fps: int,
313
- prompt_embedding: torch.Tensor,
314
- negative_prompt_embedding: Optional[torch.Tensor] = None,
315
- ):
316
- """Prepare input batch tensors for video generation.
317
-
318
- Args:
319
- height (int): Height of video frames
320
- width (int): Width of video frames
321
- num_frames (int): Number of frames to generate
322
- fps (int): Frames per second
323
- prompt_embedding (torch.Tensor): Encoded text prompt embeddings
324
- negative_prompt_embedding (torch.Tensor, optional): Encoded negative prompt embeddings
325
-
326
- Returns:
327
- dict: Batch dictionary containing:
328
- - video: Zero tensor of target video shape
329
- - t5_text_mask: Attention mask for text embeddings
330
- - image_size: Target frame dimensions
331
- - fps: Target frame rate
332
- - num_frames: Number of frames
333
- - padding_mask: Frame padding mask
334
- - t5_text_embeddings: Prompt embeddings
335
- - neg_t5_text_embeddings: Negative prompt embeddings (if provided)
336
- - neg_t5_text_mask: Mask for negative embeddings (if provided)
337
- """
338
- # Create base data batch
339
- data_batch = {
340
- "video": torch.zeros((1, 3, num_frames, height, width), dtype=torch.uint8).cuda(),
341
- "t5_text_mask": torch.ones(1, 512, dtype=torch.bfloat16).cuda(),
342
- "image_size": torch.tensor([[height, width, height, width]] * 1, dtype=torch.bfloat16).cuda(),
343
- "fps": torch.tensor([fps] * 1, dtype=torch.bfloat16).cuda(),
344
- "num_frames": torch.tensor([num_frames] * 1, dtype=torch.bfloat16).cuda(),
345
- "padding_mask": torch.zeros((1, 1, height, width), dtype=torch.bfloat16).cuda(),
346
- }
347
-
348
- # Handle text embeddings
349
-
350
- t5_embed = prompt_embedding.to(dtype=torch.bfloat16).cuda()
351
- data_batch["t5_text_embeddings"] = t5_embed
352
-
353
- if negative_prompt_embedding is not None:
354
- neg_t5_embed = negative_prompt_embedding.to(dtype=torch.bfloat16).cuda()
355
- data_batch["neg_t5_text_embeddings"] = neg_t5_embed
356
- data_batch["neg_t5_text_mask"] = torch.ones(1, 512, dtype=torch.bfloat16).cuda()
357
-
358
- return data_batch
359
-
360
-
361
- def get_video_batch(model, prompt_embedding, negative_prompt_embedding, height, width, fps, num_video_frames):
362
- """Prepare complete input batch for video generation including latent dimensions.
363
-
364
- Args:
365
- model: Diffusion model instance
366
- prompt_embedding (torch.Tensor): Text prompt embeddings
367
- negative_prompt_embedding (torch.Tensor): Negative prompt embeddings
368
- height (int): Output video height
369
- width (int): Output video width
370
- fps (int): Output video frame rate
371
- num_video_frames (int): Number of frames to generate
372
-
373
- Returns:
374
- tuple:
375
- - data_batch (dict): Complete model input batch
376
- - state_shape (list): Shape of latent state [C,T,H,W] accounting for VAE compression
377
- """
378
- raw_video_batch = prepare_data_batch(
379
- height=height,
380
- width=width,
381
- num_frames=num_video_frames,
382
- fps=fps,
383
- prompt_embedding=prompt_embedding,
384
- negative_prompt_embedding=negative_prompt_embedding,
385
- )
386
- state_shape = [
387
- model.tokenizer.channel,
388
- model.tokenizer.get_latent_num_frames(num_video_frames),
389
- height // model.tokenizer.spatial_compression_factor,
390
- width // model.tokenizer.spatial_compression_factor,
391
- ]
392
- return raw_video_batch, state_shape
393
-
394
-
395
- def generate_world_from_text(
396
- model: DiffusionT2WModel,
397
- state_shape: list[int],
398
- is_negative_prompt: bool,
399
- data_batch: dict,
400
- guidance: float,
401
- num_steps: int,
402
- seed: int,
403
- ):
404
- """Generate video from text prompt using diffusion model.
405
-
406
- Args:
407
- model (DiffusionT2WModel): Text-to-video diffusion model
408
- state_shape (list[int]): Latent state dimensions [C,T,H,W]
409
- is_negative_prompt (bool): Whether negative prompt is provided
410
- data_batch (dict): Model input batch with embeddings
411
- guidance (float): Classifier-free guidance scale
412
- num_steps (int): Number of diffusion sampling steps
413
- seed (int): Random seed for reproducibility
414
-
415
- Returns:
416
- np.ndarray: Generated video frames [T,H,W,C], range [0,255]
417
-
418
- The function:
419
- 1. Initializes random latent with maximum noise
420
- 2. Performs guided diffusion sampling
421
- 3. Decodes latents to pixel space
422
- """
423
- x_sigma_max = (
424
- misc.arch_invariant_rand(
425
- (1,) + tuple(state_shape),
426
- torch.float32,
427
- model.tensor_kwargs["device"],
428
- seed,
429
- )
430
- * model.sde.sigma_max
431
- )
432
-
433
- # Generate video
434
- sample = model.generate_samples_from_batch(
435
- data_batch,
436
- guidance=guidance,
437
- state_shape=state_shape,
438
- num_steps=num_steps,
439
- is_negative_prompt=is_negative_prompt,
440
- seed=seed,
441
- x_sigma_max=x_sigma_max,
442
- )
443
-
444
- return sample
445
-
446
-
447
- def generate_world_from_video(
448
- model: DiffusionV2WModel,
449
- state_shape: list[int],
450
- is_negative_prompt: bool,
451
- data_batch: dict,
452
- guidance: float,
453
- num_steps: int,
454
- seed: int,
455
- condition_latent: torch.Tensor,
456
- num_input_frames: int,
457
- ) -> Tuple[np.array, list, list]:
458
- """Generate video using a conditioning video/image input.
459
-
460
- Args:
461
- model (DiffusionV2WModel): The diffusion model instance
462
- state_shape (list[int]): Shape of the latent state [C,T,H,W]
463
- is_negative_prompt (bool): Whether negative prompt is provided
464
- data_batch (dict): Batch containing model inputs including text embeddings
465
- guidance (float): Classifier-free guidance scale for sampling
466
- num_steps (int): Number of diffusion sampling steps
467
- seed (int): Random seed for generation
468
- condition_latent (torch.Tensor): Latent tensor from conditioning video/image file
469
- num_input_frames (int): Number of input frames
470
-
471
- Returns:
472
- np.array: Generated video frames in shape [T,H,W,C], range [0,255]
473
- """
474
- assert not model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i, "not supported"
475
- augment_sigma = DEFAULT_AUGMENT_SIGMA
476
-
477
- if condition_latent.shape[2] < state_shape[1]:
478
- # Padding condition latent to state shape
479
- b, c, t, h, w = condition_latent.shape
480
- condition_latent = torch.cat(
481
- [
482
- condition_latent,
483
- condition_latent.new_zeros(b, c, state_shape[1] - t, h, w),
484
- ],
485
- dim=2,
486
- ).contiguous()
487
- num_of_latent_condition = compute_num_latent_frames(model, num_input_frames)
488
-
489
- x_sigma_max = (
490
- misc.arch_invariant_rand(
491
- (1,) + tuple(state_shape),
492
- torch.float32,
493
- model.tensor_kwargs["device"],
494
- seed,
495
- )
496
- * model.sde.sigma_max
497
- )
498
-
499
- sample = model.generate_samples_from_batch(
500
- data_batch,
501
- guidance=guidance,
502
- state_shape=state_shape,
503
- num_steps=num_steps,
504
- is_negative_prompt=is_negative_prompt,
505
- seed=seed,
506
- condition_latent=condition_latent,
507
- num_condition_t=num_of_latent_condition,
508
- condition_video_augment_sigma_in_inference=augment_sigma,
509
- x_sigma_max=x_sigma_max,
510
- )
511
- return sample
512
-
513
-
514
- def read_video_or_image_into_frames_BCTHW(
515
- input_path: str,
516
- input_path_format: str = "mp4",
517
- H: int = None,
518
- W: int = None,
519
- normalize: bool = True,
520
- max_frames: int = -1,
521
- also_return_fps: bool = False,
522
- ) -> torch.Tensor:
523
- """Read video or image file and convert to tensor format.
524
-
525
- Args:
526
- input_path (str): Path to input video/image file
527
- input_path_format (str): Format of input file (default: "mp4")
528
- H (int, optional): Height to resize frames to
529
- W (int, optional): Width to resize frames to
530
- normalize (bool): Whether to normalize pixel values to [-1,1] (default: True)
531
- max_frames (int): Maximum number of frames to read (-1 for all frames)
532
- also_return_fps (bool): Whether to return fps along with frames
533
-
534
- Returns:
535
- torch.Tensor | tuple: Video tensor in shape [B,C,T,H,W], optionally with fps if requested
536
- """
537
- log.debug(f"Reading video from {input_path}")
538
-
539
- loaded_data = load_from_fileobj(input_path, format=input_path_format)
540
- frames, meta_data = loaded_data
541
- if input_path.endswith(".png") or input_path.endswith(".jpg") or input_path.endswith(".jpeg"):
542
- frames = np.array(frames[0]) # HWC, [0,255]
543
- if frames.shape[-1] > 3: # RGBA, set the transparent to white
544
- # Separate the RGB and Alpha channels
545
- rgb_channels = frames[..., :3]
546
- alpha_channel = frames[..., 3] / 255.0 # Normalize alpha channel to [0, 1]
547
-
548
- # Create a white background
549
- white_bg = np.ones_like(rgb_channels) * 255 # White background in RGB
550
-
551
- # Blend the RGB channels with the white background based on the alpha channel
552
- frames = (rgb_channels * alpha_channel[..., None] + white_bg * (1 - alpha_channel[..., None])).astype(
553
- np.uint8
554
- )
555
- frames = [frames]
556
- fps = 0
557
- else:
558
- fps = int(meta_data.get("fps"))
559
- if max_frames != -1:
560
- frames = frames[:max_frames]
561
- input_tensor = np.stack(frames, axis=0)
562
- input_tensor = einops.rearrange(input_tensor, "t h w c -> t c h w")
563
- if normalize:
564
- input_tensor = input_tensor / 128.0 - 1.0
565
- input_tensor = torch.from_numpy(input_tensor).bfloat16() # TCHW
566
- log.debug(f"Raw data shape: {input_tensor.shape}")
567
- if H is not None and W is not None:
568
- input_tensor = transforms_F.resize(
569
- input_tensor,
570
- size=(H, W), # type: ignore
571
- interpolation=transforms_F.InterpolationMode.BICUBIC,
572
- antialias=True,
573
- )
574
- input_tensor = einops.rearrange(input_tensor, "(b t) c h w -> b c t h w", b=1)
575
- if normalize:
576
- input_tensor = input_tensor.to("cuda")
577
- log.debug(f"Load shape {input_tensor.shape} value {input_tensor.min()}, {input_tensor.max()}")
578
- if also_return_fps:
579
- return input_tensor, fps
580
- return input_tensor
581
-
582
-
583
- def compute_num_latent_frames(model: DiffusionV2WModel, num_input_frames: int, downsample_factor=8) -> int:
584
- """This function computes the number of latent frames given the number of input frames.
585
- Args:
586
- model (DiffusionV2WModel): video generation model
587
- num_input_frames (int): number of input frames
588
- downsample_factor (int): downsample factor for temporal reduce
589
- Returns:
590
- int: number of latent frames
591
- """
592
- num_latent_frames = (
593
- num_input_frames
594
- // model.tokenizer.video_vae.pixel_chunk_duration
595
- * model.tokenizer.video_vae.latent_chunk_duration
596
- )
597
- if num_input_frames % model.tokenizer.video_vae.latent_chunk_duration == 1:
598
- num_latent_frames += 1
599
- elif num_input_frames % model.tokenizer.video_vae.latent_chunk_duration > 1:
600
- assert (
601
- num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1
602
- ) % downsample_factor == 0, f"num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1 must be divisible by {downsample_factor}"
603
- num_latent_frames += (
604
- 1 + (num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1) // downsample_factor
605
- )
606
-
607
- return num_latent_frames
608
-
609
-
610
- def create_condition_latent_from_input_frames(
611
- model: DiffusionV2WModel,
612
- input_frames: torch.Tensor,
613
- num_frames_condition: int = 25,
614
- ):
615
- """Create condition latent for video generation from input frames.
616
-
617
- Takes the last num_frames_condition frames from input as conditioning.
618
-
619
- Args:
620
- model (DiffusionV2WModel): Video generation model
621
- input_frames (torch.Tensor): Input video tensor [B,C,T,H,W], range [-1,1]
622
- num_frames_condition (int): Number of frames to use for conditioning
623
-
624
- Returns:
625
- tuple: (condition_latent, encode_input_frames) where:
626
- - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W]
627
- - encode_input_frames (torch.Tensor): Padded input frames used for encoding
628
- """
629
- B, C, T, H, W = input_frames.shape
630
- num_frames_encode = (
631
- model.tokenizer.pixel_chunk_duration
632
- ) # (model.state_shape[1] - 1) / model.vae.pixel_chunk_duration + 1
633
- log.debug(
634
- f"num_frames_encode not set, set it based on pixel chunk duration and model state shape: {num_frames_encode}"
635
- )
636
-
637
- log.debug(
638
- f"Create condition latent from input frames {input_frames.shape}, value {input_frames.min()}, {input_frames.max()}, dtype {input_frames.dtype}"
639
- )
640
-
641
- assert (
642
- input_frames.shape[2] >= num_frames_condition
643
- ), f"input_frames not enough for condition, require at least {num_frames_condition}, get {input_frames.shape[2]}, {input_frames.shape}"
644
- assert (
645
- num_frames_encode >= num_frames_condition
646
- ), f"num_frames_encode should be larger than num_frames_condition, get {num_frames_encode}, {num_frames_condition}"
647
-
648
- # Put the conditioal frames to the begining of the video, and pad the end with zero
649
- condition_frames = input_frames[:, :, -num_frames_condition:]
650
- padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W)
651
- encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2)
652
-
653
- log.debug(
654
- f"create latent with input shape {encode_input_frames.shape} including padding {num_frames_encode - num_frames_condition} at the end"
655
- )
656
- latent = model.encode(encode_input_frames)
657
- return latent, encode_input_frames
658
-
659
-
660
- def get_condition_latent(
661
- model: DiffusionV2WModel,
662
- input_image_or_video_path: str,
663
- num_input_frames: int = 1,
664
- state_shape: list[int] = None,
665
- ):
666
- """Get condition latent from input image/video file.
667
-
668
- Args:
669
- model (DiffusionV2WModel): Video generation model
670
- input_image_or_video_path (str): Path to conditioning image/video
671
- num_input_frames (int): Number of input frames for video2world prediction
672
-
673
- Returns:
674
- tuple: (condition_latent, input_frames) where:
675
- - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W]
676
- - input_frames (torch.Tensor): Input frames tensor [B,C,T,H,W]
677
- """
678
- if state_shape is None:
679
- state_shape = model.state_shape
680
- assert num_input_frames > 0, "num_input_frames must be greater than 0"
681
-
682
- H, W = (
683
- state_shape[-2] * model.tokenizer.spatial_compression_factor,
684
- state_shape[-1] * model.tokenizer.spatial_compression_factor,
685
- )
686
-
687
- input_path_format = input_image_or_video_path.split(".")[-1]
688
- input_frames = read_video_or_image_into_frames_BCTHW(
689
- input_image_or_video_path,
690
- input_path_format=input_path_format,
691
- H=H,
692
- W=W,
693
- )
694
-
695
- condition_latent, _ = create_condition_latent_from_input_frames(model, input_frames, num_input_frames)
696
- condition_latent = condition_latent.to(torch.bfloat16)
697
-
698
- return condition_latent
699
-
700
-
701
- def check_input_frames(input_path: str, required_frames: int) -> bool:
702
- """Check if input video/image has sufficient frames.
703
-
704
- Args:
705
- input_path: Path to input video or image
706
- required_frames: Number of required frames
707
-
708
- Returns:
709
- np.ndarray of frames if valid, None if invalid
710
- """
711
- if input_path.endswith((".jpg", ".jpeg", ".png")):
712
- if required_frames > 1:
713
- log.error(f"Input ({input_path}) is an image but {required_frames} frames are required")
714
- return False
715
- return True # Let the pipeline handle image loading
716
- # For video input
717
- try:
718
- vid = imageio.get_reader(input_path, "ffmpeg")
719
- frame_count = vid.count_frames()
720
-
721
- if frame_count < required_frames:
722
- log.error(f"Input video has {frame_count} frames but {required_frames} frames are required")
723
- return False
724
- else:
725
- return True
726
- except Exception as e:
727
- log.error(f"Error reading video file {input_path}: {e}")
728
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/69f477ced9dfe59deda742bc507addf7d7268bdf DELETED
@@ -1,223 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from __future__ import annotations
17
-
18
- import collections
19
- import collections.abc
20
- import ctypes
21
- import functools
22
- import os
23
- from datetime import timedelta
24
- from typing import Any, Callable, Optional
25
-
26
- import pynvml
27
- import torch
28
- import torch.distributed as dist
29
-
30
- from .log import log
31
- from .device import Device
32
-
33
-
34
- def init() -> int | None:
35
- """Initialize distributed training."""
36
- # Set GPU affinity.
37
- pynvml.nvmlInit()
38
- local_rank = int(os.getenv("LOCAL_RANK", 0))
39
- device = Device(local_rank)
40
- os.sched_setaffinity(0, device.get_cpu_affinity())
41
- # Set up NCCL communication.
42
- os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0"
43
- os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
44
- if dist.is_available():
45
- if dist.is_initialized():
46
- return torch.cuda.current_device()
47
- torch.cuda.set_device(local_rank)
48
- # Get the timeout value from environment variable
49
- timeout_seconds = os.getenv("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", 1800)
50
- # Convert the timeout to an integer (if it isn't already) and then to a timedelta
51
- timeout_timedelta = timedelta(seconds=int(timeout_seconds))
52
- dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta)
53
- log.critical(
54
- f"Initialized distributed training with local rank {local_rank} with timeout {timeout_seconds}",
55
- rank0_only=False,
56
- )
57
- # Increase the L2 fetch granularity for faster speed.
58
- _libcudart = ctypes.CDLL("libcudart.so")
59
- # Set device limit on the current device.
60
- p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
61
- _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
62
- _libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05))
63
- log.info(f"Training with {get_world_size()} GPUs.")
64
-
65
-
66
- def get_rank(group: Optional[dist.ProcessGroup] = None) -> int:
67
- """Get the rank (GPU device) of the worker.
68
-
69
- Returns:
70
- rank (int): The rank of the worker.
71
- """
72
- rank = 0
73
- if dist.is_available() and dist.is_initialized():
74
- rank = dist.get_rank(group)
75
- return rank
76
-
77
-
78
- def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int:
79
- """Get world size. How many GPUs are available in this job.
80
-
81
- Returns:
82
- world_size (int): The total number of GPUs available in this job.
83
- """
84
- world_size = 1
85
- if dist.is_available() and dist.is_initialized():
86
- world_size = dist.get_world_size(group)
87
- return world_size
88
-
89
-
90
- def is_rank0() -> bool:
91
- """Check if current process is the master GPU.
92
-
93
- Returns:
94
- (bool): True if this function is called from the master GPU, else False.
95
- """
96
- return get_rank() == 0
97
-
98
-
99
- def rank0_only(func: Callable) -> Callable:
100
- """Apply this function only to the master GPU.
101
-
102
- Example usage:
103
- @rank0_only
104
- def func(x):
105
- return x + 3
106
-
107
- Args:
108
- func (Callable): a function.
109
-
110
- Returns:
111
- (Callable): A function wrapper executing the function only on the master GPU.
112
- """
113
-
114
- @functools.wraps(func)
115
- def wrapper(*args, **kwargs): # noqa: ANN202
116
- if is_rank0():
117
- return func(*args, **kwargs)
118
- else:
119
- return None
120
-
121
- return wrapper
122
-
123
-
124
- def barrier() -> None:
125
- """Barrier for all GPUs."""
126
- if dist.is_available() and dist.is_initialized():
127
- dist.barrier()
128
-
129
-
130
- class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel):
131
- """This extends torch.nn.parallel.DistributedDataParallel with .training_step().
132
-
133
- This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an coreModel such that
134
- model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling
135
- model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward ->
136
- training_step), allowing us to preserve the function names and signatures.
137
- """
138
-
139
- def __init__(self, model: torch.nn.Module, *args, **kwargs):
140
- super().__init__(model, *args, **kwargs)
141
-
142
- def training_step(self, *args, **kwargs) -> Any:
143
- # Cache the original model.forward() method.
144
- original_forward = self.module.forward
145
-
146
- def wrapped_training_step(*_args, **_kwargs): # noqa: ANN202
147
- # Unpatch immediately before calling training_step() because itself may want to call the real forward.
148
- self.module.forward = original_forward
149
- # The actual .training_step().
150
- return self.module.training_step(*_args, **_kwargs)
151
-
152
- # Patch the original_module's forward so we can redirect the arguments back to the real method.
153
- self.module.forward = wrapped_training_step
154
- # Call self, which implicitly calls self.forward() --> model.forward(), which is now model.training_step().
155
- # Without calling self.forward() or model.forward() explciitly, implicit hooks are also executed.
156
- return self(*args, **kwargs)
157
-
158
-
159
- def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]:
160
- """Aggregate the list of data batches from all devices and process the results.
161
-
162
- This is used for gathering validation data batches with utils.dataloader.DistributedEvalSampler.
163
- It will return the data/output of the entire validation set in its original index order. The sizes of data_batches
164
- in different ranks may differ by 1 (if dataset size is not evenly divisible), in which case a dummy sample will be
165
- created before calling dis.all_gather().
166
-
167
- Args:
168
- data_batches (list[dict[str, torch.Tensor]]): List of tensors or (hierarchical) dictionary where
169
- leaf entries are tensors.
170
-
171
- Returns:
172
- data_gather (torch.Tensor | dict[str, torch.Tensor]): tensors or (hierarchical) dictionary where
173
- leaf entries are concatenated tensors.
174
- """
175
- if isinstance(data_batches[0], torch.Tensor):
176
- # Concatenate the local data batches.
177
- data_concat = torch.cat(data_batches, dim=0) # type: ignore
178
- # Get the largest number of local samples from all ranks to determine whether to dummy-pad on this rank.
179
- max_num_local_samples = torch.tensor(len(data_concat), device="cuda")
180
- dist.all_reduce(max_num_local_samples, op=dist.ReduceOp.MAX)
181
- if len(data_concat) < max_num_local_samples:
182
- assert len(data_concat) + 1 == max_num_local_samples
183
- dummy = torch.empty_like(data_concat[:1])
184
- data_concat = torch.cat([data_concat, dummy], dim=0)
185
- dummy_count = torch.tensor(1, device="cuda")
186
- else:
187
- dummy_count = torch.tensor(0, device="cuda")
188
- # Get all concatenated batches from all ranks and concatenate again.
189
- dist.all_reduce(dummy_count, op=dist.ReduceOp.SUM)
190
- data_concat = all_gather_tensor(data_concat.contiguous())
191
- data_collate = torch.stack(data_concat, dim=1).flatten(start_dim=0, end_dim=1)
192
- # Remove the dummy samples.
193
- if dummy_count > 0:
194
- data_collate = data_collate[:-dummy_count]
195
- elif isinstance(data_batches[0], collections.abc.Mapping):
196
- data_collate = dict()
197
- for key in data_batches[0].keys():
198
- data_collate[key] = collate_batches([data[key] for data in data_batches]) # type: ignore
199
- else:
200
- raise TypeError
201
- return data_collate
202
-
203
-
204
- @torch.no_grad()
205
- def all_gather_tensor(tensor: torch.Tensor) -> list[torch.Tensor]:
206
- """Gather the corresponding tensor from all GPU devices to a list.
207
-
208
- Args:
209
- tensor (torch.Tensor): Pytorch tensor.
210
-
211
- Returns:
212
- tensor_list (list[torch.Tensor]): A list of Pytorch tensors gathered from all GPU devices.
213
- """
214
- tensor_list = [torch.zeros_like(tensor) for _ in range(get_world_size())]
215
- dist.all_gather(tensor_list, tensor)
216
- return tensor_list
217
-
218
-
219
- def broadcast(tensor, src, group=None, async_op=False):
220
- world_size = get_world_size()
221
- if world_size < 2:
222
- return tensor
223
- dist.broadcast(tensor, src=src, group=group, async_op=async_op)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/6bb055d8b2ddd78f626f08bb78f9434de5aef511 DELETED
@@ -1,276 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import ast
17
- import builtins
18
- import collections.abc as abc
19
- import importlib
20
- import inspect
21
- import os
22
- import uuid
23
- from collections import OrderedDict
24
- from contextlib import contextmanager
25
- from dataclasses import is_dataclass
26
- from typing import Any, Dict, List, Tuple, Union
27
-
28
- import attrs
29
- import yaml
30
- from omegaconf import DictConfig, ListConfig, OmegaConf
31
-
32
- from .lazy_file_io import PathManager
33
- from .lazy_registry import _convert_target_to_string
34
-
35
- __all__ = ["LazyCall", "LazyConfig"]
36
-
37
-
38
- def sort_dict(d: Dict[str, Any]) -> OrderedDict[str, Any]:
39
- return OrderedDict(sorted(d.items(), key=lambda x: x[0]))
40
-
41
-
42
- def dict_representer(dumper: yaml.Dumper, data: OrderedDict[str, Any]) -> yaml.nodes.MappingNode:
43
- return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
44
-
45
-
46
- def sort_recursive(obj: Union[Dict[str, Any], List[Any], Any]) -> Union[OrderedDict[str, Any], List[Any], Any]:
47
- if isinstance(obj, dict):
48
- return sort_dict({k: sort_recursive(v) for k, v in obj.items()})
49
- elif isinstance(obj, list):
50
- return [sort_recursive(item) for item in obj]
51
- return obj
52
-
53
-
54
- yaml.add_representer(OrderedDict, dict_representer)
55
-
56
-
57
- def get_default_params(cls_or_func):
58
- if callable(cls_or_func):
59
- # inspect signature for function
60
- signature = inspect.signature(cls_or_func)
61
- else:
62
- # inspect signature for class
63
- signature = inspect.signature(cls_or_func.__init__)
64
- params = signature.parameters
65
- default_params = {
66
- name: param.default for name, param in params.items() if param.default is not inspect.Parameter.empty
67
- }
68
- return default_params
69
-
70
-
71
- class LazyCall:
72
- """
73
- Wrap a callable so that when it's called, the call will not be executed,
74
- but returns a dict that describes the call.
75
-
76
- LazyCall object has to be called with only keyword arguments. Positional
77
- arguments are not yet supported.
78
-
79
- Examples:
80
- ::
81
- # from detectron2.config import instantiate, LazyCall
82
-
83
- layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
84
- layer_cfg.out_channels = 64 # can edit it afterwards
85
- layer = instantiate(layer_cfg)
86
- """
87
-
88
- def __init__(self, target):
89
- if not (callable(target) or isinstance(target, (str, abc.Mapping))):
90
- raise TypeError(f"target of LazyCall must be a callable or defines a callable! Got {target}")
91
- self._target = target
92
-
93
- def __call__(self, **kwargs):
94
- if is_dataclass(self._target) or attrs.has(self._target):
95
- # omegaconf object cannot hold dataclass type
96
- # https://github.com/omry/omegaconf/issues/784
97
- target = _convert_target_to_string(self._target)
98
- else:
99
- target = self._target
100
- kwargs["_target_"] = target
101
-
102
- _final_params = get_default_params(self._target)
103
- _final_params.update(kwargs)
104
-
105
- return DictConfig(content=_final_params, flags={"allow_objects": True})
106
-
107
-
108
- def _visit_dict_config(cfg, func):
109
- """
110
- Apply func recursively to all DictConfig in cfg.
111
- """
112
- if isinstance(cfg, DictConfig):
113
- func(cfg)
114
- for v in cfg.values():
115
- _visit_dict_config(v, func)
116
- elif isinstance(cfg, ListConfig):
117
- for v in cfg:
118
- _visit_dict_config(v, func)
119
-
120
-
121
- def _validate_py_syntax(filename):
122
- # see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
123
- with PathManager.open(filename, "r") as f:
124
- content = f.read()
125
- try:
126
- ast.parse(content)
127
- except SyntaxError as e:
128
- raise SyntaxError(f"Config file {filename} has syntax error!") from e
129
-
130
-
131
- def _cast_to_config(obj):
132
- # if given a dict, return DictConfig instead
133
- if isinstance(obj, dict):
134
- return DictConfig(obj, flags={"allow_objects": True})
135
- return obj
136
-
137
-
138
- _CFG_PACKAGE_NAME = "detectron2._cfg_loader"
139
- """
140
- A namespace to put all imported config into.
141
- """
142
-
143
-
144
- def _random_package_name(filename):
145
- # generate a random package name when loading config files
146
- return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename)
147
-
148
-
149
- @contextmanager
150
- def _patch_import():
151
- """
152
- Enhance relative import statements in config files, so that they:
153
- 1. locate files purely based on relative location, regardless of packages.
154
- e.g. you can import file without having __init__
155
- 2. do not cache modules globally; modifications of module states has no side effect
156
- 3. support other storage system through PathManager, so config files can be in the cloud
157
- 4. imported dict are turned into omegaconf.DictConfig automatically
158
- """
159
- old_import = builtins.__import__
160
-
161
- def find_relative_file(original_file, relative_import_path, level):
162
- # NOTE: "from . import x" is not handled. Because then it's unclear
163
- # if such import should produce `x` as a python module or DictConfig.
164
- # This can be discussed further if needed.
165
- relative_import_err = """
166
- Relative import of directories is not allowed within config files.
167
- Within a config file, relative import can only import other config files.
168
- """.replace(
169
- "\n", " "
170
- )
171
- if not len(relative_import_path):
172
- raise ImportError(relative_import_err)
173
-
174
- cur_file = os.path.dirname(original_file)
175
- for _ in range(level - 1):
176
- cur_file = os.path.dirname(cur_file)
177
- cur_name = relative_import_path.lstrip(".")
178
- for part in cur_name.split("."):
179
- cur_file = os.path.join(cur_file, part)
180
- if not cur_file.endswith(".py"):
181
- cur_file += ".py"
182
- if not PathManager.isfile(cur_file):
183
- cur_file_no_suffix = cur_file[: -len(".py")]
184
- if PathManager.isdir(cur_file_no_suffix):
185
- raise ImportError(f"Cannot import from {cur_file_no_suffix}." + relative_import_err)
186
- else:
187
- raise ImportError(
188
- f"Cannot import name {relative_import_path} from " f"{original_file}: {cur_file} does not exist."
189
- )
190
- return cur_file
191
-
192
- def new_import(name, globals=None, locals=None, fromlist=(), level=0):
193
- if (
194
- # Only deal with relative imports inside config files
195
- level != 0
196
- and globals is not None
197
- and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME)
198
- ):
199
- cur_file = find_relative_file(globals["__file__"], name, level)
200
- _validate_py_syntax(cur_file)
201
- spec = importlib.machinery.ModuleSpec(_random_package_name(cur_file), None, origin=cur_file)
202
- module = importlib.util.module_from_spec(spec)
203
- module.__file__ = cur_file
204
- with PathManager.open(cur_file) as f:
205
- content = f.read()
206
- exec(compile(content, cur_file, "exec"), module.__dict__)
207
- for name in fromlist: # turn imported dict into DictConfig automatically
208
- val = _cast_to_config(module.__dict__[name])
209
- module.__dict__[name] = val
210
- return module
211
- return old_import(name, globals, locals, fromlist=fromlist, level=level)
212
-
213
- builtins.__import__ = new_import
214
- yield new_import
215
- builtins.__import__ = old_import
216
-
217
-
218
- class LazyConfig:
219
- """
220
- Provide methods to save, load, and overrides an omegaconf config object
221
- which may contain definition of lazily-constructed objects.
222
- """
223
-
224
- @staticmethod
225
- def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
226
- """
227
- Load a config file.
228
-
229
- Args:
230
- filename: absolute path or relative path w.r.t. the current working directory
231
- keys: keys to load and return. If not given, return all keys
232
- (whose values are config objects) in a dict.
233
- """
234
- has_keys = keys is not None
235
- filename = filename.replace("/./", "/") # redundant
236
- if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]:
237
- raise ValueError(f"Config file {filename} has to be a python or yaml file.")
238
- if filename.endswith(".py"):
239
- _validate_py_syntax(filename)
240
-
241
- with _patch_import():
242
- # Record the filename
243
- module_namespace = {
244
- "__file__": filename,
245
- "__package__": _random_package_name(filename),
246
- }
247
- with PathManager.open(filename) as f:
248
- content = f.read()
249
- # Compile first with filename to:
250
- # 1. make filename appears in stacktrace
251
- # 2. make load_rel able to find its parent's (possibly remote) location
252
- exec(compile(content, filename, "exec"), module_namespace)
253
-
254
- ret = module_namespace
255
- else:
256
- with PathManager.open(filename) as f:
257
- obj = yaml.unsafe_load(f)
258
- ret = OmegaConf.create(obj, flags={"allow_objects": True})
259
-
260
- if has_keys:
261
- if isinstance(keys, str):
262
- return _cast_to_config(ret[keys])
263
- else:
264
- return tuple(_cast_to_config(ret[a]) for a in keys)
265
- else:
266
- if filename.endswith(".py"):
267
- # when not specified, only load those that are config objects
268
- ret = DictConfig(
269
- {
270
- name: _cast_to_config(value)
271
- for name, value in ret.items()
272
- if isinstance(value, (DictConfig, ListConfig, dict)) and not name.startswith("_")
273
- },
274
- flags={"allow_objects": True},
275
- )
276
- return ret
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/73755631ed6b97ebf773b3941fc0f6d1621761f7 DELETED
@@ -1,231 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from dataclasses import dataclass
17
- from typing import Callable, Dict, Optional, Tuple
18
-
19
- import torch
20
- from torch import Tensor
21
-
22
- from .df_conditioner import BaseVideoCondition
23
- from .df_df_functional_batch_ops import batch_mul
24
- from .df_df_module_res_sampler import COMMON_SOLVER_OPTIONS
25
- from .df_model_model_t2w import DiffusionT2WModel as VideoDiffusionModel
26
- from .lazy_config_init import instantiate as lazy_instantiate
27
-
28
-
29
- @dataclass
30
- class VideoLatentDiffusionDecoderCondition(BaseVideoCondition):
31
- # latent_condition will concat to the input of network, along channel dim;
32
- # cfg will make latent_condition all zero padding.
33
- latent_condition: Optional[torch.Tensor] = None
34
- latent_condition_sigma: Optional[torch.Tensor] = None
35
-
36
-
37
- class LatentDiffusionDecoderModel(VideoDiffusionModel):
38
- def __init__(self, config):
39
- super().__init__(config)
40
- """
41
- latent_corruptor: the corruption module is used to corrupt the latents. It add gaussian noise to the latents.
42
- pixel_corruptor: the corruption module is used to corrupt the pixels. It apply gaussian blur kernel to pixels in a temporal consistent way.
43
- tokenizer_corruptor: the corruption module is used to simulate tokenizer reconstruction errors.
44
-
45
- diffusion decoder noise augmentation pipeline for continuous token condition model:
46
- condition: GT_video [T, H, W]
47
- -> tokenizer_corruptor~(8x8x8) encode -> latent_corruptor -> tokenizer_corruptor~(8x8x8) decode
48
- -> pixel corruptor
49
- -> tokenizer~(1x8x8) encode -> condition [T, H/8, W/8]
50
- GT: GT_video [T, H, W] -> tokenizer~(1x8x8) -> x_t [T, H/8, W/8].
51
-
52
- diffusion decoder noise augmentation pipeline for discrete token condition model:
53
- condition: GT_video [T, H, W]
54
- -> pixel corruptor
55
- -> discrete tokenizer encode -> condition [T, T/8, H/16, W/16]
56
- GT: GT_video [T, H, W] -> tokenizer~(8x8x8) -> x_t [T, T/8, H/8, W/8].
57
-
58
- """
59
- self.latent_corruptor = lazy_instantiate(config.latent_corruptor)
60
- self.pixel_corruptor = lazy_instantiate(config.pixel_corruptor)
61
- self.tokenizer_corruptor = lazy_instantiate(config.tokenizer_corruptor)
62
-
63
- if self.latent_corruptor:
64
- self.latent_corruptor.to(**self.tensor_kwargs)
65
- if self.pixel_corruptor:
66
- self.pixel_corruptor.to(**self.tensor_kwargs)
67
-
68
- if self.tokenizer_corruptor:
69
- if hasattr(self.tokenizer_corruptor, "reset_dtype"):
70
- self.tokenizer_corruptor.reset_dtype()
71
- else:
72
- assert self.pixel_corruptor is not None
73
-
74
- self.diffusion_decoder_cond_sigma_low = config.diffusion_decoder_cond_sigma_low
75
- self.diffusion_decoder_cond_sigma_high = config.diffusion_decoder_cond_sigma_high
76
- self.diffusion_decoder_corrupt_prob = config.diffusion_decoder_corrupt_prob
77
- if hasattr(config, "condition_on_tokenizer_corruptor_token"):
78
- self.condition_on_tokenizer_corruptor_token = config.condition_on_tokenizer_corruptor_token
79
- else:
80
- self.condition_on_tokenizer_corruptor_token = False
81
-
82
- def is_image_batch(self, data_batch: dict[str, Tensor]) -> bool:
83
- """We hanlde two types of data_batch. One comes from a joint_dataloader where "dataset_name" can be used to differenciate image_batch and video_batch.
84
- Another comes from a dataloader which we by default assumes as video_data for video model training.
85
- """
86
- is_image = self.input_image_key in data_batch
87
- is_video = self.input_data_key in data_batch
88
- assert (
89
- is_image != is_video
90
- ), "Only one of the input_image_key or input_data_key should be present in the data_batch."
91
- return is_image
92
-
93
- def get_x0_fn_from_batch(
94
- self,
95
- data_batch: Dict,
96
- guidance: float = 1.5,
97
- is_negative_prompt: bool = False,
98
- apply_corruptor: bool = True,
99
- corrupt_sigma: float = 1.5,
100
- preencode_condition: bool = False,
101
- ) -> Callable:
102
- """
103
- Generates a callable function `x0_fn` based on the provided data batch and guidance factor.
104
-
105
- This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states.
106
-
107
- Args:
108
- - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner`
109
- - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5.
110
- - is_negative_prompt (bool): use negative prompt t5 in uncondition if true
111
-
112
- Returns:
113
- - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin
114
-
115
- The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence.
116
- """
117
- input_key = self.input_data_key # by default it is video key
118
- # Latent state
119
- raw_state = data_batch[input_key]
120
-
121
- if self.condition_on_tokenizer_corruptor_token:
122
- if preencode_condition:
123
- latent_condition = raw_state.to(torch.int32).contiguous()
124
- corrupted_pixel = self.tokenizer_corruptor.decode(latent_condition[:, 0])
125
- else:
126
- corrupted_pixel = (
127
- self.pixel_corruptor(raw_state) if apply_corruptor and self.pixel_corruptor else raw_state
128
- )
129
- latent_condition = self.tokenizer_corruptor.encode(corrupted_pixel)
130
- latent_condition = latent_condition[1] if isinstance(latent_condition, tuple) else latent_condition
131
- corrupted_pixel = self.tokenizer_corruptor.decode(latent_condition)
132
- latent_condition = latent_condition.unsqueeze(1)
133
- else:
134
- if preencode_condition:
135
- latent_condition = raw_state
136
- corrupted_pixel = self.decode(latent_condition)
137
- else:
138
- corrupted_pixel = (
139
- self.pixel_corruptor(raw_state) if apply_corruptor and self.pixel_corruptor else raw_state
140
- )
141
- latent_condition = self.encode(corrupted_pixel).contiguous()
142
-
143
- sigma = (
144
- torch.rand((latent_condition.shape[0],)).to(**self.tensor_kwargs) * corrupt_sigma
145
- ) # small value to indicate clean video
146
- _, _, _, c_noise_cond = self.scaling(sigma=sigma)
147
- if corrupt_sigma != self.diffusion_decoder_cond_sigma_low and self.diffusion_decoder_corrupt_prob > 0:
148
- noise = batch_mul(sigma, torch.randn_like(latent_condition))
149
- latent_condition = latent_condition + noise
150
- data_batch["latent_condition_sigma"] = batch_mul(torch.ones_like(latent_condition[:, 0:1, ::]), c_noise_cond)
151
- data_batch["latent_condition"] = latent_condition
152
- if is_negative_prompt:
153
- condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
154
- else:
155
- condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
156
-
157
- def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
158
- cond_x0 = self.denoise(noise_x, sigma, condition).x0
159
- uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0
160
- return cond_x0 + guidance * (cond_x0 - uncond_x0)
161
-
162
- return x0_fn, corrupted_pixel
163
-
164
- def generate_samples_from_batch(
165
- self,
166
- data_batch: Dict,
167
- guidance: float = 1.5,
168
- seed: int = 1,
169
- state_shape: Tuple | None = None,
170
- n_sample: int | None = None,
171
- is_negative_prompt: bool = False,
172
- num_steps: int = 35,
173
- solver_option: COMMON_SOLVER_OPTIONS = "2ab",
174
- sigma_min: float = 0.02,
175
- apply_corruptor: bool = False,
176
- return_recon_x: bool = False,
177
- corrupt_sigma: float = 0.01,
178
- preencode_condition: bool = False,
179
- ) -> Tensor:
180
- """
181
- Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples.
182
- Args:
183
- data_batch (dict): raw data batch draw from the training data loader.
184
- iteration (int): Current iteration number.
185
- guidance (float): guidance weights
186
- seed (int): random seed
187
- state_shape (tuple): shape of the state, default to self.state_shape if not provided
188
- n_sample (int): number of samples to generate
189
- is_negative_prompt (bool): use negative prompt t5 in uncondition if true
190
- num_steps (int): number of steps for the diffusion process
191
- solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver)
192
- preencode_condition (bool): use pre-computed condition if true, save tokenizer's inference time memory/
193
- """
194
- if not preencode_condition:
195
- self._normalize_video_databatch_inplace(data_batch)
196
- self._augment_image_dim_inplace(data_batch)
197
- is_image_batch = False
198
- if n_sample is None:
199
- input_key = self.input_image_key if is_image_batch else self.input_data_key
200
- n_sample = data_batch[input_key].shape[0]
201
- if state_shape is None:
202
- if is_image_batch:
203
- state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W
204
-
205
- x0_fn, recon_x = self.get_x0_fn_from_batch(
206
- data_batch,
207
- guidance,
208
- is_negative_prompt=is_negative_prompt,
209
- apply_corruptor=apply_corruptor,
210
- corrupt_sigma=corrupt_sigma,
211
- preencode_condition=preencode_condition,
212
- )
213
- generator = torch.Generator(device=self.tensor_kwargs["device"])
214
- generator.manual_seed(seed)
215
- x_sigma_max = (
216
- torch.randn(n_sample, *state_shape, **self.tensor_kwargs, generator=generator) * self.sde.sigma_max
217
- )
218
-
219
- samples = self.sampler(
220
- x0_fn,
221
- x_sigma_max,
222
- num_steps=num_steps,
223
- sigma_min=sigma_min,
224
- sigma_max=self.sde.sigma_max,
225
- solver_option=solver_option,
226
- )
227
-
228
- if return_recon_x:
229
- return samples, recon_x
230
- else:
231
- return samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/77c3f88ca85134e689203e9ac157673c42edb0b3 DELETED
@@ -1,131 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import argparse
17
-
18
- from .misc import misc
19
- import torch
20
- from peft import PeftModel
21
- from transformers import AutoModelForCausalLM, AutoTokenizer
22
-
23
- from .guardrail_aegis_categories import UNSAFE_CATEGORIES
24
- from .guardrail_common_core import ContentSafetyGuardrail, GuardrailRunner
25
- from .log import log
26
-
27
- SAFE = misc.Color.green("SAFE")
28
- UNSAFE = misc.Color.red("UNSAFE")
29
-
30
- DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/aegis"
31
-
32
-
33
- class Aegis(ContentSafetyGuardrail):
34
- def __init__(
35
- self,
36
- checkpoint_dir: str = DEFAULT_CHECKPOINT_DIR,
37
- device="cuda" if torch.cuda.is_available() else "cpu",
38
- ) -> None:
39
- self.checkpoint_dir = checkpoint_dir
40
- self.device = device
41
- self.dtype = torch.bfloat16
42
- base_model_id = "meta-llama/LlamaGuard-7b"
43
- aegis_adapter = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0"
44
- base_model = AutoModelForCausalLM.from_pretrained(base_model_id, cache_dir=self.checkpoint_dir)
45
- self.tokenizer = AutoTokenizer.from_pretrained(base_model_id, cache_dir=self.checkpoint_dir)
46
- self.model = PeftModel.from_pretrained(base_model, aegis_adapter, cache_dir=self.checkpoint_dir)
47
- self.model.to(self.device, dtype=self.dtype).eval()
48
-
49
- def get_moderation_prompt(self, user_prompt: str) -> str:
50
- """Create the moderation prompt for the Aegis model."""
51
- full_prompt = f"""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
52
-
53
- <BEGIN UNSAFE CONTENT CATEGORIES>
54
- {"".join([f"{category['instruction']}" for category in UNSAFE_CATEGORIES])}
55
- <END UNSAFE CONTENT CATEGORIES>
56
-
57
- <BEGIN CONVERSATION>
58
-
59
- """
60
- full_prompt += f"User: {user_prompt}\n\n"
61
- full_prompt += """<END CONVERSATION>
62
-
63
- Provide your safety assessment for User in the above conversation:
64
- - First line must read 'safe' or 'unsafe'.
65
- - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
66
-
67
- return full_prompt
68
-
69
- def get_aegis_block_message(self, moderation_output: str) -> str:
70
- """Extract the blocked category and reason from the Aegis model output."""
71
- block_msg = "Prompt blocked by Aegis."
72
- try:
73
- # Extract blocked category and reason
74
- if len(moderation_output.split("\n")) >= 2:
75
- moderation_output_list = moderation_output.split("\n")
76
- blocked_category = int(moderation_output_list[1].lstrip("O")) - 1
77
- if 0 <= blocked_category < len(UNSAFE_CATEGORIES):
78
- blocked_reason = UNSAFE_CATEGORIES[blocked_category]["blocked_reason"]
79
- blocked_category_name = UNSAFE_CATEGORIES[blocked_category]["category"]
80
- block_msg = f"{blocked_category_name}: {blocked_reason}"
81
- except Exception as e:
82
- log.warning(f"Unable to extract blocked category and reason from Aegis output: {e}")
83
- return block_msg
84
-
85
- def filter_aegis_output(self, prompt: str) -> tuple[bool, str]:
86
- """Filter the Aegis model output and return the safety status and message."""
87
- full_prompt = self.get_moderation_prompt(prompt)
88
- inputs = self.tokenizer([full_prompt], add_special_tokens=False, return_tensors="pt").to(self.device)
89
- output = self.model.generate(**inputs, max_new_tokens=100, pad_token_id=self.tokenizer.eos_token_id)
90
- prompt_len = inputs["input_ids"].shape[-1]
91
- moderation_output = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
92
-
93
- if "unsafe" in moderation_output.lower():
94
- block_msg = self.get_aegis_block_message(moderation_output)
95
- return False, block_msg
96
- else:
97
- return True, ""
98
-
99
- def is_safe(self, prompt: str) -> tuple[bool, str]:
100
- """Check if the input prompt is safe according to the Aegis model."""
101
- try:
102
- return self.filter_aegis_output(prompt)
103
- except Exception as e:
104
- log.error(f"Unexpected error occurred when running Aegis guardrail: {e}")
105
- return True, "Unexpected error occurred when running Aegis guardrail."
106
-
107
-
108
- def parse_args():
109
- parser = argparse.ArgumentParser()
110
- parser.add_argument("--prompt", type=str, required=True, help="Input prompt")
111
- parser.add_argument(
112
- "--checkpoint_dir",
113
- type=str,
114
- help="Path to the Aegis checkpoint folder",
115
- default=DEFAULT_CHECKPOINT_DIR,
116
- )
117
- return parser.parse_args()
118
-
119
-
120
- def main(args):
121
- aegis = Aegis(checkpoint_dir=args.checkpoint_dir)
122
- runner = GuardrailRunner(safety_models=[aegis])
123
- with misc.timer("aegis safety check"):
124
- safety, message = runner.run_safety_check(args.prompt)
125
- log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
126
- log.info(f"Message: {message}") if not safety else None
127
-
128
-
129
- if __name__ == "__main__":
130
- args = parse_args()
131
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/7b5c6e553583e8047a37aea5e4925df659426ea2 DELETED
@@ -1,196 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import importlib
17
- import os
18
- import pkgutil
19
- import sys
20
- from dataclasses import fields as dataclass_fields
21
- from dataclasses import is_dataclass
22
- from typing import Any, Dict, Optional
23
-
24
- import attr
25
- import attrs
26
- from hydra import compose, initialize
27
- from hydra.core.config_store import ConfigStore
28
- from omegaconf import DictConfig, OmegaConf
29
-
30
- from .log import log
31
- from .config import Config
32
-
33
-
34
- def is_attrs_or_dataclass(obj) -> bool:
35
- """
36
- Check if the object is an instance of an attrs class or a dataclass.
37
-
38
- Args:
39
- obj: The object to check.
40
-
41
- Returns:
42
- bool: True if the object is an instance of an attrs class or a dataclass, False otherwise.
43
- """
44
- return is_dataclass(obj) or attr.has(type(obj))
45
-
46
-
47
- def get_fields(obj):
48
- """
49
- Get the fields of an attrs class or a dataclass.
50
-
51
- Args:
52
- obj: The object to get fields from. Must be an instance of an attrs class or a dataclass.
53
-
54
- Returns:
55
- list: A list of field names.
56
-
57
- Raises:
58
- ValueError: If the object is neither an attrs class nor a dataclass.
59
- """
60
- if is_dataclass(obj):
61
- return [field.name for field in dataclass_fields(obj)]
62
- elif attr.has(type(obj)):
63
- return [field.name for field in attr.fields(type(obj))]
64
- else:
65
- raise ValueError("The object is neither an attrs class nor a dataclass.")
66
-
67
-
68
- def override(config: Config, overrides: Optional[list[str]] = None) -> Config:
69
- """
70
- :param config: the instance of class `Config` (usually from `make_config`)
71
- :param overrides: list of overrides for config
72
- :return: the composed instance of class `Config`
73
- """
74
- # Store the class of the config for reconstruction after overriding.
75
- # config_class = type(config)
76
-
77
- # Convert Config object to a DictConfig object
78
- config_dict = attrs.asdict(config)
79
- config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
80
- # Enforce "--" separator between the script arguments and overriding configs.
81
- if overrides:
82
- if overrides[0] != "--":
83
- raise ValueError('Hydra config overrides must be separated with a "--" token.')
84
- overrides = overrides[1:]
85
- # Use Hydra to handle overrides
86
- cs = ConfigStore.instance()
87
- cs.store(name="config", node=config_omegaconf)
88
- with initialize(version_base=None):
89
- config_omegaconf = compose(config_name="config", overrides=overrides)
90
- OmegaConf.resolve(config_omegaconf)
91
-
92
- def config_from_dict(ref_instance: Any, kwargs: Any) -> Any:
93
- """
94
- Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data
95
-
96
- Args:
97
- ref_instance: The reference instance to determine the type and fields when needed
98
- kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data
99
-
100
- Returns:
101
- Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data
102
-
103
- Raises:
104
- AssertionError: If the fields do not match or if extra keys are found.
105
- Exception: If there is an error constructing the new instance.
106
- """
107
- is_type = is_attrs_or_dataclass(ref_instance)
108
- if not is_type:
109
- return kwargs
110
- else:
111
- ref_fields = set(get_fields(ref_instance))
112
- assert isinstance(kwargs, dict) or isinstance(
113
- kwargs, DictConfig
114
- ), "kwargs must be a dictionary or a DictConfig"
115
- keys = set(kwargs.keys())
116
-
117
- # ref_fields must equal to or include all keys
118
- extra_keys = keys - ref_fields
119
- assert ref_fields == keys or keys.issubset(
120
- ref_fields
121
- ), f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}"
122
-
123
- resolved_kwargs: Dict[str, Any] = {}
124
- for f in keys:
125
- resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f])
126
- try:
127
- new_instance = type(ref_instance)(**resolved_kwargs)
128
- except Exception as e:
129
- log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}")
130
- log.error(e)
131
- raise e
132
- return new_instance
133
-
134
- config = config_from_dict(config, config_omegaconf)
135
-
136
- return config
137
-
138
-
139
- def get_config_module(config_file: str) -> str:
140
- if not config_file.endswith(".py"):
141
- log.error("Config file cannot be specified as module.")
142
- log.error("Please provide the path to the Python config file (relative to the Cosmos root).")
143
- assert os.path.isfile(config_file), f"Cosmos config file ({config_file}) not found."
144
- # Convert to importable module format.
145
- config_module = config_file.replace("/", ".").replace(".py", "")
146
- return config_module
147
-
148
-
149
- def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None:
150
- """
151
- Import all modules from the specified package path recursively.
152
-
153
- This function is typically used in conjunction with Hydra to ensure that all modules
154
- within a specified package are imported, which is necessary for registering configurations.
155
-
156
- Example usage:
157
- ```python
158
- import_all_modules_from_package("cosmos1.models.diffusion.config.inference", reload=True, skip_underscore=False)
159
- ```
160
-
161
- Args:
162
- package_path (str): The dotted path to the package from which to import all modules.
163
- reload (bool): Flag to determine whether to reload modules if they're already imported.
164
- skip_underscore (bool): If True, skips importing modules that start with an underscore.
165
- """
166
- log.debug(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}")
167
- package = importlib.import_module(package_path)
168
- package_directory = package.__path__
169
-
170
- def import_modules_recursively(directory: str, prefix: str) -> None:
171
- """
172
- Recursively imports or reloads all modules in the given directory.
173
-
174
- Args:
175
- directory (str): The file system path to the current package directory.
176
- prefix (str): The module prefix (e.g., 'cosmos1.models.diffusion.config').
177
- """
178
- for _, module_name, is_pkg in pkgutil.iter_modules([directory]):
179
- if skip_underscore and module_name.startswith("_"):
180
- log.debug(f"Skipping module {module_name} as it starts with an underscore")
181
- continue
182
-
183
- full_module_name = f"{prefix}.{module_name}"
184
- log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}")
185
-
186
- if full_module_name in sys.modules and reload:
187
- importlib.reload(sys.modules[full_module_name])
188
- else:
189
- importlib.import_module(full_module_name)
190
-
191
- if is_pkg:
192
- sub_package_directory = os.path.join(directory, module_name)
193
- import_modules_recursively(sub_package_directory, full_module_name)
194
-
195
- for directory in package_directory:
196
- import_modules_recursively(directory, package_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/7bebf08cef2869c85553980bf81851635dd74f7e DELETED
@@ -1,108 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import List, Tuple, Union
17
-
18
- import torch
19
- import transformers
20
- from transformers import T5EncoderModel, T5TokenizerFast
21
-
22
- from .log import log
23
-
24
- transformers.logging.set_verbosity_error()
25
-
26
-
27
- class CosmosT5TextEncoder(torch.nn.Module):
28
- """Handles T5 text encoding operations."""
29
-
30
- def __init__(self, model_name: str = "google-t5/t5-11b", device: str = "cuda", cache_dir: str = "~/.cache"):
31
- """Initializes the T5 tokenizer and encoder.
32
-
33
- Args:
34
- model_name: The name of the T5 model to use.
35
- device: The device to use for computations.
36
- """
37
- super().__init__()
38
- try:
39
- self.tokenizer = T5TokenizerFast.from_pretrained(model_name, cache_dir=cache_dir)
40
- self.text_encoder = T5EncoderModel.from_pretrained(model_name, cache_dir=cache_dir).to(device)
41
- except Exception as e:
42
- log.warning(f"Failed to load T5 model using cache_dir '{cache_dir}', falling back to default location: {e}")
43
- self.tokenizer = T5TokenizerFast.from_pretrained(model_name)
44
- self.text_encoder = T5EncoderModel.from_pretrained(model_name).to(device)
45
- self.text_encoder.eval()
46
- self.device = device
47
-
48
- @torch.inference_mode()
49
- def encode_prompts(
50
- self, prompts: Union[str, List[str]], max_length: int = 512
51
- ) -> Tuple[torch.Tensor, torch.Tensor]:
52
- """Encodes text prompts into hidden state representations using a T5 encoder.
53
-
54
- This function tokenizes the input prompts, processes them through a T5 text encoder,
55
- and returns the last hidden states. The encoded outputs beyond the actual sequence
56
- length are zero-padded. All prompts in a batch are padded to max_length.
57
-
58
- Args:
59
- prompts: Input text to encode. Can be a single string or a list of strings.
60
- max_length: Maximum sequence length for tokenization and padding. Longer
61
- sequences will be truncated. Defaults to 512.
62
- return_mask: If True, returns the attention mask along with encoded text.
63
- Defaults to False.
64
-
65
- Returns:
66
- If return_mask is False:
67
- torch.Tensor: Encoded text embeddings of shape (batch_size, max_length, hidden_size).
68
- If return_mask is True:
69
- tuple[torch.Tensor, torch.Tensor]: A tuple containing:
70
- - Encoded text embeddings of shape (batch_size, max_length, hidden_size)
71
- - Attention mask of shape (batch_size, max_length) as boolean tensor
72
-
73
- Raises:
74
- ValueError: If the input prompts list is empty.
75
-
76
- Example:
77
- >>> encoder = CosmosT5TextEncoder()
78
- >>> prompts = ["Hello world", "Another example"]
79
- >>> embeddings = encoder.encode_prompts(prompts, max_length=128)
80
- """
81
- if isinstance(prompts, str):
82
- prompts = [prompts]
83
-
84
- if not prompts:
85
- raise ValueError("The input prompt list is empty.")
86
-
87
- batch_encoding = self.tokenizer.batch_encode_plus(
88
- prompts,
89
- return_tensors="pt",
90
- truncation=True,
91
- padding="max_length",
92
- max_length=max_length,
93
- return_length=True,
94
- return_offsets_mapping=False,
95
- )
96
-
97
- input_ids = batch_encoding.input_ids.to(self.device)
98
- attn_mask = batch_encoding.attention_mask.to(self.device)
99
-
100
- outputs = self.text_encoder(input_ids=input_ids, attention_mask=attn_mask)
101
-
102
- encoded_text = outputs.last_hidden_state
103
- lengths = attn_mask.sum(dim=1).cpu()
104
-
105
- for batch_id in range(encoded_text.shape[0]):
106
- encoded_text[batch_id][lengths[batch_id] :] = 0
107
-
108
- return encoded_text, attn_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/7c09eb428a97927d5f0407e2328a3f43afbf38fc DELETED
@@ -1,72 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import pydoc
17
- from typing import Any
18
-
19
- """
20
- `locate` provide ways to map a string (typically found
21
- in config files) to callable objects.
22
- """
23
-
24
- __all__ = ["locate"]
25
-
26
-
27
- def _convert_target_to_string(t: Any) -> str:
28
- """
29
- Inverse of ``locate()``.
30
-
31
- Args:
32
- t: any object with ``__module__`` and ``__qualname__``
33
- """
34
- module, qualname = t.__module__, t.__qualname__
35
-
36
- # Compress the path to this object, e.g. ``module.submodule._impl.class``
37
- # may become ``module.submodule.class``, if the later also resolves to the same
38
- # object. This simplifies the string, and also is less affected by moving the
39
- # class implementation.
40
- module_parts = module.split(".")
41
- for k in range(1, len(module_parts)):
42
- prefix = ".".join(module_parts[:k])
43
- candidate = f"{prefix}.{qualname}"
44
- try:
45
- if locate(candidate) is t:
46
- return candidate
47
- except ImportError:
48
- pass
49
- return f"{module}.{qualname}"
50
-
51
-
52
- def locate(name: str) -> Any:
53
- """
54
- Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``,
55
- such as "module.submodule.class_name".
56
-
57
- Raise Exception if it cannot be found.
58
- """
59
- obj = pydoc.locate(name)
60
-
61
- # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly
62
- # by pydoc.locate. Try a private function from hydra.
63
- if obj is None:
64
- try:
65
- # from hydra.utils import get_method - will print many errors
66
- from hydra.utils import _locate
67
- except ImportError as e:
68
- raise ImportError(f"Cannot dynamically locate object {name}!") from e
69
- else:
70
- obj = _locate(name) # it raises if fails
71
-
72
- return obj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/859eb6498143e5b063dbc888dca7748a07cfda9d DELETED
@@ -1,45 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import os
17
- import re
18
-
19
- from .log import log
20
-
21
-
22
- def read_keyword_list_from_dir(folder_path: str) -> list[str]:
23
- """Read keyword list from all files in a folder."""
24
- output_list = []
25
- file_list = []
26
- # Get list of files in the folder
27
- for file in os.listdir(folder_path):
28
- if os.path.isfile(os.path.join(folder_path, file)):
29
- file_list.append(file)
30
-
31
- # Process each file
32
- for file in file_list:
33
- file_path = os.path.join(folder_path, file)
34
- try:
35
- with open(file_path, "r") as f:
36
- output_list.extend([line.strip() for line in f.readlines()])
37
- except Exception as e:
38
- log.error(f"Error reading file {file}: {str(e)}")
39
-
40
- return output_list
41
-
42
-
43
- def to_ascii(prompt: str) -> str:
44
- """Convert prompt to ASCII."""
45
- return re.sub(r"[^\x00-\x7F]+", " ", prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/8929f3a211707ad09f7c25b6b6e305360a42d6be DELETED
@@ -1,358 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import gc
17
- import os
18
- from abc import ABC
19
- from typing import Any
20
-
21
- import numpy as np
22
- import torch
23
-
24
- from .t5_text_encoder import CosmosT5TextEncoder
25
- from .guardrail_common_presets import presets as guardrail_presets
26
-
27
-
28
- class BaseWorldGenerationPipeline(ABC):
29
- def __init__(
30
- self,
31
- inference_type: str | None = None,
32
- checkpoint_dir: str | None = None,
33
- checkpoint_name: str | None = None,
34
- has_text_input: bool = False,
35
- offload_network: bool = False,
36
- offload_tokenizer: bool = False,
37
- offload_text_encoder_model: bool = False,
38
- offload_guardrail_models: bool = False,
39
- ):
40
- """Initialize base world generation pipeline.
41
-
42
- This abstract base class provides core functionality for world generation models including:
43
- - Model loading and initialization
44
- - Text encoding and embedding
45
- - Safety checks and content filtering
46
- - Memory management through model offloading
47
-
48
- Args:
49
- inference_type: The type of inference pipeline ("text2world" or "video2world")
50
- checkpoint_dir: Root directory containing model checkpoints
51
- checkpoint_name: Name of the specific checkpoint file to load
52
- has_text_input: Whether the pipeline takes text input for world generation
53
- offload_network: If True, moves main model to CPU after inference
54
- offload_tokenizer: If True, moves tokenizer to CPU after use
55
- offload_text_encoder_model: If True, moves T5 encoder to CPU after encoding
56
- offload_guardrail_models: If True, moves safety models to CPU after checks
57
- """
58
- self.inference_type = inference_type
59
- self.checkpoint_dir = checkpoint_dir
60
- self.checkpoint_name = checkpoint_name
61
- self.guardrail_dir = "Cosmos-1.0-Guardrail"
62
- self.has_text_input = has_text_input
63
-
64
- # Add offloading flags
65
- self.offload_network = offload_network
66
- self.offload_tokenizer = offload_tokenizer
67
- self.offload_text_encoder_model = offload_text_encoder_model
68
- self.offload_guardrail_models = offload_guardrail_models
69
-
70
- # Initialize model instances
71
- self.text_guardrail = None
72
- self.video_guardrail = None
73
- self.text_encoder = None
74
- self.model = None
75
-
76
- self._load_model()
77
-
78
- if not self.offload_text_encoder_model:
79
- self._load_text_encoder_model()
80
- if not self.offload_guardrail_models:
81
- if self.has_text_input:
82
- self._load_text_guardrail()
83
- self._load_video_guardrail()
84
- if not self.offload_network:
85
- self._load_network()
86
- if not self.offload_tokenizer:
87
- self._load_tokenizer()
88
-
89
- def _load_tokenizer(self):
90
- pass
91
-
92
- def _load_network(self):
93
- pass
94
-
95
- def _load_model(self, checkpoint_name: str) -> Any:
96
- """Load the world generation model from a checkpoint.
97
-
98
- This abstract method must be implemented by subclasses to load their specific
99
- model architecture and weights.
100
-
101
- Args:
102
- checkpoint_name: Path to the model checkpoint file
103
-
104
- Returns:
105
- The loaded model instance
106
-
107
- Raises:
108
- NotImplementedError: Must be implemented by subclasses
109
- """
110
- pass
111
-
112
- def _load_text_encoder_model(self):
113
- """Load the T5 text encoder model.
114
-
115
- Initializes and loads the T5 encoder model used for converting text prompts
116
- into embeddings that condition the world generation model.
117
-
118
- Returns:
119
- Loaded T5 text encoder model instance
120
- """
121
- self.text_encoder = CosmosT5TextEncoder(cache_dir=self.checkpoint_dir)
122
-
123
- def _load_text_guardrail(self):
124
- """Load text safety classifier models.
125
-
126
- Initializes models used for checking input prompts against safety policies.
127
- Models are loaded from the specified guardrail directory.
128
- """
129
- self.text_guardrail = guardrail_presets.create_text_guardrail_runner(
130
- checkpoint_dir=os.path.join(self.checkpoint_dir, self.guardrail_dir)
131
- )
132
-
133
- def _load_video_guardrail(self):
134
- """Load video safety classifier models.
135
-
136
- Initializes models used for validating generated video content against
137
- safety policies. Models are loaded from the specified guardrail directory.
138
- """
139
- self.video_guardrail = guardrail_presets.create_video_guardrail_runner(
140
- checkpoint_dir=os.path.join(self.checkpoint_dir, self.guardrail_dir)
141
- )
142
-
143
- def _offload_network(self):
144
- if self.model.model:
145
- del self.model.model
146
- self.model.model = None
147
- gc.collect()
148
- torch.cuda.empty_cache()
149
-
150
- def _offload_tokenizer(self):
151
- if self.model.tokenizer:
152
- del self.model.tokenizer
153
- self.model.tokenizer = None
154
- gc.collect()
155
- torch.cuda.empty_cache()
156
-
157
- def _offload_guardrail_models(self):
158
- """Offload safety classifier models to reduce memory usage.
159
-
160
- Moves safety models to CPU and clears GPU memory if they are no longer needed.
161
- This helps manage memory when processing multiple inputs sequentially.
162
- """
163
- if self.text_guardrail:
164
- del self.text_guardrail
165
- self.text_guardrail = None
166
- if self.video_guardrail:
167
- del self.video_guardrail
168
- self.video_guardrail = None
169
- gc.collect()
170
- torch.cuda.empty_cache()
171
-
172
- def _offload_text_encoder_model(self):
173
- """Offload T5 text encoder to reduce memory usage.
174
-
175
- Moves the T5 encoder to CPU and clears GPU memory after text encoding is complete.
176
- This helps manage memory when processing multiple inputs sequentially.
177
- """
178
- if self.text_encoder:
179
- del self.text_encoder
180
- self.text_encoder = None
181
- gc.collect()
182
- torch.cuda.empty_cache()
183
-
184
- def _run_model(self, *args: Any, **kwargs: Any) -> torch.Tensor:
185
- """Generate world latents using the model.
186
-
187
- This abstract method must be implemented by subclasses to define their specific
188
- generation process.
189
-
190
- Args:
191
- *args: Variable positional arguments for model inference
192
- **kwargs: Variable keyword arguments for model inference
193
-
194
- Returns:
195
- torch.Tensor: Generated world representation tensor
196
- """
197
- pass
198
-
199
- def _run_model_with_offload(self, *args: Any, **kwargs: Any) -> torch.Tensor:
200
- """Generate world representation with memory management.
201
-
202
- Handles loading the model before inference and offloading afterward if enabled.
203
- This helps minimize GPU memory usage during inference.
204
-
205
- Args:
206
- *args: Arguments passed to _run_model
207
- **kwargs: Keyword arguments passed to _run_model
208
-
209
- Returns:
210
- np.ndarray: Generated world representation as numpy array
211
- """
212
- pass
213
-
214
- def _run_guardrail_on_prompt(self, prompt: str) -> bool:
215
- """Check if prompt meets safety requirements.
216
-
217
- Validates the input prompt against safety policies using loaded guardrail models.
218
-
219
- Args:
220
- prompt: Raw text prompt to validate
221
-
222
- Returns:
223
- bool: True if prompt passes all safety checks, False otherwise
224
- """
225
- return guardrail_presets.run_text_guardrail(prompt, self.text_guardrail)
226
-
227
- def _run_guardrail_on_prompt_with_offload(self, prompt: str) -> bool:
228
- """Check prompt safety with memory management.
229
-
230
- Validates prompt safety while handling model loading/offloading to manage memory.
231
-
232
- Args:
233
- prompt: Raw text prompt to validate
234
-
235
- Returns:
236
- bool: True if prompt passes all safety checks, False otherwise
237
- """
238
- if self.offload_guardrail_models:
239
- self._load_text_guardrail()
240
-
241
- is_safe = self._run_guardrail_on_prompt(prompt)
242
-
243
- if self.offload_guardrail_models:
244
- self._offload_guardrail_models()
245
-
246
- return is_safe
247
-
248
- def _run_guardrail_on_video(self, video: np.ndarray) -> np.ndarray | None:
249
- """Check if video meets safety requirements.
250
-
251
- Validates generated video content against safety policies using guardrail models.
252
-
253
- Args:
254
- video: Video frames to validate
255
-
256
- Returns:
257
- np.ndarray: Processed video if safe, None if unsafe
258
- """
259
- return guardrail_presets.run_video_guardrail(video, self.video_guardrail)
260
-
261
- def _run_guardrail_on_video_with_offload(self, video: np.ndarray) -> np.ndarray | None:
262
- """Check if generated video meets safety requirements.
263
-
264
- Args:
265
- video: Video frames to validate
266
-
267
- Returns:
268
- np.ndarray: Processed video frames if safe, None otherwise
269
-
270
- Note:
271
- Guardrail models are offloaded after checks if enabled.
272
- """
273
- if self.offload_guardrail_models:
274
- self._load_video_guardrail()
275
-
276
- video = self._run_guardrail_on_video(video)
277
-
278
- if self.offload_guardrail_models:
279
- self._offload_guardrail_models()
280
- return video
281
-
282
- def _run_text_embedding_on_prompt(
283
- self, prompts: list[str], **kwargs: Any
284
- ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
285
- """Convert text prompts to embeddings.
286
-
287
- Processes text prompts into embedding tensors that condition the generation model.
288
-
289
- Args:
290
- prompts: List of text prompts to encode
291
- **kwargs: Additional arguments for text encoding
292
-
293
- Returns:
294
- tuple containing:
295
- - List of text embedding tensors for each prompt
296
- - List of attention masks for each embedding
297
- """
298
-
299
- embeddings = []
300
- masks = []
301
- for prompt in prompts:
302
- embedding, mask = self.text_encoder.encode_prompts(
303
- [prompt],
304
- **kwargs,
305
- )
306
- embeddings.append(embedding)
307
- masks.append(mask)
308
-
309
- return embeddings, masks
310
-
311
- def _run_text_embedding_on_prompt_with_offload(
312
- self, prompts: list[str], **kwargs: Any
313
- ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
314
- """Convert text prompt into embeddings using T5 encoder.
315
-
316
- Args:
317
- prompt: Processed and validated text prompt
318
-
319
- Returns:
320
- Text embedding tensor to condition diffusion model
321
-
322
- Note:
323
- T5 model is offloaded after encoding if enabled.
324
- """
325
- if self.offload_text_encoder_model:
326
- self._load_text_encoder_model()
327
-
328
- embeddings, masks = self._run_text_embedding_on_prompt(prompts, **kwargs)
329
-
330
- if self.offload_text_encoder_model:
331
- self._offload_text_encoder_model()
332
- return embeddings, masks
333
-
334
- def _run_tokenizer_decoding(self, samples: torch.Tensor) -> np.ndarray:
335
- """Decode model outputs into final world representation.
336
-
337
- This abstract method must be implemented by subclasses to convert raw model
338
- outputs into their specific world representation format.
339
-
340
- Args:
341
- samples: Raw output tensor from the generation model
342
-
343
- Returns:
344
- np.ndarray: Decoded world representation
345
- """
346
- pass
347
-
348
- def generate(self, *args: Any, **kwargs: Any):
349
- """Generate world representation.
350
-
351
- This abstract method must be implemented by subclasses to convert raw model
352
- outputs into their specific world representation format.
353
-
354
- Args:
355
- *args: Variable positional arguments for model inference
356
- **kwargs: Variable keyword arguments for model inference
357
- """
358
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/9586934f8c1949d734b4ea3080135d2769ec481a DELETED
@@ -1,333 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Callable, Tuple
17
-
18
- import torch
19
-
20
- from .df_df_functional_batch_ops import batch_mul
21
-
22
-
23
- def phi1(t: torch.Tensor) -> torch.Tensor:
24
- """
25
- Compute the first order phi function: (exp(t) - 1) / t.
26
-
27
- Args:
28
- t: Input tensor.
29
-
30
- Returns:
31
- Tensor: Result of phi1 function.
32
- """
33
- input_dtype = t.dtype
34
- t = t.to(dtype=torch.float64)
35
- return (torch.expm1(t) / t).to(dtype=input_dtype)
36
-
37
-
38
- def phi2(t: torch.Tensor) -> torch.Tensor:
39
- """
40
- Compute the second order phi function: (phi1(t) - 1) / t.
41
-
42
- Args:
43
- t: Input tensor.
44
-
45
- Returns:
46
- Tensor: Result of phi2 function.
47
- """
48
- input_dtype = t.dtype
49
- t = t.to(dtype=torch.float64)
50
- return ((phi1(t) - 1.0) / t).to(dtype=input_dtype)
51
-
52
-
53
- def res_x0_rk2_step(
54
- x_s: torch.Tensor,
55
- t: torch.Tensor,
56
- s: torch.Tensor,
57
- x0_s: torch.Tensor,
58
- s1: torch.Tensor,
59
- x0_s1: torch.Tensor,
60
- ) -> torch.Tensor:
61
- """
62
- Perform a residual-based 2nd order Runge-Kutta step.
63
-
64
- Args:
65
- x_s: Current state tensor.
66
- t: Target time tensor.
67
- s: Current time tensor.
68
- x0_s: Prediction at current time.
69
- s1: Intermediate time tensor.
70
- x0_s1: Prediction at intermediate time.
71
-
72
- Returns:
73
- Tensor: Updated state tensor.
74
-
75
- Raises:
76
- AssertionError: If step size is too small.
77
- """
78
- s = -torch.log(s)
79
- t = -torch.log(t)
80
- m = -torch.log(s1)
81
-
82
- dt = t - s
83
- assert not torch.any(torch.isclose(dt, torch.zeros_like(dt), atol=1e-6)), "Step size is too small"
84
- assert not torch.any(torch.isclose(m - s, torch.zeros_like(dt), atol=1e-6)), "Step size is too small"
85
-
86
- c2 = (m - s) / dt
87
- phi1_val, phi2_val = phi1(-dt), phi2(-dt)
88
-
89
- # Handle edge case where t = s = m
90
- b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0)
91
- b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0)
92
-
93
- return batch_mul(torch.exp(-dt), x_s) + batch_mul(dt, batch_mul(b1, x0_s) + batch_mul(b2, x0_s1))
94
-
95
-
96
- def reg_x0_euler_step(
97
- x_s: torch.Tensor,
98
- s: torch.Tensor,
99
- t: torch.Tensor,
100
- x0_s: torch.Tensor,
101
- ) -> Tuple[torch.Tensor, torch.Tensor]:
102
- """
103
- Perform a regularized Euler step based on x0 prediction.
104
-
105
- Args:
106
- x_s: Current state tensor.
107
- s: Current time tensor.
108
- t: Target time tensor.
109
- x0_s: Prediction at current time.
110
-
111
- Returns:
112
- Tuple[Tensor, Tensor]: Updated state tensor and current prediction.
113
- """
114
- coef_x0 = (s - t) / s
115
- coef_xs = t / s
116
- return batch_mul(coef_x0, x0_s) + batch_mul(coef_xs, x_s), x0_s
117
-
118
-
119
- def reg_eps_euler_step(
120
- x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, eps_s: torch.Tensor
121
- ) -> Tuple[torch.Tensor, torch.Tensor]:
122
- """
123
- Perform a regularized Euler step based on epsilon prediction.
124
-
125
- Args:
126
- x_s: Current state tensor.
127
- s: Current time tensor.
128
- t: Target time tensor.
129
- eps_s: Epsilon prediction at current time.
130
-
131
- Returns:
132
- Tuple[Tensor, Tensor]: Updated state tensor and current x0 prediction.
133
- """
134
- return x_s + batch_mul(eps_s, t - s), x_s + batch_mul(eps_s, 0 - s)
135
-
136
-
137
- def rk1_euler(
138
- x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable
139
- ) -> Tuple[torch.Tensor, torch.Tensor]:
140
- """
141
- Perform a first-order Runge-Kutta (Euler) step.
142
-
143
- Recommended for diffusion models with guidance or model undertrained
144
- Usually more stable at the cost of a bit slower convergence.
145
-
146
- Args:
147
- x_s: Current state tensor.
148
- s: Current time tensor.
149
- t: Target time tensor.
150
- x0_fn: Function to compute x0 prediction.
151
-
152
- Returns:
153
- Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction.
154
- """
155
- x0_s = x0_fn(x_s, s)
156
- return reg_x0_euler_step(x_s, s, t, x0_s)
157
-
158
-
159
- def rk2_mid_stable(
160
- x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable
161
- ) -> Tuple[torch.Tensor, torch.Tensor]:
162
- """
163
- Perform a stable second-order Runge-Kutta (midpoint) step.
164
-
165
- Args:
166
- x_s: Current state tensor.
167
- s: Current time tensor.
168
- t: Target time tensor.
169
- x0_fn: Function to compute x0 prediction.
170
-
171
- Returns:
172
- Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction.
173
- """
174
- s1 = torch.sqrt(s * t)
175
- x_s1, _ = rk1_euler(x_s, s, s1, x0_fn)
176
-
177
- x0_s1 = x0_fn(x_s1, s1)
178
- return reg_x0_euler_step(x_s, s, t, x0_s1)
179
-
180
-
181
- def rk2_mid(x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable) -> Tuple[torch.Tensor, torch.Tensor]:
182
- """
183
- Perform a second-order Runge-Kutta (midpoint) step.
184
-
185
- Args:
186
- x_s: Current state tensor.
187
- s: Current time tensor.
188
- t: Target time tensor.
189
- x0_fn: Function to compute x0 prediction.
190
-
191
- Returns:
192
- Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction.
193
- """
194
- s1 = torch.sqrt(s * t)
195
- x_s1, x0_s = rk1_euler(x_s, s, s1, x0_fn)
196
-
197
- x0_s1 = x0_fn(x_s1, s1)
198
-
199
- return res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1), x0_s1
200
-
201
-
202
- def rk_2heun_naive(
203
- x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable
204
- ) -> Tuple[torch.Tensor, torch.Tensor]:
205
- """
206
- Perform a naive second-order Runge-Kutta (Heun's method) step.
207
- Impl based on rho-rk-deis solvers, https://github.com/qsh-zh/deis
208
- Recommended for diffusion models without guidance and relative large NFE
209
-
210
- Args:
211
- x_s: Current state tensor.
212
- s: Current time tensor.
213
- t: Target time tensor.
214
- x0_fn: Function to compute x0 prediction.
215
-
216
- Returns:
217
- Tuple[Tensor, Tensor]: Updated state tensor and current state.
218
- """
219
- x_t, x0_s = rk1_euler(x_s, s, t, x0_fn)
220
- eps_s = batch_mul(1.0 / s, x_t - x0_s)
221
- x0_t = x0_fn(x_t, t)
222
- eps_t = batch_mul(1.0 / t, x_t - x0_t)
223
-
224
- avg_eps = (eps_s + eps_t) / 2
225
-
226
- return reg_eps_euler_step(x_s, s, t, avg_eps)
227
-
228
-
229
- def rk_2heun_edm(
230
- x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable
231
- ) -> Tuple[torch.Tensor, torch.Tensor]:
232
- """
233
- Perform a naive second-order Runge-Kutta (Heun's method) step.
234
- Impl based no EDM second order Heun method
235
-
236
- Args:
237
- x_s: Current state tensor.
238
- s: Current time tensor.
239
- t: Target time tensor.
240
- x0_fn: Function to compute x0 prediction.
241
-
242
- Returns:
243
- Tuple[Tensor, Tensor]: Updated state tensor and current state.
244
- """
245
- x_t, x0_s = rk1_euler(x_s, s, t, x0_fn)
246
- x0_t = x0_fn(x_t, t)
247
-
248
- avg_x0 = (x0_s + x0_t) / 2
249
-
250
- return reg_x0_euler_step(x_s, s, t, avg_x0)
251
-
252
-
253
- def rk_3kutta_naive(
254
- x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable
255
- ) -> Tuple[torch.Tensor, torch.Tensor]:
256
- """
257
- Perform a naive third-order Runge-Kutta step.
258
- Impl based on rho-rk-deis solvers, https://github.com/qsh-zh/deis
259
- Recommended for diffusion models without guidance and relative large NFE
260
-
261
- Args:
262
- x_s: Current state tensor.
263
- s: Current time tensor.
264
- t: Target time tensor.
265
- x0_fn: Function to compute x0 prediction.
266
-
267
- Returns:
268
- Tuple[Tensor, Tensor]: Updated state tensor and current state.
269
- """
270
- c2, c3 = 0.5, 1.0
271
- a31, a32 = -1.0, 2.0
272
- b1, b2, b3 = 1.0 / 6, 4.0 / 6, 1.0 / 6
273
-
274
- delta = t - s
275
-
276
- s1 = c2 * delta + s
277
- s2 = c3 * delta + s
278
- x_s1, x0_s = rk1_euler(x_s, s, s1, x0_fn)
279
- eps_s = batch_mul(1.0 / s, x_s - x0_s)
280
- x0_s1 = x0_fn(x_s1, s1)
281
- eps_s1 = batch_mul(1.0 / s1, x_s1 - x0_s1)
282
-
283
- _eps = a31 * eps_s + a32 * eps_s1
284
- x_s2, _ = reg_eps_euler_step(x_s, s, s2, _eps)
285
-
286
- x0_s2 = x0_fn(x_s2, s2)
287
- eps_s2 = batch_mul(1.0 / s2, x_s2 - x0_s2)
288
-
289
- avg_eps = b1 * eps_s + b2 * eps_s1 + b3 * eps_s2
290
- return reg_eps_euler_step(x_s, s, t, avg_eps)
291
-
292
-
293
- # key : order + name
294
- RK_FNs = {
295
- "1euler": rk1_euler,
296
- "2mid": rk2_mid,
297
- "2mid_stable": rk2_mid_stable,
298
- "2heun_edm": rk_2heun_edm,
299
- "2heun_naive": rk_2heun_naive,
300
- "3kutta_naive": rk_3kutta_naive,
301
- }
302
-
303
-
304
- def get_runge_kutta_fn(name: str) -> Callable:
305
- """
306
- Get the specified Runge-Kutta function.
307
-
308
- Args:
309
- name: Name of the Runge-Kutta method.
310
-
311
- Returns:
312
- Callable: The specified Runge-Kutta function.
313
-
314
- Raises:
315
- RuntimeError: If the specified method is not supported.
316
- """
317
- if name in RK_FNs:
318
- return RK_FNs[name]
319
- methods = "\n\t".join(RK_FNs.keys())
320
- raise RuntimeError(f"Only support the following Runge-Kutta methods:\n\t{methods}")
321
-
322
-
323
- def is_runge_kutta_fn_supported(name: str) -> bool:
324
- """
325
- Check if the specified Runge-Kutta function is supported.
326
-
327
- Args:
328
- name: Name of the Runge-Kutta method.
329
-
330
- Returns:
331
- bool: True if the method is supported, False otherwise.
332
- """
333
- return name in RK_FNs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/9861ef45253f4932a362923bdb6f07fd1b39666b DELETED
@@ -1,322 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from collections import defaultdict
17
- from typing import Optional
18
-
19
- import torch
20
- from einops import rearrange
21
-
22
- from .ar_config_base_tokenizer import TokenizerConfig
23
- from .lazy_config_init import instantiate as lazy_instantiate
24
-
25
-
26
- def update_vocab_size(
27
- existing_vocab_size,
28
- to_be_added_vocab_size,
29
- training_type,
30
- add_special_tokens,
31
- video_special_tokens={},
32
- ):
33
- # New vocab size
34
- if add_special_tokens:
35
- existing_vocab_size += to_be_added_vocab_size + len(video_special_tokens)
36
- # For text_to_video, we add one <bov> special token at the beginning of the video
37
- elif training_type == "text_to_video":
38
- existing_vocab_size += to_be_added_vocab_size + 1
39
- else:
40
- existing_vocab_size += to_be_added_vocab_size
41
- return existing_vocab_size
42
-
43
-
44
- class DiscreteMultimodalTokenizer:
45
- def __init__(self, tokenizer_config: TokenizerConfig):
46
- self.tokenizer_config = tokenizer_config
47
- self.vocab_size = 0
48
- self.total_seq_len = tokenizer_config.seq_len
49
- self.pad_to_multiple_of = tokenizer_config.pad_to_multiple_of
50
- self.training_type = tokenizer_config.training_type
51
- assert self.training_type in [
52
- "text_only",
53
- "text_to_video",
54
- "video_to_video",
55
- "image_text_interleaved",
56
- ], f"{self.training_type} not supported"
57
-
58
- self._build_text_tokenizer()
59
- self._build_video_tokenizer()
60
-
61
- def _build_text_tokenizer(self):
62
- r"""Function to initialize the text tokenizer model."""
63
- if self.tokenizer_config.text_tokenizer is not None:
64
- self.text_tokenizer = lazy_instantiate(self.tokenizer_config.text_tokenizer.config)
65
- self.vocab_size += self.tokenizer_config.text_tokenizer.vocab_size
66
- else:
67
- self.text_tokenizer = None
68
-
69
- def _build_video_tokenizer(self):
70
- r"""Function to initialize the video tokenizer model."""
71
- if self.tokenizer_config.video_tokenizer is not None:
72
- self.video_tokenizer = lazy_instantiate(self.tokenizer_config.video_tokenizer.config)
73
- self.video_tokenizer = self.video_tokenizer.to("cuda")
74
- self.video_vocab_size = self.tokenizer_config.video_tokenizer.vocab_size
75
- special_token_offset = (
76
- self.tokenizer_config.video_tokenizer.tokenizer_offset
77
- + self.tokenizer_config.video_tokenizer.vocab_size
78
- )
79
- self.video_special_tokens = {
80
- "<|begin_of_video|>": special_token_offset,
81
- "<|end_of_video|>": special_token_offset + 1,
82
- "<|pad_token_video|>": special_token_offset + 2,
83
- }
84
-
85
- self.vocab_size = update_vocab_size(
86
- existing_vocab_size=self.vocab_size,
87
- to_be_added_vocab_size=self.tokenizer_config.video_tokenizer.vocab_size,
88
- training_type=self.training_type,
89
- add_special_tokens=self.tokenizer_config.add_special_tokens,
90
- video_special_tokens=self.video_special_tokens,
91
- )
92
- else:
93
- self.video_tokenizer = None
94
-
95
- @property
96
- def pad_id(self):
97
- r"""Returns the pad_id."""
98
-
99
- if self.training_type == "text_only" or self.training_type == "image_text_interleaved":
100
- pad_id = self.text_tokenizer.pad_id
101
- elif self.training_type in ["text_to_video", "video_to_video"]:
102
- pad_id = self.video_special_tokens["<|pad_token_video|>"]
103
- else:
104
- raise ValueError(f"training_type {self.training_type} not defined")
105
- return pad_id
106
-
107
- @property
108
- def ignore_index(self):
109
- r"""Returns which token should be ignored during loss computation."""
110
- if self.training_type == "text_only" or self.training_type == "image_text_interleaved":
111
- if self.text_tokenizer.pad_id == self.text_tokenizer.eos_id:
112
- # If the PAD token is the same as the EOS token, we do not ignore it during loss
113
- # computation, since we want the model to be able to predict EOS tokens in inference.
114
- # The PyTorch default ignore_index for the cross-entropy loss is -100.
115
- ignore_index = -100
116
- else:
117
- ignore_index = self.text_tokenizer.pad_id
118
- elif self.training_type in ["text_to_video", "video_to_video"]:
119
- ignore_index = self.pad_id
120
- else:
121
- raise ValueError(f"training_type {self.training_type} not defined")
122
- return ignore_index
123
-
124
- @property
125
- def stop_tokens(self):
126
- r"""Returns the stop tokens."""
127
- if self.training_type == "text_only" or self.training_type == "image_text_interleaved":
128
- stop_tokens = self.text_tokenizer.stop_tokens
129
- elif self.training_type in ["text_to_video", "video_to_video"]:
130
- stop_tokens = set([self.video_special_tokens["<|end_of_video|>"]])
131
- else:
132
- raise ValueError(f"training_type {self.training_type} not defined")
133
- return stop_tokens
134
-
135
- def _tokenize_text(self, raw_text: list[str], max_text_seq_len: int = -1):
136
- r"""Function to tokenize text.
137
- Args:
138
- raw_text (list[str]): List of input strings
139
- max_text_seq_len (int): Maximum sequence length returned by text tokenizer
140
- Returns:
141
- text_tokens (list[list[int]]): List of text tokens
142
- """
143
-
144
- batch_size = len(raw_text)
145
- text_tokens = [self.text_tokenizer.encode(raw_text[i], bos=True, eos=True) for i in range(batch_size)]
146
-
147
- # Clipping the text tokens so that the sequence length does not exceed max_text_seq_len
148
- if max_text_seq_len > -1:
149
- for i in range(len(text_tokens)):
150
- if len(text_tokens[i]) > max_text_seq_len:
151
- # Simply clip and add end of seq token
152
- text_tokens[i] = text_tokens[i][0 : max_text_seq_len - 1] + [self.text_tokenizer.eos_id]
153
- return text_tokens
154
-
155
- def _tokenize_class(self, cls_labels: list[str]):
156
- r"""Function to tokenize the class label.
157
- Args:
158
- cls_labels (list[str]): List of class indices
159
- Returns:
160
- class_tokens (list[list[int]]): List of class tokens
161
- """
162
-
163
- # tokenizer_offset tells what offset should be added to the tokens.
164
- # This is needed for vocab expansion.
165
- class_tokens = [[int(x) + self.tokenizer_config.class_tokenizer.tokenizer_offset] for x in cls_labels]
166
-
167
- return class_tokens
168
-
169
- def _tokenize_video(self, videos: torch.Tensor, pixel_chunk_duration: Optional[int] = None):
170
- r"""Function to tokenize video.
171
- Args:
172
- videos (torch.Tensor): Input video data tensor
173
- pixel_chunk_duration (Optional[float]): Pixel chunk duration. If provided, we pass it to the video tokenizer.
174
- Returns:
175
- video_tokens (list[list[int]]): List of video tokens
176
- """
177
-
178
- video_tokens = []
179
- batch_size = videos.shape[0]
180
-
181
- quantized_out, _ = self.video_tokenizer.encode(videos, pixel_chunk_duration=pixel_chunk_duration)
182
- indices = self.video_tokenizer.fsq_quantizer.codes_to_indices(quantized_out.permute(0, 2, 3, 4, 1))
183
-
184
- # Flatten the indices
185
- indices = rearrange(indices, "B T H W -> B (T H W)")
186
-
187
- # tokenizer_offset tells what offset should be added to the tokens.
188
- # This is needed for vocab expansion.
189
- indices += self.tokenizer_config.video_tokenizer.tokenizer_offset
190
-
191
- # Add begin and end of video tokens
192
- bov_token = self.video_special_tokens["<|begin_of_video|>"]
193
- eov_token = self.video_special_tokens["<|end_of_video|>"]
194
-
195
- # Append bov and eov tokens
196
- if self.tokenizer_config.add_special_tokens:
197
- for i in range(batch_size):
198
- video_tokens.append([bov_token] + indices[i].tolist() + [eov_token])
199
- else:
200
- if self.training_type == "text_to_video":
201
- for i in range(batch_size):
202
- video_tokens.append([bov_token] + indices[i].tolist())
203
- else:
204
- for i in range(batch_size):
205
- video_tokens.append(indices[i].tolist())
206
- assert (
207
- len(video_tokens[-1]) == self.tokenizer_config.video_tokenizer.max_seq_len
208
- ), f"Expected {self.tokenizer_config.video_tokenizer.max_seq_len} tokens, got {len(video_tokens[-1])}; video shape: {videos.shape}"
209
-
210
- return video_tokens
211
-
212
- def tokenize(self, data_batch: dict):
213
- r"""Function to tokenize data_dict.
214
- Args:
215
- data_batch (dict): Input data dict
216
- Returns:
217
- tokens (torch.LongTensor): Token tensor dict
218
- """
219
-
220
- if (
221
- self.training_type in ["text_only", "image_text_interleaved"]
222
- and not self.tokenizer_config.text_tokenizer.tokenize_here
223
- ):
224
- # In case of pre-computed tokens, just return the data_batch
225
- return data_batch["tokens"], None
226
-
227
- # Online tokenization
228
- tokens = []
229
- token_boundaries = defaultdict(list)
230
-
231
- # Obtain maximum sequence length
232
- max_text_seq_len = -1
233
- max_visual_seq_len = -1
234
-
235
- if self.training_type in ["text_to_video", "video_to_video"]:
236
- max_visual_seq_len = self.tokenizer_config.video_tokenizer.max_seq_len
237
-
238
- # If max visual sequence length is specified, make sure that text is clipped so that
239
- # the full video/image is always seen.
240
- if max_visual_seq_len > -1:
241
- if self.tokenizer_config.add_special_tokens:
242
- max_visual_seq_len = max_visual_seq_len + 2 # Two special tokens is for [bov, eov] or [boi, eoi] token
243
- elif self.training_type == "text_to_video":
244
- max_visual_seq_len = max_visual_seq_len + 1
245
- else:
246
- max_visual_seq_len = max_visual_seq_len
247
- assert (
248
- max_visual_seq_len <= self.total_seq_len
249
- ), f"max_visual_seq_len ({max_visual_seq_len}) is greater that total sequence length ({self.total_seq_len})"
250
- max_text_seq_len = self.total_seq_len - max_visual_seq_len
251
-
252
- # Tokenize the text
253
- if (
254
- "text" in self.training_type
255
- and self.text_tokenizer is not None
256
- and self.tokenizer_config.text_tokenizer.tokenize_here
257
- ):
258
- key = self.tokenizer_config.text_tokenizer.data_key
259
- batch_size = len(data_batch[key])
260
- assert key in data_batch, f"Key {key} should be present in data for text tokenizer"
261
- tokens = self._tokenize_text(data_batch["caption"], max_text_seq_len)
262
-
263
- for i in range(batch_size):
264
- token_boundaries["text"].append((0, len(tokens[i])))
265
- else:
266
- tokens = []
267
- batch_size = None
268
-
269
- # Tokenize the class label
270
- if "class" in self.training_type and self.tokenizer_config.class_tokenizer is not None:
271
- key = self.tokenizer_config.class_tokenizer.data_key
272
- assert key in data_batch, f"Key {key} should be present in data for class tokenizer"
273
- batch_size = len(data_batch[key]) if batch_size is None else batch_size
274
- tokens_class = self._tokenize_class(data_batch[key])
275
- if len(tokens) == 0:
276
- tokens = tokens_class
277
- for i in range(batch_size):
278
- token_boundaries["class"].append((0, len(tokens[i])))
279
- else:
280
- for i in range(batch_size):
281
- token_boundaries["class"].append((len(tokens[i]), len(tokens[i]) + len(tokens_class[i])))
282
- tokens[i] = tokens[i] + tokens_class[i]
283
-
284
- # Tokenize the video
285
- if self.video_tokenizer is not None and self.tokenizer_config.video_tokenizer.tokenize_here:
286
- key = self.tokenizer_config.video_tokenizer.data_key
287
- assert key in data_batch, f"Key {key} should be present in data for video tokenizer"
288
- batch_size = len(data_batch[key]) if batch_size is None else batch_size
289
-
290
- pixel_chunk_duration = (
291
- None # If not specified, we assume it's a video dataset and use the default chunk duration
292
- )
293
- dataset_name = data_batch.get("dataset_name", None)
294
- if dataset_name is not None and dataset_name.startswith("image"):
295
- # If it's an image dataset, we use a pixel chunk duration of 1
296
- pixel_chunk_duration = 1
297
- tokens_video = self._tokenize_video(data_batch[key], pixel_chunk_duration=pixel_chunk_duration)
298
- if len(tokens) == 0:
299
- tokens = tokens_video
300
- for i in range(batch_size):
301
- token_boundaries["video"].append((0, len(tokens[i])))
302
- # [B,] each entry is ((0, len(tokens[i])))
303
- else:
304
- for i in range(batch_size):
305
- token_boundaries["video"].append((len(tokens[i]), len(tokens[i]) + len(tokens_video[i])))
306
- tokens[i] = tokens[i] + tokens_video[i]
307
-
308
- # Combine the tokens and do padding
309
- max_seq_len_in_batch = max([len(token) for token in tokens])
310
- if self.pad_to_multiple_of is not None:
311
- # Pad the sequence length to the nearest multiple of pad_to_multiple_of
312
- max_seq_len_in_batch = ((max_seq_len_in_batch - 1) // self.pad_to_multiple_of + 1) * self.pad_to_multiple_of
313
- pad_to_len = min(max_seq_len_in_batch, self.total_seq_len)
314
- for i in range(len(tokens)):
315
- if len(tokens[i]) < pad_to_len:
316
- tokens[i] = tokens[i] + [self.pad_id] * (pad_to_len - len(tokens[i]))
317
- else:
318
- tokens[i] = tokens[i][0:pad_to_len]
319
-
320
- # Convert it to long tensor
321
- tokens = torch.LongTensor(tokens)
322
- return tokens, token_boundaries
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/9918ab7cc8f55dc0c159b58c158d3556b6819acd DELETED
@@ -1,317 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Any, Dict, List, Optional, Union
17
-
18
- import numpy as np
19
- import torch
20
- from transformers import AutoTokenizer
21
-
22
- from .log import log
23
-
24
-
25
- def get_tokenizer_path(model_family: str, is_instruct_model: bool = False):
26
- """
27
- Get the tokenizer path from the model family and instruct model flag.
28
- Args:
29
- model_family (str): The model family.
30
- is_instruct_model (bool): Whether the model is an instruct model.
31
- Returns:
32
- str: The tokenizer path in s3.
33
- """
34
- model_family = model_family.lower()
35
- if model_family == "mistral":
36
- return "mistralai/Mistral-Nemo-Instruct-2407"
37
- else:
38
- assert model_family in ["llama3", "llama3.1"]
39
- if model_family == "llama3":
40
- model_path = "meta-llama/Meta-Llama-3-8B"
41
- elif model_family == "llama3.1":
42
- model_path = "meta-llama/Llama-3.1-8B"
43
- else:
44
- raise ValueError(f"Unsupported model family: {model_family}")
45
- suffix = "-Instruct" if is_instruct_model else ""
46
- model_path = f"{model_path}{suffix}"
47
- return model_path
48
-
49
-
50
- class TextTokenizer:
51
- """
52
- Text tokenizer class built on HuggingFace's Fast Tokenizer (Rust based).
53
- """
54
-
55
- def __init__(
56
- self,
57
- model_family: str,
58
- is_instruct_model: bool,
59
- local_path: Optional[str] = None,
60
- ):
61
- """
62
- Initialize the TextTokenizer.
63
- Args:
64
- model_family (str): The model family.
65
- is_instruct_model (bool): Whether the model is an instruct model.
66
- local_path (Optional[str]): The local path to the tokenizer. If not provided, the tokenizer will be downloaded from the remote path.
67
- """
68
- if local_path is None:
69
- tokenizer_path = get_tokenizer_path(model_family, is_instruct_model)
70
- else:
71
- tokenizer_path = local_path
72
-
73
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
74
- self.stop_tokens = {
75
- self.tokenizer.eos_token_id,
76
- }
77
- self.model_family = model_family
78
- self.is_instruct_model = is_instruct_model
79
- self.eos_id = self.tokenizer.eos_token_id
80
- if self.tokenizer.pad_token is None:
81
- if model_family.startswith("llama"):
82
- self.pad_id = 128004 # "<|finetune_right_pad_id|>"
83
- elif model_family == "mistral":
84
- self.pad_id = 10 # "<pad>"
85
- elif model_family == "pixtral":
86
- self.pad_id = 11 # "<pad>"
87
- else:
88
- raise ValueError(f"pad_id not defined for model_family {model_family}")
89
- else:
90
- self.pad_id = self.tokenizer.pad_token_id
91
-
92
- def tokenize(self, text: str, *, add_special_tokens: bool = False, **kwargs) -> List[str]:
93
- """
94
- Converts a string into a sequence of tokens, replacing unknown tokens with the `unk_token`.
95
-
96
- Args:
97
- text (`str`):
98
- The sequence to be encoded.
99
- add_special_tokens (`bool`, *optional*, defaults to `False`):
100
- Whether or not to add the special tokens associated with the corresponding model.
101
- Returns:
102
- `List[str]`: The list of tokens.
103
- """
104
- return self.tokenizer.tokenize(text, add_special_tokens=add_special_tokens, **kwargs)
105
-
106
- def encode(
107
- self,
108
- text: Union[str, List[str], List[int]],
109
- *, # Enforce keyword-only arguments
110
- add_special_tokens: bool = True,
111
- padding: Union[bool, str] = False,
112
- truncation: Union[bool, str] = None,
113
- max_length: Optional[int] = None,
114
- stride: int = 0,
115
- return_tensors: Optional[str] = None,
116
- **kwargs,
117
- ) -> List[int]:
118
- """
119
- Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.
120
-
121
- Args:
122
- text (`str`, `List[str]` or `List[int]`):
123
- The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the
124
- `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
125
- method).
126
- add_special_tokens (`bool`, *optional*, defaults to `True`):
127
- Whether or not to add special tokens when encoding the sequences. This will use the underlying
128
- `PretrainedTokenizerBase.build_inputs_with_special_tokens` function, which defines which tokens are
129
- automatically added to the input ids. This is usefull if you want to add `bos` or `eos` tokens
130
- automatically.
131
- padding (`bool`, `str`, *optional*, defaults to `False`):
132
- Activates and controls padding. Accepts the following values:
133
-
134
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
135
- sequence if provided).
136
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
137
- acceptable input length for the model if that argument is not provided.
138
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
139
- lengths).
140
- truncation (`bool`, `str`, *optional*, defaults to `False`):
141
- Activates and controls truncation. Accepts the following values:
142
-
143
- - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
144
- to the maximum acceptable input length for the model if that argument is not provided. This will
145
- truncate token by token, removing a token from the longest sequence in the pair if a pair of
146
- sequences (or a batch of pairs) is provided.
147
- - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
148
- maximum acceptable input length for the model if that argument is not provided. This will only
149
- truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
150
- - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
151
- maximum acceptable input length for the model if that argument is not provided. This will only
152
- truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
153
- - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
154
- greater than the model maximum admissible input size).
155
- max_length (`int`, *optional*):
156
- Controls the maximum length to use by one of the truncation/padding parameters.
157
-
158
- If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
159
- is required by one of the truncation/padding parameters. If the model has no specific maximum input
160
- length (like XLNet) truncation/padding to a maximum length will be deactivated.
161
- stride (`int`, *optional*, defaults to 0):
162
- If set to a number along with `max_length`, the overflowing tokens returned when
163
- `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
164
- returned to provide some overlap between truncated and overflowing sequences. The value of this
165
- argument defines the number of overlapping tokens.
166
- is_split_into_words (`bool`, *optional*, defaults to `False`):
167
- Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
168
- tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
169
- which it will tokenize. This is useful for NER or token classification.
170
- pad_to_multiple_of (`int`, *optional*):
171
- If set will pad the sequence to a multiple of the provided value. Requires `padding` to be activated.
172
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
173
- `>= 7.5` (Volta).
174
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
175
- If set, will return tensors instead of list of python integers. Acceptable values are:
176
-
177
- - `'tf'`: Return TensorFlow `tf.constant` objects.
178
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
179
- - `'np'`: Return Numpy `np.ndarray` objects.
180
- """
181
- return self.tokenizer.encode(
182
- text,
183
- add_special_tokens=add_special_tokens,
184
- padding=padding,
185
- truncation=truncation,
186
- max_length=max_length,
187
- stride=stride,
188
- return_tensors=return_tensors,
189
- )
190
-
191
- def decode(
192
- self,
193
- token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor"],
194
- *, # Enforce keyword-only arguments
195
- skip_special_tokens: bool = False,
196
- clean_up_tokenization_spaces: bool = None,
197
- **kwargs,
198
- ) -> str:
199
- """
200
- Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
201
- tokens and clean up tokenization spaces.
202
-
203
- Args:
204
- token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
205
- List of tokenized input ids. Can be obtained using the `__call__` method.
206
- skip_special_tokens (`bool`, *optional*, defaults to `False`):
207
- Whether or not to remove special tokens in the decoding.
208
- clean_up_tokenization_spaces (`bool`, *optional*):
209
- Whether or not to clean up the tokenization spaces. If `None`, will default to
210
- `self.clean_up_tokenization_spaces`.
211
- kwargs (additional keyword arguments, *optional*):
212
- Will be passed to the underlying model specific decode method.
213
-
214
- Returns:
215
- `str`: The decoded sentence.
216
- """
217
- return self.tokenizer.decode(
218
- token_ids,
219
- skip_special_tokens=skip_special_tokens,
220
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
221
- **kwargs,
222
- )
223
-
224
- def apply_chat_template(
225
- self,
226
- conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
227
- *,
228
- add_generation_prompt: bool = False,
229
- tokenize: bool = True,
230
- padding: bool = False,
231
- truncation: bool = False,
232
- max_length: Optional[int] = None,
233
- return_tensors: Optional[str] = None,
234
- return_dict: bool = False,
235
- return_assistant_tokens_mask: bool = False,
236
- generation_prefix: str = "",
237
- tokenizer_kwargs: Optional[Dict[str, Any]] = None,
238
- **kwargs,
239
- ):
240
- """
241
- Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token
242
- ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to determine the format and control tokens to use when converting.
243
-
244
- More details can be found at https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template
245
-
246
- Args:
247
- conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts
248
- with "role" and "content" keys, representing the chat history so far.
249
- add_generation_prompt (bool, *optional*):
250
- If this is set, a prompt with the token(s) that indicate
251
- the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
252
- Note that this argument will be passed to the chat template, and so it must be supported in the
253
- template for this argument to have any effect.
254
- continue_final_message (bool, *optional*):
255
- If this is set, the chat will be formatted so that the final
256
- message in the chat is open-ended, without any EOS tokens. The model will continue this message
257
- rather than starting a new one. This allows you to "prefill" part of
258
- the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
259
- tokenize (`bool`, defaults to `True`):
260
- Whether to tokenize the output. If `False`, the output will be a string.
261
- padding (`bool`, defaults to `False`):
262
- Whether to pad sequences to the maximum length. Has no effect if tokenize is `False`.
263
- truncation (`bool`, defaults to `False`):
264
- Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`.
265
- max_length (`int`, *optional*):
266
- Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If
267
- not specified, the tokenizer's `max_length` attribute will be used as a default.
268
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
269
- If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable
270
- values are:
271
- - `'tf'`: Return TensorFlow `tf.Tensor` objects.
272
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
273
- - `'np'`: Return NumPy `np.ndarray` objects.
274
- - `'jax'`: Return JAX `jnp.ndarray` objects.
275
- return_dict (`bool`, defaults to `False`):
276
- Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
277
- generation_prefix (str): Prefix to add before asking model to generate. Helpful to guide the generation. Defaults to "".
278
- tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
279
- return_assistant_tokens_mask (`bool`, defaults to `False`):
280
- Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
281
- the mask will contain 1. For user and system tokens, the mask will contain 0.
282
- This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
283
- **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
284
-
285
- Returns:
286
- `Union[List[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This
287
- output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is
288
- set, will return a dict of tokenizer outputs instead.
289
- """
290
- if not self.is_instruct_model:
291
- raise ValueError(
292
- "apply_chat_template is only supported for instruct models. You should pass argument is_instruct_model=True to the TextTokenizer constructor."
293
- )
294
- # Since generation_prefix is added to the text in the end, ensure that the setting is correct
295
- if generation_prefix:
296
- assert not tokenize, "tokenize must be False when generation_prefix is provided."
297
- assert add_generation_prompt, "add_generation_prompt must be set when generation_prefix is provided."
298
- formatted_text: Union[str, List[int]] = self.tokenizer.apply_chat_template(
299
- conversation,
300
- add_generation_prompt=add_generation_prompt,
301
- tokenize=tokenize,
302
- padding=padding,
303
- truncation=truncation,
304
- max_length=max_length,
305
- return_tensors=return_tensors,
306
- return_dict=return_dict,
307
- return_assistant_tokens_mask=return_assistant_tokens_mask,
308
- tokenizer_kwargs=tokenizer_kwargs,
309
- **kwargs,
310
- )
311
- if generation_prefix:
312
- formatted_text: str = formatted_text + generation_prefix
313
- log.debug(
314
- f"Adding generation prefix: {generation_prefix} to the formatted text\n"
315
- f"Formatted text: {formatted_text}"
316
- )
317
- return formatted_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/9bd252316a4bd6fb3a8f8a1c29a8e9ac44ac76fe DELETED
@@ -1,60 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- import os
3
-
4
- from omegaconf import DictConfig, OmegaConf
5
-
6
- from .lazy_instantiate import instantiate
7
- from .lazy import LazyCall, LazyConfig
8
- from .lazy_omegaconf_patch import to_object
9
-
10
- OmegaConf.to_object = to_object
11
-
12
- PLACEHOLDER = None
13
- LazyDict = DictConfig
14
-
15
- __all__ = ["instantiate", "LazyCall", "LazyConfig", "PLACEHOLDER", "LazyDict"]
16
-
17
-
18
- DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py
19
-
20
-
21
- def fixup_module_metadata(module_name, namespace, keys=None):
22
- """
23
- Fix the __qualname__ of module members to be their exported api name, so
24
- when they are referenced in docs, sphinx can find them. Reference:
25
- https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241
26
- """
27
- if not DOC_BUILDING:
28
- return
29
- seen_ids = set()
30
-
31
- def fix_one(qualname, name, obj):
32
- # avoid infinite recursion (relevant when using
33
- # typing.Generic, for example)
34
- if id(obj) in seen_ids:
35
- return
36
- seen_ids.add(id(obj))
37
-
38
- mod = getattr(obj, "__module__", None)
39
- if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")):
40
- obj.__module__ = module_name
41
- # Modules, unlike everything else in Python, put fully-qualitied
42
- # names into their __name__ attribute. We check for "." to avoid
43
- # rewriting these.
44
- if hasattr(obj, "__name__") and "." not in obj.__name__:
45
- obj.__name__ = name
46
- obj.__qualname__ = qualname
47
- if isinstance(obj, type):
48
- for attr_name, attr_value in obj.__dict__.items():
49
- fix_one(objname + "." + attr_name, attr_name, attr_value)
50
-
51
- if keys is None:
52
- keys = namespace.keys()
53
- for objname in keys:
54
- if not objname.startswith("_"):
55
- obj = namespace[objname]
56
- fix_one(objname, objname, obj)
57
-
58
-
59
- fixup_module_metadata(__name__, globals(), __all__)
60
- del fixup_module_metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/9d565d078fbe37e1d31cf8a445a460e2bae291f1 DELETED
@@ -1,224 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import argparse
17
- import os
18
-
19
- from .misc import misc
20
- import numpy as np
21
- import torch
22
- from pytorch_retinaface.data import cfg_re50
23
- from pytorch_retinaface.layers.functions.prior_box import PriorBox
24
- from pytorch_retinaface.models.retinaface import RetinaFace
25
- from torch.utils.data import DataLoader, TensorDataset
26
- from tqdm import tqdm
27
-
28
- from .guardrail_common_core import GuardrailRunner, PostprocessingGuardrail
29
- from .guardrail_common_io_utils import get_video_filepaths, read_video, save_video
30
- from .guardrail_face_blur_filter_blur_utils import pixelate_face
31
- from .guardrail_face_blur_filter_retinaface_utils import decode_batch, filter_detected_boxes, load_model
32
- from .log import log
33
-
34
- DEFAULT_RETINAFACE_CHECKPOINT = "checkpoints/Cosmos-1.0-Guardrail/face_blur_filter/Resnet50_Final.pth"
35
-
36
- # RetinaFace model constants from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py
37
- TOP_K = 5_000
38
- KEEP_TOP_K = 750
39
- NMS_THRESHOLD = 0.4
40
-
41
-
42
- class RetinaFaceFilter(PostprocessingGuardrail):
43
- def __init__(
44
- self,
45
- checkpoint: str = DEFAULT_RETINAFACE_CHECKPOINT,
46
- batch_size: int = 1,
47
- confidence_threshold: float = 0.7,
48
- device="cuda" if torch.cuda.is_available() else "cpu",
49
- ) -> None:
50
- """
51
- Initialize the RetinaFace model for face detection and blurring.
52
-
53
- Args:
54
- checkpoint: Path to the RetinaFace checkpoint file
55
- batch_size: Batch size for RetinaFace inference and processing
56
- confidence_threshold: Minimum confidence score to consider a face detection
57
- """
58
- self.cfg = cfg_re50
59
- self.batch_size = batch_size
60
- self.confidence_threshold = confidence_threshold
61
- self.device = device
62
- self.dtype = torch.float32
63
-
64
- # Disable loading ResNet pretrained weights
65
- self.cfg["pretrain"] = False
66
- self.net = RetinaFace(cfg=self.cfg, phase="test")
67
- cpu = self.device == "cpu"
68
-
69
- # Load from RetinaFace pretrained checkpoint
70
- self.net = load_model(self.net, checkpoint, cpu)
71
- self.net.to(self.device, dtype=self.dtype).eval()
72
-
73
- def preprocess_frames(self, frames: np.ndarray) -> torch.Tensor:
74
- """Preprocess a sequence of frames for face detection.
75
-
76
- Args:
77
- frames: Input frames
78
-
79
- Returns:
80
- Preprocessed frames tensor
81
- """
82
- with torch.no_grad():
83
- frames_tensor = torch.from_numpy(frames).to(self.device, dtype=self.dtype) # Shape: [T, H, W, C]
84
- frames_tensor = frames_tensor.permute(0, 3, 1, 2) # Shape: [T, C, H, W]
85
- frames_tensor = frames_tensor[:, [2, 1, 0], :, :] # RGB to BGR to match RetinaFace model input
86
- means = torch.tensor([104.0, 117.0, 123.0], device=self.device, dtype=self.dtype).view(1, 3, 1, 1)
87
- frames_tensor = frames_tensor - means # Subtract mean BGR values for each channel
88
- return frames_tensor
89
-
90
- def blur_detected_faces(
91
- self,
92
- frames: np.ndarray,
93
- batch_loc: torch.Tensor,
94
- batch_conf: torch.Tensor,
95
- prior_data: torch.Tensor,
96
- scale: torch.Tensor,
97
- min_size: tuple[int] = (20, 20),
98
- ) -> list[np.ndarray]:
99
- """Blur detected faces in a batch of frames using RetinaFace predictions.
100
-
101
- Args:
102
- frames: Input frames
103
- batch_loc: Batched location predictions
104
- batch_conf: Batched confidence scores
105
- prior_data: Prior boxes for the video
106
- scale: Scale factor for resizing detections
107
- min_size: Minimum size of a detected face region in pixels
108
-
109
- Returns:
110
- Processed frames with pixelated faces
111
- """
112
- with torch.no_grad():
113
- batch_boxes = decode_batch(batch_loc, prior_data, self.cfg["variance"])
114
- batch_boxes = batch_boxes * scale
115
-
116
- blurred_frames = []
117
- for i, boxes in enumerate(batch_boxes):
118
- boxes = boxes.detach().cpu().numpy()
119
- scores = batch_conf[i, :, 1].detach().cpu().numpy()
120
-
121
- filtered_boxes = filter_detected_boxes(
122
- boxes,
123
- scores,
124
- confidence_threshold=self.confidence_threshold,
125
- nms_threshold=NMS_THRESHOLD,
126
- top_k=TOP_K,
127
- keep_top_k=KEEP_TOP_K,
128
- )
129
-
130
- frame = frames[i]
131
- for box in filtered_boxes:
132
- x1, y1, x2, y2 = map(int, box)
133
- # Ignore bounding boxes smaller than the minimum size
134
- if x2 - x1 < min_size[0] or y2 - y1 < min_size[1]:
135
- continue
136
- max_h, max_w = frame.shape[:2]
137
- face_roi = frame[max(y1, 0) : min(y2, max_h), max(x1, 0) : min(x2, max_w)]
138
- blurred_face = pixelate_face(face_roi)
139
- frame[max(y1, 0) : min(y2, max_h), max(x1, 0) : min(x2, max_w)] = blurred_face
140
- blurred_frames.append(frame)
141
-
142
- return blurred_frames
143
-
144
- def postprocess(self, frames: np.ndarray) -> np.ndarray:
145
- """Blur faces in a sequence of frames.
146
-
147
- Args:
148
- frames: Input frames
149
-
150
- Returns:
151
- Processed frames with pixelated faces
152
- """
153
- # Create dataset and dataloader
154
- frames_tensor = self.preprocess_frames(frames)
155
- dataset = TensorDataset(frames_tensor)
156
- dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
157
- processed_frames, processed_batches = [], []
158
-
159
- prior_data, scale = None, None
160
- for i, batch in enumerate(dataloader):
161
- batch = batch[0]
162
- h, w = batch.shape[-2:] # Batch shape: [C, H, W]
163
-
164
- with torch.no_grad():
165
- # Generate priors for the video
166
- if prior_data is None:
167
- priorbox = PriorBox(self.cfg, image_size=(h, w))
168
- priors = priorbox.forward()
169
- priors = priors.to(self.device, dtype=self.dtype)
170
- prior_data = priors.data
171
-
172
- # Get scale for resizing detections
173
- if scale is None:
174
- scale = torch.Tensor([w, h, w, h])
175
- scale = scale.to(self.device, dtype=self.dtype)
176
-
177
- batch_loc, batch_conf, _ = self.net(batch)
178
-
179
- # Blur detected faces in each batch of frames
180
- start_idx = i * self.batch_size
181
- end_idx = min(start_idx + self.batch_size, len(frames))
182
- processed_batches.append(
183
- self.blur_detected_faces(frames[start_idx:end_idx], batch_loc, batch_conf, prior_data, scale)
184
- )
185
-
186
- processed_frames = [frame for batch in processed_batches for frame in batch]
187
- return np.array(processed_frames)
188
-
189
-
190
- def parse_args():
191
- parser = argparse.ArgumentParser()
192
- parser.add_argument("--input_dir", type=str, required=True, help="Path containing input videos")
193
- parser.add_argument("--output_dir", type=str, required=True, help="Path for saving processed videos")
194
- parser.add_argument(
195
- "--checkpoint",
196
- type=str,
197
- help="Path to the RetinaFace checkpoint file",
198
- default=DEFAULT_RETINAFACE_CHECKPOINT,
199
- )
200
- return parser.parse_args()
201
-
202
-
203
- def main(args):
204
- filepaths = get_video_filepaths(args.input_dir)
205
- if not filepaths:
206
- log.error(f"No video files found in directory: {args.input_dir}")
207
- return
208
-
209
- face_blur = RetinaFaceFilter(checkpoint=args.checkpoint)
210
- postprocessing_runner = GuardrailRunner(postprocessors=[face_blur])
211
- os.makedirs(args.output_dir, exist_ok=True)
212
-
213
- for filepath in tqdm(filepaths):
214
- video_data = read_video(filepath)
215
- with misc.timer("face blur filter"):
216
- frames = postprocessing_runner.postprocess(video_data.frames)
217
-
218
- output_path = os.path.join(args.output_dir, os.path.basename(filepath))
219
- save_video(output_path, frames, video_data.fps)
220
-
221
-
222
- if __name__ == "__main__":
223
- args = parse_args()
224
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/a209db0eba28a8d8bcb527bfbaca6f5e361ace14 DELETED
@@ -1,28 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from __future__ import annotations
17
-
18
- from dataclasses import dataclass
19
- from typing import Optional
20
-
21
- import torch
22
-
23
-
24
- @dataclass
25
- class DenoisePrediction:
26
- x0: torch.Tensor # clean data prediction
27
- eps: Optional[torch.Tensor] = None # noise prediction
28
- logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/a2496a4fa280586b62c846c54cfbbc9f8adc0331 DELETED
@@ -1,211 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from __future__ import annotations
17
-
18
- import collections
19
- import collections.abc
20
- import functools
21
- import json
22
- import random
23
- import time
24
- from contextlib import ContextDecorator
25
- from typing import Any, Callable, TypeVar
26
-
27
- from .log import log
28
- import numpy as np
29
- import termcolor
30
- import torch
31
-
32
- from .distributed import distributed
33
-
34
-
35
- class misc():
36
-
37
- @staticmethod
38
- def to(
39
- data: Any,
40
- device: str | torch.device | None = None,
41
- dtype: torch.dtype | None = None,
42
- memory_format: torch.memory_format = torch.preserve_format,
43
- ) -> Any:
44
- """Recursively cast data into the specified device, dtype, and/or memory_format.
45
-
46
- The input data can be a tensor, a list of tensors, a dict of tensors.
47
- See the documentation for torch.Tensor.to() for details.
48
-
49
- Args:
50
- data (Any): Input data.
51
- device (str | torch.device): GPU device (default: None).
52
- dtype (torch.dtype): data type (default: None).
53
- memory_format (torch.memory_format): memory organization format (default: torch.preserve_format).
54
-
55
- Returns:
56
- data (Any): Data cast to the specified device, dtype, and/or memory_format.
57
- """
58
- assert (
59
- device is not None or dtype is not None or memory_format is not None
60
- ), "at least one of device, dtype, memory_format should be specified"
61
- if isinstance(data, torch.Tensor):
62
- is_cpu = (isinstance(device, str) and device == "cpu") or (
63
- isinstance(device, torch.device) and device.type == "cpu"
64
- )
65
- data = data.to(
66
- device=device,
67
- dtype=dtype,
68
- memory_format=memory_format,
69
- non_blocking=(not is_cpu),
70
- )
71
- return data
72
- elif isinstance(data, collections.abc.Mapping):
73
- return type(data)({key: to(data[key], device=device, dtype=dtype, memory_format=memory_format) for key in data})
74
- elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):
75
- return type(data)([to(elem, device=device, dtype=dtype, memory_format=memory_format) for elem in data])
76
- else:
77
- return data
78
-
79
- @staticmethod
80
- def serialize(data: Any) -> Any:
81
- """Serialize data by hierarchically traversing through iterables.
82
-
83
- Args:
84
- data (Any): Input data.
85
-
86
- Returns:
87
- data (Any): Serialized data.
88
- """
89
- if isinstance(data, collections.abc.Mapping):
90
- return type(data)({key: serialize(data[key]) for key in data})
91
- elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):
92
- return type(data)([serialize(elem) for elem in data])
93
- else:
94
- try:
95
- json.dumps(data)
96
- except TypeError:
97
- data = str(data)
98
- return data
99
-
100
- @staticmethod
101
- def set_random_seed(seed: int, by_rank: bool = False) -> None:
102
- """Set random seed. This includes random, numpy, Pytorch.
103
-
104
- Args:
105
- seed (int): Random seed.
106
- by_rank (bool): if true, each GPU will use a different random seed.
107
- """
108
- if by_rank:
109
- seed += distributed.get_rank()
110
- log.info(f"Using random seed {seed}.")
111
- random.seed(seed)
112
- np.random.seed(seed)
113
- torch.manual_seed(seed) # sets seed on the current CPU & all GPUs
114
-
115
- @staticmethod
116
- def arch_invariant_rand(
117
- shape: List[int] | Tuple[int], dtype: torch.dtype, device: str | torch.device, seed: int | None = None
118
- ):
119
- """Produce a GPU-architecture-invariant randomized Torch tensor.
120
-
121
- Args:
122
- shape (list or tuple of ints): Output tensor shape.
123
- dtype (torch.dtype): Output tensor type.
124
- device (torch.device): Device holding the output.
125
- seed (int): Optional randomization seed.
126
-
127
- Returns:
128
- tensor (torch.tensor): Randomly-generated tensor.
129
- """
130
- # Create a random number generator, optionally seeded
131
- rng = np.random.RandomState(seed)
132
-
133
- # # Generate random numbers using the generator
134
- random_array = rng.standard_normal(shape).astype(np.float32) # Use standard_normal for normal distribution
135
-
136
- # Convert to torch tensor and return
137
- return torch.from_numpy(random_array).to(dtype=dtype, device=device)
138
-
139
-
140
- T = TypeVar("T", bound=Callable[..., Any])
141
-
142
-
143
- class timer(ContextDecorator): # noqa: N801
144
- """Simple timer for timing the execution of code.
145
-
146
- It can be used as either a context manager or a function decorator. The timing result will be logged upon exit.
147
-
148
- Example:
149
- def func_a():
150
- time.sleep(1)
151
- with timer("func_a"):
152
- func_a()
153
-
154
- @timer("func_b)
155
- def func_b():
156
- time.sleep(1)
157
- func_b()
158
- """
159
-
160
- def __init__(self, context: str, debug: bool = False):
161
- self.context = context
162
- self.debug = debug
163
-
164
- def __enter__(self) -> None:
165
- self.tic = time.time()
166
-
167
- def __exit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001
168
- time_spent = time.time() - self.tic
169
- if self.debug:
170
- log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds")
171
- else:
172
- log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds")
173
-
174
- def __call__(self, func: T) -> T:
175
- @functools.wraps(func)
176
- def wrapper(*args, **kwargs): # noqa: ANN202
177
- tic = time.time()
178
- result = func(*args, **kwargs)
179
- time_spent = time.time() - tic
180
- if self.debug:
181
- log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds")
182
- else:
183
- log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds")
184
- return result
185
-
186
- return wrapper # type: ignore
187
-
188
-
189
- class Color:
190
- """A convenience class to colorize strings in the console.
191
-
192
- Example:
193
- import
194
- print("This is {Color.red('important')}.")
195
- """
196
-
197
- @staticmethod
198
- def red(x: str) -> str:
199
- return termcolor.colored(str(x), color="red")
200
-
201
- @staticmethod
202
- def green(x: str) -> str:
203
- return termcolor.colored(str(x), color="green")
204
-
205
- @staticmethod
206
- def cyan(x: str) -> str:
207
- return termcolor.colored(str(x), color="cyan")
208
-
209
- @staticmethod
210
- def yellow(x: str) -> str:
211
- return termcolor.colored(str(x), color="yellow")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache/models--NeverMore0123--AutoregressiveVideo2WorldGeneration/blobs/a24d1a0cbbe184ab0a2bfb5cbee13bfd327810ae DELETED
@@ -1,165 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from __future__ import annotations
17
-
18
- from typing import Any, TypeVar
19
-
20
- import attrs
21
-
22
- from .lazy_config_init import LazyDict
23
- from .misc import Color
24
-
25
- T = TypeVar("T")
26
-
27
-
28
- def _is_attrs_instance(obj: object) -> bool:
29
- """
30
- Helper function to check if an object is an instance of an attrs-defined class.
31
-
32
- Args:
33
- obj: The object to check.
34
-
35
- Returns:
36
- bool: True if the object is an instance of an attrs-defined class, False otherwise.
37
- """
38
- return hasattr(obj, "__attrs_attrs__")
39
-
40
-
41
- def make_freezable(cls: T) -> T:
42
- """
43
- A decorator that adds the capability to freeze instances of an attrs-defined class.
44
-
45
- NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need
46
- to hack on a "_is_frozen" attribute.
47
-
48
- This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime.
49
- Once an instance is frozen, its attributes cannot be changed. It also recursively freezes
50
- any attrs-defined objects that are attributes of the class.
51
-
52
- Usage:
53
- @make_freezable
54
- @attrs.define(slots=False)
55
- class MyClass:
56
- attribute1: int
57
- attribute2: str
58
-
59
- obj = MyClass(1, 'a')
60
- obj.freeze() # Freeze the instance
61
- obj.attribute1 = 2 # Raises AttributeError
62
-
63
- Args:
64
- cls: The class to be decorated.
65
-
66
- Returns:
67
- The decorated class with added freezing capability.
68
- """
69
-
70
- if not hasattr(cls, "__dict__"):
71
- raise TypeError(
72
- "make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped "
73
- "class was defined with `@attrs.define(slots=False)`"
74
- )
75
-
76
- original_setattr = cls.__setattr__
77
-
78
- def setattr_override(self, key, value) -> None: # noqa: ANN001
79
- """
80
- Override __setattr__ to allow modifications during initialization
81
- and prevent modifications once the instance is frozen.
82
- """
83
- if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen":
84
- raise AttributeError("Cannot modify frozen instance")
85
- original_setattr(self, key, value) # type: ignore
86
-
87
- cls.__setattr__ = setattr_override # type: ignore
88
-
89
- def freeze(self: object) -> None:
90
- """
91
- Freeze the instance and all its attrs-defined attributes.
92
- """
93
- for _, value in attrs.asdict(self, recurse=False).items():
94
- if _is_attrs_instance(value) and hasattr(value, "freeze"):
95
- value.freeze()
96
- self._is_frozen = True # type: ignore
97
-
98
- cls.freeze = freeze # type: ignore
99
-
100
- return cls
101
-
102
-
103
- def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str:
104
- """
105
- Recursively pretty prints attrs objects with color.
106
- """
107
-
108
- assert attrs.has(obj.__class__)
109
-
110
- lines: list[str] = []
111
- for attribute in attrs.fields(obj.__class__):
112
- value = getattr(obj, attribute.name)
113
- if attrs.has(value.__class__):
114
- if use_color:
115
- lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":")
116
- else:
117
- lines.append(" " * indent + "* " + attribute.name + ":")
118
- lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color))
119
- else:
120
- if use_color:
121
- lines.append(
122
- " " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value)
123
- )
124
- else:
125
- lines.append(" " * indent + "* " + attribute.name + ": " + str(value))
126
- return "\n".join(lines)
127
-
128
-
129
- @make_freezable
130
- @attrs.define(slots=False)
131
- class JobConfig:
132
- # Project name.
133
- project: str = ""
134
- # Experiment name.
135
- group: str = ""
136
- # Run/job name.
137
- name: str = ""
138
-
139
- @property
140
- def path(self) -> str:
141
- return f"{self.project}/{self.group}/{self.name}"
142
-
143
-
144
- @make_freezable
145
- @attrs.define(slots=False)
146
- class Config:
147
- """Config for a job.
148
-
149
- See /README.md/Configuration System for more info.
150
- """
151
-
152
- # Model configs.
153
- model: LazyDict
154
-
155
- # Training job configs.
156
- job: JobConfig = attrs.field(factory=JobConfig)
157
-
158
- def to_dict(self) -> dict[str, Any]:
159
- return attrs.asdict(self)
160
-
161
- def validate(self) -> None:
162
- """Validate that the config has all required fields."""
163
- assert self.job.project != "", "Project name is required."
164
- assert self.job.group != "", "Group name is required."
165
- assert self.job.name != "", "Job name is required."