diff --git a/ar_config_base_model.py b/ar_config_base_model.py index d0ddd57133faf49bf99cf67231c528fbf8d543d0..f0cfe3eac295944e5b36794598a70bfacd27d705 100644 --- a/ar_config_base_model.py +++ b/ar_config_base_model.py @@ -17,7 +17,7 @@ from typing import Optional import attrs -from AutoregressiveVideo2WorldGeneration.ar_config_base_tokenizer import TokenizerConfig +from .ar_config_base_tokenizer import TokenizerConfig @attrs.define diff --git a/ar_config_base_model_config.py b/ar_config_base_model_config.py index 2676622fe75a969184d93d2cfd9891b7464f0c19..4de12fae686821ebf94aec3420719e6432856cf4 100644 --- a/ar_config_base_model_config.py +++ b/ar_config_base_model_config.py @@ -16,17 +16,17 @@ import copy from typing import Callable, List, Optional -from AutoregressiveVideo2WorldGeneration.ar_config_base_model import ModelConfig -from AutoregressiveVideo2WorldGeneration.ar_config_base_tokenizer import ( +from .ar_config_base_model import ModelConfig +from .ar_config_base_tokenizer import ( TextTokenizerConfig, TokenizerConfig, VideoTokenizerConfig, create_discrete_video_fsq_tokenizer_state_dict_config, ) -from AutoregressiveVideo2WorldGeneration.ar_tokenizer_image_text_tokenizer import ImageTextTokenizer -from AutoregressiveVideo2WorldGeneration.ar_tokenizer_text_tokenizer import TextTokenizer -from AutoregressiveVideo2WorldGeneration import log -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L +from .ar_tokenizer_image_text_tokenizer import ImageTextTokenizer +from .ar_tokenizer_text_tokenizer import TextTokenizer +from .log import log +from .lazy_config_init import LazyCall as L # Common architecture specifications BASE_CONFIG = {"n_kv_heads": 8, "norm_type": "rmsnorm", "norm_eps": 1e-5, "ffn_hidden_size": 14336} diff --git a/ar_config_base_tokenizer.py b/ar_config_base_tokenizer.py index cba1b056ae44746189d1d2ba58f35062968b629c..cd52189c10a42fd17fd90022f226de3da4a2852a 100644 --- a/ar_config_base_tokenizer.py +++ b/ar_config_base_tokenizer.py @@ -17,10 +17,10 @@ from typing import Optional import attrs -from AutoregressiveVideo2WorldGeneration.ar_tokenizer_discrete_video import DiscreteVideoFSQStateDictTokenizer -from AutoregressiveVideo2WorldGeneration.ar_tokenizer_networks import CausalDiscreteVideoTokenizer -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict +from .ar_tokenizer_discrete_video import DiscreteVideoFSQStateDictTokenizer +from .ar_tokenizer_networks import CausalDiscreteVideoTokenizer +from .lazy_config_init import LazyCall as L +from .lazy_config_init import LazyDict def create_discrete_video_fsq_tokenizer_state_dict_config( diff --git a/ar_config_inference_inference_config.py b/ar_config_inference_inference_config.py index 666b72721e1d5d0cc1b3f1e8527e402545d556f8..00be0b18e2e656c0d69d8d74298a45195530e8c4 100644 --- a/ar_config_inference_inference_config.py +++ b/ar_config_inference_inference_config.py @@ -17,7 +17,7 @@ from typing import Any, List, Union import attrs -from AutoregressiveVideo2WorldGeneration.ar_config_base_model import ModelConfig, TokenizerConfig +from .ar_config_base_model import ModelConfig, TokenizerConfig @attrs.define(slots=False) diff --git a/ar_diffusion_decoder_config_base_conditioner.py b/ar_diffusion_decoder_config_base_conditioner.py index d4f876dd2302d090e0b4943b1180999beddf36a6..9c52902c14f73193dc404e6526bcabcc18400d61 100644 --- a/ar_diffusion_decoder_config_base_conditioner.py +++ b/ar_diffusion_decoder_config_base_conditioner.py @@ -18,8 +18,8 @@ from typing import Dict, Optional import torch -from AutoregressiveVideo2WorldGeneration.df_conditioner import BaseVideoCondition, GeneralConditioner -from AutoregressiveVideo2WorldGeneration.df_config_base_conditioner import ( +from .df_conditioner import BaseVideoCondition, GeneralConditioner +from .df_config_base_conditioner import ( FPSConfig, ImageSizeConfig, LatentConditionConfig, @@ -28,8 +28,8 @@ from AutoregressiveVideo2WorldGeneration.df_config_base_conditioner import ( PaddingMaskConfig, TextConfig, ) -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict +from .lazy_config_init import LazyCall as L +from .lazy_config_init import LazyDict @dataclass diff --git a/ar_diffusion_decoder_config_config_latent_diffusion_decoder.py b/ar_diffusion_decoder_config_config_latent_diffusion_decoder.py index 531817956eb31490bb8668d6bfc8d551cc484e5a..e93fb7170a31d4bb44f491cd4bec4bcb552884d5 100644 --- a/ar_diffusion_decoder_config_config_latent_diffusion_decoder.py +++ b/ar_diffusion_decoder_config_config_latent_diffusion_decoder.py @@ -17,11 +17,11 @@ from typing import Any, List import attrs -from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_config_registry import register_configs as register_dd_configs -from AutoregressiveVideo2WorldGeneration.df_config_base_model import LatentDiffusionDecoderModelConfig -from AutoregressiveVideo2WorldGeneration.df_config_registry import register_configs -from AutoregressiveVideo2WorldGeneration import config -from AutoregressiveVideo2WorldGeneration.config_helper import import_all_modules_from_package +from .ar_diffusion_decoder_config_registry import register_configs as register_dd_configs +from .df_config_base_model import LatentDiffusionDecoderModelConfig +from .df_config_registry import register_configs +from .config import config +from .config_helper import import_all_modules_from_package @attrs.define(slots=False) diff --git a/ar_diffusion_decoder_config_inference_cosmos_diffusiondecoder_7b.py b/ar_diffusion_decoder_config_inference_cosmos_diffusiondecoder_7b.py index 7ef9ba0e663449568f1601905e75f4cf125824bc..7b84d08088ee5274e8792fabf439bc0aeceba7d3 100644 --- a/ar_diffusion_decoder_config_inference_cosmos_diffusiondecoder_7b.py +++ b/ar_diffusion_decoder_config_inference_cosmos_diffusiondecoder_7b.py @@ -15,9 +15,9 @@ from hydra.core.config_store import ConfigStore -from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_network import DiffusionDecoderGeneralDIT -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict +from .ar_diffusion_decoder_network import DiffusionDecoderGeneralDIT +from .lazy_config_init import LazyCall as L +from .lazy_config_init import LazyDict num_frames = 57 Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY: LazyDict = LazyDict( diff --git a/ar_diffusion_decoder_config_registry.py b/ar_diffusion_decoder_config_registry.py index 42849b6675be25e900f395315775e0a97b32c1f0..70f83fca49b296bba0704a3ab16e5a670fdcfa38 100644 --- a/ar_diffusion_decoder_config_registry.py +++ b/ar_diffusion_decoder_config_registry.py @@ -15,12 +15,12 @@ from hydra.core.config_store import ConfigStore -from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_config_base_conditioner import ( +from .ar_diffusion_decoder_config_base_conditioner import ( VideoLatentDiffusionDecoderConditionerConfig, ) -from AutoregressiveVideo2WorldGeneration.ar_tokenizer_discrete_video import DiscreteVideoFSQJITTokenizer -from AutoregressiveVideo2WorldGeneration.df_module_pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L +from .ar_tokenizer_discrete_video import DiscreteVideoFSQJITTokenizer +from .df_module_pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer +from .lazy_config_init import LazyCall as L def get_cosmos_video_discrete_tokenizer_comp8x16x16( diff --git a/ar_diffusion_decoder_inference.py b/ar_diffusion_decoder_inference.py index 2eec6f557896bc2648296751b0d309c8c157d283..f7d78b6170fb180f5f62bea23e149cdd465aefdc 100644 --- a/ar_diffusion_decoder_inference.py +++ b/ar_diffusion_decoder_inference.py @@ -19,10 +19,10 @@ from typing import List import torch -from AutoregressiveVideo2WorldGeneration.ar_config_inference_inference_config import DiffusionDecoderSamplingConfig -from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_model import LatentDiffusionDecoderModel -from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_utils import linear_blend_video_list, split_with_overlap -from AutoregressiveVideo2WorldGeneration import log +from .ar_config_inference_inference_config import DiffusionDecoderSamplingConfig +from .ar_diffusion_decoder_model import LatentDiffusionDecoderModel +from .ar_diffusion_decoder_utils import linear_blend_video_list, split_with_overlap +from .log import log def diffusion_decoder_process_tokens( diff --git a/ar_diffusion_decoder_model.py b/ar_diffusion_decoder_model.py index 6de3d809ab932b0f58aaf5f18134f7b2d9718ac5..73755631ed6b97ebf773b3941fc0f6d1621761f7 100644 --- a/ar_diffusion_decoder_model.py +++ b/ar_diffusion_decoder_model.py @@ -19,11 +19,11 @@ from typing import Callable, Dict, Optional, Tuple import torch from torch import Tensor -from AutoregressiveVideo2WorldGeneration.df_conditioner import BaseVideoCondition -from AutoregressiveVideo2WorldGeneration.df_df_functional_batch_ops import batch_mul -from AutoregressiveVideo2WorldGeneration.df_df_module_res_sampler import COMMON_SOLVER_OPTIONS -from AutoregressiveVideo2WorldGeneration.df_model_model_t2w import DiffusionT2WModel as VideoDiffusionModel -from AutoregressiveVideo2WorldGeneration.lazy_config_init import instantiate as lazy_instantiate +from .df_conditioner import BaseVideoCondition +from .df_df_functional_batch_ops import batch_mul +from .df_df_module_res_sampler import COMMON_SOLVER_OPTIONS +from .df_model_model_t2w import DiffusionT2WModel as VideoDiffusionModel +from .lazy_config_init import instantiate as lazy_instantiate @dataclass diff --git a/ar_diffusion_decoder_network.py b/ar_diffusion_decoder_network.py index 5208dce7369026ad81b6908a3f3a5ecbf5fa4ad2..ab621a7eb9b1bd87e681111297aad8382b40693e 100644 --- a/ar_diffusion_decoder_network.py +++ b/ar_diffusion_decoder_network.py @@ -20,8 +20,8 @@ from einops import rearrange from torch import nn from torchvision import transforms -from AutoregressiveVideo2WorldGeneration.df_module_blocks import PatchEmbed -from AutoregressiveVideo2WorldGeneration.df_network_general_dit import GeneralDIT +from .df_module_blocks import PatchEmbed +from .df_network_general_dit import GeneralDIT class DiffusionDecoderGeneralDIT(GeneralDIT): diff --git a/ar_model.py b/ar_model.py index 9527de29c60867fd95c4a1fd8a96e5eb15d7073e..4a13a8fde58e7852b683112be63eaed44e1f143f 100644 --- a/ar_model.py +++ b/ar_model.py @@ -19,24 +19,24 @@ import time from pathlib import Path from typing import Any, Dict, List, Optional, Set -from AutoregressiveVideo2WorldGeneration import misc +from .misc import misc import torch from safetensors.torch import load_file from torch.nn.modules.module import _IncompatibleKeys -from AutoregressiveVideo2WorldGeneration.ar_config_base_model import ModelConfig -from AutoregressiveVideo2WorldGeneration.ar_config_base_tokenizer import TokenizerConfig -from AutoregressiveVideo2WorldGeneration.ar_module_mm_projector import MultimodalProjector -from AutoregressiveVideo2WorldGeneration.ar_network_transformer import Transformer -from AutoregressiveVideo2WorldGeneration.ar_network_vit import VisionTransformer, get_vit_config -from AutoregressiveVideo2WorldGeneration.ar_tokenizer_tokenizer import DiscreteMultimodalTokenizer, update_vocab_size -from AutoregressiveVideo2WorldGeneration.ar_utils_checkpoint import ( +from .ar_config_base_model import ModelConfig +from .ar_config_base_tokenizer import TokenizerConfig +from .ar_module_mm_projector import MultimodalProjector +from .ar_network_transformer import Transformer +from .ar_network_vit import VisionTransformer, get_vit_config +from .ar_tokenizer_tokenizer import DiscreteMultimodalTokenizer, update_vocab_size +from .ar_utils_checkpoint import ( get_partial_state_dict, process_state_dict, substrings_to_ignore, ) -from AutoregressiveVideo2WorldGeneration.ar_utils_sampling import decode_n_tokens, decode_one_token, prefill -from AutoregressiveVideo2WorldGeneration import log +from .ar_utils_sampling import decode_n_tokens, decode_one_token, prefill +from .log import log class AutoRegressiveModel(torch.nn.Module): diff --git a/ar_module_attention.py b/ar_module_attention.py index bf13847edfa3881494401debdd65890beb59b2e3..578cd9ecfca36e5376fef8da5106652c6ca85b68 100644 --- a/ar_module_attention.py +++ b/ar_module_attention.py @@ -19,8 +19,8 @@ from typing import Optional, Union import torch from torch import nn -from AutoregressiveVideo2WorldGeneration.ar_module_embedding import RotaryPositionEmbedding -from AutoregressiveVideo2WorldGeneration.ar_module_normalization import create_norm +from .ar_module_embedding import RotaryPositionEmbedding +from .ar_module_normalization import create_norm class Attention(nn.Module): diff --git a/ar_network_transformer.py b/ar_network_transformer.py index bc48dd0723f084c474b55206a2c44ede297708c2..aeed1540fa1a3026c436873a872eaf3576ed6a09 100644 --- a/ar_network_transformer.py +++ b/ar_network_transformer.py @@ -19,17 +19,17 @@ import torch import torch.nn as nn from torch.nn.modules.module import _IncompatibleKeys -from AutoregressiveVideo2WorldGeneration.ar_module_attention import Attention -from AutoregressiveVideo2WorldGeneration.ar_module_embedding import ( +from .ar_module_attention import Attention +from .ar_module_embedding import ( RotaryPositionEmbeddingPytorchV1, RotaryPositionEmbeddingPytorchV2, SinCosPosEmbAxisTE, ) -from AutoregressiveVideo2WorldGeneration.ar_module_mlp import MLP -from AutoregressiveVideo2WorldGeneration.ar_module_normalization import create_norm -from AutoregressiveVideo2WorldGeneration.ar_utils_checkpoint import process_state_dict, substrings_to_ignore -from AutoregressiveVideo2WorldGeneration.ar_utils_misc import maybe_convert_to_namespace -from AutoregressiveVideo2WorldGeneration import log +from .ar_module_mlp import MLP +from .ar_module_normalization import create_norm +from .ar_utils_checkpoint import process_state_dict, substrings_to_ignore +from .ar_utils_misc import maybe_convert_to_namespace +from .log import log class TransformerBlock(nn.Module): diff --git a/ar_network_vit.py b/ar_network_vit.py index 5938979e6be4321ff52c518d0778601c6db7ab04..e350289be6acd54e1e089a87eeee86b92f236c32 100644 --- a/ar_network_vit.py +++ b/ar_network_vit.py @@ -26,9 +26,9 @@ from typing import Any, Callable, Mapping, Optional, Tuple import torch import torch.nn as nn -from AutoregressiveVideo2WorldGeneration.ar_module_normalization import create_norm -from AutoregressiveVideo2WorldGeneration.ar_network_transformer import TransformerBlock -from AutoregressiveVideo2WorldGeneration import log +from .ar_module_normalization import create_norm +from .ar_network_transformer import TransformerBlock +from .log import log def get_vit_config(model_name: str) -> Mapping[str, Any]: diff --git a/ar_tokenizer_discrete_video.py b/ar_tokenizer_discrete_video.py index 64a675d99a330c20f1a7de7c1dba85c58eda1afb..5e5a5244c87516121f3e7686c924f8b1c66cd772 100644 --- a/ar_tokenizer_discrete_video.py +++ b/ar_tokenizer_discrete_video.py @@ -18,7 +18,7 @@ from typing import Optional import torch from einops import rearrange -from AutoregressiveVideo2WorldGeneration.ar_tokenizer_quantizers import FSQuantizer +from .ar_tokenizer_quantizers import FSQuantizer # Make sure jit model output consistenly during consecutive calls # Check here: https://github.com/pytorch/pytorch/issues/74534 diff --git a/ar_tokenizer_image_text_tokenizer.py b/ar_tokenizer_image_text_tokenizer.py index d0911647e3159d8eb023df1b7fb2d02a8b8be76a..5877aa166d1d946b98ce604e2bd1a4284b884ae6 100644 --- a/ar_tokenizer_image_text_tokenizer.py +++ b/ar_tokenizer_image_text_tokenizer.py @@ -21,8 +21,8 @@ import transformers from transformers import AutoImageProcessor from transformers.image_utils import ImageInput, is_valid_image, load_image -from AutoregressiveVideo2WorldGeneration.ar_tokenizer_text_tokenizer import TextTokenizer -from AutoregressiveVideo2WorldGeneration import log +from .ar_tokenizer_text_tokenizer import TextTokenizer +from .log import log # Configuration for different vision-language models IMAGE_CONFIGS = { diff --git a/ar_tokenizer_modules.py b/ar_tokenizer_modules.py index 8f7744f24828c45d7da850d528fd936dbf2bc897..0c2f9c6280ccfa60e1ba8a38e3062e0caf99e71e 100644 --- a/ar_tokenizer_modules.py +++ b/ar_tokenizer_modules.py @@ -29,8 +29,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from AutoregressiveVideo2WorldGeneration.ar_tokenizer_patching import Patcher3D, UnPatcher3D -from AutoregressiveVideo2WorldGeneration.ar_tokenizer_utils import ( +from .ar_tokenizer_patching import Patcher3D, UnPatcher3D +from .ar_tokenizer_utils import ( CausalNormalize, batch2space, batch2time, @@ -41,7 +41,7 @@ from AutoregressiveVideo2WorldGeneration.ar_tokenizer_utils import ( space2batch, time2batch, ) -from AutoregressiveVideo2WorldGeneration import log +from .log import log class CausalConv3d(nn.Module): diff --git a/ar_tokenizer_networks.py b/ar_tokenizer_networks.py index 2b465abe8bb3e2c82438660765a34129241c2883..29be4d33e5dfb6255b5db0b99bcbc4311a3faa82 100644 --- a/ar_tokenizer_networks.py +++ b/ar_tokenizer_networks.py @@ -18,9 +18,9 @@ from collections import namedtuple import torch from torch import nn -from AutoregressiveVideo2WorldGeneration.ar_tokenizer_modules import CausalConv3d, DecoderFactorized, EncoderFactorized -from AutoregressiveVideo2WorldGeneration.ar_tokenizer_quantizers import FSQuantizer -from AutoregressiveVideo2WorldGeneration import log +from .ar_tokenizer_modules import CausalConv3d, DecoderFactorized, EncoderFactorized +from .ar_tokenizer_quantizers import FSQuantizer +from .log import log NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) diff --git a/ar_tokenizer_quantizers.py b/ar_tokenizer_quantizers.py index d1ce54a2071df140e229ada226b9e8852b219875..e07b51aef6f32fb39266c2f12de27c9ff87eb4d7 100644 --- a/ar_tokenizer_quantizers.py +++ b/ar_tokenizer_quantizers.py @@ -21,7 +21,7 @@ import torch import torch.nn as nn from einops import rearrange -from AutoregressiveVideo2WorldGeneration.ar_tokenizer_utils import default, pack_one, round_ste, unpack_one +from .ar_tokenizer_utils import default, pack_one, round_ste, unpack_one class FSQuantizer(nn.Module): diff --git a/ar_tokenizer_text_tokenizer.py b/ar_tokenizer_text_tokenizer.py index f6eae73cd1504f3b2eebbbc2e071d31c4d81dccf..9918ab7cc8f55dc0c159b58c158d3556b6819acd 100644 --- a/ar_tokenizer_text_tokenizer.py +++ b/ar_tokenizer_text_tokenizer.py @@ -19,7 +19,7 @@ import numpy as np import torch from transformers import AutoTokenizer -from AutoregressiveVideo2WorldGeneration import log +from .log import log def get_tokenizer_path(model_family: str, is_instruct_model: bool = False): diff --git a/ar_tokenizer_tokenizer.py b/ar_tokenizer_tokenizer.py index 7d241b3deb761cea8a97d7d97c24e7b17ded34c3..9861ef45253f4932a362923bdb6f07fd1b39666b 100644 --- a/ar_tokenizer_tokenizer.py +++ b/ar_tokenizer_tokenizer.py @@ -19,8 +19,8 @@ from typing import Optional import torch from einops import rearrange -from AutoregressiveVideo2WorldGeneration.ar_config_base_tokenizer import TokenizerConfig -from AutoregressiveVideo2WorldGeneration.lazy_config_init import instantiate as lazy_instantiate +from .ar_config_base_tokenizer import TokenizerConfig +from .lazy_config_init import instantiate as lazy_instantiate def update_vocab_size( diff --git a/ar_utils_inference.py b/ar_utils_inference.py index 4e637121ab632753f80505dcf1a6ac960b4879e6..53dea6ed871052e987bf5094f869778412202323 100644 --- a/ar_utils_inference.py +++ b/ar_utils_inference.py @@ -25,8 +25,8 @@ import torch import torchvision from PIL import Image -from AutoregressiveVideo2WorldGeneration.ar_config_inference_inference_config import SamplingConfig -from AutoregressiveVideo2WorldGeneration import log +from .ar_config_inference_inference_config import SamplingConfig +from .log import log _IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", "webp"] _VIDEO_EXTENSIONS = [".mp4"] diff --git a/ar_utils_sampling.py b/ar_utils_sampling.py index b83890022922a9ad051b78a633ac605f7cc903ea..2984b57e08440bd3117de9e25e4f3cfabd619e80 100644 --- a/ar_utils_sampling.py +++ b/ar_utils_sampling.py @@ -17,7 +17,7 @@ from typing import Optional, Tuple import torch -from AutoregressiveVideo2WorldGeneration.ar_network_transformer import Transformer +from .ar_network_transformer import Transformer def sample_top_p(logits, temperature, top_p, return_probs: bool = False): diff --git a/base.py b/base.py index 3f80351285901d4b43029ecd13fe33190283d9cf..9d5151a33133e9331c8256891dbfce0b8e622e0a 100644 --- a/base.py +++ b/base.py @@ -19,9 +19,9 @@ import os import imageio import torch -from AutoregressiveVideo2WorldGeneration.world_generation_pipeline import ARBaseGenerationPipeline -from AutoregressiveVideo2WorldGeneration.ar_utils_inference import add_common_arguments, load_vision_input, validate_args -from AutoregressiveVideo2WorldGeneration import log +from .world_generation_pipeline import ARBaseGenerationPipeline +from .ar_utils_inference import add_common_arguments, load_vision_input, validate_args +from .log import log def parse_args(): diff --git a/base_world_generation_pipeline.py b/base_world_generation_pipeline.py index 6e7337df0341bf0ee65f3a051470a752f3c091b8..83714775f6f5785e8efd78daf44f9be2337bb4a4 100644 --- a/base_world_generation_pipeline.py +++ b/base_world_generation_pipeline.py @@ -21,8 +21,8 @@ from typing import Any import numpy as np import torch -from AutoregressiveVideo2WorldGeneration.t5_text_encoder import CosmosT5TextEncoder -from AutoregressiveVideo2WorldGeneration import guardrail_common_presets as guardrail_presets +from .t5_text_encoder import CosmosT5TextEncoder +from .guardrail_common_presets import guardrail_common_presets as guardrail_presets class BaseWorldGenerationPipeline(ABC): diff --git a/config.py b/config.py index 705ca23385d9ecb4f20b373fe9e361d29bddc3ec..a24d1a0cbbe184ab0a2bfb5cbee13bfd327810ae 100644 --- a/config.py +++ b/config.py @@ -19,8 +19,8 @@ from typing import Any, TypeVar import attrs -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict -from AutoregressiveVideo2WorldGeneration.misc import Color +from .lazy_config_init import LazyDict +from .misc import Color T = TypeVar("T") diff --git a/config_helper.py b/config_helper.py index 3848d02a218ad8b9de43805e09fe1ef5f68367bc..7b5c6e553583e8047a37aea5e4925df659426ea2 100644 --- a/config_helper.py +++ b/config_helper.py @@ -27,8 +27,8 @@ from hydra import compose, initialize from hydra.core.config_store import ConfigStore from omegaconf import DictConfig, OmegaConf -from AutoregressiveVideo2WorldGeneration import log -from AutoregressiveVideo2WorldGeneration.config import Config +from .log import log +from .config import Config def is_attrs_or_dataclass(obj) -> bool: diff --git a/cosmos1/models/autoregressive/nemo/cosmos.py b/cosmos1/models/autoregressive/nemo/cosmos.py index 72e4d02dc37f7e03bbc35e6d476e748834101294..7b34fddaa3e4fce0a77f637b4c090d173aad303a 100644 --- a/cosmos1/models/autoregressive/nemo/cosmos.py +++ b/cosmos1/models/autoregressive/nemo/cosmos.py @@ -29,7 +29,7 @@ from nemo.lightning import OptimizerModule, io from nemo.lightning.base import teardown from torch import Tensor, nn -from AutoregressiveVideo2WorldGeneration import log +from .log import log class RotaryEmbedding3D(RotaryEmbedding): diff --git a/cosmos1/models/autoregressive/nemo/inference/general.py b/cosmos1/models/autoregressive/nemo/inference/general.py index 420e1bef9d97f1f8b63393d8b4b7c5ba1f0ec589..d701744b43c4ad1f65063ad2970ccae417f9a298 100644 --- a/cosmos1/models/autoregressive/nemo/inference/general.py +++ b/cosmos1/models/autoregressive/nemo/inference/general.py @@ -34,10 +34,10 @@ from nemo.lightning import io from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from cosmos1.models.autoregressive.nemo.utils import run_diffusion_decoder_model -from AutoregressiveVideo2WorldGeneration.ar_tokenizer_discrete_video import DiscreteVideoFSQJITTokenizer -from AutoregressiveVideo2WorldGeneration.ar_utils_inference import load_vision_input +from .ar_tokenizer_discrete_video import DiscreteVideoFSQJITTokenizer +from .ar_utils_inference import load_vision_input from AutoregressiveVideo2WorldGeneration import guardrail_common_presets as guardrail_presets -from AutoregressiveVideo2WorldGeneration import log +from .log import log torch._C._jit_set_texpr_fuser_enabled(False) diff --git a/cosmos1/models/autoregressive/nemo/post_training/prepare_dataset.py b/cosmos1/models/autoregressive/nemo/post_training/prepare_dataset.py index 63c37fb57b9911dea642d5a69ce9aa67d354a2ec..6c258fa32cf540a74cc056037ed2c5c0f3cc201c 100644 --- a/cosmos1/models/autoregressive/nemo/post_training/prepare_dataset.py +++ b/cosmos1/models/autoregressive/nemo/post_training/prepare_dataset.py @@ -23,8 +23,8 @@ from huggingface_hub import snapshot_download from nemo.collections.nlp.data.language_modeling.megatron import indexed_dataset from cosmos1.models.autoregressive.nemo.utils import read_input_videos -from AutoregressiveVideo2WorldGeneration.ar_tokenizer_discrete_video import DiscreteVideoFSQJITTokenizer -from AutoregressiveVideo2WorldGeneration import log +from .ar_tokenizer_discrete_video import DiscreteVideoFSQJITTokenizer +from .log import log TOKENIZER_COMPRESSION_FACTOR = [8, 16, 16] DATA_RESOLUTION_SUPPORTED = [640, 1024] diff --git a/cosmos1/models/autoregressive/nemo/utils.py b/cosmos1/models/autoregressive/nemo/utils.py index 14b679ef8f4140fc27813769bbd7fac8972d4b2b..2ebb94ab6aa8bb53a1eabc24bb6787850ae8532b 100644 --- a/cosmos1/models/autoregressive/nemo/utils.py +++ b/cosmos1/models/autoregressive/nemo/utils.py @@ -23,16 +23,16 @@ import torch import torchvision from huggingface_hub import snapshot_download -from AutoregressiveVideo2WorldGeneration.ar_config_inference_inference_config import DiffusionDecoderSamplingConfig -from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_inference import diffusion_decoder_process_tokens -from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_model import LatentDiffusionDecoderModel -from AutoregressiveVideo2WorldGeneration.df_inference_inference_utils import ( +from .ar_config_inference_inference_config import DiffusionDecoderSamplingConfig +from .ar_diffusion_decoder_inference import diffusion_decoder_process_tokens +from .ar_diffusion_decoder_model import LatentDiffusionDecoderModel +from .df_inference_inference_utils import ( load_network_model, load_tokenizer_model, skip_init_linear, ) -from AutoregressiveVideo2WorldGeneration import log -from AutoregressiveVideo2WorldGeneration.config_helper import get_config_module, override +from .log import log +from .config_helper import get_config_module, override TOKENIZER_COMPRESSION_FACTOR = [8, 16, 16] DATA_RESOLUTION_SUPPORTED = [640, 1024] diff --git a/cosmos1/models/diffusion/config/config.py b/cosmos1/models/diffusion/config/config.py index eb38e3850f95535e6b8c39a31b6715012d5df684..249725d810803ed3114223ebe9d2e5cb393b862b 100644 --- a/cosmos1/models/diffusion/config/config.py +++ b/cosmos1/models/diffusion/config/config.py @@ -17,10 +17,10 @@ from typing import Any, List import attrs -from AutoregressiveVideo2WorldGeneration.df_config_base_model import DefaultModelConfig -from AutoregressiveVideo2WorldGeneration.df_config_registry import register_configs +from .df_config_base_model import DefaultModelConfig +from .df_config_registry import register_configs from AutoregressiveVideo2WorldGeneration import config -from AutoregressiveVideo2WorldGeneration.config_helper import import_all_modules_from_package +from .config_helper import import_all_modules_from_package @attrs.define(slots=False) diff --git a/cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-text2world.py b/cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-text2world.py index 17565632abbb356766f552012357acbc5aba5e42..ecd066a54e6e302aae82794b13e7cc271a331700 100644 --- a/cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-text2world.py +++ b/cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-text2world.py @@ -15,7 +15,7 @@ from hydra.core.config_store import ConfigStore -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict +from .lazy_config_init import LazyDict Cosmos_1_0_Diffusion_Text2World_7B: LazyDict = LazyDict( dict( diff --git a/cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-video2world.py b/cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-video2world.py index 404c23aa2017ec433d5ead2b2ffaf0333bc7340d..5eff6ca4acbeacfb9d3b08e9560342ee9474e0f8 100644 --- a/cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-video2world.py +++ b/cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-video2world.py @@ -16,8 +16,8 @@ from hydra.core.config_store import ConfigStore from cosmos1.models.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict +from .lazy_config_init import LazyCall as L +from .lazy_config_init import LazyDict Cosmos_1_0_Diffusion_Video2World_7B: LazyDict = LazyDict( dict( diff --git a/cosmos1/models/diffusion/inference/text2world.py b/cosmos1/models/diffusion/inference/text2world.py index 503c8177336d589db2f7098d9ff3afefb7e076bf..90413098bf45d120bbba7a11801c2d885dcf7339 100644 --- a/cosmos1/models/diffusion/inference/text2world.py +++ b/cosmos1/models/diffusion/inference/text2world.py @@ -16,13 +16,13 @@ import argparse import os -from AutoregressiveVideo2WorldGeneration import misc +from .misc import misc import torch -from AutoregressiveVideo2WorldGeneration.df_inference_inference_utils import add_common_arguments, validate_args +from .df_inference_inference_utils import add_common_arguments, validate_args from cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline -from AutoregressiveVideo2WorldGeneration import log -from AutoregressiveVideo2WorldGeneration.io import read_prompts_from_file, save_video +from .log import log +from .io import read_prompts_from_file, save_video torch.enable_grad(False) diff --git a/cosmos1/models/diffusion/inference/video2world.py b/cosmos1/models/diffusion/inference/video2world.py index 2afe6559dfc0e1467880a12be7e74371ca36c83f..382dfcb7b9b80ee6c23882b75235b14038a9443d 100644 --- a/cosmos1/models/diffusion/inference/video2world.py +++ b/cosmos1/models/diffusion/inference/video2world.py @@ -16,13 +16,13 @@ import argparse import os -from AutoregressiveVideo2WorldGeneration import misc +from .misc import misc import torch -from AutoregressiveVideo2WorldGeneration.df_inference_inference_utils import add_common_arguments, check_input_frames, validate_args +from .df_inference_inference_utils import add_common_arguments, check_input_frames, validate_args from cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline -from AutoregressiveVideo2WorldGeneration import log -from AutoregressiveVideo2WorldGeneration.io import read_prompts_from_file, save_video +from .log import log +from .io import read_prompts_from_file, save_video torch.enable_grad(False) diff --git a/cosmos1/models/diffusion/inference/world_generation_pipeline.py b/cosmos1/models/diffusion/inference/world_generation_pipeline.py index 6ca5e63aa3883403dfe79999bd53653c6771e7ed..d998dfb1ec5d026e3965cfc57b5cc984e05d3b52 100644 --- a/cosmos1/models/diffusion/inference/world_generation_pipeline.py +++ b/cosmos1/models/diffusion/inference/world_generation_pipeline.py @@ -20,8 +20,8 @@ from typing import Any, Optional import numpy as np import torch -from AutoregressiveVideo2WorldGeneration.base_world_generation_pipeline import BaseWorldGenerationPipeline -from AutoregressiveVideo2WorldGeneration.df_inference_inference_utils import ( +from .base_world_generation_pipeline import BaseWorldGenerationPipeline +from .df_inference_inference_utils import ( generate_world_from_text, generate_world_from_video, get_condition_latent, @@ -30,8 +30,8 @@ from AutoregressiveVideo2WorldGeneration.df_inference_inference_utils import ( load_network_model, load_tokenizer_model, ) -from AutoregressiveVideo2WorldGeneration.df_model_model_t2w import DiffusionT2WModel -from AutoregressiveVideo2WorldGeneration.df_model_model_v2w import DiffusionV2WModel +from .df_model_model_t2w import DiffusionT2WModel +from .df_model_model_v2w import DiffusionV2WModel from cosmos1.models.diffusion.prompt_upsampler.text2world_prompt_upsampler_inference import ( create_prompt_upsampler, run_chat_completion, @@ -43,7 +43,7 @@ from cosmos1.models.diffusion.prompt_upsampler.video2world_prompt_upsampler_infe from cosmos1.models.diffusion.prompt_upsampler.video2world_prompt_upsampler_inference import ( run_chat_completion as run_chat_completion_vlm, ) -from AutoregressiveVideo2WorldGeneration import log +from .log import log MODEL_NAME_DICT = { "Cosmos-1.0-Diffusion-7B-Text2World": "Cosmos_1_0_Diffusion_Text2World_7B", diff --git a/cosmos1/models/diffusion/nemo/inference/general.py b/cosmos1/models/diffusion/nemo/inference/general.py index 4328173fe6929ee72c633257360226b270a2f23e..8c52e3abfb395dc13ec5ddc22ce562dccabb7a78 100644 --- a/cosmos1/models/diffusion/nemo/inference/general.py +++ b/cosmos1/models/diffusion/nemo/inference/general.py @@ -37,7 +37,7 @@ from nemo.collections.diffusion.sampler.cosmos.cosmos_diffusion_pipeline import from transformers import T5EncoderModel, T5TokenizerFast from cosmos1.models.diffusion.nemo.inference.inference_utils import process_prompt, save_video -from AutoregressiveVideo2WorldGeneration import log +from .log import log EXAMPLE_PROMPT = ( "The teal robot is cooking food in a kitchen. Steam rises from a simmering pot " diff --git a/cosmos1/models/diffusion/nemo/inference/inference_utils.py b/cosmos1/models/diffusion/nemo/inference/inference_utils.py index f10c8a67765f33cf86b8af32cf571931a7c16242..1ffc9b1c85536ac2df73829a267da4794413e9cc 100644 --- a/cosmos1/models/diffusion/nemo/inference/inference_utils.py +++ b/cosmos1/models/diffusion/nemo/inference/inference_utils.py @@ -19,18 +19,18 @@ import imageio import numpy as np import torch -from AutoregressiveVideo2WorldGeneration.ar_model import AutoRegressiveModel +from .ar_model import AutoRegressiveModel from cosmos1.models.diffusion.prompt_upsampler.text2world_prompt_upsampler_inference import ( create_prompt_upsampler, run_chat_completion, ) -from AutoregressiveVideo2WorldGeneration.guardrail_common_presets import ( +from .guardrail_common_presets import ( create_text_guardrail_runner, create_video_guardrail_runner, run_text_guardrail, run_video_guardrail, ) -from AutoregressiveVideo2WorldGeneration import log +from .log import log def get_upsampled_prompt( diff --git a/cosmos1/models/diffusion/nemo/post_training/prepare_dataset.py b/cosmos1/models/diffusion/nemo/post_training/prepare_dataset.py index 1dad9f347f4e4166a73d217ed6406bff43ad551e..eeb08cde550949b08444bd5191d87b52896c82e7 100644 --- a/cosmos1/models/diffusion/nemo/post_training/prepare_dataset.py +++ b/cosmos1/models/diffusion/nemo/post_training/prepare_dataset.py @@ -27,7 +27,7 @@ from nemo.collections.diffusion.models.model import DiT7BConfig from tqdm import tqdm from transformers import T5EncoderModel, T5TokenizerFast -from AutoregressiveVideo2WorldGeneration import log +from .log import log def get_parser(): diff --git a/cosmos1/models/diffusion/networks/general_dit_video_conditioned.py b/cosmos1/models/diffusion/networks/general_dit_video_conditioned.py index 0ffaa9f2a1b591196b82cae9659a64ac3d746b62..227d871f8d1ff698587071d4ba850c3f46386dad 100644 --- a/cosmos1/models/diffusion/networks/general_dit_video_conditioned.py +++ b/cosmos1/models/diffusion/networks/general_dit_video_conditioned.py @@ -19,10 +19,10 @@ import torch from einops import rearrange from torch import nn -from AutoregressiveVideo2WorldGeneration.df_conditioner import DataType -from AutoregressiveVideo2WorldGeneration.df_module_blocks import TimestepEmbedding, Timesteps -from AutoregressiveVideo2WorldGeneration.df_network_general_dit import GeneralDIT -from AutoregressiveVideo2WorldGeneration import log +from .df_conditioner import DataType +from .df_module_blocks import TimestepEmbedding, Timesteps +from .df_network_general_dit import GeneralDIT +from .log import log class VideoExtendGeneralDIT(GeneralDIT): diff --git a/cosmos1/models/diffusion/prompt_upsampler/inference.py b/cosmos1/models/diffusion/prompt_upsampler/inference.py index d33fb465d4f7a633ddafb05ab1e50bf46b59f1d7..4252b3758a2d49c5809dd963f7ae403209cbff7b 100644 --- a/cosmos1/models/diffusion/prompt_upsampler/inference.py +++ b/cosmos1/models/diffusion/prompt_upsampler/inference.py @@ -17,9 +17,9 @@ from typing import List, Optional, TypedDict import torch -from AutoregressiveVideo2WorldGeneration.ar_model import AutoRegressiveModel -from AutoregressiveVideo2WorldGeneration.ar_tokenizer_image_text_tokenizer import ImageTextTokenizer -from AutoregressiveVideo2WorldGeneration.ar_tokenizer_text_tokenizer import TextTokenizer +from .ar_model import AutoRegressiveModel +from .ar_tokenizer_image_text_tokenizer import ImageTextTokenizer +from .ar_tokenizer_text_tokenizer import TextTokenizer class ChatPrediction(TypedDict, total=False): diff --git a/cosmos1/models/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py b/cosmos1/models/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py index 7073af4819595aa0080d094ab3e8b2c161b6a01e..d4af83acc69280dd63ef9adb3d4d987cafa9820c 100644 --- a/cosmos1/models/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py +++ b/cosmos1/models/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py @@ -23,11 +23,11 @@ import argparse import os import re -from AutoregressiveVideo2WorldGeneration.ar_config_base_model_config import create_text_model_config -from AutoregressiveVideo2WorldGeneration.ar_model import AutoRegressiveModel +from .ar_config_base_model_config import create_text_model_config +from .ar_model import AutoRegressiveModel from cosmos1.models.diffusion.prompt_upsampler.inference import chat_completion from AutoregressiveVideo2WorldGeneration import guardrail_common_presets as guardrail_presets -from AutoregressiveVideo2WorldGeneration import log +from .log import log def create_prompt_upsampler(checkpoint_dir: str) -> AutoRegressiveModel: diff --git a/cosmos1/models/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py b/cosmos1/models/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py index 2c0ff29486fcb7c18652f4c19f9ff76ba8dc5dbc..71d7763608e72c241c3277997905c46a19e7278c 100644 --- a/cosmos1/models/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py +++ b/cosmos1/models/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py @@ -26,12 +26,12 @@ from math import ceil from PIL import Image -from AutoregressiveVideo2WorldGeneration.ar_config_base_model_config import create_vision_language_model_config -from AutoregressiveVideo2WorldGeneration.ar_model import AutoRegressiveModel +from .ar_config_base_model_config import create_vision_language_model_config +from .ar_model import AutoRegressiveModel from cosmos1.models.diffusion.prompt_upsampler.inference import chat_completion from AutoregressiveVideo2WorldGeneration import guardrail_common_presets as guardrail_presets -from AutoregressiveVideo2WorldGeneration import log -from AutoregressiveVideo2WorldGeneration.io import load_from_fileobj +from .log import log +from .io import load_from_fileobj def create_vlm_prompt_upsampler( diff --git a/df_conditioner.py b/df_conditioner.py index 0f101b3030da778cc8ea8ed46043ac6acd692e3f..4146fad65c365a8c4fd6903a0ea33860142f64f5 100644 --- a/df_conditioner.py +++ b/df_conditioner.py @@ -23,9 +23,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn -from AutoregressiveVideo2WorldGeneration.df_df_functional_batch_ops import batch_mul -from AutoregressiveVideo2WorldGeneration import log -from AutoregressiveVideo2WorldGeneration.lazy_config_init import instantiate +from .df_df_functional_batch_ops import batch_mul +from .log import log +from .lazy_config_init import instantiate class BaseConditionEntry(nn.Module): diff --git a/df_config_base_conditioner.py b/df_config_base_conditioner.py index 2a5845d0f256991b7b486d6bfdd19f7632c39043..c5f3a2c1d6d14ad0cbde6bbfe9e1c3cf6b71519a 100644 --- a/df_config_base_conditioner.py +++ b/df_config_base_conditioner.py @@ -18,9 +18,9 @@ from typing import Dict, List, Optional import attrs import torch -from AutoregressiveVideo2WorldGeneration.df_conditioner import BaseConditionEntry, TextAttr, VideoConditioner, VideoExtendConditioner -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict +from .df_conditioner import BaseConditionEntry, TextAttr, VideoConditioner, VideoExtendConditioner +from .lazy_config_init import LazyCall as L +from .lazy_config_init import LazyDict @attrs.define(slots=False) diff --git a/df_config_base_model.py b/df_config_base_model.py index 42d9e551d4460aa5e694052f3df7ad9c66ca9b45..7b41f7fb8cd4cf73d89b3f6d550dfc2d19fbe254 100644 --- a/df_config_base_model.py +++ b/df_config_base_model.py @@ -17,7 +17,7 @@ from typing import List import attrs -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict +from .lazy_config_init import LazyDict @attrs.define(slots=False) diff --git a/df_config_base_net.py b/df_config_base_net.py index 5b843163eee1e4cd7d6fcc347b72148a3f14253a..5a621c4738ef3f335dc4b97675e93f2290d0c45a 100644 --- a/df_config_base_net.py +++ b/df_config_base_net.py @@ -15,9 +15,9 @@ import copy -from AutoregressiveVideo2WorldGeneration.df_network_general_dit import GeneralDIT -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyDict +from .df_network_general_dit import GeneralDIT +from .lazy_config_init import LazyCall as L +from .lazy_config_init import LazyDict FADITV2Config: LazyDict = L(GeneralDIT)( max_img_h=240, diff --git a/df_config_base_tokenizer.py b/df_config_base_tokenizer.py index 7c6363ead8e27335b2c6d76be8341e1d9b326c06..c2b2aedf029512a0b00fd5bcebeb636db245a61f 100644 --- a/df_config_base_tokenizer.py +++ b/df_config_base_tokenizer.py @@ -15,8 +15,8 @@ import omegaconf -from AutoregressiveVideo2WorldGeneration.df_module_pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer -from AutoregressiveVideo2WorldGeneration.lazy_config_init import LazyCall as L +from .df_module_pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer +from .lazy_config_init import LazyCall as L TOKENIZER_OPTIONS = {} diff --git a/df_config_registry.py b/df_config_registry.py index aeba6024f7a858bc9d70756a7839cfb3d8add54a..c84b1d285f28ec6b1b3b10852fcf6e8dbf1587e5 100644 --- a/df_config_registry.py +++ b/df_config_registry.py @@ -15,13 +15,13 @@ from hydra.core.config_store import ConfigStore -from AutoregressiveVideo2WorldGeneration.df_config_base_conditioner import ( +from .df_config_base_conditioner import ( BaseVideoConditionerConfig, VideoConditionerFpsSizePaddingConfig, VideoExtendConditionerConfig, ) -from AutoregressiveVideo2WorldGeneration.df_config_base_net import FADITV2_14B_Config, FADITV2Config -from AutoregressiveVideo2WorldGeneration.df_config_base_tokenizer import get_cosmos_diffusion_tokenizer_comp8x8x8 +from .df_config_base_net import FADITV2_14B_Config, FADITV2Config +from .df_config_base_tokenizer import get_cosmos_diffusion_tokenizer_comp8x8x8 def register_net(cs): diff --git a/df_df_functional_multi_step.py b/df_df_functional_multi_step.py index 954075c5aa4028ca5ea21c40208e788cee136ad0..3c5a1dbe30558d9e7e97ad64304161c4e61a00f5 100644 --- a/df_df_functional_multi_step.py +++ b/df_df_functional_multi_step.py @@ -21,7 +21,7 @@ from typing import Callable, List, Tuple import torch -from AutoregressiveVideo2WorldGeneration.df_df_functional_runge_kutta import reg_x0_euler_step, res_x0_rk2_step +from .df_df_functional_runge_kutta import reg_x0_euler_step, res_x0_rk2_step def order2_fn( diff --git a/df_df_functional_runge_kutta.py b/df_df_functional_runge_kutta.py index ef9be1adb343c40d3dc9269bb2bf346d70a7a02b..9586934f8c1949d734b4ea3080135d2769ec481a 100644 --- a/df_df_functional_runge_kutta.py +++ b/df_df_functional_runge_kutta.py @@ -17,7 +17,7 @@ from typing import Callable, Tuple import torch -from AutoregressiveVideo2WorldGeneration.df_df_functional_batch_ops import batch_mul +from .df_df_functional_batch_ops import batch_mul def phi1(t: torch.Tensor) -> torch.Tensor: diff --git a/df_df_module_res_sampler.py b/df_df_module_res_sampler.py index afdd85ec8cc975b24f0b2309298ca2192bc46a03..184d60dec6f9b0326dc0aa1a3d9b89c06fa7566e 100644 --- a/df_df_module_res_sampler.py +++ b/df_df_module_res_sampler.py @@ -28,9 +28,9 @@ from typing import Any, Callable, List, Literal, Optional, Tuple, Union import attrs import torch -from AutoregressiveVideo2WorldGeneration.df_df_functional_multi_step import get_multi_step_fn, is_multi_step_fn_supported -from AutoregressiveVideo2WorldGeneration.df_df_functional_runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported -from AutoregressiveVideo2WorldGeneration.config import make_freezable +from .df_df_functional_multi_step import get_multi_step_fn, is_multi_step_fn_supported +from .df_df_functional_runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported +from .config import make_freezable COMMON_SOLVER_OPTIONS = Literal["2ab", "2mid", "1euler"] diff --git a/df_inference_inference_utils.py b/df_inference_inference_utils.py index 2f5518c1512707ae3dffcb7b238025b859e10685..68e9cbb58aa1a39cd62c15a01b3e6526a49b66b0 100644 --- a/df_inference_inference_utils.py +++ b/df_inference_inference_utils.py @@ -18,18 +18,18 @@ import importlib from contextlib import contextmanager from typing import List, NamedTuple, Optional, Tuple -from AutoregressiveVideo2WorldGeneration import misc +from .misc import misc import einops import imageio import numpy as np import torch import torchvision.transforms.functional as transforms_F -from AutoregressiveVideo2WorldGeneration.df_model_model_t2w import DiffusionT2WModel -from AutoregressiveVideo2WorldGeneration.df_model_model_v2w import DiffusionV2WModel -from AutoregressiveVideo2WorldGeneration import log -from AutoregressiveVideo2WorldGeneration.config_helper import get_config_module, override -from AutoregressiveVideo2WorldGeneration.io import load_from_fileobj +from .df_model_model_t2w import DiffusionT2WModel +from .df_model_model_v2w import DiffusionV2WModel +from .log import log +from .config_helper import get_config_module, override +from .io import load_from_fileobj TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) if TORCH_VERSION >= (1, 11): diff --git a/df_model_model_t2w.py b/df_model_model_t2w.py index 458c1e35c207948ecf715fd93d0972630c79ed80..e0b49f26ced373c9f002ec2ca0adfe674126ebf5 100644 --- a/df_model_model_t2w.py +++ b/df_model_model_t2w.py @@ -15,19 +15,19 @@ from typing import Callable, Dict, Optional, Tuple -from AutoregressiveVideo2WorldGeneration import misc +from .misc import misc import torch from torch import Tensor -from AutoregressiveVideo2WorldGeneration.df_conditioner import CosmosCondition -from AutoregressiveVideo2WorldGeneration.df_df_functional_batch_ops import batch_mul -from AutoregressiveVideo2WorldGeneration.df_df_module_denoiser_scaling import EDMScaling -from AutoregressiveVideo2WorldGeneration.df_df_module_res_sampler import COMMON_SOLVER_OPTIONS, Sampler -from AutoregressiveVideo2WorldGeneration.df_df_types import DenoisePrediction -from AutoregressiveVideo2WorldGeneration.df_module_blocks import FourierFeatures -from AutoregressiveVideo2WorldGeneration.df_module_pretrained_vae import BaseVAE -from AutoregressiveVideo2WorldGeneration import log -from AutoregressiveVideo2WorldGeneration.lazy_config_init import instantiate as lazy_instantiate +from .df_conditioner import CosmosCondition +from .df_df_functional_batch_ops import batch_mul +from .df_df_module_denoiser_scaling import EDMScaling +from .df_df_module_res_sampler import COMMON_SOLVER_OPTIONS, Sampler +from .df_df_types import DenoisePrediction +from .df_module_blocks import FourierFeatures +from .df_module_pretrained_vae import BaseVAE +from .log import log +from .lazy_config_init import instantiate as lazy_instantiate class EDMSDE: diff --git a/df_model_model_v2w.py b/df_model_model_v2w.py index 6fdbf21dda8f909a9611ca1780b1b0c04fb99e1c..54ff4d48b535d2a1f27bbcc75c20ef16821b11e1 100644 --- a/df_model_model_v2w.py +++ b/df_model_model_v2w.py @@ -16,15 +16,15 @@ from dataclasses import dataclass from typing import Callable, Dict, Optional, Tuple, Union -from AutoregressiveVideo2WorldGeneration import misc +from .misc import misc import torch from torch import Tensor -from AutoregressiveVideo2WorldGeneration.df_conditioner import VideoExtendCondition -from AutoregressiveVideo2WorldGeneration.df_config_base_conditioner import VideoCondBoolConfig -from AutoregressiveVideo2WorldGeneration.df_df_functional_batch_ops import batch_mul -from AutoregressiveVideo2WorldGeneration.df_model_model_t2w import DiffusionT2WModel -from AutoregressiveVideo2WorldGeneration import log +from .df_conditioner import VideoExtendCondition +from .df_config_base_conditioner import VideoCondBoolConfig +from .df_df_functional_batch_ops import batch_mul +from .df_model_model_t2w import DiffusionT2WModel +from .log import log @dataclass diff --git a/df_module_blocks.py b/df_module_blocks.py index 4d14bf25ad4b22c81be7e1f365c0aa4079536583..e7ef8a7e7acb9ea34133f1b4f8892f66f5f73e76 100644 --- a/df_module_blocks.py +++ b/df_module_blocks.py @@ -22,8 +22,8 @@ from einops import rearrange, repeat from einops.layers.torch import Rearrange from torch import nn -from AutoregressiveVideo2WorldGeneration.df_module_attention import Attention, GPT2FeedForward -from AutoregressiveVideo2WorldGeneration import log +from .df_module_attention import Attention, GPT2FeedForward +from .log import log def modulate(x, shift, scale): diff --git a/df_module_position_embedding.py b/df_module_position_embedding.py index cd63a1e3f3921928208cbce42633e41cf8429f9a..8f5b386748b11785d008629c567e2c1e2155342b 100644 --- a/df_module_position_embedding.py +++ b/df_module_position_embedding.py @@ -19,8 +19,8 @@ import torch from einops import rearrange, repeat from torch import nn -from AutoregressiveVideo2WorldGeneration.df_module_attention import normalize -from AutoregressiveVideo2WorldGeneration.df_module_timm import trunc_normal_ +from .df_module_attention import normalize +from .df_module_timm import trunc_normal_ class VideoPositionEmb(nn.Module): diff --git a/df_network_general_dit.py b/df_network_general_dit.py index 90d40237c58de17818c10a8f41cbe3d3df2d0b64..f447869d58c3341a7ee431827f3e2e0e60bdb8ad 100644 --- a/df_network_general_dit.py +++ b/df_network_general_dit.py @@ -24,17 +24,17 @@ from einops import rearrange from torch import nn from torchvision import transforms -from AutoregressiveVideo2WorldGeneration.df_conditioner import DataType -from AutoregressiveVideo2WorldGeneration.df_module_attention import get_normalization -from AutoregressiveVideo2WorldGeneration.df_module_blocks import ( +from .df_conditioner import DataType +from .df_module_attention import get_normalization +from .df_module_blocks import ( FinalLayer, GeneralDITTransformerBlock, PatchEmbed, TimestepEmbedding, Timesteps, ) -from AutoregressiveVideo2WorldGeneration.df_module_position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb -from AutoregressiveVideo2WorldGeneration import log +from .df_module_position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb +from .log import log class GeneralDIT(nn.Module): diff --git a/distributed.py b/distributed.py index d3ac63b7401852597d528cf99167d5995b6f6339..69f477ced9dfe59deda742bc507addf7d7268bdf 100644 --- a/distributed.py +++ b/distributed.py @@ -27,8 +27,8 @@ import pynvml import torch import torch.distributed as dist -from AutoregressiveVideo2WorldGeneration import log -from AutoregressiveVideo2WorldGeneration.device import Device +from .log import log +from .device import Device def init() -> int | None: diff --git a/download_diffusion.py b/download_diffusion.py index cdd94176ea4500b600881c99d2e5b0cff83b5003..60650ce74ba3fa8a1dbad94978fd3ebe610b0d02 100644 --- a/download_diffusion.py +++ b/download_diffusion.py @@ -18,7 +18,7 @@ from pathlib import Path from huggingface_hub import snapshot_download -from AutoregressiveVideo2WorldGeneration.convert_pixtral_ckpt import convert_pixtral_checkpoint +from .convert_pixtral_ckpt import convert_pixtral_checkpoint def parse_args(): diff --git a/guardrail_aegis.py b/guardrail_aegis.py index eef53731deccf32c23d09943a2e9251b82d71f68..77c3f88ca85134e689203e9ac157673c42edb0b3 100644 --- a/guardrail_aegis.py +++ b/guardrail_aegis.py @@ -15,14 +15,14 @@ import argparse -from AutoregressiveVideo2WorldGeneration import misc +from .misc import misc import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer -from AutoregressiveVideo2WorldGeneration.guardrail_aegis_categories import UNSAFE_CATEGORIES -from AutoregressiveVideo2WorldGeneration.guardrail_common_core import ContentSafetyGuardrail, GuardrailRunner -from AutoregressiveVideo2WorldGeneration import log +from .guardrail_aegis_categories import UNSAFE_CATEGORIES +from .guardrail_common_core import ContentSafetyGuardrail, GuardrailRunner +from .log import log SAFE = misc.Color.green("SAFE") UNSAFE = misc.Color.red("UNSAFE") diff --git a/guardrail_blocklist.py b/guardrail_blocklist.py index b73595216ab9d7ce9615a994d38c8f59bb95ded6..46385211d438d1953e9ba21376680dc2c42db01c 100644 --- a/guardrail_blocklist.py +++ b/guardrail_blocklist.py @@ -19,13 +19,13 @@ import re import string from difflib import SequenceMatcher -from AutoregressiveVideo2WorldGeneration import misc +from .misc import misc import nltk from better_profanity import profanity -from AutoregressiveVideo2WorldGeneration.guardrail_blocklist_utils import read_keyword_list_from_dir, to_ascii -from AutoregressiveVideo2WorldGeneration.guardrail_common_core import ContentSafetyGuardrail, GuardrailRunner -from AutoregressiveVideo2WorldGeneration import log +from .guardrail_blocklist_utils import read_keyword_list_from_dir, to_ascii +from .guardrail_common_core import ContentSafetyGuardrail, GuardrailRunner +from .log import log DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/blocklist" CENSOR = misc.Color.red("*") diff --git a/guardrail_blocklist_utils.py b/guardrail_blocklist_utils.py index 481fcbad2d64f49aeb33f88eea52b638573ef040..859eb6498143e5b063dbc888dca7748a07cfda9d 100644 --- a/guardrail_blocklist_utils.py +++ b/guardrail_blocklist_utils.py @@ -16,7 +16,7 @@ import os import re -from AutoregressiveVideo2WorldGeneration import log +from .log import log def read_keyword_list_from_dir(folder_path: str) -> list[str]: diff --git a/guardrail_common_core.py b/guardrail_common_core.py index b094f70d4c04058c0c61dd45dcfa0292c9b6c23f..e4916c3379353f577a811def9a1d29f2e0a48708 100644 --- a/guardrail_common_core.py +++ b/guardrail_common_core.py @@ -17,7 +17,7 @@ from typing import Any, Tuple import numpy as np -from AutoregressiveVideo2WorldGeneration import log +from .log import log class ContentSafetyGuardrail: diff --git a/guardrail_common_io_utils.py b/guardrail_common_io_utils.py index 1a655f5d05e2af9b8db59eab9850d26fb08e1c52..148897d5cae9165673cb74e336548c71adb261b1 100644 --- a/guardrail_common_io_utils.py +++ b/guardrail_common_io_utils.py @@ -19,7 +19,7 @@ from dataclasses import dataclass import imageio import numpy as np -from AutoregressiveVideo2WorldGeneration import log +from .log import log @dataclass diff --git a/guardrail_common_presets.py b/guardrail_common_presets.py index 2d00856dd5d6da7596f6b6f910e688c1ded47a3d..00dc8c5edb503d273bbaaf91428072d569d6420c 100644 --- a/guardrail_common_presets.py +++ b/guardrail_common_presets.py @@ -17,12 +17,12 @@ import os import numpy as np -from AutoregressiveVideo2WorldGeneration.guardrail_aegis import Aegis -from AutoregressiveVideo2WorldGeneration.guardrail_blocklist import Blocklist -from AutoregressiveVideo2WorldGeneration.guardrail_common_core import GuardrailRunner -from AutoregressiveVideo2WorldGeneration.guardrail_face_blur_filter import RetinaFaceFilter -from AutoregressiveVideo2WorldGeneration.guardrail_video_content_safety_filter import VideoContentSafetyFilter -from AutoregressiveVideo2WorldGeneration import log +from .guardrail_aegis import Aegis +from .guardrail_blocklist import Blocklist +from .guardrail_common_core import GuardrailRunner +from .guardrail_face_blur_filter import RetinaFaceFilter +from .guardrail_video_content_safety_filter import VideoContentSafetyFilter +from .log import log def create_text_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner: diff --git a/guardrail_face_blur_filter.py b/guardrail_face_blur_filter.py index c97cf164990767b999a0e9a0e5a96547de73f810..9d565d078fbe37e1d31cf8a445a460e2bae291f1 100644 --- a/guardrail_face_blur_filter.py +++ b/guardrail_face_blur_filter.py @@ -16,7 +16,7 @@ import argparse import os -from AutoregressiveVideo2WorldGeneration import misc +from .misc import misc import numpy as np import torch from pytorch_retinaface.data import cfg_re50 @@ -25,11 +25,11 @@ from pytorch_retinaface.models.retinaface import RetinaFace from torch.utils.data import DataLoader, TensorDataset from tqdm import tqdm -from AutoregressiveVideo2WorldGeneration.guardrail_common_core import GuardrailRunner, PostprocessingGuardrail -from AutoregressiveVideo2WorldGeneration.guardrail_common_io_utils import get_video_filepaths, read_video, save_video -from AutoregressiveVideo2WorldGeneration.guardrail_face_blur_filter_blur_utils import pixelate_face -from AutoregressiveVideo2WorldGeneration.guardrail_face_blur_filter_retinaface_utils import decode_batch, filter_detected_boxes, load_model -from AutoregressiveVideo2WorldGeneration import log +from .guardrail_common_core import GuardrailRunner, PostprocessingGuardrail +from .guardrail_common_io_utils import get_video_filepaths, read_video, save_video +from .guardrail_face_blur_filter_blur_utils import pixelate_face +from .guardrail_face_blur_filter_retinaface_utils import decode_batch, filter_detected_boxes, load_model +from .log import log DEFAULT_RETINAFACE_CHECKPOINT = "checkpoints/Cosmos-1.0-Guardrail/face_blur_filter/Resnet50_Final.pth" diff --git a/guardrail_face_blur_filter_retinaface_utils.py b/guardrail_face_blur_filter_retinaface_utils.py index c3c373986fc0c04720671857d8d1f4fa8ef73ec0..5d1bc4c8a22a942736ae6b73a4ebb21da4980adc 100644 --- a/guardrail_face_blur_filter_retinaface_utils.py +++ b/guardrail_face_blur_filter_retinaface_utils.py @@ -17,7 +17,7 @@ import numpy as np import torch from pytorch_retinaface.utils.nms.py_cpu_nms import py_cpu_nms -from AutoregressiveVideo2WorldGeneration import log +from .log import log # Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py diff --git a/guardrail_video_content_safety_filter.py b/guardrail_video_content_safety_filter.py index 4f541c94c0c83d55d159f795ba30e9490eea442d..072076fb853aec819a7298df83e26338e0cb4c3a 100644 --- a/guardrail_video_content_safety_filter.py +++ b/guardrail_video_content_safety_filter.py @@ -18,15 +18,15 @@ import json import os from typing import Iterable, Tuple, Union -from AutoregressiveVideo2WorldGeneration import misc +from .misc import misc import torch from PIL import Image -from AutoregressiveVideo2WorldGeneration.guardrail_common_core import ContentSafetyGuardrail, GuardrailRunner -from AutoregressiveVideo2WorldGeneration.guardrail_common_io_utils import get_video_filepaths, read_video -from AutoregressiveVideo2WorldGeneration.guardrail_video_content_safety_filter_model import ModelConfig, VideoSafetyModel -from AutoregressiveVideo2WorldGeneration.guardrail_video_content_safety_filter_vision_encoder import SigLIPEncoder -from AutoregressiveVideo2WorldGeneration import log +from .guardrail_common_core import ContentSafetyGuardrail, GuardrailRunner +from .guardrail_common_io_utils import get_video_filepaths, read_video +from .guardrail_video_content_safety_filter_model import ModelConfig, VideoSafetyModel +from .guardrail_video_content_safety_filter_vision_encoder import SigLIPEncoder +from .log import log DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/video_content_safety_filter" diff --git a/guardrail_video_content_safety_filter_model.py b/guardrail_video_content_safety_filter_model.py index 88e8af353c4fd4f3e5ad5961109d5c712994c43f..d4ccc005d90bad4d029bf9fdc9a66450fa9f7049 100644 --- a/guardrail_video_content_safety_filter_model.py +++ b/guardrail_video_content_safety_filter_model.py @@ -17,7 +17,7 @@ import attrs import torch import torch.nn as nn -from AutoregressiveVideo2WorldGeneration.config import make_freezable +from .config import make_freezable @make_freezable diff --git a/lazy.py b/lazy.py index 00194049dda47884b1ba3d18034ef22ae7bb4ff2..7be2848733fe33c1513756b850c4da250a790008 100644 --- a/lazy.py +++ b/lazy.py @@ -29,8 +29,8 @@ import attrs import yaml from omegaconf import DictConfig, ListConfig, OmegaConf -from AutoregressiveVideo2WorldGeneration.lazy_file_io import PathManager -from AutoregressiveVideo2WorldGeneration.lazy_registry import _convert_target_to_string +from .lazy_file_io import PathManager +from .lazy_registry import _convert_target_to_string __all__ = ["LazyCall", "LazyConfig"] diff --git a/lazy_config_init.py b/lazy_config_init.py index e3041a86a039156c194acab286a9e15366db58c0..9bd252316a4bd6fb3a8f8a1c29a8e9ac44ac76fe 100644 --- a/lazy_config_init.py +++ b/lazy_config_init.py @@ -3,9 +3,9 @@ import os from omegaconf import DictConfig, OmegaConf -from AutoregressiveVideo2WorldGeneration.lazy_instantiate import instantiate -from AutoregressiveVideo2WorldGeneration.lazy import LazyCall, LazyConfig -from AutoregressiveVideo2WorldGeneration.lazy_omegaconf_patch import to_object +from .lazy_instantiate import instantiate +from .lazy import LazyCall, LazyConfig +from .lazy_omegaconf_patch import to_object OmegaConf.to_object = to_object diff --git a/lazy_instantiate.py b/lazy_instantiate.py index 2486ca065bcca28cdb6bcc717af2a66c451c361c..4c860c42a1c3d8adc417e9593892491d0803fe51 100644 --- a/lazy_instantiate.py +++ b/lazy_instantiate.py @@ -20,7 +20,7 @@ from typing import Any import attrs -from AutoregressiveVideo2WorldGeneration.lazy_registry import _convert_target_to_string, locate +from .lazy_registry import _convert_target_to_string, locate __all__ = ["dump_dataclass", "instantiate"] diff --git a/misc.py b/misc.py index 401b9edf9a000c8b0c4ba71d74ad4e30ea3a2685..a2496a4fa280586b62c846c54cfbbc9f8adc0331 100644 --- a/misc.py +++ b/misc.py @@ -29,7 +29,7 @@ import numpy as np import termcolor import torch -from AutoregressiveVideo2WorldGeneration import distributed +from .distributed import distributed class misc(): diff --git a/t5_text_encoder.py b/t5_text_encoder.py index 37f19bad0b2538097e378becbc7e88262f99ee45..7bebf08cef2869c85553980bf81851635dd74f7e 100644 --- a/t5_text_encoder.py +++ b/t5_text_encoder.py @@ -19,7 +19,7 @@ import torch import transformers from transformers import T5EncoderModel, T5TokenizerFast -from AutoregressiveVideo2WorldGeneration import log +from .log import log transformers.logging.set_verbosity_error() diff --git a/video2world.py b/video2world.py index 49aa113abdbe5ad2ae5ccd72ecfda18be5d67000..d01c331fe900cfc84cea41fa9f7e1c71da530ec5 100644 --- a/video2world.py +++ b/video2world.py @@ -19,10 +19,10 @@ import os import imageio import torch -from AutoregressiveVideo2WorldGeneration.world_generation_pipeline import ARVideo2WorldGenerationPipeline -from AutoregressiveVideo2WorldGeneration.ar_utils_inference import add_common_arguments, load_vision_input, validate_args -from AutoregressiveVideo2WorldGeneration import log -from AutoregressiveVideo2WorldGeneration.io import read_prompts_from_file +from .world_generation_pipeline import ARVideo2WorldGenerationPipeline +from .ar_utils_inference import add_common_arguments, load_vision_input, validate_args +from .log import log +from .io import read_prompts_from_file def parse_args(): diff --git a/video2world_hf.py b/video2world_hf.py index 470b974557b67b36497a5c3c8cf7267bd9dfcc2d..1f41a4225dcea325c5ea283e51e09477ee1d0e6d 100644 --- a/video2world_hf.py +++ b/video2world_hf.py @@ -19,10 +19,10 @@ import os import imageio import torch -from AutoregressiveVideo2WorldGeneration.world_generation_pipeline import ARVideo2WorldGenerationPipeline -from AutoregressiveVideo2WorldGeneration.ar_utils_inference import load_vision_input, validate_args -from AutoregressiveVideo2WorldGeneration import log -from AutoregressiveVideo2WorldGeneration.io import read_prompts_from_file +from .world_generation_pipeline import ARVideo2WorldGenerationPipeline +from .ar_utils_inference import load_vision_input, validate_args +from .log import log +from .io import read_prompts_from_file # from download_autoregressive import main as download_autoregressive from transformers import PreTrainedModel, PretrainedConfig diff --git a/world_generation_pipeline.py b/world_generation_pipeline.py index 4273983a3ac63666931eb4b2260c5fd8eb062ff5..1e300540d3a022a74d708a0df0f04204a895b189 100644 --- a/world_generation_pipeline.py +++ b/world_generation_pipeline.py @@ -17,30 +17,30 @@ import gc import os from typing import List, Optional, Tuple -from AutoregressiveVideo2WorldGeneration import misc +from .misc import misc import numpy as np import torch from einops import rearrange -from AutoregressiveVideo2WorldGeneration.ar_config_base_model_config import create_video2world_model_config -from AutoregressiveVideo2WorldGeneration.ar_config_base_tokenizer import TokenizerConfig -from AutoregressiveVideo2WorldGeneration.ar_config_inference_inference_config import ( +from .ar_config_base_model_config import create_video2world_model_config +from .ar_config_base_tokenizer import TokenizerConfig +from .ar_config_inference_inference_config import ( DataShapeConfig, DiffusionDecoderSamplingConfig, InferenceConfig, SamplingConfig, ) -from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_inference import diffusion_decoder_process_tokens -from AutoregressiveVideo2WorldGeneration.ar_diffusion_decoder_model import LatentDiffusionDecoderModel -from AutoregressiveVideo2WorldGeneration.ar_model import AutoRegressiveModel -from AutoregressiveVideo2WorldGeneration.ar_utils_inference import _SUPPORTED_CONTEXT_LEN, prepare_video_batch_for_saving -from AutoregressiveVideo2WorldGeneration.base_world_generation_pipeline import BaseWorldGenerationPipeline -from AutoregressiveVideo2WorldGeneration.df_inference_inference_utils import ( +from .ar_diffusion_decoder_inference import diffusion_decoder_process_tokens +from .ar_diffusion_decoder_model import LatentDiffusionDecoderModel +from .ar_model import AutoRegressiveModel +from .ar_utils_inference import _SUPPORTED_CONTEXT_LEN, prepare_video_batch_for_saving +from .base_world_generation_pipeline import BaseWorldGenerationPipeline +from .df_inference_inference_utils import ( load_model_by_config, load_network_model, load_tokenizer_model, ) -from AutoregressiveVideo2WorldGeneration import log +from .log import log def detect_model_size_from_ckpt_path(ckpt_path: str) -> str: