diff --git a/cosmos1/models/autoregressive/configs/base/model.py b/ar_config_base_model.py similarity index 98% rename from cosmos1/models/autoregressive/configs/base/model.py rename to ar_config_base_model.py index e9f1c0f3fb45d85fab1c2259090bdc794aee7b44..d0ddd57133faf49bf99cf67231c528fbf8d543d0 100644 --- a/cosmos1/models/autoregressive/configs/base/model.py +++ b/ar_config_base_model.py @@ -17,7 +17,7 @@ from typing import Optional import attrs -from cosmos1.models.autoregressive.configs.base.tokenizer import TokenizerConfig +from AutoregressiveVideo2WorldGeneration.ar_config_base_tokenizer import TokenizerConfig @attrs.define diff --git a/cosmos1/models/autoregressive/configs/base/model_config.py b/ar_config_base_model_config.py similarity index 97% rename from cosmos1/models/autoregressive/configs/base/model_config.py rename to ar_config_base_model_config.py index 7c16be1b6d30426cb9af3498aa2d51fa7c451696..2676622fe75a969184d93d2cfd9891b7464f0c19 100644 --- a/cosmos1/models/autoregressive/configs/base/model_config.py +++ b/ar_config_base_model_config.py @@ -16,17 +16,17 @@ import copy from typing import Callable, List, Optional -from cosmos1.models.autoregressive.configs.base.model import ModelConfig -from cosmos1.models.autoregressive.configs.base.tokenizer import ( +from AutoregressiveVideo2WorldGeneration.ar_config_base_model import ModelConfig +from AutoregressiveVideo2WorldGeneration.ar_config_base_tokenizer import ( TextTokenizerConfig, TokenizerConfig, VideoTokenizerConfig, create_discrete_video_fsq_tokenizer_state_dict_config, ) -from cosmos1.models.autoregressive.tokenizer.image_text_tokenizer import ImageTextTokenizer -from cosmos1.models.autoregressive.tokenizer.text_tokenizer import TextTokenizer -from cosmos1.utils import log -from cosmos1.utils.lazy_config import LazyCall as L +from AutoregressiveVideo2WorldGeneration.ar_tokenizer_image_text_tokenizer import ImageTextTokenizer +from AutoregressiveVideo2WorldGeneration.ar_tokenizer_text_tokenizer import TextTokenizer +from AutoregressiveVideo2WorldGeneration import log +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L # Common architecture specifications BASE_CONFIG = {"n_kv_heads": 8, "norm_type": "rmsnorm", "norm_eps": 1e-5, "ffn_hidden_size": 14336} diff --git a/cosmos1/models/autoregressive/configs/base/tokenizer.py b/ar_config_base_tokenizer.py similarity index 93% rename from cosmos1/models/autoregressive/configs/base/tokenizer.py rename to ar_config_base_tokenizer.py index 3b8609614eee2921504dba117e3e89f710ba346a..cba1b056ae44746189d1d2ba58f35062968b629c 100644 --- a/cosmos1/models/autoregressive/configs/base/tokenizer.py +++ b/ar_config_base_tokenizer.py @@ -17,10 +17,10 @@ from typing import Optional import attrs -from cosmos1.models.autoregressive.tokenizer.discrete_video import DiscreteVideoFSQStateDictTokenizer -from cosmos1.models.autoregressive.tokenizer.networks import CausalDiscreteVideoTokenizer -from cosmos1.utils.lazy_config import LazyCall as L -from cosmos1.utils.lazy_config import LazyDict +from AutoregressiveVideo2WorldGeneration.ar_tokenizer_discrete_video import DiscreteVideoFSQStateDictTokenizer +from AutoregressiveVideo2WorldGeneration.ar_tokenizer_networks import CausalDiscreteVideoTokenizer +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict def create_discrete_video_fsq_tokenizer_state_dict_config( diff --git a/cosmos1/models/autoregressive/configs/inference/inference_config.py b/ar_config_inference_inference_config.py similarity index 97% rename from cosmos1/models/autoregressive/configs/inference/inference_config.py rename to ar_config_inference_inference_config.py index 6ff2ee93cbeb016f5c472952cb0dbcd1bab4e3fc..666b72721e1d5d0cc1b3f1e8527e402545d556f8 100644 --- a/cosmos1/models/autoregressive/configs/inference/inference_config.py +++ b/ar_config_inference_inference_config.py @@ -17,7 +17,7 @@ from typing import Any, List, Union import attrs -from cosmos1.models.autoregressive.configs.base.model import ModelConfig, TokenizerConfig +from AutoregressiveVideo2WorldGeneration.ar_config_base_model import ModelConfig, TokenizerConfig @attrs.define(slots=False) diff --git a/cosmos1/models/autoregressive/diffusion_decoder/config/base/conditioner.py b/ar_diffusion_decoder_config_base_conditioner.py similarity index 85% rename from cosmos1/models/autoregressive/diffusion_decoder/config/base/conditioner.py rename to ar_diffusion_decoder_config_base_conditioner.py index 8124ee46e4be5608c83df6459f665dad6e9642a7..d4f876dd2302d090e0b4943b1180999beddf36a6 100644 --- a/cosmos1/models/autoregressive/diffusion_decoder/config/base/conditioner.py +++ b/ar_diffusion_decoder_config_base_conditioner.py @@ -18,8 +18,8 @@ from typing import Dict, Optional import torch -from cosmos1.models.diffusion.conditioner import BaseVideoCondition, GeneralConditioner -from cosmos1.models.diffusion.config.base.conditioner import ( +from AutoregressiveVideo2WorldGeneration.df_conditioner import BaseVideoCondition, GeneralConditioner +from AutoregressiveVideo2WorldGeneration.df_config_base_conditioner import ( FPSConfig, ImageSizeConfig, LatentConditionConfig, @@ -28,8 +28,8 @@ from cosmos1.models.diffusion.config.base.conditioner import ( PaddingMaskConfig, TextConfig, ) -from cosmos1.utils.lazy_config import LazyCall as L -from cosmos1.utils.lazy_config import LazyDict +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict @dataclass diff --git a/cosmos1/models/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py b/ar_diffusion_decoder_config_config_latent_diffusion_decoder.py similarity index 81% rename from cosmos1/models/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py rename to ar_diffusion_decoder_config_config_latent_diffusion_decoder.py index f820b76d12107d700d33f7e19d4aae62049e7e31..531817956eb31490bb8668d6bfc8d551cc484e5a 100644 --- a/cosmos1/models/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py +++ b/ar_diffusion_decoder_config_config_latent_diffusion_decoder.py @@ -17,11 +17,11 @@ from typing import Any, List import attrs -from cosmos1.models.autoregressive.diffusion_decoder.config.registry import register_configs as register_dd_configs -from cosmos1.models.diffusion.config.base.model import LatentDiffusionDecoderModelConfig -from cosmos1.models.diffusion.config.registry import register_configs -from cosmos1.utils import config -from cosmos1.utils.config_helper import import_all_modules_from_package +from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_config_registry import register_configs as register_dd_configs +from AutoregressiveVideo2WorldGeneration.df_config_base_model import LatentDiffusionDecoderModelConfig +from AutoregressiveVideo2WorldGeneration.df_config_registry import register_configs +from AutoregressiveVideo2WorldGeneration import config +from AutoregressiveVideo2WorldGeneration.config_helper import import_all_modules_from_package @attrs.define(slots=False) diff --git a/cosmos1/models/autoregressive/diffusion_decoder/config/inference/cosmos_diffusiondecoder_7b.py b/ar_diffusion_decoder_config_inference_cosmos_diffusiondecoder_7b.py similarity index 92% rename from cosmos1/models/autoregressive/diffusion_decoder/config/inference/cosmos_diffusiondecoder_7b.py rename to ar_diffusion_decoder_config_inference_cosmos_diffusiondecoder_7b.py index ad296f42b5317ca2a4a26e21ca32bf1d952566d1..7ef9ba0e663449568f1601905e75f4cf125824bc 100644 --- a/cosmos1/models/autoregressive/diffusion_decoder/config/inference/cosmos_diffusiondecoder_7b.py +++ b/ar_diffusion_decoder_config_inference_cosmos_diffusiondecoder_7b.py @@ -15,9 +15,9 @@ from hydra.core.config_store import ConfigStore -from cosmos1.models.autoregressive.diffusion_decoder.network import DiffusionDecoderGeneralDIT -from cosmos1.utils.lazy_config import LazyCall as L -from cosmos1.utils.lazy_config import LazyDict +from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_network import DiffusionDecoderGeneralDIT +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict num_frames = 57 Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY: LazyDict = LazyDict( diff --git a/cosmos1/models/autoregressive/diffusion_decoder/config/registry.py b/ar_diffusion_decoder_config_registry.py similarity index 90% rename from cosmos1/models/autoregressive/diffusion_decoder/config/registry.py rename to ar_diffusion_decoder_config_registry.py index b835fc06da23f9aef4bab0ad5aa9dc9b9a2b43d4..42849b6675be25e900f395315775e0a97b32c1f0 100644 --- a/cosmos1/models/autoregressive/diffusion_decoder/config/registry.py +++ b/ar_diffusion_decoder_config_registry.py @@ -15,12 +15,12 @@ from hydra.core.config_store import ConfigStore -from cosmos1.models.autoregressive.diffusion_decoder.config.base.conditioner import ( +from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_config_base_conditioner import ( VideoLatentDiffusionDecoderConditionerConfig, ) -from cosmos1.models.autoregressive.tokenizer.discrete_video import DiscreteVideoFSQJITTokenizer -from cosmos1.models.diffusion.module.pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer -from cosmos1.utils.lazy_config import LazyCall as L +from AutoregressiveVideo2WorldGeneration.ar_tokenizer_discrete_video import DiscreteVideoFSQJITTokenizer +from AutoregressiveVideo2WorldGeneration.df_module_pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L def get_cosmos_video_discrete_tokenizer_comp8x16x16( diff --git a/cosmos1/models/autoregressive/diffusion_decoder/inference.py b/ar_diffusion_decoder_inference.py similarity index 92% rename from cosmos1/models/autoregressive/diffusion_decoder/inference.py rename to ar_diffusion_decoder_inference.py index b923840956955a1913263d4a8151d14993a75b71..2eec6f557896bc2648296751b0d309c8c157d283 100644 --- a/cosmos1/models/autoregressive/diffusion_decoder/inference.py +++ b/ar_diffusion_decoder_inference.py @@ -19,10 +19,10 @@ from typing import List import torch -from cosmos1.models.autoregressive.configs.inference.inference_config import DiffusionDecoderSamplingConfig -from cosmos1.models.autoregressive.diffusion_decoder.model import LatentDiffusionDecoderModel -from cosmos1.models.autoregressive.diffusion_decoder.utils import linear_blend_video_list, split_with_overlap -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration.ar_config_inference_inference_config import DiffusionDecoderSamplingConfig +from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_model import LatentDiffusionDecoderModel +from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_utils import linear_blend_video_list, split_with_overlap +from AutoregressiveVideo2WorldGeneration import log def diffusion_decoder_process_tokens( diff --git a/cosmos1/models/autoregressive/diffusion_decoder/model.py b/ar_diffusion_decoder_model.py similarity index 95% rename from cosmos1/models/autoregressive/diffusion_decoder/model.py rename to ar_diffusion_decoder_model.py index 50f4ea81a5ade75e18cccb240f599aa4a0c789cd..6de3d809ab932b0f58aaf5f18134f7b2d9718ac5 100644 --- a/cosmos1/models/autoregressive/diffusion_decoder/model.py +++ b/ar_diffusion_decoder_model.py @@ -19,11 +19,11 @@ from typing import Callable, Dict, Optional, Tuple import torch from torch import Tensor -from cosmos1.models.diffusion.conditioner import BaseVideoCondition -from cosmos1.models.diffusion.diffusion.functional.batch_ops import batch_mul -from cosmos1.models.diffusion.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS -from cosmos1.models.diffusion.model.model_t2w import DiffusionT2WModel as VideoDiffusionModel -from cosmos1.utils.lazy_config import instantiate as lazy_instantiate +from AutoregressiveVideo2WorldGeneration.df_conditioner import BaseVideoCondition +from AutoregressiveVideo2WorldGeneration.df_df_functional_batch_ops import batch_mul +from AutoregressiveVideo2WorldGeneration.df_df_module_res_sampler import COMMON_SOLVER_OPTIONS +from AutoregressiveVideo2WorldGeneration.df_model_model_t2w import DiffusionT2WModel as VideoDiffusionModel +from AutoregressiveVideo2WorldGeneration.lazy_config_init import instantiate as lazy_instantiate @dataclass diff --git a/cosmos1/models/autoregressive/diffusion_decoder/network.py b/ar_diffusion_decoder_network.py similarity index 97% rename from cosmos1/models/autoregressive/diffusion_decoder/network.py rename to ar_diffusion_decoder_network.py index e3c114d520677a67e15fb126aad82250c7bbec44..5208dce7369026ad81b6908a3f3a5ecbf5fa4ad2 100644 --- a/cosmos1/models/autoregressive/diffusion_decoder/network.py +++ b/ar_diffusion_decoder_network.py @@ -20,8 +20,8 @@ from einops import rearrange from torch import nn from torchvision import transforms -from cosmos1.models.diffusion.module.blocks import PatchEmbed -from cosmos1.models.diffusion.networks.general_dit import GeneralDIT +from AutoregressiveVideo2WorldGeneration.df_module_blocks import PatchEmbed +from AutoregressiveVideo2WorldGeneration.df_network_general_dit import GeneralDIT class DiffusionDecoderGeneralDIT(GeneralDIT): diff --git a/cosmos1/models/autoregressive/diffusion_decoder/utils.py b/ar_diffusion_decoder_utils.py similarity index 100% rename from cosmos1/models/autoregressive/diffusion_decoder/utils.py rename to ar_diffusion_decoder_utils.py diff --git a/cosmos1/models/autoregressive/model.py b/ar_model.py similarity index 96% rename from cosmos1/models/autoregressive/model.py rename to ar_model.py index 195de093d8fc4cba98b1255fce9bcbe2800c75e7..9527de29c60867fd95c4a1fd8a96e5eb15d7073e 100644 --- a/cosmos1/models/autoregressive/model.py +++ b/ar_model.py @@ -19,23 +19,24 @@ import time from pathlib import Path from typing import Any, Dict, List, Optional, Set +from AutoregressiveVideo2WorldGeneration import misc import torch from safetensors.torch import load_file from torch.nn.modules.module import _IncompatibleKeys -from cosmos1.models.autoregressive.configs.base.model import ModelConfig -from cosmos1.models.autoregressive.configs.base.tokenizer import TokenizerConfig -from cosmos1.models.autoregressive.modules.mm_projector import MultimodalProjector -from cosmos1.models.autoregressive.networks.transformer import Transformer -from cosmos1.models.autoregressive.networks.vit import VisionTransformer, get_vit_config -from cosmos1.models.autoregressive.tokenizer.tokenizer import DiscreteMultimodalTokenizer, update_vocab_size -from cosmos1.models.autoregressive.utils.checkpoint import ( +from AutoregressiveVideo2WorldGeneration.ar_config_base_model import ModelConfig +from AutoregressiveVideo2WorldGeneration.ar_config_base_tokenizer import TokenizerConfig +from AutoregressiveVideo2WorldGeneration.ar_module_mm_projector import MultimodalProjector +from AutoregressiveVideo2WorldGeneration.ar_network_transformer import Transformer +from AutoregressiveVideo2WorldGeneration.ar_network_vit import VisionTransformer, get_vit_config +from AutoregressiveVideo2WorldGeneration.ar_tokenizer_tokenizer import DiscreteMultimodalTokenizer, update_vocab_size +from AutoregressiveVideo2WorldGeneration.ar_utils_checkpoint import ( get_partial_state_dict, process_state_dict, substrings_to_ignore, ) -from cosmos1.models.autoregressive.utils.sampling import decode_n_tokens, decode_one_token, prefill -from cosmos1.utils import log, misc +from AutoregressiveVideo2WorldGeneration.ar_utils_sampling import decode_n_tokens, decode_one_token, prefill +from AutoregressiveVideo2WorldGeneration import log class AutoRegressiveModel(torch.nn.Module): diff --git a/cosmos1/models/autoregressive/modules/attention.py b/ar_module_attention.py similarity index 98% rename from cosmos1/models/autoregressive/modules/attention.py rename to ar_module_attention.py index 78d15826c96a449308747af75ace0b3e82043f76..bf13847edfa3881494401debdd65890beb59b2e3 100644 --- a/cosmos1/models/autoregressive/modules/attention.py +++ b/ar_module_attention.py @@ -19,8 +19,8 @@ from typing import Optional, Union import torch from torch import nn -from cosmos1.models.autoregressive.modules.embedding import RotaryPositionEmbedding -from cosmos1.models.autoregressive.modules.normalization import create_norm +from AutoregressiveVideo2WorldGeneration.ar_module_embedding import RotaryPositionEmbedding +from AutoregressiveVideo2WorldGeneration.ar_module_normalization import create_norm class Attention(nn.Module): diff --git a/cosmos1/models/autoregressive/modules/embedding.py b/ar_module_embedding.py similarity index 100% rename from cosmos1/models/autoregressive/modules/embedding.py rename to ar_module_embedding.py diff --git a/cosmos1/models/autoregressive/modules/mlp.py b/ar_module_mlp.py similarity index 100% rename from cosmos1/models/autoregressive/modules/mlp.py rename to ar_module_mlp.py diff --git a/cosmos1/models/autoregressive/modules/mm_projector.py b/ar_module_mm_projector.py similarity index 100% rename from cosmos1/models/autoregressive/modules/mm_projector.py rename to ar_module_mm_projector.py diff --git a/cosmos1/models/autoregressive/modules/normalization.py b/ar_module_normalization.py similarity index 100% rename from cosmos1/models/autoregressive/modules/normalization.py rename to ar_module_normalization.py diff --git a/cosmos1/models/autoregressive/networks/transformer.py b/ar_network_transformer.py similarity index 97% rename from cosmos1/models/autoregressive/networks/transformer.py rename to ar_network_transformer.py index 66cb8ff6ab49eb0e3b12fc50757e107e27c3e599..bc48dd0723f084c474b55206a2c44ede297708c2 100644 --- a/cosmos1/models/autoregressive/networks/transformer.py +++ b/ar_network_transformer.py @@ -19,17 +19,17 @@ import torch import torch.nn as nn from torch.nn.modules.module import _IncompatibleKeys -from cosmos1.models.autoregressive.modules.attention import Attention -from cosmos1.models.autoregressive.modules.embedding import ( +from AutoregressiveVideo2WorldGeneration.ar_module_attention import Attention +from AutoregressiveVideo2WorldGeneration.ar_module_embedding import ( RotaryPositionEmbeddingPytorchV1, RotaryPositionEmbeddingPytorchV2, SinCosPosEmbAxisTE, ) -from cosmos1.models.autoregressive.modules.mlp import MLP -from cosmos1.models.autoregressive.modules.normalization import create_norm -from cosmos1.models.autoregressive.utils.checkpoint import process_state_dict, substrings_to_ignore -from cosmos1.models.autoregressive.utils.misc import maybe_convert_to_namespace -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration.ar_module_mlp import MLP +from AutoregressiveVideo2WorldGeneration.ar_module_normalization import create_norm +from AutoregressiveVideo2WorldGeneration.ar_utils_checkpoint import process_state_dict, substrings_to_ignore +from AutoregressiveVideo2WorldGeneration.ar_utils_misc import maybe_convert_to_namespace +from AutoregressiveVideo2WorldGeneration import log class TransformerBlock(nn.Module): diff --git a/cosmos1/models/autoregressive/networks/vit.py b/ar_network_vit.py similarity index 98% rename from cosmos1/models/autoregressive/networks/vit.py rename to ar_network_vit.py index 25d0c4850bafdfaba29c2abc14b2da05578ee23e..5938979e6be4321ff52c518d0778601c6db7ab04 100644 --- a/cosmos1/models/autoregressive/networks/vit.py +++ b/ar_network_vit.py @@ -26,9 +26,9 @@ from typing import Any, Callable, Mapping, Optional, Tuple import torch import torch.nn as nn -from cosmos1.models.autoregressive.modules.normalization import create_norm -from cosmos1.models.autoregressive.networks.transformer import TransformerBlock -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration.ar_module_normalization import create_norm +from AutoregressiveVideo2WorldGeneration.ar_network_transformer import TransformerBlock +from AutoregressiveVideo2WorldGeneration import log def get_vit_config(model_name: str) -> Mapping[str, Any]: diff --git a/cosmos1/models/autoregressive/tokenizer/discrete_video.py b/ar_tokenizer_discrete_video.py similarity index 99% rename from cosmos1/models/autoregressive/tokenizer/discrete_video.py rename to ar_tokenizer_discrete_video.py index 477c9ac9f80832448f73379bbd6c67e29a2f40da..64a675d99a330c20f1a7de7c1dba85c58eda1afb 100644 --- a/cosmos1/models/autoregressive/tokenizer/discrete_video.py +++ b/ar_tokenizer_discrete_video.py @@ -18,7 +18,7 @@ from typing import Optional import torch from einops import rearrange -from cosmos1.models.autoregressive.tokenizer.quantizers import FSQuantizer +from AutoregressiveVideo2WorldGeneration.ar_tokenizer_quantizers import FSQuantizer # Make sure jit model output consistenly during consecutive calls # Check here: https://github.com/pytorch/pytorch/issues/74534 diff --git a/cosmos1/models/autoregressive/tokenizer/image_text_tokenizer.py b/ar_tokenizer_image_text_tokenizer.py similarity index 99% rename from cosmos1/models/autoregressive/tokenizer/image_text_tokenizer.py rename to ar_tokenizer_image_text_tokenizer.py index 3bbc2c82d7ac2ed01b45c27d35d7c9071c696e1f..d0911647e3159d8eb023df1b7fb2d02a8b8be76a 100644 --- a/cosmos1/models/autoregressive/tokenizer/image_text_tokenizer.py +++ b/ar_tokenizer_image_text_tokenizer.py @@ -21,8 +21,8 @@ import transformers from transformers import AutoImageProcessor from transformers.image_utils import ImageInput, is_valid_image, load_image -from cosmos1.models.autoregressive.tokenizer.text_tokenizer import TextTokenizer -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration.ar_tokenizer_text_tokenizer import TextTokenizer +from AutoregressiveVideo2WorldGeneration import log # Configuration for different vision-language models IMAGE_CONFIGS = { diff --git a/cosmos1/models/autoregressive/tokenizer/modules.py b/ar_tokenizer_modules.py similarity index 99% rename from cosmos1/models/autoregressive/tokenizer/modules.py rename to ar_tokenizer_modules.py index 290c145380129040f471899b924a9a93d389c73b..8f7744f24828c45d7da850d528fd936dbf2bc897 100644 --- a/cosmos1/models/autoregressive/tokenizer/modules.py +++ b/ar_tokenizer_modules.py @@ -29,8 +29,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from cosmos1.models.autoregressive.tokenizer.patching import Patcher3D, UnPatcher3D -from cosmos1.models.autoregressive.tokenizer.utils import ( +from AutoregressiveVideo2WorldGeneration.ar_tokenizer_patching import Patcher3D, UnPatcher3D +from AutoregressiveVideo2WorldGeneration.ar_tokenizer_utils import ( CausalNormalize, batch2space, batch2time, @@ -41,7 +41,7 @@ from cosmos1.models.autoregressive.tokenizer.utils import ( space2batch, time2batch, ) -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration import log class CausalConv3d(nn.Module): diff --git a/cosmos1/models/autoregressive/tokenizer/networks.py b/ar_tokenizer_networks.py similarity index 90% rename from cosmos1/models/autoregressive/tokenizer/networks.py rename to ar_tokenizer_networks.py index f25ac36f54179dd77e2d2177895c64dab258ff71..2b465abe8bb3e2c82438660765a34129241c2883 100644 --- a/cosmos1/models/autoregressive/tokenizer/networks.py +++ b/ar_tokenizer_networks.py @@ -18,9 +18,9 @@ from collections import namedtuple import torch from torch import nn -from cosmos1.models.autoregressive.tokenizer.modules import CausalConv3d, DecoderFactorized, EncoderFactorized -from cosmos1.models.autoregressive.tokenizer.quantizers import FSQuantizer -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration.ar_tokenizer_modules import CausalConv3d, DecoderFactorized, EncoderFactorized +from AutoregressiveVideo2WorldGeneration.ar_tokenizer_quantizers import FSQuantizer +from AutoregressiveVideo2WorldGeneration import log NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) diff --git a/cosmos1/models/autoregressive/tokenizer/patching.py b/ar_tokenizer_patching.py similarity index 100% rename from cosmos1/models/autoregressive/tokenizer/patching.py rename to ar_tokenizer_patching.py diff --git a/cosmos1/models/autoregressive/tokenizer/quantizers.py b/ar_tokenizer_quantizers.py similarity index 98% rename from cosmos1/models/autoregressive/tokenizer/quantizers.py rename to ar_tokenizer_quantizers.py index 589204cd9e69b734313a918df7e99c748ae7f6e5..d1ce54a2071df140e229ada226b9e8852b219875 100644 --- a/cosmos1/models/autoregressive/tokenizer/quantizers.py +++ b/ar_tokenizer_quantizers.py @@ -21,7 +21,7 @@ import torch import torch.nn as nn from einops import rearrange -from cosmos1.models.autoregressive.tokenizer.utils import default, pack_one, round_ste, unpack_one +from AutoregressiveVideo2WorldGeneration.ar_tokenizer_utils import default, pack_one, round_ste, unpack_one class FSQuantizer(nn.Module): diff --git a/cosmos1/models/autoregressive/tokenizer/text_tokenizer.py b/ar_tokenizer_text_tokenizer.py similarity index 99% rename from cosmos1/models/autoregressive/tokenizer/text_tokenizer.py rename to ar_tokenizer_text_tokenizer.py index 797457192fda248a1dfcfbe5a08298e9a48036df..f6eae73cd1504f3b2eebbbc2e071d31c4d81dccf 100644 --- a/cosmos1/models/autoregressive/tokenizer/text_tokenizer.py +++ b/ar_tokenizer_text_tokenizer.py @@ -19,7 +19,7 @@ import numpy as np import torch from transformers import AutoTokenizer -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration import log def get_tokenizer_path(model_family: str, is_instruct_model: bool = False): diff --git a/cosmos1/models/autoregressive/tokenizer/tokenizer.py b/ar_tokenizer_tokenizer.py similarity index 98% rename from cosmos1/models/autoregressive/tokenizer/tokenizer.py rename to ar_tokenizer_tokenizer.py index 6bda2565b9418eb0a63dc891dd6f8412434359c6..7d241b3deb761cea8a97d7d97c24e7b17ded34c3 100644 --- a/cosmos1/models/autoregressive/tokenizer/tokenizer.py +++ b/ar_tokenizer_tokenizer.py @@ -19,8 +19,8 @@ from typing import Optional import torch from einops import rearrange -from cosmos1.models.autoregressive.configs.base.tokenizer import TokenizerConfig -from cosmos1.utils.lazy_config import instantiate as lazy_instantiate +from AutoregressiveVideo2WorldGeneration.ar_config_base_tokenizer import TokenizerConfig +from AutoregressiveVideo2WorldGeneration.lazy_config_init import instantiate as lazy_instantiate def update_vocab_size( diff --git a/cosmos1/models/autoregressive/tokenizer/utils.py b/ar_tokenizer_utils.py similarity index 100% rename from cosmos1/models/autoregressive/tokenizer/utils.py rename to ar_tokenizer_utils.py diff --git a/cosmos1/models/autoregressive/utils/checkpoint.py b/ar_utils_checkpoint.py similarity index 100% rename from cosmos1/models/autoregressive/utils/checkpoint.py rename to ar_utils_checkpoint.py diff --git a/cosmos1/models/autoregressive/utils/inference.py b/ar_utils_inference.py similarity index 98% rename from cosmos1/models/autoregressive/utils/inference.py rename to ar_utils_inference.py index 150d41efce6a6668d3cf9f1a60d5bcf886d12931..4e637121ab632753f80505dcf1a6ac960b4879e6 100644 --- a/cosmos1/models/autoregressive/utils/inference.py +++ b/ar_utils_inference.py @@ -25,8 +25,8 @@ import torch import torchvision from PIL import Image -from cosmos1.models.autoregressive.configs.inference.inference_config import SamplingConfig -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration.ar_config_inference_inference_config import SamplingConfig +from AutoregressiveVideo2WorldGeneration import log _IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", "webp"] _VIDEO_EXTENSIONS = [".mp4"] diff --git a/cosmos1/models/autoregressive/utils/misc.py b/ar_utils_misc.py similarity index 100% rename from cosmos1/models/autoregressive/utils/misc.py rename to ar_utils_misc.py diff --git a/cosmos1/models/autoregressive/utils/sampling.py b/ar_utils_sampling.py similarity index 98% rename from cosmos1/models/autoregressive/utils/sampling.py rename to ar_utils_sampling.py index 91ba0e7abef2ffee4e57f7cca2b8ddbca25c27e3..b83890022922a9ad051b78a633ac605f7cc903ea 100644 --- a/cosmos1/models/autoregressive/utils/sampling.py +++ b/ar_utils_sampling.py @@ -17,7 +17,7 @@ from typing import Optional, Tuple import torch -from cosmos1.models.autoregressive.networks.transformer import Transformer +from AutoregressiveVideo2WorldGeneration.ar_network_transformer import Transformer def sample_top_p(logits, temperature, top_p, return_probs: bool = False): diff --git a/cosmos1/models/autoregressive/inference/base.py b/base.py similarity index 93% rename from cosmos1/models/autoregressive/inference/base.py rename to base.py index 5d756f8c06a0d6c7355774a41dd250079fbe085a..3f80351285901d4b43029ecd13fe33190283d9cf 100644 --- a/cosmos1/models/autoregressive/inference/base.py +++ b/base.py @@ -19,9 +19,9 @@ import os import imageio import torch -from cosmos1.models.autoregressive.inference.world_generation_pipeline import ARBaseGenerationPipeline -from cosmos1.models.autoregressive.utils.inference import add_common_arguments, load_vision_input, validate_args -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration.world_generation_pipeline import ARBaseGenerationPipeline +from AutoregressiveVideo2WorldGeneration.ar_utils_inference import add_common_arguments, load_vision_input, validate_args +from AutoregressiveVideo2WorldGeneration import log def parse_args(): diff --git a/cosmos1/models/common/base_world_generation_pipeline.py b/base_world_generation_pipeline.py similarity index 98% rename from cosmos1/models/common/base_world_generation_pipeline.py rename to base_world_generation_pipeline.py index 7475e8d0272b3098d905d7694d758a5a79c8bad5..6e7337df0341bf0ee65f3a051470a752f3c091b8 100644 --- a/cosmos1/models/common/base_world_generation_pipeline.py +++ b/base_world_generation_pipeline.py @@ -21,8 +21,8 @@ from typing import Any import numpy as np import torch -from cosmos1.models.common.t5_text_encoder import CosmosT5TextEncoder -from cosmos1.models.guardrail.common import presets as guardrail_presets +from AutoregressiveVideo2WorldGeneration.t5_text_encoder import CosmosT5TextEncoder +from AutoregressiveVideo2WorldGeneration import guardrail_common_presets as guardrail_presets class BaseWorldGenerationPipeline(ABC): diff --git a/config.json b/config.json new file mode 100644 index 0000000000000000000000000000000000000000..e143a7077e0433746e506ba00f1b435c87e67b66 --- /dev/null +++ b/config.json @@ -0,0 +1,10 @@ +{ + "architectures": [ + "ARVideo2World" + ], + "auto_map": { + "AutoConfig": "video2world_hf.ARVideo2WorldConfig", + "AutoModel": "video2world_hf.ARVideo2World" + }, + "model_type": "AutoModel" +} \ No newline at end of file diff --git a/cosmos1/utils/config.py b/config.py similarity index 97% rename from cosmos1/utils/config.py rename to config.py index 7b4cb6bb03330c3701a850bbd50523331980d0f5..705ca23385d9ecb4f20b373fe9e361d29bddc3ec 100644 --- a/cosmos1/utils/config.py +++ b/config.py @@ -19,8 +19,8 @@ from typing import Any, TypeVar import attrs -from cosmos1.utils.lazy_config import LazyDict -from cosmos1.utils.misc import Color +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict +from AutoregressiveVideo2WorldGeneration.misc import Color T = TypeVar("T") diff --git a/cosmos1/utils/config_helper.py b/config_helper.py similarity index 98% rename from cosmos1/utils/config_helper.py rename to config_helper.py index d6f4e169663c2bfe4eb09ce2c8571c88f2127450..3848d02a218ad8b9de43805e09fe1ef5f68367bc 100644 --- a/cosmos1/utils/config_helper.py +++ b/config_helper.py @@ -27,8 +27,8 @@ from hydra import compose, initialize from hydra.core.config_store import ConfigStore from omegaconf import DictConfig, OmegaConf -from cosmos1.utils import log -from cosmos1.utils.config import Config +from AutoregressiveVideo2WorldGeneration import log +from AutoregressiveVideo2WorldGeneration.config import Config def is_attrs_or_dataclass(obj) -> bool: diff --git a/cosmos1/scripts/convert_pixtral_ckpt.py b/convert_pixtral_ckpt.py similarity index 100% rename from cosmos1/scripts/convert_pixtral_ckpt.py rename to convert_pixtral_ckpt.py diff --git a/cosmos1/models/autoregressive/nemo/cosmos.py b/cosmos1/models/autoregressive/nemo/cosmos.py index b4f34ce8929d2aedc3f345420d91d22081a8dddd..72e4d02dc37f7e03bbc35e6d476e748834101294 100644 --- a/cosmos1/models/autoregressive/nemo/cosmos.py +++ b/cosmos1/models/autoregressive/nemo/cosmos.py @@ -29,7 +29,7 @@ from nemo.lightning import OptimizerModule, io from nemo.lightning.base import teardown from torch import Tensor, nn -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration import log class RotaryEmbedding3D(RotaryEmbedding): diff --git a/cosmos1/models/autoregressive/nemo/inference/general.py b/cosmos1/models/autoregressive/nemo/inference/general.py index dbc34e431da44065ecc4a484a0ed3e1acfbb464c..420e1bef9d97f1f8b63393d8b4b7c5ba1f0ec589 100644 --- a/cosmos1/models/autoregressive/nemo/inference/general.py +++ b/cosmos1/models/autoregressive/nemo/inference/general.py @@ -34,10 +34,10 @@ from nemo.lightning import io from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from cosmos1.models.autoregressive.nemo.utils import run_diffusion_decoder_model -from cosmos1.models.autoregressive.tokenizer.discrete_video import DiscreteVideoFSQJITTokenizer -from cosmos1.models.autoregressive.utils.inference import load_vision_input -from cosmos1.models.guardrail.common import presets as guardrail_presets -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration.ar_tokenizer_discrete_video import DiscreteVideoFSQJITTokenizer +from AutoregressiveVideo2WorldGeneration.ar_utils_inference import load_vision_input +from AutoregressiveVideo2WorldGeneration import guardrail_common_presets as guardrail_presets +from AutoregressiveVideo2WorldGeneration import log torch._C._jit_set_texpr_fuser_enabled(False) diff --git a/cosmos1/models/autoregressive/nemo/post_training/prepare_dataset.py b/cosmos1/models/autoregressive/nemo/post_training/prepare_dataset.py index 28d531f0dc2a0e03abc13c82e54037fd9fc36a51..63c37fb57b9911dea642d5a69ce9aa67d354a2ec 100644 --- a/cosmos1/models/autoregressive/nemo/post_training/prepare_dataset.py +++ b/cosmos1/models/autoregressive/nemo/post_training/prepare_dataset.py @@ -23,8 +23,8 @@ from huggingface_hub import snapshot_download from nemo.collections.nlp.data.language_modeling.megatron import indexed_dataset from cosmos1.models.autoregressive.nemo.utils import read_input_videos -from cosmos1.models.autoregressive.tokenizer.discrete_video import DiscreteVideoFSQJITTokenizer -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration.ar_tokenizer_discrete_video import DiscreteVideoFSQJITTokenizer +from AutoregressiveVideo2WorldGeneration import log TOKENIZER_COMPRESSION_FACTOR = [8, 16, 16] DATA_RESOLUTION_SUPPORTED = [640, 1024] diff --git a/cosmos1/models/autoregressive/nemo/utils.py b/cosmos1/models/autoregressive/nemo/utils.py index 755986e9dff6d568f4197610d963d07b704ee7e1..14b679ef8f4140fc27813769bbd7fac8972d4b2b 100644 --- a/cosmos1/models/autoregressive/nemo/utils.py +++ b/cosmos1/models/autoregressive/nemo/utils.py @@ -23,16 +23,16 @@ import torch import torchvision from huggingface_hub import snapshot_download -from cosmos1.models.autoregressive.configs.inference.inference_config import DiffusionDecoderSamplingConfig -from cosmos1.models.autoregressive.diffusion_decoder.inference import diffusion_decoder_process_tokens -from cosmos1.models.autoregressive.diffusion_decoder.model import LatentDiffusionDecoderModel -from cosmos1.models.diffusion.inference.inference_utils import ( +from AutoregressiveVideo2WorldGeneration.ar_config_inference_inference_config import DiffusionDecoderSamplingConfig +from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_inference import diffusion_decoder_process_tokens +from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_model import LatentDiffusionDecoderModel +from AutoregressiveVideo2WorldGeneration.df_inference_inference_utils import ( load_network_model, load_tokenizer_model, skip_init_linear, ) -from cosmos1.utils import log -from cosmos1.utils.config_helper import get_config_module, override +from AutoregressiveVideo2WorldGeneration import log +from AutoregressiveVideo2WorldGeneration.config_helper import get_config_module, override TOKENIZER_COMPRESSION_FACTOR = [8, 16, 16] DATA_RESOLUTION_SUPPORTED = [640, 1024] diff --git a/cosmos1/models/diffusion/config/config.py b/cosmos1/models/diffusion/config/config.py index 514e47022e728e885000505f728dd5a0056d3e0d..eb38e3850f95535e6b8c39a31b6715012d5df684 100644 --- a/cosmos1/models/diffusion/config/config.py +++ b/cosmos1/models/diffusion/config/config.py @@ -17,10 +17,10 @@ from typing import Any, List import attrs -from cosmos1.models.diffusion.config.base.model import DefaultModelConfig -from cosmos1.models.diffusion.config.registry import register_configs -from cosmos1.utils import config -from cosmos1.utils.config_helper import import_all_modules_from_package +from AutoregressiveVideo2WorldGeneration.df_config_base_model import DefaultModelConfig +from AutoregressiveVideo2WorldGeneration.df_config_registry import register_configs +from AutoregressiveVideo2WorldGeneration import config +from AutoregressiveVideo2WorldGeneration.config_helper import import_all_modules_from_package @attrs.define(slots=False) diff --git a/cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-text2world.py b/cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-text2world.py index 5b6fe2a7ffd63536a6cce293cc38b471ae21eb18..17565632abbb356766f552012357acbc5aba5e42 100644 --- a/cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-text2world.py +++ b/cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-text2world.py @@ -15,7 +15,7 @@ from hydra.core.config_store import ConfigStore -from cosmos1.utils.lazy_config import LazyDict +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict Cosmos_1_0_Diffusion_Text2World_7B: LazyDict = LazyDict( dict( diff --git a/cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-video2world.py b/cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-video2world.py index 39f90475c23dd6983f3456bb614f779a63b1626b..404c23aa2017ec433d5ead2b2ffaf0333bc7340d 100644 --- a/cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-video2world.py +++ b/cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-video2world.py @@ -16,8 +16,8 @@ from hydra.core.config_store import ConfigStore from cosmos1.models.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT -from cosmos1.utils.lazy_config import LazyCall as L -from cosmos1.utils.lazy_config import LazyDict +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict Cosmos_1_0_Diffusion_Video2World_7B: LazyDict = LazyDict( dict( diff --git a/cosmos1/models/diffusion/inference/text2world.py b/cosmos1/models/diffusion/inference/text2world.py index 4faaa75e4fd15c6b913fbe37b1bd0d9e6db8fb87..503c8177336d589db2f7098d9ff3afefb7e076bf 100644 --- a/cosmos1/models/diffusion/inference/text2world.py +++ b/cosmos1/models/diffusion/inference/text2world.py @@ -16,12 +16,13 @@ import argparse import os +from AutoregressiveVideo2WorldGeneration import misc import torch -from cosmos1.models.diffusion.inference.inference_utils import add_common_arguments, validate_args +from AutoregressiveVideo2WorldGeneration.df_inference_inference_utils import add_common_arguments, validate_args from cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline -from cosmos1.utils import log, misc -from cosmos1.utils.io import read_prompts_from_file, save_video +from AutoregressiveVideo2WorldGeneration import log +from AutoregressiveVideo2WorldGeneration.io import read_prompts_from_file, save_video torch.enable_grad(False) diff --git a/cosmos1/models/diffusion/inference/video2world.py b/cosmos1/models/diffusion/inference/video2world.py index cd495f7f75a4d1f20d9550e2015cd072ae310734..2afe6559dfc0e1467880a12be7e74371ca36c83f 100644 --- a/cosmos1/models/diffusion/inference/video2world.py +++ b/cosmos1/models/diffusion/inference/video2world.py @@ -16,12 +16,13 @@ import argparse import os +from AutoregressiveVideo2WorldGeneration import misc import torch -from cosmos1.models.diffusion.inference.inference_utils import add_common_arguments, check_input_frames, validate_args +from AutoregressiveVideo2WorldGeneration.df_inference_inference_utils import add_common_arguments, check_input_frames, validate_args from cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline -from cosmos1.utils import log, misc -from cosmos1.utils.io import read_prompts_from_file, save_video +from AutoregressiveVideo2WorldGeneration import log +from AutoregressiveVideo2WorldGeneration.io import read_prompts_from_file, save_video torch.enable_grad(False) diff --git a/cosmos1/models/diffusion/inference/world_generation_pipeline.py b/cosmos1/models/diffusion/inference/world_generation_pipeline.py index 47a8a522f3f2646813a9c9e50590aedd7759d002..6ca5e63aa3883403dfe79999bd53653c6771e7ed 100644 --- a/cosmos1/models/diffusion/inference/world_generation_pipeline.py +++ b/cosmos1/models/diffusion/inference/world_generation_pipeline.py @@ -20,8 +20,8 @@ from typing import Any, Optional import numpy as np import torch -from cosmos1.models.common.base_world_generation_pipeline import BaseWorldGenerationPipeline -from cosmos1.models.diffusion.inference.inference_utils import ( +from AutoregressiveVideo2WorldGeneration.base_world_generation_pipeline import BaseWorldGenerationPipeline +from AutoregressiveVideo2WorldGeneration.df_inference_inference_utils import ( generate_world_from_text, generate_world_from_video, get_condition_latent, @@ -30,8 +30,8 @@ from cosmos1.models.diffusion.inference.inference_utils import ( load_network_model, load_tokenizer_model, ) -from cosmos1.models.diffusion.model.model_t2w import DiffusionT2WModel -from cosmos1.models.diffusion.model.model_v2w import DiffusionV2WModel +from AutoregressiveVideo2WorldGeneration.df_model_model_t2w import DiffusionT2WModel +from AutoregressiveVideo2WorldGeneration.df_model_model_v2w import DiffusionV2WModel from cosmos1.models.diffusion.prompt_upsampler.text2world_prompt_upsampler_inference import ( create_prompt_upsampler, run_chat_completion, @@ -43,7 +43,7 @@ from cosmos1.models.diffusion.prompt_upsampler.video2world_prompt_upsampler_infe from cosmos1.models.diffusion.prompt_upsampler.video2world_prompt_upsampler_inference import ( run_chat_completion as run_chat_completion_vlm, ) -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration import log MODEL_NAME_DICT = { "Cosmos-1.0-Diffusion-7B-Text2World": "Cosmos_1_0_Diffusion_Text2World_7B", diff --git a/cosmos1/models/diffusion/nemo/inference/general.py b/cosmos1/models/diffusion/nemo/inference/general.py index 397c60febed717ded44f54497edd1ebecbb7382c..4328173fe6929ee72c633257360226b270a2f23e 100644 --- a/cosmos1/models/diffusion/nemo/inference/general.py +++ b/cosmos1/models/diffusion/nemo/inference/general.py @@ -37,7 +37,7 @@ from nemo.collections.diffusion.sampler.cosmos.cosmos_diffusion_pipeline import from transformers import T5EncoderModel, T5TokenizerFast from cosmos1.models.diffusion.nemo.inference.inference_utils import process_prompt, save_video -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration import log EXAMPLE_PROMPT = ( "The teal robot is cooking food in a kitchen. Steam rises from a simmering pot " diff --git a/cosmos1/models/diffusion/nemo/inference/inference_utils.py b/cosmos1/models/diffusion/nemo/inference/inference_utils.py index de95a04d418aace14944ddc636247ab9cbb47848..f10c8a67765f33cf86b8af32cf571931a7c16242 100644 --- a/cosmos1/models/diffusion/nemo/inference/inference_utils.py +++ b/cosmos1/models/diffusion/nemo/inference/inference_utils.py @@ -19,18 +19,18 @@ import imageio import numpy as np import torch -from cosmos1.models.autoregressive.model import AutoRegressiveModel +from AutoregressiveVideo2WorldGeneration.ar_model import AutoRegressiveModel from cosmos1.models.diffusion.prompt_upsampler.text2world_prompt_upsampler_inference import ( create_prompt_upsampler, run_chat_completion, ) -from cosmos1.models.guardrail.common.presets import ( +from AutoregressiveVideo2WorldGeneration.guardrail_common_presets import ( create_text_guardrail_runner, create_video_guardrail_runner, run_text_guardrail, run_video_guardrail, ) -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration import log def get_upsampled_prompt( diff --git a/cosmos1/models/diffusion/nemo/post_training/prepare_dataset.py b/cosmos1/models/diffusion/nemo/post_training/prepare_dataset.py index 7b4e85a48b1cbbeb0835bfebaa5acc4fc7579841..1dad9f347f4e4166a73d217ed6406bff43ad551e 100644 --- a/cosmos1/models/diffusion/nemo/post_training/prepare_dataset.py +++ b/cosmos1/models/diffusion/nemo/post_training/prepare_dataset.py @@ -27,7 +27,7 @@ from nemo.collections.diffusion.models.model import DiT7BConfig from tqdm import tqdm from transformers import T5EncoderModel, T5TokenizerFast -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration import log def get_parser(): diff --git a/cosmos1/models/diffusion/networks/general_dit_video_conditioned.py b/cosmos1/models/diffusion/networks/general_dit_video_conditioned.py index f27bcabe29cd87fc17b0f57f9e56b2a0f1bb3959..0ffaa9f2a1b591196b82cae9659a64ac3d746b62 100644 --- a/cosmos1/models/diffusion/networks/general_dit_video_conditioned.py +++ b/cosmos1/models/diffusion/networks/general_dit_video_conditioned.py @@ -19,10 +19,10 @@ import torch from einops import rearrange from torch import nn -from cosmos1.models.diffusion.conditioner import DataType -from cosmos1.models.diffusion.module.blocks import TimestepEmbedding, Timesteps -from cosmos1.models.diffusion.networks.general_dit import GeneralDIT -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration.df_conditioner import DataType +from AutoregressiveVideo2WorldGeneration.df_module_blocks import TimestepEmbedding, Timesteps +from AutoregressiveVideo2WorldGeneration.df_network_general_dit import GeneralDIT +from AutoregressiveVideo2WorldGeneration import log class VideoExtendGeneralDIT(GeneralDIT): diff --git a/cosmos1/models/diffusion/prompt_upsampler/inference.py b/cosmos1/models/diffusion/prompt_upsampler/inference.py index b022dd14933e4e0d5abe5f47f61cd29675101d49..d33fb465d4f7a633ddafb05ab1e50bf46b59f1d7 100644 --- a/cosmos1/models/diffusion/prompt_upsampler/inference.py +++ b/cosmos1/models/diffusion/prompt_upsampler/inference.py @@ -17,9 +17,9 @@ from typing import List, Optional, TypedDict import torch -from cosmos1.models.autoregressive.model import AutoRegressiveModel -from cosmos1.models.autoregressive.tokenizer.image_text_tokenizer import ImageTextTokenizer -from cosmos1.models.autoregressive.tokenizer.text_tokenizer import TextTokenizer +from AutoregressiveVideo2WorldGeneration.ar_model import AutoRegressiveModel +from AutoregressiveVideo2WorldGeneration.ar_tokenizer_image_text_tokenizer import ImageTextTokenizer +from AutoregressiveVideo2WorldGeneration.ar_tokenizer_text_tokenizer import TextTokenizer class ChatPrediction(TypedDict, total=False): diff --git a/cosmos1/models/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py b/cosmos1/models/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py index 44a83e3a364b4278fabce4102b3e331048bba455..7073af4819595aa0080d094ab3e8b2c161b6a01e 100644 --- a/cosmos1/models/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py +++ b/cosmos1/models/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py @@ -23,11 +23,11 @@ import argparse import os import re -from cosmos1.models.autoregressive.configs.base.model_config import create_text_model_config -from cosmos1.models.autoregressive.model import AutoRegressiveModel +from AutoregressiveVideo2WorldGeneration.ar_config_base_model_config import create_text_model_config +from AutoregressiveVideo2WorldGeneration.ar_model import AutoRegressiveModel from cosmos1.models.diffusion.prompt_upsampler.inference import chat_completion -from cosmos1.models.guardrail.common import presets as guardrail_presets -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration import guardrail_common_presets as guardrail_presets +from AutoregressiveVideo2WorldGeneration import log def create_prompt_upsampler(checkpoint_dir: str) -> AutoRegressiveModel: diff --git a/cosmos1/models/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py b/cosmos1/models/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py index beff9f2ca574afbec4d5157ec22d0c62ceb4a64f..2c0ff29486fcb7c18652f4c19f9ff76ba8dc5dbc 100644 --- a/cosmos1/models/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py +++ b/cosmos1/models/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py @@ -26,12 +26,12 @@ from math import ceil from PIL import Image -from cosmos1.models.autoregressive.configs.base.model_config import create_vision_language_model_config -from cosmos1.models.autoregressive.model import AutoRegressiveModel +from AutoregressiveVideo2WorldGeneration.ar_config_base_model_config import create_vision_language_model_config +from AutoregressiveVideo2WorldGeneration.ar_model import AutoRegressiveModel from cosmos1.models.diffusion.prompt_upsampler.inference import chat_completion -from cosmos1.models.guardrail.common import presets as guardrail_presets -from cosmos1.utils import log -from cosmos1.utils.io import load_from_fileobj +from AutoregressiveVideo2WorldGeneration import guardrail_common_presets as guardrail_presets +from AutoregressiveVideo2WorldGeneration import log +from AutoregressiveVideo2WorldGeneration.io import load_from_fileobj def create_vlm_prompt_upsampler( diff --git a/cosmos1/utils/misc.py b/cosmos1/utils/misc.py deleted file mode 100644 index 8b0c6d66669dc220db0c24f3b14d33acc0a0c512..0000000000000000000000000000000000000000 --- a/cosmos1/utils/misc.py +++ /dev/null @@ -1,207 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import collections -import collections.abc -import functools -import json -import random -import time -from contextlib import ContextDecorator -from typing import Any, Callable, TypeVar - -import numpy as np -import termcolor -import torch - -from cosmos1.utils import distributed, log - - -def to( - data: Any, - device: str | torch.device | None = None, - dtype: torch.dtype | None = None, - memory_format: torch.memory_format = torch.preserve_format, -) -> Any: - """Recursively cast data into the specified device, dtype, and/or memory_format. - - The input data can be a tensor, a list of tensors, a dict of tensors. - See the documentation for torch.Tensor.to() for details. - - Args: - data (Any): Input data. - device (str | torch.device): GPU device (default: None). - dtype (torch.dtype): data type (default: None). - memory_format (torch.memory_format): memory organization format (default: torch.preserve_format). - - Returns: - data (Any): Data cast to the specified device, dtype, and/or memory_format. - """ - assert ( - device is not None or dtype is not None or memory_format is not None - ), "at least one of device, dtype, memory_format should be specified" - if isinstance(data, torch.Tensor): - is_cpu = (isinstance(device, str) and device == "cpu") or ( - isinstance(device, torch.device) and device.type == "cpu" - ) - data = data.to( - device=device, - dtype=dtype, - memory_format=memory_format, - non_blocking=(not is_cpu), - ) - return data - elif isinstance(data, collections.abc.Mapping): - return type(data)({key: to(data[key], device=device, dtype=dtype, memory_format=memory_format) for key in data}) - elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): - return type(data)([to(elem, device=device, dtype=dtype, memory_format=memory_format) for elem in data]) - else: - return data - - -def serialize(data: Any) -> Any: - """Serialize data by hierarchically traversing through iterables. - - Args: - data (Any): Input data. - - Returns: - data (Any): Serialized data. - """ - if isinstance(data, collections.abc.Mapping): - return type(data)({key: serialize(data[key]) for key in data}) - elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): - return type(data)([serialize(elem) for elem in data]) - else: - try: - json.dumps(data) - except TypeError: - data = str(data) - return data - - -def set_random_seed(seed: int, by_rank: bool = False) -> None: - """Set random seed. This includes random, numpy, Pytorch. - - Args: - seed (int): Random seed. - by_rank (bool): if true, each GPU will use a different random seed. - """ - if by_rank: - seed += distributed.get_rank() - log.info(f"Using random seed {seed}.") - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) # sets seed on the current CPU & all GPUs - - -def arch_invariant_rand( - shape: List[int] | Tuple[int], dtype: torch.dtype, device: str | torch.device, seed: int | None = None -): - """Produce a GPU-architecture-invariant randomized Torch tensor. - - Args: - shape (list or tuple of ints): Output tensor shape. - dtype (torch.dtype): Output tensor type. - device (torch.device): Device holding the output. - seed (int): Optional randomization seed. - - Returns: - tensor (torch.tensor): Randomly-generated tensor. - """ - # Create a random number generator, optionally seeded - rng = np.random.RandomState(seed) - - # # Generate random numbers using the generator - random_array = rng.standard_normal(shape).astype(np.float32) # Use standard_normal for normal distribution - - # Convert to torch tensor and return - return torch.from_numpy(random_array).to(dtype=dtype, device=device) - - -T = TypeVar("T", bound=Callable[..., Any]) - - -class timer(ContextDecorator): # noqa: N801 - """Simple timer for timing the execution of code. - - It can be used as either a context manager or a function decorator. The timing result will be logged upon exit. - - Example: - def func_a(): - time.sleep(1) - with timer("func_a"): - func_a() - - @timer("func_b) - def func_b(): - time.sleep(1) - func_b() - """ - - def __init__(self, context: str, debug: bool = False): - self.context = context - self.debug = debug - - def __enter__(self) -> None: - self.tic = time.time() - - def __exit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001 - time_spent = time.time() - self.tic - if self.debug: - log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") - else: - log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") - - def __call__(self, func: T) -> T: - @functools.wraps(func) - def wrapper(*args, **kwargs): # noqa: ANN202 - tic = time.time() - result = func(*args, **kwargs) - time_spent = time.time() - tic - if self.debug: - log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") - else: - log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") - return result - - return wrapper # type: ignore - - -class Color: - """A convenience class to colorize strings in the console. - - Example: - import - print("This is {Color.red('important')}.") - """ - - @staticmethod - def red(x: str) -> str: - return termcolor.colored(str(x), color="red") - - @staticmethod - def green(x: str) -> str: - return termcolor.colored(str(x), color="green") - - @staticmethod - def cyan(x: str) -> str: - return termcolor.colored(str(x), color="cyan") - - @staticmethod - def yellow(x: str) -> str: - return termcolor.colored(str(x), color="yellow") diff --git a/cosmos1/utils/device.py b/device.py similarity index 100% rename from cosmos1/utils/device.py rename to device.py diff --git a/cosmos1/models/diffusion/conditioner.py b/df_conditioner.py similarity index 98% rename from cosmos1/models/diffusion/conditioner.py rename to df_conditioner.py index 15c0d9b636f6fb4b540ee1d23846d26716433c98..0f101b3030da778cc8ea8ed46043ac6acd692e3f 100644 --- a/cosmos1/models/diffusion/conditioner.py +++ b/df_conditioner.py @@ -23,9 +23,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn -from cosmos1.models.diffusion.diffusion.functional.batch_ops import batch_mul -from cosmos1.utils import log -from cosmos1.utils.lazy_config import instantiate +from AutoregressiveVideo2WorldGeneration.df_df_functional_batch_ops import batch_mul +from AutoregressiveVideo2WorldGeneration import log +from AutoregressiveVideo2WorldGeneration.lazy_config_init import instantiate class BaseConditionEntry(nn.Module): diff --git a/cosmos1/models/diffusion/config/base/conditioner.py b/df_config_base_conditioner.py similarity index 94% rename from cosmos1/models/diffusion/config/base/conditioner.py rename to df_config_base_conditioner.py index 3333af87b216199c38dc0c45aedec389beec536e..2a5845d0f256991b7b486d6bfdd19f7632c39043 100644 --- a/cosmos1/models/diffusion/config/base/conditioner.py +++ b/df_config_base_conditioner.py @@ -18,9 +18,9 @@ from typing import Dict, List, Optional import attrs import torch -from cosmos1.models.diffusion.conditioner import BaseConditionEntry, TextAttr, VideoConditioner, VideoExtendConditioner -from cosmos1.utils.lazy_config import LazyCall as L -from cosmos1.utils.lazy_config import LazyDict +from AutoregressiveVideo2WorldGeneration.df_conditioner import BaseConditionEntry, TextAttr, VideoConditioner, VideoExtendConditioner +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict @attrs.define(slots=False) diff --git a/cosmos1/models/diffusion/config/base/model.py b/df_config_base_model.py similarity index 95% rename from cosmos1/models/diffusion/config/base/model.py rename to df_config_base_model.py index 97b94e554f612f2f45a34443944d3cdefc1b7c82..42d9e551d4460aa5e694052f3df7ad9c66ca9b45 100644 --- a/cosmos1/models/diffusion/config/base/model.py +++ b/df_config_base_model.py @@ -17,7 +17,7 @@ from typing import List import attrs -from cosmos1.utils.lazy_config import LazyDict +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict @attrs.define(slots=False) diff --git a/cosmos1/models/diffusion/config/base/net.py b/df_config_base_net.py similarity index 85% rename from cosmos1/models/diffusion/config/base/net.py rename to df_config_base_net.py index 931a8ef2204209d5812b0b16a156f1a3decdb94e..5b843163eee1e4cd7d6fcc347b72148a3f14253a 100644 --- a/cosmos1/models/diffusion/config/base/net.py +++ b/df_config_base_net.py @@ -15,9 +15,9 @@ import copy -from cosmos1.models.diffusion.networks.general_dit import GeneralDIT -from cosmos1.utils.lazy_config import LazyCall as L -from cosmos1.utils.lazy_config import LazyDict +from AutoregressiveVideo2WorldGeneration.df_network_general_dit import GeneralDIT +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict FADITV2Config: LazyDict = L(GeneralDIT)( max_img_h=240, diff --git a/cosmos1/models/diffusion/config/base/tokenizer.py b/df_config_base_tokenizer.py similarity index 89% rename from cosmos1/models/diffusion/config/base/tokenizer.py rename to df_config_base_tokenizer.py index f03a96abd0cf23e6285f5ffbfc1e93e65718700f..7c6363ead8e27335b2c6d76be8341e1d9b326c06 100644 --- a/cosmos1/models/diffusion/config/base/tokenizer.py +++ b/df_config_base_tokenizer.py @@ -15,8 +15,8 @@ import omegaconf -from cosmos1.models.diffusion.module.pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer -from cosmos1.utils.lazy_config import LazyCall as L +from AutoregressiveVideo2WorldGeneration.df_module_pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer +from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L TOKENIZER_OPTIONS = {} diff --git a/cosmos1/models/diffusion/config/registry.py b/df_config_registry.py similarity index 87% rename from cosmos1/models/diffusion/config/registry.py rename to df_config_registry.py index 5d92360f8370b4e590e4acc02ac35ac5be35a693..aeba6024f7a858bc9d70756a7839cfb3d8add54a 100644 --- a/cosmos1/models/diffusion/config/registry.py +++ b/df_config_registry.py @@ -15,13 +15,13 @@ from hydra.core.config_store import ConfigStore -from cosmos1.models.diffusion.config.base.conditioner import ( +from AutoregressiveVideo2WorldGeneration.df_config_base_conditioner import ( BaseVideoConditionerConfig, VideoConditionerFpsSizePaddingConfig, VideoExtendConditionerConfig, ) -from cosmos1.models.diffusion.config.base.net import FADITV2_14B_Config, FADITV2Config -from cosmos1.models.diffusion.config.base.tokenizer import get_cosmos_diffusion_tokenizer_comp8x8x8 +from AutoregressiveVideo2WorldGeneration.df_config_base_net import FADITV2_14B_Config, FADITV2Config +from AutoregressiveVideo2WorldGeneration.df_config_base_tokenizer import get_cosmos_diffusion_tokenizer_comp8x8x8 def register_net(cs): diff --git a/cosmos1/models/diffusion/diffusion/functional/batch_ops.py b/df_df_functional_batch_ops.py similarity index 100% rename from cosmos1/models/diffusion/diffusion/functional/batch_ops.py rename to df_df_functional_batch_ops.py diff --git a/cosmos1/models/diffusion/diffusion/functional/multi_step.py b/df_df_functional_multi_step.py similarity index 94% rename from cosmos1/models/diffusion/diffusion/functional/multi_step.py rename to df_df_functional_multi_step.py index b651c600b6ebc1afed97bd92d84e9619b393e3f0..954075c5aa4028ca5ea21c40208e788cee136ad0 100644 --- a/cosmos1/models/diffusion/diffusion/functional/multi_step.py +++ b/df_df_functional_multi_step.py @@ -21,7 +21,7 @@ from typing import Callable, List, Tuple import torch -from cosmos1.models.diffusion.diffusion.functional.runge_kutta import reg_x0_euler_step, res_x0_rk2_step +from AutoregressiveVideo2WorldGeneration.df_df_functional_runge_kutta import reg_x0_euler_step, res_x0_rk2_step def order2_fn( diff --git a/cosmos1/models/diffusion/diffusion/functional/runge_kutta.py b/df_df_functional_runge_kutta.py similarity index 99% rename from cosmos1/models/diffusion/diffusion/functional/runge_kutta.py rename to df_df_functional_runge_kutta.py index d07aafe41fdafa9e323079ac57a96994b365fe88..ef9be1adb343c40d3dc9269bb2bf346d70a7a02b 100644 --- a/cosmos1/models/diffusion/diffusion/functional/runge_kutta.py +++ b/df_df_functional_runge_kutta.py @@ -17,7 +17,7 @@ from typing import Callable, Tuple import torch -from cosmos1.models.diffusion.diffusion.functional.batch_ops import batch_mul +from AutoregressiveVideo2WorldGeneration.df_df_functional_batch_ops import batch_mul def phi1(t: torch.Tensor) -> torch.Tensor: diff --git a/cosmos1/models/diffusion/diffusion/modules/denoiser_scaling.py b/df_df_module_denoiser_scaling.py similarity index 100% rename from cosmos1/models/diffusion/diffusion/modules/denoiser_scaling.py rename to df_df_module_denoiser_scaling.py diff --git a/cosmos1/models/diffusion/diffusion/modules/res_sampler.py b/df_df_module_res_sampler.py similarity index 96% rename from cosmos1/models/diffusion/diffusion/modules/res_sampler.py rename to df_df_module_res_sampler.py index 77fe1c5aefcef5f7683cc8d6bf585337a0cca41a..afdd85ec8cc975b24f0b2309298ca2192bc46a03 100644 --- a/cosmos1/models/diffusion/diffusion/modules/res_sampler.py +++ b/df_df_module_res_sampler.py @@ -28,9 +28,9 @@ from typing import Any, Callable, List, Literal, Optional, Tuple, Union import attrs import torch -from cosmos1.models.diffusion.diffusion.functional.multi_step import get_multi_step_fn, is_multi_step_fn_supported -from cosmos1.models.diffusion.diffusion.functional.runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported -from cosmos1.utils.config import make_freezable +from AutoregressiveVideo2WorldGeneration.df_df_functional_multi_step import get_multi_step_fn, is_multi_step_fn_supported +from AutoregressiveVideo2WorldGeneration.df_df_functional_runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported +from AutoregressiveVideo2WorldGeneration.config import make_freezable COMMON_SOLVER_OPTIONS = Literal["2ab", "2mid", "1euler"] diff --git a/cosmos1/models/diffusion/diffusion/types.py b/df_df_types.py similarity index 100% rename from cosmos1/models/diffusion/diffusion/types.py rename to df_df_types.py diff --git a/cosmos1/models/diffusion/inference/inference_utils.py b/df_inference_inference_utils.py similarity index 98% rename from cosmos1/models/diffusion/inference/inference_utils.py rename to df_inference_inference_utils.py index f0edb4a949c48c6259539887d80e43158e9e0e68..2f5518c1512707ae3dffcb7b238025b859e10685 100644 --- a/cosmos1/models/diffusion/inference/inference_utils.py +++ b/df_inference_inference_utils.py @@ -18,17 +18,18 @@ import importlib from contextlib import contextmanager from typing import List, NamedTuple, Optional, Tuple +from AutoregressiveVideo2WorldGeneration import misc import einops import imageio import numpy as np import torch import torchvision.transforms.functional as transforms_F -from cosmos1.models.diffusion.model.model_t2w import DiffusionT2WModel -from cosmos1.models.diffusion.model.model_v2w import DiffusionV2WModel -from cosmos1.utils import log, misc -from cosmos1.utils.config_helper import get_config_module, override -from cosmos1.utils.io import load_from_fileobj +from AutoregressiveVideo2WorldGeneration.df_model_model_t2w import DiffusionT2WModel +from AutoregressiveVideo2WorldGeneration.df_model_model_v2w import DiffusionV2WModel +from AutoregressiveVideo2WorldGeneration import log +from AutoregressiveVideo2WorldGeneration.config_helper import get_config_module, override +from AutoregressiveVideo2WorldGeneration.io import load_from_fileobj TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) if TORCH_VERSION >= (1, 11): diff --git a/cosmos1/models/diffusion/model/model_t2w.py b/df_model_model_t2w.py similarity index 93% rename from cosmos1/models/diffusion/model/model_t2w.py rename to df_model_model_t2w.py index c21c1489eb738471cbb17876ba535206e9afc2aa..458c1e35c207948ecf715fd93d0972630c79ed80 100644 --- a/cosmos1/models/diffusion/model/model_t2w.py +++ b/df_model_model_t2w.py @@ -15,18 +15,19 @@ from typing import Callable, Dict, Optional, Tuple +from AutoregressiveVideo2WorldGeneration import misc import torch from torch import Tensor -from cosmos1.models.diffusion.conditioner import CosmosCondition -from cosmos1.models.diffusion.diffusion.functional.batch_ops import batch_mul -from cosmos1.models.diffusion.diffusion.modules.denoiser_scaling import EDMScaling -from cosmos1.models.diffusion.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS, Sampler -from cosmos1.models.diffusion.diffusion.types import DenoisePrediction -from cosmos1.models.diffusion.module.blocks import FourierFeatures -from cosmos1.models.diffusion.module.pretrained_vae import BaseVAE -from cosmos1.utils import log, misc -from cosmos1.utils.lazy_config import instantiate as lazy_instantiate +from AutoregressiveVideo2WorldGeneration.df_conditioner import CosmosCondition +from AutoregressiveVideo2WorldGeneration.df_df_functional_batch_ops import batch_mul +from AutoregressiveVideo2WorldGeneration.df_df_module_denoiser_scaling import EDMScaling +from AutoregressiveVideo2WorldGeneration.df_df_module_res_sampler import COMMON_SOLVER_OPTIONS, Sampler +from AutoregressiveVideo2WorldGeneration.df_df_types import DenoisePrediction +from AutoregressiveVideo2WorldGeneration.df_module_blocks import FourierFeatures +from AutoregressiveVideo2WorldGeneration.df_module_pretrained_vae import BaseVAE +from AutoregressiveVideo2WorldGeneration import log +from AutoregressiveVideo2WorldGeneration.lazy_config_init import instantiate as lazy_instantiate class EDMSDE: diff --git a/cosmos1/models/diffusion/model/model_v2w.py b/df_model_model_v2w.py similarity index 97% rename from cosmos1/models/diffusion/model/model_v2w.py rename to df_model_model_v2w.py index c8998c0b941cf146bc9ac4bac512a66371974c30..6fdbf21dda8f909a9611ca1780b1b0c04fb99e1c 100644 --- a/cosmos1/models/diffusion/model/model_v2w.py +++ b/df_model_model_v2w.py @@ -16,14 +16,15 @@ from dataclasses import dataclass from typing import Callable, Dict, Optional, Tuple, Union +from AutoregressiveVideo2WorldGeneration import misc import torch from torch import Tensor -from cosmos1.models.diffusion.conditioner import VideoExtendCondition -from cosmos1.models.diffusion.config.base.conditioner import VideoCondBoolConfig -from cosmos1.models.diffusion.diffusion.functional.batch_ops import batch_mul -from cosmos1.models.diffusion.model.model_t2w import DiffusionT2WModel -from cosmos1.utils import log, misc +from AutoregressiveVideo2WorldGeneration.df_conditioner import VideoExtendCondition +from AutoregressiveVideo2WorldGeneration.df_config_base_conditioner import VideoCondBoolConfig +from AutoregressiveVideo2WorldGeneration.df_df_functional_batch_ops import batch_mul +from AutoregressiveVideo2WorldGeneration.df_model_model_t2w import DiffusionT2WModel +from AutoregressiveVideo2WorldGeneration import log @dataclass diff --git a/cosmos1/models/diffusion/module/attention.py b/df_module_attention.py similarity index 100% rename from cosmos1/models/diffusion/module/attention.py rename to df_module_attention.py diff --git a/cosmos1/models/diffusion/module/blocks.py b/df_module_blocks.py similarity index 99% rename from cosmos1/models/diffusion/module/blocks.py rename to df_module_blocks.py index aee7d6fe0f8cc91391e5f8fd71470288408d7003..4d14bf25ad4b22c81be7e1f365c0aa4079536583 100644 --- a/cosmos1/models/diffusion/module/blocks.py +++ b/df_module_blocks.py @@ -22,8 +22,8 @@ from einops import rearrange, repeat from einops.layers.torch import Rearrange from torch import nn -from cosmos1.models.diffusion.module.attention import Attention, GPT2FeedForward -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration.df_module_attention import Attention, GPT2FeedForward +from AutoregressiveVideo2WorldGeneration import log def modulate(x, shift, scale): diff --git a/cosmos1/models/diffusion/module/position_embedding.py b/df_module_position_embedding.py similarity index 97% rename from cosmos1/models/diffusion/module/position_embedding.py rename to df_module_position_embedding.py index 9cb16bf49d9c62d10ff54770c291da87b5a9bd5c..cd63a1e3f3921928208cbce42633e41cf8429f9a 100644 --- a/cosmos1/models/diffusion/module/position_embedding.py +++ b/df_module_position_embedding.py @@ -19,8 +19,8 @@ import torch from einops import rearrange, repeat from torch import nn -from cosmos1.models.diffusion.module.attention import normalize -from cosmos1.models.diffusion.module.timm import trunc_normal_ +from AutoregressiveVideo2WorldGeneration.df_module_attention import normalize +from AutoregressiveVideo2WorldGeneration.df_module_timm import trunc_normal_ class VideoPositionEmb(nn.Module): diff --git a/cosmos1/models/diffusion/module/pretrained_vae.py b/df_module_pretrained_vae.py similarity index 100% rename from cosmos1/models/diffusion/module/pretrained_vae.py rename to df_module_pretrained_vae.py diff --git a/cosmos1/models/diffusion/module/timm.py b/df_module_timm.py similarity index 100% rename from cosmos1/models/diffusion/module/timm.py rename to df_module_timm.py diff --git a/cosmos1/models/diffusion/networks/general_dit.py b/df_network_general_dit.py similarity index 98% rename from cosmos1/models/diffusion/networks/general_dit.py rename to df_network_general_dit.py index 5b0bb143ac6c222a9a8b9c6690da7785e66a51f3..90d40237c58de17818c10a8f41cbe3d3df2d0b64 100644 --- a/cosmos1/models/diffusion/networks/general_dit.py +++ b/df_network_general_dit.py @@ -24,17 +24,17 @@ from einops import rearrange from torch import nn from torchvision import transforms -from cosmos1.models.diffusion.conditioner import DataType -from cosmos1.models.diffusion.module.attention import get_normalization -from cosmos1.models.diffusion.module.blocks import ( +from AutoregressiveVideo2WorldGeneration.df_conditioner import DataType +from AutoregressiveVideo2WorldGeneration.df_module_attention import get_normalization +from AutoregressiveVideo2WorldGeneration.df_module_blocks import ( FinalLayer, GeneralDITTransformerBlock, PatchEmbed, TimestepEmbedding, Timesteps, ) -from cosmos1.models.diffusion.module.position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration.df_module_position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb +from AutoregressiveVideo2WorldGeneration import log class GeneralDIT(nn.Module): diff --git a/cosmos1/utils/distributed.py b/distributed.py similarity index 98% rename from cosmos1/utils/distributed.py rename to distributed.py index d0a88c0519245236af13c104754bcd83b517cae3..d3ac63b7401852597d528cf99167d5995b6f6339 100644 --- a/cosmos1/utils/distributed.py +++ b/distributed.py @@ -27,8 +27,8 @@ import pynvml import torch import torch.distributed as dist -from cosmos1.utils import log -from cosmos1.utils.device import Device +from AutoregressiveVideo2WorldGeneration import log +from AutoregressiveVideo2WorldGeneration.device import Device def init() -> int | None: diff --git a/cosmos1/scripts/download_autoregressive.py b/download_autoregressive.py similarity index 100% rename from cosmos1/scripts/download_autoregressive.py rename to download_autoregressive.py diff --git a/cosmos1/scripts/download_diffusion.py b/download_diffusion.py similarity index 97% rename from cosmos1/scripts/download_diffusion.py rename to download_diffusion.py index 2ebeaffc148874f325b414dabcea4e83a25c869e..cdd94176ea4500b600881c99d2e5b0cff83b5003 100644 --- a/cosmos1/scripts/download_diffusion.py +++ b/download_diffusion.py @@ -18,7 +18,7 @@ from pathlib import Path from huggingface_hub import snapshot_download -from cosmos1.scripts.convert_pixtral_ckpt import convert_pixtral_checkpoint +from AutoregressiveVideo2WorldGeneration.convert_pixtral_ckpt import convert_pixtral_checkpoint def parse_args(): diff --git a/cosmos1/models/guardrail/aegis/aegis.py b/guardrail_aegis.py similarity index 94% rename from cosmos1/models/guardrail/aegis/aegis.py rename to guardrail_aegis.py index 0ec0b462f9ba35a94b2158954888443b87db2096..eef53731deccf32c23d09943a2e9251b82d71f68 100644 --- a/cosmos1/models/guardrail/aegis/aegis.py +++ b/guardrail_aegis.py @@ -15,13 +15,14 @@ import argparse +from AutoregressiveVideo2WorldGeneration import misc import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer -from cosmos1.models.guardrail.aegis.categories import UNSAFE_CATEGORIES -from cosmos1.models.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner -from cosmos1.utils import log, misc +from AutoregressiveVideo2WorldGeneration.guardrail_aegis_categories import UNSAFE_CATEGORIES +from AutoregressiveVideo2WorldGeneration.guardrail_common_core import ContentSafetyGuardrail, GuardrailRunner +from AutoregressiveVideo2WorldGeneration import log SAFE = misc.Color.green("SAFE") UNSAFE = misc.Color.red("UNSAFE") diff --git a/cosmos1/models/guardrail/aegis/categories.py b/guardrail_aegis_categories.py similarity index 100% rename from cosmos1/models/guardrail/aegis/categories.py rename to guardrail_aegis_categories.py diff --git a/cosmos1/models/guardrail/blocklist/blocklist.py b/guardrail_blocklist.py similarity index 96% rename from cosmos1/models/guardrail/blocklist/blocklist.py rename to guardrail_blocklist.py index fa3d30e0e74f162d8ea15e23aaded470f3c4ee90..b73595216ab9d7ce9615a994d38c8f59bb95ded6 100644 --- a/cosmos1/models/guardrail/blocklist/blocklist.py +++ b/guardrail_blocklist.py @@ -19,12 +19,13 @@ import re import string from difflib import SequenceMatcher +from AutoregressiveVideo2WorldGeneration import misc import nltk from better_profanity import profanity -from cosmos1.models.guardrail.blocklist.utils import read_keyword_list_from_dir, to_ascii -from cosmos1.models.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner -from cosmos1.utils import log, misc +from AutoregressiveVideo2WorldGeneration.guardrail_blocklist_utils import read_keyword_list_from_dir, to_ascii +from AutoregressiveVideo2WorldGeneration.guardrail_common_core import ContentSafetyGuardrail, GuardrailRunner +from AutoregressiveVideo2WorldGeneration import log DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/blocklist" CENSOR = misc.Color.red("*") diff --git a/cosmos1/models/guardrail/blocklist/utils.py b/guardrail_blocklist_utils.py similarity index 96% rename from cosmos1/models/guardrail/blocklist/utils.py rename to guardrail_blocklist_utils.py index 0c721914e1372f39ab81ad213a3b65fe30adee5b..481fcbad2d64f49aeb33f88eea52b638573ef040 100644 --- a/cosmos1/models/guardrail/blocklist/utils.py +++ b/guardrail_blocklist_utils.py @@ -16,7 +16,7 @@ import os import re -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration import log def read_keyword_list_from_dir(folder_path: str) -> list[str]: diff --git a/cosmos1/models/guardrail/common/core.py b/guardrail_common_core.py similarity index 98% rename from cosmos1/models/guardrail/common/core.py rename to guardrail_common_core.py index 15c7a36f3130c33d064f206b0656cf86cd91f403..b094f70d4c04058c0c61dd45dcfa0292c9b6c23f 100644 --- a/cosmos1/models/guardrail/common/core.py +++ b/guardrail_common_core.py @@ -17,7 +17,7 @@ from typing import Any, Tuple import numpy as np -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration import log class ContentSafetyGuardrail: diff --git a/cosmos1/models/guardrail/common/io_utils.py b/guardrail_common_io_utils.py similarity index 98% rename from cosmos1/models/guardrail/common/io_utils.py rename to guardrail_common_io_utils.py index b027850d84e8d618e0a5bb3dc6d7cc4bb5acef66..1a655f5d05e2af9b8db59eab9850d26fb08e1c52 100644 --- a/cosmos1/models/guardrail/common/io_utils.py +++ b/guardrail_common_io_utils.py @@ -19,7 +19,7 @@ from dataclasses import dataclass import imageio import numpy as np -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration import log @dataclass diff --git a/cosmos1/models/guardrail/common/presets.py b/guardrail_common_presets.py similarity index 84% rename from cosmos1/models/guardrail/common/presets.py rename to guardrail_common_presets.py index 8b28c554d23066971c3ae07fc5d756e7018602c2..2d00856dd5d6da7596f6b6f910e688c1ded47a3d 100644 --- a/cosmos1/models/guardrail/common/presets.py +++ b/guardrail_common_presets.py @@ -17,12 +17,12 @@ import os import numpy as np -from cosmos1.models.guardrail.aegis.aegis import Aegis -from cosmos1.models.guardrail.blocklist.blocklist import Blocklist -from cosmos1.models.guardrail.common.core import GuardrailRunner -from cosmos1.models.guardrail.face_blur_filter.face_blur_filter import RetinaFaceFilter -from cosmos1.models.guardrail.video_content_safety_filter.video_content_safety_filter import VideoContentSafetyFilter -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration.guardrail_aegis import Aegis +from AutoregressiveVideo2WorldGeneration.guardrail_blocklist import Blocklist +from AutoregressiveVideo2WorldGeneration.guardrail_common_core import GuardrailRunner +from AutoregressiveVideo2WorldGeneration.guardrail_face_blur_filter import RetinaFaceFilter +from AutoregressiveVideo2WorldGeneration.guardrail_video_content_safety_filter import VideoContentSafetyFilter +from AutoregressiveVideo2WorldGeneration import log def create_text_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner: diff --git a/cosmos1/models/guardrail/face_blur_filter/face_blur_filter.py b/guardrail_face_blur_filter.py similarity index 93% rename from cosmos1/models/guardrail/face_blur_filter/face_blur_filter.py rename to guardrail_face_blur_filter.py index a4163942f69f14f10cf4fdde54eb5c1303c16782..c97cf164990767b999a0e9a0e5a96547de73f810 100644 --- a/cosmos1/models/guardrail/face_blur_filter/face_blur_filter.py +++ b/guardrail_face_blur_filter.py @@ -16,6 +16,7 @@ import argparse import os +from AutoregressiveVideo2WorldGeneration import misc import numpy as np import torch from pytorch_retinaface.data import cfg_re50 @@ -24,11 +25,11 @@ from pytorch_retinaface.models.retinaface import RetinaFace from torch.utils.data import DataLoader, TensorDataset from tqdm import tqdm -from cosmos1.models.guardrail.common.core import GuardrailRunner, PostprocessingGuardrail -from cosmos1.models.guardrail.common.io_utils import get_video_filepaths, read_video, save_video -from cosmos1.models.guardrail.face_blur_filter.blur_utils import pixelate_face -from cosmos1.models.guardrail.face_blur_filter.retinaface_utils import decode_batch, filter_detected_boxes, load_model -from cosmos1.utils import log, misc +from AutoregressiveVideo2WorldGeneration.guardrail_common_core import GuardrailRunner, PostprocessingGuardrail +from AutoregressiveVideo2WorldGeneration.guardrail_common_io_utils import get_video_filepaths, read_video, save_video +from AutoregressiveVideo2WorldGeneration.guardrail_face_blur_filter_blur_utils import pixelate_face +from AutoregressiveVideo2WorldGeneration.guardrail_face_blur_filter_retinaface_utils import decode_batch, filter_detected_boxes, load_model +from AutoregressiveVideo2WorldGeneration import log DEFAULT_RETINAFACE_CHECKPOINT = "checkpoints/Cosmos-1.0-Guardrail/face_blur_filter/Resnet50_Final.pth" diff --git a/cosmos1/models/guardrail/face_blur_filter/blur_utils.py b/guardrail_face_blur_filter_blur_utils.py similarity index 100% rename from cosmos1/models/guardrail/face_blur_filter/blur_utils.py rename to guardrail_face_blur_filter_blur_utils.py diff --git a/cosmos1/models/guardrail/face_blur_filter/retinaface_utils.py b/guardrail_face_blur_filter_retinaface_utils.py similarity index 98% rename from cosmos1/models/guardrail/face_blur_filter/retinaface_utils.py rename to guardrail_face_blur_filter_retinaface_utils.py index 27e69cec320c28d13ea1a0443f77a565b59f24dd..c3c373986fc0c04720671857d8d1f4fa8ef73ec0 100644 --- a/cosmos1/models/guardrail/face_blur_filter/retinaface_utils.py +++ b/guardrail_face_blur_filter_retinaface_utils.py @@ -17,7 +17,7 @@ import numpy as np import torch from pytorch_retinaface.utils.nms.py_cpu_nms import py_cpu_nms -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration import log # Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py diff --git a/cosmos1/models/guardrail/video_content_safety_filter/video_content_safety_filter.py b/guardrail_video_content_safety_filter.py similarity index 92% rename from cosmos1/models/guardrail/video_content_safety_filter/video_content_safety_filter.py rename to guardrail_video_content_safety_filter.py index 78d812ae3a184de73d22cae5e3b55c2fe486b69e..4f541c94c0c83d55d159f795ba30e9490eea442d 100644 --- a/cosmos1/models/guardrail/video_content_safety_filter/video_content_safety_filter.py +++ b/guardrail_video_content_safety_filter.py @@ -18,14 +18,15 @@ import json import os from typing import Iterable, Tuple, Union +from AutoregressiveVideo2WorldGeneration import misc import torch from PIL import Image -from cosmos1.models.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner -from cosmos1.models.guardrail.common.io_utils import get_video_filepaths, read_video -from cosmos1.models.guardrail.video_content_safety_filter.model import ModelConfig, VideoSafetyModel -from cosmos1.models.guardrail.video_content_safety_filter.vision_encoder import SigLIPEncoder -from cosmos1.utils import log, misc +from AutoregressiveVideo2WorldGeneration.guardrail_common_core import ContentSafetyGuardrail, GuardrailRunner +from AutoregressiveVideo2WorldGeneration.guardrail_common_io_utils import get_video_filepaths, read_video +from AutoregressiveVideo2WorldGeneration.guardrail_video_content_safety_filter_model import ModelConfig, VideoSafetyModel +from AutoregressiveVideo2WorldGeneration.guardrail_video_content_safety_filter_vision_encoder import SigLIPEncoder +from AutoregressiveVideo2WorldGeneration import log DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/video_content_safety_filter" diff --git a/cosmos1/models/guardrail/video_content_safety_filter/model.py b/guardrail_video_content_safety_filter_model.py similarity index 96% rename from cosmos1/models/guardrail/video_content_safety_filter/model.py rename to guardrail_video_content_safety_filter_model.py index 1f53f1cabc2ee49c1f50dc17cef237ff1b80e37d..88e8af353c4fd4f3e5ad5961109d5c712994c43f 100644 --- a/cosmos1/models/guardrail/video_content_safety_filter/model.py +++ b/guardrail_video_content_safety_filter_model.py @@ -17,7 +17,7 @@ import attrs import torch import torch.nn as nn -from cosmos1.utils.config import make_freezable +from AutoregressiveVideo2WorldGeneration.config import make_freezable @make_freezable diff --git a/cosmos1/models/guardrail/video_content_safety_filter/vision_encoder.py b/guardrail_video_content_safety_filter_vision_encoder.py similarity index 100% rename from cosmos1/models/guardrail/video_content_safety_filter/vision_encoder.py rename to guardrail_video_content_safety_filter_vision_encoder.py diff --git a/cosmos1/utils/io.py b/io.py similarity index 100% rename from cosmos1/utils/io.py rename to io.py diff --git a/cosmos1/utils/lazy_config/lazy.py b/lazy.py similarity index 98% rename from cosmos1/utils/lazy_config/lazy.py rename to lazy.py index 68f761d7b3762cf387623e609954de90eac4619a..00194049dda47884b1ba3d18034ef22ae7bb4ff2 100644 --- a/cosmos1/utils/lazy_config/lazy.py +++ b/lazy.py @@ -29,8 +29,8 @@ import attrs import yaml from omegaconf import DictConfig, ListConfig, OmegaConf -from cosmos1.utils.lazy_config.file_io import PathManager -from cosmos1.utils.lazy_config.registry import _convert_target_to_string +from AutoregressiveVideo2WorldGeneration.lazy_file_io import PathManager +from AutoregressiveVideo2WorldGeneration.lazy_registry import _convert_target_to_string __all__ = ["LazyCall", "LazyConfig"] diff --git a/cosmos1/utils/lazy_config/__init__.py b/lazy_config_init.py similarity index 89% rename from cosmos1/utils/lazy_config/__init__.py rename to lazy_config_init.py index cb5b0ec33f05de2f3761e4f724200c4383f481d1..e3041a86a039156c194acab286a9e15366db58c0 100644 --- a/cosmos1/utils/lazy_config/__init__.py +++ b/lazy_config_init.py @@ -3,9 +3,9 @@ import os from omegaconf import DictConfig, OmegaConf -from cosmos1.utils.lazy_config.instantiate import instantiate -from cosmos1.utils.lazy_config.lazy import LazyCall, LazyConfig -from cosmos1.utils.lazy_config.omegaconf_patch import to_object +from AutoregressiveVideo2WorldGeneration.lazy_instantiate import instantiate +from AutoregressiveVideo2WorldGeneration.lazy import LazyCall, LazyConfig +from AutoregressiveVideo2WorldGeneration.lazy_omegaconf_patch import to_object OmegaConf.to_object = to_object diff --git a/cosmos1/utils/lazy_config/file_io.py b/lazy_file_io.py similarity index 100% rename from cosmos1/utils/lazy_config/file_io.py rename to lazy_file_io.py diff --git a/cosmos1/utils/lazy_config/instantiate.py b/lazy_instantiate.py similarity index 97% rename from cosmos1/utils/lazy_config/instantiate.py rename to lazy_instantiate.py index 742ed3816bdbf910e92e80812f947c70236dd9d5..2486ca065bcca28cdb6bcc717af2a66c451c361c 100644 --- a/cosmos1/utils/lazy_config/instantiate.py +++ b/lazy_instantiate.py @@ -20,7 +20,7 @@ from typing import Any import attrs -from cosmos1.utils.lazy_config.registry import _convert_target_to_string, locate +from AutoregressiveVideo2WorldGeneration.lazy_registry import _convert_target_to_string, locate __all__ = ["dump_dataclass", "instantiate"] diff --git a/cosmos1/utils/lazy_config/omegaconf_patch.py b/lazy_omegaconf_patch.py similarity index 100% rename from cosmos1/utils/lazy_config/omegaconf_patch.py rename to lazy_omegaconf_patch.py diff --git a/cosmos1/utils/lazy_config/registry.py b/lazy_registry.py similarity index 100% rename from cosmos1/utils/lazy_config/registry.py rename to lazy_registry.py diff --git a/cosmos1/utils/log.py b/log.py similarity index 71% rename from cosmos1/utils/log.py rename to log.py index 822a9755338b7db34bf0941a63d712ed84123749..e0dd98e373ecf8c70cbd301d878d25578ff1a908 100644 --- a/cosmos1/utils/log.py +++ b/log.py @@ -89,37 +89,39 @@ def _rank0_only_filter(record: Any) -> bool: record["message"] = f"[RANK {_get_rank()}]" + record["message"] return not is_rank0 - -def trace(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).trace(message) - - -def debug(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).debug(message) - - -def info(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).info(message) - - -def success(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).success(message) - - -def warning(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).warning(message) - - -def error(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).error(message) - - -def critical(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).critical(message) - - -def exception(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).exception(message) +class log(): + + @staticmethod + def trace(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).trace(message) + + @staticmethod + def debug(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).debug(message) + + @staticmethod + def info(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).info(message) + + @staticmethod + def success(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).success(message) + + @staticmethod + def warning(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).warning(message) + + @staticmethod + def error(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).error(message) + + @staticmethod + def critical(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).critical(message) + + @staticmethod + def exception(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).exception(message) def _get_rank(group: Optional[dist.ProcessGroup] = None) -> int: diff --git a/misc.py b/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..bc97aa3149629889dad13b87ada0023088dffb64 --- /dev/null +++ b/misc.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import collections +import collections.abc +import functools +import json +import random +import time +from contextlib import ContextDecorator +from typing import Any, Callable, TypeVar + +from AutoregressiveVideo2WorldGeneration import log +import numpy as np +import termcolor +import torch + +from AutoregressiveVideo2WorldGeneration import distributed + + +class misc(): + + @staticmethod + def to( + data: Any, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, + memory_format: torch.memory_format = torch.preserve_format, + ) -> Any: + """Recursively cast data into the specified device, dtype, and/or memory_format. + + The input data can be a tensor, a list of tensors, a dict of tensors. + See the documentation for torch.Tensor.to() for details. + + Args: + data (Any): Input data. + device (str | torch.device): GPU device (default: None). + dtype (torch.dtype): data type (default: None). + memory_format (torch.memory_format): memory organization format (default: torch.preserve_format). + + Returns: + data (Any): Data cast to the specified device, dtype, and/or memory_format. + """ + assert ( + device is not None or dtype is not None or memory_format is not None + ), "at least one of device, dtype, memory_format should be specified" + if isinstance(data, torch.Tensor): + is_cpu = (isinstance(device, str) and device == "cpu") or ( + isinstance(device, torch.device) and device.type == "cpu" + ) + data = data.to( + device=device, + dtype=dtype, + memory_format=memory_format, + non_blocking=(not is_cpu), + ) + return data + elif isinstance(data, collections.abc.Mapping): + return type(data)({key: to(data[key], device=device, dtype=dtype, memory_format=memory_format) for key in data}) + elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): + return type(data)([to(elem, device=device, dtype=dtype, memory_format=memory_format) for elem in data]) + else: + return data + + @staticmethod + def serialize(data: Any) -> Any: + """Serialize data by hierarchically traversing through iterables. + + Args: + data (Any): Input data. + + Returns: + data (Any): Serialized data. + """ + if isinstance(data, collections.abc.Mapping): + return type(data)({key: serialize(data[key]) for key in data}) + elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): + return type(data)([serialize(elem) for elem in data]) + else: + try: + json.dumps(data) + except TypeError: + data = str(data) + return data + + @staticmethod + def set_random_seed(seed: int, by_rank: bool = False) -> None: + """Set random seed. This includes random, numpy, Pytorch. + + Args: + seed (int): Random seed. + by_rank (bool): if true, each GPU will use a different random seed. + """ + if by_rank: + seed += distributed.get_rank() + log.info(f"Using random seed {seed}.") + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) # sets seed on the current CPU & all GPUs + + @staticmethod + def arch_invariant_rand( + shape: List[int] | Tuple[int], dtype: torch.dtype, device: str | torch.device, seed: int | None = None + ): + """Produce a GPU-architecture-invariant randomized Torch tensor. + + Args: + shape (list or tuple of ints): Output tensor shape. + dtype (torch.dtype): Output tensor type. + device (torch.device): Device holding the output. + seed (int): Optional randomization seed. + + Returns: + tensor (torch.tensor): Randomly-generated tensor. + """ + # Create a random number generator, optionally seeded + rng = np.random.RandomState(seed) + + # # Generate random numbers using the generator + random_array = rng.standard_normal(shape).astype(np.float32) # Use standard_normal for normal distribution + + # Convert to torch tensor and return + return torch.from_numpy(random_array).to(dtype=dtype, device=device) + + +T = TypeVar("T", bound=Callable[..., Any]) + + +class timer(ContextDecorator): # noqa: N801 + """Simple timer for timing the execution of code. + + It can be used as either a context manager or a function decorator. The timing result will be logged upon exit. + + Example: + def func_a(): + time.sleep(1) + with timer("func_a"): + func_a() + + @timer("func_b) + def func_b(): + time.sleep(1) + func_b() + """ + + def __init__(self, context: str, debug: bool = False): + self.context = context + self.debug = debug + + def __enter__(self) -> None: + self.tic = time.time() + + def __exit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001 + time_spent = time.time() - self.tic + if self.debug: + log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") + else: + log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") + + def __call__(self, func: T) -> T: + @functools.wraps(func) + def wrapper(*args, **kwargs): # noqa: ANN202 + tic = time.time() + result = func(*args, **kwargs) + time_spent = time.time() - tic + if self.debug: + log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") + else: + log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") + return result + + return wrapper # type: ignore + + +class Color: + """A convenience class to colorize strings in the console. + + Example: + import + print("This is {Color.red('important')}.") + """ + + @staticmethod + def red(x: str) -> str: + return termcolor.colored(str(x), color="red") + + @staticmethod + def green(x: str) -> str: + return termcolor.colored(str(x), color="green") + + @staticmethod + def cyan(x: str) -> str: + return termcolor.colored(str(x), color="cyan") + + @staticmethod + def yellow(x: str) -> str: + return termcolor.colored(str(x), color="yellow") diff --git a/cosmos1/models/common/t5_text_encoder.py b/t5_text_encoder.py similarity index 98% rename from cosmos1/models/common/t5_text_encoder.py rename to t5_text_encoder.py index 456d2eb8fbf56a08d77a73dcd4d422b50d43dcef..37f19bad0b2538097e378becbc7e88262f99ee45 100644 --- a/cosmos1/models/common/t5_text_encoder.py +++ b/t5_text_encoder.py @@ -19,7 +19,7 @@ import torch import transformers from transformers import T5EncoderModel, T5TokenizerFast -from cosmos1.utils import log +from AutoregressiveVideo2WorldGeneration import log transformers.logging.set_verbosity_error() diff --git a/cosmos1/models/autoregressive/inference/video2world.py b/video2world.py similarity index 93% rename from cosmos1/models/autoregressive/inference/video2world.py rename to video2world.py index 4f9bd2cf9f40eefa3cbd473f10c8b7d399b4f637..49aa113abdbe5ad2ae5ccd72ecfda18be5d67000 100644 --- a/cosmos1/models/autoregressive/inference/video2world.py +++ b/video2world.py @@ -19,10 +19,10 @@ import os import imageio import torch -from cosmos1.models.autoregressive.inference.world_generation_pipeline import ARVideo2WorldGenerationPipeline -from cosmos1.models.autoregressive.utils.inference import add_common_arguments, load_vision_input, validate_args -from cosmos1.utils import log -from cosmos1.utils.io import read_prompts_from_file +from AutoregressiveVideo2WorldGeneration.world_generation_pipeline import ARVideo2WorldGenerationPipeline +from AutoregressiveVideo2WorldGeneration.ar_utils_inference import add_common_arguments, load_vision_input, validate_args +from AutoregressiveVideo2WorldGeneration import log +from AutoregressiveVideo2WorldGeneration.io import read_prompts_from_file def parse_args(): diff --git a/video2world_hf.py b/video2world_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..ca48cf4965295c64d2c65f6c22b38994155b890e --- /dev/null +++ b/video2world_hf.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import imageio +import torch + +from AutoregressiveVideo2WorldGeneration.world_generation_pipeline import ARVideo2WorldGenerationPipeline +from AutoregressiveVideo2WorldGeneration.ar_utils_inference import load_vision_input, validate_args +from AutoregressiveVideo2WorldGeneration import log +from AutoregressiveVideo2WorldGeneration.io import read_prompts_from_file + +# from download_autoregressive import main as download_autoregressive +from transformers import PreTrainedModel, PretrainedConfig + + +class ARVideo2WorldConfig(PretrainedConfig): + model_type = "ARVideo2World" + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.checkpoint_dir = kwargs.get("checkpoint_dir", "checkpoints") + self.ar_model_dir = kwargs.get("ar_model_dir", "Cosmos-1.0-Autoregressive-5B-Video2World") + self.video_save_name = kwargs.get("video_save_name", "output") + self.video_save_folder = kwargs.get("video_save_folder", "outputs/") + self.prompt = kwargs.get("prompt", None) + + self.input_type = kwargs.get("input_type", "text_and_video") + self.input_image_or_video_path = kwargs.get("input_image_or_video_path", None) + self.batch_input_path = kwargs.get("batch_input_path", None) + self.num_input_frames = kwargs.get("num_input_frames", 9) + self.temperature = kwargs.get("temperature", 1.0) + self.top_p = kwargs.get("top_p", 0.8) + self.seed = kwargs.get("seed", 0) + + self.disable_diffusion_decoder = kwargs.get("disable_diffusion_decoder", False) + self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False) + self.offload_diffusion_decoder = kwargs.get("offload_diffusion_decoder", False) + self.offload_ar_model = kwargs.get("offload_ar_model", False) + self.offload_tokenizer = kwargs.get("offload_tokenizer", False) + self.offload_text_encoder_model = kwargs.get("offload_text_encoder_model", False) + + +class ARVideo2World(PreTrainedModel): + config_class = ARVideo2WorldConfig + + def __init__(self, args=ARVideo2WorldConfig()): + super().__init__(args) + torch.enable_grad(False) + self.args = args + + inference_type = "video2world" # When the inference_type is "video2world", AR model takes both text and video as input, the world generation is based on the input text prompt and video + self.sampling_config = validate_args(args, inference_type) + + # Initialize prompted base generation model pipeline + self.pipeline = ARVideo2WorldGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=args.ar_model_dir, + disable_diffusion_decoder=args.disable_diffusion_decoder, + offload_guardrail_models=args.offload_guardrail_models, + offload_diffusion_decoder=args.offload_diffusion_decoder, + offload_network=args.offload_ar_model, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + ) + + def forward(self, prompt, input_image_or_video_path): + args = self.args + + # Load input image(s) or video(s) + input_videos = load_vision_input( + input_type=args.input_type, + batch_input_path=args.batch_input_path, + input_image_or_video_path=input_image_or_video_path, + data_resolution=args.data_resolution, + num_input_frames=args.num_input_frames, + ) + + # Load input prompt(s) + if args.batch_input_path: + prompts_list = read_prompts_from_file(args.batch_input_path) + else: + prompts_list = [{"visual_input": input_image_or_video_path, "prompt": prompt}] + + # Iterate through prompts + for idx, prompt_entry in enumerate(prompts_list): + video_path = prompt_entry["visual_input"] + input_filename = os.path.basename(video_path) + + # Check if video exists in loaded videos + if input_filename not in input_videos: + log.critical(f"Input file {input_filename} not found, skipping prompt.") + continue + + inp_vid = input_videos[input_filename] + inp_prompt = prompt_entry["prompt"] + + # Generate video + log.info(f"Run with input: {prompt_entry}") + out_vid = self.pipeline.generate( + inp_prompt=inp_prompt, + inp_vid=inp_vid, + num_input_frames=args.num_input_frames, + seed=args.seed, + sampling_config=self.sampling_config, + ) + if out_vid is None: + log.critical("Guardrail blocked video2world generation.") + continue + + # Save video + if args.input_image_or_video_path: + out_vid_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") + else: + out_vid_path = os.path.join(args.video_save_folder, f"{idx}.mp4") + imageio.mimsave(out_vid_path, out_vid, fps=25) + + log.info(f"Saved video to {out_vid_path}") + + def save_pretrained(self, save_directory, **kwargs): + # We don't save anything, but need this function to override + pass + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + config = kwargs["config"] + other_args = kwargs.copy() + other_args.pop("config") + config.update(other_args) + # model_sizes = ["5B",] if "5B" in config.ar_model_dir else ["13B",] + # model_types = ["Video2World",] + # download_autoregressive(model_types, model_sizes, config.checkpoint_dir) + model = cls(config) + return model \ No newline at end of file diff --git a/cosmos1/models/autoregressive/inference/world_generation_pipeline.py b/world_generation_pipeline.py similarity index 97% rename from cosmos1/models/autoregressive/inference/world_generation_pipeline.py rename to world_generation_pipeline.py index 1874da033f84f423c01e3d3617564eb174b891c9..4273983a3ac63666931eb4b2260c5fd8eb062ff5 100644 --- a/cosmos1/models/autoregressive/inference/world_generation_pipeline.py +++ b/world_generation_pipeline.py @@ -17,29 +17,30 @@ import gc import os from typing import List, Optional, Tuple +from AutoregressiveVideo2WorldGeneration import misc import numpy as np import torch from einops import rearrange -from cosmos1.models.autoregressive.configs.base.model_config import create_video2world_model_config -from cosmos1.models.autoregressive.configs.base.tokenizer import TokenizerConfig -from cosmos1.models.autoregressive.configs.inference.inference_config import ( +from AutoregressiveVideo2WorldGeneration.ar_config_base_model_config import create_video2world_model_config +from AutoregressiveVideo2WorldGeneration.ar_config_base_tokenizer import TokenizerConfig +from AutoregressiveVideo2WorldGeneration.ar_config_inference_inference_config import ( DataShapeConfig, DiffusionDecoderSamplingConfig, InferenceConfig, SamplingConfig, ) -from cosmos1.models.autoregressive.diffusion_decoder.inference import diffusion_decoder_process_tokens -from cosmos1.models.autoregressive.diffusion_decoder.model import LatentDiffusionDecoderModel -from cosmos1.models.autoregressive.model import AutoRegressiveModel -from cosmos1.models.autoregressive.utils.inference import _SUPPORTED_CONTEXT_LEN, prepare_video_batch_for_saving -from cosmos1.models.common.base_world_generation_pipeline import BaseWorldGenerationPipeline -from cosmos1.models.diffusion.inference.inference_utils import ( +from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_inference import diffusion_decoder_process_tokens +from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_model import LatentDiffusionDecoderModel +from AutoregressiveVideo2WorldGeneration.ar_model import AutoRegressiveModel +from AutoregressiveVideo2WorldGeneration.ar_utils_inference import _SUPPORTED_CONTEXT_LEN, prepare_video_batch_for_saving +from AutoregressiveVideo2WorldGeneration.base_world_generation_pipeline import BaseWorldGenerationPipeline +from AutoregressiveVideo2WorldGeneration.df_inference_inference_utils import ( load_model_by_config, load_network_model, load_tokenizer_model, ) -from cosmos1.utils import log, misc +from AutoregressiveVideo2WorldGeneration import log def detect_model_size_from_ckpt_path(ckpt_path: str) -> str: