diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..f5edfa5630ebf12f1a78525900e1645ae48068c2
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,2 @@
+*.jpg filter=lfs diff=lfs merge=lfs -text
+*.webp filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..351b1d030951f268408c2fc519d8497399de6104
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,37 @@
+__pycache__/
+outputs/
+ckpts/
+vqgan/vqgan_pretrained/
+vqgan/vqgan_taming_ckpt/
+data/
+models/datasets/.cache/
+*.json
+output/
+tmp*
+multirun/
+.nfs*
+lightning_logs/
+static/
+archive/
+output_profile/
+logs/
+.history/
+.cache/
+output*/
+*.out
+*.parquet
+wandb/
+vqgan/
+*.csv
+.python-version
+ft_cache/
+alias.txt
+env.sh
+generated_image.png
+Untitled-1.ipynb
+*.log
+demo/old
+*.pem
+.sesskey
+icons.py
+generated/
\ No newline at end of file
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000000000000000000000000000000000000..9f634de96a802eeff3b2d4b695e48e80bd4b6805
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,15 @@
+[submodule "third_party/LlamaGen"]
+ path = third_party/LlamaGen
+ url = https://github.com/alexanderswerdlow/LlamaGen.git
+ branch = wip_v1
+[submodule "third_party/Lumina-mGPT"]
+ path = third_party/Lumina-mGPT
+ url = https://github.com/alexanderswerdlow/Lumina-mGPT.git
+ branch = non_causal
+[submodule "third_party/Show-o"]
+ path = third_party/Show-o
+ url = https://github.com/showlab/Show-o.git
+[submodule "third_party/1d-tokenizer"]
+ path = third_party/1d-tokenizer
+ url = https://github.com/bytedance/1d-tokenizer.git
+ branch = main
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..205770c54271fbb1f7ed57c60b7a722a26566602
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,79 @@
+# Base image with CUDA 12.6.3 and cuDNN
+FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04
+
+# Set environment variables
+ARG DEBIAN_FRONTEND=noninteractive
+ENV PYTHONUNBUFFERED=1 \
+ SYSTEM=spaces \
+ AM_I_IN_A_DOCKER_CONTAINER=Yes \
+ PYTHONPATH=/home/appuser/app \
+ HF_HOME=/home/appuser/.cache \
+ TORCH_HOME=/home/appuser/.cache \
+ TMP_DIR=/home/appuser/tmp \
+ TRANSFORMERS_CACHE=/home/appuser/.cache/transformers \
+ NVIDIA_VISIBLE_DEVICES=all \
+ NVIDIA_DRIVER_CAPABILITIES=compute,utility
+
+# Install system dependencies and set Python 3.10 as default
+RUN apt-get update && apt-get install --no-install-recommends -y \
+ build-essential \
+ python3.10 \
+ python3.10-distutils \
+ python3-pip \
+ ffmpeg \
+ libsm6 \
+ libxext6 \
+ libgl1 \
+ git \
+ openssh-client \
+ && ln -sf /usr/bin/python3.10 /usr/bin/python \
+ && ln -sf /usr/bin/pip3 /usr/bin/pip \
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
+
+# Install `uv`
+RUN pip install --upgrade pip \
+ && pip install uv
+
+# Create a non-root user
+RUN useradd -m -u 1000 appuser
+
+# Set working directory
+WORKDIR /home/appuser/app
+
+# Copy dependency files and install dependencies
+COPY --chown=appuser pyproject.toml uv.lock README.md ./
+RUN mkdir -p -m 0600 ~/.ssh && ssh-keyscan github.com >> ~/.ssh/known_hosts
+
+RUN --mount=type=ssh uv sync --no-group dev
+RUN --mount=type=ssh uv sync --frozen --no-cache \
+ && chown -R appuser:appuser /home/appuser/app/.venv \
+ && rm -rf /root/.cache /home/appuser/.cache
+
+# Ensure non-root user has write access to cache and tmp directories
+RUN mkdir -p /home/appuser/.cache/transformers /home/appuser/tmp /home/appuser/.cache \
+ && chown -R appuser:appuser /home/appuser/.cache /home/appuser/tmp/ /home/appuser/app/
+
+RUN chmod -R 777 /tmp
+
+# Copy application code
+COPY --chown=appuser demo demo
+COPY --chown=appuser unidisc unidisc
+COPY --chown=appuser models models
+COPY --chown=appuser configs configs
+COPY --chown=appuser third_party third_party
+COPY --chown=appuser ckpts ckpts
+COPY --chown=appuser ./__* ./
+COPY --chown=appuser ./*.py ./
+COPY --chown=appuser ./archive/pytorch_model_fsdp.bin ./
+
+# Switch to non-root user
+USER appuser
+
+# Expose port for Gradio
+EXPOSE 5003
+
+# Command to run the application
+CMD ["bash", "demo/demo.sh"]
+
+# DOCKER_BUILDKIT=1 docker build --ssh default --network=host -t unidisc .
+# docker run --network=host -it -p 5003:5003 unidisc
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e15a2e105d6bb810f7692e0a98737f3e31e62410
--- /dev/null
+++ b/README.md
@@ -0,0 +1,82 @@
+
+
+

+
Unified Multimodal Discrete Diffusion
+
+[Alexander Swerdlow](https://aswerdlow.com/)
1*
+[Mihir Prabhudesai](https://mihirp1998.github.io/)
1*
+[Siddharth Gandhi](hhttps://www.ssgandhi.com/)
1
+[Deepak Pathak](https://www.cs.cmu.edu/~dpathak/)
1
+[Katerina Fragkiadaki](https://www.cs.cmu.edu/~katef/)
1
+
+
+
1 Carnegie Mellon University
+
+[](https://arxiv.org/pdf/0000.00000) [](https://unidisc.github.io/)
+
+
+
+
+
+## Hugging Face models and annotations
+
+The UniDisc checkpoints are available on [Hugging Face](https://huggingface.co/unidisc):
+* [unidisc/todo](https://huggingface.co/unidisc/todo)
+
+## Getting Started
+
+To install the dependencies, run:
+```bash
+git submodule update --init --recursive
+uv sync --no-group dev
+uv sync
+```
+
+For a more detailed installation guide, please refer to [INSTALL.md](docs/INSTALL.md).
+
+## Training
+
+See [TRAIN.md](docs/TRAIN.md) for details.
+
+## Inference
+
+
+
+
+
+Interactive demo for **TODO**.
+```
+python demo/server.py
+python demo/client_simple_fasthtml.py
+```
+
+
+## Training
+
+See [TRAINING.md](docs/TRAINING.md) for details.
+
+## Evaluation
+
+See [EVAL.md](docs/EVAL.md) for details.
+
+
+### Citation
+To cite our work, please use the following:
+```
+@article{TODO,
+ title={TODO},
+ author={TODO},
+ journal={arXiv preprint arXiv:TODO},
+ year={TODO}
+}
+```
+
+## Credits
+
+This repository is built on top of the following repositories:
+
+- [MDLM](https://github.com/kuleshov-group/mdlm)
+- [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X)
\ No newline at end of file
diff --git a/__builtins__.pyi b/__builtins__.pyi
new file mode 100644
index 0000000000000000000000000000000000000000..53008e1d88e2938c9955d5c6ac3e98d20df2ed25
--- /dev/null
+++ b/__builtins__.pyi
@@ -0,0 +1,7 @@
+from ipdb import set_trace as st
+from decoupled_utils import start_timing as start_timing
+from decoupled_utils import end_timing as end_timing
+ENABLE_TIMING: bool
+ENABLE_TIMING_SYNC: bool
+DEVICE_BACKEND_TYPE: str
+exists = lambda v: v is not None
\ No newline at end of file
diff --git a/configs/config.yaml b/configs/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0f7936e667fa7d803be9237ffa1851a24ff2ef2f
--- /dev/null
+++ b/configs/config.yaml
@@ -0,0 +1,451 @@
+defaults:
+ - _self_
+ - /model: small
+ - /noise: loglinear
+ - /lr_scheduler: constant_warmup
+ - /experiments: []
+ # - override hydra/launcher: submitit_slurm
+
+slurm: False
+debug: False
+mode: train # train / eval
+diffusion: absorbing_state
+backbone: dit # dit / dimamba / ar
+parameterization: subs # subs / d3pm / sedd
+time_conditioning: False
+T: 0 # 0 (continuous time) / 1000
+subs_masking: False
+seed: 42
+profile: False
+# These belong in trainer.* and hydra.launcher.* but are put here for CLI convinience
+devices: ${device_count:}
+nodes: 1
+partition: ${find_partition:}
+constraint: ${find_constraint:}
+ckpt: null
+
+loader:
+ desired_global_batch_size: 512
+ global_batch_size: null
+ eval_global_batch_size: ${.global_batch_size}
+ batch_size: ${div_up:${.desired_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
+ eval_batch_size: ${div_up:${.desired_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
+ num_workers: ${eval:"max(len(__import__('os').sched_getaffinity(0)) // 16, 4)"}
+ pin_memory: True
+ persistent_workers: True
+
+sampling:
+ predictor: ddpm_cache # analytic, ddpm, ddpm_cache
+ steps: 1000
+ max_sampling_steps: 500 # The highest level we use for sampling
+ noise_removal: True
+ num_sample_log: 2
+ semi_ar: False
+ stride_length: 1
+ num_strides: 1
+
+eval:
+ checkpoint_path: '' # Used to evaluate a checkpoint after training.
+ disable_ema: False
+ compute_generative_perplexity: False
+ perplexity_batch_size: 8
+ gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
+ generate_samples: True
+ cfg: null
+ num_masking_viz_batches: 1
+ num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
+ test_eval_speed: False
+ standalone_fid: False
+ visualize_data_only: false
+ val_with_train_data: false
+ max_num_fid_batches_per_device: null
+ class_conditional_fid: false
+ compute_entropy: false
+ compute_standalone_mauve: false
+ compute_standalone_entropy: false
+ compute_img_to_txt_mauve_clip: false
+ compute_img_to_txt_mauve_during_unconditional_fid: false
+ mauve_num_samples: 5000
+ mauve_divergence_curve_discretization_size: 25 # default in mauve repo
+ mauve_average_over_seeds: 3
+ mauve_scaling_factor: 5 # default in mauve repo
+ txt_conditional_fid: false
+ unconditional_fid: false
+ fid_mode: inline
+ calculate_clip_score: false
+ clean_fid_use_precomputed_stats: false
+ clean_fid_precomputed_name: null
+ clean_fid_precomputed_split: null
+ clean_fid_precomputed_res: null
+ attention_caching: false
+ set_random_gen_seed: false
+ compute_val_metrics_standalone: false
+ num_val_metrics_standalone_batches_per_device: ${eval:'max(${eval.num_val_metrics_standalone_samples} // (${trainer.devices} * ${loader.eval_batch_size}), 1)'}
+ num_val_metrics_standalone_samples: -1
+ return_unweighed_sim: false
+ compute_chameleon_perplexity: false
+ global_disable_mauve: false
+ bypass_normal_validation: false
+ auto_enhance: false
+ num_auto_enhance_iter: 2
+ ar_inpainting_min_val: 0.5
+ ar_inpainting_max_val: 1.0
+ ar_inpainting_force_val: null
+
+optim:
+ weight_decay: 0
+ lr: 3e-4
+ beta1: 0.9
+ beta2: 0.999
+ eps: 1e-8
+ fused: true
+
+model:
+ use_custom_vae_config: false
+ use_custom_vae_ckpt: null
+ downscale_ratio: null
+ image_vocab_size: null
+ vae_type: null
+ use_attention_mask: false
+
+ cond_use_custom_vae_config: false
+ cond_use_custom_vae_ckpt: null
+ cond_downscale_ratio: null
+ cond_image_vocab_size: null
+ cond_vae_type: null
+ text_model: true
+
+ attn_type: flash
+ force_varlen_attn: false
+ force_cast_bf16: false
+ norm_type: layernorm
+ mup: false
+ qk_norm: false
+ distillation: false
+ force_argmax_valid_indices: false
+ use_flash_attn_3: false
+ use_spda_attn: false # Spelled wrong...
+ rope_2d: false
+ modality_embed: false
+ zero_linear_init: true
+ full_attention: true
+ use_lora: false
+ use_kv_cache: false
+ force_optimized_native_attn: false
+ use_pretrained_img_emb: true
+ use_flex_attention: false
+ add_labels: null
+ flex_attention_txt_masking_prob: null
+ flex_attention_img_masking_prob: null
+
+trainer:
+ _target_: lightning.Trainer
+ accelerator: cuda
+ num_nodes: ${nodes}
+ devices: ${devices}
+
+ # Given a desired global batch size (e.g., how many batches we see before a optim.step, summed over all nodes/gpus/accum_steps), we find the number of gradient accumulations that gets us closest given our current configuration. We assume that loader.batch_size is the largest that can fit in a single fwd/bwd.
+ accumulate_grad_batches: ${find_grad_accum:${loader.desired_global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
+ gradient_clip_val: 1.0
+ precision: 'bf16'
+ max_steps: 1_000_000_000
+
+ num_epochs: 1_000_000_000
+ optimizer_cls: adamw
+ set_grads_to_none: true
+ eval_on_start: true
+ eval_decay_steps: false
+ eval_epochs: null
+ ckpt_steps: 100000
+ fsdp: false
+ force_enable_checkpointing: false
+ limit_val_batches: null
+ ckpt_every_n_minutes: 60
+ ckpt_recent_timeout_minutes: 10
+ checkpoint_all_ranks: true
+ force_null_sigma: false
+
+ log_every_n_steps: 10
+ limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
+ val_check_interval: 100
+
+ ema: 0.9999
+ antithetic_sampling: True
+ importance_sampling: False
+ sampling_eps: 1e-3
+ change_of_variables: False
+ benchmark: true
+ backward_pass: true
+ forward_pass: true
+ profile_memory: false
+ pytorch_profile: false
+ nvtx_profile: false
+ custom_ddp_bf16: true
+ log_seperate_modal_losses: true
+ use_gradient_checkpointing: false
+ text_loss_weight: null
+ img_loss_weight: null
+ disable_strict_load: false
+ attach_oom_observer_eval: false
+ find_unused_parameters: false
+ restart_on_failure: false
+ skip_early_checkpointing: true
+ log_flops: true
+ sync_timing: false
+ use_custom_ema: false
+ scale_lr_by_batch_size: false
+ tpu_eager: false
+ allow_dynamic_nodes: false
+ force_disable_signal_handler: false
+ tpu_profile: false
+ tpu_cache: false
+ enable_jax_smi: false
+ tpu_compile_debug: false
+ xla_spmd: false
+ log_grad_norm: true
+ tpu_profile_markers: true
+ compile: false
+ disable_all_checkpointing: false
+ tpu_force_mark_step: false
+ ar_shift: false
+ ar_llm_loss: false
+ ar_print_loss: false
+ chameleon_z_loss: null
+ image_mode: discrete # continuous / discrete
+ chameleon_use_ce_loss: false
+ low_precision_loss: false
+ low_precision_params: false
+ scratch: false
+ use_spmd_distributed_checkpointing: null
+ use_simple_spmd_distributed_checkpointing: false
+ load_from_state_dict: null
+ load_from_optimizer_state_dict: null
+ multimodal_batches: false
+ sync_dataloader_timing: false
+ compile_flag_pos_emb: false
+ compile_fullgraph: false
+ compile_mode: max-autotune-no-cudagraphs
+ joint_ar_nar_prob: null
+ joint_ar_nar_prob_warmup_steps: null
+ joint_ar_nar_timestep_warmup_steps: null
+ spmd_mesh: null
+ detect_anomaly: false
+ freeze_chameleon_embeddings: false
+ ckpt_model_only: false
+ use_orig_params: null
+ disable_adjust_num_warmup_steps: false
+ mask_entire_modality: null
+ iterate_dataloader_only: false
+ force_bf16_eval: false
+ disable_all_eval_generation: false
+ debug_xla_sept: false
+ ignore_text_in_unified: false
+ allow_null_sigma: false
+ disable_forward_autocast_during_eval: false
+ viz_images_only: false
+ add_label: false
+ first_token_dropout: null
+ disable_ddp_optimizer: false
+ rand_flip_ar_prob: null
+ rand_ar_modality_dropout: null
+ use_linear_warmup_cosine_annealing: false
+ no_ce_weighting: false
+ interleaved: false
+ interleaved_training_flex_attention: false
+ awr: false
+ ar_inpainting: false
+
+wandb:
+ entity: grads
+ project: ${eval:'"unidisc-debug" if ${debug} else "unidisc"'}
+ resume: ${eval:'"allow" if ${slurm} else None'}
+ id: null
+ group: null
+ job_type: null
+ name: null
+ tags:
+ - ${data.train}
+
+checkpointing_root_dir: ${oc.env:UNIDISC_CHECKPOINTING_ROOT_DIR,null}
+root_output_dir: ${oc.env:UNIDISC_ROOT_OUTPUT_DIR,outputs}
+python_orig: |
+ accelerate launch \
+ --num_machines $SLURM_NNODES \
+ --num_processes $NUM_PROCESSES \
+ --rdzv_backend c10d \
+ --main_process_ip $MASTER_ADDR \
+ --main_process_port $MASTER_PORT \
+ --machine_rank $SLURM_PROCID \
+ --mixed_precision bf16 \
+ --dynamo_backend no \
+ --enable_cpu_affinity \
+ --max_restarts 0 \
+
+mem_per_gpu: 40
+cpus_per_gpu: 8
+slurm_name: null
+timeout_min: ${partition_limit:${partition}}
+hydra:
+ run:
+ dir: ${oc.env:HYDRA_RUN_DIR,${root_output_dir}/outputs/${get_dir_name:}/${oc.env:HYDRA_RUN_DIR_NAME,${now:%Y_%m_%d}/${now:%H_%M_%S}}}
+ sweep:
+ dir: ${oc.env:HYDRA_RUN_DIR,${root_output_dir}/outputs/${get_dir_name:}/${oc.env:HYDRA_RUN_DIR_NAME,${now:%Y_%m_%d}/${now:%H_%M_%S}}}
+ subdir: ${hydra.job.id}
+ job:
+ chdir: true
+ # launcher:
+ # name: ${get_slurm_name:}
+ # # See https://hydra.cc/docs/configure_hydra/workdir/
+ # submitit_folder: ${hydra.sweep.dir}/%j
+ # nodes: ${nodes} # Number of nodes. This value is *per* node
+ # mem_gb: ${eval:'${mem_per_gpu} * ${trainer.devices}'} # 40GB per gpu. This value is *per* node
+ # gpus_per_node: ${trainer.devices}
+ # partition: ${partition}
+ # constraint: ${constraint}
+ # exclude: ${exclude_nodes:}
+
+ # timeout_min: ${timeout_min}
+ # max_num_timeout: 12 # Num requeue exlcuding pre-emptions
+ # comment: aswerdlo
+ # stderr_to_stdout: true
+
+ # # Be careful with changing anything below.
+ # # see: https://github.com/stas00/ml-engineering/tree/master/training/fault-tolerance#approach-b2-choosing-which-process-to-send-the-signal-to
+ # # see: https://github.com/huggingface/accelerate/issues/1918
+
+ # # The accelerate launcher w/1 initial process and then spawn 1 per GPU
+ # tasks_per_node: 1
+ # cpus_per_task: ${eval:'${cpus_per_gpu} * ${trainer.devices}'}
+ # python: |
+ # bash -c "torchrun --nnodes $SLURM_NNODES --nproc_per_node $SLURM_GPUS_PER_NODE --role \$(hostname -s|tr -dc '0-9'): --node_rank \$SLURM_PROCID --max-restarts=2 --rdzv_id $RANDOM --rdzv_backend c10d --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
+
+ # # python: "${getpythoncmd:}"
+ # # tasks_per_node: ${devices}
+ # # cpus_per_task: 8
+ # # python: 'python'
+
+ # python_suffix: ' --dummy-arg $SLURM_JOB_ID" &'
+ # signal: 'B:USR2@360'
+ # post_srun_commands:
+ # - ''
+ # - wait
+
+ # srun_args:
+ # - '--jobid $SLURM_JOB_ID'
+
+ # setup:
+ # - |
+ # export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
+ # export MASTER_PORT=$(( ($SLURM_JOB_ID % 20001) + 30000 ))
+ # export NUM_PROCESSES=$((SLURM_NNODES * SLURM_GPUS_PER_NODE))
+ # export NCCL_DEBUG=INFO
+ # export NCCL_NSOCKS_PERTHREAD=4
+ # export NCCL_SOCKET_NTHREADS=2
+ # export OMP_NUM_THREADS=2
+ # export PYTHONUNBUFFERED=1
+ # export STDOUT_PATH=$(scontrol show job $SLURM_JOB_ID | grep -oP "StdOut=\K[^ ]+")
+ # export LOCAL_JOB_FOLDER=$(dirname $STDOUT_PATH)
+ # export NCCL_TOPO_DUMP_FILE="$LOCAL_JOB_FOLDER/nccl_topo.xml"
+ # if [ -n "$SLURM_RESTART_COUNT" ]; then
+ # export RESTART_COUNT=$SLURM_RESTART_COUNT
+ # else
+ # export RESTART_COUNT=0
+ # fi
+ # export MAIN_LOG_PATH="$LOCAL_JOB_FOLDER/log_$RESTART_COUNT.txt"
+
+ # mkdir -p $LOCAL_JOB_FOLDER
+ # printenv > "$LOCAL_JOB_FOLDER"/env_"$SLURM_LOCALID_$RESTART_COUNT.txt"
+
+ # echo "ibstatus: $(ibstatus)"
+ # echo "ibdev2netdev: $(ibdev2netdev)"
+ # echo "rdma device: $(rdma link)"
+ # echo "environment: $(env | grep NCCL)"
+ # echo "NUM_PROCESSES: $NUM_PROCESSES, SLURM_NNODES: $SLURM_NNODES SLURM_GPUS_PER_NODE: $SLURM_GPUS_PER_NODE"
+ # echo "NODE_ID: $SLURM_NODEID, SLURM_PROCID: $SLURM_PROCID, MASTER_ADDR: $MASTER_ADDR, MASTER_PORT: $MASTER_PORT"
+ # echo "PWD: $PWD, LOCAL_JOB_FOLDER: $LOCAL_JOB_FOLDER, MAIN_LOG_PATH: $MAIN_LOG_PATH"
+
+ # trap 'echo "SIGUSR2 received for $SLURM_JOB_ID"; \
+ # if [ -n "$SLURM_ARRAY_JOB_ID" ]; then echo "SLURM_ARRAY_JOB_ID: $SLURM_ARRAY_JOB_ID"; fi; \
+ # if [ -n "$SLURM_ARRAY_TASK_ID" ]; then echo "SLURM_ARRAY_TASK_ID: $SLURM_ARRAY_TASK_ID"; fi; \
+ # # ps auxww | grep $USER; \
+ # pid=$(pgrep -u $USER -f "python.*(accelerate|torchrun|deepspeed|distributed\.run).*dummy-arg $SLURM_JOB_ID"); \
+ # echo "Found parent PIDs: $pid"; \
+ # for p in $pid; do \
+ # echo "Parent PID has cmd: $(ps -p $p -o cmd=)"; \
+ # children=$(pgrep -P $p); \
+ # echo "Children: $children"; \
+ # if [ -n "$children" ]; then \
+ # for child in $children; do \
+ # ppid=$(ps -o ppid= -p $child | tr -d " ")
+ # if [ "$ppid" -eq "$p" ]; then
+ # echo "Killing direct child process: PID $child with cmd: $(ps -p $child -o cmd=)"
+ # kill -USR2 $child &
+ # else
+ # echo "Skipping non-direct child process: PID $child with PPID $ppid"
+ # fi
+ # done; \
+ # echo "Sent kill signals to children of $p"; \
+ # else \
+ # echo "No children found for $p"; \
+ # fi; \
+ # done; \
+ # wait;' SIGUSR2
+
+checkpointing:
+ # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
+ save_dir: ${cwd:}/checkpoints
+ # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
+ resume_from_ckpt: true
+ resume_ckpt_path: ${cwd:}/checkpoints
+ initial_resume_ckpt_path: null
+ resume_wandb: true
+ checkpoints_total_limit: 2
+ use_automatic_naming: false
+
+
+data:
+ cache_dir: ${oc.env:HF_DATASETS_CACHE,/grogu/user/mprabhud/aswerdlo/huggingface/datasets}
+ num_proc: ${eval:"max(len(__import__('os').sched_getaffinity(0)) // 4, 16)"}
+ cond_resolution: null
+ iterable: false
+ force_disable_shuffle: false
+ pin_dataset_to_gpu: false
+ webdataset_iterable: false
+ webdataset_train_data: null
+ webdataset_val_data: null
+ webdataset_train_num_samples: null
+ webdataset_val_num_samples: null
+ webdataset_indexed: false
+ dataset_type: null
+ keep_tensordict_on_disk: false
+ use_token_dataset: false
+ use_custom_tensordict_collate: false
+ use_weighted_tensordict_sampler: false
+ enable_cuda_in_tensordict_collate: true
+ data_dir_train: null
+ data_dir_val: null
+ token_output_dir: null
+ wrap_dataloaders: true
+ force_shuffle_train: false
+ move_tensordict_to_shm: false
+ keep_hf_dataset_in_memory: false
+ use_chameleon: false
+ tokenize_vqvae_in_dataloader: false
+ force_mp_spawn: false
+ force_raw_images_in_multiple_tensordict: false
+ disable_text_modality: false
+ txt_only: false
+ disable_mask_after_eos: false
+ allow_label: false
+ split_dataset: false
+ img_token_shift: ${model.text_vocab_size}
+ zero_shot_eval_dataset: null
+ require_sample_ids: false
+ use_packing_collate: false
+ dynamic_packing_lengths: false
+ remove_txt_img_padding: false
+ add_image_gen_tokens: false
+ use_slow_tokenizer: false
+ add_image_token: false
+
+dummyarg: null
\ No newline at end of file
diff --git a/configs/config_empty.yaml b/configs/config_empty.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7a956b57208089c78871d3a8718fd30aa82e30fd
--- /dev/null
+++ b/configs/config_empty.yaml
@@ -0,0 +1,8 @@
+defaults:
+ - _self_
+ - /model: small
+ - /experiments: []
+
+# from omegaconf import OmegaConf
+# with open("config.yaml", "w") as fp:
+# OmegaConf.save(config=config, f=fp.name)
\ No newline at end of file
diff --git a/configs/experiments/ar.yaml b/configs/experiments/ar.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..197515397e438ed2d1aa9a6d08a216f1d0055306
--- /dev/null
+++ b/configs/experiments/ar.yaml
@@ -0,0 +1,10 @@
+# @package _global_
+
+parameterization: ar
+
+trainer:
+ ar_shift: true
+
+model:
+ full_attention: false
+ use_flex_attention: false
\ No newline at end of file
diff --git a/configs/experiments/elm.yaml b/configs/experiments/elm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8e899d6d105715951d2500a23a3fe6fefd756aa3
--- /dev/null
+++ b/configs/experiments/elm.yaml
@@ -0,0 +1,15 @@
+# @package _global_
+
+backbone: elm
+
+data:
+ tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
+
+model:
+ use_lora: false
+ full_attention: true
+ model_id: apple/OpenELM-270M # apple/OpenELM-1_1B
+
+trainer:
+ use_gradient_checkpointing: false
+ sd3_compile_config: false
\ No newline at end of file
diff --git a/configs/experiments/eval_model.yaml b/configs/experiments/eval_model.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fcfc000adbdb7d7225e8877bd47be70ae5184e91
--- /dev/null
+++ b/configs/experiments/eval_model.yaml
@@ -0,0 +1,21 @@
+# @package _global_
+
+mode: eval
+
+loader:
+ batch_size: 16
+ eval_batch_size: 16
+
+trainer:
+ disable_all_eval_generation: false
+
+eval:
+ compute_generative_perplexity: true
+ generate_samples: true
+ num_sample_batches: 20
+ log_every_n_fid: 1
+ log_every_n_evals: 1
+ compute_standalone_mauve: true
+ mauve_num_samples: 5000
+ # mauve_divergence_curve_discretization_size: 200 # works well for our repo
+ # mauve_scaling_factor: 2 # works well for our repo
\ No newline at end of file
diff --git a/configs/experiments/eval_text.yaml b/configs/experiments/eval_text.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..803b8562737293e83dce03ca945a9e72765eb802
--- /dev/null
+++ b/configs/experiments/eval_text.yaml
@@ -0,0 +1,26 @@
+# @package _global_
+
+mode: eval
+
+sampling:
+ steps: 100
+ max_sampling_steps: 100
+
+loader:
+ batch_size: 2
+ eval_batch_size: 2
+
+trainer:
+ fsdp: false
+
+eval:
+ perplexity_batch_size: 2
+ num_masking_viz_batches: 2
+ log_every_n_evals: 1
+ num_uncond_sample_batches: 2
+ num_sample_batches: 2
+ num_random_masking: 1
+ masking_batch_size: 2
+ cfg: null
+ generate_samples: true
+ compute_generative_perplexity: false
\ No newline at end of file
diff --git a/configs/experiments/eval_text_only.yaml b/configs/experiments/eval_text_only.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..56d36b382efada7e1eecbdddee852bb1966dd7fb
--- /dev/null
+++ b/configs/experiments/eval_text_only.yaml
@@ -0,0 +1,30 @@
+# @package _global_
+
+mode: eval
+debug: true
+
+sampling:
+ steps: 100
+ max_sampling_steps: 100
+
+loader:
+ batch_size: 2
+ eval_batch_size: 2
+
+trainer:
+ fsdp: false
+
+model:
+ image_model_fid_eval: false
+
+eval:
+ log_every_n_evals: 1
+ perplexity_batch_size: 2
+ num_uncond_sample_batches: 2
+ num_sample_batches: 2
+ num_masking_viz_batches: -1
+ num_random_masking: -1
+ masking_batch_size: -1
+ cfg: null
+ generate_samples: true
+ compute_generative_perplexity: true
\ No newline at end of file
diff --git a/configs/experiments/eval_unified.yaml b/configs/experiments/eval_unified.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a35e9b17425e66e198ebb6ff1f07d07291a983cc
--- /dev/null
+++ b/configs/experiments/eval_unified.yaml
@@ -0,0 +1,27 @@
+# @package _global_
+
+mode: eval
+devices: ${device_count:}
+
+sampling:
+ steps: 500
+ max_sampling_steps: 1000
+
+loader:
+ batch_size: 6
+ eval_batch_size: 6
+
+trainer:
+ fsdp: false
+ disable_all_eval_generation: false
+
+eval:
+ perplexity_batch_size: 6
+ num_masking_viz_batches: 12
+ log_every_n_evals: 1
+ num_uncond_sample_batches: 5
+ num_sample_batches: 2
+ num_random_masking: 3
+ masking_batch_size: 6
+ cfg: 6.0
+ generate_samples: false
\ No newline at end of file
diff --git a/configs/experiments/fid_cc12m.yaml b/configs/experiments/fid_cc12m.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a3041d22631cccecde8c740bb2131102766601d2
--- /dev/null
+++ b/configs/experiments/fid_cc12m.yaml
@@ -0,0 +1,22 @@
+# @package _global_
+
+data:
+ keep_hf_dataset_in_memory: true
+ aggressive_aug: false
+ n_duplicate_train: null
+ n_duplicate_val: null
+
+ tokenize_vqvae_in_dataloader: false
+ enable_cuda_in_tensordict_collate: false
+ force_mp_spawn: false
+ keep_tensordict_on_disk: false
+ move_tensordict_to_shm: false
+
+ fid_dataset: cc12m_tokens_val_256
+ image_data_train: null
+ image_data_val: null
+ data_dir_train: ${data.data_dir_val}
+ data_dir_val:
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_val_256
+ weight: 1
+ name: ${data.fid_dataset}
\ No newline at end of file
diff --git a/configs/experiments/fid_datacomp1b.yaml b/configs/experiments/fid_datacomp1b.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..541ef9af0b2d2efe9e0365e0cc2a897cdcad9c27
--- /dev/null
+++ b/configs/experiments/fid_datacomp1b.yaml
@@ -0,0 +1,22 @@
+# @package _global_
+
+data:
+ keep_hf_dataset_in_memory: true
+ aggressive_aug: false
+ n_duplicate_train: null
+ n_duplicate_val: null
+
+ tokenize_vqvae_in_dataloader: false
+ enable_cuda_in_tensordict_collate: false
+ force_mp_spawn: false
+ keep_tensordict_on_disk: false
+ move_tensordict_to_shm: false
+
+ fid_dataset: datacomp1b_8_magvit_val
+ image_data_train: null
+ image_data_val: null
+ data_dir_train: ${data.data_dir_val}
+ data_dir_val:
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit_val
+ weight: -1
+ name: ${data.fid_dataset}
\ No newline at end of file
diff --git a/configs/experiments/fid_hf.yaml b/configs/experiments/fid_hf.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8f273a2ae9cd800e8655167f4b8bd44beb4538ec
--- /dev/null
+++ b/configs/experiments/fid_hf.yaml
@@ -0,0 +1,25 @@
+# @package _global_
+
+data:
+ disable_text_modality: false
+ keep_hf_dataset_in_memory: true
+ aggressive_aug: false
+ n_duplicate_train: null
+ n_duplicate_val: null
+ data_dir_train: []
+ data_dir_val: []
+ fid_dataset: sayakpaul/coco-30-val-2014
+ train: combined_tokens
+ val: {.train}
+ image_data_val:
+ - val: ${data.fid_dataset}
+ weight: -1
+ name: ${.val}
+ tokenize_vqvae_in_dataloader: false
+ raw_images: true
+ image_data_train:
+ - train: ${data.fid_dataset}
+ weight: -1
+ name: ${.train}
+ tokenize_vqvae_in_dataloader: false
+ raw_images: true
diff --git a/configs/experiments/jan_cub.yaml b/configs/experiments/jan_cub.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..faf3816e3707eb63be71a7100bec3fd71d71f9e4
--- /dev/null
+++ b/configs/experiments/jan_cub.yaml
@@ -0,0 +1,51 @@
+# @package _global_
+
+defaults:
+ - override /model: medium
+ - override /lr_scheduler: cosine_with_hard_restarts_schedule_with_warmup
+
+loader:
+ batch_size: 16
+ eval_batch_size: 16
+ desired_global_batch_size: 128
+ num_workers: 4
+
+trainer:
+ ckpt_steps: 5000
+ val_check_interval: 100
+ use_legacy_update_batch_fn: true
+ mask_txt_only: true
+ mask_entire_modality: 0.15
+ ema: 0.9999
+ use_custom_ema: true
+ force_enable_checkpointing: true
+ skip_early_checkpointing: false
+ force_after_eos_padding: false
+
+checkpointing:
+ checkpoints_total_limit: 20
+
+lr_scheduler:
+ num_warmup_steps: 10000
+ num_training_steps: 400000
+ num_cycles: 80
+
+data:
+ resolution: 256
+ train: cub2011_custom
+ use_weighted_tensordict_sampler: false
+
+model:
+ vae_type: titok128
+ txt_length: 18
+ img_length: 128
+ rope_2d: false
+ force_text_vocab_size: 5450
+ text_vocab_size: 5451
+ image_vocab_size: 8192
+ attn_dropout: 0.1
+
+optim:
+ lr: 1.0e-04
+ weight_decay: 0.2
+ beta2: 0.99
\ No newline at end of file
diff --git a/configs/experiments/large_maskdit_exp.yaml b/configs/experiments/large_maskdit_exp.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..84188a8967682cf1678e7741d48f4f91ad44b48f
--- /dev/null
+++ b/configs/experiments/large_maskdit_exp.yaml
@@ -0,0 +1,7 @@
+# @package _global_
+
+defaults:
+ - override /model: large_maskdit
+
+
+backbone: maskdit
\ No newline at end of file
diff --git a/configs/experiments/large_scale_high_res_interleaved_inference.yaml b/configs/experiments/large_scale_high_res_interleaved_inference.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..db6b6d4d43ea7717a048a45b251b00a0c78bbaea
--- /dev/null
+++ b/configs/experiments/large_scale_high_res_interleaved_inference.yaml
@@ -0,0 +1,51 @@
+# @package _global_
+
+debug: true
+seed: 163
+
+loader:
+ eval_batch_size: 1
+ batch_size: 1
+
+data:
+ move_tensordict_to_shm: false
+ resolution: 1024
+ disable_mask_after_eos: true
+ disable_packing: true
+ data_dir_val:
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/HPDv2_image_reward_v1_v2_v3/train
+ weight: 1.0
+ name: HPDv2_image_reward_512
+
+model:
+ img_length: 4096
+ txt_length: 1024
+ length: 5120
+
+trainer:
+ compile: false
+ limit_val_batches: 2
+ fsdp: false
+ force_full_attention_mask: true
+ force_null_sigma: true
+ allow_null_sigma: true
+
+eval:
+ num_sample_batches: 1
+ num_random_masking: 0
+ num_masking_viz_batches: 0
+ limit_val_batches_manual: 1
+ num_uncond_sample_batches: 10
+ eval_large_batch: 10
+ val_with_train_data: false
+ maskgit_r_temp: 4.5
+ half_uncond: false
+ cfg: 3.0
+ return_interleaved_modalities_split: true
+ static_img_txt_demo: true
+ visualize_sample: true
+
+sampling:
+ steps: 50
+ max_sampling_steps: 50
+ predictor: "maskgit"
diff --git a/configs/experiments/large_scale_train.yaml b/configs/experiments/large_scale_train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a55e2275db34419cbaf1bfb0c036a36576b3c0a1
--- /dev/null
+++ b/configs/experiments/large_scale_train.yaml
@@ -0,0 +1,151 @@
+# @package _global_
+
+defaults:
+ - vq16_t2i
+ - override /model: extra_large
+
+data:
+ train: combined_tokens
+ valid: ${.train}
+ precache: false
+ streaming: false
+ resolution: 256
+ block_size: 128
+ tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
+ wrap: true
+ iterable: false
+ webdataset_iterable: false
+ webdataset_indexed: false
+ unpaired: false
+ dataset_type: null
+ tokens_flip_collate: false
+ n_val_samples: null
+ n_train_samples: null
+ n_duplicate_train: null
+ n_duplicate_val: null
+ raw_data_dir: null
+ save_train_dataloader: true
+ save_validation_dataloader: true
+ tokenizers_parallelism: false
+ token_data_dir: null
+ force_disable_shuffle: false
+ use_custom_tensordict_collate: true
+ use_weighted_tensordict_sampler: true
+ force_mp_spawn: false
+ enable_cuda_in_tensordict_collate: false
+ use_token_dataset: true
+ keep_tensordict_on_disk: true
+ move_tensordict_to_shm: false
+ add_text_to_weighted_sampler: false
+ data_dir_train:
+ # - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/HPDv2_image_reward_v1_v2_v3/train
+ # weight: 15.0
+ # name: hpdv2
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/pixelprose_tokens
+ weight: 1.0
+ name: pixelprose
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/journeydb_train
+ weight: 10.0
+ name: journeydb_train
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_0_tokens
+ weight: 1.0
+ name: datacomp0
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_1_tokens
+ weight: 1.0
+ name: datacomp1
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_2_tokens
+ weight: 1.0
+ name: datacomp2
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_3_tokens
+ weight: 1.0
+ name: datacomp3
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_4_tokens
+ weight: 1.0
+ name: datacomp4
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_5_tokens
+ weight: 1.0
+ name: datacomp5
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_6_tokens
+ weight: 1.0
+ name: datacomp6
+ data_dir_val:
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/pixelprose_tokens
+ weight: 1.0
+ name: dummy_1
+
+model:
+ img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
+ txt_length: ${eval:'${data.block_size} if ${.unified_model} else 0'}
+ length: ${eval:'${.txt_length} + ${.img_length}'}
+ unified_model: true
+ image_model: true
+ text_model: true
+ image_model_fid_eval: false
+ force_argmax_valid_indices: true
+ use_pretrained_img_emb: false
+ rope_2d: true
+ modality_embed: true
+ norm_type: rms
+ qk_norm: true
+ sandwich_normalization: true
+ text_vocab_size: 32001
+
+loader:
+ batch_size: 8
+ eval_batch_size: ${eval:'${.batch_size} // 2'}
+ desired_global_batch_size: 512
+ persistent_workers: true
+ pin_memory: false
+ num_workers: 0
+ num_eval_workers: 0
+eval:
+ log_every_n_evals: -1
+ log_every_n_fid: -1
+ limit_val_batches_manual: 16
+ generate_samples: true
+ compute_generative_perplexity: false
+ perplexity_batch_size: ${loader.eval_batch_size}
+ cfg: 5.0
+ num_val_metrics_standalone_samples: -1
+ num_val_metrics_standalone_batches_per_device: -1
+ auto_enhance_reward_config:
+ dfn_score: 1.0
+ laion_aesthetic_score: 1.0
+
+trainer:
+ log_flops: false
+ log_every_n_steps: 10
+ custom_ddp_bf16: true
+ log_seperate_modal_losses: true
+ limit_val_batches: 16
+ softmin_snr: 5
+ text_loss_weight: 1.0
+ img_loss_weight: 0.6
+ use_gradient_checkpointing: false
+ ckpt_steps: 20000
+ ckpt_every_n_minutes: 180
+ ckpt_recent_timeout_minutes: 10
+ use_custom_ema: false
+ ema: 0.0
+ fsdp: true
+ restart_on_failure: true
+ eval_on_start: false
+ val_check_interval: 100000000000
+ scale_lr_by_batch_size: false
+ watch_gradients: false
+ compile: true
+ mask_entire_modality: 0.15
+ compile_flag_pos_emb: true
+ multimodal_batches: true
+optim:
+ lr: 0.0001
+sampling:
+ steps: 128
+ num_sample_batches: 2
+wandb:
+ mode: online
+checkpointing:
+ checkpoints_total_limit: 10
+ use_automatic_naming: false
+lr_scheduler:
+ num_warmup_steps: 10000
\ No newline at end of file
diff --git a/configs/experiments/large_scale_train_high_res.yaml b/configs/experiments/large_scale_train_high_res.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bd39dabc810ffab37fb6989764b46e82dc135763
--- /dev/null
+++ b/configs/experiments/large_scale_train_high_res.yaml
@@ -0,0 +1,39 @@
+
+# @package _global_
+
+data:
+ resolution: 512
+ data_dir_train:
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/HPDv2_image_reward_v1_v2_v3/train
+ weight: 1
+ name: HPDv2_image_reward_512
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_512
+ weight: 2
+ name: pick_score_sac_prompts_v1_v2_v3_512
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_7_512
+ weight: 0.5
+ name: datacomp1b_7_512
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/text/slimpajama6b
+ weight: 2.5
+ name: slimpajama6b
+ data_dir_val:
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/gecko_eval_512
+ weight: 1.0
+ name: gecko_eval_512
+
+trainer:
+ text_loss_weight: 1.0
+ img_loss_weight: 0.5
+ force_full_attention_mask: true
+ mask_entire_modality: 0.1
+
+loader:
+ pin_memory: false
+ num_workers: 4
+ num_eval_workers: 4
+
+lr_scheduler:
+ num_warmup_steps: 5000
+
+model:
+ linear_factor: 2
\ No newline at end of file
diff --git a/configs/experiments/large_scale_train_high_res_inference.yaml b/configs/experiments/large_scale_train_high_res_inference.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..00f18791c68220390a0b413bdf671fbd4f9160a6
--- /dev/null
+++ b/configs/experiments/large_scale_train_high_res_inference.yaml
@@ -0,0 +1,30 @@
+# @package _global_
+
+data:
+ use_token_dataset: true
+ disable_mask_after_eos: true
+ move_tensordict_to_shm: false
+
+trainer:
+ compile_flag_pos_emb: true
+ multimodal_batches: true
+ allow_null_sigma: true
+
+eval:
+ num_sample_batches: 1
+ num_random_masking: 0
+ num_masking_viz_batches: 0
+ limit_val_batches_manual: 1
+ num_uncond_sample_batches: 10
+ eval_large_batch: 10
+ val_with_train_data: false
+ maskgit_r_temp: 4.5
+ half_uncond: false
+ cfg: 3.0
+ static_img_txt_demo: true
+ visualize_sample: true
+
+sampling:
+ steps: 50
+ max_sampling_steps: 50
+ predictor: "maskgit"
diff --git a/configs/experiments/large_scale_train_high_res_interleaved.yaml b/configs/experiments/large_scale_train_high_res_interleaved.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a286b7c86092b328714f9ceadbe83caaf992ef95
--- /dev/null
+++ b/configs/experiments/large_scale_train_high_res_interleaved.yaml
@@ -0,0 +1,105 @@
+
+# @package _global_
+
+data:
+ move_tensordict_to_shm: false
+ enable_cuda_in_tensordict_collate: false
+ force_mp_spawn: false
+ resolution: 512
+ add_text_to_weighted_sampler: false
+
+ add_image_gen_tokens: true
+ use_packing_collate: true
+ dynamic_packing_lengths: true
+ remove_txt_img_padding: true
+ require_sample_ids: true
+ block_size: ${model.length}
+ disable_mask_after_eos: true
+ add_image_token: true
+ use_slow_tokenizer: true
+ force_seed: true
+
+ data_dir_train:
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/HPDv2_image_reward_v1_v2_v3/train
+ weight: 0.5
+ name: HPDv2_image_reward_v1_v2_v3 # 3593248
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_512
+ weight: 1.0
+ name: pick_score_sac_prompts_v1_v2_v3_512 # 9330810
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/pixelprose_tokens
+ weight: 1.0
+ name: pixelprose_tokens # 6627589
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/babel/cambrian_10m_v5
+ weight: 1.0
+ name: cambrian_10m_v5 # 8215264
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_7_512
+ weight: 1.0
+ name: datacomp1b_7_512 # 23955209
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_2_tokens
+ weight: 0.5
+ name: datacomp_1b_datacomp1b_2_tokens # 10161505
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_4_tokens
+ weight: 0.5
+ name: datacomp_1b_datacomp1b_4_tokens # 27895717
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/babel/mmc4_fewer_faces_v0
+ weight: 2.0
+ name: mmc4_fewer_faces_v0 # 22605524
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_5_tokens
+ weight: 0.5
+ name: datacomp_1b_datacomp1b_5_tokens
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_0_tokens
+ weight: 0.5
+ name: datacomp_1b_datacomp1b_0_tokens
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_1_tokens
+ weight: 0.5
+ name: datacomp_1b_datacomp1b_1_tokens
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/babel/cosmopedia_2_v0
+ weight: 1.0
+ name: cosmopedia_v2
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/babel/fineweb_edu_dedup_v0
+ weight: 1.0
+ name: fineweb_edu_dedup
+ data_dir_val:
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/gecko_eval_512
+ weight: 1.0
+ name: gecko_eval_512
+
+trainer:
+ text_loss_weight: 1.0
+ img_loss_weight: 0.2
+ mask_entire_modality: 0.2
+
+ force_full_attention_mask: false
+ force_full_attention_mask_loss_only: false
+ disable_all_eval_generation: true
+ interleaved: true
+ interleaved_training_flex_attention: true
+ force_convert_to_dict: true
+ val_check_interval: -1
+ use_gradient_checkpointing: true
+ disable_all_checkpointing: false
+ set_max_txt_loss_ratio: true
+ gradient_clip_val: 1.0
+ skip_early_checkpointing: false
+ bypass_load_from_state_dicts_if_resuming: true
+
+loader:
+ num_workers: 4
+ num_eval_workers: 4
+
+lr_scheduler:
+ num_warmup_steps: 5000
+
+model:
+ linear_factor: 2
+ use_flex_attention: true
+ use_spda_attn: true
+
+ length: 1536
+ txt_length: ${.length}
+ img_length: ${.length}
+
+eval:
+ generate_samples: false
+ disable_visualization: true
+
diff --git a/configs/experiments/maskgit.yaml b/configs/experiments/maskgit.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4c872e45079cc9aa9b09d59d15c31e880bdca756
--- /dev/null
+++ b/configs/experiments/maskgit.yaml
@@ -0,0 +1,6 @@
+# @package _global_
+
+model:
+ downscale_ratio: 16
+ image_vocab_size: 1024
+ vae_type: maskgit
\ No newline at end of file
diff --git a/configs/experiments/master_eval.yaml b/configs/experiments/master_eval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c7402372970da8c7fbf3c7a33440a165a36a1e69
--- /dev/null
+++ b/configs/experiments/master_eval.yaml
@@ -0,0 +1,49 @@
+# @package _global_
+
+mode: eval
+
+eval:
+ fid_samples: 4096
+ max_num_fid_batches_per_device: ${eval:'max(${eval.fid_samples} // (${trainer.devices} * ${loader.eval_batch_size}), 1)'}
+ compute_generative_perplexity: true
+ generate_samples: true
+ log_every_n_fid: 1
+ log_every_n_evals: 1
+ class_conditional_fid: false
+ txt_conditional_fid: true
+ calculate_clip_score: true
+ cfg: 5
+ num_sample_batches: 2
+ compute_standalone_mauve: false
+ mauve_num_samples: -1
+ set_random_gen_seed: true
+ # gen_ppl_eval_model_name_or_path: 'meta-llama/Meta-Llama-3-8B'
+ compute_img_to_txt_mauve_clip: true
+ compute_img_to_txt_mauve_during_unconditional_fid: true
+ force_eval_uncond: true
+ ablation_config: true
+ compute_val_metrics_standalone: true
+ num_val_metrics_standalone_samples: 2000
+
+trainer:
+ disable_all_eval_generation: false
+ force_after_eos_padding: true
+
+model:
+ image_model_fid_eval: true
+ use_kv_cache: ${is_ar:${parameterization}}
+
+loader:
+ batch_size: 64
+ eval_batch_size: 64
+ num_workers: 0
+ num_eval_workers: 1
+
+sampling:
+ steps: ${model.length}
+ max_sampling_steps: ${sampling.steps}
+ sampling_step_frac: null
+
+
+data:
+ fid_dataset: null
\ No newline at end of file
diff --git a/configs/experiments/mscoco_fid.yaml b/configs/experiments/mscoco_fid.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b7b911c79b94c911065c0757140bf38ae3d1349d
--- /dev/null
+++ b/configs/experiments/mscoco_fid.yaml
@@ -0,0 +1,21 @@
+# @package _global_
+
+data:
+ disable_text_modality: false
+ keep_hf_dataset_in_memory: true
+ aggressive_aug: false
+ n_duplicate_train: null
+ n_duplicate_val: null
+ data_dir_train: []
+ data_dir_val: []
+ image_data_train: ${data.image_data_val}
+ image_data_val:
+ - val: sayakpaul/coco-30-val-2014
+ weight: -1
+ name: mscoco_val
+ tokenize_vqvae_in_dataloader: false
+ raw_images: true
+
+eval:
+ compute_generative_perplexity: true
+ generate_samples: true
\ No newline at end of file
diff --git a/configs/experiments/paired_standalone_fid_eval.yaml b/configs/experiments/paired_standalone_fid_eval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..49007f943debe69e4cfc3b6e5a699ae2a8fb2a8f
--- /dev/null
+++ b/configs/experiments/paired_standalone_fid_eval.yaml
@@ -0,0 +1,29 @@
+# @package _global_
+
+mode: eval
+debug: true
+
+eval:
+ fid_samples: 4096
+ max_num_fid_batches_per_device: ${eval:'max(${eval.fid_samples} // (${trainer.devices} * ${loader.eval_batch_size}), 1)'}
+ compute_generative_perplexity: false
+ generate_samples: false
+ log_every_n_fid: 1
+ log_every_n_evals: 1
+ class_conditional_fid: false
+ txt_conditional_fid: true
+ calculate_clip_score: true
+ cfg: 5
+
+model:
+ image_model_fid_eval: true
+
+loader:
+ eval_batch_size: 32
+
+sampling:
+ steps: ${model.length}
+ max_sampling_steps: ${model.length}
+
+data:
+ keep_hf_dataset_in_memory: false
\ No newline at end of file
diff --git a/configs/experiments/small_scale_train.yaml b/configs/experiments/small_scale_train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a39c9e49c59e330822f03d713bba07d73e6e39c6
--- /dev/null
+++ b/configs/experiments/small_scale_train.yaml
@@ -0,0 +1,187 @@
+# @package _global_
+
+defaults:
+ - vq16_magvit
+ - override /model: small
+ - override /lr_scheduler: constant_warmup_cosine_decay
+
+model:
+ img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
+ txt_length: ${eval:'${data.block_size} if ${.unified_model} else 0'}
+ length: ${eval:'${.txt_length} + ${.img_length}'}
+ image_model: true
+ text_model: true
+ unified_model: true
+ image_model_fid_eval: false
+ force_argmax_valid_indices: true
+ use_pretrained_img_emb: false
+ codebook_embed_dim: 256
+ qk_norm: true
+ norm_type: rms
+ sandwich_normalization: true
+ zero_linear_init: false
+ modality_embed: true
+ rope_2d: false
+ use_spda_attn: true
+ force_optimized_native_attn: true
+ freeze_txt_emb: false
+ add_labels: null
+ txt_dropout: null
+ text_vocab_size: 32001
+
+data:
+ train: combined_tokens
+ valid: ${.train}
+ n_duplicate_train: null
+ wrap: true
+ streaming: false
+ precache: false
+ tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
+ resolution: 256
+ block_size: 128
+ n_val_samples: null
+ unpaired: false
+ n_duplicate_val: null
+ save_train_dataloader: true
+ save_validation_dataloader: true
+ iterable: false
+ webdataset_iterable: false
+ webdataset_indexed: false
+ dataset_type: null
+ tokens_flip_collate: false
+ n_train_samples: null
+ raw_data_dir: null
+ tokenizers_parallelism: false
+ token_data_dir: null
+ force_disable_shuffle: false
+ keep_tensordict_on_disk: true
+ use_custom_tensordict_collate: true
+ force_mp_spawn: false
+ enable_cuda_in_tensordict_collate: false
+ use_weighted_tensordict_sampler: true
+ fraction_txt_data: 0.0
+ tokenize_vqvae_in_dataloader: false
+ use_token_dataset: true
+ image_dataset: tglcourse/lsun_church_train
+ image_data_train: null
+ image_data_val: null
+ keep_hf_dataset_in_memory: true
+ allow_label: false
+ disable_text_modality: true
+ force_raw_train_images: false
+ aggressive_aug: true
+ allow_aug_vqvae_dataloader: true
+ move_tensordict_to_shm: false
+ data_dir_train:
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit
+ weight: -1
+ name: datacomp1b_8_magvit_train
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_train_256
+ weight: -1
+ name: cc12m_tokens_train_256
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/HPDv2_image_reward_v1_v2_v3_magvit
+ weight: -1
+ name: HPDv2_image_reward_v1_v2_v3_magvit
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_magvit
+ weight: -1
+ name: pick_score_sac_prompts_v1_v2_v3_magvit
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/datacomp1b_0_1_6_magvit
+ weight: -1
+ name: datacomp1b_0_1_6_magvit
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_0
+ weight: -1
+ name: laion400m_magvit_part_0
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_1
+ weight: -1
+ name: laion400m_magvit_part_1
+ data_dir_val:
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit_val
+ weight: 1
+ name: datacomp1b_8_magvit_val
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_val_256
+ weight: 1
+ name: cc12m_tokens_val_256
+
+eval:
+ generate_samples: true
+ compute_generative_perplexity: true
+ log_every_n_evals: 10
+ log_every_n_fid: 20
+ limit_val_batches_manual: 16
+ perplexity_batch_size: ${loader.eval_batch_size}
+ num_masking_viz_batches: -1
+ cfg: null
+ class_conditional_fid: false
+ force_cfg_value: true
+ split_cfg_batches: true
+ max_num_fid_batches_per_device: ${eval:'8192 // (${trainer.devices} * ${loader.eval_batch_size})'}
+ fid_mode: clean
+ clean_fid_precomputed_name: lsun_church
+ clean_fid_precomputed_split: trainfull
+ clean_fid_precomputed_res: 256
+
+trainer:
+ log_every_n_steps: 10
+ val_check_interval: 1000
+ custom_ddp_bf16: true
+ scale_lr_by_batch_size: false
+ limit_val_batches: 16
+ use_gradient_checkpointing: false
+ log_seperate_modal_losses: true
+ softmin_snr: 5
+ text_loss_weight: 1.0
+ img_loss_weight: null
+ low_precision_loss: false
+ compile: true
+ multimodal_batches: true
+ compile_fullgraph: false
+ log_grad_norm_every_n_steps: 10
+ mask_entire_modality: 0.1
+ force_shift_image_batches: false
+ ckpt_steps: 10000
+ ckpt_every_n_minutes: -1
+ ignore_text_in_unified: false
+ disable_all_eval_generation: true
+ eval_on_start: false
+ ckpt_model_only: false
+ ema: 0.0
+ use_custom_ema: false
+ log_flops: false
+ disable_distributed_torchmetrics: true
+ restart_on_failure: true
+ force_null_sigma: true
+ allow_null_sigma: true
+ compile_flag_pos_emb: true
+ add_label: false
+ first_token_dropout: null
+ force_shift_raw_image_batches: true
+ txt_dropout: 0.1
+ force_full_attention_mask_loss_only: true
+
+optim:
+ lr: 0.0003
+ weight_decay: 0.05
+
+loader:
+ batch_size: 64
+ eval_batch_size: ${loader.batch_size}
+ num_workers: 4
+ desired_global_batch_size: 512
+ persistent_workers: true
+ pin_memory: true
+ num_eval_workers: 1
+
+sampling:
+ steps: ${model.length}
+ num_sample_batches: 2
+ max_sampling_steps: ${model.length}
+
+wandb:
+ mode: online
+
+lr_scheduler:
+ num_warmup_steps: 5000
+ num_training_steps: ${trainer.max_steps}
+
+checkpointing:
+ checkpoints_total_limit: 10
\ No newline at end of file
diff --git a/configs/experiments/small_scale_train_caching.yaml b/configs/experiments/small_scale_train_caching.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b0bac00ccb4983f3e57e5aba08aa93f3824d2672
--- /dev/null
+++ b/configs/experiments/small_scale_train_caching.yaml
@@ -0,0 +1,186 @@
+# @package _global_
+
+defaults:
+ - /model: small
+
+model:
+ downscale_ratio: 16
+ image_vocab_size: 8192
+ vae_type: magvit
+ use_custom_vae_ckpt: null
+ custom_vae_name: null
+ img_length: 256
+ txt_length: 128
+ image_model: true
+ text_model: true
+ unified_model: true
+ image_model_fid_eval: false
+ force_argmax_valid_indices: true
+ use_pretrained_img_emb: false
+ codebook_embed_dim: 256
+ qk_norm: true
+ norm_type: rms
+ sandwich_normalization: true
+ zero_linear_init: false
+ modality_embed: true
+ rope_2d: false
+ use_spda_attn: true
+ force_optimized_native_attn: true
+ freeze_txt_emb: false
+ add_labels: null
+ txt_dropout: null
+ text_vocab_size: 32001
+ use_flex_attention: true
+ flex_attention_txt_masking_prob: 0.1
+ flex_attention_img_masking_prob: 0.1
+ linear_factor: 1
+data:
+ train: combined_tokens
+ valid: ${.train}
+ n_duplicate_train: null
+ wrap: true
+ streaming: false
+ precache: false
+ tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
+ resolution: 256
+ block_size: 128
+ n_val_samples: null
+ unpaired: false
+ n_duplicate_val: null
+ save_train_dataloader: true
+ save_validation_dataloader: true
+ iterable: false
+ webdataset_iterable: false
+ webdataset_indexed: false
+ dataset_type: null
+ tokens_flip_collate: false
+ n_train_samples: null
+ raw_data_dir: null
+ tokenizers_parallelism: false
+ token_data_dir: null
+ force_disable_shuffle: false
+ keep_tensordict_on_disk: true
+ use_custom_tensordict_collate: true
+ force_mp_spawn: false
+ enable_cuda_in_tensordict_collate: false
+ use_weighted_tensordict_sampler: true
+ fraction_txt_data: 0.0
+ data_dir_train:
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit
+ weight: -1
+ name: datacomp1b_8_magvit_train
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_train_256
+ weight: -1
+ name: cc12m_tokens_train_256
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/HPDv2_image_reward_v1_v2_v3_magvit
+ weight: -1
+ name: HPDv2_image_reward_v1_v2_v3_magvit
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_magvit
+ weight: -1
+ name: pick_score_sac_prompts_v1_v2_v3_magvit
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/datacomp1b_0_1_6_magvit
+ weight: -1
+ name: datacomp1b_0_1_6_magvit
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_0
+ weight: -1
+ name: laion400m_magvit_part_0
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_1
+ weight: -1
+ name: laion400m_magvit_part_1
+ data_dir_val:
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit_val
+ weight: 1
+ name: datacomp1b_8_magvit_val
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_val_256
+ weight: 1
+ name: cc12m_tokens_val_256
+ tokenize_vqvae_in_dataloader: false
+ val:
+ .train: null
+ use_token_dataset: true
+ image_dataset: tglcourse/lsun_church_train
+ image_data_train: null
+ image_data_val: null
+ keep_hf_dataset_in_memory: true
+ allow_label: false
+ disable_text_modality: true
+ force_raw_train_images: false
+ aggressive_aug: true
+ allow_aug_vqvae_dataloader: true
+ move_tensordict_to_shm: false
+ force_full_attention_mask: false
+eval:
+ generate_samples: false
+ compute_generative_perplexity: false
+ log_every_n_evals: 10
+ log_every_n_fid: 20
+ limit_val_batches_manual: 16
+ perplexity_batch_size: ${loader.eval_batch_size}
+ num_masking_viz_batches: -1
+ max_num_fid_batches_per_device: ${eval:'8192 // (${trainer.devices} * ${loader.eval_batch_size})'}
+ cfg: null
+ class_conditional_fid: false
+ force_cfg_value: true
+ split_cfg_batches: true
+ fid_mode: clean
+ clean_fid_precomputed_name: lsun_church
+ clean_fid_precomputed_split: trainfull
+ clean_fid_precomputed_res: 256
+trainer:
+ log_every_n_steps: 10
+ val_check_interval: 1000
+ custom_ddp_bf16: true
+ scale_lr_by_batch_size: false
+ limit_val_batches: 16
+ use_gradient_checkpointing: false
+ log_seperate_modal_losses: true
+ softmin_snr: 5
+ text_loss_weight: 1.0
+ img_loss_weight: null
+ low_precision_loss: false
+ compile: false
+ multimodal_batches: true
+ compile_fullgraph: false
+ log_grad_norm_every_n_steps: 10
+ mask_entire_modality: 0.1
+ force_shift_image_batches: false
+ ckpt_steps: 10000
+ ckpt_every_n_minutes: -1
+ ignore_text_in_unified: false
+ disable_all_eval_generation: false
+ eval_on_start: false
+ ckpt_model_only: false
+ ema: 0.0
+ use_custom_ema: false
+ log_flops: false
+ disable_distributed_torchmetrics: true
+ restart_on_failure: true
+ force_null_sigma: true
+ allow_null_sigma: true
+ compile_flag_pos_emb: true
+ add_label: false
+ first_token_dropout: null
+ force_shift_raw_image_batches: true
+ txt_dropout: 0.1
+ disable_ddp_optimizer: true
+optim:
+ lr: 0.0003
+ weight_decay: 0.05
+loader:
+ batch_size: 64
+ eval_batch_size: ${loader.batch_size}
+ num_workers: 1
+ desired_global_batch_size: 512
+ persistent_workers: true
+ pin_memory: true
+ num_eval_workers: 1
+sampling:
+ steps: ${model.length}
+ num_sample_batches: 2
+ max_sampling_steps: ${model.length}
+wandb:
+ mode: online
+lr_scheduler:
+ num_warmup_steps: 5000
+checkpointing:
+ checkpoints_total_limit: 4
diff --git a/configs/experiments/small_text_only.yaml b/configs/experiments/small_text_only.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6bcedd37a1e1a45e6ecba8a650dddfce505c321b
--- /dev/null
+++ b/configs/experiments/small_text_only.yaml
@@ -0,0 +1,28 @@
+# @package _global_
+
+defaults:
+ - lsun_text8_exp_2
+ - owt_only
+ - override /model: small
+
+backbone: dit
+
+loader:
+ batch_size: 64
+
+trainer:
+ val_check_interval: 10000
+ ckpt_steps: 10000
+ softmin_snr: null
+
+optim:
+ fused: true
+ weight_decay: 0.03
+
+sampling:
+ num_sample_batches: 4
+ max_sampling_steps: 256
+
+model:
+ txt_length: 1024
+
\ No newline at end of file
diff --git a/configs/experiments/standalone_fid_eval.yaml b/configs/experiments/standalone_fid_eval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2286a6bb4641a7995788a1b16638f38a0a3ee4c6
--- /dev/null
+++ b/configs/experiments/standalone_fid_eval.yaml
@@ -0,0 +1,18 @@
+# @package _global_
+
+mode: eval
+debug: true
+
+eval:
+ max_num_fid_batches_per_device: ${eval:'4096 // (${trainer.devices} * ${loader.eval_batch_size})'}
+ compute_generative_perplexity: false
+ generate_samples: false
+ log_every_n_fid: 1
+ log_every_n_evals: 1
+
+loader:
+ eval_batch_size: 32
+
+sampling:
+ steps: 500
+ max_sampling_steps: 500
diff --git a/configs/experiments/titok.yaml b/configs/experiments/titok.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ca0720be95066508e5e0efe4bad02b1cf9de02f0
--- /dev/null
+++ b/configs/experiments/titok.yaml
@@ -0,0 +1,8 @@
+# @package _global_
+
+data:
+ resolution: 256
+ downscale_ratio: 16
+
+model:
+ vae_type: titok
\ No newline at end of file
diff --git a/configs/experiments/titok_sl256.yaml b/configs/experiments/titok_sl256.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9ba95f3d505f729732db4e872b7d08db8035f583
--- /dev/null
+++ b/configs/experiments/titok_sl256.yaml
@@ -0,0 +1,7 @@
+# @package _global_
+
+data:
+ resolution: 256
+
+model:
+ vae_type: titok
\ No newline at end of file
diff --git a/configs/experiments/txt_only.yaml b/configs/experiments/txt_only.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..816a347e926d3c0c203c08329588965aee0d9f8f
--- /dev/null
+++ b/configs/experiments/txt_only.yaml
@@ -0,0 +1,21 @@
+# @package _global_
+
+data:
+ streaming: False
+ unpaired: false
+
+trainer:
+ img_loss_weight: null
+ text_loss_weight: null
+
+model:
+ use_pretrained_img_emb: false
+ image_model_fid_eval: false
+ unified_model: false
+ image_model: false
+ txt_length: 256
+ img_length: 0
+
+eval:
+ log_every_n_evals: -1
+ log_every_n_fid: -1
\ No newline at end of file
diff --git a/configs/experiments/unified.yaml b/configs/experiments/unified.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..60068dd7bf76d8ae631192a47b9a7b2c715ab009
--- /dev/null
+++ b/configs/experiments/unified.yaml
@@ -0,0 +1,23 @@
+# @package _global_
+
+data:
+ zero_shot_eval_dataset: "nlphuji/flickr30k"
+ precache: False
+ tokenizers_parallelism: False # parallelism causes some weird error
+ n_val_samples: 2048
+ block_size: 128
+
+model:
+ unified_model: True
+ text_model: true
+
+checkpointing:
+ resume_from_ckpt: True
+ load_from_text_model: "ckpts/unidisc-owt/model.safetensors"
+
+loader:
+ batch_size: 12
+
+trainer:
+ val_check_interval: 2000
+ log_seperate_modal_losses: true
\ No newline at end of file
diff --git a/configs/experiments/vq16.yaml b/configs/experiments/vq16.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3abefcdd36b13e262b79b82b676d8d536ba00575
--- /dev/null
+++ b/configs/experiments/vq16.yaml
@@ -0,0 +1,9 @@
+# @package _global_
+
+model:
+ downscale_ratio: 16
+ image_vocab_size: 16384
+ vae_type: VQ-16
+ use_custom_vae_ckpt: null
+ custom_vae_name: null
+ img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
\ No newline at end of file
diff --git a/configs/experiments/vq16_1024.yaml b/configs/experiments/vq16_1024.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0b57163369a32372071a7b232ee3605877c2bfee
--- /dev/null
+++ b/configs/experiments/vq16_1024.yaml
@@ -0,0 +1,8 @@
+# @package _global_
+
+model:
+ downscale_ratio: 16
+ image_vocab_size: 1024
+ codebook_embed_dim: 256
+ vae_type: VQ-16
+ use_custom_vae_ckpt: ${oc.env:DIFFUSION_DATA_DIR}/ckpts/2024-07-03-01-10-53_022-VQ-16_0042000.pt
\ No newline at end of file
diff --git a/configs/experiments/vq16_magvit.yaml b/configs/experiments/vq16_magvit.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dbc002f78fcbf8da28ca5154cda156646902752a
--- /dev/null
+++ b/configs/experiments/vq16_magvit.yaml
@@ -0,0 +1,9 @@
+# @package _global_
+
+model:
+ downscale_ratio: 16
+ image_vocab_size: 8192
+ vae_type: magvit
+ use_custom_vae_ckpt: null
+ custom_vae_name: null
+ img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
\ No newline at end of file
diff --git a/configs/experiments/vq16_t2i.yaml b/configs/experiments/vq16_t2i.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4f6bf2fa9ad0ca5ab743609fdcc3dd8c6a0db284
--- /dev/null
+++ b/configs/experiments/vq16_t2i.yaml
@@ -0,0 +1,10 @@
+# @package _global_
+
+model:
+ downscale_ratio: 16
+ image_vocab_size: 16384
+ vae_type: VQ-16
+ use_custom_vae_ckpt: ${get_repo_dir:}/ckpts/vq_ds16_t2i.pt
+ custom_vae_name: _t2i
+ codebook_embed_dim: 8
+ img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
\ No newline at end of file
diff --git a/configs/experiments/webdataset.yaml b/configs/experiments/webdataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5a26b5321724f365be2f486cdf83b85bf6fa0b99
--- /dev/null
+++ b/configs/experiments/webdataset.yaml
@@ -0,0 +1,12 @@
+# @package _global_
+
+data:
+ train: datacomp1b_indexed
+ valid: ${.train}
+
+ iterable: false
+ webdataset_iterable: false
+ webdataset_indexed: true
+ unpaired: false
+ dataset_type: null
+ tokens_flip_collate: false
\ No newline at end of file
diff --git a/configs/experiments/zero_shot_eval.yaml b/configs/experiments/zero_shot_eval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0d06ee0d079e5fde2869668dbfe078ecf4451913
--- /dev/null
+++ b/configs/experiments/zero_shot_eval.yaml
@@ -0,0 +1,29 @@
+# @package _global_
+
+mode: zero-shot-eval
+
+data:
+ # train: "nlphuji/flickr30k"
+ train: "facebook/winoground"
+ precache: False
+ tokenizers_parallelism: False # parallelism causes some weird error
+ n_val_samples: 2048
+ block_size: 128
+ disable_text_modality: false
+
+eval:
+ cfg: 5
+ compute_val_metrics_standalone: false
+ compute_img_to_txt_mauve_clip: false
+
+loader:
+ batch_size: 16
+ eval_batch_size: 16
+
+
+model:
+ unified_model: True
+ text_model: true
+ image_model: true
+ vae_type: magvit
+ force_optimized_native_attn: false
\ No newline at end of file
diff --git a/configs/lr_scheduler/constant_warmup.yaml b/configs/lr_scheduler/constant_warmup.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0bfcbe3213ae1d1e568b1bca3846c1924bd62cce
--- /dev/null
+++ b/configs/lr_scheduler/constant_warmup.yaml
@@ -0,0 +1,2 @@
+_target_: transformers.get_constant_schedule_with_warmup
+num_warmup_steps: 2500
\ No newline at end of file
diff --git a/configs/lr_scheduler/constant_warmup_cosine_decay.yaml b/configs/lr_scheduler/constant_warmup_cosine_decay.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..360aac251ad9f2366d3586828afb93c58b7dc06f
--- /dev/null
+++ b/configs/lr_scheduler/constant_warmup_cosine_decay.yaml
@@ -0,0 +1,3 @@
+_target_: transformers.get_cosine_schedule_with_warmup
+num_warmup_steps: 2500
+num_training_steps: 1000000
\ No newline at end of file
diff --git a/configs/lr_scheduler/cosine_decay_warmup.yaml b/configs/lr_scheduler/cosine_decay_warmup.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e0d66b1d0d460c881e467b8b93888e193e6eb0cb
--- /dev/null
+++ b/configs/lr_scheduler/cosine_decay_warmup.yaml
@@ -0,0 +1,7 @@
+_target_: utils.CosineDecayWarmupLRScheduler
+t_in_epochs: False
+t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}
+warmup_prefix: True
+warmup_lr_init: 1e-6
+warmup_t: ${eval:0.1*${trainer.max_steps}}
+lr_min: 1e-6
\ No newline at end of file
diff --git a/configs/lr_scheduler/cosine_with_hard_restarts_schedule_with_warmup.yaml b/configs/lr_scheduler/cosine_with_hard_restarts_schedule_with_warmup.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..24898a4d6bcb0e8ed6f6bca49338659484fdd21f
--- /dev/null
+++ b/configs/lr_scheduler/cosine_with_hard_restarts_schedule_with_warmup.yaml
@@ -0,0 +1,4 @@
+_target_: transformers.get_cosine_with_hard_restarts_schedule_with_warmup
+num_warmup_steps: 2500
+num_training_steps: 1000000
+num_cycles: 1
\ No newline at end of file
diff --git a/configs/model/extra_large.yaml b/configs/model/extra_large.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..343ae57933e05803a2242ec4efcb5067473ba3d5
--- /dev/null
+++ b/configs/model/extra_large.yaml
@@ -0,0 +1,10 @@
+name: extra_large
+type: ddit
+hidden_size: 2048
+cond_dim: 128
+length: 1024
+n_blocks: 24
+n_heads: 16
+scale_by_sigma: True
+dropout: 0.1
+tie_word_embeddings: False
\ No newline at end of file
diff --git a/configs/model/large.yaml b/configs/model/large.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6f17744d60435cf9c0393c6c7d45978bb21c4dc7
--- /dev/null
+++ b/configs/model/large.yaml
@@ -0,0 +1,14 @@
+name: large
+type: ddit
+hidden_size: 1280
+cond_dim: 128
+length: 1024
+base_n_blocks: 28
+# We try to roughly match parameter count
+n_blocks: ${adjust_n_blocks:}
+n_heads: 20
+scale_by_sigma: True
+dropout: 0.1
+tie_word_embeddings: False
+
+# 36 1280 20
\ No newline at end of file
diff --git a/configs/model/medium.yaml b/configs/model/medium.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4d376aeee7e93b4f1305dc55c4bbbb4b84b453cd
--- /dev/null
+++ b/configs/model/medium.yaml
@@ -0,0 +1,12 @@
+name: medium
+type: ddit
+hidden_size: 1024
+cond_dim: 128
+length: 1024
+base_n_blocks: 24
+# We try to roughly match parameter count
+n_blocks: ${adjust_n_blocks:}
+n_heads: 16
+scale_by_sigma: True
+dropout: 0.1
+tie_word_embeddings: False
\ No newline at end of file
diff --git a/configs/model/small-ar.yaml b/configs/model/small-ar.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fa140031979e92aa7fc13f2a4fc3959c80cf48c8
--- /dev/null
+++ b/configs/model/small-ar.yaml
@@ -0,0 +1,11 @@
+name: small
+type: ddit
+hidden_size: 768
+cond_dim: 128
+length: 1024
+n_blocks: 12
+n_heads: 12
+scale_by_sigma: True
+dropout: 0.1
+causal: True
+tie_word_embeddings: False
\ No newline at end of file
diff --git a/configs/model/small.yaml b/configs/model/small.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5ce9703adaba2f385059bf78155db2feb27d58f1
--- /dev/null
+++ b/configs/model/small.yaml
@@ -0,0 +1,12 @@
+name: small
+type: ddit
+hidden_size: 768
+cond_dim: 128
+length: 1024
+base_n_blocks: 12
+# We try to roughly match parameter count
+n_blocks: ${adjust_n_blocks:}
+n_heads: 12
+scale_by_sigma: True
+dropout: 0.1
+tie_word_embeddings: False
\ No newline at end of file
diff --git a/configs/model/small_imagenet.yaml b/configs/model/small_imagenet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..71eba5fad24ab6f51ad3b3698f85ecd4e1b87846
--- /dev/null
+++ b/configs/model/small_imagenet.yaml
@@ -0,0 +1,10 @@
+name: small
+type: ddit
+hidden_size: 768
+cond_dim: 128
+length: 1024
+n_blocks: 12
+n_heads: 12
+scale_by_sigma: True
+dropout: 0.1
+tie_word_embeddings: False
\ No newline at end of file
diff --git a/configs/model/tiny-ar.yaml b/configs/model/tiny-ar.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bef5fe891e23cdab54e031b16e643c150b9488ef
--- /dev/null
+++ b/configs/model/tiny-ar.yaml
@@ -0,0 +1,11 @@
+name: tiny
+type: ddit
+hidden_size: 512
+cond_dim: 128
+length: 1024
+n_blocks: 8
+n_heads: 8
+scale_by_sigma: True
+dropout: 0.1
+causal: True
+tie_word_embeddings: False
\ No newline at end of file
diff --git a/configs/model/tiny.yaml b/configs/model/tiny.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..34a1f136d2a33db6516bde2de6674d7e100d49d8
--- /dev/null
+++ b/configs/model/tiny.yaml
@@ -0,0 +1,10 @@
+name: tiny
+type: ddit
+hidden_size: 512
+cond_dim: 128
+length: 1024
+n_blocks: 8
+n_heads: 8
+scale_by_sigma: True
+dropout: 0.1
+tie_word_embeddings: False
\ No newline at end of file
diff --git a/configs/model/xs.yaml b/configs/model/xs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b01ba7afc862eb0a9efecd2701ee3c7e372a8ef1
--- /dev/null
+++ b/configs/model/xs.yaml
@@ -0,0 +1,10 @@
+name: tiny
+type: ddit
+hidden_size: 256
+cond_dim: 128
+length: 1024
+n_blocks: 4
+n_heads: 8
+scale_by_sigma: True
+dropout: 0.1
+tie_word_embeddings: False
\ No newline at end of file
diff --git a/configs/model/xxl.yaml b/configs/model/xxl.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f3a917473db1beb18c5c1f62c3c2c98fc94e6873
--- /dev/null
+++ b/configs/model/xxl.yaml
@@ -0,0 +1,10 @@
+name: xxl
+type: ddit
+hidden_size: 4096
+cond_dim: 128
+length: 1024
+n_blocks: 30
+n_heads: 16
+scale_by_sigma: True
+dropout: 0.1
+tie_word_embeddings: False
\ No newline at end of file
diff --git a/configs/noise/ar.yaml b/configs/noise/ar.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ea87023b328f9d078d5b4d4715b3fbfe270d1db7
--- /dev/null
+++ b/configs/noise/ar.yaml
@@ -0,0 +1,2 @@
+type: ar
+scale: 6.0
\ No newline at end of file
diff --git a/configs/noise/linear.yaml b/configs/noise/linear.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..12fffce86b73e79ad8bffe6fca4d3a9a7d8a808a
--- /dev/null
+++ b/configs/noise/linear.yaml
@@ -0,0 +1,3 @@
+type: linear
+sigma_min: 1e-3
+sigma_max: 7.0
\ No newline at end of file
diff --git a/configs/noise/loglinear.yaml b/configs/noise/loglinear.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..04c914f32ffe43e52b35ea31fb4617d366983393
--- /dev/null
+++ b/configs/noise/loglinear.yaml
@@ -0,0 +1,3 @@
+type: loglinear
+sigma_min: 1e-4
+sigma_max: 20
\ No newline at end of file
diff --git a/configs/noise/polynomial.yaml b/configs/noise/polynomial.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7218191733ac458ac2d9de1411fe5de7cde42a3c
--- /dev/null
+++ b/configs/noise/polynomial.yaml
@@ -0,0 +1,5 @@
+type: polynomial
+a: -3
+b: 5
+c: -4
+eps: 1e-3
\ No newline at end of file
diff --git a/constants.py b/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..d64071fff1b39c54edd36c03c873de0a64ef97bf
--- /dev/null
+++ b/constants.py
@@ -0,0 +1,23 @@
+from pathlib import Path
+import os
+
+UNIDISC_DIR = Path(os.getenv("UNIDISC_DIR", Path(__file__).parent))
+LIB_DIR = UNIDISC_DIR / "third_party"
+CELEBV_DATA_DIR = Path(os.getenv("CELEBV_DATA_DIR", "/home/mprabhud/aswerdlo/repos/lib/CelebV-Text/downloaded_celebvtext"))
+SCRATCH_CELEBV_DATA_DIR = Path("/scratch/aswerdlo/sora/celebv_text/downloaded_celebvtext")
+CONFIG_PATH = os.getenv("UNIDISC_CONFIG_PATH", "configs")
+HF_TOKEN = os.getenv("HF_TOKEN", os.getenv("HF_HUB_DATASETS_TOKEN"))
+HF_DATASETS_CACHE = os.getenv("HF_DATASETS_CACHE", None)
+HF_CACHE_DIR = os.getenv("HF_HOME", None)
+
+if HF_CACHE_DIR is not None:
+ HF_CACHE_DIR = Path(HF_CACHE_DIR)
+elif HF_DATASETS_CACHE is not None:
+ HF_CACHE_DIR = Path(HF_DATASETS_CACHE).parent
+else:
+ HF_CACHE_DIR = Path("~/.cache/huggingface").expanduser()
+try:
+ if SCRATCH_CELEBV_DATA_DIR.exists():
+ CELEBV_DATA_DIR = SCRATCH_CELEBV_DATA_DIR
+except:
+ print(f"Error setting CELEBV_DATA_DIR")
diff --git a/data_defs.py b/data_defs.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f586de22d683a88f863a6dc93635bb7fe577426
--- /dev/null
+++ b/data_defs.py
@@ -0,0 +1,94 @@
+from __future__ import annotations
+from tensordict import tensorclass
+import torch
+from torch import nn
+from typing import Optional
+from unidisc.utils.tensor_utils import get_contiguous_blocks, get_interleaved_indices
+from tensordict import TensorDict
+
+@tensorclass
+class InterleavedBatch:
+ input_ids: torch.Tensor
+ modality: torch.Tensor
+ sample_ids: torch.Tensor
+ attention_mask: Optional[torch.Tensor] = None
+
+ def to_ragged_batch(self):
+ data = []
+ batch_indices, start_positions, end_positions = get_contiguous_blocks(self.sample_ids)
+ first_sample_ids = self.sample_ids[batch_indices, start_positions]
+ self.auto_batch_size_()
+ for i in range(batch_indices.shape[0]):
+ if first_sample_ids[i] == -1:
+ continue
+ data.append(self[batch_indices[i], start_positions[i]:end_positions[i]])
+
+ return TensorDict.lazy_stack(data, dim=0)
+
+ def to_elements(self):
+ data = self.to_ragged_batch()
+ new_data = []
+ for i in range(data.shape[0]):
+ new_data.append(InterleavedElement.from_raw(data[i]))
+ return TensorDict.lazy_stack(new_data, dim=0)
+
+ @classmethod
+ def custom_from_dict(cls, data: TensorDict):
+ new_dict = {}
+ for field in cls.fields():
+ if field.name in data:
+ new_dict[field.name] = data[field.name]
+
+ return cls(**new_dict)
+
+
+@tensorclass
+class InterleavedElement:
+ txt_input_ids: Optional[list[torch.Tensor]] = None
+ img_input_ids: Optional[list[torch.Tensor]] = None
+ txt: Optional[torch.Tensor] = None
+ img: Optional[torch.Tensor] = None
+ img_pos_ids: Optional[torch.Tensor] = None
+ batch_indices: Optional[torch.Tensor] = None
+ start_positions: Optional[torch.Tensor] = None
+ end_positions: Optional[torch.Tensor] = None
+ raw_data: Optional[InterleavedBatch] = None
+
+ @classmethod
+ def from_raw(cls, interleaved_batch: InterleavedBatch):
+ batch_indices, start_positions, end_positions = get_contiguous_blocks(interleaved_batch.modality[None])
+ block_modality = interleaved_batch.modality[start_positions]
+
+ img_input_ids = []
+ txt_input_ids = []
+ img_pos_ids = []
+ for i in range(batch_indices.shape[0]):
+ if block_modality[i] == 1:
+ assert len(txt_input_ids) > 0
+ img_input_ids.append(interleaved_batch.input_ids[start_positions[i]:end_positions[i]])
+ img_pos_ids.append(len(txt_input_ids) - 1)
+ else:
+ txt_input_ids.append(interleaved_batch.input_ids[start_positions[i]:end_positions[i]])
+
+ return cls(img_input_ids=img_input_ids, txt_input_ids=txt_input_ids, img_pos_ids=torch.tensor(img_pos_ids), batch_indices=batch_indices, start_positions=start_positions, end_positions=end_positions, raw_data=interleaved_batch)
+
+ def to_list(self):
+ txt_idx = 0
+ img_idx = 0
+ has_added_txt = False
+ data = []
+ modalities = []
+ while txt_idx < len(self.txt_input_ids) or img_idx < len(self.img_input_ids):
+ if not has_added_txt and txt_idx < len(self.txt_input_ids):
+ data.append(self.txt_input_ids[txt_idx])
+ modalities.append(0)
+ has_added_txt = True
+ elif img_idx < len(self.img_input_ids) and self.img_pos_ids[img_idx] == txt_idx:
+ data.append(self.img_input_ids[img_idx])
+ modalities.append(1)
+ img_idx += 1
+ else:
+ has_added_txt = False
+ txt_idx += 1
+
+ return data, modalities
\ No newline at end of file
diff --git a/dataloader.py b/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..70b18d08238726eb337090520564dcc7510fa950
--- /dev/null
+++ b/dataloader.py
@@ -0,0 +1,678 @@
+import math
+import typing
+from pathlib import Path
+
+import tokenizers
+import torch
+import transformers
+from unidisc.datasets.sampler import WeightedDatasetSampler
+
+from models.datasets.image_datasets import TensorCollate, get_image_dataset, get_unpaired_dataset
+from models.datasets.text_datasets import Text8Tokenizer, get_text_dataset
+from torch.utils.data import default_collate
+from decoupled_utils import breakpoint_on_error, gprint, rprint, is_torch_xla_available
+from datasets import load_dataset
+
+
+def identity(x):
+ return x
+
+
+def get_dataset(dataset_name, tokenizer, *args, config=None, **kwargs):
+ rprint(f"getting dataset {dataset_name}")
+ if getattr(config.data, "unpaired", False):
+ return get_unpaired_dataset(config=config, tokenizer=tokenizer, **kwargs)
+ elif getattr(config.model, "image_model", False) or getattr(config.data, "force_image_dataset", False):
+ return get_image_dataset(config=config, tokenizer=tokenizer, **kwargs)
+ else:
+ rprint(f"getting text dataset")
+ return get_text_dataset(dataset_name, tokenizer, *args, **kwargs)
+
+def tokenize_text(tokenizer, block_size, text, return_token_type_ids=True):
+ return tokenizer(text, max_length=block_size, padding="max_length", truncation=True, add_special_tokens=True, return_attention_mask=True, return_token_type_ids=return_token_type_ids).convert_to_tensors("pt")
+
+def get_tokenizer(config):
+ if config.data.tokenizer_name_or_path is None or config.data.tokenizer_name_or_path == "None":
+ return None
+ elif config.data.tokenizer_name_or_path == "text8":
+ tokenizer = Text8Tokenizer()
+ elif config.data.tokenizer_name_or_path == "bert-base-uncased":
+ tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
+ else:
+ tokenizer_kwargs = dict()
+ if config.data.tokenizer_name_or_path == "NousResearch/Llama-2-7b-hf":
+ tokenizer_kwargs["add_eos_token"] = True
+ tokenizer_kwargs["padding_side"] = 'right'
+ rprint("Using Llama tokenizer, adding add_eos_token and setting padding_side to right")
+ if getattr(config.data, "use_slow_tokenizer", False):
+ tokenizer_kwargs["use_fast"] = False
+ tokenizer = transformers.AutoTokenizer.from_pretrained(config.data.tokenizer_name_or_path, **tokenizer_kwargs)
+
+ if getattr(config.data, "add_image_token", False):
+ special_token = ''
+ existing_id = 811
+ tmp_index = len(tokenizer)
+ tokenizer.add_special_tokens({
+ 'additional_special_tokens': [special_token]
+ }, replace_additional_special_tokens=False)
+ tokenizer._added_tokens_decoder[existing_id] = tokenizer._added_tokens_decoder.pop(tmp_index)
+ assert len(tokenizer.additional_special_tokens_ids) == 1
+ tokenizer.additional_special_tokens_ids = [existing_id]
+ tokenizer._added_tokens_encoder[''] = existing_id
+ tokenizer.total_vocab_size = tmp_index
+
+ if isinstance(tokenizer, transformers.GPT2TokenizerFast) or isinstance(tokenizer, transformers.GPT2Tokenizer):
+ tokenizer._tokenizer.post_processor = tokenizers.processors.BertProcessing(
+ (tokenizer.bos_token, tokenizer.bos_token_id), (tokenizer.eos_token, tokenizer.eos_token_id)
+ )
+
+ # For wrapped batches:
+ # [BOS] sent1 [EOS] sent2-fragment [EOS]
+ # [BOS] sent2-fragment [EOS] sent3 [EOS]
+ if tokenizer.bos_token is None:
+ if tokenizer.cls_token is None:
+ raise AttributeError("Tokenizer must have a bos_token or " f"cls_token: {tokenizer}")
+ tokenizer.bos_token = tokenizer.cls_token
+ if tokenizer.eos_token is None:
+ if tokenizer.sep_token is None:
+ raise AttributeError("Tokenizer must have a eos_token " f"or sep_token: {tokenizer}")
+ tokenizer.eos_token = tokenizer.sep_token
+ if tokenizer.pad_token is None:
+ if config.data.tokenizer_name_or_path != "gpt2":
+ rprint(f"Adding pad token to tokenizer")
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
+
+ assert tokenizer.padding_side == 'right'
+ assert tokenizer.truncation_side == 'right'
+
+ return tokenizer
+
+
+class SimpleDataLoader:
+ def __init__(self, dataset, batch_size=1, collate_fn=default_collate, **kwargs):
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.collate_fn = collate_fn
+ self.idx = 0
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.idx < len(self.dataset):
+ batch = []
+ for _ in range(self.batch_size):
+ if self.idx >= len(self.dataset):
+ break
+ batch.append(self.dataset[self.idx])
+ self.idx += 1
+ return self.collate_fn(batch)
+ else:
+ raise StopIteration
+
+ def __len__(self):
+ return (len(self.dataset) + self.batch_size - 1) // self.batch_size
+
+def get_zero_shot_dataloader(config, tokenizer, device=None, **kwargs):
+ if config.data.zero_shot_eval_dataset is None:
+ rprint("No zero shot eval dataset provided")
+ return None, None
+
+ dataset_name = config.data.zero_shot_eval_dataset
+ dataloader_seed = config.seed if config.mode == "eval" else 42
+ if dataset_name == "nlphuji/flickr30k":
+ data = load_dataset(dataset_name, num_proc=config.data.num_proc, cache_dir=config.data.cache_dir, streaming=config.data.streaming)
+ dataset = data["test"]
+ elif dataset_name == "facebook/winoground":
+ data = load_dataset(dataset_name, num_proc=config.data.num_proc, cache_dir=config.data.cache_dir, streaming=config.data.streaming)
+ dataset = data["test"]
+ breakpoint()
+ dl_cls = torch.utils.data.DataLoader
+ valid_loader = dl_cls(
+ dataset,
+ batch_size=config.loader.eval_batch_size,
+ num_workers=config.loader.num_eval_workers,
+ pin_memory=config.loader.pin_memory,
+ generator=torch.Generator().manual_seed(dataloader_seed),
+ persistent_workers=False,
+ **kwargs,
+ )
+ valid_loader.tokenizer = tokenizer
+ return valid_loader
+
+
+def get_dataloaders(config, tokenizer, skip_train=False, skip_valid=False, valid_seed=None, device=None, **kwargs):
+ if skip_train:
+ train_set = None
+ else:
+ _mode = getattr(config.data, "force_train_mode", "train")
+ if _mode != "train":
+ rprint(f"Forcing train mode to {_mode}")
+ train_set = get_dataset(
+ config.data.train,
+ tokenizer,
+ mode=_mode,
+ wrap=config.data.wrap,
+ cache_dir=config.data.cache_dir,
+ block_size=config.model.length,
+ num_proc=config.data.num_proc,
+ streaming=config.data.streaming,
+ config=config,
+ **kwargs,
+ )
+ if hasattr(train_set, '__len__'):
+ rprint(f"Training set len: {len(train_set)}")
+
+ if config.data.valid in ["text8", "lm1b", "ag_news"]:
+ validation_split = "test"
+ else:
+ validation_split = "validation"
+
+ if skip_valid:
+ valid_set = None
+ else:
+ valid_set = get_dataset(
+ config.data.valid,
+ tokenizer,
+ wrap=config.data.wrap,
+ mode=validation_split,
+ cache_dir=config.data.cache_dir,
+ block_size=config.model.length,
+ streaming=False,
+ num_proc=config.data.num_proc,
+ config=config,
+ **kwargs,
+ )
+ if hasattr(valid_set, '__len__'):
+ rprint(f"Validation set len: {len(valid_set)}")
+
+ dataloader_seed = config.seed if (config.mode == "eval" or is_torch_xla_available() or getattr(config.data, "force_seed", False)) else 42
+ gprint(f"Dataloader seed: {dataloader_seed}")
+
+ if skip_train:
+ train_loader = None
+ else:
+ train_kwargs = dict(drop_last=True)
+ train_dataloader_generator = torch.Generator().manual_seed(dataloader_seed)
+ dl_cls = torch.utils.data.DataLoader
+ if getattr(config.data, "webdataset_iterable", False) or getattr(config.data, "webdataset_indexed", False):
+ train_kwargs.pop("drop_last", None)
+
+ if getattr(config.loader, "disable_prefetch", False):
+ train_kwargs["prefetch_factor"] = 1
+
+ if getattr(config.data, "force_disable_shuffle", False) is False:
+ if getattr(config.data, "webdataset_iterable", False):
+ import webdataset
+ dl_cls = webdataset.WebLoader
+ train_kwargs["shuffle"] = False
+ train_kwargs["prefetch_factor"] = 8
+ elif getattr(config.data, "webdataset_indexed", False):
+ import wids
+ train_kwargs["sampler"] = wids.DistributedChunkedSampler(train_set, shuffle=True)
+ elif isinstance(train_set, torch.utils.data.IterableDataset) is False:
+ train_kwargs["shuffle"] = True
+
+ if "tokens" in config.data.train and config.data.pin_dataset_to_gpu:
+ if config.backend == 'cuda':
+ cur_mb = torch.cuda.memory_reserved() / 1e9
+ rprint(f"Moving dataloader to device {device} with: {cur_mb} GB of memory reserved")
+ train_set = train_set.to(device=device)
+ if config.backend == 'cuda':
+ cur_mb = torch.cuda.memory_reserved() / 1e9
+ rprint(f"Moved dataloader to device {device} with: {cur_mb} GB of memory reserved")
+
+ if "tokens" in config.data.train:
+ if getattr(config.data, "use_custom_tensordict_collate", False):
+ train_kwargs["collate_fn"] = TensorCollate(device=device, enable_cuda_in_tensordict_collate=config.data.enable_cuda_in_tensordict_collate)
+ else:
+ train_kwargs["collate_fn"] = identity
+
+ if getattr(config.data, "use_packing_collate", False):
+ generator = torch.Generator().manual_seed(dataloader_seed)
+ token_collate = train_kwargs["collate_fn"] if getattr(config.data, "use_custom_tensordict_collate", False) else None
+ train_kwargs["collate_fn"] = PackingCollate(config, train_set, config.model.length, generator, tensor_collate=token_collate, tokenizer=tokenizer)
+
+ if getattr(config.data, "use_weighted_tensordict_sampler", False):
+ generator = torch.Generator().manual_seed(dataloader_seed)
+ train_kwargs['sampler'] = WeightedDatasetSampler(train_set, generator=generator)
+ train_kwargs["shuffle"] = False
+ else:
+ train_kwargs["shuffle"] = True
+
+ if getattr(config.data, "use_list_collate", False):
+ train_kwargs["collate_fn"] = lambda x: x
+
+ if getattr(config.data, "force_shuffle_train", False):
+ rprint("Forcing shuffle on train dataloader")
+ train_kwargs["shuffle"] = True
+
+ if getattr(config.data, "force_disable_shuffle_train", False):
+ rprint("Forcing disable shuffle on train dataloader")
+ train_kwargs["shuffle"] = False
+
+ if getattr(config.data, "force_distributed_sampler", False):
+ import torch_xla.runtime as xr
+ train_kwargs["sampler"] = torch.utils.data.distributed.DistributedSampler(
+ train_set,
+ num_replicas=xr.world_size(),
+ rank=xr.global_ordinal(),
+ shuffle=True
+ )
+
+ if getattr(config.data, "use_identity_collate", False):
+ train_kwargs["collate_fn"] = lambda x: x
+
+ if train_set.__class__.__name__ == "WebLoader":
+ train_loader = train_set
+ else:
+ rprint(f"Train dataloader kwargs: {train_kwargs}")
+ train_loader = dl_cls(
+ train_set,
+ batch_size=None if getattr(config.data, "webdataset_iterable", False) else config.loader.batch_size,
+ num_workers=config.loader.num_workers,
+ pin_memory=config.loader.pin_memory,
+ persistent_workers=config.loader.num_workers > 0 and getattr(config.loader, "persistent_workers", True),
+ generator=train_dataloader_generator,
+ **train_kwargs,
+ )
+ train_loader.tokenizer = tokenizer
+
+ if skip_valid:
+ valid_loader = None
+ else:
+ shuffle_valid = True
+ valid_dataloader_generator = torch.Generator().manual_seed(dataloader_seed)
+ valid_kwargs = dict(drop_last=True)
+
+ dl_cls = torch.utils.data.DataLoader
+ if getattr(config.data, "webdataset_iterable", False) or getattr(config.data, "webdataset_indexed", False):
+ valid_kwargs.pop("drop_last", None)
+
+ if getattr(config.data, "force_disable_shuffle", False) is False:
+ if getattr(config.data, "webdataset_iterable", False):
+ valid_kwargs["shuffle"] = False
+ import webdataset
+ dl_cls = webdataset.WebLoader
+ elif getattr(config.data, "webdataset_indexed", False):
+ import wids
+ valid_kwargs["sampler"] = wids.DistributedChunkedSampler(valid_set, shuffle=True)
+ elif isinstance(valid_set, torch.utils.data.IterableDataset) is False and shuffle_valid:
+ valid_kwargs["shuffle"] = shuffle_valid
+
+ if "tokens" in config.data.valid:
+ if getattr(config.data, "use_custom_tensordict_collate", False):
+ valid_kwargs["collate_fn"] = TensorCollate(device=device, enable_cuda_in_tensordict_collate=config.data.enable_cuda_in_tensordict_collate)
+ else:
+ valid_kwargs["collate_fn"] = identity
+
+ if getattr(config.data, "use_packing_collate", False):
+ generator = torch.Generator().manual_seed(dataloader_seed)
+ token_collate = valid_kwargs["collate_fn"] if getattr(config.data, "use_custom_tensordict_collate", False) else None
+ valid_kwargs["collate_fn"] = PackingCollate(config, valid_set, config.model.length, generator, tensor_collate=token_collate, tokenizer=tokenizer)
+
+ if getattr(config.data, "use_weighted_tensordict_sampler", False):
+ generator = torch.Generator().manual_seed(dataloader_seed)
+ valid_kwargs['sampler'] = WeightedDatasetSampler(valid_set, generator=generator)
+
+ if getattr(config.data, "shuffle_valid", False):
+ torch.manual_seed(config.seed)
+
+ valid_kwargs["shuffle"] = getattr(config.data, "shuffle_valid", False)
+
+ if getattr(config.data, "force_distributed_sampler", False):
+ import torch_xla.runtime as xr
+ valid_kwargs["sampler"] = torch.utils.data.distributed.DistributedSampler(
+ valid_set,
+ num_replicas=xr.world_size(),
+ rank=xr.global_ordinal(),
+ shuffle=True
+ )
+
+ if valid_set.__class__.__name__ == "WebLoader":
+ valid_loader = valid_set
+ else:
+ rprint(f"Valid dataloader kwargs: {valid_kwargs}")
+ valid_loader = dl_cls(
+ valid_set,
+ batch_size=None if getattr(config.data, "webdataset_iterable", False) else config.loader.eval_batch_size,
+ num_workers=getattr(config.loader, "num_eval_workers", config.loader.num_workers),
+ pin_memory=config.loader.pin_memory,
+ generator=valid_dataloader_generator,
+ persistent_workers=False,
+ **valid_kwargs,
+ )
+ # Will be used in generative perplexity calculation
+ valid_loader.tokenizer = tokenizer
+
+ return train_loader, valid_loader
+
+
+# Samplers adapted from: https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/fault_tolerant_sampler.py
+
+
+class RandomFaultTolerantSampler(torch.utils.data.RandomSampler):
+
+ def __init__(self, *args, generator=None, **kwargs):
+ # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,
+ # which should be reproducible if pl.seed_everything was called beforehand.
+ # This means that changing the seed of the experiment will also change the
+ # sampling order.
+ if generator is None:
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
+ generator = torch.Generator().manual_seed(seed)
+ kwargs.pop("shuffle", None)
+ super().__init__(*args, generator=generator, **kwargs)
+ self.counter = 0
+ self.restarting = False
+
+ def state_dict(self):
+ return {"random_state": self.generator.get_state(), "counter": self.counter}
+
+ def load_state_dict(self, state_dict):
+ self.generator.set_state(state_dict.get("random_state"))
+ self.counter = state_dict["counter"]
+ # self.start_counter = self.counter
+ self.restarting = True
+
+ # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
+ # epoch, and subsequent epoch will have very few batches.
+
+ def __iter__(self) -> typing.Iterator[int]:
+ n = len(self.data_source)
+
+ self.state = self.generator.get_state()
+ indices = torch.randperm(n, generator=self.generator).tolist()
+
+ if not self.restarting:
+ self.counter = 0
+ else:
+ indices = indices[self.counter :]
+ self.restarting = False
+
+ for index in indices:
+ self.counter += 1
+ yield index
+
+ self.counter = 0
+
+
+class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.counter = 0
+ self.restarting = False
+
+ def state_dict(self):
+ return {"epoch": self.epoch, "counter": self.counter}
+
+ def load_state_dict(self, state_dict):
+ self.epoch = state_dict["epoch"]
+ self.counter = state_dict["counter"]
+ self.restarting = True
+
+ # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
+ # epoch, and subsequent epoch will have very few batches.
+ def __iter__(self):
+ if self.shuffle:
+ # deterministically shuffle based on epoch and seed
+ g = torch.Generator()
+ g.manual_seed(self.seed + self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
+ else:
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
+
+ if not self.drop_last:
+ # add extra samples to make it evenly divisible
+ padding_size = self.total_size - len(indices)
+ if padding_size <= len(indices):
+ indices += indices[:padding_size]
+ else:
+ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
+ else:
+ # remove tail of data to make it evenly divisible.
+ indices = indices[: self.total_size]
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.rank : self.total_size : self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ if not self.restarting:
+ self.counter = 0
+ else:
+ indices = indices[self.counter :]
+ self.restarting = False
+
+ for index in indices:
+ self.counter += 1
+ yield index
+
+ self.counter = 0
+
+
+if __name__ == "__main__":
+ import os
+
+ with breakpoint_on_error():
+ from omegaconf import OmegaConf
+
+ cc12m_config = OmegaConf.create(
+ {
+ "model": {
+ "image_model": True,
+ "unified_model": True,
+ },
+ "data": {
+ "tokenizers_parallelism": False,
+ "resolution": 128,
+ "train": "pixparse/cc12m-wds",
+ "val": "pixparse/cc12m-wds",
+ "streaming": False,
+ "precache": True,
+ "tokenizer_name_or_path": "gpt2",
+ "n_val_samples": None,
+ "n_train_samples": None,
+ "block_size": 32,
+ "data_dir": "/path/to/cc12m",
+ },
+ }
+ )
+
+ imagenet_config = OmegaConf.create(
+ {
+ "model": {
+ "image_model": True,
+ },
+ "data": {
+ "resolution": 128,
+ "train": "ILSVRC/imagenet-1k",
+ "val": "ILSVRC/imagenet-1k",
+ "streaming": False,
+ "precache": True,
+ "tokenizer_name_or_path": "gpt2",
+ },
+ }
+ )
+
+ facecaption_config = OmegaConf.create(
+ {
+ "seed": 12345,
+ "model": {
+ "image_model": True,
+ },
+ "data": {
+ "resolution": 256,
+ "train": "facecaption",
+ "val": "facecaption",
+ "streaming": False,
+ "precache": False,
+ "tokenizer_name_or_path": "gpt2",
+ "cache_dir": os.environ["HF_DATASETS_CACHE"],
+ "raw_data_dir": "/grogu/user/mprabhud/data/diffusion/facecaption",
+ "block_size": 32,
+ },
+ "loader": {
+ "num_workers": 0,
+ "batch_size": 1,
+ "eval_batch_size": 1,
+ },
+ "trainer": {
+ "devices": 1,
+ "num_nodes": 1,
+ "accumulate_grad_batches": 1,
+ },
+ }
+ )
+
+ tokenizer = get_tokenizer(facecaption_config)
+ dataset = get_dataset(
+ dataset_name=facecaption_config.data.train,
+ mode="train",
+ config=facecaption_config,
+ tokenizer=tokenizer,
+ )
+ test = next(iter(dataset))
+ breakpoint()
+
+
+
+from typing import List, Dict
+import torch
+from tensordict import TensorDict
+def process_batch(batch: TensorDict):
+ if isinstance(batch, list):
+ return [process_batch(b) for b in batch]
+ else:
+ if "write_flag" in batch:
+ del batch["write_flag"]
+ if "dataset_idx" in batch:
+ del batch["dataset_idx"]
+ batch.auto_batch_size_()
+ return batch
+
+def ignore_slice(tensor, slice, padding_token_id):
+ tensor["modality"][slice] = -1
+ tensor["attention_mask"][slice] = 0
+ tensor["input_ids"][slice] = padding_token_id
+ if "sample_ids" in tensor:
+ tensor["sample_ids"][slice] = -1
+ else:
+ tensor["sample_ids"] = torch.full(tensor["input_ids"].shape, fill_value=-1, dtype=tensor["input_ids"].dtype, device=tensor["input_ids"].device)
+
+class PackingCollate:
+ def __init__(self, config, dataset, seq_length, generator, tensor_collate=None, tokenizer=None):
+ self.dataset = dataset
+ self.seq_length = seq_length
+ self.tensor_collate = tensor_collate
+ self.generator = generator
+ self.tokenizer = tokenizer
+ self.padding_token_id = tokenizer.pad_token_id
+ self.eos_token_id = tokenizer.eos_token_id
+ self.disable_packing = getattr(config.data, "disable_packing", False)
+ img_special_tokens = tokenizer("", add_special_tokens=False)['input_ids']
+ assert len(img_special_tokens) == 1
+ self.image_token_id = img_special_tokens[0]
+
+ def __call__(self, batch: TensorDict):
+ if self.tensor_collate is not None:
+ if isinstance(batch, list):
+ batch = [self.tensor_collate(b) for b in batch]
+ else:
+ batch = self.tensor_collate(batch)
+
+ B = len(batch)
+ seq_length = self.seq_length
+
+ batch = process_batch(batch)
+ assert batch[0].batch_size is None or len(batch[0].batch_size) == 1
+
+ new_batch = batch[0].new_zeros((B, seq_length))
+ ignore_slice(new_batch, slice(None, None), self.padding_token_id)
+
+ for i in range(B):
+ total_length = 0
+ sample_idx = 0
+ sample_queue = [batch[i]]
+
+ # We originally get bs number of samples but since we're packing, we probably need more so we randomly select.
+ while total_length < seq_length:
+ if self.disable_packing and sample_idx > 0:
+ break
+ if not sample_queue:
+ dataset_idx = torch.randint(len(self.dataset.datasets), (1,), generator=self.generator).item()
+ element_idx = torch.randint(len(self.dataset.datasets[dataset_idx]), (1,), generator=self.generator).item()
+ sample = self.dataset[(dataset_idx, element_idx)]
+ sample = process_batch(sample)
+ else:
+ sample = sample_queue.pop(0)
+
+ available_length = seq_length - total_length
+ if available_length < sample.shape[0] // 4:
+ if total_length > 0:
+ break
+ else:
+ continue
+
+ if "sample_ids" not in sample:
+ sequence_starts = (sample['input_ids'] == self.padding_token_id).long()
+ sample["sample_ids"] = torch.cumsum(sequence_starts, dim=0) - 1
+ processed_ids = torch.where(sample["sample_ids"] < 0, torch.zeros_like(sample["sample_ids"]), -1)
+ sample["sample_ids"] = processed_ids
+
+ if not ((sample["sample_ids"] == 0) | (sample["sample_ids"] == -1)).all():
+ assert (sample["modality"] == 0).all()
+
+ first_neg_one = (sample["sample_ids"] == -1).nonzero(as_tuple=True)[0]
+
+ if first_neg_one.numel() > 0:
+ first_neg_one = first_neg_one[0].item()
+ else:
+ assert sample["attention_mask"].all()
+ first_neg_one = len(sample["attention_mask"])
+
+ valid_slice = slice(None, min(first_neg_one, available_length))
+ new_length = min(first_neg_one, available_length)
+
+ sample["sample_ids"][valid_slice] = sample_idx
+ new_batch[i, total_length:total_length+new_length] = sample[valid_slice]
+
+ total_length += new_length
+ sample_idx += 1
+
+ if (new_batch["sample_ids"] == -1).all():
+ gprint(f"WARNING!!!! All sample ids are -1 in packing collate before ignore")
+
+ if new_batch["modality"][i, -1] == 1:
+ # Find contiguous sequence of image tokens from the end
+ modality_slice = new_batch["modality"][i]
+ is_image = modality_slice == 1
+
+ # Get indices where modality changes
+ change_points = torch.where(is_image[:-1] != is_image[1:])[0] + 1
+
+ if change_points.numel() > 0 and is_image[-1]:
+ # Get start of last contiguous image sequence
+ start_pos = change_points[-1].item()
+ assert (new_batch["modality"][i, start_pos:] == 1).all()
+ try:
+ if start_pos > 0 and new_batch["input_ids"][i, start_pos - 1] == self.image_token_id:
+ start_pos -= 1
+
+ if start_pos > 0 and new_batch["input_ids"][i, start_pos - 1] != self.eos_token_id:
+ new_batch["input_ids"][i, start_pos] = self.eos_token_id
+ new_batch["attention_mask"][i, start_pos] = 1
+ new_batch["modality"][i, start_pos] = 0
+ start_pos += 1
+
+ except IndexError:
+ print(f"WARNING!!!! ERROR IN PACKING COLLATE")
+
+ ignore_slice(new_batch[i], slice(start_pos, None), self.padding_token_id)
+
+ if (new_batch["sample_ids"] == -1).all():
+ gprint(f"WARNING!!!! All sample ids are -1 in packing collate after ignore")
+
+ return new_batch
+
diff --git a/decoupled_utils.py b/decoupled_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d792e5c8cf304846672103fe70a3aa04c2e9597
--- /dev/null
+++ b/decoupled_utils.py
@@ -0,0 +1,1053 @@
+"""
+A collection of assorted utilities for deep learning with no required dependencies aside from PyTorch, NumPy and the Python stdlib.
+
+TODO: Validate required Python version.
+"""
+import contextlib
+import builtins
+import functools
+import glob
+import hashlib
+import io
+import os
+import pickle
+import subprocess
+import sys
+import time
+import traceback
+from collections import defaultdict
+from datetime import datetime
+from functools import lru_cache, partial, wraps
+from importlib import import_module
+from importlib.util import find_spec
+from io import BytesIO
+from pathlib import Path
+from typing import Any, Optional
+from urllib.parse import urlparse
+import inspect
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch import Tensor
+
+# TPUs:
+# wait_device_ops
+# https://github.com/pytorch/xla/blob/2766f9f756e8c8150ea3b7df98762c2f82d66e39/examples/debug/train_resnet_benchmark.py#L18
+
+# From diffusers
+import importlib
+import importlib.metadata
+
+# TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better.
+def _is_package_available(pkg_name: str, return_version: bool = False):
+ # Check if the package spec exists and grab its version to avoid importing a local directory
+ package_exists = importlib.util.find_spec(pkg_name) is not None
+ package_version = "N/A"
+ if package_exists:
+ try:
+ # Primary method to get the package version
+ package_version = importlib.metadata.version(pkg_name)
+ except importlib.metadata.PackageNotFoundError:
+ # Fallback method: Only for "torch" and versions containing "dev"
+ if pkg_name == "torch":
+ try:
+ package = importlib.import_module(pkg_name)
+ temp_version = getattr(package, "__version__", "N/A")
+ # Check if the version contains "dev"
+ if "dev" in temp_version:
+ package_version = temp_version
+ package_exists = True
+ else:
+ package_exists = False
+ except ImportError:
+ # If the package can't be imported, it's not available
+ package_exists = False
+ else:
+ # For packages other than "torch", don't attempt the fallback and set as not available
+ package_exists = False
+ if return_version:
+ return package_exists, package_version
+ else:
+ return package_exists
+
+# https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py#L281
+@lru_cache
+def is_torch_cuda_available():
+ return torch.cuda.is_available()
+
+@lru_cache
+def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
+ """
+ Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set
+ the USE_TORCH_XLA to false.
+ """
+ assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true."
+ _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla", return_version=True)
+ if not _torch_xla_available:
+ return False
+
+ import torch_xla
+
+ if check_is_gpu:
+ return torch_xla.runtime.device_type() in ["GPU", "CUDA"]
+ elif check_is_tpu:
+ return torch_xla.runtime.device_type() == "TPU"
+
+ return True
+
+def get_available_backend():
+ if is_torch_cuda_available():
+ backend = "cuda"
+ elif is_torch_xla_available():
+ backend = "xla"
+ else:
+ backend = "cpu"
+ return backend
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+ builtins.HAS_XLA_SPAWNED = False
+
+def use_dist():
+ return dist.is_available() and dist.is_initialized()
+
+def get_device():
+ return torch.device(f"cuda:{get_rank()}")
+
+def get_tpu_devices():
+ import torch_xla.core.xla_model as xm
+ return xm.get_xla_supported_devices()
+
+def get_num_nodes():
+ if is_torch_cuda_available():
+ rprint(f"Warning: get_num_nodes() is not supported for CUDA. Returning world size // num devices.")
+ return get_world_size() // get_num_devices()
+ if is_torch_xla_available():
+ from torch_xla._internal import tpu
+ return tpu.num_tpu_workers()
+
+def get_num_devices():
+ """
+ Number of physical "devices" on a single node. In CUDA, this is the number of GPUs. In XLA, this is the number of TPUs *chips* so, e.g., a v4 slice will always return 4 (even if part of a larger slice).
+ """
+ if is_torch_cuda_available():
+ return torch.cuda.device_count()
+ else:
+ return 1
+
+def get_world_size():
+ if is_torch_xla_available():
+ import torch_xla.runtime as xr
+ return xr.world_size()
+ elif use_dist():
+ return dist.get_world_size()
+ elif 'WORLD_SIZE' in os.environ:
+ return int(os.environ['WORLD_SIZE'])
+ elif is_torch_cuda_available():
+ return torch.cuda.device_count()
+ else:
+ return 1
+
+@lru_cache
+def get_xla_rank():
+ # When using spmd, these return 0 regardless of node [e.g., essentially return the local rank]
+ # import torch_xla.core.xla_model as xm
+ # return xm.get_ordinal()
+ # from accelerate import PartialState
+ # return PartialState().local_process_index
+ from torch_xla._internal import tpu
+ task_id = tpu.task_id() # Num chips
+ worker_id = tpu.worker_id() # Num workers [e.g., which node]
+ return task_id if task_id is not None else worker_id
+
+def get_rank(check_for_group: bool = False):
+ """ Helper function that returns torch.distributed.get_rank() if DDP has been initialized otherwise it returns 0.
+ """
+ if os.environ.get("FORCE_MAIN_RANK", "0") == "1":
+ return 0
+ elif is_torch_xla_available():
+ if builtins.HAS_XLA_SPAWNED:
+ return get_xla_rank()
+ else:
+ return 0
+ elif use_dist():
+ return dist.get_rank()
+ elif (rank := os.environ.get("RANK", None)) is not None:
+ return int(rank) # RANK is set by torch.distributed.launch
+ elif (rank := os.environ.get("SLURM_PROCID", None)) is not None:
+ return int(rank) # SLURM_PROCID is set by SLURM
+ elif check_for_group:
+ # if neither pytorch, SLURM env vars are set
+ # check NODE_RANK/GROUP_RANK and LOCAL_RANK env vars
+ # assume global_rank is zero if undefined
+ node_rank = os.environ.get("NODE_RANK", os.environ.get("GROUP_RANK", 0))
+ local_rank = os.environ.get("LOCAL_RANK", 0)
+ return 0 if (int(node_rank) == 0 and int(local_rank) == 0) else 1
+ else:
+ return 0
+
+def is_main_process():
+ return get_rank() == 0
+
+def get_local_rank():
+ if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+ return xm.get_local_ordinal()
+ else:
+ return int(os.environ.get("LOCAL_RANK", 0))
+
+def is_local_main_process():
+ return get_local_rank() == 0
+
+def get_num_gpus() -> int:
+ return get_world_size()
+
+def rank_zero_fn(fn):
+ @wraps(fn)
+ def wrapped_fn(*args: Any, **kwargs: Any):
+ if is_main_process():
+ return fn(*args, **kwargs)
+ return None
+
+ return wrapped_fn
+
+def barrier():
+ if use_dist() or getattr(builtins, "HAS_XLA_SPAWNED", False):
+ frame = inspect.currentframe().f_back
+ filename = frame.f_code.co_filename
+ lineno = frame.f_lineno
+ code_context = inspect.getframeinfo(frame).code_context[0].strip()
+ debug_log_func(f"Before barrier - {filename}:{lineno} - {code_context}")
+ if is_torch_xla_available() and getattr(builtins, "HAS_XLA_SPAWNED", False):
+ import torch_xla.core.xla_model as xm
+ xm.rendezvous('barrier')
+ else:
+ torch.distributed.barrier()
+ debug_log_func(f"After barrier - {filename}:{lineno} - {code_context}")
+
+def get_hostname():
+ return __import__('socket').gethostname().removesuffix('.eth')
+
+def get_slurm_job_id():
+ if os.environ.get("SLURM_ARRAY_JOB_ID", None) is not None and os.environ.get("SLURM_ARRAY_TASK_ID", None) is not None:
+ job_str = f"{os.environ.get('SLURM_ARRAY_JOB_ID')}_{os.environ.get('SLURM_ARRAY_TASK_ID')}"
+ elif os.environ.get("SLURM_JOB_ID", None) is not None:
+ job_str = os.environ.get("SLURM_JOB_ID")
+ else:
+ job_str = None
+ return job_str
+
+def get_restart_str():
+ restart_str = f"r{os.environ.get('SLURM_RESTART_COUNT', '0')}_" if os.environ.get('SLURM_RESTART_COUNT', None) is not None else ""
+ if "TORCHELASTIC_RESTART_COUNT" in os.environ and os.environ.get("TORCHELASTIC_RESTART_COUNT", '0') != '0':
+ restart_str = f"r{os.environ.get('TORCHELASTIC_RESTART_COUNT', '0')}_"
+ return restart_str
+
+def get_slurm_filename_info():
+ job_str = f"{get_slurm_job_id()}__" if get_slurm_job_id() is not None else ""
+ restart_str = get_restart_str()
+ if "SLURM_NODEID" in os.environ:
+ node_str = f"{os.environ.get('SLURM_NODEID', '')}_"
+ elif "SLURM_PROCID" in os.environ:
+ node_str = f"{os.environ.get('SLURM_PROCID', '')}_"
+ else:
+ node_str = ""
+
+ return f"{job_str}{restart_str}{node_str}{get_rank()}"
+
+def get_slurm_log_prefix():
+ if "SLURM_NODEID" in os.environ:
+ slurm_nodestr = f", Node:{os.environ.get('SLURM_NODEID', 'N/A')}"
+ elif "SLURM_PROCID" in os.environ:
+ slurm_nodestr = f", Node:{os.environ.get('SLURM_PROCID', 'N/A')}"
+ else:
+ slurm_nodestr = ""
+
+ jobid = get_slurm_job_id()
+ jobid_str = f", JID:{jobid}" if jobid is not None else ""
+ restart_str = f", {get_restart_str()}" if get_restart_str() != "" else ""
+ return f"Rank:{get_rank()}{slurm_nodestr}{jobid_str}{restart_str}"
+
+def slurm_prefix_func():
+ timestamp = datetime.now().strftime("[%Y-%m-%d %H:%M:%S]")
+ return f"{timestamp} [{get_slurm_log_prefix()}]"
+
+try:
+ from unidisc.utils.logging_utils import log_info, log_debug, set_logger, log_memory
+ set_logger(__name__)
+ info_log_func = log_info
+ debug_log_func = log_debug
+ prefix_func = None
+except Exception as e:
+ print(e)
+ def simple_print_func(*args, **kwargs):
+ print(*args)
+ info_log_func = simple_print_func
+ debug_log_func = simple_print_func
+ prefix_func = slurm_prefix_func
+
+def gprint(*args, main_process_only=False, **kwargs):
+ """
+ Prints to console + log as INFO on regardless of rank.
+ """
+ if prefix_func is not None:
+ args = (prefix_func(), *args)
+ info_log_func(*args, main_process_only=main_process_only, **kwargs)
+
+def dprint(*args, **kwargs):
+ """
+ Prints to console + log as DEBUG on regardless of rank.
+ """
+ if prefix_func is not None:
+ args = (prefix_func(), *args)
+ debug_log_func(*args, **kwargs)
+
+def rprint(*args, **kwargs):
+ """
+ Prints to console + log as INFO on main process only. All ranks also print to log file [as DEBUG] if called, but this is not required [e.g., no barrier used]
+ """
+ gprint(*args, main_process_only=True, **kwargs)
+
+def mprint(*args, **kwargs):
+ log_memory(*args, **kwargs)
+ # dprint(*args, **kwargs)
+
+log_func = rprint
+
+def process_file_prefix():
+ datetime_str = datetime.now().strftime("%Y_%m_%d-%H_%M_%S_%f")[:-5]
+ return f"{get_slurm_filename_info()}_{get_hostname()}_{datetime_str}"
+
+def get_info():
+ return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE).stdout.decode("utf-8")
+
+def print_trainable_parameters(model):
+ """
+ Prints the number of trainable parameters in the model.
+ """
+ trainable_params = 0
+ all_param = 0
+ for _, param in model.named_parameters():
+ all_param += param.numel()
+ if param.requires_grad:
+ trainable_params += param.numel()
+ log_func(
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
+ )
+
+def print_params(model):
+ log_func(f"Total Parameters: {sum(p.numel() for p in model.parameters()):,}")
+ log_func(f"Unfrozen Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
+ log_func(f"Frozen Parameters: {sum(p.numel() for p in model.parameters() if not p.requires_grad):,}")
+
+
+def calculate_storage_size(obj, storage_view_sizes, count_views=False):
+ if isinstance(obj, torch.Tensor):
+ storage = obj.storage()
+ storage_id = id(storage)
+ element_size = storage.element_size()
+ storage_size = storage.size() * element_size
+ view_size = obj.numel() * element_size
+
+ # We count storage size only for the first time we encounter the storage
+ if storage_id not in storage_view_sizes:
+ storage_view_sizes[storage_id] = storage_size
+ print_size = storage_size
+ else:
+ print_size = 0 if not count_views or not obj._is_view() else view_size
+
+ if count_views or not obj._is_view():
+ log_func(f"{'View' if obj._is_view() else 'Storage'} Tensor: " f"shape {obj.size()}, size {print_size / (1024 ** 2):.2f} MB")
+
+ return print_size if count_views or not obj._is_view() else 0 # Count views only if requested
+ elif isinstance(obj, dict):
+ # Recurse for dictionaries
+ return sum(calculate_storage_size(v, storage_view_sizes, count_views) for v in obj.values())
+ elif isinstance(obj, (list, tuple)):
+ # Recurse for lists or tuples
+ return sum(calculate_storage_size(item, storage_view_sizes, count_views) for item in obj)
+ elif hasattr(obj, "__dataclass_fields__"):
+ # Recurse for dataclasses based on their fields
+ fields = getattr(obj, "__dataclass_fields__")
+ return sum(calculate_storage_size(getattr(obj, f), storage_view_sizes, count_views) for f in fields)
+ else:
+ # Non-Tensor, non-dict, non-list objects are not measured
+ return 0
+
+
+def calculate_total_size(obj, count_views=False):
+ storage_view_sizes = defaultdict(int)
+ total_size = calculate_storage_size(obj, storage_view_sizes, count_views)
+ total_unique_storage_size = sum(storage_view_sizes.values())
+ log_func(f"Total unique storage size: {total_unique_storage_size / (1024 ** 2):.2f} MB")
+ if count_views: # Only add view sizes to total if requested
+ total_view_size = total_size - total_unique_storage_size
+ log_func(f"Total view size (if counted): {total_view_size / (1024 ** 2):.2f} MB")
+ else:
+ log_func(f"Total size (without counting views): {total_size / (1024 ** 2):.2f} MB")
+
+ return total_size
+
+
+def save_tensor_dict(tensor_dict: dict, path):
+ output_dict = {}
+ for k, v in tensor_dict.items():
+ if isinstance(v, Tensor):
+ if v.dtype == torch.float16 or v.dtype == torch.bfloat16:
+ output_dict[k] = v.to(dtype=torch.float32).detach().cpu().numpy()
+ else:
+ output_dict[k] = v.detach().cpu().numpy()
+ elif isinstance(v, np.ndarray):
+ output_dict[f"np_{k}"] = v
+ else:
+ output_dict[k] = v
+ np.savez_compressed(path, **output_dict)
+
+
+def load_tensor_dict(path: Path, object_keys=[]):
+ from jaxtyping import BFloat16 # TODO: Remove dependency
+ tensor_dict = {}
+ np_dict = np.load(path, allow_pickle=True)
+ for k, v in np_dict.items():
+ if k in object_keys:
+ tensor_dict[k] = v
+ elif v.dtype == BFloat16:
+ tensor_dict[k] = torch.from_numpy(v.astype(np.float32)).to(dtype=torch.bfloat16)
+ elif k.startswith("np_"):
+ tensor_dict[k.removeprefix("np_")] = v
+ else:
+ tensor_dict[k] = torch.from_numpy(v)
+ return tensor_dict
+
+
+def tensor_hash(tensor):
+ """Computes a SHA256 hash of a tensor. Useful for debugging to check equality in different places."""
+ tensor_bytes = tensor.detach().float().cpu().numpy().tobytes()
+ return hashlib.sha256(tensor_bytes).hexdigest()
+
+def module_hash(module: Optional[dict] = None, state_dict: Optional[dict] = None):
+ assert module is not None or state_dict is not None
+ state_dict = module.state_dict() if module is not None else state_dict
+ sorted_state_dict = {k: state_dict[k] for k in sorted(state_dict)}
+ params_cat = torch.cat([v.flatten() for _, v in sorted_state_dict.items()])
+ return tensor_hash(params_cat)
+
+def parameter_hash(params: list[torch.Tensor]):
+ return tensor_hash(torch.cat([p.cpu().flatten() for p in params]))
+
+def find_diff_params(state_dict_1, state_dict_2):
+ diff_keys = set(state_dict_1.keys()) ^ set(state_dict_2.keys()) # Symmetric difference to find keys not in both
+ matched_keys = set(state_dict_1.keys()) & set(state_dict_2.keys()) # Intersection to find keys in both
+
+ # Check for differences in matched keys
+ for key in matched_keys:
+ if not torch.equal(state_dict_1[key], state_dict_2[key]):
+ diff_keys.add(key)
+
+ return diff_keys
+
+
+def init_from_ckpt(module, path, ignore_keys=None, unfrozen_keys=None, strict=False, truncate=None, only_incl=None, verbose=True):
+ log_func(f"Loading {module.__class__.__name__} from checkpoint: {path}")
+ log_func(f"Strict Load: {strict}, Ignoring: {ignore_keys}, Unfreezing: {unfrozen_keys}, Truncating: {truncate}")
+
+ if ignore_keys is None:
+ ignore_keys = ()
+
+ if unfrozen_keys is None:
+ unfrozen_keys = ()
+
+ sd = torch.load(path, map_location="cpu")
+
+ if "state_dict" in sd.keys():
+ sd = sd["state_dict"]
+ elif "weight" in sd.keys():
+ sd = sd["weight"]
+
+ num_deleted = defaultdict(int)
+ for k in list(sd):
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ num_deleted[ik] += 1
+ del sd[k]
+
+ for k, v in num_deleted.items():
+ log_func(f"Deleted {v} keys due to ignore_key: {k}")
+
+ if truncate is not None:
+ for k in list(sd):
+ if k.startswith(truncate):
+ sd[k.replace(truncate, "")] = sd[k]
+ del sd[k]
+
+ num_ignored = defaultdict(int)
+ for n in module.state_dict().keys():
+ if n not in sd.keys():
+ for ik in ignore_keys:
+ if ik in n:
+ num_ignored[ik] += 1
+ else:
+ log_func(f"Missing {n}")
+
+ if only_incl is not None:
+ for k in list(sd):
+ keep = False
+ for ik in only_incl:
+ if ik in k:
+ keep = True
+ if not keep:
+ del sd[k]
+
+ for k, v in num_ignored.items():
+ log_func(f"Missing {v} keys due to ignore_key: {k}")
+
+ for n in sd.keys():
+ if n not in module.state_dict().keys():
+ log_func(f"Unexpected {n}")
+
+ checkpoint_keys = set(sd.keys())
+ current_keys = set(module.state_dict().keys())
+
+ if verbose:
+ log_func(f"Loading: {checkpoint_keys.intersection(current_keys)}")
+ else:
+ log_func(f"Loading {len(checkpoint_keys.intersection(current_keys))} keys into the model: {str(module.__class__)}")
+
+ module.load_state_dict(sd, strict=strict)
+
+ if len(unfrozen_keys) > 0:
+ for n, p in module.named_parameters():
+ p.requires_grad_ = False
+ for unfrozen_name in unfrozen_keys:
+ if unfrozen_name in n:
+ p.requires_grad_ = True
+ log_func(f"Unfreezing: {n}")
+
+ log_func(f"Restored from {path}")
+
+
+def check_gpu_memory_usage():
+ allocated = torch.cuda.memory_allocated()
+ reserved = torch.cuda.memory_reserved()
+ total_memory = torch.cuda.get_device_properties(int(get_local_rank())).total_memory
+
+ allocated_percent = (allocated / total_memory) * 100
+ reserved_percent = (reserved / total_memory) * 100
+
+ log_func(f"Allocated memory: {allocated_percent:.2f}%")
+ log_func(f"Reserved memory: {reserved_percent:.2f}%")
+ log_func(f'Available devices (CUDA_VISIBLE_DEVICES): {os.environ.get("CUDA_VISIBLE_DEVICES")}')
+
+ assert allocated_percent <= 5
+ assert reserved_percent <= 5
+
+
+def load_checkpoint_from_url(url: str, file_path: Optional[str] = None) -> Path:
+ if file_path is None:
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_path is not None:
+ filename = file_path
+
+ file_path = Path.home() / ".cache" / "pretrained_weights" / filename
+
+ file_path.parent.mkdir(parents=True, exist_ok=True)
+ if not os.path.exists(file_path):
+ log_func(f'Downloading: "{url}" to {file_path}\n')
+ torch.hub.download_url_to_file(url, file_path, progress=True)
+
+ return file_path
+
+
+# Copied from torch.profiler.profiler
+def tensorboard_trace_handler(dir_name: str, record_memory: bool = False, worker_name: Optional[str] = None, use_gzip: bool = True):
+ """
+ Outputs tracing files to directory of ``dir_name``, then that directory can be
+ directly delivered to tensorboard as logdir.
+ ``worker_name`` should be unique for each worker in distributed scenario,
+ it will be set to '[hostname]_[pid]' by default.
+ """
+ import os
+ import socket
+ import time
+
+ def handler_fn(prof: torch.profiler.profile) -> None:
+ nonlocal worker_name
+ if not os.path.isdir(dir_name):
+ try:
+ os.makedirs(dir_name, exist_ok=True)
+ except Exception as e:
+ raise RuntimeError("Can't create directory: " + dir_name) from e
+ if not worker_name:
+ worker_name = f"{socket.gethostname()}_{os.getpid()}"
+ # Use nanosecond here to avoid naming clash when exporting the trace
+ file_name = f"{worker_name}.{time.time_ns()}.pt.trace.json"
+ if use_gzip:
+ file_name = file_name + ".gz"
+
+ chrome_trace_path = os.path.join(dir_name, file_name)
+ memory_trace_path = os.path.join(dir_name, "memory_timeline.html")
+ if is_main_process() and record_memory:
+ try:
+ log_func(f"Exporting memory timeline: {memory_trace_path}")
+ prof.export_memory_timeline(memory_trace_path)
+ except Exception as e:
+ log_func(f"Failed to export memory timeline: {e}")
+
+ prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=100)
+
+ log_func(f"Exporting chrome trace to {chrome_trace_path}")
+ prof.export_chrome_trace(chrome_trace_path)
+
+ return handler_fn
+
+def get_date_time_str():
+ return datetime.now().strftime("%Y_%m_%d-%H_%M_%S.%f")[:-3]
+
+def save_memory_profile(profile_dir):
+ import wandb
+ rank_postfix = f"_rank_{get_rank()}" if use_dist() else ""
+ log_func(f"Saving memory profile to {profile_dir}")
+ os.makedirs(profile_dir, exist_ok=True)
+ torch.cuda.memory._dump_snapshot(f"{profile_dir}/memory_snapshot{rank_postfix}.pickle")
+ os.system(
+ f"python -m torch.cuda._memory_viz trace_plot {profile_dir}/memory_snapshot{rank_postfix}.pickle -o {profile_dir}/memory_snapshot{rank_postfix}.html"
+ )
+ torch.cuda.memory._save_segment_usage(f"{profile_dir}/segment{rank_postfix}.svg")
+ torch.cuda.memory._save_memory_usage(f"{profile_dir}/memory{rank_postfix}.svg")
+ torch.cuda.memory._record_memory_history(enabled=None)
+
+ log_func(f"Saved memory snapshot at: {profile_dir}/memory_snapshot{rank_postfix}.pickle")
+ log_func(f"Run the following to view the snapshot:\npython -m http.server --directory {profile_dir.resolve()} 6008")
+
+ if is_main_process() and wandb.run is not None:
+ wandb.log({'profile': wandb.Html(f"{profile_dir}/memory_snapshot{rank_postfix}.html")})
+ wandb.log({'profile': wandb.Html(f"{profile_dir}/memory_timeline{rank_postfix}.html")})
+
+def print_memory(verbose: bool = False, print_output: bool = True):
+ max_cur_reserved, max_peak_reserved, max_peak_allocated, max_cur_allocated = -1, -1, -1, -1
+ max_cur_reserved_device, max_peak_reserved_device, max_peak_allocated_device, max_cur_allocated_device = -1, -1, -1, -1
+ for device in range(torch.cuda.device_count()):
+ current_reserved_memory_MB = torch.cuda.memory_reserved(device=torch.device(f'cuda:{device}')) / (1024**2)
+ peak_reserved_memory_MB = torch.cuda.max_memory_reserved(device=torch.device(f'cuda:{device}')) / (1024**2)
+ peak_allocated_memory_MB = torch.cuda.max_memory_allocated(device=torch.device(f'cuda:{device}')) / (1024**2)
+ current_allocated_memory_MB = torch.cuda.memory_allocated(device=torch.device(f'cuda:{device}')) / (1024**2)
+
+ if current_reserved_memory_MB > max_cur_reserved:
+ max_cur_reserved = current_reserved_memory_MB
+ max_cur_reserved_device = device
+
+ if peak_reserved_memory_MB > max_peak_reserved:
+ max_peak_reserved = peak_reserved_memory_MB
+ max_peak_reserved_device = device
+
+ if peak_allocated_memory_MB > max_peak_allocated:
+ max_peak_allocated = peak_allocated_memory_MB
+ max_peak_allocated_device = device
+
+ if current_allocated_memory_MB > max_cur_allocated:
+ max_cur_allocated = current_allocated_memory_MB
+ max_cur_allocated_device = device
+
+ if verbose:
+ log_func(torch.cuda.memory_summary(abbreviated=False))
+
+ if print_output:
+ log_func(f"GPU Cur Reserved: {max_cur_reserved:.2f}MB on rank {max_cur_reserved_device}, Cur Allocated: {max_cur_allocated:.2f}MB on rank {max_cur_allocated_device}, Peak Reserved: {max_peak_reserved:.2f}MB on rank {max_peak_reserved_device}, Peak Allocated: {max_peak_allocated:.2f}MB on rank {max_peak_allocated_device}")
+
+ return max_cur_reserved
+
+def print_memory_summary():
+ val = print_memory(verbose=False, print_output=False)
+ log_func(f"GPU Cur Reserved: {val:.2f}MB, {val / (torch.cuda.get_device_properties(0).total_memory / 1024**2) * 100:.2f}%")
+
+@contextlib.contextmanager
+def show_memory_usage(empty_cache: bool = True, verbose: bool = False, print_output=False, show_caller: bool = True):
+ synchronize_device()
+ if empty_cache: clear_cache()
+
+ if show_caller:
+ callers = [
+ f"{frame.function} (...{frame.filename[-10:]}:{frame.lineno})"
+ for frame in inspect.stack()
+ if "contextlib" not in frame.filename
+ ][1:5]
+
+ decorated_func_name = ", ".join([inspect.currentframe().f_back.f_code.co_name, inspect.currentframe().f_back.f_back.f_code.co_name, inspect.currentframe().f_back.f_back.f_back.f_code.co_name])
+ caller_info = decorated_func_name + " " + " -> ".join(reversed(callers))
+ log_func(f"Before context (called by {caller_info}): {print_memory(verbose, print_output=False)}MB cur reserved")
+ caller_str = f", called by {caller_info} "
+ else:
+ caller_str = ""
+
+ yield
+
+ synchronize_device()
+ if empty_cache: clear_cache()
+ log_func(f"After context{caller_str}: {print_memory(verbose, print_output=print_output)}MB cur reserved")
+
+@contextlib.contextmanager
+def profile_memory(enable: bool = True, empty_cache: bool = True):
+ with contextlib.ExitStack() as stack:
+ stack.enter_context(show_memory_usage(empty_cache=empty_cache))
+ if enable and is_main_process(): torch.cuda.memory._record_memory_history()
+ yield
+ if enable: save_memory_profile(Path(f"output/profile/{get_date_time_str()}"))
+
+def profile_memory_decorator(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ with profile_memory():
+ return func(*args, **kwargs)
+ return wrapper
+
+class Profiler:
+ def __init__(self, output_dir, warmup_steps: int = 5, active_steps: int = 3, record_memory: bool = False, record_memory_only: bool = False):
+ self.record_memory = record_memory
+ self.profile_dir = Path(output_dir) / "profile"
+ self.profile_dir.mkdir(parents=True, exist_ok=True)
+ wait, warmup, active, repeat = 0, warmup_steps, active_steps, 0
+ self.total_steps = (wait + warmup + active) * (1 + repeat)
+ schedule = torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat)
+ profiler_kwargs = dict(record_shapes=True, with_stack=True)
+ if record_memory_only:
+ pass
+ else:
+ profiler_kwargs.update(
+ dict(
+ with_modules=True,
+ with_flops=True,
+ ),
+ )
+
+ self.profiler = torch.profiler.profile(
+ schedule=schedule,
+ on_trace_ready=tensorboard_trace_handler(self.profile_dir, record_memory=record_memory),
+ profile_memory=record_memory,
+ **profiler_kwargs
+ )
+ self.profiler.start()
+
+ def step(self, global_step: int):
+ self.profiler.step()
+ return global_step >= (self.total_steps - 1)
+
+ def finish(self):
+ self.profiler.stop()
+ if use_dist():
+ torch.distributed.barrier()
+ if is_main_process():
+ import wandb
+ traces = glob.glob(f"{self.profile_dir}/*.pt.trace.json*")
+ for trace in traces:
+ log_func(f"Adding {trace}")
+ wandb.save(trace, base_path=self.profile_dir, policy="now")
+
+ if use_dist():
+ torch.distributed.barrier()
+
+
+def get_pdb():
+ return import_module("pdb") if (any(["_pdbpp_path_hack" in str(p) for p in sys.path]) or find_spec("ipdb") is None) else import_module("ipdb")
+
+def _breakpoint(rank: Optional[int] = None, traceback: Optional[Any] = None):
+ if get_num_gpus() > 1:
+ if (is_main_process() if rank is None else get_rank() == rank):
+ if is_torch_xla_available():
+ from fairseq import pdb
+ pdb.set_trace()
+ else:
+ old_stdin = None
+ if isinstance(sys.stdin, io.TextIOWrapper):
+ old_stdin = sys.stdin
+ sys.stdin = open(0)
+ try:
+ log_func('Breakpoint triggered. You may need to type "up" to get to the correct frame')
+ log_func(f'Traceback: {traceback}')
+ if traceback is not None:
+ get_pdb().post_mortem(traceback)
+ else:
+ get_pdb().set_trace()
+ finally:
+ if old_stdin is not None:
+ sys.stdin.close()
+ sys.stdin = old_stdin
+ barrier()
+ else:
+ if traceback is not None:
+ get_pdb().post_mortem(traceback)
+ else:
+ log_func("Breakpoint triggered. You may need to type \"up\" to get to the correct frame")
+ get_pdb().set_trace(sys._getframe(1))
+
+def set_global_exists():
+ builtins.exists = lambda v: v is not None
+
+def set_global_breakpoint():
+ import ipdb
+
+ builtins.breakpoint = _breakpoint
+ builtins.st = ipdb.set_trace # We import st everywhere
+ builtins.ug = lambda: globals().update(locals())
+
+def set_timing_builtins(enable: bool = False, sync: bool = False):
+ builtins.start_timing = partial(start_timing, builtin=True, enable=False, sync=False)
+ builtins.end_timing = partial(end_timing, builtin=True, enable=False, sync=False)
+ builtins.ENABLE_TIMING = enable
+ builtins.ENABLE_TIMING_SYNC = sync
+
+def synchronize_device():
+ if is_torch_cuda_available():
+ torch.cuda.synchronize()
+ elif is_torch_xla_available():
+ import torch_xla
+ torch_xla.sync()
+ xm.wait_device_ops()
+
+def clear_cache():
+ if is_torch_cuda_available():
+ # This caused untold grief. Without calling `_cuda_clearCublasWorkspaces`,
+ # some model configurations would eventually cause a CUDA OOM during inference. See:
+ # https://github.com/pytorch/pytorch/issues/99835, https://github.com/pytorch/pytorch/issues/105181
+ torch._C._cuda_clearCublasWorkspaces()
+ torch._dynamo.reset()
+ import gc; gc.collect()
+ torch.cuda.empty_cache()
+ elif is_torch_xla_available():
+ rprint("Clearing cache not supported for XLA")
+
+def start_timing(message: str, enable: bool = False, sync: bool = False, builtin: bool = False):
+ if (builtin and ENABLE_TIMING) or enable:
+ if (builtin and ENABLE_TIMING_SYNC) or sync:
+ synchronize_device()
+ torch.cuda.nvtx.range_push(f"[SYNC] {message}" if ((builtin and ENABLE_TIMING_SYNC) or sync) else message)
+
+ return time.time()
+
+def end_timing(start_time: Optional[float] = None, enable: bool = False, sync: bool = False, builtin: bool = False):
+ if (builtin and ENABLE_TIMING) or enable:
+ if (builtin and ENABLE_TIMING_SYNC) or sync:
+ synchronize_device()
+ torch.cuda.nvtx.range_pop()
+
+ if start_time is not None:
+ return time.time() - start_time
+
+@contextlib.contextmanager
+def breakpoint_on_error():
+ set_global_breakpoint()
+ try:
+ yield
+ except Exception as e:
+ print("Exception...", e)
+ traceback.print_exc()
+ breakpoint(traceback=e.__traceback__)
+ raise e
+
+@contextlib.contextmanager
+def get_time_sync(enable: bool = True):
+ if enable and is_main_process():
+ synchronize_device()
+ start_time = time.time()
+ yield
+ if enable and is_main_process():
+ synchronize_device()
+ end_time = time.time()
+ print(f"Time taken: {end_time - start_time:.2f}s")
+
+def write_to_file(path: Path, text: str):
+ try:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with open(path, "a") as file:
+ file.write(text + "\n")
+ except:
+ log_func(f"Could not write to {path}")
+
+
+def to(obj, device):
+ if torch.is_tensor(obj):
+ return obj.to(device)
+ if isinstance(obj, dict):
+ return {k : to(v, device) for k, v in obj.items()}
+ if isinstance(obj, tuple):
+ return tuple(to(v, device) for v in obj)
+ if isinstance(obj, list):
+ return [to(v, device) for v in obj]
+ return obj
+
+def all_gather(data):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
+ Args:
+ data: any picklable object
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ world_size = get_num_gpus()
+ if world_size == 1:
+ return [data]
+
+ # serialized to a Tensor
+ buffer = pickle.dumps(data)
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to("cuda")
+
+ # obtain Tensor size of each rank
+ local_size = torch.tensor([tensor.numel()], device="cuda")
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
+ dist.all_gather(size_list, local_size)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ tensor_list = []
+ for _ in size_list:
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
+ if local_size != max_size:
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
+ tensor = torch.cat((tensor, padding), dim=0)
+ dist.all_gather(tensor_list, tensor)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(to(pickle.loads(buffer), torch.device('cpu')))
+
+ return data_list
+
+def get_modules(model: torch.nn.Module, cls: Any):
+ children = list(model.children())
+ if isinstance(model, cls):
+ return [model]
+ elif len(children) == 0:
+ return []
+ else:
+ return [ci for c in children for ci in get_modules(model=c, cls=cls)]
+
+map_chars = {
+ "/" : "__",
+ " " : "_",
+}
+
+def sanitize_filename(filename: str) -> str:
+ return "".join(map_chars.get(c, c) for c in filename if c.isalnum() or map_chars.get(c, c) in (" ", ".", "_", "-", "__"))
+
+
+def hash_str_as_int(s: str):
+ return int(hashlib.sha256(s.encode('utf-8')).hexdigest(), 16) % 10**8
+
+
+def torch_to_numpy(arr: Tensor):
+ if arr.dtype == torch.bfloat16:
+ return arr.float().cpu().detach().numpy()
+ else:
+ return arr.cpu().detach().numpy()
+
+def to_numpy(arr):
+ if isinstance(arr, Tensor):
+ return torch_to_numpy(arr)
+ else:
+ return arr
+
+class try_except:
+ def __init__(self, raise_error: bool = False, write_error_to_file: bool = False, clear_cuda_cache: bool = False):
+ self.raise_error = raise_error
+ self.write_error_to_file = write_error_to_file
+ self.clear_cuda_cache = clear_cuda_cache
+
+ def __call__(self, f):
+ @functools.wraps(f)
+ def inner(*args, **kwargs):
+ with self:
+ return f(*args, **kwargs)
+ return inner
+
+ def __enter__(self):
+ if self.clear_cuda_cache:
+ clear_cache()
+ return self
+
+ def __exit__(self, exc_type, exc_value, tb):
+ if exc_type is not None:
+ try:
+ log_func(f"Exception caught: {exc_type}")
+ log_func(traceback.format_exc())
+ if self.write_error_to_file:
+ timestamp = int(time.time_ns())
+ with open(f"exception_{timestamp}_{process_file_prefix()}.out", "w") as file:
+ file.write(traceback.format_exc())
+ except Exception as e:
+ print(f"Error writing to file: {e}")
+
+ if self.clear_cuda_cache:
+ clear_cache()
+
+ if self.raise_error:
+ raise exc_value
+ return True # Suppress the exception if raise_error is False
+
+def move_to(obj, device):
+ if torch.is_tensor(obj):
+ return obj.to(device=device)
+ elif isinstance(obj, dict):
+ res = {}
+ for k, v in obj.items():
+ res[k] = move_to(v, device)
+ return res
+ elif isinstance(obj, list):
+ res = []
+ for v in obj:
+ res.append(move_to(v, device))
+ return res
+ elif isinstance(obj, str):
+ return obj
+ else:
+ raise TypeError("Invalid type for move_to")
+
+
+
+import types
+def run_with_named_function(name, func, *args, **kwargs):
+ """
+ Runs the given function inside a dynamically created function with the specified name.
+ This causes the specified name to appear in the stack trace.
+
+ E.g., x = run_with_named_function(f"iter_{i}", self._ddpm_update, x, t, dt, x0=x0, x0_unmask=x0_unmask, **kwargs)
+
+ Parameters:
+ - name (str): The desired name to appear in the stack trace.
+ - func (callable): The function to execute.
+ - *args: Positional arguments to pass to func.
+ - **kwargs: Keyword arguments to pass to func.
+ """
+ def wrapper(*args, **kwargs):
+ return func(*args, **kwargs)
+ code = wrapper.__code__
+ new_code = types.CodeType(
+ code.co_argcount,
+ code.co_posonlyargcount if hasattr(code, "co_posonlyargcount") else 0,
+ code.co_kwonlyargcount,
+ code.co_nlocals,
+ code.co_stacksize,
+ code.co_flags,
+ code.co_code,
+ code.co_consts,
+ code.co_names,
+ code.co_varnames,
+ code.co_filename,
+ name, # Set the function name in the code object
+ code.co_firstlineno,
+ code.co_lnotab,
+ code.co_freevars,
+ code.co_cellvars
+ )
+ new_func = types.FunctionType(
+ new_code,
+ wrapper.__globals__,
+ name,
+ wrapper.__defaults__,
+ wrapper.__closure__,
+ )
+ return new_func(*args, **kwargs)
\ No newline at end of file
diff --git a/demo/README.md b/demo/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..193f17edf5b8070b2714bccb5ad1afe0f733d07b
--- /dev/null
+++ b/demo/README.md
@@ -0,0 +1,6 @@
+title: UniDisc Demo
+emoji: 🐢
+colorFrom: purple
+colorTo: blue
+sdk: docker
+pinned: false
\ No newline at end of file
diff --git a/demo/api_data_defs.py b/demo/api_data_defs.py
new file mode 100644
index 0000000000000000000000000000000000000000..639fa97966e874c2aa8f1ee0eab20758bf85b923
--- /dev/null
+++ b/demo/api_data_defs.py
@@ -0,0 +1,30 @@
+from typing import Any, Dict, List, Optional, Union
+from fastapi import FastAPI
+from pydantic import BaseModel
+from PIL import Image
+
+class ContentPart(BaseModel):
+ model_config = {"arbitrary_types_allowed": True}
+ type: str # "text" or "image_url"
+ text: Union[str, None] = None
+ image_url: Union[Dict[str, str], Image.Image, None] = None
+ is_mask: bool = False
+
+class ChatMessage(BaseModel):
+ role: str
+ content: List[ContentPart]
+
+class ChatRequest(BaseModel):
+ messages: List[ChatMessage]
+ model: str = "unidisc"
+ max_tokens: int = 1024
+ temperature: float = 0.9
+ top_p: float = 0.95
+ unmask_to_eos: bool = False # Controls masking behavior between BOS and EOS tokens
+ resolution: int = 256 # New: resolution for image (default: 256)
+ sampling_steps: int = 35 # New: number of sampling steps (default: 35)
+ maskgit_r_temp: float = 4.5 # new parameter default
+ cfg: float = 3.5 # new parameter default
+ sampler: str = "maskgit" # new parameter default
+ use_reward_models: bool = False
+ request_hash: Optional[str] = None
\ No newline at end of file
diff --git a/demo/client.py b/demo/client.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac3bfbc4f41099a1597f5fc78bf0e4e95221421f
--- /dev/null
+++ b/demo/client.py
@@ -0,0 +1,656 @@
+from fasthtml.common import *
+from fasthtml.svg import *
+from monsterui.all import *
+from pathlib import Path
+import requests
+import base64
+from PIL import Image
+import numpy as np
+import io
+import json
+
+SHOW_DEV_BUTTONS = True
+DEMO_DIR = Path("demo")
+ADD_DEV_FORM = True
+
+DEMOS = [
+ {
+ "name": "Dog",
+ "image": DEMO_DIR / "assets" / "dog.jpg",
+ "mask": DEMO_DIR / "assets" / "dog.json",
+ "text": "A brown bulldog.",
+ },
+ {
+ "name": "Pickup Truck",
+ "image": DEMO_DIR / "assets" / "pickup.jpg",
+ "mask": DEMO_DIR / "assets" / "pickup.json",
+ "text": "A pickup truck.",
+ },
+ {
+ "name": "Taj Mahal",
+ "image": DEMO_DIR / "assets" / "tajmahal.jpg",
+ "mask": DEMO_DIR / "assets" / "tajmahal.json",
+ "text": "The in .",
+ },
+ {
+ "name": "Venice",
+ "image": DEMO_DIR / "assets" / "venice.jpg",
+ "mask": DEMO_DIR / "assets" / "venice.json",
+ "text": "A in.",
+ },
+ {
+ "name": "T2I",
+ "text": "A sits at the counter of an art-deco loungebar, drinking whisky from a tumbler glass.",
+ }
+]
+
+# Use MonsterUI's theme headers.
+app, rt = fast_app(hdrs=(Theme.blue.headers(),))
+
+def square_crop(image: Image.Image) -> Image.Image:
+ width, height = image.size
+ side = min(width, height)
+ left = (width - side) // 2
+ top = (height - side) // 2
+ right = left + side
+ bottom = top + side
+ return image.crop((left, top, right, bottom))
+
+def process(image: Image.Image, desired_resolution: int = 512) -> Image.Image:
+ cropped_image = square_crop(image.convert("RGB"))
+ return cropped_image.resize((desired_resolution, desired_resolution), Image.LANCZOS)
+
+def encode_image(file: Path | io.BytesIO | Image.Image) -> Dict[str, str]:
+ if isinstance(file, Image.Image):
+ buffered = io.BytesIO()
+ file.save(buffered, format="JPEG")
+ base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
+ elif isinstance(file, Path):
+ with file.open("rb") as img_file:
+ base64_str = base64.b64encode(img_file.read()).decode("utf-8")
+ else:
+ base64_str = base64.b64encode(file.getvalue()).decode("utf-8")
+ return {"url": f"data:image/jpeg;base64,{base64_str}"}
+
+
+def encode_array_image(array: np.ndarray) -> Dict[str, str]:
+ # Handle boolean masks more efficiently
+ if array.dtype == bool:
+ array = array.astype(np.uint8) * 255
+ im = Image.fromarray(array)
+ buffered = io.BytesIO()
+ im.save(buffered, format="JPEG", quality=95)
+ base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
+ return {"url": f"data:image/jpeg;base64,{base64_str}"}
+
+
+def get_boolean_mask(mask_data: str) -> np.ndarray:
+ """Decode compressed mask data from client"""
+ mask_info = json.loads(mask_data)
+ data = base64.b64decode(mask_info['data'])
+ width, height = mask_info['width'], mask_info['height']
+ arr = np.frombuffer(data, dtype=np.uint8)
+ bits = np.unpackbits(arr, bitorder='big')[:width * height]
+ return bits.reshape((height, width)).astype(bool)
+
+def get_input_card_params():
+ return dict(
+ header=Div(H4("Input"), Subtitle("You can mask the image, text, or both.")),
+ id="input-card",
+ title="Input",
+ )
+
+def create_input_card_content(text_content=""):
+ """Create the shared input card content structure."""
+ content = [
+ Div(id="preview-container", cls="relative flex justify-center items-center mb-4 p-4 empty:p-0 empty:mb-0"),
+ TextArea(text_content, name="user_input", id="user-input-text", cls="resize-none h-12 w-full mb-4"),
+ Input(type="file", name="uploaded_file", id="upload-image-input", cls="mb-4"),
+ Input(type="hidden", name="mask_data", id="mask-data")
+ ]
+ return content
+
+@rt("/")
+def get(session):
+ demo_cards = []
+ for demo in DEMOS:
+ if 'image' in demo:
+ demo_image_url = encode_image(process(Image.open(demo['image'])))['url']
+
+ inner_content = Div(
+ Div(
+ Loading(cls="hidden", htmx_indicator=True),
+ id=f"demo-spinner-{DEMOS.index(demo)}",
+ cls="absolute inset-0 flex items-center justify-center"
+ ),
+ Div(
+ Img(src=demo_image_url,
+ cls="w-32 h-32 object-cover rounded-md transition-opacity hover:opacity-60 cursor-pointer mb-3"),
+ cls="demo-image-container relative flex justify-center"
+ ),
+ P(demo['text'],
+ cls="mt-2 text-sm text-muted-foreground group-hover:text-foreground transition-colors text-center"),
+ cls="flex flex-col items-center p-1"
+ )
+
+ demo_card = Card(
+ inner_content,
+ cls="demo-card hover:shadow-md transition-shadow cursor-pointer w-fit mx-auto",
+ title=f"{demo['name']}",
+ hx_post=f"/load_demo/{DEMOS.index(demo)}",
+ hx_target="#input-card",
+ hx_swap="innerHTML",
+ hx_indicator=f"#demo-spinner-{DEMOS.index(demo)}"
+ )
+ demo_cards.append(demo_card)
+
+ js_script = fr"""
+ document.body.addEventListener('htmx:beforeRequest', function(ev) {{
+ const target = ev.detail.elt.querySelector('[hx-indicator]');
+ if(target) target.querySelector('.loading').classList.remove('hidden');
+ }});
+ document.body.addEventListener('htmx:afterRequest', function(ev) {{
+ const target = ev.detail.elt.querySelector('[hx-indicator]');
+ if(target) target.querySelector('.loading').classList.add('hidden');
+ }});
+
+ const demoMaskData = {json.dumps(session.get('demo_mask'))} || undefined;
+ if (typeof demoMaskData !== 'undefined' && demoMaskData !== null) {{
+ const maskInfo = JSON.parse(demoMaskData);
+ const data = atob(maskInfo.data);
+ const arr = new Uint8Array(data.length);
+ for (let i = 0; i < data.length; i++) {{
+ arr[i] = data.charCodeAt(i);
+ }}
+ const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
+ for (let i = 0; i < arr.length * 8; i++) {{
+ const byteIndex = Math.floor(i / 8);
+ const bitIndex = 7 - (i % 8);
+ if (arr[byteIndex] & (1 << bitIndex)) {{
+ const x = i % canvas.width;
+ const y = Math.floor(i / canvas.width);
+ imageData.data[(y * canvas.width + x) * 4 + 3] = 255;
+ }}
+ }}
+ ctx.putImageData(imageData, 0, 0);
+ updateMaskData(canvas);
+ }}
+ function downloadMaskData() {{
+ const maskData = document.getElementById('mask-data').value;
+ if (!maskData) return;
+ const blob = new Blob([maskData], {{type: 'application/json'}});
+ const link = document.createElement('a');
+ link.href = URL.createObjectURL(blob);
+ link.download = `mask_${{Date.now()}}.json`;
+ link.click();
+ }}
+ function initializeCanvas(img, wrapper) {{
+ const DISPLAY_SIZE = 256; // fixed display size in pixels
+
+ // Set the preview image to the fixed size
+ img.style.width = DISPLAY_SIZE + "px";
+ img.style.height = DISPLAY_SIZE + "px";
+
+ const canvas = document.createElement('canvas');
+ // Use our fixed display size for the canvas dimensions
+ canvas.width = DISPLAY_SIZE;
+ canvas.height = DISPLAY_SIZE;
+ canvas.style.position = 'absolute';
+ canvas.style.top = '0';
+ canvas.style.left = '0';
+ canvas.style.cursor = 'crosshair';
+
+ const ctx = canvas.getContext('2d');
+ // Compute a scale factor relative to the image's natural dimensions
+ const scaleFactor = DISPLAY_SIZE / Math.max(img.naturalWidth, img.naturalHeight);
+ ctx.strokeStyle = 'black';
+ ctx.lineWidth = 35 * scaleFactor; // adjust line width proportionally
+
+ let drawing = false;
+ canvas.addEventListener('mousedown', e => {{
+ drawing = true;
+ ctx.beginPath();
+ ctx.moveTo(e.offsetX, e.offsetY);
+ }});
+ canvas.addEventListener('mousemove', e => {{
+ if (drawing) {{
+ ctx.lineTo(e.offsetX, e.offsetY);
+ ctx.stroke();
+ }}
+ }});
+ canvas.addEventListener('mouseup', e => {{
+ drawing = false;
+ updateMaskData(canvas);
+ }});
+ canvas.addEventListener('mouseleave', e => {{
+ if (drawing) {{
+ drawing = false;
+ updateMaskData(canvas);
+ }}
+ }});
+ wrapper.appendChild(canvas);
+ return canvas;
+ }}
+ function updateMaskData(canvas) {{
+ const ctx = canvas.getContext('2d');
+ const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
+ const data = imageData.data;
+ const buffer = new Uint8Array(Math.ceil((canvas.width * canvas.height) / 8));
+ for (let i = 0; i < data.length; i += 4) {{
+ const pixelIndex = i / 4;
+ const byteIndex = Math.floor(pixelIndex / 8);
+ const bitIndex = 7 - (pixelIndex % 8);
+ if (data[i + 3] > 0) {{
+ buffer[byteIndex] |= (1 << bitIndex);
+ }}
+ }}
+ const base64 = btoa(String.fromCharCode(...buffer));
+ document.getElementById('mask-data').value = JSON.stringify({{
+ data: base64,
+ width: canvas.width,
+ height: canvas.height
+ }});
+ }}
+ function clearMask() {{
+ const canvas = document.querySelector('#preview-container canvas');
+ if (canvas) {{
+ const ctx = canvas.getContext('2d');
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
+ updateMaskData(canvas);
+ }}
+ document.getElementById('mask-data').value = '';
+ }}
+ function clearImage() {{
+ // Clear file input and preview
+ const fileInput = document.getElementById('upload-image-input');
+ fileInput.value = ''; // Reset file input
+ const previewContainer = document.getElementById('preview-container');
+ previewContainer.innerHTML = ''; // Clear canvas and image
+ document.getElementById('mask-data').value = ''; // Clear mask data
+
+ // If there's a demo image, re-initialize it
+ const demoImg = {json.dumps(session.get('demo_image', ''))};
+ if (demoImg) {{
+ const img = new Image();
+ img.onload = function() {{
+ const wrapper = document.createElement('div');
+ wrapper.style.position = 'relative';
+ wrapper.style.display = 'inline-block';
+ initializeCanvas(img, wrapper);
+ wrapper.appendChild(img);
+ previewContainer.appendChild(wrapper);
+ }};
+ img.src = demoImg;
+ }}
+ }}
+ // Helper function to square crop an image (crop centered)
+ function squareCropImage(img) {{
+ const side = Math.min(img.naturalWidth, img.naturalHeight);
+ const left = (img.naturalWidth - side) / 2;
+ const top = (img.naturalHeight - side) / 2;
+ const offCanvas = document.createElement("canvas");
+ offCanvas.width = side;
+ offCanvas.height = side;
+ const offCtx = offCanvas.getContext("2d");
+ offCtx.drawImage(img, left, top, side, side, 0, 0, side, side);
+ return offCanvas.toDataURL("image/jpeg");
+ }}
+
+ // Listen for file uploads and square crop the image before previewing
+ document.getElementById('upload-image-input').addEventListener('change', function(event) {{
+ const file = event.target.files[0];
+ if (file) {{
+ const img = new Image();
+ img.onload = function() {{
+ // Square crop the loaded image
+ const croppedDataUrl = squareCropImage(img);
+ const croppedImg = new Image();
+ croppedImg.onload = function() {{
+ const previewContainer = document.getElementById('preview-container');
+ previewContainer.innerHTML = '';
+ const wrapper = document.createElement('div');
+ wrapper.style.position = 'relative';
+ wrapper.style.display = 'inline-block';
+ croppedImg.style.display = 'block';
+ croppedImg.style.maxWidth = '100%';
+ initializeCanvas(croppedImg, wrapper);
+ wrapper.appendChild(croppedImg);
+ previewContainer.appendChild(wrapper);
+ }};
+ croppedImg.src = croppedDataUrl;
+ }};
+ img.src = URL.createObjectURL(file);
+ }}
+ }});
+ """
+
+ main_content = Container(
+ Div(
+ DivFullySpaced(
+ Style("""
+ .top-left {
+ position: absolute;
+ top: 3%;
+ left: 2%;
+ /* Additional styling as needed */
+ }
+
+ .custom_middle {
+ position: relative;
+ top: 0%;
+ left: 50%;
+ transform: translate(-50%, 0%);
+ /* Additional styling as needed */
+ }
+ """),
+ H1("UniDisc Demo", cls="text-4xl font-light tracking-tight top-left"),
+ Div(*demo_cards,
+ cls="grid grid-cols-3 gap-4 max-w-5xl custom_middle"),
+ cls="flex items-center justify-between mb-8 px-4"
+ ),
+ Form(
+ Grid(
+ Card(
+ Div(*create_input_card_content()),
+ **get_input_card_params()
+ ),
+ Card(
+ Div(id="output-content", cls="space-y-4"),
+ header=Div(H4("Output")),
+ id="output-card",
+ title="Output"
+ ),
+ cls="grid grid-cols-2 gap-6 mb-0"
+ ),
+ CardFooter(
+ Grid(
+ Button(
+ Div(
+ Span("Submit", cls="submit-text"),
+ Loading(cls="hidden h-4 w-4 animate-spin", id='loading', htmx_indicator=True),
+ cls="flex gap-2 items-center justify-center"
+ ),
+ cls=(ButtonT.primary,'w-full'),
+ hx_indicator="this .loading"
+ ),
+ Button("Clear Mask", type="button", cls=(ButtonT.primary,'w-full'), onclick="clearMask()"),
+ Button("Clear Image", type="button", cls=(ButtonT.primary,'w-full'), onclick="clearImage()"),
+ *([Button("Download Mask", type="button",
+ cls=(ButtonT.primary, 'w-full', 'dev-only'),
+ onclick="downloadMaskData()")] if SHOW_DEV_BUTTONS else []),
+ Button(
+ # DivFullySpaced(UkIcon('move-down', 20, 20, 3),"Sampling Configs"),
+ "Sampling Configs",
+ uk_toggle="target: #config-modal", id="config-modal-button", cls=(ButtonT.primary, 'w-full')
+ ),
+ cls="grid grid-cols-4 gap-2"
+ ),
+
+ ),
+ Card(
+ Grid(
+ Div(
+ LabelInput("Max Tokens", name="max_tokens", type="number", value=32, cls="w-full"),
+ cls="space-y-1.5"
+ ),
+ Div(
+ LabelSelect(
+ *Options(256, 512, 1024, selected_idx=1),
+ name="resolution",
+ label="Resolution",
+ cls="w-full",
+ ),
+ cls="space-y-1.5"
+ ),
+ Div(
+ LabelInput("Sampling Steps", name="sampling_steps", type="number", value=32, cls="w-full"),
+ cls="space-y-1.5"
+ ),
+ Div(
+ LabelInput("Top P", name="top_p", type="number", value=0.95, step="0.01", cls="w-full"),
+ cls="space-y-1.5"
+ ),
+ Div(
+ LabelInput("Temperature", name="temperature", type="number", value=0.9, step="0.1", min_value="0.0", max_value="2.0", cls="w-full"),
+ cls="space-y-1.5"
+ ),
+ Div(
+ LabelInput("MaskGit R Temp", name="maskgit_r_temp", type="number", value=4.5, step="0.1", cls="w-full"),
+ cls="space-y-1.5"
+ ),
+ Div(
+ LabelInput("CFG", name="cfg", type="number", value=2.5, step="0.1", cls="w-full"),
+ cls="space-y-1.5"
+ ),
+ Div(
+ LabelSelect(
+ *Options("maskgit", "maskgit_nucleus", "ddpm_cache", selected_idx=1),
+ name="sampler",
+ label="Sampler",
+ cls="w-full",
+ ),
+ cls="space-y-1.5"
+ ),
+ *([
+ Div(
+ LabelInput("Port", name="port", type="number", value=8001, step="0.01", cls="w-full"),
+ cls="space-y-1.5"
+ ),
+ Div(
+ LabelSelect(*Options("False", "True", selected_idx=0), name="reward_models",label="Reward Models", cls="w-full",),
+ cls="space-y-1.5"
+ )
+ ] if ADD_DEV_FORM else []),
+ Hidden(name="save_mask_enabled", value="True"),
+ cls="grid grid-cols-4 gap-4",
+ ),
+ cls="mb-6",
+ title="Configuration",
+ id="config-modal",
+ hidden=True
+ ),
+ hx_swap="innerHTML",
+ hx_target="#output-content",
+ hx_post="/submit",
+ enctype="multipart/form-data",
+ cls="mb-6"
+ ),
+ ),
+ Script(js_script),
+ )
+
+ return main_content
+
+
+@rt("/load_demo/{demo_index}")
+def post(demo_index: int, session):
+ demo = DEMOS[demo_index]
+ if 'image' in demo:
+ session['demo_image'] = encode_image(process(Image.open(demo['image'])))['url']
+ if 'text' in demo:
+ session['demo_text'] = demo['text']
+
+ if 'mask' in demo and demo['mask'] and Path(demo['mask']).exists():
+ session['demo_mask'] = json.loads(Path(demo['mask']).read_text())
+ else:
+ session['demo_mask'] = None
+
+ content = create_input_card_content(
+ text_content=session['demo_text'],
+ )
+
+ mask_json = 'undefined' if not session['demo_mask'] else json.dumps(session['demo_mask'])
+ content.append(Script(fr"""
+ img = new Image();
+ img.onload = async function() {{
+ const previewContainer = document.getElementById('preview-container');
+ previewContainer.innerHTML = '';
+ const wrapper = document.createElement('div');
+ wrapper.style.position = 'relative';
+ wrapper.style.display = 'inline-block';
+
+ const canvas = initializeCanvas(img, wrapper);
+ const ctx = canvas.getContext('2d');
+
+ wrapper.appendChild(img);
+ previewContainer.appendChild(wrapper);
+
+ const dataUrl = {json.dumps(session.get('demo_image', ''))};
+ const base64Data = dataUrl.split(',')[1];
+ const byteCharacters = atob(base64Data);
+ const byteArrays = [];
+
+ for (let offset = 0; offset < byteCharacters.length; offset += 1024) {{
+ const slice = byteCharacters.slice(offset, offset + 1024);
+ const byteNumbers = new Array(slice.length);
+ for (let i = 0; i < slice.length; i++) {{
+ byteNumbers[i] = slice.charCodeAt(i);
+ }}
+ const byteArray = new Uint8Array(byteNumbers);
+ byteArrays.push(byteArray);
+ }}
+
+ const blob = new Blob(byteArrays, {{ type: 'image/jpeg' }});
+ const file = new File([blob], "demo_image.jpg", {{
+ type: 'image/jpeg',
+ lastModified: Date.now()
+ }});
+
+ const dataTransfer = new DataTransfer();
+ dataTransfer.items.add(file);
+
+ const fileInput = document.getElementById('upload-image-input');
+ fileInput.files = dataTransfer.files;
+ fileInput.dispatchEvent(new Event('change'));
+
+ const demoMaskData = {mask_json};
+ if (typeof demoMaskData !== 'undefined' && demoMaskData !== null) {{
+ const maskInfo = demoMaskData;
+ const data = atob(maskInfo.data);
+ const arr = new Uint8Array(data.length);
+ for (let i = 0; i < data.length; i++) {{
+ arr[i] = data.charCodeAt(i);
+ }}
+ const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
+ for (let i = 0; i < arr.length * 8; i++) {{
+ const byteIndex = Math.floor(i / 8);
+ const bitIndex = 7 - (i % 8);
+ if (arr[byteIndex] & (1 << bitIndex)) {{
+ const x = i % canvas.width;
+ const y = Math.floor(i / canvas.width);
+ imageData.data[(y * canvas.width + x) * 4 + 3] = 255;
+ }}
+ }}
+ ctx.putImageData(imageData, 0, 0);
+ updateMaskData(canvas);
+ }}
+ }};
+ img.src = {json.dumps(session.get('demo_image', ''))};
+ """))
+
+ return Card(
+ Div(*content),
+ **get_input_card_params()
+ )
+
+
+@rt("/submit")
+def post(
+ req,
+ temperature: float,
+ top_p: float,
+ maskgit_r_temp: float,
+ cfg: float,
+ max_tokens: int,
+ resolution: int,
+ sampling_steps: int,
+ sampler: str,
+ user_input: str | None = None,
+ mask_data: str | None = None,
+ uploaded_file: UploadFile | None = None,
+ port: int | None = 8001,
+ reward_models: str | None = "False"
+):
+ messages = []
+ if user_input:
+ messages.append({"type": "text", "text": user_input})
+
+ current_image = None
+ if uploaded_file is not None and uploaded_file.filename != "No image":
+ current_image = process(Image.open(io.BytesIO(uploaded_file.file.read())), int(resolution))
+ img_data = encode_image(current_image)["url"]
+
+ messages.append({
+ "type": "image_url",
+ "image_url": {"url": img_data},
+ "is_mask": False
+ })
+
+ if mask_data is not None and len(mask_data) > 0:
+ mask_array = get_boolean_mask(mask_data)
+ mask_data_url = encode_array_image(mask_array)["url"]
+ messages.append({
+ "type": "image_url",
+ "image_url": {"url": mask_data_url},
+ "is_mask": True
+ })
+
+ config_payload = {
+ "max_tokens": int(max_tokens),
+ "resolution": int(resolution),
+ "sampling_steps": int(sampling_steps),
+ "top_p": float(top_p),
+ "temperature": float(temperature),
+ "maskgit_r_temp": float(maskgit_r_temp),
+ "cfg": float(cfg),
+ "sampler": sampler,
+ "use_reward_models": reward_models == "True"
+ }
+
+ payload = {
+ "messages": [{"role": "user", "content": messages}],
+ "model": "unidisc",
+ **config_payload
+ }
+
+
+ API_URL = f"http://localhost:{port}/v1/chat/completions"
+ response = requests.post(API_URL, json=payload)
+ components = []
+
+ if response.status_code == 200:
+ response_json = response.json()
+ if "choices" in response_json:
+ content = response_json["choices"][0]["message"]["content"]
+ if isinstance(content, list):
+ for part in content:
+ if part["type"] == "text":
+ components.append(Card(
+ P(part["text"], cls="p-4"),
+ cls="response-card mb-4",
+ title="Response"
+ ))
+ elif part["type"] == "image_url":
+ components.append(
+ Card(
+ Div(
+ Img(
+ src=part["image_url"]["url"],
+ cls="w-64 h-64 object-cover rounded-md"
+ ),
+ cls="flex justify-center items-center p-4"
+ ),
+ cls="response-card mb-4"
+ )
+ )
+ else:
+ components.append(Card(P(content, cls="p-4"), cls="response-card mb-4", title="Response"))
+ else:
+ components.append(Card(P(f"API Error: {response.text}"), cls="response-card destructive mb-4", title="Error"))
+
+ output_content = Div(*components, id="output-content", cls="space-y-4 flex flex-col")
+
+ return output_content
+
+
+serve(port=5003)
\ No newline at end of file
diff --git a/demo/demo.sh b/demo/demo.sh
new file mode 100644
index 0000000000000000000000000000000000000000..18b4f8bb9a71560e308d50771ea76fe3d4ab10ee
--- /dev/null
+++ b/demo/demo.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+
+UNIDISC_FORCE_CUDNN_SPDA_CONTEXT=1 uv run python demo/server.py experiments='[large_scale_train,large_scale_train_high_res_interleaved,eval_unified,large_scale_high_res_interleaved_inference]' \
+trainer.load_from_state_dict="/home/appuser/app/pytorch_model_fsdp.bin" &
+
+uv run python demo/client.py
\ No newline at end of file
diff --git a/demo/inference.py b/demo/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d6944c1dcf3000096228be20aafdd4428ba99b3
--- /dev/null
+++ b/demo/inference.py
@@ -0,0 +1,467 @@
+from __future__ import annotations
+
+import re
+import copy
+import os
+import pickle
+from functools import partial
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple, Union
+from decoupled_utils import clear_cache
+
+from model_utils import wrapped_batch_decode
+from unidisc.tokenizers.image_tokenizers import decode_latents, get_image_batch
+
+os.environ["UNIDISC_FORCE_CUDNN_SPDA_CONTEXT"] = "1"
+os.environ["UNIDISC_DISABLE_APEX_RMSNORM"] = "1"
+
+import hydra
+import numpy as np
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from hydra import compose, initialize
+from image_utils import Im
+from omegaconf import OmegaConf, open_dict
+from PIL import Image
+from accelerate import PartialState
+
+import dataloader
+from decoupled_utils import (breakpoint_on_error, gprint,
+ set_global_breakpoint, set_global_exists)
+from demo.inference_utils import (convert_to_model_input, messages_to_batch, save_grid_image)
+from demo.server import ChatRequest, ChatMessage, ContentPart
+from utils import set_omega_conf_resolvers, set_torch_defaults
+
+os.environ["HYDRA_FULL_ERROR"] = "1"
+
+set_global_breakpoint() # Overrides breakpoint() to use ipdb.set_trace() instead and handle distributed training
+set_global_exists()
+set_omega_conf_resolvers()
+
+def get_cfg(overrides: Union[str, list[str]]):
+ with initialize(version_base=None, config_path='configs'):
+ cfg = compose(config_name='config.yaml', return_hydra_config=False, overrides=overrides)
+ return cfg
+
+def set_accelerator(config):
+ from accelerate import Accelerator
+ mixed_precision = "bf16"
+ accelerator = Accelerator(mixed_precision=mixed_precision)
+ compute_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ compute_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ compute_dtype = torch.bfloat16
+ gprint(f"Compute dtype is: {compute_dtype}")
+ with open_dict(config):
+ config.trainer.devices = accelerator.num_processes
+ config.trainer.dtype = str(compute_dtype)
+
+ return config, accelerator
+
+def setup(config=None, save_config=False, demo_type="jan", device=None, profile_memory=True):
+ if profile_memory:
+ torch.cuda.memory._record_memory_history()
+ from torchtnt.utils.oom import attach_oom_observer
+ attach_oom_observer(output_dir=str(os.getcwd()), trace_max_entries=500000)
+
+ set_torch_defaults()
+ if config is not None:
+ demo_type = "jan"
+ config_path = Path(__file__).parent / f'outputs/config_{demo_type}.pkl'
+
+ if save_config:
+ OmegaConf.resolve(config)
+ config_path.parent.mkdir(parents=True, exist_ok=True)
+ with open(config_path, 'wb') as f:
+ pickle.dump(config, f)
+
+ yaml_config_path = config_path.with_suffix('.yaml')
+ with open(yaml_config_path, 'w') as yaml_file:
+ OmegaConf.save(config, yaml_file)
+ print(f"Saved config to {config_path}")
+ exit()
+ elif config is None:
+ with open(config_path, 'rb') as f:
+ config = pickle.load(f)
+ print(f"Loaded config from {config_path}")
+
+ if config is not None:
+ demo_type = "jan"
+
+ config, accelerator = set_accelerator(config)
+ from model import Diffusion
+ device = PartialState().device if device is None else device
+ model = Diffusion(config=config, tokenizer=dataloader.get_tokenizer(config), device=device)
+ model.set_accelerator(accelerator)
+ return partial(inference, config=config, model=model)
+
+mask_token = ""
+
+def expand_mask_tokens(messages: List[ChatMessage]) -> List[ChatMessage]:
+ # Expand -> ... (N times)
+ import re
+ def replace_match(match: re.Match) -> str:
+ # Extract the number after str:
+ # Match
+ pattern = r''
+ return re.sub(pattern, replace_match, text)
+
+ messages = copy.deepcopy(messages)
+ for message in messages:
+ for content in message.content:
+ if content.type == "text":
+ print(f"Input text: {content.text}")
+ content.text = expand_mask_in_text(content.text)
+ print(f"Expanded text: {content.text}")
+
+ return messages
+
+def get_fixed_batch(config, tokenizer, model, input_data, resolution):
+ assert len(input_data) == 2, "Input data must contain 2 messages"
+ images = [content["image_url"] for content in input_data if content['type'] == "image_url"]
+ texts = [content["text"] for content in input_data if content['type'] == "text"]
+ assert len(images) == 1
+ assert len(texts) == 1
+ _img = Im(images[0])
+ if not _img.height == _img.width:
+ _img = _img.square(resolution, resolution)
+ elif _img.height != resolution or _img.width != resolution:
+ _img = _img.resize(resolution, resolution)
+ img_image_ids = get_image_batch(config, model.get_vae(), {"img": _img.torch[None]}, model.device)
+ txt_input_ids = dataloader.tokenize_text(tokenizer, config.data.block_size, texts)
+
+ data = {}
+ seq_len = config.data.block_size + img_image_ids.shape[1] # Allow variable token length
+ data["input_ids"] = txt_input_ids["input_ids"]
+ data["attention_mask"] = txt_input_ids["attention_mask"]
+ data["modality"] = torch.full((1, seq_len,), dtype=torch.int64, fill_value=1) # assuming images
+ data["modality"][..., :data["input_ids"].shape[1]] = 0
+ data["input_ids"] = torch.cat([data["input_ids"].to(model.device), img_image_ids], dim=-1)
+ data["attention_mask"] = torch.cat([data["attention_mask"], torch.full((1, seq_len - data["attention_mask"].shape[1],), dtype=torch.bool, fill_value=True)], dim=-1).bool()
+ data["img"] = Im(images[0]).torch[None]
+ data["sample_ids"] = torch.full((1, seq_len), dtype=torch.int64, fill_value=0)
+ for k in list(data.keys()):
+ data[k] = data[k].to(model.device)
+
+ data["input_ids"] = torch.where(
+ (data["modality"] == 1) & (data["input_ids"] != -1),
+ data["input_ids"] + config.data.img_token_shift,
+ data["input_ids"]
+ )
+
+ return data
+
+
+def inference(
+ request: ChatRequest,
+ config = None,
+ model = None,
+):
+ messages = request.messages
+ messages = expand_mask_tokens(messages)
+ input_request = copy.deepcopy(request)
+
+ with open_dict(config):
+ config.eval.top_p = request.top_p
+ config.eval.temperature = request.temperature
+ config.eval.maskgit_r_temp = request.maskgit_r_temp
+ config.eval.cfg = request.cfg
+ config.sampling.predictor = request.sampler
+ model.sampler = config.sampling.predictor
+
+ gen_img = False
+ gen_txt = False
+ resolution = request.resolution
+ print(f"messages: {messages}")
+ img_contains_mask = any(content.is_mask for msg in messages for content in msg.content)
+ input_contains_img = any(content.type == "image_url" for msg in messages for content in msg.content)
+ if img_contains_mask:
+ print(f"img_contains_mask: {img_contains_mask}")
+ recent_mask = [content.image_url for content in messages[-1].content if content.is_mask]
+ recent_img = [content.image_url for content in messages[-1].content if (content.type == "image_url" and not content.is_mask)]
+ assert len(recent_mask) == len(recent_img) == 1, "Number of masks must match number of images"
+ recent_mask = recent_mask[0]
+ recent_img = recent_img[0]
+ for msg in messages:
+ msg.content = [content for content in msg.content if not content.is_mask]
+
+ if any("" in content.text for content in messages[-1].content if content.type == "text") or (not input_contains_img):
+ print(f"Generating image: {messages[-1].content[-1].text}")
+ gen_img = True
+ messages[-1].content[-1].text = messages[-1].content[-1].text.replace("", "")
+ messages.append(ChatMessage(
+ role="assistant",
+ content=[ContentPart(
+ type="image_url",
+ image_url=Image.new("RGB", (resolution, resolution), color=(0, 0, 0))
+ )]
+ ))
+ elif not any(mask_token in content.text for content in messages[-1].content if content.type == "text") and config.trainer.interleaved and not getattr(config.eval, "static_img_txt_demo", False):
+ print(f"Generating {request.max_tokens} tokens of text")
+ gen_txt = True
+ messages.append(ChatMessage(
+ role="assistant",
+ content=[ContentPart(
+ # "authentication" is a single token in the tokenizer so this gives us exact control over the number of tokens
+ type="text", text="authentication" * request.max_tokens
+ )]
+ ))
+ elif any(mask_token in content.text for content in messages[-2].content if content.type == "text") and getattr(config.eval, "static_img_txt_demo", False):
+ print(f"Got user text input with mask tokens, generating text")
+ gen_txt = True
+
+ force_reorder = True
+ mask_eos = True
+ if force_reorder and input_contains_img:
+ image_messages = [msg for msg in messages if any(content.type == "image_url" for content in msg.content)]
+ messages = [msg for msg in messages if not any(content.type == "image_url" for content in msg.content)]
+ messages.extend(image_messages)
+ print(f"Reordered messages, images are now last")
+
+ messages = convert_to_model_input(messages)
+ print(f"input messages: {messages}")
+
+ all_special_tokens = {x.content for x in model.tokenizer.added_tokens_decoder.values()}
+ if mask_token not in all_special_tokens:
+ new_tokens = [mask_token]
+ new_tokens = list(set(new_tokens) - set(model.tokenizer.get_vocab().keys()))
+ model.tokenizer.add_special_tokens({"additional_special_tokens": new_tokens}, replace_additional_special_tokens=False)
+ assert model.tokenizer.added_tokens_decoder[len(model.tokenizer) - 1].content == mask_token
+ print(model.tokenizer(new_tokens, add_special_tokens=False))
+
+ mask_token_id = len(model.tokenizer) - 1
+ if config.trainer.interleaved:
+ batch = messages_to_batch(config, model.tokenizer, model, messages, resolution=resolution)
+ else:
+ batch = get_fixed_batch(config, model.tokenizer, model, messages, resolution=resolution)
+
+ sampling_steps = request.sampling_steps
+ sample_modality = batch["modality"]
+ x0 = batch["input_ids"]
+ x0_unmask = torch.ones_like(x0, dtype=torch.bool)
+ txt_contains_mask = False
+ for i in range(x0.shape[0]):
+ if gen_img or img_contains_mask:
+ modality_seq = sample_modality[i]
+ changes = torch.diff(modality_seq)
+ change_points = torch.where(changes != 0)[0] + 1
+ change_points = torch.cat([torch.tensor([0], device=change_points.device), change_points, torch.tensor([len(modality_seq)], device=change_points.device)])
+
+ sequences = []
+ for start, end in zip(change_points[:-1], change_points[1:]):
+ if modality_seq[start] == 1:
+ sequences.append((start.item(), end.item()))
+
+ if sequences:
+ last_start, last_end = sequences[-1]
+ x0_unmask[i, last_start:last_end] = False
+ print(f"Masked slice: {last_start}:{last_end}")
+ else:
+ print(f"WARNING: No sequences found")
+
+ if img_contains_mask:
+ def downscale_bool(arr: np.ndarray, D: int) -> np.ndarray:
+ if len(arr.shape) == 3:
+ print(f"Converting (H, W, C) to (H, W)")
+ arr = arr.sum(axis=-1)
+ H, W = arr.shape
+ assert H % D == 0 and W % D == 0, "H and W must be divisible by D"
+ return arr.reshape(H // D, D, W // D, D).any(axis=(1, 3))
+
+ import math
+ _res = int(math.sqrt(last_end - last_start) * config.model.downscale_ratio)
+ if recent_mask.size != (_res, _res):
+ print(f"WARNING!! recent_mask.size: {recent_mask.size}, last_end - last_start: {last_end - last_start}")
+ mask_arr = downscale_bool(np.array(recent_mask.convert("RGB").resize((resolution, resolution), resample=Image.Resampling.NEAREST)).astype(np.bool_), config.model.downscale_ratio)
+ print(Im(mask_arr).save())
+ mask_arr = torch.from_numpy(mask_arr).to(x0_unmask.device).reshape(-1).nonzero().squeeze()
+ mask_arr = mask_arr + last_start
+ x0_unmask[i, last_start:last_end] = True
+ x0_unmask[i, mask_arr] = False
+
+ if gen_img and not gen_txt and getattr(config.eval, "static_img_txt_demo", False):
+ print(f"Unmasking all text positions for static_img_txt_demo: {x0_unmask[i, modality_seq == 0].sum().item()}")
+ x0_unmask[i, modality_seq == 0] = True
+ elif gen_txt and not getattr(config.eval, "static_img_txt_demo", False):
+ bos_positions = torch.where(x0[i] == model.tokenizer.bos_token_id)[0]
+ if len(bos_positions) == 0:
+ continue
+
+ last_bos = bos_positions[-2] if force_reorder else bos_positions[-1]
+ eos_positions = torch.where((x0[i] == model.tokenizer.eos_token_id) & (torch.arange(len(x0[i]), device=x0.device) > last_bos))[0]
+
+ print(f"BOS positions: {bos_positions}, EOS positions: {eos_positions}")
+ unmask_to_eos = True
+ if unmask_to_eos and len(eos_positions) > 0:
+ last_eos = eos_positions[0]
+ else:
+ last_eos = None # Mask everything after last BOS
+
+ x0_unmask[i, last_bos+1:last_eos] = False
+ if mask_eos and force_reorder:
+ x0_unmask[i, last_bos+1:last_eos+3] = False
+ print(f"Masked slice: {last_bos}:{last_eos}")
+
+ to_mask = x0[i] == mask_token_id
+ if to_mask.sum().item() > 0:
+ x0_unmask[i, to_mask] = False
+ print(f"Found {to_mask.sum().item()} text mask tokens")
+ txt_contains_mask = True
+
+
+ # Add metrics for x0_unmask[0]
+ true_indices = torch.where(x0_unmask[0])[0]
+ first_true = true_indices[0].item()
+ last_true = true_indices[-1].item()
+ total_true = x0_unmask[0].sum().item()
+
+ masked_modalities = batch["modality"][0][x0_unmask[0]]
+ zeros = (masked_modalities == 0).sum().item()
+ ones = (masked_modalities == 1).sum().item()
+
+ print(f"x0_unmask num unmasked: {total_true}, x0_unmask, first position: {first_true}, x0_unmask last position: {last_true}")
+ print(f"x0_unmask num txt (0) count: {zeros}, x0_unmask num img (1) count: {ones}")
+ print(f"Masking {((~x0_unmask) & (batch['sample_ids'] >= 0)).sum().item()} positions, modality shape: {batch['modality'].shape}")
+
+ # Find first invalid sample ID, defaulting to full length if none found
+ invalid_positions = (batch["sample_ids"][0].long() == -1).nonzero(as_tuple=True)[0]
+ first_invalid_sample_id = invalid_positions[0].item() if len(invalid_positions) > 0 else len(batch["sample_ids"][0])
+ print(f"First invalid sample ID position: {first_invalid_sample_id}")
+ row_len = save_grid_image(x0_unmask[0][:first_invalid_sample_id], "x0_unmask_viz.png")
+ save_grid_image(batch["modality"][0][:first_invalid_sample_id], "modality_viz.png", row_len=row_len)
+ _sc = batch["sample_ids"][0].clone()
+ _sc[_sc == 0] = 1
+ _sc[_sc == -1] = 0
+ save_grid_image(_sc, "sample_ids_viz.png", row_len=row_len)
+
+ if request.use_reward_models:
+ idx = 0
+ bs = 1
+ num_iter = 4
+ from tensordict import TensorDict
+ gen_batch = TensorDict.from_dict(batch, batch_size=[batch['input_ids'].shape[0]])
+ text_samples_list = []
+ img_samples_list = []
+ _gen_batch = []
+ for i in range(num_iter):
+ _gen_batch.append(gen_batch[[idx]])
+ gen_batch = torch.cat(_gen_batch, dim=0)
+
+ for j in range(num_iter):
+ _modality = gen_batch[[idx]].get("modality", None)
+ _sample_ids = gen_batch[[idx]].get("sample_ids", None)
+ if _modality is not None:
+ _modality = _modality.to(model.device)
+ if _sample_ids is not None:
+ _sample_ids = _sample_ids.to(model.device)
+ else:
+ _sample_ids = torch.zeros_like(_modality)
+ text_samples, img_samples, x = model._sample(
+ text_only=False,
+ num_steps=sampling_steps,
+ batch_size_per_gpu=bs,
+ modality=_modality,
+ sample_ids=_sample_ids,
+ x0=gen_batch["input_ids"][[idx]].to(model.device),
+ x0_unmask=x0_unmask[[idx]].to(model.device),
+ return_raw_data=True,
+ allow_interleaved_conditional=True
+ )
+ gen_batch[[idx]]['input_ids'] = x
+ text_samples_list.extend(text_samples)
+ img_samples_list.extend(img_samples)
+ print(f"Sampled {j + 1} / {num_iter}")
+
+ text_samples_list = wrapped_batch_decode(
+ model.tokenizer,
+ torch.stack(text_samples_list, dim=0),
+ clean_up_tokenization_spaces=True,
+ skip_special_tokens=True,
+ disable_mask_after_eos=True
+ )
+
+ img_samples_list = torch.cat(img_samples_list, dim=0)
+ reward_config = config.eval.auto_enhance_reward_config
+ rewards, raw_rewards = model.get_rewards(reward_config, img_samples_list, text_samples_list, batch=gen_batch, return_raw_rewards=True)
+
+ gprint(f"Avg Rewards: {rewards}")
+
+ sorted_indices = torch.argsort(rewards, descending=True).tolist()
+ sorted_text_samples = [text_samples_list[i] for i in sorted_indices]
+ sorted_img_samples = [img_samples_list[i] for i in sorted_indices]
+ sorted_avg_rewards = [rewards[i] for i in sorted_indices]
+ sorted_raw_rewards = {k: [raw_rewards[k][i] for i in sorted_indices] for k in raw_rewards}
+
+ txt_samples = [sorted_text_samples[0]]
+ img_samples = [Im(sorted_img_samples[0]).pil]
+ else:
+ txt_samples, img_samples = model._sample(
+ text_only=False,
+ num_steps=sampling_steps,
+ batch_size_per_gpu=1,
+ example_batch=batch,
+ sample_batch_idx=0,
+ modality=batch["modality"],
+ sample_ids=batch["sample_ids"],
+ allow_interleaved_conditional=True,
+ x0_unmask=x0_unmask,
+ x0=x0,
+ )
+
+ if not config.trainer.interleaved:
+ txt_samples = model.tokenizer.batch_decode(txt_samples[..., model.static_txt_sl], remove_special_tokens=True)
+ txt_samples[0] = txt_samples[0].replace("", "").strip()
+ img_len = (resolution // config.model.downscale_ratio)**2
+ img_samples = decode_latents(config, model.get_vae(), img_samples[..., -img_len:])
+ assert img_samples.shape[0] == 1
+ img_samples = [Im(img_samples[0]).pil]
+
+ returned_message = ChatMessage(
+ role="assistant",
+ content=[]
+ )
+ if img_contains_mask or gen_img:
+ print(f"Inference returned img_samples: {img_samples}")
+ returned_message.content.append(ContentPart(
+ type="image_url",
+ image_url=img_samples[-1]
+ ))
+
+ if txt_contains_mask or not gen_img:
+ print(f"Inference returned txt_samples: {txt_samples}")
+ last_new_txt = ""
+ for i, _txt in enumerate(txt_samples[-1].rsplit("")):
+ if len(_txt) > 0:
+ _txt = _txt.replace("", "").replace("", "").replace("", "").strip().replace(' ', ' ').replace(' ', ' ').replace(' .', '.').replace('\\n', ' ')
+ _txt = re.sub(r'[^a-zA-Z. ]', '', _txt)
+ if len(_txt) > 0:
+ last_new_txt = _txt
+
+ returned_message.content.append(ContentPart(
+ type="text",
+ text=last_new_txt
+ ))
+
+ input_request.messages.append(returned_message)
+
+ clear_cache()
+
+
+ return input_request
+
+@hydra.main(version_base=None, config_path="../configs", config_name="config")
+@torch.no_grad()
+def main(config=None):
+ inference = setup(config, save_config=True)
+ exit()
+ inference([{"type": "text", "text": "Hello, how are you?"}])
+
+if __name__ == "__main__":
+ with breakpoint_on_error():
+ main()
\ No newline at end of file
diff --git a/demo/inference_utils.py b/demo/inference_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5e514a36fdebc2289aff574c2b221071e02ef1c
--- /dev/null
+++ b/demo/inference_utils.py
@@ -0,0 +1,293 @@
+from __future__ import annotations
+
+import base64
+import copy
+import io
+import random
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+import math
+from PIL import Image
+from image_utils import Im
+
+from decoupled_utils import gprint
+
+if TYPE_CHECKING:
+ from demo.server import ChatRequest
+
+def tensor_center_crop(tensor_image, crop_size):
+ _, _, h, w = tensor_image.shape
+
+ while h >= 2 * crop_size[0] and w >= 2 * crop_size[1]:
+ tensor_image = F.interpolate(tensor_image, size=(h // 2, w // 2), mode='area')
+ _, _, h, w = tensor_image.shape
+
+ scale = max(crop_size[0] / h, crop_size[1] / w)
+ new_h, new_w = round(h * scale), round(w * scale)
+ tensor_image = F.interpolate(tensor_image, size=(new_h, new_w), mode='bilinear')
+
+ crop_top = random.randint(0, new_h - crop_size[0])
+ crop_left = random.randint(0, new_w - crop_size[1])
+ crop_bottom = crop_top + crop_size[0]
+ crop_right = crop_left + crop_size[1]
+ return tensor_image[:, :, crop_top:crop_bottom, crop_left:crop_right]
+
+def parse_messages(messages: List[dict]) -> Tuple[List[Image.Image], List[List[dict]]]:
+ """
+ Given a list of message dicts with format:
+ [
+ {"type": "text", "text": msg},
+ {"type": "image_url", "image_url": }
+ ]
+
+ Returns:
+ - all_images: a list containing the PIL images, in the order of their appearance
+ - all_content: a nested list (single conversation) with dicts indicating message type
+ """
+ all_images: List[Image.Image] = []
+ conversation: List[dict] = []
+
+ for msg in messages:
+ if msg["type"] == "text":
+ conversation.append(msg)
+ elif msg["type"] == "image_url":
+ idx = len(all_images)
+ all_images.append(msg["image_url"])
+ _msg = copy.deepcopy(msg)
+ _msg["image_url"] = {"url": idx}
+ conversation.append(_msg)
+ else:
+ raise ValueError(f"Unsupported message type: {msg['type']}. Expected 'text' or 'image_url'.")
+
+ all_content = [conversation]
+ return all_images, all_content
+
+def messages_to_batch(config, tokenizer, model, input_data, resolution):
+ import copy
+
+ from model import get_image_batch
+ from unidisc.tokenizers.tokenize_interleaved import _has_image, preprocess
+
+ # Build conversations and extract images.
+ all_images = []
+ conversations = []
+ for item in input_data:
+ role = item["role"]
+ assert role in ["user", "assistant"]
+ role = "human" if role == "user" else "gpt"
+ if item["type"] == "image_url":
+ token = ""
+ all_images.append(item["image_url"])
+ elif item["type"] == "text":
+ token = item["text"]
+ else:
+ continue
+ if conversations and conversations[-1]["from"] == role:
+ conversations[-1]["value"] += " " + token
+ else:
+ conversations.append({"from": role, "value": token})
+
+ output_list = []
+ entry = {"id": "1", "conversations": conversations}
+ if all_images:
+ entry["image"] = {}
+ output_list.append(entry)
+ all_content = output_list
+
+ vae = model.get_vae()
+ device = model.device
+ if not all_images:
+ image_ids = None
+ else:
+ _img = torch.cat([
+ tensor_center_crop(
+ torch.from_numpy(np.array(img))[None, :].permute(0, 3, 1, 2) / 255,
+ (resolution, resolution)
+ ) for img in all_images
+ ])
+ try:
+ batch_size = 32
+ image_ids_list = []
+ for i in range(0, len(_img), batch_size):
+ batch = _img[i:i+batch_size]
+ batch_ids = get_image_batch(config, vae, {"img": batch}, device)
+ image_ids_list.append(batch_ids)
+ image_ids = torch.cat(image_ids_list)
+ except Exception as e:
+ gprint(f"{_img.shape}, {e}")
+ import traceback
+ traceback.print_exc()
+
+ all_input_ids = []
+ all_attention_masks = []
+ all_modality = []
+ assert len(all_content) == 1
+ for sources in all_content:
+ has_image = _has_image(sources)
+ sources = copy.deepcopy([sources["conversations"]])
+ _image_ids = image_ids if has_image else None
+ try:
+ print(f"Sources: {sources}")
+ data_dict = preprocess(sources, tokenizer, has_image=has_image, image_ids=_image_ids)
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+ gprint(f"Error in preprocess: {e}")
+ return None, None, None
+ input_ids = data_dict["input_ids"][0]
+ attention_mask = data_dict["attention_mask"][0]
+ modality = data_dict["modality"][0]
+ if (input_ids[-2:] == tokenizer.eos_token_id).all():
+ input_ids = input_ids[:-1]
+ attention_mask = attention_mask[:-1]
+ modality = modality[:-1]
+
+ assert config.model.length >= input_ids.shape[0], f"Input ids length {input_ids.shape[0]} is greater than model length {config.model.length}"
+
+ attention_mask = attention_mask.bool()
+ print(f"Attention mask: {attention_mask.shape}, input ids: {input_ids.shape}, modality: {modality.shape}")
+
+ if modality[-1] == 1:
+ is_image = modality == 1
+ change_points = torch.where(is_image[:-1] != is_image[1:])[0] + 1
+ if change_points.numel() > 0:
+ start_pos = change_points[-1].item()
+ modality[start_pos:] = 0
+ attention_mask[start_pos:] = False
+ input_ids[start_pos:] = tokenizer.pad_token_id
+ all_input_ids.append(input_ids)
+ all_attention_masks.append(attention_mask)
+ all_modality.append(modality)
+
+ all_input_ids = torch.stack(all_input_ids)
+ all_attention_masks = torch.stack(all_attention_masks)
+ all_modality = torch.stack(all_modality)
+ all_sample_ids = torch.zeros_like(all_modality, dtype=torch.long)
+ all_sample_ids[~all_attention_masks] = -1
+ batch = {
+ "input_ids": all_input_ids,
+ "attention_mask": all_attention_masks,
+ "modality": all_modality.long(),
+ "sample_ids": all_sample_ids.long(),
+ }
+
+ for k in batch:
+ batch[k] = batch[k].to(device)
+
+ batch["input_ids"] = torch.where(
+ (batch["modality"] == 1) & (batch["input_ids"] != -1),
+ batch["input_ids"] + config.data.img_token_shift,
+ batch["input_ids"]
+ )
+
+ return batch
+
+def pil_to_base64(image: Image.Image) -> str:
+ buffered = io.BytesIO()
+ image.save(buffered, format="JPEG")
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
+
+def convert_to_model_input(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ model_input = []
+ for msg in messages:
+ for part in msg.content:
+ if part.type == "text" and part.text:
+ model_input.append({
+ "type": "text",
+ "text": part.text,
+ "role": msg.role
+ })
+ elif part.type == "image_url" and part.image_url:
+ model_input.append({
+ "type": "image_url",
+ "image_url": part.image_url,
+ "role": msg.role
+ })
+ return model_input
+
+def convert_request_pil_to_base64(request: ChatRequest) -> ChatRequest:
+ for msg in request.messages:
+ for part in msg.content:
+ if part.type == "image_url" and isinstance(part.image_url, Image.Image):
+ buffered = io.BytesIO()
+ part.image_url.convert("RGB").save(buffered, format="JPEG")
+ base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
+ part.image_url = {"url": f"data:image/jpeg;base64,{base64_str}"}
+
+ return request
+
+def convert_request_base64_to_pil(request: ChatRequest) -> ChatRequest:
+ for message in request.messages:
+ for part in message.content:
+ if part.type == "image_url" and "url" in part.image_url:
+ image_data = part.image_url["url"]
+ # Remove any data URL header, e.g. "data:image/jpeg;base64,"
+ if image_data.startswith("data:"):
+ try:
+ header, image_data = image_data.split(",", 1)
+ except ValueError as e:
+ raise ValueError(
+ f"Invalid image URL format: {image_data}"
+ ) from e
+ try:
+ decoded_bytes = base64.b64decode(image_data)
+ part.image_url = Image.open(io.BytesIO(decoded_bytes))
+ except Exception as e:
+ raise ValueError(
+ f"Error decoding or loading image. Ensure the base64 string is valid. Details: {e}"
+ ) from e
+ return request
+
+def trim_merge_messages(request: ChatRequest) -> ChatRequest:
+ # Remove empty text parts from each message
+ for msg in request.messages:
+ msg.content = [
+ part for part in msg.content
+ if not (part.type == "text" and part.text.strip() == "")
+ ]
+
+ # Remove messages with no content
+ request.messages = [
+ msg for msg in request.messages
+ if msg.content
+ ]
+
+ # Merge consecutive messages with the same role
+ merged_messages = []
+ for msg in request.messages:
+ if merged_messages and merged_messages[-1].role == msg.role:
+ merged_messages[-1].content.extend(msg.content)
+ else:
+ merged_messages.append(msg)
+
+ request.messages = merged_messages
+ return request
+
+def save_grid_image(input_arr: torch.Tensor, output_name, row_len=None):
+ # Convert to boolean then to int (0/1)
+ x0_bool = input_arr.bool().long()
+ n = x0_bool.numel()
+ if row_len is None:
+ row_len = math.ceil(math.sqrt(n))
+ rows = math.ceil(n / row_len)
+ total = rows * row_len
+ # Pad with -1 to mark padded positions
+ padded = torch.full((total,), -1, dtype=torch.long)
+ padded[:n] = x0_bool
+ grid = padded.reshape(rows, row_len)
+ # Create an RGB image: false=black, true=white, padded=red
+ image = torch.zeros((rows, row_len, 3), dtype=torch.uint8)
+ mask_true = (grid == 1)
+ mask_padding = (grid == -1)
+ image[mask_true] = torch.tensor([255, 255, 255], dtype=torch.uint8)
+ image[mask_padding] = torch.tensor([255, 0, 0], dtype=torch.uint8)
+ img = Image.fromarray(image.numpy(), mode='RGB')
+
+ from datetime import datetime
+ output = Im(img).save(datetime.now().strftime("%Y_%m_%d-%H_%M_%S") + "_" + output_name)
+ print(f"Saved visualization to {output}")
+ return row_len
\ No newline at end of file
diff --git a/demo/misc/client_simple_gradio.py b/demo/misc/client_simple_gradio.py
new file mode 100644
index 0000000000000000000000000000000000000000..83028da9e5b82745233b65b9dd7d5e00020a7934
--- /dev/null
+++ b/demo/misc/client_simple_gradio.py
@@ -0,0 +1,318 @@
+import time
+import gradio as gr
+import requests
+import asyncio
+from pathlib import Path
+import base64
+from PIL import Image
+import numpy as np
+from typing import List, Dict, Any
+import io
+import uuid
+from demo.server import ChatRequest, ChatMessage, ContentPart
+
+API_URL = "http://localhost:8000/v1/chat/completions"
+
+# Encode a file on disk as a base64 data URL.
+def encode_image(file_path: Path) -> Dict[str, str]:
+ with file_path.open("rb") as img_file:
+ base64_str = base64.b64encode(img_file.read()).decode("utf-8")
+ return {"url": f"data:image/jpeg;base64,{base64_str}"}
+
+# Convert a numpy array (or a PIL image) to a base64-encoded JPEG data URL.
+def encode_array_image(array: np.ndarray) -> Dict[str, str]:
+ im = Image.fromarray(array) if isinstance(array, np.ndarray) else array
+ buffered = io.BytesIO()
+ im.save(buffered, format="JPEG")
+ base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
+ return {"url": f"data:image/jpeg;base64,{base64_str}"}
+
+def decode_image(img_data: str) -> Image:
+ base64_data = img_data.split("base64,")[1]
+ image_bytes = base64.b64decode(base64_data)
+ return Image.open(io.BytesIO(image_bytes))
+
+# Helper: compute a boolean mask from the image editor data.
+def get_boolean_mask(image_data):
+ if image_data is None:
+ return None
+ layers = image_data.get("layers", [])
+ if not layers:
+ bg = image_data.get("background")
+ if bg is not None:
+ height, width = bg.shape[:2]
+ return np.zeros((height, width), dtype=np.uint8)
+ return None
+ mask_layer = layers[0]
+ if mask_layer.shape[-1] == 4:
+ colored = mask_layer[..., 3] > 0
+ return (colored.astype(np.uint8) * 255), image_data["composite"]
+ else:
+ colored = mask_layer > 0
+ return (colored.astype(np.uint8) * 255), image_data["composite"]
+
+# Convert the stored content into a list of ContentPart objects.
+def convert_to_content_parts(raw: Any) -> List[ContentPart]:
+ if isinstance(raw, str):
+ return [ContentPart(type="text", text=raw)]
+ elif isinstance(raw, list):
+ parts = []
+ for item in raw:
+ if isinstance(item, dict):
+ parts.append(ContentPart(**item))
+ else:
+ raise ValueError(f"Unexpected list element type: {type(item)}")
+ return parts
+ elif isinstance(raw, tuple):
+ return [ContentPart(type="image_url", image_url=encode_image(Path(raw[0])))]
+ elif isinstance(raw, dict):
+ _content = raw.value if isinstance(raw, gr.Image) else raw
+ if "path" in _content:
+ return [ContentPart(type="image_url", image_url=encode_image(Path(_content["path"])))]
+ else:
+ raise ValueError(f"Expected 'path' in content dict, got: {_content}")
+ else:
+ raise ValueError(f"Unexpected content type: {type(raw)}")
+
+def add_user_msg_to_history(history: List[Dict[str, Any]], message: Dict[str, Any]) -> List[Dict[str, Any]]:
+ for file_path in message.get("files", []):
+ history.append({"role": "user", "content": {"path": file_path}})
+ if text := message.get("text"):
+ history.append({"role": "user", "content": text})
+ return history
+
+def add_assistant_msg_to_history(history: List[Dict[str, Any]], content: List[Any]) -> List[Dict[str, Any]]:
+ for item in content:
+ if isinstance(item, str):
+ history.append({"role": "assistant", "content": item})
+ elif isinstance(item, tuple):
+ img_data, _ = item
+ if isinstance(img_data, str) and img_data.startswith("data:image"):
+ image = decode_image(img_data)
+ else:
+ image = img_data
+ history.append({"role": "assistant", "content": gr.Image(value=image)})
+ return history
+
+def build_chat_request(
+ history: List[Dict[str, Any]],
+ message: Dict[str, Any],
+ model: str = "unidisc",
+ max_tokens: int = 1024,
+ temperature: float = 0.9,
+ top_p: float = 0.95,
+ resolution: int = 256,
+ sampling_steps: int = 35,
+ maskgit_r_temp: float = 4.5,
+ cfg: float = 3.5,
+ sampler: str = "maskgit"
+) -> ChatRequest:
+ messages = [ChatMessage(role=entry["role"], content=convert_to_content_parts(entry["content"])) for entry in history]
+ if "mask" in message and message.get("files"):
+ messages[-1].content.append(ContentPart(type="image_url", image_url=encode_array_image(message["mask"]), is_mask=True))
+
+ return ChatRequest(
+ messages=messages,
+ model=model,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ resolution=resolution,
+ sampling_steps=sampling_steps,
+ maskgit_r_temp=maskgit_r_temp,
+ cfg=cfg,
+ sampler=sampler
+ )
+
+async def send_request(payload: Dict[str, Any]) -> Dict[str, Any]:
+ response = await asyncio.to_thread(lambda: requests.post(API_URL, json=payload))
+ response.raise_for_status()
+ return response.json()
+
+def process_response(response: Dict[str, Any]) -> str | List[Any]:
+ choices = response.get("choices", [])
+ if not choices:
+ return ""
+ message = choices[0].get("message", {})
+ content = message.get("content", [])
+ if isinstance(content, str):
+ return content
+ result = []
+ for part in content:
+ if part.get("type") == "text":
+ result.append(part.get("text", ""))
+ elif part.get("type") == "image_url":
+ img_data = part.get("image_url", {}).get("url", "")
+ if img_data.startswith("data:image"):
+ result.append((img_data, "image"))
+ return ["\n".join(result)] if all(isinstance(item, str) for item in result) else result
+
+def save_composite_image(composite: np.ndarray, file_path: str) -> str:
+ image = Image.fromarray(composite.astype('uint8'), 'RGBA')
+ image.save(file_path)
+ return file_path
+
+def overwrite_input_img(history: List[Dict[str, Any]], message: Dict[str, Any]) -> List[Dict[str, Any]]:
+ if 'composite' in message:
+ composite_image_path = save_composite_image(message['composite'], f'/tmp/gradio/{uuid.uuid4()}.png')
+ for entry in reversed(history):
+ if not isinstance(entry['content'], str):
+ entry['content'] = gr.Image(value=composite_image_path)
+ return history
+ return history
+
+async def bot(
+ history: List[Dict[str, Any]],
+ message: Dict[str, Any],
+ max_tokens: int,
+ resolution: int,
+ sampling_steps: int,
+ top_p: float,
+ temperature: float,
+ maskgit_r_temp: float,
+ cfg: float,
+ sampler: str
+ ):
+ history = add_user_msg_to_history(history, message)
+ chat_request = build_chat_request(
+ history,
+ message,
+ max_tokens=int(max_tokens),
+ resolution=int(resolution),
+ sampling_steps=int(sampling_steps),
+ top_p=float(top_p),
+ temperature=float(temperature),
+ maskgit_r_temp=float(maskgit_r_temp),
+ cfg=float(cfg),
+ sampler=str(sampler)
+ )
+ do_overwrite_input_img = True
+ payload = chat_request.model_dump()
+ if do_overwrite_input_img:
+ history = overwrite_input_img(history, message)
+ try:
+ response = await send_request(payload)
+ content = process_response(response)
+ history = add_assistant_msg_to_history(history, content)
+ except requests.HTTPError as e:
+ history.append({"role": "assistant", "content": f"Error: {e}"})
+ return history, gr.update(value=None, interactive=True)
+
+async def handle_submit(history, message, mask_editor, max_tokens, resolution, sampling_steps, top_p, temperature, maskgit_r_temp, cfg, sampler):
+ if mask_editor is not None:
+ mask, composite = get_boolean_mask(mask_editor)
+ if mask is not None and mask.sum() > 0:
+ message["mask"] = mask
+ message["composite"] = composite
+ history_out, chat_input_update = await bot(history, message, max_tokens, resolution, sampling_steps, top_p, temperature, maskgit_r_temp, cfg, sampler)
+ return history_out, chat_input_update, gr.update(value=None), 0
+
+def square_crop(image: Image.Image) -> Image.Image:
+ width, height = image.size
+ side = min(width, height)
+ left = (width - side) // 2
+ top = (height - side) // 2
+ right = left + side
+ bottom = top + side
+ return image.crop((left, top, right, bottom))
+
+def update_image_editor(chat_input_value, image_editor_value, num_editor_updates, desired_resolution: int = 256):
+ print(f"num_editor_updates: {num_editor_updates}, chat_input_value: {chat_input_value}")
+ files = chat_input_value.get("files", [])
+ if len(files) == 0:
+ print(f"len files 0 returning image_editor_value, new num_editor_updates: {0}")
+ return image_editor_value, 0
+
+ # For some reason when you upload a file, this is called twice. We want to prevent further updates to avoid resetting masking while e.g., typing.
+ if num_editor_updates >= 2:
+ print(f"returning image_editor_value, new num_editor_updates: {num_editor_updates}")
+ return image_editor_value, num_editor_updates
+
+ file_path = files[0]
+ image = Image.open(file_path)
+ cropped_image = square_crop(image)
+ if desired_resolution > 0:
+ cropped_image = cropped_image.resize(
+ (int(desired_resolution), int(desired_resolution)), Image.LANCZOS
+ )
+
+ if (len(chat_input_value['text']) > 0 and num_editor_updates >= 0):
+ print(f"setting background,new num_editor_updates: {num_editor_updates + 1}")
+ image_editor_value["background"] = cropped_image
+ return image_editor_value, num_editor_updates + 1
+ else:
+ print(f"returning cropped_image, new num_editor_updates: {num_editor_updates + 1}")
+ return cropped_image, num_editor_updates + 1
+
+demo_examples = [
+ {"text": "This is a", "files": [str(Path("demo/assets/dog.jpg").resolve())]},
+]
+
+with gr.Blocks() as demo:
+ chatbot = gr.Chatbot(
+ elem_id="chatbot",
+ bubble_full_width=False,
+ type="messages",
+ render_markdown=False,
+ )
+ with gr.Row():
+ with gr.Column(scale=2):
+ chat_input = gr.MultimodalTextbox(
+ interactive=True,
+ file_count="multiple",
+ placeholder="Enter message or upload file...",
+ show_label=False,
+ sources=["upload"],
+ )
+ with gr.Column(scale=1):
+ image_editor = gr.ImageMask(
+ label="Mask the image",
+ brush=gr.Brush(default_size=64, colors=["#000000"], color_mode='fixed')
+ )
+
+ gr.Examples(
+ examples=demo_examples,
+ inputs=chat_input,
+ label="Try these examples"
+ )
+
+ with gr.Row():
+ max_tokens_input = gr.Number(value=32, label="Tokens to Generate", precision=0)
+ resolution_input = gr.Number(value=256, label="Resolution", precision=0)
+ sampling_steps_input = gr.Number(value=32, label="Sampling Steps", precision=0)
+ with gr.Row():
+ top_p_input = gr.Number(value=0.95, label="Top P [maskgit_nucleus only]", precision=2)
+ temperature_input = gr.Number(value=0.9, label="Temperature [maskgit_nucleus only]", precision=2)
+ with gr.Row():
+ maskgit_r_temp_input = gr.Number(value=4.5, label="MaskGit R Temp", precision=2)
+ cfg_input = gr.Number(value=2.5, label="CFG", precision=2)
+ sampler_input = gr.Dropdown(
+ choices=["maskgit", "maskgit_nucleus", "ddpm_cache"],
+ value="maskgit_nucleus",
+ label="Sampler"
+ )
+
+ # State to track the last set of files we processed for the editor.
+ num_editor_updates = gr.State(0)
+
+ # We only invoke `update_image_editor` on change, but it will no-op
+ # if no new file is present or if the file hasn't changed.
+ chat_input.change(
+ fn=update_image_editor,
+ inputs=[chat_input, image_editor, num_editor_updates, resolution_input],
+ outputs=[image_editor, num_editor_updates]
+ )
+
+ chat_input.submit(
+ handle_submit,
+ [
+ chatbot, chat_input, image_editor,
+ max_tokens_input, resolution_input, sampling_steps_input,
+ top_p_input, temperature_input, maskgit_r_temp_input,
+ cfg_input, sampler_input
+ ],
+ [chatbot, chat_input, image_editor, num_editor_updates]
+ )
+
+if __name__ == "__main__":
+ demo.launch(share=True)
diff --git a/demo/misc/client_simple_streamlit.py b/demo/misc/client_simple_streamlit.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5aa361012431c41ceb152e36f268ff303c68c80
--- /dev/null
+++ b/demo/misc/client_simple_streamlit.py
@@ -0,0 +1,359 @@
+import streamlit as st
+import requests
+from pathlib import Path
+import base64
+from PIL import Image
+import numpy as np
+import io
+import uuid
+from streamlit_drawable_canvas import st_canvas
+from demo.api_data_defs import ChatRequest, ChatMessage, ContentPart
+from typing import Dict
+import time
+import json
+
+API_URL = "http://localhost:8000/v1/chat/completions"
+DEMO_DIR = Path("demo")
+
+def square_crop(image: Image.Image) -> Image.Image:
+ width, height = image.size
+ side = min(width, height)
+ left = (width - side) // 2
+ top = (height - side) // 2
+ right = left + side
+ bottom = top + side
+ return image.crop((left, top, right, bottom))
+
+def process(image: Image.Image, desired_resolution: int = 256) -> Image.Image:
+ cropped_image = square_crop(image.convert("RGB"))
+ return cropped_image.resize(
+ (int(desired_resolution), int(desired_resolution)), Image.LANCZOS
+ )
+
+DEMOS = [
+ {
+ "name": "Dog",
+ "image": DEMO_DIR / "assets" / "dog.jpg",
+ "mask": DEMO_DIR / "assets" / "dog.json",
+ "text": "A corgi playing in the snow",
+ },
+ {
+ "name": "Landscape",
+ "image": DEMO_DIR / "assets" / "mountain.jpg",
+ "mask": DEMO_DIR / "assets" / "mountain.json",
+ "text": "Snowy mountain peak.",
+ },
+ {
+ "name": "Architecture",
+ "image": DEMO_DIR / "assets" / "building.jpg",
+ "mask": DEMO_DIR / "assets" / "building.json",
+ "text": "Modern glass skyscraper",
+ }
+]
+
+# Custom CSS for animations and layout
+st.markdown("""
+
+""", unsafe_allow_html=True)
+
+def load_demo_assets(demo, config):
+ """Load demo assets with error handling"""
+ try:
+ st.session_state.demo_image = process(Image.open(demo["image"]), config["resolution"])
+ st.session_state.original_image = np.array(st.session_state.demo_image)
+ st.session_state.demo_text = demo["text"]
+ if demo["mask"].exists():
+ with demo["mask"].open("r") as f:
+ print(f"Loaded mask from {demo['mask']}")
+ st.session_state.initial_drawing = json.load(f)
+ breakpoint()
+ else:
+ st.warning(f"Mask not found for {demo['name']}")
+ st.session_state.initial_drawing = None
+ except Exception as e:
+ st.error(f"Failed to load {demo['name']} demo: {str(e)}")
+
+def encode_image(file: Path | io.BytesIO | Image.Image) -> Dict[str, str]:
+ if isinstance(file, Image.Image):
+ buffered = io.BytesIO()
+ file.save(buffered, format="JPEG")
+ base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
+ elif isinstance(file, Path):
+ with file.open("rb") as img_file:
+ base64_str = base64.b64encode(img_file.read()).decode("utf-8")
+ else:
+ base64_str = base64.b64encode(file.getvalue()).decode("utf-8")
+ return {"url": f"data:image/jpeg;base64,{base64_str}"}
+
+def encode_array_image(array: np.ndarray) -> Dict[str, str]:
+ im = Image.fromarray(array) if isinstance(array, np.ndarray) else array
+ buffered = io.BytesIO()
+ im.save(buffered, format="JPEG")
+ base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
+ return {"url": f"data:image/jpeg;base64,{base64_str}"}
+
+def get_boolean_mask(canvas_data):
+ if canvas_data is None or canvas_data.image_data is None:
+ return None, None
+ mask_data = canvas_data.json_data.get("objects", [])
+ if not mask_data:
+ return np.zeros_like(st.session_state.original_image, dtype=np.uint8), None
+ mask = np.zeros(st.session_state.original_image.shape[:2], dtype=np.uint8)
+ for obj in mask_data:
+ if obj.get("type") == "path":
+ path = obj.get("path")
+ # Custom processing of the path could be added here
+ return mask * 255, None
+
+# Initialize session state variables
+if "demo_image" not in st.session_state:
+ st.session_state.demo_image = None
+if "demo_text" not in st.session_state:
+ st.session_state.demo_text = ""
+if "initial_drawing" not in st.session_state:
+ st.session_state.initial_drawing = None
+if "original_image" not in st.session_state:
+ st.session_state.original_image = None
+if "stroke_image" not in st.session_state:
+ st.session_state.stroke_image = None
+if "response" not in st.session_state:
+ st.session_state.response = None
+
+# Main UI title and demo selection
+st.title("Image + Text Input Demo")
+
+# Add configuration options in sidebar before any processing
+st.sidebar.header("Configuration")
+config = {
+ "max_tokens": st.sidebar.number_input("Max Tokens", value=32, min_value=1, key="max_tokens"),
+ "resolution": st.sidebar.number_input("Resolution", value=256, min_value=64, key="resolution"),
+ "sampling_steps": st.sidebar.number_input("Sampling Steps", value=32, min_value=1, key="sampling_steps"),
+ "top_p": st.sidebar.number_input("Top P", value=0.95, min_value=0.0, max_value=1.0, key="top_p"),
+ "temperature": st.sidebar.number_input("Temperature", value=0.9, min_value=0.0, max_value=2.0, key="temperature"),
+ "maskgit_r_temp": st.sidebar.number_input("MaskGit R Temp", value=4.5, min_value=0.0, key="maskgit_r_temp"),
+ "cfg": st.sidebar.number_input("CFG", value=2.5, min_value=0.0, key="cfg"),
+ "sampler": st.sidebar.selectbox(
+ "Sampler",
+ options=["maskgit", "maskgit_nucleus", "ddpm_cache"],
+ index=1,
+ key="sampler"
+ ),
+ "save_mask_enabled": True
+}
+
+st.subheader("Example Inputs")
+with st.container():
+ cols = st.columns(len(DEMOS))
+ for col, demo in zip(cols, DEMOS):
+ with col:
+ try:
+ demo_html = f"""
+
+
+
))['url']})
+
+
+
+
{demo['name']} Example
+
{demo['text']}
+
+
+ """
+ st.markdown(demo_html, unsafe_allow_html=True)
+
+ if st.button(f"Load {demo['name']}", key=f"demo_{demo['name']}"):
+ load_demo_assets(demo, config)
+
+ if not demo["image"].exists():
+ st.warning(f"Missing assets for {demo['name']}")
+
+ except Exception as e:
+ st.error(f"Error loading {demo['name']}: {str(e)}")
+
+# Layout: two columns - left for input, right for output
+col_input, col_output = st.columns(2)
+
+with col_input:
+ st.subheader("Input")
+ # st.markdown('', unsafe_allow_html=True)
+ canvas_placeholder = st.empty()
+ user_input = st.text_input(
+ "Input — \"\" denotes a mask token. \"\" denotes N.",
+ value=st.session_state.get("demo_text", "")
+ )
+ uploader_placeholder = st.empty()
+
+ # Always show uploader below canvas to allow image changes
+ with uploader_placeholder.container():
+ # Use a unique key for the uploader so it stays consistent
+ uploaded_file = st.file_uploader("Upload image", type=["png", "jpg", "jpeg"], key="uploader")
+ if uploaded_file:
+ image = process(Image.open(uploaded_file), config["resolution"])
+ st.session_state.original_image = np.array(image)
+
+ # Render canvas only when an image is available
+ if st.session_state.original_image is not None:
+ print(f"Loading canvas...")
+ with canvas_placeholder.container():
+ canvas_result = st_canvas(
+ fill_color="rgba(0,0,0,0)",
+ stroke_width=6,
+ stroke_color="#000000",
+ background_image=Image.fromarray(st.session_state.original_image),
+ initial_drawing=st.session_state.initial_drawing,
+ height=256,
+ width=256,
+ drawing_mode="freedraw",
+ key="canvas"
+ )
+ else:
+ canvas_result = None
+ canvas_placeholder.empty()
+
+ # Add save mask button conditional on flag
+ if config["save_mask_enabled"] and canvas_result is not None and canvas_result.image_data is not None:
+ if st.button("💾 Save Current Mask", help="Save drawn mask as SVG"):
+ # Generate unique filename
+ save_dir = DEMO_DIR / "assets" / "saved_masks"
+ save_dir.mkdir(exist_ok=True)
+ filename = f"mask_{uuid.uuid4().hex[:8]}.json"
+
+ json_data = json.dumps(canvas_result.json_data)
+ (save_dir / filename).write_text(json_data)
+
+ st.session_state.last_saved_mask = {
+ "path": str(save_dir / filename),
+ "timestamp": time.time()
+ }
+ st.success(f"Mask saved as {filename}")
+
+ # Show save confirmation temporarily
+ if "last_saved_mask" in st.session_state and (time.time() - st.session_state.last_saved_mask["timestamp"]) < 5:
+ st.info(f"Last saved: {Path(st.session_state.last_saved_mask['path']).name}")
+
+ # Submission button
+ if st.button("Submit"):
+ if uploaded_file or user_input or st.session_state.demo_image:
+ with st.spinner("Generating response..."):
+ start_time = time.time()
+ mask, composite = get_boolean_mask(canvas_result)
+ messages = []
+ if user_input:
+ messages.append(ContentPart(type="text", text=user_input))
+ current_image = uploaded_file if uploaded_file else st.session_state.demo_image
+ if current_image:
+ if uploaded_file:
+ img_data = encode_image(io.BytesIO(uploaded_file.getvalue()))["url"]
+ else:
+ img_data = encode_image(current_image)["url"]
+ img_part = ContentPart(
+ type="image_url",
+ image_url={"url": img_data},
+ is_mask=False
+ )
+ messages.append(img_part)
+ print(f"mask is none: {mask is None}")
+ if mask is not None:
+ mask_data = encode_array_image(mask)["url"]
+ mask_part = ContentPart(
+ type="image_url",
+ image_url={"url": mask_data},
+ is_mask=True
+ )
+ messages.append(mask_part)
+
+ # print(f"messages: {messages}")
+ payload = ChatRequest(
+ messages=[ChatMessage(role="user", content=messages)],
+ model="unidisc",
+ **config # Use the config dictionary instead of inline sidebar inputs
+ ).model_dump()
+
+ response = requests.post(API_URL, json=payload)
+ if response.status_code == 200:
+ st.session_state.response = response.json()
+ else:
+ st.error(f"API Error: {response.text}")
+
+with col_output:
+ st.subheader("Output")
+ if st.session_state.response:
+ if "choices" in st.session_state.response:
+ content = st.session_state.response["choices"][0]["message"]["content"]
+ if isinstance(content, list):
+ for part in content:
+ if part["type"] == "text":
+ st.text_input(value=part["text"], label="Unmasked Text", disabled=True)
+ elif part["type"] == "image_url":
+ st.image(part["image_url"]["url"], use_container_width=False, width=256) # Set a fixed width
+ st.markdown('', unsafe_allow_html=True)
+ else:
+ st.text_input(value=content, label="Unmasked Text", disabled=True)
diff --git a/demo/scoring/README.md b/demo/scoring/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..abd795088c848c4cda47f66074b62812538abeb3
--- /dev/null
+++ b/demo/scoring/README.md
@@ -0,0 +1,15 @@
+# Model scoring
+
+This folder contains code to test UniDisc under various configurations/checkpoints. Specifically, we generate a set of images and captions including masks for each, and then call the server under various configurations. We use a set of reward models from `model_eval.py` to score each output.
+
+Here is an example workflow:
+
+```bash
+uv run demo/scoring/generate_input.py input/v1 --num_pairs 500 --mask_txt --mask_img
+
+uv run demo/scoring/call_model.py --input_dir input/v1 --output_dir generated/v1 --num_pairs 200 --iterate_over_modes
+
+uv run accelerate launch --main_process_port $RANDOM demo/scoring/generate_rewards.py --input_dir generated/v1 --output_file rewards_v1.json --batch_size 32
+
+uv run demo/scoring/analyze_rewards.py rewards_v1.json --save_image
+```
\ No newline at end of file
diff --git a/demo/scoring/analyze_rewards.py b/demo/scoring/analyze_rewards.py
new file mode 100644
index 0000000000000000000000000000000000000000..a109d351199b72eca74378da9eab1c05436f708b
--- /dev/null
+++ b/demo/scoring/analyze_rewards.py
@@ -0,0 +1,203 @@
+from collections import defaultdict
+from pathlib import Path
+import json
+import re
+import typer
+
+typer.main.get_command_name = lambda name: name
+app = typer.Typer(pretty_exceptions_show_locals=False)
+
+# Pre-compile the regex pattern to extract the prefix.
+PREFIX_PATTERN = re.compile(r"(.+?)__pair_")
+
+def extract_prefix(folder: str) -> str:
+ """
+ Extracts the prefix from a folder name using the PREFIX_PATTERN.
+ If the pattern does not match, returns the folder name as-is.
+ """
+ match = PREFIX_PATTERN.match(folder)
+ return match.group(1) if match else folder
+
+def get_ignored_reward_keys(prefix: str) -> set[str]:
+ """
+ Returns a set of raw reward keys to ignore based on the prefix.
+ Adjust this mapping as your application logic requires.
+ """
+ if "capmask" not in prefix and "cap" in prefix:
+ return {"text_reward_model_score"}
+ elif "imgmask" not in prefix and "img" in prefix:
+ return {"laion_aesthetic_score"}
+ return set()
+
+@app.command()
+def main(
+ input_file: Path,
+ save_image: bool = False
+):
+ """
+ Reads a generated JSON rewards file and, for each dataset,
+ processes each unique prefix. For each prefix, it finds the matching examples
+ in the dataset and computes:
+ - The overall normalized reward average (normalized from 0 to 1).
+ - The normalized average for each raw reward type (ignoring certain types based on the prefix).
+
+ The prefix is extracted from each folder name by matching everything before the
+ '__pair_' substring. If the folder name does not match, the entire folder name is used.
+
+ Normalization is performed using the global minimum and maximum values for each
+ reward type (computed over all datasets and indices where the reward is not ignored).
+ """
+ try:
+ content = input_file.read_text()
+ data = json.loads(content)
+ except Exception as e:
+ typer.echo(f"Error reading JSON file: {e}")
+ raise typer.Exit(code=1)
+
+ # ------------------------------------------------------------------------
+ # First pass: Compute global normalization stats for each prefix and reward_key.
+ #
+ # For every dataset, we iterate over its folder_names. For each index,
+ # we compute the folder prefix and (using get_ignored_reward_keys) decide which
+ # reward keys to process. For each such key and index (if the index exists in the
+ # reward key's list), we update our normalization mapping.
+ #
+ # norm_stats is a dict mapping prefix -> dict mapping reward_key -> (global_min, global_max)
+ # ------------------------------------------------------------------------
+ norm_stats: dict[str, dict[str, tuple[float, float]]] = {}
+ for dataset in data.values():
+ folder_names: list[str] = dataset.get("folder_names", [])
+ raw_rewards: dict[str, list[float]] = dataset.get("raw_rewards", {})
+ for i, folder in enumerate(folder_names):
+ current_prefix = extract_prefix(folder)
+ ignore_keys = get_ignored_reward_keys(current_prefix)
+ for reward_key, values in raw_rewards.items():
+ if reward_key in ignore_keys:
+ continue
+ if i >= len(values):
+ continue
+ value = values[i]
+ if current_prefix not in norm_stats:
+ norm_stats[current_prefix] = {}
+ if reward_key not in norm_stats[current_prefix]:
+ norm_stats[current_prefix][reward_key] = (value, value)
+ else:
+ curr_min, curr_max = norm_stats[current_prefix][reward_key]
+ norm_stats[current_prefix][reward_key] = (min(curr_min, value), max(curr_max, value))
+
+ # Determine unique prefixes from all datasets.
+ unique_prefixes: set[str] = set()
+ for dataset in data.values():
+ folder_names = dataset.get("folder_names", [])
+ for folder in folder_names:
+ unique_prefixes.add(extract_prefix(folder))
+ unique_prefixes = sorted(unique_prefixes)
+
+ print(f"Found {len(unique_prefixes)} unique prefixes: {unique_prefixes}")
+
+ # ------------------------------------------------------------------------
+ # For each prefix, process and sort dataset outputs by overall normalized reward.
+ #
+ # In each dataset we find the indices with the current prefix, then for each
+ # reward key (that is not globally ignored for this prefix) we first normalize
+ # each reward value using the pre-computed min and max and then average the values.
+ # The overall average is computed (as in the original code) by summing the averages
+ # for each reward key and dividing by the total number of raw reward keys.
+ # ------------------------------------------------------------------------
+ for prefix in unique_prefixes:
+ typer.echo(f"Prefix: {prefix}")
+ dataset_outputs = [] # List of tuples: (overall_avg, output_string)
+ img_outputs = defaultdict(list)
+ for dataset_name, dataset in data.items():
+ output_lines = []
+ output_lines.append(f" Dataset: {dataset_name}")
+
+ folder_names: list[str] = dataset.get("folder_names", [])
+ folder_paths: list[str] = dataset.get("folder_paths", [])
+ raw_rewards: dict[str, list[float]] = dataset.get("raw_rewards", {})
+
+ if not folder_names:
+ output_lines.append(" No folder names provided in this dataset.")
+ dataset_outputs.append((float("-inf"), "\n".join(output_lines)))
+ continue
+
+ # Compute the indices in this dataset with the target prefix.
+ indices = [
+ idx for idx, folder in enumerate(folder_names)
+ if extract_prefix(folder) == prefix
+ ]
+
+ if save_image:
+ num_to_save = 2
+ _folder_paths = sorted([Path(p) for p in folder_paths])
+ for idx in indices[:num_to_save]:
+ img_outputs[_folder_paths[idx].name].append((dataset_name, _folder_paths[idx]))
+
+ ignore_keys = get_ignored_reward_keys(prefix)
+ reward_details = ""
+ total_norm_rewards = 0.0
+ for reward_key, values in raw_rewards.items():
+ if reward_key in ignore_keys:
+ continue
+
+ # Retrieve the global min and max for this reward key under this prefix.
+ norm_info = norm_stats.get(prefix, {}).get(reward_key)
+ if norm_info is None:
+ reward_details += f"{reward_key}: No data, "
+ continue
+ min_val, max_val = norm_info
+
+ # Normalize the values using the global min and max.
+ group_norm_values = []
+ for i in indices:
+ if i < len(values):
+ orig_value = values[i]
+ normalized = ((orig_value - min_val) / (max_val - min_val)) if max_val != min_val else 0.0
+ group_norm_values.append(normalized)
+ if group_norm_values:
+ avg_norm = sum(group_norm_values) / len(group_norm_values)
+ reward_details += f"{reward_key}: {avg_norm:.4f}, "
+ total_norm_rewards += avg_norm
+ else:
+ reward_details += f"{reward_key}: No data, "
+
+ # Compute the overall average normalized reward for sorting.
+ overall_avg = total_norm_rewards / len(raw_rewards) if raw_rewards else 0.0
+
+ reward_details = f" Avg: {overall_avg:.4f}, " + reward_details
+ output_lines.append(reward_details)
+ dataset_outputs.append((overall_avg, "\n".join(output_lines)))
+
+ # Sort the dataset outputs by overall average normalized reward (descending).
+ for avg, out in sorted(dataset_outputs, key=lambda x: x[0], reverse=True):
+ typer.echo(out)
+
+ typer.echo("-" * 40)
+
+ if save_image:
+ from unidisc.utils.viz_utils import create_text_image
+ from PIL import Image
+ from image_utils import Im
+ for k, v in img_outputs.items():
+ imgs = []
+ for _dataset_name, _folder_path in v:
+ def get_img(_img_path, _txt_path):
+ _img = Image.open(_img_path).resize((1024, 1024))
+ out = f'{_dataset_name}: {_txt_path.read_text()}'
+ txt_img = create_text_image(out, desired_width=_img.width)
+ _img = Im.concat_vertical(_img, txt_img)
+
+ input_img = None
+ if (_folder_path / "input_image.png").exists():
+ input_img = get_img(_folder_path / "input_image.png", _folder_path / "input_caption.txt")
+
+ out_img = get_img(_folder_path / "image.png", _folder_path / "caption.txt")
+ if input_img:
+ imgs.append(Im.concat_vertical(input_img, out_img))
+ else:
+ imgs.append(out_img)
+
+ Im.concat_horizontal([x.pil for x in imgs]).save(f"{k}.png")
+
+if __name__ == "__main__":
+ app()
diff --git a/demo/scoring/call_model.py b/demo/scoring/call_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb46d1780f4d094703090b5552d6cc9e83d00896
--- /dev/null
+++ b/demo/scoring/call_model.py
@@ -0,0 +1,342 @@
+import base64
+import io
+import json
+import shutil
+import time
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from pathlib import Path
+
+import numpy as np
+import requests
+import typer
+from PIL import Image
+from tqdm import tqdm
+from image_utils import Im
+
+typer.main.get_command_name = lambda name: name
+app = typer.Typer(pretty_exceptions_show_locals=False)
+
+def square_crop(image: Image.Image) -> Image.Image:
+ """Crop the image to a square (centered)."""
+ width, height = image.size
+ side = min(width, height)
+ left = (width - side) // 2
+ top = (height - side) // 2
+ right = left + side
+ bottom = top + side
+ return image.crop((left, top, right, bottom))
+
+def process(image: Image.Image, desired_resolution: int = 512) -> Image.Image:
+ """Square-crop and resize the image."""
+ cropped_image = square_crop(image.convert("RGB"))
+ return cropped_image.resize((desired_resolution, desired_resolution), Image.LANCZOS)
+
+def encode_image(file: Path | io.BytesIO | Image.Image) -> dict:
+ """Encode an image as base64 data in a dict of the form {'url': 'data:image/jpeg;base64,...'}."""
+ if isinstance(file, Image.Image):
+ buffered = io.BytesIO()
+ file.save(buffered, format="JPEG")
+ base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
+ elif isinstance(file, Path):
+ with file.open("rb") as img_file:
+ base64_str = base64.b64encode(img_file.read()).decode("utf-8")
+ else:
+ base64_str = base64.b64encode(file.getvalue()).decode("utf-8")
+ return {"url": f"data:image/jpeg;base64,{base64_str}"}
+
+def encode_array_image(array: np.ndarray) -> dict:
+ """Encode a mask array as base64 data in a dict of the form {'url': 'data:image/jpeg;base64,...'}."""
+ if array.dtype == bool:
+ array = array.astype(np.uint8) * 255
+ im = Image.fromarray(array)
+ buffered = io.BytesIO()
+ im.save(buffered, format="JPEG", quality=95)
+ base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
+ return {"url": f"data:image/jpeg;base64,{base64_str}"}
+
+def call_unidisc_api(
+ image_path: Path | None,
+ caption: str | None,
+ mask_path: Path | None,
+ cfg: dict,
+) -> list:
+ """
+ Build the payload and call the UniDisc API, returning a list of
+ output pieces. Each piece is a dict with either:
+ {"type": "text", "text": "..."}
+ or {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}}
+ """
+ # Prepare message content as in reference code
+ messages = []
+ if caption:
+ messages.append({"type": "text", "text": caption})
+
+ if image_path and image_path.exists():
+ resolution = int(cfg.get("resolution", 512))
+ current_image = process(Image.open(image_path), resolution)
+ img_data = encode_image(current_image)["url"]
+ messages.append({
+ "type": "image_url",
+ "image_url": {"url": img_data},
+ "is_mask": False
+ })
+
+ if mask_path and mask_path.exists():
+ mask_array = np.array(Image.open(mask_path))
+ mask_data_url = encode_array_image(mask_array)["url"]
+ messages.append({
+ "type": "image_url",
+ "image_url": {"url": mask_data_url},
+ "is_mask": True
+ })
+
+ config_payload = {
+ "max_tokens": int(cfg.get("max_tokens", 32)),
+ "resolution": int(cfg.get("resolution", 512)),
+ "sampling_steps": int(cfg.get("sampling_steps", 32)),
+ "top_p": float(cfg.get("top_p", 0.95)),
+ "temperature": float(cfg.get("temperature", 0.9)),
+ "maskgit_r_temp": float(cfg.get("maskgit_r_temp", 4.5)),
+ "cfg": float(cfg.get("cfg", 2.5)),
+ "sampler": cfg.get("sampler", "maskgit_nucleus"),
+ "use_reward_models": bool(cfg.get("use_reward_models", False)),
+ }
+
+ port = cfg.get('port', 8001)
+ hostname = f"{port}" if ":" in port else f"localhost:{port}"
+
+ payload = {
+ "messages": [{"role": "user", "content": messages}],
+ "model": "unidisc",
+ **config_payload
+ }
+
+ api_url = f"http://{hostname}/v1/chat/completions"
+ response = requests.post(api_url, json=payload)
+ if response.status_code != 200:
+ return [{"type": "text", "text": f"API Error: {response.text}", "error": True}]
+
+ response_json = response.json()
+ if "choices" not in response_json:
+ return [{"type": "text", "text": f"Malformed response: {response.text}", "error": True}]
+
+ # The reference code expects "content" to be a list with items typed "text" or "image_url"
+ content = response_json["choices"][0]["message"]["content"]
+ if isinstance(content, list):
+ return content
+ else:
+ # If it's not a list, wrap it
+ return [{"type": "text", "text": content}]
+
+def decode_image_base64(url_str: str) -> Image.Image:
+ """Given a 'data:image/...;base64,...' string, return the PIL.Image."""
+ # e.g. "..."
+ base64_part = url_str.split("base64,")[-1]
+ raw = base64.b64decode(base64_part)
+ return Image.open(io.BytesIO(raw))
+
+def run_inference_for_folder(
+ folder: Path,
+ output_folder: Path,
+ cfg: dict,
+ use_image: bool,
+ use_img_mask: bool,
+ use_caption: bool,
+ use_cap_mask: bool,
+):
+ """
+ For a single folder with an image, caption, and mask, call the API,
+ then write out the returned content (images/text).
+ """
+
+ image_file = None
+ caption_file = None
+ mask_file = None
+ for f in folder.iterdir():
+ name_lower = f.name.lower()
+ if name_lower.startswith("image") and f.suffix.lower() in [".jpg", ".jpeg", ".png"]:
+ image_file = f
+ if name_lower.startswith("mask") and f.suffix.lower() == ".png":
+ mask_file = f
+ if name_lower.startswith("caption") and f.suffix.lower() in [".txt"]:
+ caption_file = f
+ if name_lower.startswith("mask_caption") and f.suffix.lower() == ".txt":
+ mask_caption_file = f
+
+ results = call_unidisc_api(
+ image_path=image_file if use_image else None,
+ caption=mask_caption_file.read_text().strip() if (mask_caption_file and use_cap_mask) else (caption_file.read_text().strip() if (caption_file and use_caption) else None),
+ mask_path=mask_file if use_img_mask else None,
+ cfg=cfg,
+ )
+ output_folder.mkdir(parents=True, exist_ok=True)
+
+ text_parts = []
+ img_count = 0
+ for i, item in enumerate(results):
+ if item["type"] == "text":
+ text_parts.append(item["text"])
+ elif item["type"] == "image_url":
+ out_img = decode_image_base64(item["image_url"]["url"])
+ out_img_name = output_folder / f"image.png"
+ out_img.save(out_img_name)
+ img_count += 1
+ if "error" in item:
+ text_parts.append(item["text"])
+
+ cfg['mode'] = f"{'img_' if use_image else ''}{'imgmask_' if use_img_mask else ''}{'cap_' if use_caption else ''}{'capmask_' if use_cap_mask else ''}"
+ cfg['use_image'] = use_image
+ cfg['use_img_mask'] = use_img_mask
+ cfg['use_caption'] = use_caption
+ cfg['use_cap_mask'] = use_cap_mask
+
+ if len(text_parts) > 0:
+ out_txt = output_folder / "caption.txt"
+ out_txt.write_text("\n".join(text_parts))
+ else:
+ shutil.copy(caption_file, output_folder / "caption.txt")
+ print(f"No text found, copied input caption to output: mode={cfg['mode']}")
+
+ if img_count == 0:
+ shutil.copy(image_file, output_folder / "image.png")
+ print(f"No image found, copied input image to output: mode={cfg['mode']}")
+
+ config_file = output_folder / "config.json"
+ config_file.write_text(json.dumps(cfg, indent=2))
+
+ input_img = (mask_file if use_img_mask else (image_file if use_image else None))
+ input_txt = (mask_caption_file if use_cap_mask else (caption_file if use_caption else None))
+
+ input_img = Im(input_img) if input_img else Im.new(h=512, w=512)
+ input_txt = input_txt.read_text().strip() if input_txt else "Empty caption"
+
+ input_img.save(output_folder / "input_image.png")
+ (output_folder / "input_caption.txt").write_text(input_txt)
+
+@app.command()
+def main(
+ input_dir: Path | None = None,
+ output_dir: Path | None = None,
+ param_file: Path | None = None,
+ num_pairs: int | None = None,
+ num_workers: int = 32,
+ batch_sleep: float = 0.2,
+ use_image: bool = False,
+ use_img_mask: bool = False,
+ use_caption: bool = False,
+ use_cap_mask: bool = False,
+ iterate_over_modes: bool = False,
+ single_config: bool = False,
+):
+ """
+ Generate datasets by calling the UniDisc API on each (image, caption, mask) triplet in input_dir.
+
+ Modified version:
+ - Queues tasks in order on a single global ThreadPoolExecutor.
+ - After queueing each batch, sleeps for a bit.
+ - Does not wait indefinitely for a batch to finish before moving on.
+ """
+
+ if use_img_mask:
+ assert use_image
+
+ if input_dir is None or output_dir is None:
+ raise ValueError("Both input_dir and output_dir must be provided.")
+
+ if param_file is not None:
+ all_configs = json.loads(param_file.read_text())
+ if not isinstance(all_configs, list):
+ raise ValueError("param_file must contain a JSON list of configs.")
+ else:
+ all_configs = []
+ for cfg in [2.5]:
+ # for sampler in ["maskgit_nucleus", "maskgit"]:
+ for port in ["babel-10-9:8000", "babel-6-29:8001"]:
+ all_configs.append(dict(port=port, cfg=cfg))
+
+ if single_config:
+ all_configs = [{'port': 'localhost:8000', 'cfg': 2.5}]
+
+ subfolders = sorted([f for f in input_dir.iterdir() if f.is_dir()])
+ if num_pairs is not None:
+ subfolders = subfolders[:num_pairs]
+
+ configs = []
+ from decoupled_utils import sanitize_filename
+ for i, cfg in enumerate(all_configs):
+ # Build config name from key/values if not explicitly provided
+ if "name" not in cfg:
+ # Sanitize values and join with underscores
+ cfg_name = "_".join(
+ f"{k}={str(v).replace('/', '_').replace(' ', '_')}"
+ for k, v in sorted(cfg.items())
+ )
+ else:
+ cfg_name = cfg["name"]
+ cfg_output_dir = output_dir / sanitize_filename(cfg_name)
+ cfg_output_dir.mkdir(parents=True, exist_ok=True)
+ configs.append((cfg, cfg_name, cfg_output_dir))
+
+ # Compute the batch size so each config receives a fair share of workers per batch.
+ # E.g., if num_workers=10 and there are 2 configs then each config gets ~5 workers in a batch.
+ batch_size = max(1, num_workers // len(configs))
+ total_folders = len(subfolders)
+ print(f"Processing {total_folders} folders across {len(configs)} configs with batch size {batch_size} per config.")
+
+ modes = []
+ if iterate_over_modes:
+ modes.append(dict(use_image=False, use_img_mask=False, use_caption=True, use_cap_mask=False)) # T2I
+ modes.append(dict(use_image=True, use_img_mask=False, use_caption=False, use_cap_mask=False)) # I2T
+ modes.append(dict(use_image=True, use_img_mask=True, use_caption=True, use_cap_mask=True)) # Both masked
+ modes.append(dict(use_image=True, use_img_mask=False, use_caption=True, use_cap_mask=True))
+ modes.append(dict(use_image=False, use_img_mask=False, use_caption=True, use_cap_mask=False))
+ else:
+ modes.append(dict(use_image=use_image, use_img_mask=use_img_mask, use_caption=use_caption, use_cap_mask=use_cap_mask))
+
+ # Use a single, global ThreadPoolExecutor so we can submit tasks in order
+ futures = []
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
+ for batch_start in range(0, total_folders, batch_size):
+ batch_folders = subfolders[batch_start : batch_start + batch_size]
+ for cfg, cfg_name, cfg_out in configs:
+ for folder in batch_folders:
+ for mode in modes:
+ use_image = mode['use_image']
+ use_img_mask = mode['use_img_mask']
+ use_caption = mode['use_caption']
+ use_cap_mask = mode['use_cap_mask']
+ key = ""
+ if use_img_mask:
+ key += "imgmask_"
+ elif use_image:
+ key += "img_"
+
+ if use_cap_mask:
+ key += "capmask_"
+ elif use_caption:
+ key += "cap_"
+
+ key = key.removesuffix('_')
+ folder_output = cfg_out / f"{key}__{folder.name}"
+ futures.append(
+ executor.submit(
+ run_inference_for_folder,
+ folder=folder,
+ output_folder=folder_output,
+ cfg=cfg,
+ use_image=use_image,
+ use_img_mask=use_img_mask,
+ use_caption=use_caption,
+ use_cap_mask=use_cap_mask,
+ )
+ )
+ print(f"Queued batch {batch_start} to {batch_start + len(batch_folders)}. Sleeping for {batch_sleep} seconds...")
+ time.sleep(batch_sleep)
+
+ for future in tqdm(as_completed(futures), total=len(futures), desc=f"Processing folders..."):
+ future.result()
+
+ print("All processing complete.")
+
+if __name__ == "__main__":
+ app()
\ No newline at end of file
diff --git a/demo/scoring/generate_input.py b/demo/scoring/generate_input.py
new file mode 100644
index 0000000000000000000000000000000000000000..c11083808954161f22dff57322489be7214a5362
--- /dev/null
+++ b/demo/scoring/generate_input.py
@@ -0,0 +1,108 @@
+from pathlib import Path
+import random
+from tqdm import tqdm
+import typer
+from datasets import load_dataset
+from PIL import Image
+import numpy as np
+import torch
+import transformers
+
+typer.main.get_command_name = lambda name: name
+app = typer.Typer(pretty_exceptions_show_locals=False)
+
+@app.command()
+def main(
+ output_dir: Path,
+ num_pairs: int = 100,
+ shard_start: int = 0,
+ num_shards: int = 1,
+ mask_img: bool = False,
+ mask_txt: bool = False,
+):
+ """
+ Generate a dataset of image-caption pairs from the synthetic dataset, optionally
+ masking text and/or image content.
+ """
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ base_url = "https://huggingface.co/datasets/ProGamerGov/synthetic-dataset-1m-high-quality-captions/resolve/main/data/data-{i:06d}.tar"
+ urls = [base_url.format(i=i) for i in range(shard_start, shard_start + num_shards)]
+ dataset = load_dataset("webdataset", data_files={"train": urls}, split="train", streaming=True)
+
+ if mask_txt:
+ print("Initializing pipeline...")
+ pipeline = transformers.pipeline(
+ "text-generation",
+ model="meta-llama/Llama-3.2-3B-Instruct",
+ torch_dtype=torch.bfloat16,
+ device="cuda" if torch.cuda.is_available() else "cpu"
+ )
+ print(f"Pipeline initialized on {pipeline.device}")
+
+ for idx, sample in tqdm(enumerate(dataset)):
+ if idx >= num_pairs:
+ break
+
+ pair_dir = output_dir / f"pair_{idx:06d}"
+ pair_dir.mkdir(parents=True, exist_ok=True)
+
+ img = sample['jpg'].convert('RGB')
+ if img.width != img.height:
+ min_dim = min(img.width, img.height)
+ left = (img.width - min_dim) // 2
+ top = (img.height - min_dim) // 2
+ img = img.crop((left, top, left + min_dim, top + min_dim))
+ print(f"Cropped image from {img.width}x{img.height} to {min_dim}x{min_dim}")
+
+ caption = sample['json']['short_caption']
+ original_caption = caption
+
+ if mask_txt:
+ mask_percent = random.choice([10, 20, 30, 40, 50, 60, 70, 80, 90])
+ messages = [
+ {"role": "system", "content": f"You are a helpful assistant that masks important parts of captions. Respond only with the caption where important parts are replaced with . Each is a single token so use multiple tokens to mask more text. Keep the masking natural and meaningful. For example, if the caption is 'A man in a red shirt is playing a guitar in a park', you might output 'A in is playing a in a park'. Please mask approximately {mask_percent}% of the caption."},
+ {"role": "user", "content": f"Mask important parts of this caption: {caption}"}
+ ]
+ masked_caption = pipeline(messages, max_new_tokens=200, pad_token_id=pipeline.tokenizer.eos_token_id)
+ caption = masked_caption[0]["generated_text"][-1]["content"].strip().removeprefix("'").removesuffix("'").removeprefix('"').removesuffix('"')
+
+ if "" not in caption:
+ words = original_caption.split()
+ if len(words) > 1:
+ # Choose random start & end in range of words
+ start_idx = random.randint(0, len(words) - 1)
+ end_idx = random.randint(start_idx, len(words) - 1)
+ # Replace consecutive chunk with
+ for i in range(start_idx, end_idx + 1):
+ words[i] = ""
+ caption = "".join(words)
+
+ if mask_img:
+ arr = np.zeros((img.height, img.width), dtype=np.bool_)
+ height, width = arr.shape[:2]
+
+ # Pick a random rectangle
+ rect_w = random.randint(max(1, width // 5), min(width * 9 // 10, width))
+ rect_h = random.randint(max(1, height // 5), min(height * 9 // 10, height))
+ start_x = random.randint(0, width - rect_w)
+ start_y = random.randint(0, height - rect_h)
+
+ arr[start_y:start_y + rect_h, start_x:start_x + rect_w] = True
+
+ # Convert array back to PIL image
+ mask_img = Image.fromarray(arr)
+ mask_img.save(pair_dir / "mask.png")
+
+ img.save(pair_dir / "image.jpg")
+ (pair_dir / "caption.txt").write_text(original_caption)
+ (pair_dir / "mask_caption.txt").write_text(caption)
+
+ if (idx + 1) % 10 == 0:
+ print(f"Saved {idx + 1} pairs...")
+
+ print(f"Successfully saved {num_pairs} image-caption pairs to {output_dir}")
+
+if __name__ == "__main__":
+ app()
diff --git a/demo/scoring/generate_rewards.py b/demo/scoring/generate_rewards.py
new file mode 100644
index 0000000000000000000000000000000000000000..53c6396ed994c6cb53898d93c7be2a4dbb9851cd
--- /dev/null
+++ b/demo/scoring/generate_rewards.py
@@ -0,0 +1,163 @@
+import json
+from pathlib import Path
+
+import torch
+import typer
+from image_utils import Im
+from omegaconf import OmegaConf
+from tqdm import tqdm
+from accelerate.state import PartialState
+from accelerate.utils import gather_object
+from PIL import Image
+
+from decoupled_utils import set_global_breakpoint
+from model import Diffusion
+
+typer.main.get_command_name = lambda name: name
+app = typer.Typer(pretty_exceptions_show_locals=False)
+
+set_global_breakpoint()
+
+@app.command()
+def main(
+ input_dir: Path | None = None,
+ output_file: Path | None = None,
+ batch_size: int = 32,
+ resolution: int = 512,
+ num_pairs: int | None = None,
+ num_dirs: int | None = None,
+):
+ """
+ Process datasets contained in subdirectories of `input_dir`, distributed across multiple GPUs.
+ Each GPU processes complete datasets for better efficiency.
+ """
+ distributed_state = PartialState()
+ device = distributed_state.device
+ dtype = torch.bfloat16
+
+ # Initialize model without Accelerator
+ model = Diffusion(None, None, device, disable_init=True)
+ model.device = device
+ model.dtype = dtype
+
+ reward_config = OmegaConf.create({
+ "dfn_score": 1.0,
+ "hpsv2_score": 1.0,
+ "clip_score": 1.0,
+ "laion_aesthetic_score": 1.0,
+ "text_reward_model_score": 1.0
+ })
+
+ all_rewards = {}
+ # Get all dataset directories and distribute them across GPUs
+ dataset_dirs = sorted([p for p in input_dir.iterdir() if p.is_dir()], key=lambda p: p.name)
+ if not dataset_dirs:
+ if distributed_state.is_main_process:
+ print("No dataset directories found in the input directory.")
+ raise typer.Exit()
+
+ if num_dirs is not None:
+ dataset_dirs = dataset_dirs[:num_dirs]
+
+ # Split datasets across processes
+ with distributed_state.split_between_processes(dataset_dirs) as process_dataset_dirs:
+ for ds_dir in tqdm(process_dataset_dirs, desc=f"Processing datasets (GPU {distributed_state.process_index})"):
+ if distributed_state.is_main_process:
+ print(f"Processing dataset: {ds_dir.name}")
+
+ pair_dirs = sorted([p for p in ds_dir.iterdir() if p.is_dir()], key=lambda p: p.name)
+ if num_pairs is not None:
+ pair_dirs = pair_dirs[:num_pairs]
+ if not pair_dirs:
+ if distributed_state.is_main_process:
+ print(f" No pair subdirectories found in {ds_dir.name}, skipping.")
+ continue
+
+ images = []
+ captions = []
+ for pair_dir in sorted(pair_dirs, key=lambda p: p.name):
+ image_path = pair_dir / "image.png"
+ caption_path = pair_dir / "caption.txt"
+
+ if not (image_path.exists() and caption_path.exists()):
+ print(f" Skipping {pair_dir}: missing image.png or caption.txt")
+ continue
+
+ try:
+ img = Image.open(image_path)
+ if resolution != img.height or resolution != img.width:
+ print(f"WARNING!!! Image resolution {img.height}x{img.width} does not match resolution {resolution}x{resolution}")
+ min_dim = min(img.width, img.height)
+ left = (img.width - min_dim) // 2
+ top = (img.height - min_dim) // 2
+ img = img.crop((left, top, left + min_dim, top + min_dim))
+ img = img.resize((resolution, resolution), Image.Resampling.LANCZOS)
+ images.append(Im(img).torch.unsqueeze(0))
+ except Exception as e:
+ print(f"Error processing image {image_path}: {e}")
+ continue
+
+ try:
+ caption = caption_path.read_text().strip()
+ captions.append(caption)
+ except Exception as e:
+ print(f"Error reading caption {caption_path}: {e}")
+ continue
+
+ num_pairs = len(images)
+ if num_pairs == 0:
+ print(f"No valid pairs found in dataset {ds_dir.name}, skipping.")
+ continue
+
+ dataset_reward_batches = []
+ dataset_raw_rewards = []
+ for i in tqdm(range(0, num_pairs, batch_size), desc="Processing pairs"):
+ batch_imgs = torch.cat(images[i : i + batch_size], dim=0).to(device) / 255.0
+ batch_texts = captions[i : i + batch_size]
+ with torch.inference_mode():
+ rewards, raw_rewards = model.get_rewards(reward_config, batch_imgs, batch_texts, None, return_raw_rewards=True)
+ dataset_reward_batches.append(rewards.cpu())
+ dataset_raw_rewards.append(raw_rewards)
+
+ dataset_rewards_tensor = torch.cat(dataset_reward_batches, dim=0)
+ dataset_raw_rewards_dict = {}
+ for key in raw_rewards.keys():
+ dataset_raw_rewards_dict[key] = torch.cat(
+ [batch[key] for batch in dataset_raw_rewards], dim=0
+ )
+
+ all_rewards[ds_dir.name] = {
+ "rewards": dataset_rewards_tensor.tolist(),
+ "raw_rewards": {k: v.tolist() for k, v in dataset_raw_rewards_dict.items()},
+ "folder_names": [f.name for f in pair_dirs],
+ "folder_paths": [f.as_posix() for f in pair_dirs]
+ }
+ if distributed_state.is_main_process:
+ print(f"Finished processing {num_pairs} pairs from {ds_dir.name}")
+
+ gathered_rewards = gather_object([all_rewards])
+
+ all_keys = set()
+ all_gathered_rewards = {}
+ for i in range(len(gathered_rewards)):
+ assert len(set(gathered_rewards[i].keys()).intersection(all_keys)) == 0
+ all_keys.update(gathered_rewards[i].keys())
+ all_gathered_rewards.update(gathered_rewards[i])
+
+ gathered_rewards = all_gathered_rewards
+
+ if distributed_state.is_main_process:
+ print("All rewards:")
+ print(json.dumps(gathered_rewards, indent=2))
+
+ try:
+ output_file.parent.mkdir(parents=True, exist_ok=True)
+ with open(output_file, "w") as f:
+ json.dump(gathered_rewards, f, indent=2)
+ print(f"Rewards saved to {output_file}")
+ except Exception as e:
+ print(f"Error saving rewards to file: {e}")
+
+
+if __name__ == "__main__":
+ app()
\ No newline at end of file
diff --git a/demo/server.py b/demo/server.py
new file mode 100644
index 0000000000000000000000000000000000000000..08c60d3902eadd989122a74459ef3aceed8ba597
--- /dev/null
+++ b/demo/server.py
@@ -0,0 +1,242 @@
+import asyncio
+import base64
+import logging
+import multiprocessing as mp
+from contextlib import asynccontextmanager
+from pathlib import Path
+from typing import Any, Dict, List, Union
+import random
+import json
+import hydra
+import torch
+import time
+from fastapi import FastAPI, HTTPException, Request
+from fastapi.exceptions import RequestValidationError
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import JSONResponse
+from uvicorn import run
+
+from decoupled_utils import breakpoint_on_error
+from demo.api_data_defs import ChatMessage, ChatRequest, ContentPart
+from demo.inference_utils import (convert_request_base64_to_pil,
+ convert_request_pil_to_base64,
+ trim_merge_messages)
+from utils import set_omega_conf_resolvers
+
+logger = logging.getLogger("uvicorn.error")
+
+mp.set_start_method('spawn', force=True)
+
+set_omega_conf_resolvers()
+
+
+async def dummy_response(messages: List[Dict[str, Any]]) -> ChatRequest:
+ await asyncio.sleep(0.1)
+ response_content = []
+ for msg in messages:
+ if msg["role"] == "user":
+ for item in msg["content"]:
+ if item["type"] == "text":
+ response_content.append(ContentPart(type="text", text="Response to: " + item["text"]))
+ elif item["type"] == "image_url":
+ response_content.append(ContentPart(type="text", text="Image received and processed."))
+
+ image_path = Path("static/0457_01.jpg") # Replace with a real image path
+ if image_path.is_file():
+ with image_path.open("rb") as img_file:
+ base64_str = base64.b64encode(img_file.read()).decode('utf-8')
+ response_content.append(ContentPart(
+ type="image_url",
+ image_url={"url": f"data:image/jpeg;base64,{base64_str}"}
+ ))
+ else:
+ logger.warning(f"Image file not found at {image_path}")
+
+ return ChatRequest(messages=[ChatMessage(role="assistant", content=response_content)])
+
+def call_model(messages: List[Dict[str, Any]], inference) -> ChatRequest:
+ print(f"input messages: {messages}")
+ returned_messages = inference(messages)
+ openai_messages = convert_request_pil_to_base64(returned_messages)
+ return openai_messages
+
+def generate_response(messages: List[Dict[str, Any]], inference, dummy_response: bool = False) -> ChatRequest:
+ if dummy_response:
+ return dummy_response(messages)
+ else:
+ return call_model(messages, inference)
+
+def call(inference, request: ChatRequest):
+ try:
+ print(f"Hash: {request.request_hash}")
+ output_dir = Path(f"{Path(__file__).parent}/outputs/responses")
+ filename = output_dir / f"{request.request_hash}.json"
+
+ if request.request_hash is not None and filename.exists():
+ with open(filename, "r") as f:
+ generated_json = json.load(f)
+ print(f"Response loaded from {filename}")
+ else:
+ processed_messages = convert_request_base64_to_pil(request)
+ processed_messages = trim_merge_messages(processed_messages)
+ generated: ChatRequest = generate_response(processed_messages, inference)
+ generated_json = generated.messages[-1].model_dump()
+
+
+ if request.request_hash is not None and not filename.exists():
+ filename.parent.mkdir(parents=True, exist_ok=True)
+ with open(filename, "w") as f:
+ json.dump(generated.messages[-1].model_dump(), f, indent=2)
+
+ print(f"Response saved to {filename}")
+
+ # OpenAI format
+ return JSONResponse({
+ "id": "cmpl-000",
+ "object": "chat.completion",
+ "created": int(asyncio.get_event_loop().time()),
+ "choices": [{
+ "index": 0,
+ "message": generated_json,
+ "finish_reason": "stop"
+ }],
+ "usage": {
+ "prompt_tokens": 0,
+ "completion_tokens": 0,
+ "total_tokens": 0
+ }
+ })
+
+ except Exception as e:
+ from traceback import format_exc
+ logger.error(f"Error processing request: {str(e)}")
+ logger.error(format_exc())
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+def gpu_worker(gpu_id, config, request_queue, response_queue):
+ torch.cuda.set_device(gpu_id) # We use this instead of CUDA_VISIBLE_DEVICES since the user may have set.
+ from demo.inference import setup
+ inference = setup(config)
+ print(f"GPU {gpu_id} Initialized inference")
+ while True:
+ # Wait for a new request (blocking call)
+ print(f"GPU {gpu_id} Waiting for request")
+ request_data = request_queue.get()
+ print(f"GPU {gpu_id} Received request")
+ if request_data is None:
+ print(f"GPU {gpu_id} Received shutdown signal")
+ break # a way to shut down this worker gracefully
+ try:
+ # Process the request – note that this call is synchronous
+ print(f"GPU {gpu_id} Processing request")
+ start_time = time.time()
+ result = call(inference, request_data)
+ print(f"GPU {gpu_id} Finished processing request in {time.time() - start_time} seconds")
+ response_queue.put(result)
+ print(f"GPU {gpu_id} Put result in response queue")
+ except Exception as e:
+ print(f"GPU {gpu_id} Error processing request {request_data}: {e}")
+ response_queue.put(e)
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ # Check if we're in development mode
+ dev_mode = getattr(app.config, "dev_mode", False)
+ app.state.dev_mode = dev_mode
+ print(f"Dev mode: {dev_mode}")
+
+ if dev_mode:
+ # Development mode: Single synchronous GPU process
+ logging.info("Starting in DEVELOPMENT mode - synchronous operation, no multiprocessing")
+ from demo.inference import setup
+ app.state.inference = setup(app.config)
+ yield
+ else:
+ # Normal mode with worker processes
+ app.state.worker_lock = asyncio.Lock()
+ workers = []
+ num_gpus = torch.cuda.device_count()
+ logging.info(f"Number of GPUs: {num_gpus}")
+ for gpu_id in range(num_gpus):
+ req_q = mp.Queue(maxsize=1) # enforce one request at a time
+ res_q = mp.Queue()
+ p = mp.Process(target=gpu_worker, args=(gpu_id, app.config, req_q, res_q))
+ p.start()
+ workers.append({"process": p, "req_q": req_q, "res_q": res_q})
+ logging.info(f"Started worker {gpu_id}")
+
+ app.state.workers = workers
+ yield
+ # On shutdown: signal all workers to stop and join them
+ for worker in app.state.workers:
+ worker["req_q"].put(None)
+ for worker in app.state.workers:
+ worker["process"].join()
+ logger.info("Worker process joined.")
+
+
+app = FastAPI(title="Multimodal VLM Endpoint", lifespan=lifespan)
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"], # or ["*"] to allow all origins
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+app.state.workers = []
+logger = logging.getLogger("uvicorn")
+
+
+@app.post("/v1/chat/completions")
+async def chat_completion(request: ChatRequest):
+ if getattr(app.state, "dev_mode", False):
+ return call(app.state.inference, request)
+
+ worker = None
+ async with app.state.worker_lock:
+ while worker is None:
+ # Shuffle workers each time to distribute load
+ workers = list(enumerate(app.state.workers))
+ random.shuffle(workers)
+ for i, w in workers:
+ print(f"Trying to assign request to worker {i}")
+ try:
+ w["req_q"].put_nowait(request)
+ worker = w
+ print(f"Assigned request to worker {w['process'].name}")
+ break
+ except mp.queues.Full:
+ print(f"Worker {w['process'].name} is full")
+ continue
+ if worker is None:
+ await asyncio.sleep(0.1)
+
+ loop = asyncio.get_running_loop()
+ result = await loop.run_in_executor(None, worker["res_q"].get)
+ if isinstance(result, Exception):
+ raise HTTPException(status_code=500, detail=str(result))
+ return result
+
+@app.exception_handler(RequestValidationError)
+async def validation_exception_handler(request: Request, exc: RequestValidationError):
+ body = await request.body()
+ logger.error("Request body: %s", body)
+ logger.error("Validation errors: %s", exc.errors())
+ logger.error("Original body: %s", exc.body)
+ return JSONResponse(
+ status_code=422,
+ content={"detail": exc.errors(), "body": exc.body},
+ )
+
+@hydra.main(version_base=None, config_path="../configs", config_name="config")
+@torch.no_grad()
+def main(config):
+ with breakpoint_on_error():
+ app.config = config
+ dev_mode = getattr(config, "dev_mode", False)
+ app.state.dev_mode = dev_mode
+ run(app, host="0.0.0.0", port=getattr(config, "port", 8001))
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/docs/DATA.md b/docs/DATA.md
new file mode 100644
index 0000000000000000000000000000000000000000..ec7097f914d97d4a48948844cf24e358125df4c3
--- /dev/null
+++ b/docs/DATA.md
@@ -0,0 +1,14 @@
+See `unidisc/datasets/preprocessing` for instructions on how to preprocess datasets.
+
+We support the following datasets:
+
+- Cambrian
+- CapsFusion
+- CC12M
+- DataComp1B
+- JourneyDB
+- LAION400M
+- MMC4
+- PixelProse
+
+Additionally, we generated our own synthetic dataset and provide the [generation scripts](unidisc/datasets/preprocessing/unidisc_dataset/README.md) as well as the raw data.
\ No newline at end of file
diff --git a/docs/EVAL.md b/docs/EVAL.md
new file mode 100644
index 0000000000000000000000000000000000000000..1b75f9b80b9157bc543be1634b34228805679155
--- /dev/null
+++ b/docs/EVAL.md
@@ -0,0 +1,11 @@
+## Eval
+To run evaluation, run the following command:
+```
+RUN_CC=1 RUN_DB=1 RUN_FLICKR=1 RUN_COCO=1 RUN_MEDIUM=1 RUN_AR=1 RUN_NAR=1 NUM_GPUS=1 bash scripts/small_scale_eval.sh
+```
+
+`RUN_CC`, `RUN_DB`, `RUN_FLICKR`, `RUN_COCO` all control which datasets to evaluate on.
+
+`RUN_MEDIUM` controls whether to run experiments on the 100M or 300M ckpts.
+
+`RUN_AR` and `RUN_NAR` control whether to run the AR and NAR ckpts.
diff --git a/docs/INSTALL.md b/docs/INSTALL.md
new file mode 100644
index 0000000000000000000000000000000000000000..9d901560fa634961ec8a9d2498d0757ce0cc17be
--- /dev/null
+++ b/docs/INSTALL.md
@@ -0,0 +1,114 @@
+# Installation Guide
+
+First, if you did not clone with submodules (`--recurse-submodules`), run:
+```bash
+git submodule update --init --recursive
+```
+
+## UV (Recommended)
+
+First, install [uv](https://docs.astral.sh/uv/getting-started/installation/):
+
+```bash
+curl -LsSf https://astral.sh/uv/install.sh | sh
+```
+
+**Note:** You may need to set `CUDA_HOME` and have it pointing to a valid CUDA 12.x installation. To use a different CUDA version, please change the sources in `pyproject.toml` (and the `torch`/`torchvision` versions). See [this guide](https://docs.astral.sh/uv/guides/integration/pytorch/) for more details.
+
+Next, run:
+```bash
+uv sync --no-group dev
+uv sync # To install all dependencies: uv sync --all-groups
+```
+
+If it succeeded, that's it! Prefix any commands with `uv run` to use the environment.
+
+E.g., `accelerate launch main.py` -> `uv run accelerate launch main.py`
+
+or
+
+`python main.py` -> `uv run python main.py`
+
+Alternatively, you can activate the environment manually and run as follows:
+```bash
+uv sync
+source .venv/bin/activate
+python main.py
+```
+
+## Pip / Anaconda / Micromamba
+
+### Step 1: Optional: Install Micromamba
+curl -Ls https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xvj bin/micromamba
+export MAMBA_ROOT_PREFIX="~/micromamba"
+eval "$(~/bin/micromamba shell hook -s posix)"
+alias conda='micromamba'
+
+### Step 2: Create conda environment
+`conda create -n unidisc python=3.10`
+`conda config --add channels conda-forge`
+
+If using micromamba:
+`micromamba config append channels conda-forge`
+
+If using conda:
+`conda config --set channel_priority flexible`
+
+### Step 3: Setup CUDA
+To use existing an existing cuda installation:
+`export CUDA_HOME=...`
+
+
+To install CUDA w/conda or micromamba:
+```
+conda install cuda cuda-nvcc -c nvidia/label/cuda-12.4.1
+export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:$LD_LIBRARY_PATH"
+export CUDA_HOME=$CONDA_PREFIX`
+```
+
+### Step 4: Install PyTorch
+
+If using conda/micromamba CUDA:
+`conda install pytorch==2.4.1 pytorch-cuda=12.4 -c pytorch -c nvidia/label/cuda-12.4.1 -c nvidia`
+
+Otherwise, install from PyPI:
+`pip install torch==2.5.0 torchvision==0.20.0 --index-url https://download.pytorch.org/whl/test/cu124`
+
+### [Optional] Install Flash Attention
+```
+pip install --upgrade packaging ninja pip wheel setuptools
+pip install flash-attn --no-build-isolation
+```
+
+#### [Optional] Flash Attn 3
+```
+git clone https://github.com/Dao-AILab/flash-attention
+cd hopper; python setup.py install
+```
+
+To test: `pip install pytest; export PYTHONPATH=$PWD; pytest -q -s test_flash_attn.py`
+
+
+### Step 5: Install Other Dependencies
+
+```
+pip install -r docs/reqs/requirements.txt
+pip install -r docs/reqs/requirements_eval.txt
+pip install --force-reinstall --no-deps -r docs/reqs/forked_requirements.txt
+pip install tensordict-nightly 'git+https://github.com/huggingface/accelerate' --force-reinstall --no-deps
+pip install 'git+https://github.com/huggingface/datasets' 'git+https://github.com/huggingface/transformers'
+pip install 'git+ssh://git@github.com/alexanderswerdlow/hydra.git@working_ci#egg=hydra-core'
+pip install 'git+ssh://git@github.com/alexanderswerdlow/hydra.git@working_ci#egg=hydra-submitit-launcher&subdirectory=plugins/hydra_submitit_launcher'
+pip install 'numpy>2.0.0'
+```
+
+### Misc / Troubleshooting
+- This may be required if you don't install CUDA through conda: `conda install gcc_linux-64==12.4.0 gxx_linux-64===12.4.0`
+- Other non-forked deps [only if they show as not installed]: `pip install hydra-core webdataset`
+- Dependencies you may need for non-core code:
+
+
+```bash
+pip install flask werkzeug sentence_transformers ngrok opencv-python lpips simple_slurm typer ftfy bitsandbytes sentencepiece flask requests peft transformers deepspeed langchain langchain_groq langchain_core langchain_community langchain-openai git+https://github.com/microsoft/mup.git
+pip install fairseq --no-deps
+```
\ No newline at end of file
diff --git a/docs/TOKENIZERS.md b/docs/TOKENIZERS.md
new file mode 100644
index 0000000000000000000000000000000000000000..5d0a3033ef7184b2f7bae1a41f5d6ff991f073a1
--- /dev/null
+++ b/docs/TOKENIZERS.md
@@ -0,0 +1,5 @@
+For all large scale experiments, we use the `vq_ds16_t2i` tokenizer from [LLaMaGen](https://github.com/FoundationVision/LlamaGen).
+
+For small-scale/scaling experiments, we use the MagViTv2 tokenizer from [Show-o](https://github.com/showlab/Show-o).
+
+For CUB200 experiments, we use the [TiTok](https://github.com/bytedance/1d-tokenizer) tokenizer. In experiments, we found this tokenizer to perform the best, however it was not released at the time of our earlier experiments.
diff --git a/docs/TRAIN.md b/docs/TRAIN.md
new file mode 100644
index 0000000000000000000000000000000000000000..faa4a45deabacd8d2056257eb44d8492507d05b3
--- /dev/null
+++ b/docs/TRAIN.md
@@ -0,0 +1,87 @@
+## Small Scale Training
+UNIDISC_FORCE_CUDNN_SPDA_CONTEXT=1 accelerate launch --main_process_port=$RANDOM main.py +experiments='[small_scale_train,ar]' loader.batch_size=8 wandb.name='10_11_ar' trainer.val_check_interval=100 debug=true
+
+### Caching
+UNIDISC_FORCE_CUDNN_SPDA_CONTEXT=1 accelerate launch --main_process_port=$RANDOM main.py +experiments='[small_scale_caching_train]' loader.batch_size=8 wandb.name='10_11_ar' trainer.val_check_interval=100 debug=true
+
+## Large Scale Training
+
+To train the large scale experiments, we recommend to set the following environment variables:
+```
+unset CUDA_VISIBLE_DEVICES; unset CUDA_LAUNCH_BLOCKING; unset NCCL_SOCKET_IFNAME; unset NCCL_NSOCKS_PERTHREAD; unset NCCL_SOCKET_NTHREADS; unset OMP_NUM_THREADS; unset NCCL_P2P_DISABLE; unset NCCL_P2P_LEVEL
+
+export NCCL_DEBUG=INFO
+export PYTHONUNBUFFERED=1
+export UNIDISC_FORCE_CUDNN_SPDA_CONTEXT=1
+export UNIDISC_ROOT_OUTPUT_DIR="outputs"
+```
+
+We have several stages of training. You will have to load from the previous stage's checkpoint, using either `trainer.load_from_state_dict` or by setting `HYDRA_RUN_DIR_NAME` which will automatically load the latest checkpoint/optimizer/sampler states.
+
+1st stage:
+```
+accelerate launch --main_process_port=$RANDOM main.py +experiments='[large_scale_train]' debug=true loader.batch_size=8
+```
+
+2nd stage:
+```
+accelerate launch --main_process_port=$RANDOM main.py +experiments='[large_scale_train,large_scale_train_high_res]' debug=true loader.batch_size=4
+```
+
+3rd stage:
+```
+accelerate launch --main_process_port=$RANDOM main.py +experiments='[large_scale_train,large_scale_train_high_res_interleaved]' debug=true loader.batch_size=1
+```
+
+## SLURM Training
+To train on SLURM, you have two options:
+
+1. sbatch directly from the script: `sbatch scripts/train_large_scale_slurm.sh`. Please edit the script to set the correct number of nodes, gpus per node, and other parameters.
+
+2. Use sbatch through hydra_submitit_launcher. This is recommended for smaller experiments, partially hyperparameter sweeps, although it has been setup to work for multi-node training.
+
+First, uncomment everything under `hydra.launcher` in `configs/config.yaml`. Next, uncomment this at the top of `configs/config.yaml`:
+```
+- override hydra/launcher: submitit_slurm
+```
+
+Next, run the following (change to `uv pip install` if you are using uv):
+
+```
+pip install 'git+ssh://git@github.com/alexanderswerdlow/hydra.git@working_ci#egg=hydra-core'
+pip install 'git+ssh://git@github.com/alexanderswerdlow/hydra.git@working_ci#egg=hydra-submitit-launcher&subdirectory=plugins/hydra_submitit_launcher'
+```
+
+To use hydra_submitit_launcher, append the following to any command:
+
+`devices=8 nodes=1 partition=general --multirun`
+
+You may modify `devices`/`nodes`/`partition` based on your set. See `hydra.launcher` in `configs/config.yaml` to set additional SLURM parameters.
+
+
+Both of the above methods use 1 task per node. Some SLURM clusters prefer to use 1 task per GPU. Please see this amazing guide for more details: `https://github.com/stas00/ml-engineering`. In short, be careful in making this change as there are many subtle differences (e.g., passing signals between processes, how checkpointing/requeing/error handling works, etc.)
+
+## TPU Training
+
+TODO: Add more documentation for TPU training. Our codebase is setup to use TPUs through `torch_xla`, taking advantage of SPMD.
+
+Misc TPU Notes:
+- SPMD has a very confusing setup for SPMD and pretends that each node is a single device. See `decoupled_utils.py` for some of the logic used to handle this. Moreover, getting the proper rank can only be done after spawning SPMD, so we need to handle this as a lot of code needs the device rank on import.
+
+## Attention Kernels
+The codebase currently supports the following:
+
+- PyTorch SDPA (Including CUDNN Flash, Regular Flash, etc.)
+- PyTorch FlexAttention (For interleaved/caching training)
+- Flash Attention 2/3 (With varying kernels, e.g., packed/non-packed/varlen depending on the use case)
+- TorchXLA SPMD FlashAttention (For TPU training)
+
+Generally, we use PyTorch SDPA (preferrably the CUDDN Kernel which you can force with `UNIDISC_FORCE_CUDNN_SPDA_CONTEXT=1`) and FlexAttention for all interleaved/caching training, setting `trainer.compile=true` to improve MFU. We found this to be similar in speed to Flash Attention 2, which at the time of development did not have good compile support.
+
+# Config Notes
+- `data.enable_cuda_in_tensordict_collate` requires `loader.num_workers = 0`
+- Enable `data.move_tensordict_to_shm` to speed up dataloading (keeping `data.keep_tensordict_on_disk = true`), assuming you have enough system memory. Separately, you can disable `data.keep_tensordict_on_disk`, but this will load the entire tensordict into each dataloader worker process (e.g., given 2 GPUs and 4 workers, this will load 8 tensordicts into system memory) which is not possible on larger datasets. Optionally, you can set `+data.shm_path='/path/to/data'` to use a custom path, e.g., to use a scratch disk instead of system memory.
+- You may need to set `NCCL_IB_DISABLE=1` or `NCCL_P2P_DISABLE=1` depending on the system configuration. Setting `NCCL_P2P_LEVEL=NVL` is recommended if the system has NVLink.
+- To use `data.enable_cuda_in_tensordict_collate=true`, you must also set `data.force_mp_spawn=false` and `loader.num_workers>0`.
+- Resuming from checkpoint can be done in multiple ways. Double check the log output to verify the correct checkpoint is being loaded. The tl;dr is the following: If you use `hydra_submitit_launcher` or set `HYDRA_RUN_DIR_NAME`, it will automatically load the latest checkpoint/optimizer/sampler states.
+- To resume from weights only set: `trainer.load_from_state_dict="/path/to/weights.bin"`
\ No newline at end of file
diff --git a/docs/images/banner.webp b/docs/images/banner.webp
new file mode 100644
index 0000000000000000000000000000000000000000..463d8db3af2640d5c27077205c4a93d910be7ca4
--- /dev/null
+++ b/docs/images/banner.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7821a7a1bcb23baf268e1dd85f10677f3666e9aa55bead036a94f24676806959
+size 2342332
diff --git a/docs/images/diagram.webp b/docs/images/diagram.webp
new file mode 100644
index 0000000000000000000000000000000000000000..9f31efa53408695a9385e67fc18561df9a22c232
--- /dev/null
+++ b/docs/images/diagram.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7faa55529b3e6a6d35a4468b50992c5f30eb581644dac76dd3cd1ad232e6299e
+size 1159872
diff --git a/docs/reqs/requirements.txt b/docs/reqs/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4e8ead764278e2f1a187b9a599df1f5332098a0b
--- /dev/null
+++ b/docs/reqs/requirements.txt
@@ -0,0 +1,34 @@
+accelerate
+transformers
+ipdb
+ipython
+rich
+transformers
+wandb
+datasets
+h5py
+fsspec
+timm
+pandas
+ml_collections
+typer
+braceexpand
+torch-fidelity
+scikit-learn
+lovely-tensors
+torchtnt
+torchinfo
+pynvml
+diffusers
+omegaconf
+tensordict
+lightning_utilities
+torchtnt
+torchmetrics
+jaxtyping
+einops
+peft
+sentencepiece
+numpy>=2.0.0
+clean-fid
+hf_transfer
\ No newline at end of file
diff --git a/docs/reqs/requirements_eval.txt b/docs/reqs/requirements_eval.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fc716bbb69587c4cb67b62f3bad1dd9a07a20f22
--- /dev/null
+++ b/docs/reqs/requirements_eval.txt
@@ -0,0 +1,7 @@
+evaluate
+mauve-text
+clean-fid
+hpsv2
+open_clip_torch
+git+https://github.com/boomb0om/text2image-benchmark
+git+https://github.com/openai/CLIP.git
\ No newline at end of file
diff --git a/docs/reqs/requirements_forked.txt b/docs/reqs/requirements_forked.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5af91a84c77e9e0670b7c0a25522035f8523f0c7
--- /dev/null
+++ b/docs/reqs/requirements_forked.txt
@@ -0,0 +1,4 @@
+git+ssh://git@github.com/alexanderswerdlow/webdataset.git@wip
+git+ssh://git@github.com/alexanderswerdlow/submitit.git
+git+ssh://git@github.com/alexanderswerdlow/image_utils.git@wip_v1
+git+https://github.com/huggingface/diffusers.git
\ No newline at end of file
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5bbfa4b7d3155aadc23cbfbf25803ff3c8bb5a4
--- /dev/null
+++ b/main.py
@@ -0,0 +1,1134 @@
+import os
+import sys
+from contextlib import ExitStack
+from pathlib import Path
+
+from constants import CONFIG_PATH, LIB_DIR
+sys.path.append(str(LIB_DIR / "hydra_submitit_launcher"))
+
+import builtins
+import random
+import re
+import signal
+import traceback
+from copy import deepcopy
+from datetime import datetime
+
+import hydra
+import numpy as np
+import omegaconf
+from hydra.core.hydra_config import HydraConfig
+from omegaconf import DictConfig, OmegaConf, open_dict, read_write
+from safetensors.torch import load_file, save_file
+
+import dataloader
+from model import Diffusion
+import utils
+import wandb
+from decoupled_utils import (check_gpu_memory_usage, get_hostname,
+ get_local_rank, get_rank, get_slurm_filename_info,
+ get_slurm_log_prefix, get_tpu_devices,
+ get_world_size, gprint, is_local_main_process,
+ is_main_process, is_torch_cuda_available,
+ is_torch_xla_available, print_params,
+ process_file_prefix, profile_memory, rank_zero_fn,
+ rprint, set_global_breakpoint, set_global_exists,
+ set_timing_builtins, try_except)
+from utils import (ErrorHandler, _print_config, convert_state_dict_keys, set_omega_conf_resolvers, set_torch_defaults)
+
+# Only needed when debugging hydra
+# os.environ["HYDRA_FULL_ERROR"] = "1"
+
+set_global_breakpoint() # Overrides breakpoint() to use ipdb.set_trace() instead and handle distributed training
+set_global_exists()
+set_omega_conf_resolvers()
+
+if is_torch_xla_available():
+ from jax_smi import initialise_tracking
+
+def _load_from_checkpoint(config, tokenizer):
+ OmegaConf.resolve(config)
+ if "hf" in config.backbone:
+ return Diffusion(config=config, tokenizer=tokenizer).to("cuda")
+
+ return Diffusion.load_from_checkpoint(config.eval.checkpoint_path, tokenizer=tokenizer, config=config)
+
+@rank_zero_fn
+def _print_batch(train_ds, valid_ds, tokenizer, k=64):
+ for dl_type, dl in [("train", train_ds), ("valid", valid_ds)]:
+ rprint(f"Printing {dl_type} dataloader batch.")
+ batch = next(iter(dl))
+ rprint("Batch input_ids.shape", batch["input_ids"].shape)
+ first = batch["input_ids"][0, :k]
+ last = batch["input_ids"][0, -k:]
+ rprint(f"First {k} tokens:", tokenizer.decode(first))
+ rprint("ids:", first)
+ rprint(f"Last {k} tokens:", tokenizer.decode(last))
+ rprint("ids:", last)
+
+
+def generate_samples(config, tokenizer):
+ rprint("Generating samples.")
+ model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
+ model.gen_ppl_metric.reset()
+ if config.eval.disable_ema:
+ rprint("Disabling EMA.")
+ model.ema = None
+ stride_length = config.sampling.stride_length
+ num_strides = config.sampling.num_strides
+ for _ in range(config.sampling.num_sample_batches):
+ if config.sampling.semi_ar:
+ _, intermediate_samples, _ = model.restore_model_and_semi_ar_sample(
+ stride_length=stride_length, num_strides=num_strides, dt=1 / config.sampling.steps
+ )
+ text_samples = intermediate_samples[-1]
+ # Note: Samples generated using semi-ar method
+ # need to to be processed before computing generative perplexity
+ # since these samples contain numerous <|endoftext|> tokens
+ # and diffusion.compute_generative_perplexity() discards
+ # any text after the first EOS token.
+ else:
+ samples = model.restore_model_and_sample(num_steps=config.sampling.steps)
+ text_samples = model.tokenizer.batch_decode(samples)
+ model.compute_generative_perplexity(text_samples)
+
+ rprint("Text samples:", text_samples)
+ if not config.sampling.semi_ar:
+ rprint("Generative perplexity:", model.gen_ppl_metric.compute())
+ return text_samples
+
+
+def instantiate_wandb(config, accelerator):
+ if is_torch_xla_available():
+ gprint("Initializing wandb for XLA")
+ if config.mode == 'eval':
+ config.wandb.project = f"{config.wandb.project}-eval"
+ elif config.mode == 'zero-shot-eval':
+ config.wandb.project = f"{config.wandb.project}-zero-shot-eval"
+
+ if config.wandb.group is not None:
+ config.wandb.group = str(config.wandb.group)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ wandb_kwargs = dict(config.wandb)
+
+ if getattr(config, "sweep_id", None) is not None:
+ rprint(f"Setting Wandb group to {config.sweep_id}")
+ wandb_kwargs["group"] = config.sweep_id
+ del wandb_kwargs["project"]
+ accelerator.init_trackers(
+ config.wandb.project, config=OmegaConf.to_container(config, resolve=True, throw_on_missing=True), init_kwargs=dict(wandb=wandb_kwargs)
+ )
+
+ if getattr(config.trainer, "log_code", True) and is_main_process():
+ if "matrix" in get_hostname():
+ rprint(f"Not logging code to wandb on {get_hostname()}")
+ else:
+ rprint(f"Logging code to wandb from {Path(__file__).parent}")
+ try:
+ wandb.run.log_code(
+ root=str(Path(__file__).parent),
+ include_fn=lambda path: any(path.endswith(f) for f in (".py", ".yaml", ".yml", ".txt", ".md")),
+ exclude_fn=lambda path, root: any(x in os.path.relpath(path, root) for x in ("output", "multirun", "logs", "wandb")),
+ )
+ except Exception as e:
+ rprint(f"Failed to log code to wandb: {e}")
+
+ with open_dict(config):
+ try:
+ config.wandb_url = wandb.run.get_url()
+ wandb.define_metric("global_samples")
+ wandb.define_metric("effective_global_tokens")
+ wandb.define_metric("effective_global_step")
+ wandb.define_metric("train_metrics/samples")
+ wandb.define_metric("trainer/loss", step_metric="global_samples")
+ except Exception as e:
+ rprint(f"Failed to get wandb url: {e}")
+
+def instantiate_model(config, tokenizer):
+ model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
+ if config.eval.disable_ema:
+ rprint("Disabling EMA.")
+ model.ema = None
+
+ return model
+
+def gconf(config, attr):
+ return getattr(config, attr, None)
+
+
+def has_ckpt(config, attr):
+ return gconf(config, attr) is not None and utils.fsspec_exists(gconf(config, attr))
+
+
+def set_env_vars(config):
+ import torch
+ hostname = __import__("socket").gethostname()
+ rprint(f"Starting Training on {hostname}")
+ import torch
+ # os.environ["TORCHINDUCTOR_CACHE_DIR"] = str((Path.home() / ".cache" / "torchinductor").resolve())
+
+ if not is_torch_xla_available():
+ try:
+ # Applies the equivalent of ulimit -l unlimited to this process [and children]
+ # This caused a significant amount of pain to figure out
+ import resource
+ soft, hard = resource.getrlimit(resource.RLIMIT_MEMLOCK)
+ resource.setrlimit(resource.RLIMIT_MEMLOCK, (hard, hard))
+ if is_local_main_process():
+ gprint(f"Successfully set RLIMIT_MEMLOCK to {hard}")
+ except ValueError as e:
+ rprint(f"Failed to set RLIMIT_MEMLOCK: {e}")
+ except resource.error as e:
+ rprint(f"Error setting RLIMIT_MEMLOCK: {e}")
+ else:
+ rprint(f"Not setting RLIMIT_MEMLOCK on XLA")
+
+ if "matrix-3-28" in hostname or "matrix-3-26" in hostname:
+ rprint(f"Disabling NCCL P2P")
+ os.environ["NCCL_P2P_DISABLE"] = "1"
+
+ if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") != "":
+ assert False, f"TORCH_DISTRIBUTED_DEBUG is set to: {os.environ.get('TORCH_DISTRIBUTED_DEBUG')}. Please unset it as it starts a gloo backend."
+
+ if config.model.use_spda_attn:
+ os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
+ os.environ["TORCH_CUDNN_MHA_ENABLED"] = "1"
+ rprint("Setting SPDA Flags")
+
+ if config.trainer.detect_anomaly:
+ torch.autograd.set_detect_anomaly(True)
+
+def update_config_before_resolution(config):
+ import torch
+ if hasattr(config, "training"):
+ rprint(f"'training' has been refactored to 'trainer'. Please update the config.")
+
+ with open_dict(config):
+ config.output_dir = os.getcwd()
+ config.logging_dir = os.getcwd()
+ if config.model.use_kv_cache is False and config.mode == "eval" and config.loader.eval_batch_size > 1:
+ config.loader.eval_batch_size = max(config.loader.eval_batch_size, 16)
+
+ # todo revert?
+ if getattr(config.eval, 'txt_img_ratio', None) is not None:
+ # 2,1,0.5,0.25
+ tot = config.model.length
+ # if its 2:1, then distribute the tokens as 2/3, 1/3
+ # if its 1:1, then distribute the tokens as 1/2, 1/2
+ # if its 0.5:1, then distribute the tokens as 2/3, 1/3
+ # if its 0.25:1, then distribute the tokens as 1/4, 3/4
+ if config.eval.txt_img_ratio == 2:
+ # do first 2/3 tokens as text, last 1/3 as image
+ config.model.txt_length = int(tot * 2/3)
+ elif config.eval.txt_img_ratio == 1:
+ config.model.txt_length = int(tot / 2)
+ elif config.eval.txt_img_ratio == 0.5:
+ config.model.txt_length = int(tot * 2/3)
+ elif config.eval.txt_img_ratio == 0.25:
+ config.model.txt_length = int(tot / 4)
+ config.model.img_length = tot - config.model.txt_length
+ config.model.length = config.model.txt_length + config.model.img_length
+ # config.eval.attention_caching_txt_to_img_ratio = config.model.txt_length // 20
+
+ if getattr(config.eval, "varying_seq_len_ratio", False):
+ assert getattr(config.eval, "sampling_step_ratio", None) is not None, "Must set both varying_seq_len_ratio and sampling_step_ratio"
+ config.sampling.steps = int(config.model.length * config.eval.sampling_step_ratio)
+
+ if getattr(config.eval, "ablation_config", False):
+ if config.parameterization == "ar":
+ rprint(f"WARNING!!!!! FORCING AR PARAMS")
+ config.trainer.ar_shift = True
+ config.model.full_attention = False
+
+ config.data.keep_tensordict_on_disk = True
+ if is_torch_cuda_available():
+ if any(x.lower() in torch.cuda.get_device_name().lower() for x in ["v100", "1080", "2080", "quadro", "titan"]) or torch.cuda.get_device_capability()[0] <= 7:
+ rprint(f"Using 2080Ti/V100, setting precision to fp32")
+ config.trainer.precision = "no"
+ config.model.force_optimized_native_attn = False
+ config.trainer.compile = False
+ if any(x.lower() in torch.cuda.get_device_name().lower() for x in ["2080", "quadro"]):
+ config.loader.eval_batch_size = config.loader.eval_batch_size // 7
+ config.loader.batch_size = config.loader.batch_size // 7
+ elif any(x.lower() in torch.cuda.get_device_name().lower() for x in ["1080", "titan"]):
+ config.loader.eval_batch_size = config.loader.eval_batch_size // 6
+ config.loader.batch_size = config.loader.batch_size // 6
+ else:
+ config.loader.eval_batch_size = config.loader.eval_batch_size // 2
+ config.loader.batch_size = config.loader.batch_size // 2
+ elif "a5000" in torch.cuda.get_device_name().lower() or "a4500" in torch.cuda.get_device_name().lower():
+ config.loader.eval_batch_size = config.loader.eval_batch_size // 2
+ config.loader.batch_size = config.loader.batch_size // 2
+ else:
+ rprint(f"Found {torch.cuda.get_device_name()}")
+ config.loader.eval_batch_size = config.loader.eval_batch_size // 2
+ config.loader.batch_size = config.loader.batch_size // 2
+
+ if getattr(config, "parametierzation", None) == "ar" and config.eval.cfg is not None:
+ config.loader.eval_batch_size = config.loader.eval_batch_size // 2
+ config.loader.batch_size = config.loader.batch_size // 2
+
+ config.loader.eval_batch_size = max(config.loader.eval_batch_size, 1)
+ config.loader.batch_size = max(config.loader.batch_size, 1)
+
+ if getattr(config, "parametierzation", None) == "ar":
+ config.trainer.compile = False
+
+ if getattr(config.sampling, "sampling_step_frac", None) is not None:
+ config.sampling.steps = int(config.model.length * config.sampling.sampling_step_frac)
+ rprint(f"Setting sampling steps to {config.sampling.steps}")
+
+ if os.environ.get("SUBMITIT_FOLDER") is not None or os.environ.get("CUSTOM_SBATCH_LAUNCHER", "0") == "1":
+ rprint(f'Using submitit folder: {os.environ.get("SUBMITIT_FOLDER", "")}, setting slurm=True')
+ config.slurm = True
+
+ if (config.debug is False or os.environ.get("HYDRA_RUN_DIR_NAME", None) is not None) and torch.distributed.is_torchelastic_launched():
+ config.trainer.restart_on_failure = True
+ rprint(f"Setting restart_on_failure to True")
+
+ if config.trainer.restart_on_failure and config.mode == 'train':
+ if os.environ.get("HYDRA_RUN_DIR", None) is None and os.environ.get("HYDRA_RUN_DIR_NAME", None) is None:
+ os.environ["HYDRA_RUN_DIR"] = config.output_dir
+ rprint(f"Setting HYDRA_RUN_DIR to {os.environ['HYDRA_RUN_DIR']}")
+ else:
+ rprint(f"Not setting HYDRA_RUN_DIR, already set to {os.environ.get('HYDRA_RUN_DIR', 'N/A')}, and HYDRA_RUN_DIR_NAME is set to {os.environ.get('HYDRA_RUN_DIR_NAME', 'N/A')}")
+
+ os.environ["RESTART_FAULT_TOLERANT"] = "1"
+ rprint(f"Setting RESTART_FAULT_TOLERANT to 1")
+ elif config.trainer.restart_on_failure:
+ rprint(f"Restart_on_failure is True, but mode is not 'train', so not setting restart fault tolerant")
+
+ relevant_vars = {}
+ for key, value in os.environ.items():
+ if "SLURM" in key or "NCCL" in key or "TORCH" in key:
+ relevant_vars[key] = value
+
+ config.env_vars = relevant_vars
+
+ if config.trainer.profile_memory:
+ config.trainer.max_steps = 2
+
+ if config.debug and config.trainer.force_enable_checkpointing is False and (config.trainer.ckpt_steps is None or config.trainer.ckpt_steps > 0):
+ config.trainer.ckpt_steps = 10000
+ rprint(f"Only checkpointing every {config.trainer.ckpt_steps} steps in debug mode")
+
+ if config.loader.global_batch_size is None:
+ config.loader.global_batch_size = config.loader.batch_size * config.trainer.accumulate_grad_batches * (1 if is_torch_xla_available() else get_world_size())
+ config.loader.eval_global_batch_size = config.loader.global_batch_size
+ if config.trainer.scale_lr_by_batch_size:
+ config.optim.lr = config.optim.lr * (config.loader.global_batch_size / 512)
+ rprint(f"Setting global batch size to {config.loader.global_batch_size}, lr to {config.optim.lr}")
+
+ if config.mode != 'train':
+ config.checkpointing.resume_wandb = False
+ config.wandb.resume = None
+
+ if config.trainer.use_spmd_distributed_checkpointing is None:
+ config.trainer.use_spmd_distributed_checkpointing = is_torch_xla_available() and config.trainer.xla_spmd
+
+ if config.trainer.disable_all_eval_generation:
+ config.eval.num_masking_viz_batches=0
+ config.eval.num_uncond_sample_batches=0
+ config.eval.num_sample_batches=0
+ config.eval.num_random_masking=0
+ config.eval.generate_samples=False
+ config.trainer.log_flops=False
+ config.eval.log_every_n_evals=-1
+ config.eval.log_every_n_fid = -1
+ config.model.image_model_fid_eval = False
+ rprint("Disabling all eval generation!!!")
+
+ if os.environ.get("XLA_IR_DEBUG", "0") == "1":
+ config.trainer.tpu_profile = True
+
+ if config.checkpointing_root_dir is not None:
+ assert "checkpoints" in config.checkpointing.save_dir
+ relative_path = Path(*Path(config.checkpointing.save_dir).relative_to(config.root_output_dir).parts[1:])
+ full_checkpointing_dir = Path(config.checkpointing_root_dir) / relative_path
+ if config.checkpointing_root_dir is not None:
+ old_save_dir = Path(config.output_dir) / "checkpoints"
+ full_checkpointing_dir.mkdir(parents=True, exist_ok=True)
+ try:
+ if old_save_dir.exists():
+ rprint(f"WARNING: Cannot create symlink from {old_save_dir} to {full_checkpointing_dir} because {old_save_dir} exists.")
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ old_save_dir = Path(*old_save_dir.parts[:-1]) / f"checkpoints_{timestamp}"
+
+ old_save_dir.symlink_to(full_checkpointing_dir, target_is_directory=True)
+ rprint(f"Created softlink from {old_save_dir} to {full_checkpointing_dir}")
+
+ # Create a symlink from the parent of full_checkpointing_dir named "original" back to config.output_dir
+ original_link = full_checkpointing_dir.parent / "original_output_dir"
+ if not original_link.exists():
+ original_link.symlink_to(Path(config.output_dir).resolve(), target_is_directory=True)
+ rprint(f"Created softlink from {original_link} to {config.output_dir}")
+ else:
+ rprint(f"WARNING: Symlink {original_link} already exists. Skipping creation.")
+
+ except OSError as e:
+ rprint(f"Error creating softlinks: {e}")
+
+ assert getattr(config.data, "allow_label", False) == getattr(config.trainer, "add_label", False) == (getattr(config.model, "add_labels", None) is not None) == getattr(config.eval, "class_conditional_fid", False), f"Mismatching values: data.allow_label={config.data.allow_label}, trainer.add_label={config.trainer.add_label}, model.add_labels={config.model.add_labels}, eval.class_conditional_fid={config.eval.class_conditional_fid}"
+
+ if getattr(config.loader, "num_eval_workers", None) is not None and config.loader.num_workers == 0:
+ rprint(f"Setting num_eval_workers to 0 because num_workers is 0")
+ config.loader.num_eval_workers = 0
+
+ if config.trainer.disable_all_checkpointing:
+ gprint("-"*50)
+ gprint(f"WARNING: DISABLING ALL CHECKPOINTING!!!!")
+ gprint("-"*50)
+ gprint(f"WARNING: DISABLING ALL CHECKPOINTING!!!!")
+ gprint("-"*50)
+ config.trainer.ckpt_steps = 100000000
+
+ if config.sampling.steps != config.sampling.max_sampling_steps:
+ rprint(f"WARNING!!!! steps {config.sampling.steps} != max_sampling_steps {config.sampling.max_sampling_steps}")
+ config.sampling.max_sampling_steps = config.sampling.steps
+
+def get_latest_ckpt(config, input_dir):
+ if input_dir is None or not Path(input_dir).exists():
+ rprint(f"Project dir {input_dir} does not exist")
+ return None
+
+ if config.trainer.xla_spmd and is_torch_xla_available():
+ rprint(f"XLA SPMD detected, using XLA checkpointing")
+ if any(Path(input_dir).iterdir()):
+ rprint(f"Found existing files/folders in {input_dir}")
+ return input_dir
+ else:
+ rprint(f"No folders found in {input_dir}")
+ return None
+
+ folders = [str(folder) for folder in Path(input_dir).iterdir() if folder.is_dir() and ((folder / "model.safetensors").exists() or (folder / "config.yaml").exists())]
+
+ if len(folders) == 0:
+ rprint(f"No folders found in {input_dir}")
+ return None
+
+ def _inner(folder):
+ return list(map(int, re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)))[0]
+
+ folders.sort(key=_inner)
+ rprint(f"Found folders: {folders}")
+ input_dir = folders[-1]
+ return input_dir
+
+def is_sweep():
+ try:
+ subdir = HydraConfig.get().sweep.subdir
+ rprint(f"Found sweep subdir: {subdir}")
+ return True
+ except omegaconf.errors.InterpolationToMissingValueError:
+ return False
+
+def get_sweep_run_name(config):
+ try:
+ subdir = HydraConfig.get().sweep.subdir
+ sweep_str = f"{subdir}_"
+ is_sweep = True
+ except omegaconf.errors.InterpolationToMissingValueError:
+ is_sweep = False
+ sweep_str = f"{os.environ.get('SLURM_JOB_ID', '')}_"
+
+ if getattr(config, "training", None) is not None and getattr(getattr(config, "training", None), "force_keys", None) is not None:
+ rprint("Using legacy keys")
+ forced_keys = set(config.training.force_keys)
+ else:
+ forced_keys = set(getattr(config.trainer, "forced_keys", []))
+
+ if is_sweep:
+ print(
+ f"Getting sweep keys: {HydraConfig.get().job.sweep_keys}, Tasks: {HydraConfig.get().overrides.task}, {getattr(config.trainer, 'forced_keys', [])}"
+ )
+ valid_keys = set(HydraConfig.get().job.sweep_keys)
+ for task in HydraConfig.get().overrides.task:
+ if task.removeprefix("+").split("=")[0] in valid_keys or task.removeprefix("+").split("=")[0] in forced_keys:
+ sweep_str += f"{task.removeprefix('+').split('=')[0].split('.')[-1]}={task.removeprefix('+').split('=')[1]}__"
+ if task.removeprefix("+").split("=")[0] in forced_keys:
+ forced_keys.remove(task.removeprefix("+").split("=")[0])
+ print(f"Forced key: {task.removeprefix('+').split('=')[0]}={task.removeprefix('+').split('=')[1]}")
+
+ for key in sorted(list(forced_keys)):
+ sweep_str += f"{key.split('.')[-1]}={OmegaConf.select(config, key)}__"
+
+ rprint(f"Sweep: {is_sweep=}, {sweep_str=}")
+ return "" if sweep_str == "" else sweep_str[:-2]
+
+def save_config_to_ckpt(config, output_dir, model):
+ with try_except(write_error_to_file=True, clear_cuda_cache=True):
+ with read_write(config):
+ with open_dict(config):
+ config.state.ckpt_step = model.global_step
+ config.state.num_evals = model.num_evals
+
+ OmegaConf.save(config=config, f=Path(output_dir) / "config.yaml")
+ rprint(f"Saved global step {model.global_step}")
+
+def determine_ckpt(config):
+ has_recent_ckpt = False
+ rprint(f"Looking at checkpoint path: {getattr(config.checkpointing, 'resume_ckpt_path', None)}")
+ if (
+ config.checkpointing.resume_from_ckpt
+ and (latest_ckpt := get_latest_ckpt(config, getattr(config.checkpointing, "resume_ckpt_path", None))) is not None
+ and (Path(latest_ckpt) / "config.yaml").exists()
+ ):
+ ckpt_path = latest_ckpt
+ has_recent_ckpt = True
+ if config.slurm:
+ config.wandb.resume = "must"
+ rprint(f"Resuming from checkpoint {ckpt_path}")
+ elif config.checkpointing.resume_from_ckpt and getattr(config.checkpointing, "initial_resume_ckpt_path", None) is not None:
+ ckpt_path = config.checkpointing.initial_resume_ckpt_path
+ rprint(f"Resuming from initial checkpoint {ckpt_path}")
+ else:
+ ckpt_path = None
+
+ if ckpt_path is not None and (config.checkpointing.resume_wandb or has_recent_ckpt):
+ loaded = OmegaConf.load(Path(ckpt_path) / "config.yaml")
+ if loaded.wandb.id is not None:
+ config.wandb.id = str(loaded.wandb.id)
+ rprint(f"Found wandb id: {config.wandb.id}")
+ else:
+ rprint(f"No wandb id found in checkpoint {ckpt_path}")
+
+ if config.checkpointing.resume_wandb and config.wandb.id is not None:
+ config.wandb.resume = "must"
+ rprint(f"Resuming wandb, setting must, run id: {config.wandb.id}")
+ elif config.slurm and config.wandb.id is None:
+ if os.environ.get("SLURM_ARRAY_TASK_COUNT", "") != "" and int(os.environ.get("SLURM_ARRAY_TASK_COUNT", "")) > 1:
+ config.wandb.id = str(os.environ.get("SLURM_ARRAY_JOB_ID")) + f"_{os.environ.get('SLURM_ARRAY_TASK_ID')}"
+ else:
+ config.wandb.id = str(os.environ.get("SLURM_JOB_ID"))
+ rprint(f"Setting wandb id to {config.wandb.id}")
+
+ if config.checkpointing.initial_resume_ckpt_path is not None and config.checkpointing.resume_wandb:
+ assert config.wandb.id is not None
+
+ if config.ckpt is not None:
+ ckpt_path = config.ckpt
+ rprint(f"Running eval with checkpoint {ckpt_path}")
+
+ if config.wandb.id is not None:
+ config.wandb.id = str(config.wandb.id)
+
+ if config.wandb.id is None or getattr(config.trainer, "force_new_wandb_id", False):
+ config.wandb.id = wandb.util.generate_id()
+ config.wandb.resume = "allow"
+ rprint(f"Set wandb id: {config.wandb.id}")
+
+ rprint(f"Using wandb id: {config.wandb.id}")
+ subdir = get_sweep_run_name(config)
+ rprint(f"Wandb name: {config.wandb.name}, Wandb subdir: {subdir}")
+
+ if config.wandb.name == 'default':
+ config.wandb.name = None
+ else:
+ config.wandb.name = (
+ (f"{config.wandb.name}_" if config.wandb.name else "")
+ + (f"{subdir}_" if (subdir is not None and subdir != "") else "")
+ + f"{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
+ )
+
+ if getattr(config.wandb, "group", None) is None and subdir is not None and config.debug and os.environ.get("SLURM_ARRAY_JOB_ID", "") != "":
+ config.wandb.group = os.environ.get("SLURM_ARRAY_JOB_ID")
+ rprint(f"Wandb group: {config.wandb.group}")
+
+ return ckpt_path
+
+def run(config, tokenizer):
+ import torch
+ from accelerate import (Accelerator, DataLoaderConfiguration,
+ DDPCommunicationHookType,
+ DistributedDataParallelKwargs,
+ FullyShardedDataParallelPlugin)
+ from accelerate.state import AcceleratorState
+ from accelerate.utils import GradientAccumulationPlugin, ProjectConfiguration
+
+ set_torch_defaults(config.trainer.benchmark)
+
+ set_env_vars(config)
+ update_config_before_resolution(config)
+ ckpt_path = determine_ckpt(config)
+ OmegaConf.resolve(config)
+ if is_torch_cuda_available():
+ check_gpu_memory_usage()
+
+ if is_torch_cuda_available():
+ rprint(f"pt={torch.__version__}, cuda={torch.version.cuda}, nccl={torch.cuda.nccl.version()}")
+ rprint(f"GPU={torch.cuda.get_device_name()}, device compute capabilities={torch.cuda.get_device_capability()}, pytorch compute capabilities={torch.cuda.get_arch_list()}")
+ elif is_torch_xla_available():
+ rprint(f"XLA Devices={get_tpu_devices()}")
+
+ rprint(
+ f"Initial GROUP_RANK: {os.environ.get('GROUP_RANK', 'N/A')}, RANK: {os.environ.get('RANK', 'N/A')}, LOCAL_RANK: {os.environ.get('LOCAL_RANK', 'N/A')}, WORLD_SIZE: {os.environ.get('WORLD_SIZE', 'N/A')}, MASTER_ADDR: {os.environ.get('MASTER_ADDR', 'N/A')}, MASTER_PORT: {os.environ.get('MASTER_PORT', 'N/A')}, TORCHELASTIC_RUN_ID: {os.environ.get('TORCHELASTIC_RUN_ID', 'N/A')}, TORCHELASTIC_RESTART_COUNT: {os.environ.get('TORCHELASTIC_RESTART_COUNT', 'N/A')}, TORCHELASTIC_MAX_RESTARTS: {os.environ.get('TORCHELASTIC_MAX_RESTARTS', 'N/A')}, LOCAL_WORLD_SIZE: {os.environ.get('LOCAL_WORLD_SIZE', 'N/A')}, Elastic: {torch.distributed.is_torchelastic_launched()}"
+ )
+ rprint(f"Computed Rank: {get_rank()}, Local Rank: {get_local_rank()}, World Size: {get_world_size()}")
+
+ # This lets us have start_timing and end_timing functions and a global enable/disable
+ # We always use torch.cuda.synchronize before/after as otherwise the timing is not very meaningful
+ sync_timing = (config.trainer.nvtx_profile and getattr(config.trainer, "sync_nvtx_timing", True)) or getattr(config.trainer, "sync_timing", False)
+ set_timing_builtins(enable=config.trainer.nvtx_profile, sync=sync_timing)
+
+ num_nodes = config.trainer.num_nodes
+ with open_dict(config):
+ config.trainer = OmegaConf.merge(config.trainer, dict(mixed_precision=config.trainer.precision, log_with="wandb", log_gradients=None))
+ if getattr(config.trainer, "process_dataloader_only", False):
+ gprint("Processing dataloader only")
+ train_ds, valid_ds = dataloader.get_dataloaders(config, tokenizer, device="cpu", skip_train=(config.mode == 'eval' and not config.eval.val_with_train_data))
+ gprint(f"Exiting after processing dataloader")
+ return
+
+ accelerator_project_config = ProjectConfiguration(
+ project_dir=config.output_dir,
+ logging_dir=config.logging_dir,
+ automatic_checkpoint_naming=config.checkpointing.use_automatic_naming,
+ save_on_each_node=False,
+ )
+
+ accelerate_kwargs = dict()
+ gradient_kwargs = dict()
+ if config.trainer.fsdp and not (config.trainer.xla_spmd and is_torch_xla_available()):
+ rprint("Using FSDP...")
+ if config.backbone == "llama" or config.backbone == "gemma":
+ os.environ["ACCELERATE_USE_FSDP"] = "true"
+ os.environ["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP"
+ os.environ["FSDP_BACKWARD_PREFETCH"] = "NO_PREFETCH" # Saved memory
+ os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "true"
+ os.environ["FSDP_FORWARD_PREFETCH"] = "false"
+ os.environ["FSDP_OFFLOAD_PARAMS"] = "false"
+ os.environ["FSDP_SHARDING_STRATEGY"] = "FULL_SHARD"
+ os.environ["FSDP_STATE_DICT_TYPE"] = "SHARDED_STATE_DICT"
+ os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
+ os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
+ fsdp_plugin = FullyShardedDataParallelPlugin()
+ else:
+ os.environ["ACCELERATE_USE_FSDP"] = "true"
+ os.environ["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP" # or "SIZE_BASED_WRAP"
+ if config.backbone == "elm":
+ os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "OpenELMDecoderLayer"
+ os.environ["FSDP_BACKWARD_PREFETCH"] = "BACKWARD_PRE"
+ os.environ["FSDP_SHARDING_STRATEGY"] = "HYBRID_SHARD_ZERO2"
+ else:
+ # Fastest but requires more memory: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.BackwardPrefetch
+ os.environ["FSDP_BACKWARD_PREFETCH"] = "BACKWARD_PRE"
+ # See: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy
+ os.environ["FSDP_SHARDING_STRATEGY"] = "HYBRID_SHARD_ZERO2"
+ os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "DDiTBlock"
+
+ # SHARDED_STATE_DICT is a bit faster, but more complicated as later on we need to merge the shards.
+ from torch.distributed.fsdp.fully_sharded_data_parallel import (
+ FullOptimStateDictConfig, FullStateDictConfig)
+ fsdp_plugin = FullyShardedDataParallelPlugin(
+ state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
+ optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), # SHARDED_STATE_DICT
+ )
+
+ if config.trainer.compile or config.trainer.use_orig_params is True:
+ # https://github.com/huggingface/transformers/pull/24591/files
+ fsdp_plugin.use_orig_params = True
+ rprint("Using orig params for FSDP. This is required for torch.compile to work.")
+
+ accelerate_kwargs["fsdp_plugin"] = fsdp_plugin
+ gradient_kwargs["sync_each_batch"] = False
+
+ if getattr(config.trainer, "fsdp_sync_each_batch", False): # Reduce memory usage: https://huggingface.co/docs/accelerate/en/concept_guides/gradient_synchronization#nosync-requires-additional-gpu-memory-when-using-fsdp
+ rprint("Using sync each batch for Chameleon")
+ gradient_kwargs["sync_each_batch"] = True
+
+ elif config.trainer.xla_spmd is False: # For XLA FSDP, we init where we normally prepare()
+ rprint("Using DDP...")
+ ddp_kwargs = DistributedDataParallelKwargs(
+ find_unused_parameters=config.trainer.find_unused_parameters,
+ comm_hook=DDPCommunicationHookType.BF16,
+ static_graph=config.trainer.accumulate_grad_batches == 1,
+ gradient_as_bucket_view=True,
+ )
+ # bucket_cap_mb=32,
+
+ # Not needed right now
+ from datetime import timedelta
+
+ from accelerate.utils import InitProcessGroupKwargs
+ init_process_group_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800))
+ accelerate_kwargs["kwargs_handlers"] = [ddp_kwargs, init_process_group_kwargs]
+ else:
+ rprint(f"Did not choose DDP or FSDP.")
+
+ if config.trainer.accumulate_grad_batches <= 0:
+ gprint("WARNING!!!!!! Accumulate grad batches is <= 0, setting to 1")
+ config.trainer.accumulate_grad_batches = 1
+
+ gradient_accumulation_plugin = GradientAccumulationPlugin(
+ num_steps=config.trainer.accumulate_grad_batches,
+ adjust_scheduler=False, # We manually adjust our LR for accumulate_grad_batches
+ sync_with_dataloader=False,
+ **gradient_kwargs
+ )
+
+ if config.trainer.mixed_precision == "bf16" and (is_torch_cuda_available() and not torch.cuda.is_bf16_supported()):
+ rprint(f"No BF16 GPU found, falling back to FP16")
+ config.trainer.mixed_precision = "fp16"
+
+ if config.trainer.mixed_precision == "fp32":
+ config.trainer.mixed_precision = "no"
+ else:
+ if is_torch_xla_available():
+ os.environ["ACCELERATE_DOWNCAST_BF16"] = "true"
+
+ rprint(f"Mixed precision: {config.trainer.mixed_precision}")
+
+ if config.seed is None or getattr(config.eval, 'set_random_gen_seed', False):
+ # do not ask why, has to do something with seeds being reset by val_epoch_end so if you don't execute this code, your generations in val_epoch_end will be same across gpus
+ accelerate_kwargs["rng_types"] = []
+ rprint("No seed provided, disabling accelerate RNG synchronization")
+
+ accelerator = Accelerator(
+ mixed_precision=config.trainer.mixed_precision,
+ log_with=config.trainer.log_with,
+ project_config=accelerator_project_config,
+ gradient_accumulation_plugin=gradient_accumulation_plugin,
+ dataloader_config=DataLoaderConfiguration(split_batches=False, dispatch_batches=False, non_blocking=False),
+ **accelerate_kwargs,
+ )
+
+ gprint(f"Distributed Type: {accelerator.distributed_type}, Accelerator state: {AcceleratorState()}")
+ num_processes = AcceleratorState().num_processes
+ if getattr(config.trainer, "global_num_warmup_steps", None) is not None:
+ rprint(f"Global num_warmup_steps was: {config.lr_scheduler.num_warmup_steps}. Applying to num_warmup_steps")
+ config.lr_scheduler.num_warmup_steps = config.trainer.global_num_warmup_steps
+
+ if getattr(config.trainer, "global_num_training_steps", None) is not None:
+ rprint(f"Global num_training_steps was: {config.lr_scheduler.num_training_steps}. Applying to num_training_steps")
+ config.lr_scheduler.num_training_steps = config.trainer.global_num_training_steps
+
+ if not config.trainer.disable_adjust_num_warmup_steps:
+ rprint(f"Original num_warmup_steps was: {config.lr_scheduler.num_warmup_steps}")
+ config.lr_scheduler.num_warmup_steps = config.lr_scheduler.num_warmup_steps * num_processes
+ rprint(f"Setting num_warmup_steps to: {config.lr_scheduler.num_warmup_steps}")
+
+ if hasattr(config.lr_scheduler, "num_training_steps"):
+ rprint(f"Original num_training_steps was: {config.lr_scheduler.num_training_steps}")
+ config.lr_scheduler.num_training_steps = config.lr_scheduler.num_training_steps * num_processes
+ rprint(f"Setting num_training_steps to: {config.lr_scheduler.num_training_steps}")
+
+ assert config.trainer.allow_dynamic_nodes or (os.environ.get("XLA_USE_SPMD", "0") == "1") or accelerator.num_processes == (
+ config.trainer.devices * num_nodes
+ ), f"Expected {config.trainer.devices * num_nodes} GPUs but got {accelerator.num_processes} processes."
+
+ compute_dtyle = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ compute_dtyle = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ compute_dtyle = torch.bfloat16
+
+ if compute_dtyle != torch.bfloat16:
+ rprint(f"WARNING!!!! Compute dtype is: {compute_dtyle}")
+ else:
+ rprint(f"Compute dtype is: {compute_dtyle}")
+
+ if is_main_process():
+ instantiate_wandb(config, accelerator)
+
+ run_cmd = get_run_cmd(config)
+ with open_dict(config):
+ config.trainer.devices = accelerator.num_processes
+ config.trainer.dtype = str(compute_dtyle)
+ if hasattr(config, "state"):
+ config.state.cmd = run_cmd
+ else:
+ config.state = OmegaConf.create(dict(cmd=run_cmd))
+
+ OmegaConf.set_readonly(config, True)
+
+ if getattr(config.trainer, "attach_oom_observer", False):
+ from torchtnt.utils.oom import attach_oom_observer
+ attach_oom_observer(output_dir=str(os.getcwd()), trace_max_entries=500000)
+ rprint(f"Attached OOM observer to {os.getcwd()}")
+ train_ds, valid_ds = dataloader.get_dataloaders(config, tokenizer, device=accelerator.device, skip_train=(config.mode == 'eval' and not config.eval.val_with_train_data))
+ model = Diffusion(config=config, tokenizer=valid_ds.tokenizer, device=accelerator.device)
+
+ if is_main_process():
+ print_params(model.backbone)
+
+ try:
+ if getattr(config.model, "image_model", False) is False:
+ _print_batch(train_ds, valid_ds, tokenizer)
+ except:
+ pass
+
+ get_ema_path = lambda x: Path(x) / "ema.ckpt"
+ SAMPLER_NAME = "weighted_dataset_sampler"
+
+ def save_model_hook(models, weights, output_dir):
+ nonlocal model, accelerator, train_ds
+
+ if is_main_process():
+ with try_except(write_error_to_file=True):
+ if getattr(model, "ema", None) is not None:
+ torch.save(accelerator.unwrap_model(model).ema.state_dict(), get_ema_path(output_dir))
+ rprint(f"Saved EMA to {get_ema_path(output_dir)}")
+
+ save_config_to_ckpt(config, output_dir, model)
+
+ with try_except(write_error_to_file=True):
+ if config.data.use_weighted_tensordict_sampler:
+ from accelerate.utils import save
+ output_sampler_file = output_dir.joinpath(f"{SAMPLER_NAME}_train.bin")
+ save(train_ds.sampler.state_dict(), output_sampler_file, save_on_each_node=False, safe_serialization=False)
+ rprint(f"Sampler state for dataloader saved in {output_sampler_file}")
+
+ initial_global_step = None
+ def load_model_hook(models, input_dir):
+ nonlocal initial_global_step, model, train_ds
+ config_path = os.path.join(input_dir, "config.yaml")
+ ckpt_config = OmegaConf.load(config_path)
+ initial_global_step = ckpt_config.state.ckpt_step
+ model.global_step = initial_global_step
+ try:
+ if hasattr(config.state, "num_evals"):
+ model.num_evals = config.state.num_evals
+ except Exception as e:
+ rprint(f"Error loading model: {e}")
+ rprint(f"Loaded global step {initial_global_step}")
+
+ state_dict = None
+ if getattr(config.checkpointing, "load_from_old_attention_format", False):
+ state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
+ state_dict = convert_state_dict_keys(state_dict)
+
+ if getattr(model, "ema", None) is not None:
+ if get_ema_path(input_dir).exists():
+ rprint(f"Loading EMA from {get_ema_path(input_dir)}")
+ model.ema.load_state_dict(torch.load(get_ema_path(input_dir), map_location='cpu'))
+ else:
+ rprint(f"No EMA found, initializing EMA with state_dict")
+ if state_dict is None:
+ state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
+
+ # We likely don't need the unwrap, but just to be safe
+ accelerator.unwrap_model(models[0]).load_state_dict(state_dict)
+ from models.ema import EMAModel
+ model.ema = EMAModel(accelerator.unwrap_model(models[0]).parameters(), decay=config.trainer.ema)
+
+ if config.data.use_weighted_tensordict_sampler and not is_torch_xla_available(): # and not config.eval.test_eval_speed:
+ input_sampler_file = Path(input_dir).joinpath(f"{SAMPLER_NAME}_train.bin")
+ if train_ds is not None and input_sampler_file.exists():
+ train_ds.sampler.load_state_dict(torch.load(input_sampler_file))
+ rprint("All dataloader sampler states loaded successfully")
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+ model.init_dataloader(train_ds, valid_ds)
+ model.set_accelerator(accelerator, ckpt_path)
+ model.set_callbacks()
+
+ if getattr(config.checkpointing, "load_from_text_model", None) is not None:
+ rprint(f"Loading from text model")
+ model.custom_load_checkpoint()
+
+ if getattr(config.checkpointing, "load_from_lightning_ckpt", None) is not None:
+ ckpt = torch.load(config.checkpointing.load_from_lightning_ckpt)
+ initial_global_step = ckpt["global_step"]
+ state_dict_ = {k.removeprefix("backbone."): v for k, v in ckpt["state_dict"].items() if "backbone" in k}
+ state_dict_ = {k.replace(".attn_", ".attention.attn_"): v for k, v in state_dict_.items()}
+ accelerator.unwrap_model(model.backbone).load_state_dict(state_dict_)
+
+ if config.trainer.ema > 0:
+ model.ema.load_state_dict(ckpt["ema"])
+
+ rprint(f"Loaded lightning ckpt: {config.checkpointing.load_from_lightning_ckpt}")
+
+ if initial_global_step is not None:
+ # The load_hooks are before accelerate does it's loading and it overwrites model.global_step if we set it there
+ model.global_step = initial_global_step
+ rprint(f"Set global step to {initial_global_step}")
+
+ contexts = []
+ if config.trainer.nvtx_profile:
+ contexts.append(torch.autograd.profiler.emit_nvtx(record_shapes=True))
+
+ if config.trainer.profile_memory:
+ contexts.append(profile_memory())
+
+ using_torch_elastic = torch.distributed.is_torchelastic_launched()
+ if using_torch_elastic:
+ rprint(f"Torchelastic launched: {torch.distributed.is_torchelastic_launched()}")
+ contexts.append(ErrorHandler())
+
+ with ExitStack() as stack:
+ for ctx in contexts:
+ stack.enter_context(ctx)
+
+ rprint(f"output_dir: {config.output_dir}")
+ model.to(accelerator.device)
+ if config.mode == 'train':
+ model.train()
+ elif config.mode == 'eval':
+ if config.eval.standalone_fid:
+ model.validate(None)
+ else:
+ model.validate(None)
+ elif config.mode == 'zero-shot-eval':
+ model.zero_shot_eval()
+ else:
+ raise ValueError(f"Invalid mode: {config.mode}")
+
+ accelerator.end_training()
+
+
+def get_run_cmd(config):
+ orig_argv = deepcopy(sys.argv)
+
+ prepend_argv = []
+ if "HYDRA_RUN_DIR" in os.environ:
+ prepend_argv.append(f"HYDRA_RUN_DIR='{os.environ['HYDRA_RUN_DIR']}'")
+ else:
+ prepend_argv.append(f"HYDRA_RUN_DIR='{str(Path(config.output_dir).resolve())}'")
+
+ if orig_argv[1].startswith("experiments=["):
+ orig_argv[1] = orig_argv[1].removeprefix("experiments=[").removesuffix("]")
+ orig_argv[1] = f"experiments=\'[{orig_argv[1]}]\'"
+
+ if os.environ.get("CUSTOM_SBATCH_LAUNCHER", "0") == "1":
+ sbatch_script_path = 'scripts/slurm.sh'
+ orig_argv.pop(0)
+ orig_argv = ["sbatch", f"--nodes={os.environ.get('SLURM_NNODES', '1')}", f"--gpus-per-node={os.environ.get('SLURM_GPUS_PER_NODE', '1')}", f"--partition={os.environ.get('SLURM_JOB_PARTITION', 'all')}", sbatch_script_path] + orig_argv
+ else:
+ prepend_argv.append("accelerate launch")
+
+ full_run_cmd = " ".join(prepend_argv + orig_argv)
+ rprint(f"Full run cmd: {full_run_cmd}")
+ return full_run_cmd
+
+@hydra.main(version_base=None, config_path=CONFIG_PATH, config_name="config")
+@try_except()
+def main(config):
+ if is_sweep():
+ print(f"Checking if we need to requeue for job id {os.environ['SLURM_JOB_ID']}")
+ from unidisc.utils.slurm_requeue import check_requeue
+ check_requeue()
+ print(f"Done checking if we need to requeue for job id {os.environ['SLURM_JOB_ID']}.")
+
+ """Main entry point for trainer."""
+ import torch # Causes issue pickling if imported by default.
+ if is_torch_xla_available():
+ builtins.HAS_XLA_SPAWNED = True
+ os.environ['PJRT_DEVICE'] = 'TPU'
+
+ if config.trainer.precision == "bf16":
+ os.environ['XLA_USE_BF16'] = '1'
+
+ if config.devices == 1 and config.trainer.xla_spmd is False and config.trainer.fsdp is False:
+ os.environ['TPU_PROCESS_BOUNDS'] = '1,1,1'
+ os.environ['TPU_VISIBLE_CHIPS'] = '0'
+ gprint(f"Setting TPU_PROCESS_BOUNDS: {os.environ['TPU_PROCESS_BOUNDS']}")
+ gprint(f"Setting TPU_VISIBLE_CHIPS: {os.environ['TPU_VISIBLE_CHIPS']}")
+
+ if config.trainer.tpu_eager:
+ os.environ['XLA_USE_EAGER_DEBUG_MODE'] = '1'
+
+ if config.trainer.tpu_compile_debug:
+ os.environ['PT_XLA_DEBUG'] = '1'
+ os.environ['PT_XLA_DEBUG_LEVEL'] = '2'
+ os.environ['XLA_IR_DEBUG'] = '1'
+ os.environ['XLA_HLO_DEBUG'] = '1'
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
+ os.environ['TF_CPP_VMODULE'] = 'xla_graph_executor=5,pjrt_computation_client=3'
+
+ # We intentionally set these after to avoid import side effects
+ spmd_mesh, axis_names, num_nodes = None, None, None
+ if config.trainer.xla_spmd:
+ import torch_xla.core.xla_model as xm
+ import torch_xla.distributed.spmd as xs
+ import torch_xla.runtime as xr
+ from accelerate import PartialState
+ from torch_xla._internal import tpu
+ auto_spmd = getattr(config.trainer, "auto_spmd", False)
+
+ xr.use_spmd(auto=auto_spmd) # Auto causes a crash
+ force_global_devices = getattr(config.trainer, "force_global_devices", None)
+ force_local_devices = getattr(config.trainer, "force_local_devices", None)
+ assert (force_global_devices is None) == (force_local_devices is None), "Must set both or neither"
+
+ if force_global_devices is not None:
+ num_global_devices = force_global_devices
+ num_local_devices = force_local_devices
+ gprint(f"Using force global devices: num_global_devices={num_global_devices}, num_local_devices={num_local_devices}")
+ else:
+ num_global_devices = xr.global_runtime_device_count()
+ num_local_devices = tpu.num_available_devices()
+ assert num_global_devices == tpu.num_expected_global_devices()
+ assert tpu.num_available_devices() == tpu.num_available_chips() == tpu.num_local_processes()
+
+ num_nodes = num_global_devices // num_local_devices
+ spmd_mesh_shape = getattr(config.trainer, "spmd_mesh", None)
+ if spmd_mesh_shape is None:
+ spmd_mesh_shape = (num_nodes, num_local_devices, 1)
+
+ if getattr(config.trainer, "force_disable_replicas", False):
+ spmd_mesh_shape = (1, num_global_devices, 1)
+ rprint(f"Forcing disable replicas: {spmd_mesh_shape}")
+
+ if auto_spmd is False:
+ if getattr(config.trainer, "spmd_multislice", None) is not None:
+ from torch_xla.distributed.spmd import HybridMesh
+ ici_mesh_shape = spmd_mesh_shape
+ dcn_mesh_shape = (config.trainer.spmd_multislice, 1, 1)
+ spmd_mesh = HybridMesh(ici_mesh_shape=ici_mesh_shape, dcn_mesh_shape=dcn_mesh_shape, axis_names=('data','fsdp','tensor'))
+ rprint(f"Using multislice: {config.trainer.spmd_multislice}: {ici_mesh_shape} {dcn_mesh_shape}, {spmd_mesh.shape()}")
+ else:
+ spmd_mesh = xs.Mesh(np.array(range(num_global_devices)), spmd_mesh_shape, ('dcn', 'fsdp', 'model'))
+ xs.set_global_mesh(spmd_mesh)
+
+ config.devices = 1
+ config.nodes = 1
+
+ with read_write(config):
+ with open_dict(config):
+ config.state = OmegaConf.create(dict(spmd_mesh=spmd_mesh_shape))
+ config.state.axis_names = axis_names
+ config.state.num_nodes = num_nodes
+ config.state.num_global_devices = num_global_devices
+ config.state.num_local_devices = num_local_devices
+ config.state.worker_ips = tpu.get_worker_ips()
+ if os.environ.get("TPU_NAME") is not None:
+ config.state.tpu_name = os.environ.get("TPU_NAME")
+
+ if config.trainer.tpu_eager:
+ import torch_xla
+ torch_xla.experimental.eager_mode(True)
+
+ if config.trainer.tpu_profile:
+ if config.trainer.tpu_profile_markers:
+ os.environ['XLA_IR_DEBUG'] = '1'
+ os.environ['XLA_HLO_DEBUG'] = '1'
+ import torch_xla.debug.profiler as xp
+ server = xp.start_server(9012)
+
+ if config.trainer.tpu_cache:
+ import torch_xla.runtime as xr
+ readonly = not is_main_process()
+ rprint(f"Initializing TPU cache with readonly={readonly}")
+ xr.initialize_cache(str((Path.home() / '.cache' / 'unidisc' / f"tpu_{get_rank()}_{get_hostname().replace('-', '_')}").resolve()), readonly=readonly)
+
+ if config.trainer.enable_jax_smi:
+ initialise_tracking()
+ rprint("Initializing jax-smi")
+
+ from unidisc.utils.logging_utils import set_logger
+ set_logger(f"{__name__} {get_slurm_log_prefix()}", Path(f"{get_slurm_filename_info()}_{get_hostname().replace('-', '_')}.out"))
+
+ if is_torch_xla_available():
+ import torch_xla.runtime as xr
+ gprint(
+ f"Computed Rank: {get_rank()}, "
+ f"Is Main Process: {is_main_process()}, "
+ f"Is Local Main Process: {is_local_main_process()}, "
+ f"XLA world size: {xr.world_size()}, "
+ f"XLA Model Ordinal: {xm.get_ordinal()}, "
+ f"XLA Global Ordinal: {xr.global_ordinal()}, "
+ f"XLA Supported Devices: {xm.get_xla_supported_devices()}, "
+ f"Accelerate Local Process Index: {PartialState().local_process_index}, "
+ f"Task ID: {tpu.task_id()}, "
+ f"Worker ID: {tpu.worker_id()} "
+ f"global device count: {xr.global_runtime_device_count()}, "
+ f"local process count: {xr.local_process_count()}, "
+ f"local device count: {xr.local_device_count()}, "
+ f"addressable device count: {xr.addressable_device_count()}, "
+ f"num_expected_global_devices: {tpu.num_expected_global_devices()}, "
+ f"num_available_devices: {tpu.num_available_devices()}, "
+ f"num_available_chips: {tpu.num_available_chips()}, "
+ f"num_local_processes: {tpu.num_local_processes()}, "
+ f"process_bounds_size: {tpu.process_bounds_size()}, "
+ f"get_worker_ips: {tpu.get_worker_ips()}, "
+ f"Computed Num Nodes: {num_nodes}, "
+ f"Specified Mesh: {spmd_mesh_shape}, "
+ f"Specified Mesh Axes: {axis_names}"
+ )
+
+ gprint(f"LIBTPU_INIT_ARGS: {os.environ.get('LIBTPU_INIT_ARGS', 'None')}")
+ gprint(f"XLA_FLAGS: {os.environ.get('XLA_FLAGS', 'None')}")
+
+ if getattr(config.trainer, "disable_ddp_optimizer", False):
+ torch._dynamo.config.optimize_ddp = False
+
+ if config.seed is not None:
+ if config.mode == 'eval':
+ config.seed = config.seed + 1000 * int(get_rank())
+ else:
+ config.seed = config.seed + int(get_rank())
+ np.random.seed(config.seed)
+ random.seed(config.seed)
+ torch.manual_seed(config.seed)
+ if is_torch_cuda_available():
+ # TODO: Is seed all desired? Does it set the same one on all GPUs even in multi-process?
+ torch.cuda.manual_seed_all(config.seed)
+
+ if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+ xm.set_rng_state(config.seed)
+ gprint(f"Set seed: {config.seed}")
+ else:
+ rprint("No seed provided")
+
+ _print_config(config, resolve=True, save_cfg=True)
+
+ with open(f"env_vars_{get_slurm_filename_info()}_{get_hostname().replace('-', '_')}.txt", "w") as f:
+ for key, value in os.environ.items():
+ f.write(f"{key}={value}\n")
+
+ tokenizer = dataloader.get_tokenizer(config)
+
+ if "tokens" in config.data.train and (config.loader.num_workers > 0 or getattr(config.data, "force_mp_spawn", False)):
+ from torch import multiprocessing as mp
+ try:
+ rprint(f"Start already method set to: {mp.get_start_method()}")
+ except:
+ mp.set_start_method("spawn")
+ rprint(f"Start method set to: {mp.get_start_method()}")
+
+ rprint(f"Mode: {config.mode}")
+ if config.mode == "sample_eval":
+ generate_samples(config, tokenizer)
+ else:
+ try:
+ run(config, tokenizer)
+ except Exception as e:
+ rprint(f"Traceback: {traceback.format_exc()}")
+ rprint(f"Exception: {e}")
+
+ timestamp = int(__import__("time").time_ns())
+ error_filepath = f"exception_{timestamp}_{process_file_prefix()}.out"
+ with open(error_filepath, "w") as file:
+ file.write(traceback.format_exc())
+ rprint(f"See error file {Path(error_filepath).resolve()} for traceback")
+
+ if is_torch_xla_available():
+ exit(1)
+
+ if ("SLURM_JOB_ID" not in os.environ) and ("RESTART_FAULT_TOLERANT" not in os.environ) and not is_torch_xla_available():
+ gprint(f"Entering debugger")
+ breakpoint(traceback=e.__traceback__)
+ else:
+ rprint(f"Not breaking, SLURM_JOB_ID: {os.environ.get('SLURM_JOB_ID')}, RESTART_FAULT_TOLERANT: {os.environ.get('RESTART_FAULT_TOLERANT')}")
+
+ if "RESTART_FAULT_TOLERANT" in os.environ:
+ sigterm_handler = signal.getsignal(signal.SIGTERM)
+ if callable(sigterm_handler):
+ rprint(f"Calling SIGTERM handler")
+ sigterm_handler(signal.SIGTERM, None)
+
+ try:
+ if config.trainer.num_nodes > 1 and config.debug is False and is_main_process():
+ wandb.alert(title="Exception!", text=f"{e}, {traceback.format_exc()}")
+ except:
+ pass
+ raise e
+ finally:
+ pass
+
+if __name__ == "__main__":
+ main()
diff --git a/model.py b/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1f520022901d0970b3537df23a9269028841a43
--- /dev/null
+++ b/model.py
@@ -0,0 +1,1670 @@
+import math
+import random
+import types
+import time
+from collections import defaultdict
+from contextlib import nullcontext
+from functools import cached_property, partial
+from contextlib import ExitStack
+
+from numpy import mask_indices
+from unidisc.utils.tensor_utils import get_contiguous_blocks, get_contiguous_blocks_per_sample, get_interleaved_indices
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from accelerate.utils import gather, gather_object
+from einops import rearrange
+from tensordict import TensorDict
+from torch import Tensor, nn
+from tqdm.auto import tqdm
+
+import model_eval
+import model_setup
+import model_utils
+import utils
+from decoupled_utils import (Profiler, barrier, dprint, get_rank, get_world_size, gprint,
+ is_local_main_process, is_main_process,
+ is_torch_cuda_available, is_torch_xla_available,
+ print_memory, rprint, save_memory_profile,
+ synchronize_device, try_except, use_dist)
+from unidisc.tokenizers.image_tokenizers import (decode_latents, get_image_batch,
+ get_vae, vae_encode_image)
+from unidisc.utils.cuda_utils import sync_times
+from unidisc.utils.xla_utils import shard_output
+from model_utils import (Loss, ddprint, ema_update, empty_device_cache, get_chameleon_txt_indices, get_interleaved_block_mask, log,
+ replace_nan_dict, update_histogram, update_logs, get_block_mask)
+from unidisc.utils.trainer_utils import TrainingState, incremental_dict_update, linear_warmup
+
+is_xla_available = is_torch_xla_available()
+
+if is_xla_available:
+ import torch_xla
+ from torch_xla.distributed.spmd import XLAShardedTensor
+
+
+
+
+def maybe_unwrap(t: torch.Tensor) -> torch.Tensor:
+ return t.global_tensor if isinstance(t, XLAShardedTensor) else t
+
+class Diffusion:
+ def __init__(self, config, tokenizer, device, disable_init=False):
+ super().__init__()
+ setup_methods = [
+ 'init', 'to', 'get_params', 'get_vae', 'get_cond_vae', 'configure_optimizers',
+ '_validate_configuration', 'register_signal_handler', 'on_train_start',
+ 'optimizer_step', 'init_dataloader', 'set_accelerator', 'set_callbacks',
+ 'on_train_step_end', 'init_optimizer_lr_scheduler', 'after_backward', 'checkpoint',
+ 'print_hashes', 'shortcut_return', 'reset_validation_metrics', 'unwrap_model'
+ ]
+ for method_name in setup_methods:
+ setattr(self, method_name, types.MethodType(getattr(model_setup, method_name), self))
+
+ utils_methods = [
+ 'get_coord_plot', '_score_entropy', 'sample_subs_guidance',
+ 'restore_model_and_semi_ar_sample', '_reconstruction_loss',
+ 'restore_model_and_sample', 'get_score', '_staggered_score',
+ '_analytic_update', '_denoiser_update', '_transp_transition',
+ 'eval_retokenize', 'compute_generative_perplexity', '_d3pm_loss',
+ '_d3pm_parameterization', '_sedd_parameterization',
+ 'get_base_shapes_for_mup', 'update_histogram', '_maybe_sub_sample',
+ 'viz_images_from_dataloader', 'compute_cider'
+ ]
+ for method_name in utils_methods:
+ setattr(self, method_name, types.MethodType(getattr(model_utils, method_name), self))
+
+ eval_methods = [
+ 'get_every_n_evals', 'on_validation_epoch_start', 'sample',
+ 'predict_step', 'validation_step', 'on_validation_epoch_end',
+ 'on_validation_epoch_cleanup', '_sample_prior', '_ddpm_forward',
+ '_ddpm_update', '_ddpm_caching_update', '_sample', '_ar_sampler',
+ 'decode_batch', 'sample_transfusion', 'sample_continuous_image',
+ 'decode_sampling', '_ddpm_update_finetune_controlled_tweedie',
+ 'sample_masking', 'log_flops', "visualize_samples", "_maskgit_update",
+ "_first_hitting_update", "update_inline_fid", "compute_inline_fid",
+ "update_clean_fid", "compute_clean_fid_eval", "sample_for_fid",
+ "compute_clip_score", "mauve_store_references", "zero_shot_eval_step",
+ "zero_shot_eval_epoch_end", "get_cfg_weight", "cleanup_fid_output",
+ "calculate_chameleon_perplexity", "get_anole_data",
+ "update_img_to_txt_mauve_clip", "compute_mauve_entropy",
+ "get_top_k", "compute_entropy", "get_mauve_score", "get_valid_seq", "gather_tokens",
+ "count_valid_tokens", "compute_val_metrics_standalone", "_maskgit_nucleus_update",
+ "get_img_text_saturation_batch", "handle_interleaved_decode", "get_interleaved_image",
+ "auto_enhance", "get_clip_score", "get_dfn_score", "get_hpsv2_score", "get_model_likelihood_score",
+ "get_laion_aesthetic_score", "get_rewards", "get_chameleon_score", "clear_reward_models",
+ "get_text_likelihood_score", "get_text_reward_model_score", "save_image_text_pair"
+ ]
+ for method_name in eval_methods:
+ setattr(self, method_name, types.MethodType(getattr(model_eval, method_name), self))
+
+ if disable_init:
+ pass
+ else:
+ model_setup.init(self, config, tokenizer, device)
+
+ @cached_property
+ def xla_mesh(self):
+ import torch_xla.distributed.spmd as xs
+ return xs.get_global_mesh()
+
+ def on_train_resume(self):
+ if not is_torch_xla_available():
+ empty_device_cache()
+
+ if self.ema is not None and not self.config.trainer.use_custom_ema:
+ self.ema.restore(self.get_params(), raise_error_if_already_restored=False)
+
+ self.backbone.train()
+
+ def zero_shot_update_batch(self, batch):
+ dataset = self.config.data.train
+ if dataset is None:
+ return batch
+
+ def get_attr(attr_name):
+ return getattr(self.config.model, attr_name, None)
+
+ if dataset == "nlphuji/flickr30k":
+ # image captioning dataset
+ # above thing but order is [txt, img]
+ batch['gt_input_ids'] = batch['input_ids']
+ image_input_ids = get_image_batch(self.config, self.get_vae(), batch, self.device)
+ image_input_ids += self.text_vocab_size
+ batch["input_ids"] = torch.cat([torch.zeros_like(batch['gt_input_ids'], dtype=torch.int64), image_input_ids], dim=-1).to(self.device)
+ batch['attention_mask'] = torch.cat([torch.zeros_like(batch['gt_input_ids'], dtype=torch.bool), torch.ones_like(image_input_ids, dtype=torch.bool)], dim=-1).to(self.device)
+ batch["modality"] = torch.cat([torch.zeros_like(batch['gt_input_ids'], dtype=torch.int64), torch.ones_like(image_input_ids, dtype=torch.int64)], dim=-1).to(self.device)
+ elif dataset == "facebook/winoground":
+ # get image and text input ids
+ caption_0_input_ids = batch['caption_0_input_ids']
+ caption_1_input_ids = batch['caption_1_input_ids']
+ image_0 = batch['img_0']
+ image_1 = batch['img_1']
+ # tokenize and store captions separately
+ image_0_input_ids = vae_encode_image(self.config, self.get_vae(), image_0, self.device, get_attr("vae_type")) + self.text_vocab_size
+ image_1_input_ids = vae_encode_image(self.config, self.get_vae(), image_1, self.device, get_attr("vae_type")) + self.text_vocab_size
+ # make 4 combinat ions of image and text
+ batch['input_ids_0_0'] = torch.cat([caption_0_input_ids, image_0_input_ids], dim=-1).to(self.device)
+ batch['input_ids_0_1'] = torch.cat([caption_0_input_ids, image_1_input_ids], dim=-1).to(self.device)
+ batch['input_ids_1_0'] = torch.cat([caption_1_input_ids, image_0_input_ids], dim=-1).to(self.device)
+ batch['input_ids_1_1'] = torch.cat([caption_1_input_ids, image_1_input_ids], dim=-1).to(self.device)
+ batch['attention_mask'] = torch.cat([torch.zeros_like(caption_0_input_ids, dtype=torch.bool), torch.ones_like(image_0_input_ids, dtype=torch.bool)], dim=-1).to(self.device)
+ batch['modality'] = torch.cat([torch.zeros_like(caption_0_input_ids, dtype=torch.int64), torch.ones_like(image_0_input_ids, dtype=torch.int64)], dim=-1).to(self.device)
+ # elif dataset == "facebook/winoground":
+ batch["modality_mask"] = F.one_hot(batch["modality"], num_classes=2).to(torch.bool)
+ return batch
+
+ def update_batch(self, batch):
+ if getattr(self.config.eval, 'big_seq_len_eval', False):
+ # new batch of 8192 seq length with txt length 4096 and img length 4096s
+ N = self.config.model.length
+ new_batch = dict()
+ new_batch['input_ids'] = torch.zeros(batch['input_ids'].shape[0], N, device=self.device, dtype=batch['input_ids'].dtype)
+ new_batch['attention_mask'] = torch.ones(batch['attention_mask'].shape[0], N, device=self.device, dtype=batch['attention_mask'].dtype)
+ new_batch['modality'] = torch.zeros(batch['modality'].shape[0], N, device=self.device, dtype=batch['modality'].dtype)
+ new_batch['modality'][:, N//2:] = 1
+ new_batch['modality_mask'] = F.one_hot(new_batch['modality'], num_classes=2).to(torch.bool)
+ batch = new_batch
+ return batch
+
+ continuous_mode = self.config.trainer.image_mode == "continuous"
+ if batch is None:
+ gprint(f"Warning! Batch is None")
+ return batch
+
+ if isinstance(batch, TensorDict):
+ batch.batch_size = (batch.batch_size[0],)
+
+ if self.image_model or getattr(self.config.data, "force_image_dataset", False):
+ text_input_ids = None
+ if isinstance(batch, TensorDict) and (self.is_compiled or getattr(self.config.trainer, "force_convert_to_dict", False)):
+ batch = dict(batch.items())
+
+ if "txt_input_ids" in batch or "img_input_ids" in batch:
+ index_keys = ["img_input_ids", "txt_input_ids", "sample_ids"]
+ for key in index_keys:
+ if key in batch:
+ if isinstance(batch[key], list):
+ batch[key] = torch.stack(batch[key], dim=0)
+ batch[key] = batch[key].to(torch.int64)
+
+ index_keys = ["img_label"]
+ for key in index_keys:
+ if key in batch:
+ batch[key] = batch[key].squeeze(-1)
+
+ img_input_ids = batch.pop("img_input_ids")
+ batch["input_ids"] = img_input_ids
+ batch["attention_mask"] = torch.ones_like(img_input_ids).to(torch.bool)
+ if "txt_input_ids" in batch:
+ batch["input_ids"] = torch.cat([batch["txt_input_ids"], batch["input_ids"] + self.text_vocab_size], dim=-1)
+ batch["attention_mask"] = torch.cat([batch["txt_attention_mask"], batch["attention_mask"]], dim=-1)
+
+ batch["input_ids"] = batch["input_ids"].to(torch.int64)
+
+ if "modality" not in batch:
+ if getattr(self.config.trainer, "ignore_text_in_unified", False):
+ modality = torch.ones_like(batch["input_ids"], dtype=torch.int64)
+ else:
+ assert self.config.model.txt_length > 0 and self.config.model.img_length > 0
+ modality = torch.zeros_like(batch["input_ids"], dtype=torch.int64)
+ modality[:, -img_input_ids.shape[-1]:] = 1
+ batch["modality"] = modality
+
+ elif (self.config.trainer.multimodal_batches or continuous_mode) and \
+ not getattr(self.config.trainer, "use_legacy_update_batch_fn", False):
+
+ if "img" in batch:
+ is_image_batch = (batch["modality"] == 1).all(dim=-1)
+ image_input_ids = get_image_batch(self.config, self.get_vae(), batch, self.device)
+ assert ((batch["modality"].sum(dim=-1) == 0) | (batch["modality"].sum(dim=-1) >= image_input_ids.shape[1])).all()
+
+ if getattr(self.config.trainer, "add_label", False):
+ assert (batch["modality"] == 1).all()
+ batch["input_ids"][:, 1:] = torch.where(is_image_batch[:, None], image_input_ids, batch["input_ids"][:, 1:])
+ elif image_input_ids.ndim == 3:
+ batch["img_emb"] = torch.where((batch["modality"] == 1)[:, :, None], image_input_ids, torch.nan)
+ elif (batch["input_ids"][batch["modality"] == 1] == -1).all():
+ batch["input_ids"].masked_scatter_(batch["modality"] == 1, image_input_ids)
+ else:
+ batch["input_ids"] = torch.where(is_image_batch[:, None], image_input_ids, batch["input_ids"])
+
+ if getattr(self.config.trainer, "force_shift_raw_image_batches", False):
+ assert not getattr(self.config.trainer, "force_shift_image_batches", False)
+ batch["input_ids"] = torch.where(batch["modality"] == 1, batch["input_ids"] + self.text_vocab_size, batch["input_ids"])
+ else:
+ if getattr(self.config.trainer, "add_label", False):
+ shift_index = self.vocab_size - self.config.model.add_labels
+ batch["input_ids"] = torch.cat([batch["label"] + shift_index, batch["input_ids"]], dim=-1)
+ batch["attention_mask"] = torch.cat([torch.zeros_like(batch["label"], dtype=torch.bool), batch["attention_mask"]], dim=-1)
+ batch["modality"] = torch.cat([torch.ones_like(batch["label"], dtype=torch.int64), batch["modality"]], dim=-1)
+ assert (batch["modality"] == 1).all()
+
+ batch["input_ids"] = batch["input_ids"].to(torch.int64)
+ if "sample_ids" in batch:
+ batch["sample_ids"] = batch["sample_ids"].to(torch.int64)
+
+ if getattr(self.config.trainer, "force_shift_image_batches", False):
+ batch["input_ids"] = torch.where(batch["modality"] == 1, batch["input_ids"] + self.text_vocab_size, batch["input_ids"])
+ else:
+ if continuous_mode:
+ assert False
+ else:
+ if "input_ids" in batch and not self.config.trainer.ignore_text_in_unified:
+ assert self.config.model.unified_model
+ assert "attention_mask" in batch
+ text_input_ids = batch["input_ids"]
+
+ image_ids = get_image_batch(self.config, self.get_vae(), batch, self.device)
+ image_attention_mask = torch.ones_like(image_ids).to(torch.bool)
+
+ if "cond_img" in batch:
+ cond_image_ids = get_image_batch(self.config, self.get_cond_vae(), batch, self.device, use_cond=True)
+ batch["cond_input_ids"] = cond_image_ids
+
+ if text_input_ids is not None:
+ assert batch["input_ids"].shape[1] == self.config.model.txt_length
+ assert image_ids.shape[1] == self.config.model.img_length
+ image_ids = image_ids + self.text_vocab_size
+
+ batch["input_ids"] = torch.cat([batch["input_ids"].to(self.device), image_ids], dim=-1)
+ batch["attention_mask"] = torch.cat([batch["attention_mask"].to(self.device), image_attention_mask], dim=-1).to(torch.bool)
+ assert batch["input_ids"].shape[1] == batch["attention_mask"].shape[1] == self.config.model.length
+ batch["modality"] = torch.zeros_like(batch["input_ids"], dtype=torch.int64)
+ batch["modality"][:, -image_ids.shape[-1]:] = 1
+ else:
+ assert self.unified_model is False
+ batch["input_ids"] = image_ids
+ batch["attention_mask"] = image_attention_mask
+ batch["modality"] = torch.ones_like(batch["input_ids"], dtype=torch.int64)
+
+ if "txt_x0_unmask" in batch and "img_x0_unmask" in batch:
+ assert not continuous_mode
+ batch["gt_img_input_ids"] = image_ids
+ batch["x0_unmask"] = torch.cat([batch["txt_x0_unmask"], batch["img_x0_unmask"]], dim=-1)
+ batch["input_ids"][~batch["x0_unmask"]] = self.mask_index
+
+ if (batch["input_ids"].shape[1] != self.config.model.length) and not self.config.trainer.ar_inpainting:
+ gprint(f"Warning! Input ids shape: {batch['input_ids'].shape}, model length: {self.config.model.length}")
+ batch["input_ids"] = batch["input_ids"][:, : self.config.model.length]
+ assert False, f"input ids are not the correct length input ids shape: {batch['input_ids'].shape}, model length: {self.config.model.length}"
+
+ if getattr(self.config.model, "img_cond", False):
+ assert "cond_input_ids" in batch
+ assert not continuous_mode
+
+ if "modality" in batch:
+ batch["modality"] = batch["modality"].to(torch.int64)
+ if self.config.trainer.multimodal_batches and batch["modality"].ndim == 2 and batch["modality"].shape[-1] == 1:
+ batch["modality"] = batch["modality"].repeat(1, self.config.model.length)
+ else:
+ if self.image_model and not self.config.trainer.multimodal_batches:
+ assert self.config.model.txt_length > 0 and self.config.model.img_length > 0
+ modality = torch.zeros_like(batch["input_ids"], dtype=torch.int64)
+ modality[:, self.static_img_sl] = 1
+ batch["modality"] = modality
+ elif self.config.data.txt_only:
+ batch["modality"] = torch.zeros_like(batch["input_ids"], dtype=torch.int64)
+
+ if "modality" in batch:
+ batch["modality"][batch["modality"] == -1] = 0
+ assert batch["modality"].min() == 0 and batch["modality"].max() == 1
+ batch["modality_mask"] = F.one_hot(batch["modality"], num_classes=2).to(torch.bool)
+ batch["batch_contains_img"] = (batch["modality"] == 1).any(dim=-1)
+ batch['txt_sl'] = self.txt_sl(batch)
+ batch['img_sl'] = self.img_sl(batch)
+
+ if getattr(self.config.trainer, "force_remove_img_tokens", False):
+ assert not continuous_mode
+ batch["input_ids"] = batch["input_ids"][batch['txt_sl']]
+ batch["attention_mask"] = batch["attention_mask"][batch['txt_sl']]
+
+ if getattr(self.config.trainer, "add_label", False):
+ assert getattr(self.config.model, "add_labels", False)
+ assert "label" in batch
+ batch["label"] = batch["label"].to(torch.int64)
+ assert 0 <= batch["label"].min() and batch["label"].max() < self.config.model.add_labels
+ shift_index = self.vocab_size - self.config.model.add_labels
+
+ assert batch["input_ids"].shape[-1] == self.config.model.length
+ if batch["label"].ndim == 1:
+ batch["input_ids"][:, [0]] = (batch["label"] + shift_index).unsqueeze(-1)
+ else:
+ batch["input_ids"][:, [0]] = batch["label"] + shift_index
+
+ batch["attention_mask"][:, 0] = False
+
+ if isinstance(batch, dict):
+ for key in batch.keys():
+ if isinstance(batch[key], torch.Tensor):
+ batch[key] = batch[key].to(self.device)
+ elif isinstance(batch, TensorDict):
+ assert self.config.backbone != "gemma"
+ batch = batch.to(self.device)
+
+ if getattr(self.config.trainer, "force_full_attention_mask", False):
+ batch["attention_mask"] = torch.ones_like(batch["attention_mask"], dtype=torch.bool)
+
+ batch["attention_mask"] = batch["attention_mask"].to(torch.bool)
+
+ if self.config.data.require_sample_ids:
+ assert "sample_ids" in batch
+ batch["sample_ids"][~(batch["attention_mask"].bool())] = -1
+ batch["attention_mask"][batch["sample_ids"] == -1] = False
+
+ # Flip [txt, img] -> [img, txt]
+ # TODO: Flip by sample not batch. As we train w/~8 batches, it's for now
+ if (self.training or getattr(self.config.trainer, "force_flip_ar_val", False)) and self.config.parameterization == "ar" and getattr(self.config.trainer, "rand_flip_ar_prob", None) is not None:
+ assert (batch["modality"][:, :self.config.model.txt_length] == 0).all() and (batch["modality"][:, self.config.model.txt_length:] == 1).all(), "Modality does not match img_before_txt configuration"
+ batch_flip_mask = torch.rand(batch["modality"].shape[0], device=self.device) < self.config.trainer.rand_flip_ar_prob
+ img_slice = slice(-self.config.model.img_length, None)
+ txt_slice = slice(None, self.config.model.txt_length)
+
+ for key in ["modality", "attention_mask", "input_ids"]:
+ batch[key][batch_flip_mask] = torch.cat([batch[key][batch_flip_mask][:, img_slice], batch[key][batch_flip_mask][:, txt_slice]], dim=1)
+
+ if "modality_mask" in batch:
+ batch["modality_mask"] = F.one_hot(batch["modality"], num_classes=2).to(torch.bool)
+
+ batch['txt_sl'] = None
+ batch['img_sl'] = None
+ batch["batch_flip_mask"] = batch_flip_mask
+
+ if self.config.trainer.interleaved and "sample_ids" not in batch:
+ batch["sample_ids"] = torch.zeros_like(batch["modality"], dtype=torch.int64)
+
+ if self.config.trainer.interleaved:
+ batch_indices, start_positions, end_positions = get_contiguous_blocks(batch["modality"])
+ interleaved_metadata = TensorDict({
+ "batch_indices": batch_indices,
+ "start_positions": start_positions,
+ "end_positions": end_positions
+ }, batch_size=[])
+ allowed_image_sizes = (64, 256, 1024, 2304, 4096)
+ block_sizes = (end_positions - start_positions).to(torch.int32)
+ is_txt_block = batch["modality"][batch_indices, start_positions] == 0
+ is_valid_img_size = torch.isin(block_sizes, torch.tensor(allowed_image_sizes, dtype=torch.int32, device=self.device))
+
+ if not ((is_txt_block | is_valid_img_size).all()):
+ gprint(f"WARNING: Found non-text block of size {block_sizes[~(is_txt_block | is_valid_img_size)]} in interleaved batch")
+
+ if isinstance(batch, TensorDict):
+ batch.batch_size = []
+ batch["interleaved_metadata"] = interleaved_metadata
+
+ return batch
+
+ def get_cond_dict(self, batch):
+ ret_dict = dict()
+ if "cond_input_ids" in batch:
+ ret_dict["x_cond"] = batch["cond_input_ids"]
+
+ if "img_label" in batch:
+ ret_dict["label"] = batch["img_label"]
+
+ if self.config.model.use_attention_mask:
+ ret_dict["attention_mask"] = batch["attention_mask"]
+
+ if self.config.trainer.multimodal_batches:
+ ret_dict["modality"] = batch["modality"]
+
+ if self.config.trainer.image_mode == "continuous":
+ ret_dict["continuous_mode"] = True
+ ret_dict["modality"] = batch["modality"]
+
+ if self.parameterization == "ar" and "modality" in batch:
+ ret_dict["modality"] = batch["modality"]
+
+ return ret_dict
+
+ def training_step(self, batch, batch_idx):
+ batch = self.update_batch(batch)
+ return self.compute_loss(batch, prefix="train", batch_idx=batch_idx)
+
+ def q_xt(self, x, move_chance, allow_move_mask=None, return_ignore_batch_mask_for_metrics=False, mask_image_square=False, mask_text_region=False, batch=None):
+ """Computes the noisy sample xt.
+
+ Args:
+ x: int torch.Tensor with shape (batch_size,
+ diffusion_model_input_length), input.
+ move_chance: float torch.Tensor with shape (batch_size, 1).
+ """
+ if self.config.backbone == "maskdit" and getattr(self.config.trainer, "force_single_timestep_per_batch", False):
+ num_to_mask = int(x.shape[1] * move_chance[0].item())
+ batch_size, seq_len = x.shape
+ random_indices = torch.rand(batch_size, seq_len, device=x.device).argsort(dim=1)[:, :num_to_mask]
+ xt = x.scatter(1, random_indices, self.mask_index)
+ return xt
+
+ move_indices = torch.rand(*x.shape, device=x.device) < move_chance
+
+ if mask_image_square:
+ latent_dim = int(math.sqrt(self.config.model.img_length))
+ img_move_indices = move_indices[:, self.static_img_sl].clone().reshape(move_indices.shape[0], latent_dim, latent_dim)
+ max_d = int(math.sqrt(self.config.model.img_length))
+ for b in range(move_indices.shape[0]):
+ if move_chance[b] == 1:
+ continue
+ h, w = img_move_indices[b].shape
+ d = random.randint(max_d // 2, max_d - 2)
+ i = random.randint(0, h - d)
+ j = random.randint(0, w - d)
+
+ mask = torch.zeros_like(img_move_indices[b], dtype=torch.bool)
+ mask[i:i+d, j:j+d] = True
+ move_indices[b, self.static_img_sl] = mask.reshape(-1)
+
+ if mask_text_region:
+ for b in range(x.shape[0]):
+ if move_chance[b] == 1:
+ continue
+ should_mask = torch.zeros_like(move_indices[b, self.static_txt_sl], dtype=torch.bool)
+ max_valid = (x[b] == self.tokenizer.eos_token_id).nonzero()[0, 0] if self.tokenizer.eos_token_id in x[b] else x.shape[1]
+ d = random.randint(max_valid//3, max_valid-1)
+ start = random.randint(0, max_valid - d)
+ should_mask[start:start+d] = True
+ move_indices[b, self.static_txt_sl] = should_mask
+
+ ignore_batch_mask_for_metrics = None
+ should_mask_txt, should_mask_img = None, None
+ if (mask_prob := getattr(self.config.trainer, "mask_entire_modality", None)) is not None \
+ and (mask_image_square is False and mask_text_region is False) and self.backbone.training:
+
+ assert batch is not None
+ batch_size, seq_len = x.shape
+ if getattr(self.config.trainer, "mask_txt_only", False):
+ should_mask_txt = torch.rand(batch_size, 1, device=x.device) < mask_prob
+ should_mask_img = torch.zeros_like(should_mask_txt, device=x.device)
+ else:
+ should_mask_txt = torch.rand(batch_size, 1, device=x.device) < mask_prob/2
+ should_mask_img = torch.rand(batch_size, 1, device=x.device) < mask_prob/2
+
+ if self.config.trainer.multimodal_batches:
+ if self.config.trainer.interleaved:
+ batch_indices, start_positions, end_positions = get_contiguous_blocks_per_sample(batch["modality"], batch["sample_ids"])
+
+ block_size = end_positions - start_positions
+ size_mask = block_size > 4
+ batch_indices, start_positions, end_positions = batch_indices[size_mask], start_positions[size_mask], end_positions[size_mask]
+
+
+ block_counts = torch.zeros_like(batch_indices)
+ max_num_sample_ids = torch.zeros_like(batch_indices)
+
+
+ for i in range(len(batch_indices)):
+ curr_sample_id = batch["sample_ids"][batch_indices[i], start_positions[i]]
+
+ # Find blocks before this one with same batch index and sample_id
+ prev_blocks_mask = (batch_indices[:i] == batch_indices[i]) & \
+ (batch["sample_ids"][batch_indices[:i], start_positions[:i]] == curr_sample_id)
+
+ total_in_sample = ((batch_indices == batch_indices[i]) & (batch["sample_ids"][batch_indices, start_positions] == curr_sample_id)).sum()
+
+ block_counts[i] = prev_blocks_mask.sum()
+ max_num_sample_ids[i] = total_in_sample
+
+ block_prob = (block_counts + 1) / max_num_sample_ids
+ positions = torch.arange(move_indices.shape[-1], device=move_indices.device).unsqueeze(0) # Shape: [1, N]
+ mask = (positions >= start_positions.unsqueeze(1)) & (positions < end_positions.unsqueeze(1)) # Shape: [M, N]
+ mask = mask & (torch.rand(batch_indices.shape[0], 1, device=x.device) < (mask_prob * block_prob * 2)[..., None])
+ expanded_batch_indices = batch_indices.unsqueeze(1).expand(-1, move_indices.shape[1]) # Shape: [M, N]
+
+ # True if we should manually mask the part of the sequence
+ accum = torch.zeros_like(move_indices, dtype=torch.int32) # Shape: [B, N]
+ accum.scatter_add_(0, expanded_batch_indices, mask.int()) # Accumulate counts
+ accum = accum.to(torch.bool)
+
+ move_indices = move_indices | accum
+
+ # We ignore the entire sequence if any of the blocks are fully masked
+ ignore_batch_mask_for_metrics = torch.zeros((move_indices.shape[0],), device=x.device, dtype=torch.bool)
+ ignore_batch_mask_for_metrics.scatter_add_(0, batch_indices, mask.any(dim=-1))
+ else:
+ # TODO: Be smarter about masking for interleaved
+ # To make sure that we have even masking prob, we prefer to mask less but equally
+ both_mask = should_mask_txt & should_mask_img
+ should_mask_txt = torch.where(both_mask, False, should_mask_txt)
+ should_mask_img = torch.where(both_mask, False, should_mask_img)
+ move_indices = torch.where(should_mask_txt, batch["modality_mask"][..., 0], move_indices)
+ move_indices = torch.where(should_mask_img, batch["modality_mask"][..., 1], move_indices)
+ ignore_batch_mask_for_metrics = should_mask_img | should_mask_txt
+ else:
+ both_mask = should_mask_txt & should_mask_img
+ should_mask_txt[both_mask] = False
+ should_mask_img[both_mask] = False
+ should_mask_img[batch["txt_sl"].all(dim=-1)] = False
+ move_indices[:, self.static_txt_sl] = torch.where(should_mask_txt, True, move_indices[:, self.static_txt_sl])
+ move_indices[:, self.static_img_sl] = torch.where(should_mask_img, True, move_indices[:, self.static_img_sl])
+ ignore_batch_mask_for_metrics = should_mask_img | should_mask_txt
+
+ joint_ar_nar_mask = None
+ if self.config.trainer.joint_ar_nar_prob is not None and self.training:
+ batch_size = x.shape[0]
+ current_prob = linear_warmup(
+ current_step=self.global_step,
+ warmup_steps=self.config.trainer.joint_ar_nar_prob_warmup_steps,
+ final_value=self.config.trainer.joint_ar_nar_prob,
+ initial_value=1.0
+ )
+ joint_ar_nar_mask = torch.rand(batch_size, device=x.device) < current_prob
+ move_indices = torch.where(joint_ar_nar_mask[:, None], False, move_indices)
+
+ if self.config.trainer.add_label:
+ move_indices[:, 0] = False
+
+ if self.config.trainer.first_token_dropout is not None and self.training:
+ _initial_mask = torch.rand(x.shape[0], device=x.device) < self.config.trainer.first_token_dropout
+ move_indices[:, 0] = torch.where(_initial_mask, True, move_indices[:, 0])
+ if ignore_batch_mask_for_metrics is None:
+ ignore_batch_mask_for_metrics = _initial_mask
+ else:
+ ignore_batch_mask_for_metrics = ignore_batch_mask_for_metrics | _initial_mask
+
+ if allow_move_mask is not None:
+ move_indices = move_indices & allow_move_mask
+
+ if getattr(self.config.trainer, "discrete_diffusion_mode", "absorbing") == "uniform":
+ if getattr(self.config.model, "force_argmax_valid_indices", False):
+ assert self.mask_index == self.text_vocab_size - 1
+ text_random_tokens = torch.randint(0, self.text_vocab_size - 1, size=x.shape, device=x.device)
+ img_random_tokens = torch.randint(self.text_vocab_size, self.vocab_size, size=x.shape, device=x.device)
+ random_tokens = torch.where(batch["modality_mask"][..., 0], text_random_tokens, img_random_tokens)
+ assert not torch.any(random_tokens == self.mask_index)
+ else:
+ random_tokens = torch.randint(0, vocab_size, size=x.shape, device=x.device)
+ random_tokens = torch.where(random_tokens == self.mask_index, random_tokens + 1, random_tokens) # avoid mask index
+ xt = torch.where(move_indices, random_tokens, x)
+ else:
+ xt = torch.where(move_indices, self.mask_index, x)
+
+ if self.parameterization == "ar":
+ xt = x.clone()
+
+ if return_ignore_batch_mask_for_metrics:
+ return xt, ignore_batch_mask_for_metrics, joint_ar_nar_mask, should_mask_txt, should_mask_img, move_indices
+ else:
+ return xt
+
+ def _sample_t(self, n, device):
+ if self.config.backbone == "maskdit" and getattr(self.config.trainer, "force_single_timestep_per_batch", False):
+ _eps_t = torch.rand(1, device=device).repeat(n)
+ else:
+ _eps_t = torch.rand(n, device=device)
+ if self.config.trainer.joint_ar_nar_timestep_warmup_steps is not None:
+ max_t = linear_warmup(
+ current_step=self.global_step,
+ warmup_steps=self.config.trainer.joint_ar_nar_timestep_warmup_steps,
+ final_value=1,
+ initial_value=0,
+ start_step=0
+ )
+ _eps_t = _eps_t * max_t
+ if max_t == 1:
+ offset = torch.arange(n, device=device) / n
+ _eps_t = (_eps_t / n + offset) % 1
+
+ elif self.antithetic_sampling:
+ offset = torch.arange(n, device=device) / n
+ _eps_t = (_eps_t / n + offset) % 1
+
+ if getattr(self.config.trainer, "force_timestep", None) is not None:
+ _eps_t[:] = self.config.trainer.force_timestep
+ elif getattr(self.config.eval, "ar_inpainting_force_val", None) is not None:
+ _eps_t[:] = self.config.eval.ar_inpainting_force_val
+
+ t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
+ if self.importance_sampling:
+ return self.noise.importance_sampling_transformation(t)
+ return t.to(torch.float32)
+
+ def _subs_parameterization(self, logits, xt, batch=None, modality=None, **kwargs):
+ # log prob at the mask index = - infinity
+ if not self.allow_slicing:
+ logits = logits.clone()
+
+ logits[..., self.mask_index] += self.neg_infinity
+ if getattr(self.config.model, "force_argmax_valid_indices", False):
+ if self.config.trainer.multimodal_batches:
+ _txt_sl = batch["txt_sl"] if modality is None else modality == 0
+ _img_sl = batch["img_sl"] if modality is None else modality == 1
+ logits[..., self.text_vocab_size:] = torch.where(_txt_sl[..., None], self.neg_infinity, logits[..., self.text_vocab_size:])
+ logits[..., :self.text_vocab_size] = torch.where(_img_sl[..., None], self.neg_infinity, logits[..., :self.text_vocab_size])
+ else:
+ logits[..., self.static_txt_sl, self.text_vocab_size:] = self.neg_infinity
+ logits[..., self.static_img_sl, :self.text_vocab_size] = self.neg_infinity
+
+ # Normalize the logits such that x.exp() is
+ # a probability distribution over vocab_size.
+ logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
+
+ if self.parameterization != "ar" and xt is not None:
+ # Apply updates directly in the logits matrix.
+ # For the logits of the unmasked tokens, set all values
+ # to -infinity except for the indices corresponding to
+ # the unmasked tokens.
+ unmasked_indices = xt != self.mask_index
+ if not self.allow_slicing:
+ logits = torch.where(unmasked_indices.unsqueeze(-1), torch.full_like(logits, self.neg_infinity), logits)
+ logits = torch.where(
+ unmasked_indices.unsqueeze(-1) & (torch.arange(logits.size(-1)).to(logits.device) == xt.unsqueeze(-1)),
+ torch.zeros_like(logits),
+ logits
+ )
+ else:
+ logits[unmasked_indices] = self.neg_infinity
+ logits[unmasked_indices, xt[unmasked_indices]] = 0
+
+ return logits
+
+ def _process_sigma(self, sigma):
+ if sigma is None:
+ assert (self.parameterization == "ar" or self.config.trainer.ar_llm_loss) or self.config.trainer.allow_null_sigma
+ return sigma
+
+ if sigma.ndim > 1 and not self.config.trainer.image_mode == "continuous":
+ sigma = sigma.squeeze(-1)
+ assert sigma.ndim == 1, sigma.shape
+
+ if not self.time_conditioning and getattr(self.config.model, "force_time_conditioning", False):
+ sigma = torch.zeros_like(sigma)
+
+ return sigma
+
+ def forward(
+ self,
+ x,
+ sigma,
+ batch=None,
+ forward_attention_mask=None,
+ return_additional_loss=False,
+ x_img_emb=None,
+ disable_ar_shift=False,
+ continuous_mode=False,
+ joint_ar_nar_mask=None,
+ return_logits=False,
+ block_mask=None,
+ update_cache_slice=None,
+ **kwargs,
+ ):
+ """Returns log score."""
+ sigma = self._process_sigma(sigma)
+ if self.config.trainer.image_mode == "continuous": assert "modality" in kwargs
+ should_autocast = (((self.config.trainer.disable_forward_autocast_during_eval and self.backbone.training) is False) and (self.dtype != torch.float32))
+ with ExitStack() as stack:
+ if should_autocast:
+ stack.enter_context(torch.autocast(device_type=self.device.type, dtype=self.dtype))
+
+ orig_modality = None
+ if self.config.backbone == "elm":
+ if getattr(self.config.trainer, "print_llm_ppl", False):
+ _labels = x.clone()
+ _labels[~forward_attention_mask] = -100
+ kwargs['labels'] = _labels
+
+ if "modality" in kwargs:
+ if self.config.mode == "eval": orig_modality = kwargs.pop("modality")
+ else: kwargs.pop("modality")
+
+ if "modality_mask" in kwargs: kwargs.pop("modality_mask")
+ if "x0" in kwargs: kwargs.pop("x0")
+ if "start_pos" in kwargs: kwargs.pop("start_pos")
+ if "sample_ids" in kwargs: kwargs.pop("sample_ids")
+
+ output = self.backbone(input_ids=x, **kwargs)
+
+ if self.config.mode == "eval": kwargs["modality"] = orig_modality
+
+ if isinstance(output, Tensor):
+ logits = output
+ else:
+ logits = output.logits
+
+ if getattr(self.config.trainer, "print_llm_ppl", False):
+ rprint(f"AR PPL: {torch.exp(output.loss)}")
+ else:
+ if self.config.trainer.compile == 'max-autotune' and not is_xla_available:
+ torch.compiler.cudagraph_mark_step_begin()
+
+ logits = self.backbone(x, sigma, continuous_mode=continuous_mode, x_img_emb=x_img_emb, block_mask=block_mask, update_cache_slice=update_cache_slice, **kwargs)
+ if self.config.trainer.force_bf16_eval:
+ logits = logits.to(torch.bfloat16)
+
+ if continuous_mode:
+ assert self.parameterization == "ar"
+ logits, logits_img = logits
+
+ if self.config.trainer.ar_shift and not disable_ar_shift:
+ # config trainer ar shift is for training
+ # disable ar shift is for sampling at inference
+ logits = logits[:, :-1]
+ xt = x[:, 1:]
+ if orig_modality is not None and self.config.mode == 'eval':
+ orig_modality = orig_modality[:, 1:]
+ else:
+ xt = x
+
+ if self.config.trainer.low_precision_loss:
+ logits = logits.to(self.dtype)
+ if continuous_mode:
+ logits_img = logits_img.to(self.dtype)
+
+ if self.parameterization == "planner":
+ return logits
+ elif self.config.trainer.ar_llm_loss:
+ assert not self.parameterization == "ar"
+ model_output = self._subs_parameterization(logits, xt=xt, modality=orig_modality), logits
+ if is_xla_available: shard_output(model_output[0], self.xla_mesh)
+ if is_xla_available: shard_output(model_output[1], self.xla_mesh)
+ return model_output if return_additional_loss else model_output[0]
+ elif self.parameterization == "ar":
+ if not getattr(self.config.trainer, "use_orig_unidisc_dit", False):
+ logits = torch.where(
+ torch.arange(logits.shape[-1], device=logits.device)[None, None, :] == self.mask_index, self.neg_infinity, logits
+ )
+
+ _modality = kwargs.get("modality") if batch is None else batch.get("modality")
+
+ # During eval, we let the sampler handle this part.
+ if getattr(self.config.model, "force_argmax_valid_indices", False) and _modality.shape[1] == (logits.shape[1] + 1):
+ if not self.allow_slicing:
+ logits = logits.clone()
+
+ logits[..., self.text_vocab_size:] = torch.where(
+ (kwargs.get("modality") == 0)[..., 1:, None], torch.finfo(logits.dtype).min, logits[..., self.text_vocab_size:]
+ )
+ logits[..., :self.text_vocab_size] = torch.where(
+ (kwargs.get("modality") == 1)[..., 1:, None], torch.finfo(logits.dtype).min, logits[..., :self.text_vocab_size]
+ )
+
+ logits = logits.log_softmax(-1)
+
+ if continuous_mode:
+ return (logits, logits_img)
+ elif self.parameterization == "subs":
+ if return_logits:
+ return logits
+ model_output = self._subs_parameterization(logits, xt=xt, batch=batch, **kwargs)
+ if is_xla_available: shard_output(model_output, self.xla_mesh)
+ return model_output
+ elif self.parameterization == "sedd":
+ return self._sedd_parameterization(logits=logits, xt=x, sigma=sigma)
+ elif self.parameterization == "d3pm":
+ return self._d3pm_parameterization(logits=logits)
+
+ return logits
+
+ def compute_loss(self, batch, prefix, batch_idx=-1):
+ if not is_xla_available and ((self.current_run_fwd_bwd_pass == 0 and self.config.mode == 'train') or batch_idx == 0):
+ self.visualize_samples(batch, batch_idx, split=prefix)
+ if getattr(self.config.trainer, 'overfit_on_first_batch', False):
+ if batch_idx <= 0:
+ # store it
+ self.overfit_batch = batch.copy()
+ else:
+ batch = self.overfit_batch
+
+ kwargs = self.get_cond_dict(batch)
+ modality_mask = batch.get("modality_mask", None)
+ (input_tokens, output_tokens, attention_mask) = self._maybe_sub_sample(batch["input_ids"], batch.get("attention_mask", None))
+
+ continuous_mode = self.config.trainer.image_mode == "continuous"
+ joint_ar_nar_mask, modality = None, None
+ if continuous_mode:
+ assert 'modality' in batch
+ x0, img_emb, attention_mask, modality = (
+ batch["input_ids"],
+ batch["img_emb"],
+ batch["attention_mask"],
+ batch["modality"],
+ ) # img_emb has [0.] * txt_len + img_emb
+ xt = x0
+ B, N_tot, C = img_emb.shape
+
+ noise_scheduler = self.get_vae().scheduler
+ noise = torch.randn_like(img_emb)
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (B,), device=img_emb.device).long()
+ img_timesteps = timesteps.unsqueeze(-1).expand(-1, N_tot).to(self.dtype)
+ zero_timesteps = torch.zeros_like(img_timesteps)
+ unet_conditioning = torch.where(modality == 1, img_timesteps, zero_timesteps)
+ # unet_conditioning = timesteps.to(self.dtype)
+ # unet_conditioning = torch.where(modality_mask==1, timesteps.to(self.dtype), torch.zeros_like(timesteps.to(self.dtype)))
+ x_img_emb = noise_scheduler.add_noise(img_emb, noise, timesteps).to(self.dtype)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(img_emb, noise, timesteps) # todo, might break
+ elif noise_scheduler.config.prediction_type:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+ target = target.to(self.dtype)
+ else:
+ unet_conditioning, xt, x0, x_img_emb, modality_mask = None, None, input_tokens, None, batch.get("modality_mask", None)
+ if self.parameterization != "ar":
+ t = self._sample_t(x0.shape[0], x0.device)
+ if self.T > 0:
+ t = (t * self.T).to(torch.int)
+ t = t / self.T
+ t += 1 / self.T # t \in {1/T, 2/T, ..., 1}
+
+ if self.change_of_variables:
+ unet_conditioning = t[:, None]
+ f_T = torch.log1p(-torch.exp(-self.noise.sigma_max))
+ f_0 = torch.log1p(-torch.exp(-self.noise.sigma_min))
+ move_chance = torch.exp(f_0 + t * (f_T - f_0))
+ move_chance = move_chance[:, None]
+ else:
+ # total, rate
+ sigma, dsigma = self.noise(t)
+ unet_conditioning = sigma[:, None]
+ move_chance = 1 - torch.exp(-sigma[:, None])
+
+ xt, ignore_batch_mask_for_metrics, joint_ar_nar_mask, should_mask_txt, should_mask_img, move_indices = self.q_xt(x0, move_chance, return_ignore_batch_mask_for_metrics=True, batch=batch)
+ if (self.config.model.flex_attention_img_masking_prob is not None or self.config.model.flex_attention_txt_masking_prob is not None) and self.backbone.training:
+ assert xt.shape[1] == (self.config.model.img_length + self.config.model.txt_length)
+ txt_batch_attn_dropout = torch.rand(xt.shape[0], device=xt.device) < self.config.model.flex_attention_txt_masking_prob
+ img_batch_attn_dropout = torch.rand(xt.shape[0], device=xt.device) < self.config.model.flex_attention_img_masking_prob
+
+ # If we mask out a modality, we cannot let it only see itself
+ txt_batch_attn_dropout = txt_batch_attn_dropout & ~should_mask_txt.squeeze(-1)
+ img_batch_attn_dropout = img_batch_attn_dropout & ~should_mask_img.squeeze(-1)
+ kwargs['block_mask'] = get_block_mask(txt_batch_attn_dropout, img_batch_attn_dropout, self.config.model.txt_length, xt.shape[0], xt.shape[1], xt.device)
+
+ # TODO: Somehow report these metrics so we know what's going on
+ ignore_batch_mask_for_metrics = ignore_batch_mask_for_metrics | (txt_batch_attn_dropout | img_batch_attn_dropout).unsqueeze(-1)
+
+ if getattr(self.config.trainer, "interleaved_training_flex_attention", False):
+ kwargs['block_mask'] = get_interleaved_block_mask(batch["sample_ids"], batch_size=xt.shape[0], seq_len=xt.shape[1], device=xt.device)
+ kwargs['sample_ids'] = batch["sample_ids"]
+
+ elif self.config.trainer.ar_inpainting:
+ x0 = torch.cat([x0, x0], dim=1)
+ kwargs['modality'] = torch.cat([kwargs['modality'], kwargs['modality']], dim=1)
+ attention_mask = torch.cat([torch.zeros_like(attention_mask, dtype=attention_mask.dtype), torch.ones_like(attention_mask, dtype=attention_mask.dtype)], dim=1)
+ modality_mask = torch.cat([modality_mask, modality_mask], dim=1)
+ min_val, max_val = 0.0, 1.0
+ n = x0.shape[0]
+ _eps_t = torch.rand(n, device=self.device)
+ offset = torch.arange(n, device=self.device) / n
+ _eps_t = (_eps_t / n + offset) % 1
+ t = (max_val - min_val) * _eps_t + min_val
+ if getattr(self.config.eval, "ar_inpainting_force_val", None) is not None:
+ t = torch.full_like(t, getattr(self.config.eval, "ar_inpainting_force_val"), dtype=t.dtype, device=t.device)
+ move_indices = torch.rand(*x0.shape, device=x0.device) < t[:, None]
+ move_indices[:, x0.shape[1] // 2:] = False
+ x0 = torch.where(move_indices, self.mask_index, x0)
+ xt = x0
+ else:
+ xt = x0
+ if (self.training or getattr(self.config.trainer, "force_flip_ar_val", False)) and self.config.trainer.rand_ar_modality_dropout is not None:
+ assert not is_xla_available
+ xt = xt.clone()
+ batch_modality_dropout = torch.rand(xt.shape[0], device=xt.device) < self.config.trainer.rand_ar_modality_dropout
+ first_modality = batch["modality"][:, 0]
+ first_modality_mask = batch["modality"] == first_modality[:, None]
+ xt = torch.where(first_modality_mask & batch_modality_dropout[:, None], self.mask_index, xt)
+ attention_mask = torch.where(first_modality_mask & batch_modality_dropout[:, None], False, attention_mask)
+ true_logits = None
+ model_output = self.forward(
+ xt, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=x_img_emb, joint_ar_nar_mask=joint_ar_nar_mask, **kwargs
+ )
+ if isinstance(model_output, tuple):
+ if continuous_mode:
+ model_output, img_output = model_output # model_output is for text, img_output is for image although both will have N_total length (zeroed out according to modality mask)
+ B, _, C = img_output.shape
+ # use modality mask to get the correct logits
+ x0 = x0[modality==0].reshape(B, -1)
+ xt = xt[modality==0].reshape(B, -1)
+ attention_mask = torch.ones_like(x0, dtype=torch.bool) # since we separate text, we don't need to mask it out
+ img_output = img_output[modality==1].reshape(B, -1, C)
+ target = target[modality==1].reshape(B, -1, C)
+ else:
+ model_output, true_logits = model_output
+
+ to_dtype = self.dtype if self.config.trainer.low_precision_loss else torch.float32
+ model_output = model_output.to(to_dtype)
+ if true_logits is not None:
+ true_logits = true_logits.to(self.dtype)
+
+ if continuous_mode:
+ img_output = img_output.to(to_dtype)
+ target = target.to(to_dtype)
+
+ # if prefix != 'train':
+ # breakpoint()
+
+ if self.config.trainer.ar_shift:
+ x0 = x0[:, 1:]
+ xt = xt[:, 1:]
+ attention_mask = attention_mask[:, 1:]
+ if modality_mask is not None: modality_mask = modality_mask[:, 1:]
+ if modality is not None: modality = modality[:, 1:]
+
+ if not self.is_compiled:
+ utils.print_nans(model_output, "model_output")
+
+ if self.parameterization == "sedd":
+ return dsigma[:, None] * self._score_entropy(model_output, sigma[:, None], xt, x0)
+ elif self.parameterization == "planner":
+ return F.binary_cross_entropy_with_logits(model_output.squeeze(-1), move_indices.float()).mean()
+
+ diffusion_loss = None
+ if self.T > 0:
+ diffusion_loss = self._d3pm_loss(model_output=model_output, xt=xt, x0=x0, t=t)
+ if self.parameterization == "d3pm":
+ reconstruction_loss = self._reconstruction_loss(x0)
+ elif self.parameterization == "subs" or self.parameterization == "ar":
+ reconstruction_loss = 0
+ # return reconstruction_loss + diffusion_loss
+
+ if self.parameterization == "ar":
+ if getattr(self.config.trainer, "use_orig_unidisc_dit", False):
+ return self.shortcut_return(model_output, x0, attention_mask, prefix)
+ else:
+ log_p_theta = model_output.gather(-1, x0[:, :, None])[:, :, 0]
+ else:
+ # SUBS parameterization, continuous time
+ log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1)
+
+ if self.change_of_variables or self.importance_sampling:
+ return log_p_theta * torch.log1p(-torch.exp(-self.noise.sigma_min))
+
+ if self.parameterization == "ar" or getattr(self.config.trainer, "no_ce_weighting", False):
+ std_weighting = 1
+ else:
+ std_weighting = (dsigma / torch.expm1(sigma))[:, None]
+
+ # ddprint(f"self.current_run_fwd_bwd_pass: {self.current_run_fwd_bwd_pass}, log_p_theta: {torch.isnan(log_p_theta).any()}")
+ # if torch.isnan(log_p_theta).any() or self.current_run_fwd_bwd_pass > 15473:
+ # import pickle
+ # import time
+ # rank = get_rank()
+ # timestamp = int(time.time() * 1e9) # nanosecond timestep
+ # filename = f'batch_datastep_{self.current_run_fwd_bwd_pass}_rank{rank}_{timestamp}.pkl'
+ # with open(filename, 'wb') as f:
+ # pickle.dump(log_p_theta, f)
+ # ddprint(f"Saved batch to {filename}")
+
+ loss = -log_p_theta * std_weighting
+ if not (self.parameterization == "ar" or (self.config.trainer.ar_llm_loss and joint_ar_nar_mask is None) or getattr(self.config.trainer, "no_ce_weighting", False)):
+ gamma = getattr(self.config.trainer, "softmin_snr", None)
+ if gamma is not None:
+ softmin_weighting = (dsigma / (torch.expm1(sigma) + (1 / gamma)))[:, None]
+ loss = -log_p_theta * softmin_weighting
+
+ if diffusion_loss is not None:
+ assert self.T > 0
+ loss = diffusion_loss
+
+ std_loss = -log_p_theta * std_weighting
+ loss_dict = dict(std_loss=std_loss.detach(), extra_losses=dict())
+
+ if self.config.trainer.log_seperate_modal_losses:
+ assert not continuous_mode
+ loss_dict.update(
+ dict(
+ std_txt_loss=(std_loss.detach() * modality_mask[..., 0] * attention_mask),
+ std_img_loss=(std_loss.detach() * modality_mask[..., 1] * attention_mask)
+ )
+ )
+
+ if getattr(self.config.trainer, "mask_entire_modality", None) is not None and self.backbone.training and not self.config.parameterization == "ar":
+ loss_dict['batch_ignore_loss'] = ignore_batch_mask_for_metrics.squeeze(-1)
+
+ if joint_ar_nar_mask is not None:
+ if "batch_ignore_loss" in loss_dict:
+ loss_dict["batch_ignore_loss"] = loss_dict["batch_ignore_loss"] | joint_ar_nar_mask
+ else:
+ loss_dict["batch_ignore_loss"] = joint_ar_nar_mask
+
+ if (self.config.trainer.multimodal_batches or (self.config.trainer.text_loss_weight is not None and self.config.trainer.img_loss_weight is not None)) and not continuous_mode:
+ txt_mask = modality_mask[..., 0] & attention_mask
+ img_mask = modality_mask[..., 1] & attention_mask
+ txt_count = txt_mask.sum()
+ img_count = img_mask.sum()
+ total_count = txt_count + img_count
+ txt_frac = txt_count / total_count
+ img_frac = img_count / total_count
+ loss_dict["extra_losses"]["trainer/img_frac"] = img_frac
+ loss_dict["extra_losses"]["trainer/txt_frac"] = txt_frac
+ loss_dict["extra_losses"]["trainer/attention_mask_valid_frac"] = attention_mask.sum() / attention_mask.numel()
+ if "batch_ignore_loss" in loss_dict:
+ loss_dict["extra_losses"]["trainer/ignore_batch_metrics_frac"] = loss_dict["batch_ignore_loss"].sum() / loss_dict["batch_ignore_loss"].numel()
+
+ if joint_ar_nar_mask is not None:
+ pass # Defer loss mean until after ar_loss is calculated
+ elif self.config.trainer.text_loss_weight is not None and self.config.trainer.img_loss_weight is not None:
+ assert not continuous_mode
+ loss = loss * attention_mask
+ txt_loss = (
+ loss[txt_mask].sum() / txt_count
+ ) * txt_frac * self.config.trainer.text_loss_weight
+ img_loss = (
+ loss[img_mask].sum() / img_count
+ ) * img_frac * self.config.trainer.img_loss_weight
+
+ if getattr(self.config.trainer, "set_max_txt_loss_ratio", None) is not None and not (torch.isnan(img_loss).any() or torch.isnan(txt_loss).any()):
+ max_txt_loss = getattr(self.config.trainer, "set_max_txt_loss_ratio", 1.5) * img_loss.detach()
+ scale = torch.minimum(torch.tensor(1.0, device=txt_loss.device), max_txt_loss / (txt_loss.detach() + 1e-8))
+ txt_loss = txt_loss * scale
+
+ txt_loss = torch.nan_to_num(txt_loss, nan=0.0)
+ img_loss = torch.nan_to_num(img_loss, nan=0.0)
+
+ if getattr(self.config.trainer, "force_remove_img_tokens", False):
+ img_loss = torch.tensor(0, device=loss.device, dtype=loss.dtype)
+
+ loss = txt_loss + img_loss
+ loss_dict.update(dict(txt_loss=txt_loss.clone().detach(), img_loss=img_loss.clone().detach()))
+
+ elif continuous_mode:
+ img_loss = F.mse_loss(img_output, target)
+
+ if attention_mask[:, self.static_txt_sl].numel() == 0:
+ # Let grads pass even though this is zeros...
+ txt_loss = (loss[:, self.static_txt_sl] * attention_mask[:, self.static_txt_sl]).sum()
+ else:
+ txt_loss = (loss[:, self.static_txt_sl] * attention_mask[:, self.static_txt_sl]).sum() / attention_mask[:, self.static_txt_sl].sum()
+ loss = txt_loss + img_loss * self.config.trainer.image_loss_weight
+ loss_dict.update(dict(img_loss=img_loss.clone().detach(), txt_loss=txt_loss.clone().detach()))
+ else:
+ _attention_mask = torch.ones_like(attention_mask) if getattr(self.config.trainer, "force_full_attention_mask_loss_only", False) else attention_mask
+ loss = (loss * _attention_mask).sum() / _attention_mask.sum()
+ loss = torch.nan_to_num(loss, nan=0.0)
+
+ ar_loss = None
+ if self.config.trainer.ar_llm_loss:
+ assert not continuous_mode
+ valid_loss = xt == self.mask_index
+ _labels = x0.clone()
+ _labels = torch.where(valid_loss, _labels, -1)
+ _labels = torch.where(~attention_mask.to(torch.bool), -1, _labels)
+
+ _logits = true_logits
+ _logits[:, :, self.mask_index] += self.neg_infinity
+
+ if getattr(self.config.model, "force_argmax_valid_indices", False):
+ assert not self.config.trainer.multimodal_batches
+ _logits[:, self.static_txt_sl, self.text_vocab_size:] = torch.finfo(_logits.dtype).min
+ _logits[:, self.static_img_sl, : self.text_vocab_size] = torch.finfo(_logits.dtype).min
+
+ _logits = _logits.contiguous().view(-1, _logits.shape[-1])
+ _labels = _labels.contiguous().view(-1)
+
+ if self.config.trainer.ar_print_loss:
+ _labels = _labels.to(_logits.device)
+ ce_loss = loss_fct(_logits, _labels)
+ loss_fct = nn.CrossEntropyLoss(reduction='none')
+ ce_loss = ce_loss.mean(dim=-1)
+ if hasattr(self, 'histogram') is False:
+ self.histogram = {}
+
+ update_histogram(self.histogram, t, ce_loss)
+ rprint(f"ELM loss: move: {move_chance}, t:{t}, {ce_loss}")
+
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-1, reduction='none' if joint_ar_nar_mask is not None else 'mean')
+ ce_loss = loss_fct(_logits, _labels)
+ loss_dict["extra_losses"]["trainer/ce_loss"] = ce_loss
+ ar_loss = ce_loss
+
+ if joint_ar_nar_mask is not None:
+ __true_logits = true_logits.clone()
+ __true_logits = torch.where(torch.arange(true_logits.shape[-1], device=true_logits.device)[None, None, :] == self.mask_index, self.neg_infinity, __true_logits)
+ log_softmax = __true_logits.log_softmax(-1)
+ ar_loss = -log_softmax.gather(-1, x0[:, :, None])[:, :, 0]
+
+ assert ar_loss is not None
+ assert ar_loss.ndim == 2
+ assert loss.ndim == 2
+ ar_loss_weight = joint_ar_nar_mask.sum(dim=0) / joint_ar_nar_mask.shape[0]
+ nar_loss_weight = 1 - ar_loss_weight
+ loss_dict["extra_losses"]["trainer/ar_loss_weight"] = ar_loss_weight.detach().float()
+ loss_dict["extra_losses"]["trainer/nar_loss_weight"] = nar_loss_weight.detach().float()
+ loss_dict["extra_losses"]["trainer/ce_loss"] = ar_loss.mean().detach().float()
+ ar_loss = (ar_loss * ar_loss_weight) * attention_mask
+ nar_loss = (loss * nar_loss_weight) * attention_mask
+ valid_count = attention_mask.sum()
+ if not is_xla_available:
+ ar_valid_count = attention_mask[joint_ar_nar_mask].sum()
+ nar_valid_count = attention_mask[~joint_ar_nar_mask].sum()
+ loss_dict["extra_losses"]["trainer/ar_loss"] = (ar_loss[joint_ar_nar_mask].sum() / ar_valid_count).detach().float()
+ loss_dict["extra_losses"]["trainer/nar_loss"] = (loss[~joint_ar_nar_mask].sum() / nar_valid_count).detach().float()
+ loss_dict["extra_losses"]["trainer/ar_ppl"] = torch.exp(loss_dict["extra_losses"]["trainer/ar_loss"]).detach().float()
+ loss_dict["extra_losses"]["trainer/nar_ppl"] = torch.exp(loss_dict["extra_losses"]["trainer/nar_loss"]).detach().float()
+ loss = (torch.where(joint_ar_nar_mask[:, None], ar_loss, nar_loss).sum() / valid_count) + weighted_z_loss
+ elif ar_loss is not None:
+ loss = ar_loss
+
+ loss_dict = dict(loss=loss, **loss_dict)
+ std_loss = loss_dict.get("std_loss", 0)
+ std_nlls = std_loss * attention_mask
+
+ if "batch_ignore_loss" in loss_dict:
+ attention_mask = torch.where(loss_dict['batch_ignore_loss'][:, None].repeat(1, attention_mask.shape[-1]), torch.full_like(attention_mask, False), attention_mask)
+
+ losses = Loss(
+ loss=loss_dict["loss"],
+ img_loss=loss_dict.get("img_loss", 0),
+ txt_loss=loss_dict.get("txt_loss", 0),
+ nlls=std_nlls,
+ txt_nlls=loss_dict.get("std_txt_loss", 0),
+ img_nlls=loss_dict.get("std_img_loss", 0),
+ token_mask=attention_mask,
+ modality_mask=modality_mask,
+ extra_losses=loss_dict.get("extra_losses", None),
+ )
+
+ if getattr(self.config.trainer, "disable_torchmetrics", False):
+ raise NotImplementedError("Torchmetrics disabled")
+
+ elif prefix == "train":
+ return losses
+ elif prefix == "val":
+ self.valid_metrics.update(losses.nlls, losses.token_mask)
+ if hasattr(self, "valid_txt_metrics"):
+ self.valid_txt_metrics.update(losses.txt_nlls, losses.modality_mask[..., 0] & losses.token_mask)
+ self.valid_img_metrics.update(losses.img_nlls, losses.modality_mask[..., 1] & losses.token_mask)
+
+ elif prefix == "test":
+ self.test_metrics.update(losses.nlls, losses.token_mask)
+ metrics = self.test_metrics
+ self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True)
+ else:
+ raise ValueError(f"Invalid prefix: {prefix}")
+
+ @torch.no_grad()
+ def zero_shot_eval(self):
+ dataloader = self.validation_dataloader
+ total_batches = len(dataloader)
+ rprint(f"Zero shot eval with {total_batches} batches with limit_val_batches: {self.config.trainer.limit_val_batches}")
+ for idx, batch in tqdm(enumerate(dataloader), total=total_batches, desc="Zero shot eval validation steps", disable=not is_main_process()):
+ if self.config.trainer.limit_val_batches is not None and idx >= self.config.trainer.limit_val_batches:
+ break
+ self.zero_shot_eval_step(batch, idx)
+
+ self.zero_shot_eval_epoch_end()
+
+ def validate(self, state: TrainingState):
+ self.on_validation_epoch_start()
+
+ if getattr(self.config.eval, "compute_val_metrics_standalone", False) and getattr(self.config.eval, "bypass_normal_validation", False):
+ batch = next(iter(self.validation_dataloader))
+ self.on_validation_epoch_end(example_batch=batch)
+ self.on_validation_epoch_cleanup()
+ return
+
+ total_len = 10 if self.config.data.iterable or self.config.data.webdataset_indexed else len(self.validation_dataloader)
+ dprint(f"Validation batches: {total_len}")
+
+ total_batches = (
+ self.config.trainer.limit_val_batches
+ if (self.config.trainer.limit_val_batches is not None and self.fid_eval is False)
+ else total_len
+ )
+ if getattr(self.config.eval, 'pplx_full_dataset', False):
+ rprint("[INFO] PPLX full dataset eval, setting total_batches to total_len")
+ total_batches = total_len
+ elif self.config.eval.max_num_fid_batches_per_device is not None and self.fid_eval:
+ total_batches = min(total_len, self.config.eval.max_num_fid_batches_per_device)
+
+ _dataloader = self.train_dataloader if self.config.eval.val_with_train_data else self.validation_dataloader
+ rprint(f"Validating with {total_batches} batches on {self.world_size} GPUs with batch size {self.config.loader.eval_batch_size}")
+ for idx, batch in tqdm(enumerate(_dataloader), total=total_batches, desc="Validation steps", disable=not is_main_process()):
+ if self.config.trainer.limit_val_batches is not None and idx >= total_batches:
+ break
+ self.validation_step(batch, idx)
+
+ if getattr(self.config.eval, "eval_large_batch", None) is not None:
+ assert isinstance(batch, TensorDict)
+ dataloader_iter = iter(_dataloader)
+ large_batch = [next(dataloader_iter, None) for _ in range(getattr(self.config.eval, "eval_large_batch", None))]
+ large_batch = [b for b in large_batch if b is not None]
+ large_batch = torch.stack(large_batch, dim=0)
+ batch = large_batch
+ gprint(f"Large batch shape: {batch.shape}")
+ else:
+ batch = next(iter(_dataloader))
+
+ if self.config.eval.visualize_data_only:
+ return
+
+ if self.config.eval.compute_standalone_mauve and not getattr(self.config.eval, "global_disable_mauve", False):
+ self.mauve_store_references(_dataloader)
+
+ if self.config.mode == "eval":
+ gprint(f"Batch shape: {batch['input_ids'].shape}")
+
+ self.on_validation_epoch_end(example_batch=batch)
+ self.on_validation_epoch_cleanup()
+
+ @cached_property
+ def global_batch_size(self):
+ """Batch size for a single step over all GPUs"""
+ # SPMD treats all ranks [regardless of node] as a single device
+ return self.step_batch_size * (1 if (self.config.trainer.xla_spmd and is_xla_available) else self.world_size)
+
+ @cached_property
+ def step_batch_size(self):
+ """Batch size for a single step for a single GPU"""
+ return self.config.loader.batch_size * self.config.trainer.accumulate_grad_batches
+
+ @cached_property
+ def world_size(self):
+ """Number of GPUs over all nodes"""
+ return get_world_size()
+
+ @cached_property
+ def num_tokens_per_sample(self):
+ """Number of tokens per sample"""
+ return self.config.model.length
+
+ @cached_property
+ def gradient_accumulation_steps(self):
+ """Number of gradient accumulation steps"""
+ return self.config.trainer.accumulate_grad_batches
+
+ @cached_property
+ def static_txt_sl(self):
+ return slice(None, self.config.model.txt_length)
+
+ @cached_property
+ def static_img_sl(self):
+ return slice(-self.config.model.img_length, None)
+
+ def img_txt_pair_batch_mask(self, batch=None):
+ return batch["modality_mask"][..., 1].sum(dim=-1) > 0
+
+ def txt_sl(self, batch=None):
+ return batch["modality_mask"][..., 0]
+
+ def img_sl(self, batch=None):
+ return batch["modality_mask"][..., 1]
+
+ @cached_property
+ def is_compiled(self):
+ return is_xla_available or self.config.trainer.compile
+
+ @property
+ def allow_slicing(self):
+ return not is_xla_available and not self.backbone.training
+
+ @property
+ def training(self):
+ return self.backbone.training
+
+ def get_step_metrics(self):
+ return {
+ "trainer/global_step": self.global_step,
+ "global_samples": self.global_step * self.global_batch_size,
+ "train_metrics/global_tokens": self.global_step * self.global_batch_size * self.config.model.length,
+ "effective_global_tokens": self.global_step * self.global_batch_size * self.config.model.length * (0.5 if self.config.parameterization == "subs" else 1.0),
+ "effective_global_step": int(self.global_step * (0.5 if self.config.parameterization == "subs" else 1.0)),
+ }
+
+ def train(self):
+ tr = self.config.trainer
+ total_batch_size = self.global_batch_size
+ initial_global_step = self.global_step
+ true_step = 0
+ first_epoch = 0
+ self.current_run_global_step = 0
+ self.current_run_fwd_bwd_pass = 0
+ rprint(f"Started at step {self.accelerator.step}")
+ if self.non_embedding_params < 1e9:
+ with try_except(write_error_to_file=True, clear_cuda_cache=True):
+ self.print_hashes()
+
+ # There is an unknown bug with accelerator where non-master ranks don't load the step count from a checkpoint.
+ # We workaround by broadcasting the step count if necessary
+ if is_torch_cuda_available():
+ dprint(f"Gathering step from {self.world_size} ranks")
+ starting_steps = gather_object([self.accelerator.step])
+ rprint(f"Starting steps: {starting_steps}")
+ if not all([x > 0 for x in starting_steps]):
+ rprint(f"Not all ranks have >0 step, setting to: {starting_steps[0]}")
+ self.accelerator.step = starting_steps[0]
+
+ if is_xla_available:
+ import torch_xla.core.xla_model as xm
+ import torch_xla.debug.profiler as xp
+ assert (self.config.trainer.accumulate_grad_batches == 1) or getattr(self.config.trainer, "allow_accum_grad_batches_xla", False), "Accumulate grad batches must be 1 for XLA"
+
+ rprint(f"***** Starting training at global step: {self.global_step} *****")
+ rprint(f" Instantaneous batch size per device = {self.config.loader.batch_size}")
+ rprint(f" Gradient Accumulation steps = {tr.accumulate_grad_batches}")
+ rprint(f" Num GPUs = {tr.devices}")
+ rprint(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ rprint(f" Total optimization steps = {tr.max_steps}")
+ rprint(f" Reported Global Batch Size: {self.global_batch_size}, Reported Step Batch Size: {self.step_batch_size}, Reported World Size: {self.world_size}")
+
+ if not self.config.data.iterable and not self.config.data.webdataset_indexed and is_torch_cuda_available():
+ num_epoch_steps = len(self.train_dataloader)
+ rprint(f" Num examples = {len(self.train_dataloader.dataset)}")
+ rprint(f" Num batches each epoch = {len(self.train_dataloader)}")
+ rprint(f"Train Dataloader Size on single GPU: {num_epoch_steps}")
+ if len(self.train_dataloader.dataset) < total_batch_size:
+ rprint("The training dataloader is smaller than the total batch size. This may lead to unexpected behaviour.")
+ else:
+ num_epoch_steps = 10000
+
+ if self.config.trainer.pytorch_profile:
+ profiler = Profiler(
+ output_dir=self.config.output_dir, warmup_steps=tr.profiler_warmup_steps, active_steps=tr.profiler_active_steps, record_memory=True
+ )
+
+ if self.config.trainer.viz_images_only:
+ return self.viz_images_from_dataloader()
+
+ progress_bar = tqdm(range(0, tr.max_steps), initial=initial_global_step, desc="Steps", disable=not is_local_main_process(), leave=False, smoothing=0.15)
+
+ global_step_metrics = defaultdict(float)
+ global_extra_wandb_metrics = dict()
+ accumulate_steps = 0
+ first_start_time = time.time()
+ self.on_train_start()
+
+ rprint(f"Training for {tr.num_epochs} epochs...")
+ last_end_step_time = start_timing(f"Dataloading accum:{accumulate_steps}, #{true_step}, global_step:{self.global_step}")
+ for epoch in range(first_epoch, tr.num_epochs):
+ rprint(f"Starting epoch {epoch}...")
+ for step, batch in enumerate(self.train_dataloader):
+ ddprint(f"Data Step: {step}")
+ if self.config.trainer.iterate_dataloader_only:
+ rprint(f"Iterating dataloader only: {step}")
+ # rprint((batch["modality"] == 0).sum(), (batch["modality"] == 1).sum())
+ if (batch["attention_mask"] == 0).all(dim=-1).any():
+ breakpoint()
+ batch = self.update_batch(batch)
+ if (batch["sample_ids"] == -1).all(dim=-1).any():
+ breakpoint()
+ continue
+
+ elif getattr(self.config.trainer, "iterate_dataloader_n_dataloader_batches", None) is not None and step <= self.config.trainer.iterate_dataloader_n_dataloader_batches:
+ self.current_run_fwd_bwd_pass += 1
+ if self.current_run_fwd_bwd_pass % self.config.trainer.accumulate_grad_batches == 0:
+ self.global_step += 1
+ self.current_run_global_step += 1
+ ddprint(f"Iterating dataloader only for {self.config.trainer.iterate_dataloader_n_dataloader_batches} dataloader batches. At step {self.global_step=}, {self.current_run_global_step=}, {self.current_run_fwd_bwd_pass=}")
+ continue
+
+ if self.config.trainer.tpu_force_mark_step: xm.mark_step()
+ if self.config.trainer.sync_dataloader_timing: synchronize_device()
+ global_step_metrics[f"dataloading_time"] += end_timing(last_end_step_time)
+
+ if self.config.trainer.nvtx_profile and self.is_compiled and step == 4:
+ torch.cuda.cudart().cudaProfilerStart()
+
+ if self.current_run_global_step == 1 and is_xla_available:
+ gprint(f"First start time: {time.time() - first_start_time}")
+
+ if getattr(self.config.data, "force_dummy_tensordict", False):
+ gprint(self.global_step, self.current_run_global_step, true_step, batch["idx"].tolist(), batch["dataset_idx"].tolist())
+
+ if getattr(self.config.trainer, "assert_at_n_steps", None) is not None and self.global_step == self.config.trainer.assert_at_n_steps:
+ gprint(batch["img_input_ids"].min(), batch["img_input_ids"].max(), batch["txt_input_ids"].min(), batch["txt_input_ids"].max())
+
+ if batch is None:
+ rprint(f"Batch is None at step {step}")
+ continue
+
+ if self.config.trainer.tpu_force_mark_step: xm.mark_step()
+ ddprint(f"After Data Step 2: {step}")
+ with nullcontext() if is_xla_available else self.accelerator.accumulate(self.backbone):
+ ddprint(f"Before forward pass for global_step: {self.global_step}")
+ start_forward_time = start_timing(f"Forward Pass accum:{accumulate_steps}, #{true_step}, global_step:{self.global_step}")
+ global_step_metrics["examples_seen_per_gpu"] += len(next(iter(batch.values())))
+ state: TrainingState = TrainingState(
+ epoch_step=step,
+ num_epoch_steps=num_epoch_steps,
+ global_step=self.global_step,
+ epoch=epoch,
+ true_step=true_step,
+ current_run_global_step=self.current_run_global_step,
+ )
+
+ if self.accelerator.sync_gradients and is_xla_available is False:
+ self.cb_handler.on_train_step_start(state=state, unit=None)
+
+ if self.config.trainer.tpu_force_mark_step: xm.mark_step()
+
+ ddprint(f"Before Fwd: {step}")
+ with xp.StepTrace('Forward', step_num=step) if self.config.trainer.tpu_profile else nullcontext():
+ losses = self.training_step(batch, step)
+
+ ddprint(f"After Fwd: {step}")
+ global_step_metrics["forward_pass_time"] += end_timing(start_forward_time)
+ true_step += 1
+ evaluate_extra_log_data = lambda: dict()
+
+ if self.config.trainer.tpu_force_mark_step: xm.mark_step()
+
+ if isinstance(losses, dict):
+ for k, v in losses.items():
+ if isinstance(v, torch.Tensor):
+ global_step_metrics[k.removeprefix("metric_")] += v.detach().cpu().item()
+ else:
+ global_extra_wandb_metrics[k.removeprefix("metric_")] = v
+ losses = dict(
+ filter(lambda item: not item[0].startswith("metric_"), losses.items())
+ ) # Allow for custom metrics that are not losses
+ loss = sum(losses.values())
+ elif isinstance(losses, Loss):
+ loss = losses.loss
+ metrics = self.train_metrics(losses.nlls, losses.token_mask)
+ if hasattr(self, "txt_metrics") and losses.modality_mask is not None:
+ txt_metrics = self.txt_metrics(losses.txt_nlls, losses.modality_mask[..., 0] & losses.token_mask)
+ if hasattr(self, "img_metrics") and losses.modality_mask is not None:
+ img_metrics = self.img_metrics(losses.img_nlls, losses.modality_mask[..., 1] & losses.token_mask)
+
+ extra_losses_dict = losses.extra_losses
+ extra_losses_dict = extra_losses_dict if extra_losses_dict is not None else dict()
+ if self.config.trainer.tpu_force_mark_step: xm.mark_step()
+
+ def evaluate_extra_log_data():
+ if hasattr(self, "txt_metrics"):
+ return {
+ **{f"train/txt_{k.split('/')[-1]}": v for k, v in replace_nan_dict(txt_metrics).items()},
+ **{f"train/img_{k.split('/')[-1]}": v for k, v in replace_nan_dict(img_metrics).items()},
+ }
+ else:
+ return {}
+
+ ddprint(f"Before loss: {step}")
+ incremental_dict_update(global_extra_wandb_metrics, {
+ "trainer/loss": loss,
+ "trainer/img_loss": losses.img_loss,
+ "trainer/txt_loss": losses.txt_loss,
+ **{
+ "global_samples": self.global_step * self.global_batch_size,
+ "train_metrics/global_tokens": self.global_step * self.global_batch_size * self.config.model.length,
+ "effective_global_tokens": self.global_step * self.global_batch_size * self.config.model.length * (0.5 if self.config.parameterization == "subs" else 1.0),
+ "effective_global_step": int(self.global_step * (0.5 if self.config.parameterization == "subs" else 1.0)),
+ },
+ **metrics,
+ **extra_losses_dict,
+ })
+ if self.config.trainer.tpu_force_mark_step: xm.mark_step()
+ else:
+ loss = losses
+
+ if is_torch_cuda_available():
+ global_step_metrics["loss"] = loss.detach().cpu().item() # Only on the main process to avoid syncing
+
+ ddprint(f"Before backward pass for global_step: {self.global_step}")
+
+ # Short-circuit to avoid XLA eval
+ if tr.backward_pass and (is_xla_available or torch.isfinite(loss).all()):
+ start_backward_time = start_timing(f"Backward Pass accum:{accumulate_steps}, #{true_step}, global_step:{self.global_step}")
+ if self.accelerator.sync_gradients:
+ start_sync_time = start_timing(f"Gradient Sync global_step:{self.global_step}")
+ if getattr(self.config.trainer, "sync_timing", False):
+ sync_times(self.device)
+
+ if self.config.trainer.tpu_force_mark_step: xm.mark_step()
+
+ # After each fwd, we perform a bwd. However, if we are accumulating there is an internal no_sync so the gradients remain on the GPU until
+ # the final bwd before a step. This can be controlled by sync_each_batch. Note that for the last bwd, the sync happens inside the bwd call below, so any timing for stragglers needs to happen before this call.
+ with xp.StepTrace('Backward', step_num=step) if self.config.trainer.tpu_profile else nullcontext():
+ ddprint(f"Before accelerator.backward for global_step: {self.global_step}")
+ self.accelerator.backward(loss)
+ ddprint(f"After accelerator.backward for global_step: {self.global_step}")
+
+ with xp.StepTrace('After Backward + Clip', step_num=step) if self.config.trainer.tpu_profile else nullcontext():
+ if self.accelerator.sync_gradients:
+ ddprint(f"Before after.backward for global_step: {self.global_step}")
+ self.after_backward(state)
+ if tr.gradient_clip_val is not None:
+ ddprint(f"Before self.accelerator.clip_grad_norm_ for global_step: {self.global_step}")
+ total_grad_norm = self.accelerator.clip_grad_norm_(self.backbone.parameters(), tr.gradient_clip_val)
+ ddprint(f"After self.accelerator.clip_grad_norm_ for global_step: {self.global_step}")
+
+ with xp.StepTrace('Optimizer + Scheduler Step', step_num=step) if self.config.trainer.tpu_profile else nullcontext():
+ ddprint(f"Before optimizer step for global_step: {self.global_step}, {step}")
+ if is_xla_available and False:
+ # TODO: xm.optimizer_step(self.optimizer) does not appear to be needed for XLA
+ xm.optimizer_step(self.optimizer)
+ else:
+ self.optimizer.step()
+ ddprint(f"After optimizer step for global_step: {self.global_step}, {step}")
+ self.lr_scheduler.step()
+ ddprint(f"After lr_scheduler step for global_step: {self.global_step}, {step}")
+
+ zero_grad_kwargs = dict()
+ if "apex" not in self.config.trainer.optimizer_cls:
+ zero_grad_kwargs["set_to_none"] = tr.set_grads_to_none
+
+ ddprint(f"Before zero_grad for global_step: {self.global_step}, {step}")
+ self.optimizer.zero_grad(**zero_grad_kwargs)
+ ddprint(f"Zeroed gradients for global_step: {self.global_step}, {step}")
+
+ if self.accelerator.sync_gradients:
+ if self.ema is not None:
+ if self.config.trainer.use_custom_ema:
+ ema_update(self.unwrap_model(self.ema), self.unwrap_model(self.backbone), self.config.trainer.ema)
+ else:
+ self.ema.step(self.get_params())
+ global_step_metrics["gradient_sync_time"] += end_timing(start_sync_time)
+
+ global_step_metrics["backward_pass_time"] += end_timing(start_backward_time)
+ else:
+ if not torch.isfinite(loss).all(): gprint(f"Loss is not finite: {loss}")
+ gprint("Skipping backward pass!")
+
+ accumulate_steps += 1
+ self.current_run_fwd_bwd_pass += 1
+
+ # Important: A single "global_step" is a single optimizer step. The accumulate decorator silently skips backward + optimizer to allow for gradient accumulation.
+ # A "true_step" counts the number of forward passes (on a per-GPU basis). The condition below should only happen immediately after a backward + optimizer step.
+ ddprint(f"Syncing gradients for global_step: {self.global_step}. Should sync: {self.accelerator.sync_gradients}, {step}, {self.accelerator.step}, {self.accelerator.gradient_accumulation_steps}")
+ if self.accelerator.sync_gradients:
+ start_gradient_sync_time = start_timing(f"On Sync Gradients global_step:{self.global_step}, {step}")
+
+ ddprint(f"Before on_train_step_end for global_step: {self.global_step}, {step}, {self.accelerator.step}, {self.accelerator.gradient_accumulation_steps}")
+ state.batch = batch
+ del loss, losses, batch
+ gradient_sync_time_after_train_step_end_time = start_timing(f"On Sync Gradients global_step:{self.global_step}, {step}")
+ self.on_train_step_end(state)
+ ddprint(f"After on_train_step_end for global_step: {self.global_step}, {step}, {self.accelerator.step}, {self.accelerator.gradient_accumulation_steps}")
+ global_step_metrics["gradient_sync_time_after_train_step_end"] += end_timing(gradient_sync_time_after_train_step_end_time)
+
+ if is_xla_available and self.config.trainer.tpu_force_mark_step: xm.mark_step()
+
+ if self.config.trainer.profile_memory and self.global_step + 1 >= tr.max_steps:
+ rprint("Finished profiling memory...")
+ break
+
+ if self.config.trainer.pytorch_profile and profiler.step(self.global_step):
+ rprint(f"Profiling finished at step: {self.global_step}")
+ break
+
+ if getattr(self.config.trainer, "throw_failure_for_testing", False) and self.current_run_global_step == 5:
+ raise RuntimeError("Test failure")
+
+ if is_xla_available and self.config.trainer.tpu_force_mark_step: xm.mark_step()
+
+ progress_bar.update(1)
+ self.global_step += 1
+ self.current_run_global_step += 1
+ global_step_metrics["gradient_sync_time"] += end_timing(start_gradient_sync_time)
+
+ logs = {
+ "examples_seen": self.global_step * total_batch_size,
+ "trainer/global_step": self.global_step,
+ **{k:v for k, v in global_step_metrics.items()},
+ **{f"lr_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_last_lr())},
+ **global_extra_wandb_metrics,
+ }
+
+ if is_torch_cuda_available():
+ logs["gpu_max_mem_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3)
+ logs["gpu_cur_mem_reserved_gb"] = torch.cuda.memory_reserved() / (1024**3)
+ logs["gpu_max_mem_allocated_gb"] = torch.cuda.max_memory_allocated() / (1024**3)
+ logs["gpu_cur_mem_allocated_gb"] = torch.cuda.memory_allocated() / (1024**3)
+
+ if is_xla_available:
+ if self.global_step % getattr(self.config.trainer, "log_every_n_steps", 1) == 0:
+ xm.add_step_closure(update_logs, args=(logs, evaluate_extra_log_data), run_async=False)
+ del logs
+ global_extra_wandb_metrics = dict()
+ if self.config.trainer.tpu_force_mark_step: xm.mark_step()
+ else:
+ logs.update(evaluate_extra_log_data())
+ progress_bar.set_postfix(**logs)
+ log(logs)
+ global_extra_wandb_metrics = dict()
+
+
+ if getattr(self.config.trainer, "sync_timing", False):
+ global_step_metrics = {f"rank_{get_rank()}/{k}": v for k, v in global_step_metrics.items()}
+ all_step_metrics = self.accelerator.gather_for_metrics([global_step_metrics], use_gather_object=True)
+ merged_metrics = {k: v for d in all_step_metrics for k, v in d.items()}
+ log(merged_metrics)
+
+ if is_xla_available and self.config.trainer.tpu_force_mark_step: xm.mark_step()
+
+ global_step_metrics = defaultdict(float)
+ accumulate_steps = 0
+
+ if self.global_step >= tr.max_steps:
+ break
+
+ ddprint(f"After logging for step v3: {self.global_step}, {step}")
+
+ if getattr(self.config.trainer, "assert_at_n_steps", None) is not None and self.current_run_global_step >= getattr(self.config.trainer, "assert_at_n_steps", None):
+ raise RuntimeError(f"Assertion failed at step {self.current_run_global_step}")
+
+ ddprint(f"After logging for step v4: {self.global_step}, {step}")
+
+ if is_xla_available and self.config.trainer.tpu_profile and (self.global_step == 0 or self.global_step % 50 == 0) and is_main_process():
+ import torch_xla.debug.metrics as met
+ rprint(met.metrics_report())
+ met.clear_all()
+
+ if is_xla_available and self.config.trainer.tpu_force_mark_step: xm.mark_step()
+ ddprint(f"Finished sync_gradients: {self.global_step}, {self.accelerator.step}, {self.accelerator.gradient_accumulation_steps}")
+
+ ddprint(f"Finished step: {self.global_step},{step},{self.accelerator.step},{self.accelerator.gradient_accumulation_steps},{self.accelerator.gradient_state.__repr__()}")
+ if self.config.trainer.sync_dataloader_timing: synchronize_device()
+ last_end_step_time = start_timing(f"Dataloading #{true_step + 1}")
+
+ if self.global_step >= tr.max_steps:
+ break
+
+ dprint(f"Finished epoch: {epoch}")
+
+ # Create the pipeline using using the trained modules and save it.
+ rprint("Training finished.")
+ barrier()
+
+ if tr.profile_memory:
+ print_memory(verbose=True)
+ save_memory_profile(self.config.output_dir / "profile")
+
+ if tr.pytorch_profile:
+ profiler.finish()
+ elif tr.nvtx_profile:
+ torch.cuda.cudart().cudaProfilerStop()
+ elif self.global_step > 100 or tr.skip_early_checkpointing is False:
+ self.checkpoint(state)
+
+ barrier()
diff --git a/model_eval.py b/model_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..f78b17948f8f1249710cbba8826ec68f549d66c3
--- /dev/null
+++ b/model_eval.py
@@ -0,0 +1,4042 @@
+import ast
+from copy import deepcopy
+import json
+import math
+import os
+import pickle
+import random
+import shutil
+import string
+import time
+from datetime import datetime
+from pathlib import Path
+from typing import Optional
+import pandas as pd
+from constants import UNIDISC_DIR
+from data_defs import InterleavedBatch
+import einops
+import numpy as np
+from unidisc.utils.simple_llm import get_llm
+from unidisc.utils.viz_utils import augment_image_with_random_object_coco, create_text_image
+import torch
+import torch.utils.checkpoint
+from accelerate.utils import gather, gather_object
+from image_utils import Im
+from jaxtyping import Bool, Float, Integer
+from PIL import Image
+from tensordict import TensorDict, tensorclass
+from torch import Tensor
+from tqdm import tqdm
+from collections import defaultdict
+import torch.nn.functional as F
+import utils
+import wandb
+from decoupled_utils import (barrier, dprint, get_num_gpus, get_rank, get_world_size,
+ gprint, is_main_process, print_memory_summary,
+ rprint, save_memory_profile, show_memory_usage, try_except, sanitize_filename)
+from unidisc.tokenizers.chameleon_tokenizers import (decode_ids_batched,
+ get_chameleon_images)
+from unidisc.tokenizers.image_tokenizers import decode_latents, get_image_batch
+from unidisc.utils.throughput_monitor import get_available_flops
+from model_utils import (_sample_categorical, empty_device_cache, get_chameleon_txt_indices, get_interleaved_block_mask, log,
+ remap_image_torch, replace_nan_dict,
+ wrapped_batch_decode)
+from torch import nn
+from model_utils import get_block_mask, MauveScore, Entropy
+
+def get_anole_data(self, model, processor, prompt, image, dtype, device):
+ inputs = processor(text=prompt, images=[image], padding=True, return_tensors="pt").to(device=device, dtype=dtype)
+ image_tokens = model.model.get_image_tokens(inputs["pixel_values"])
+ special_image_mask = inputs["input_ids"] == model.model.vocabulary_mapping.image_token_id
+ image_tokens = image_tokens.to(inputs["input_ids"].device, inputs["input_ids"].dtype)
+ inputs["input_ids"] = inputs["input_ids"].masked_scatter(special_image_mask, image_tokens)
+ inputs.pop("pixel_values")
+ return inputs
+
+def calculate_chameleon_perplexity(self, model, processor, prompts, images, dtype=torch.bfloat16, return_all=False, standalone=False):
+ """
+ Calculate perplexities for multiple prompts and images using the Chameleon model.
+
+ Args:
+ model (ChameleonForConditionalGeneration): The Chameleon model.
+ processor (ChameleonProcessor): The Chameleon processor.
+ prompts (List[str]): List of prompt strings.
+ images (List[Image.Image]): List of PIL Image objects.
+ device (str): The device to use for computation (default: "cuda:0").
+ dtype (torch.dtype): The data type to use (default: torch.bfloat16).
+
+ Returns:
+ List[float]: List of perplexities for each prompt-image pair.
+ """
+ device = self.device
+ if model is None or processor is None:
+ model = getattr(self, "chameleon_model", None)
+ processor = getattr(self, "chameleon_processor", None)
+ if model is None:
+ from image_utils import Im
+ from transformers import (ChameleonForConditionalGeneration, ChameleonProcessor)
+ self.chameleon_model = ChameleonForConditionalGeneration.from_pretrained("leloy/Anole-7b-v0.1-hf", torch_dtype=torch.bfloat16).to("cuda")
+ self.chameleon_processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf")
+
+ model = self.chameleon_model
+ processor = self.chameleon_processor
+ assert len(prompts) == len(images), "Number of prompts and images must match"
+
+ perplexities = []
+
+ for prompt, image in zip(prompts, images):
+ if not standalone:
+ txt_first_prompt = f"{prompt} "
+ img_first_prompt = f" {prompt}"
+ else:
+ txt_first_prompt = prompt
+ img_first_prompt = ""
+ tot_ppl = 0.0
+ tot_loss = 0.0
+ img_loss = 0.0
+ txt_loss = 0.0
+ for i, _prompt in enumerate([txt_first_prompt, img_first_prompt]):
+ inputs = self.get_anole_data(model, processor, _prompt, image, dtype, device)
+ img_start_tok_id = self.chameleon_processor.tokenizer(self.chameleon_processor.image_start_token)['input_ids'][1]
+ img_end_tok_id = self.chameleon_processor.tokenizer(self.chameleon_processor.image_end_token)['input_ids'][1]
+ if i == 0:
+ # text first
+ mod_mask = torch.cumsum(inputs['input_ids'] == img_start_tok_id, dim=1).bool()
+ else:
+ # img first
+ mod_mask = torch.cumsum(inputs['input_ids'] == img_end_tok_id, dim=1).bool()
+ mod_mask = mod_mask.cumsum(dim=1) > 1
+ output = model(
+ input_ids=inputs['input_ids'].to(device),
+ attention_mask=inputs['attention_mask'].to(device),
+ labels=inputs['input_ids'].to(device)
+ )
+ loss = output.loss
+ perplexity = torch.exp(loss).item()
+ tot_ppl += perplexity
+ logits = output.logits
+ logits = logits.transpose(-1, -2)
+ sample_chunk = inputs["input_ids"]
+ nlls = F.cross_entropy(logits[..., :-1].to(self.device), sample_chunk[..., 1:].to(self.device), reduction="none")
+ mod_mask = mod_mask[:, 1:]
+ # img nll is where mod_mask == 1
+ zeros = torch.zeros_like(nlls)
+ img_nll = torch.where(mod_mask, nlls, zeros).mean().item()
+ txt_nll = torch.where(~mod_mask, nlls, zeros).mean().item()
+ tot_loss += loss.item()
+ if not standalone:
+ txt_loss += txt_nll
+ img_loss += img_nll
+ else:
+ if i == 0:
+ txt_loss += loss.item()
+ else:
+ img_loss += loss.item()
+
+ if not standalone:
+ tot_ppl /= 2
+ tot_loss /= 2
+ img_loss /= 2
+ txt_loss /= 2
+
+ if return_all:
+ perplexities.append((tot_ppl, tot_loss, img_loss, txt_loss))
+ else:
+ perplexities.append(tot_ppl)
+
+ print(f"Total PPL: {tot_ppl} | Total Loss: {tot_loss} | Img Loss: {img_loss} | Txt Loss: {txt_loss}")
+ return perplexities
+
+def get_every_n_evals(self, n):
+ return (
+ self.config.mode == "eval"
+ or ((self.num_evals > 0 or getattr(self.config.eval, "log_on_start", False)) and n > 0 and self.num_evals % n == 0)
+ ) and n != -1
+
+@try_except(write_error_to_file=True)
+def on_validation_epoch_start(self):
+ rprint("on_validation_epoch_start")
+ # EMA (Exponential Moving Average) is a technique used to maintain a moving average of model parameters
+ # It can help stabilize training and potentially improve model performance
+ if self.ema is not None and not self.config.trainer.use_custom_ema:
+ # Store the current model parameters in the EMA object
+ rprint(" [WARNING] USING EMA IN on_validation_epoch_start - THIS MIGHT RESET LOADED WEIGHTS ".center(100, "!"))
+ self.ema.store(self.get_params())
+ # Copy the EMA parameters to the current model
+ self.ema.copy_to(self.get_params())
+
+ self.backbone.eval()
+ self.reset_validation_metrics()
+
+ if getattr(self.config.trainer, "disable_torchmetrics", False) is False:
+ assert self.valid_metrics.nll.mean_value == 0
+ assert self.valid_metrics.nll.weight == 0
+ if self.non_embedding_params < 1e9:
+ self.print_hashes()
+ if (
+ self.image_model
+ and getattr(self.config.model, "image_model_fid_eval", False)
+ and self.get_every_n_evals(getattr(self.config.eval, "log_every_n_fid", 10))
+ ):
+
+ self.fid_eval = True
+ if self.config.eval.fid_mode == "inline":
+ from vqgan.inception_metrics import MultiInceptionMetrics
+ self.inception_metrics = MultiInceptionMetrics(
+ reset_real_features=False,
+ compute_unconditional_metrics=True,
+ compute_conditional_metrics=False,
+ compute_conditional_metrics_per_class=False,
+ num_classes=1000,
+ num_inception_chunks=10,
+ manifold_k=3,
+ )
+ if self.config.mode == "eval":
+ self.computed_tokens = []
+ else:
+ if getattr(self.config.eval, "force_fid_output_dir", None) is None:
+ shm_path = Path("/dev/shm") / os.getenv("USER")
+ fid_save_path = shm_path / Path(self.config.output_dir).parent.stem / Path(self.config.output_dir).stem / f"{self.num_evals}_{self.global_step}" / "fid_gen"
+ else:
+ fid_save_path = Path(getattr(self.config.eval, "force_fid_output_dir", None)) / "fid_gen"
+ fid_save_path.mkdir(parents=True, exist_ok=True)
+ fid_gt_path = fid_save_path.parent / (fid_save_path.name.replace("gen", "gt"))
+ fid_gt_path.mkdir(parents=True, exist_ok=True)
+ self.fid_gen_dir = fid_save_path
+ self.fid_gt_dir = fid_gt_path
+ rprint(f"FID eval output dir: {self.fid_gen_dir}, FID GT dir: {self.fid_gt_dir}")
+
+ rprint(f"Setting FID eval for epoch {self.num_evals}")
+ else:
+ self.fid_eval = False
+ if self.image_model and getattr(self.config.model, "image_model_fid_eval", False):
+ rprint(f"Not setting FID eval: num_evals: {self.num_evals} % {getattr(self.config.eval, 'log_every_n_fid', 10)}")
+
+ if self.config.eval.compute_img_to_txt_mauve_clip:
+ shm_path = Path("/dev/shm") / os.getenv("USER")
+ img_to_txt_mauve_save_path = shm_path / Path(self.config.output_dir).parent.stem / Path(self.config.output_dir).stem / f"{self.num_evals}_{self.global_step}" / "img_to_txt_mauve_gen"
+ img_to_txt_mauve_save_path.mkdir(parents=True, exist_ok=True)
+ img_to_txt_mauve_gt_path = img_to_txt_mauve_save_path.parent / (img_to_txt_mauve_save_path.name.replace("gen", "gt"))
+ img_to_txt_mauve_gt_path.mkdir(parents=True, exist_ok=True)
+ self.img_to_txt_mauve_gen_dir = img_to_txt_mauve_save_path
+ self.img_to_txt_mauve_gt_dir = img_to_txt_mauve_gt_path
+ rprint(f"Img to txt mauve eval gen dir: {self.img_to_txt_mauve_gen_dir}, gt dir: {self.img_to_txt_mauve_gt_dir}")
+
+ self.saved_tokens = defaultdict(list)
+ self.validation_start_time = time.time()
+
+ if getattr(self.config.trainer, "attach_oom_observer_eval", False):
+ from torchtnt.utils.oom import attach_oom_observer
+ attach_oom_observer(output_dir=str(self.config.output_dir), trace_max_entries=1000000)
+ rprint(f"Attached OOM observer to {self.config.output_dir}")
+ self.gpu_memory_reserved = torch.cuda.memory_reserved()
+
+
+def sample(self, return_input_ids=False, **kwargs):
+ continuous_mode = self.config.trainer.image_mode == "continuous"
+ text_only = kwargs.get("text_only", False)
+ kwargs.pop("text_only", None)
+ assert not continuous_mode
+ txt_tokens, img_tokens = self._sample(text_only=text_only, **kwargs)
+ if img_tokens is not None:
+ img_pred = decode_latents(self.config, self.get_vae(), img_tokens)
+ else:
+ img_pred = None
+ if txt_tokens is not None:
+ txt_pred = wrapped_batch_decode(self.tokenizer, txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)
+ else:
+ txt_pred = None
+ if return_input_ids:
+ return txt_pred, img_pred, txt_tokens, img_tokens
+ else:
+ return txt_pred, img_pred
+
+
+@torch.no_grad()
+def predict_step(self, batch, batch_idx, dataloader_idx=0):
+ batch = self.update_batch(batch)
+ assert (batch["input_ids"][~batch["x0_unmask"]] == self.mask_index).all()
+ txt_pred, img_pred, txt_tokens, img_tokens = self.sample(x0=batch["input_ids"], x0_unmask=batch["x0_unmask"], return_input_ids=True)
+ batch.update(dict(txt_pred=txt_pred, img_pred=img_pred, txt_tokens=txt_tokens, img_tokens=(img_tokens + self.text_vocab_size)))
+ return batch
+
+@torch.no_grad()
+def zero_shot_eval_step(self, batch, batch_idx):
+ batch = self.zero_shot_update_batch(batch)
+ dataset_name = self.config.data.train
+
+ def get_similarity(x0, batch, num_timesteps=None, txt_cond=True, return_unweighed=False, do_unconditional=False):
+ # NOTE - this function assume [txt, img] order with self.config.model.txt_length + self.config.model.img_length
+ # given a batch of img+text, get the similarity score
+ return_unweighed = return_unweighed or getattr(self.config.eval, "return_unweighed_sim", False)
+ class_log_probs = []
+ unweighed_class_log_probs = []
+ num_timesteps = num_timesteps or self.config.sampling.steps
+ effective_batch_size = batch['modality'].shape[0]
+ empty_device_cache()
+ times = torch.linspace(0, 1, steps=num_timesteps + 2)[1:-1].to(self.device).to(torch.float32)
+
+ if getattr(self.config.eval, "use_random_timesteps_same_batch", False):
+ times = torch.rand(num_timesteps, device=x0.device)
+ times = torch.sort(times)[0]
+
+ if getattr(self.config.eval, "use_random_timesteps_diff_batch", False):
+ # get a (B, num_timesteps) random timesteps
+ times = torch.rand(effective_batch_size, num_timesteps, device=x0.device)
+ times = torch.sort(times)[0]
+ print(f'Times: {times}')
+
+ do_unconditional = do_unconditional or getattr(self.config.eval, "do_unconditional", False)
+ # unweighed/weighed, randomized but different over batch, randomized but same over batch,
+ cond_mask = torch.full_like(x0, False, device=x0.device).bool()
+ if txt_cond:
+ cond_mask[:, :self.config.model.txt_length] = True
+ else:
+ # img conditioned
+ cond_mask[:, self.config.model.txt_length:] = True
+ full_mask = torch.full_like(x0, self.mask_index, device=x0.device)
+ pad_mask = x0 == self.tokenizer.pad_token_id
+ rprint(f'Getting similarity with {times.shape[0]} timesteps, {effective_batch_size} samples, {do_unconditional} unconditional, {self.parameterization} parameterization, {self.config.eval.cfg} cfg, {num_timesteps} num_timesteps, {txt_cond} txt_cond')
+ # for t in times:
+ # # t = self._sample_t(1, x0.device).expand(effective_batch_size)
+ # breakpoint()
+ # if getattr(self.config.eval, "`use_random_timesteps_diff_batch`", False):
+ # t = t.expand(effective_batch_size)
+ # else:
+ # t = t.expand(1)
+ for i in range(num_timesteps):
+ empty_device_cache()
+ if getattr(self.config.eval, "use_random_timesteps_diff_batch", False):
+ t = times[:, i]
+ else:
+ t = times[i]
+ t = t.expand(effective_batch_size)
+ sigma, dsigma = self.noise(t)
+ # print(sigma, t)
+ unet_conditioning = None # sigma[:, None] -> This causes CUDA OOM
+ move_chance = 1 - torch.exp(-sigma[:, None])
+
+ xt, ignore_batch_mask_for_metrics, joint_ar_nar_mask, _, __ = self.q_xt(x0, move_chance, return_ignore_batch_mask_for_metrics=True, batch=batch)
+ if not do_unconditional:
+ cond = torch.where(cond_mask, x0, xt)
+ if self.config.eval.cfg is not None:
+ uncond = torch.where(cond_mask, full_mask, xt)
+ cond_output = self.forward(
+ cond, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality'], return_logits=True
+ )
+ uncond_output = self.forward(
+ uncond, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality'], return_logits=True
+ )
+ cat_output = torch.stack([cond_output, uncond_output])
+ logits = cfg(self.config, t, cat_output).squeeze(0)
+ model_output = self._subs_parameterization(logits, xt=xt, batch=batch, modality=batch['modality'])
+ else:
+ # return logits false so already done with subs parameterization
+ model_output = self.forward(
+ cond, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality']
+ )
+ else:
+ if self.config.eval.cfg is not None:
+ uncond = torch.where(cond_mask, full_mask, xt)
+ cond_output = self.forward(
+ xt, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality'], return_logits=True
+ )
+ uncond_output = self.forward(
+ uncond, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality'], return_logits=True
+ )
+ cat_output = torch.stack([cond_output, uncond_output])
+ logits = cfg(self.config, t, cat_output).squeeze(0)
+ model_output = self._subs_parameterization(logits, xt=xt, batch=batch, modality=batch['modality'])
+ else:
+ # return logits false so already done with subs parameterization
+ model_output = self.forward(
+ xt, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality']
+ )
+
+
+ # print(f'Time: {t[0]}')
+ log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1)
+ # print(f'Log P Theta before pad remove: {-log_p_theta.mean()} | {(log_p_theta == 0).sum()}')
+ zeros = torch.zeros_like(log_p_theta)
+ log_p_theta = torch.where(pad_mask, zeros, log_p_theta)
+ # zero out the loss on conditioned part
+ if not do_unconditional:
+ log_p_theta = torch.where(cond_mask, zeros, log_p_theta)
+ # print(f'Log P Theta after pad remove: {-log_p_theta.mean()} | {(log_p_theta == 0).sum()}')
+ std_weighting = (dsigma / torch.expm1(sigma))[:, None]
+ unweighed_log_p_theta = -log_p_theta
+ loss = -log_p_theta * std_weighting
+ log_probs = loss.sum(dim=-1) / (~pad_mask).sum(dim=-1)
+ unweighed_log_probs = unweighed_log_p_theta.sum(dim=-1) / (~pad_mask).sum(dim=-1)
+ # print(f'Weighed loss: {log_probs.mean()} | Log P Theta: {-log_p_theta.mean()} | Std Weighting: {std_weighting.mean()}')
+ class_log_probs.append(log_probs)
+ unweighed_class_log_probs.append(unweighed_log_probs)
+ overall_time_log_probs = torch.stack(class_log_probs) # (num_time, B)
+ unweighed_overall_time_log_probs = torch.stack(unweighed_class_log_probs) # (num_time, B)
+ if return_unweighed:
+ return unweighed_overall_time_log_probs.mean(dim=0) # (B)
+ return overall_time_log_probs.mean(dim=0) # (B)
+
+ def get_similarity_ar(x0, batch, txt_cond=True, do_unconditional=False, **kwargs):
+ # get likelihood for each token and then average
+ img_first = kwargs.get("img_first", False)
+ if img_first:
+ x0 = torch.cat([x0[:, self.config.model.txt_length:], x0[:, :self.config.model.txt_length]], dim=1)
+ mod = batch['modality']
+ mod = torch.cat([mod[:, self.config.model.txt_length:], mod[:, :self.config.model.txt_length]], dim=1)
+ else:
+ mod = batch['modality']
+ empty_device_cache()
+ do_unconditional = do_unconditional or getattr(self.config.eval, "do_unconditional", False)
+
+ if getattr(self.config.eval, "cfg", None):
+ rprint('NOT SETTING CFG for AR')
+ # if getattr(self.config.eval, "cfg", None):
+ # cat_mod_input_ids = torch.cat([x0, torch.where(batch['modality'] == 1, self.mask_index, x0)], dim=0)
+ # _modality = torch.cat([batch['modality'], batch['modality']], dim=0)
+ # cat_p_x0 = self.forward(
+ # cat_mod_input_ids,
+ # sigma=None,
+ # batch=dict(modality=_modality), modality=_modality
+ # )
+ # logit_c, logit_u = cat_p_x0.chunk(2, dim=0)
+ # _w = getattr(self.config.eval, "cfg", None)
+ # model_output = (1 + _w) * logit_c - _w * logit_u
+ # else:
+ model_output = self.forward(x=x0, sigma=None, modality=mod)
+ x0 = x0[:, 1:]
+ # attention_mask = batch['attention_mask'][0][None, :].repeat(x0.shape[0], 1)[:, 1:]
+ attention_mask = x0 != self.tokenizer.pad_token_id
+ log_p_theta = model_output.gather(-1, x0[:, :, None])[:, :, 0]
+ if img_first:
+ txt_sl = slice(self.config.model.img_length-1, None)
+ img_sl = slice(None, self.config.model.img_length-1)
+ else:
+ txt_sl = slice(None, self.config.model.txt_length - 1)
+ img_sl = slice(self.config.model.txt_length - 1, None)
+ nll = (-log_p_theta * attention_mask).sum(dim=-1) / attention_mask.sum(dim=-1)
+ txt_nll = (-log_p_theta[:, txt_sl] * attention_mask[:, txt_sl]).sum(dim=-1) / attention_mask[:, txt_sl].sum(dim=-1)
+ img_nll = (-log_p_theta[:, img_sl] * attention_mask[:, img_sl]).sum(dim=-1) / attention_mask[:, img_sl].sum(dim=-1)
+ if do_unconditional:
+ return nll
+ return img_nll if txt_cond else txt_nll
+
+ def get_similarity_chameleon(zipp, batch, txt_cond=True, do_unconditional=False, prompts=None, images=None, **kwargs):
+ # get likelihood for each token and then average
+ empty_device_cache()
+ img_first = kwargs.get("img_first", False)
+ img_start_tok_id = self.chameleon_processor.tokenizer(self.chameleon_processor.image_start_token)['input_ids'][1]
+ img_end_tok_id = self.chameleon_processor.tokenizer(self.chameleon_processor.image_end_token)['input_ids'][1]
+ do_unconditional = do_unconditional or getattr(self.config.eval, "do_unconditional", False)
+ if not prompts and not images:
+ prompt, image = zipp
+ if img_first:
+ _prompt = f" {prompt}"
+ else:
+ _prompt = f"{prompt} "
+ inputs = self.get_anole_data(self.chameleon_model, self.chameleon_processor, _prompt, image, dtype=self.dtype, device=self.device)
+
+ else:
+ inputs = self.get_anole_data(self.chameleon_model, self.chameleon_processor, prompts, images, dtype=self.dtype, device=self.device)
+ # mod mask which is one for image tokens from the indx we see img_start_tok_id to img_end_tok_id
+
+ if img_first:
+ mod_mask = torch.cumsum(inputs['input_ids'] == img_end_tok_id, dim=1).bool()
+ else:
+ mod_mask = torch.cumsum(inputs['input_ids'] == img_start_tok_id, dim=1).bool()
+
+ mod_mask = mod_mask.cumsum(dim=1) > 1
+ output = self.chameleon_model(
+ input_ids=inputs['input_ids'].to(self.device),
+ attention_mask=inputs['attention_mask'].to(self.device),
+ labels=inputs['input_ids'].to(self.device)
+ )
+ loss = output.loss
+ logits = output.logits
+ logits = logits.transpose(-1, -2)
+ sample_chunk = inputs["input_ids"]
+ nlls = F.cross_entropy(logits[..., :-1].to(self.device), sample_chunk[..., 1:].to(self.device), reduction="none")
+ mod_mask = mod_mask[:, 1:]
+ # img nll is where mod_mask == 1
+ zeros = torch.zeros_like(nlls)
+ img_nll = torch.where(mod_mask, nlls, zeros)
+ txt_nll = torch.where(~mod_mask, nlls, zeros)
+ if do_unconditional:
+ return nlls.mean(dim=-1)
+ return img_nll.mean(dim=-1) if txt_cond else txt_nll.mean(dim=-1)
+
+ if dataset_name == "nlphuji/flickr30k":
+ txt_tokens, img_tokens = self._sample(
+ text_only=False,
+ x0=batch["input_ids"],
+ x0_unmask=batch["attention_mask"],
+ modality=batch["modality"],
+ )
+ img_samples = decode_latents(self.config, self.get_vae(), img_tokens[:, :self.config.model.img_length])
+ txt_samples = wrapped_batch_decode(self.tokenizer, txt_tokens[:, self.config.model.img_length:], clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)
+ gt_text_samples = wrapped_batch_decode(self.tokenizer, batch['gt_input_ids'][:, :self.config.model.txt_length], skip_special_tokens=True, clean_up_tokenization_spaces=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)
+ self.compute_cider(txt_samples, gt_text_samples)
+ elif dataset_name == "facebook/winoground":
+ # breakpoint()
+ # if batch_idx <= 15:
+ # return
+ a0_0 = batch["input_ids_0_0"] # a
+ a0_1 = batch["input_ids_0_1"] # d
+ a1_0 = batch["input_ids_1_0"] # b
+ a1_1 = batch["input_ids_1_1"] # c
+
+ text_correct_count = 0
+ image_correct_count = 0
+ group_correct_count = 0
+
+ wino_chameleon = getattr(self.config.eval, "wino_chameleon", False)
+
+ s0_0, s0_1, s1_0, s1_1 = None, None, None, None
+ modes = ['image', 'text', 'group']
+
+ if wino_chameleon:
+ txt0 = wrapped_batch_decode(tokens=batch['caption_0_input_ids'], tokenizer=self.tokenizer, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)[0]
+ txt1 = wrapped_batch_decode(tokens=batch['caption_1_input_ids'], tokenizer=self.tokenizer, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)[0]
+ img0 = Im(batch['img_0']).pil
+ img1 = Im(batch['img_1']).pil
+ prompts = [txt0, txt0, txt1, txt1]
+ images = [img0, img1, img0, img1]
+ zipp = list(zip(prompts, images))
+
+ # note - signs are reversed since we have loss, so want to minimize instead of maximize
+ def text_correct(result):
+ return torch.logical_and(result["s0_i0"] < result["s1_i0"], result["s1_i1"] < result["s0_i1"])
+
+ def image_correct(result):
+ return torch.logical_and(result["s0_i0"] < result["s0_i1"], result["s1_i1"] < result["s1_i0"])
+
+ def group_correct(result):
+ return torch.logical_and(image_correct(result), text_correct(result))
+ results_cond = {}
+ for mode in modes:
+ do_unconditional = (mode == 'group')
+ txt_cond = not (mode == 'text')
+ img_first = mode == 'text'
+ if wino_chameleon:
+ do_unconditional = True
+ s0_0 = get_similarity_chameleon(zipp[0], batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first)
+ s0_1 = get_similarity_chameleon(zipp[1], batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first)
+ s1_0 = get_similarity_chameleon(zipp[2], batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first)
+ s1_1 = get_similarity_chameleon(zipp[3], batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first)
+ elif self.parameterization == "ar":
+ s0_0 = get_similarity_ar(a0_0, batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first)
+ s0_1 = get_similarity_ar(a0_1, batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first)
+ s1_0 = get_similarity_ar(a1_0, batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first)
+ s1_1 = get_similarity_ar(a1_1, batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first)
+ else:
+ s0_0 = get_similarity(a0_0, batch, txt_cond=txt_cond, do_unconditional=do_unconditional)
+ s0_1 = get_similarity(a0_1, batch, txt_cond=txt_cond, do_unconditional=do_unconditional)
+ s1_0 = get_similarity(a1_0, batch, txt_cond=txt_cond, do_unconditional=do_unconditional)
+ s1_1 = get_similarity(a1_1, batch, txt_cond=txt_cond, do_unconditional=do_unconditional)
+ result = {
+ "s0_i0": s0_0,
+ "s0_i1": s0_1,
+ "s1_i0": s1_0,
+ "s1_i1": s1_1,
+ }
+ if mode == 'text':
+ results_cond['text'] = text_correct(result)
+ text_correct_count += text_correct(result).sum().item()
+ elif mode == 'image':
+ results_cond['image'] = image_correct(result)
+ image_correct_count += image_correct(result).sum().item()
+ elif mode == 'group':
+ if getattr(self.config.eval, "wino_group_conditional", False):
+ rprint('[Winoground] Using conditional group accuracy')
+ group_correct_count = (torch.logical_and(results_cond['text'], results_cond['image'])).sum().item()
+ else:
+ rprint('[Winoground] Using unconditional group accuracy')
+ group_correct_count += group_correct(result).sum().item()
+ bsz = a0_0.shape[0]
+ txt_acc = text_correct_count / bsz
+ img_acc = image_correct_count / bsz
+ group_acc = group_correct_count / bsz
+
+ self.win_text_accuracy.update(txt_acc)
+ self.win_image_accuracy.update(img_acc)
+ self.win_group_accuracy.update(group_acc)
+ running_avg_txt = self.win_text_accuracy.compute()
+ running_avg_img = self.win_image_accuracy.compute()
+ running_avg_group = self.win_group_accuracy.compute()
+ rprint(f"[{batch_idx}] Winoground Text Accuracy: {txt_acc} ({running_avg_txt}), Image Accuracy: {img_acc} ({running_avg_img}), Group Accuracy: {group_acc} ({running_avg_group})")
+ else:
+ # def randomize_batch - input is a batch. for the batch['input_ids'] which contains self.config.model.txt_length txt tokens + self.config.model.img_length img tokens which are PAIRED
+ # we want to randomly swap the img/txt tokens between each other
+ x0 = batch['input_ids']
+ img_first = getattr(self.config.model, "img_first", False)
+ only_one_correct = getattr(self.config.eval, "only_one_correct", False)
+ wino_chameleon = getattr(self.config.eval, "wino_chameleon", False)
+ # todo check attn mask for text retrieval
+ x0_txt = x0.clone()
+ x0_img = x0.clone()
+ if only_one_correct:
+ # for each sample from 1st batch onwards, shuffle the img/txt tokens, as in map randomly
+ x0c = x0.clone()
+ if img_first:
+ second_half = x0c[1:, self.config.model.img_length:]
+ else:
+ second_half = x0c[1:, self.config.model.txt_length:]
+ # shuffle second half
+ # second_half = second_half[torch.randperm(second_half.size(0))]
+ second_half = torch.cat([second_half[1:], second_half[0].unsqueeze(0)], dim=0)
+ # replace img tokens with txt tokens
+ if img_first:
+ x0c[1:, self.config.model.img_length:] = second_half
+ else:
+ x0c[1:, self.config.model.txt_length:] = second_half
+ if wino_chameleon:
+ if img_first:
+ img_tokens = x0c[:, :self.config.model.img_length]
+ txt_tokens = x0c[:, self.config.model.img_length:]
+ else:
+ txt_tokens = x0c[:, :self.config.model.txt_length]
+ img_tokens = x0c[:, self.config.model.txt_length:]
+ dec_txt = wrapped_batch_decode(self.tokenizer, txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)
+ dec_imgs = decode_latents(self.config, self.get_vae(), img_tokens - self.text_vocab_size)
+ dec_imgs = [Im(img).pil for img in dec_imgs]
+ if img_first:
+ # append '' to beginning of each txt sample
+ dec_txt = [' ' + txt for txt in dec_txt]
+ else:
+ dec_txt = [txt + ' ' for txt in dec_txt]
+ class_sim = get_similarity_chameleon(None, batch, do_unconditional=True, img_first=img_first, prompts=dec_txt, images=dec_imgs)
+ if torch.isinf(class_sim).any():
+ rprint(f'[Chameleon] Inf found in class_sim, check transformers version')
+ breakpoint()
+ elif self.parameterization == "ar":
+ class_sim = get_similarity_ar(x0c, batch, do_unconditional=True)
+ else:
+ class_sim = get_similarity(x0c, batch, do_unconditional=True)
+
+ topk = class_sim.topk(k=1, dim=0, largest=False)
+ topk_indices = topk.indices
+ topk_acc = (topk_indices == 0).float().mean().item()
+ rprint(f"[{batch_idx}] Datacomp Correct Pair Retrieval Acc: {topk_acc} ({self.datacomp_img_acc.compute()})")
+ self.datacomp_img_acc.update(topk_acc)
+ else:
+ if img_first:
+ # image retrieval given text, so fix text
+ x0_txt[:, self.config.model.img_length:] = x0[0, self.config.model.img_length:] # make all texts the first text
+
+ # text retrieval given image
+ x0_img[:, :self.config.model.img_length] = x0[0, :self.config.model.img_length] # make all images the first image
+ else:
+ # image retrieval given text, so fix text
+ x0_txt[:, :self.config.model.txt_length] = x0[0, :self.config.model.txt_length] # make all texts the first text
+
+ # text retrieval given image
+ x0_img[:, self.config.model.txt_length:] = x0[0, self.config.model.txt_length:] # make all images the first image
+
+ if self.parameterization == "ar":
+ txt_class_sim = get_similarity_ar(x0_txt, batch, txt_cond=True)
+ img_class_sim = get_similarity_ar(x0_img, batch, txt_cond=True) # TODO MAYBE REVERT?
+ else:
+ txt_class_sim = get_similarity(x0_txt, batch, txt_cond=True)
+ img_class_sim = get_similarity(x0_img, batch, txt_cond=False)
+
+ img_topk = img_class_sim.topk(k=1, dim=0, largest=False)
+ txt_topk = txt_class_sim.topk(k=1, dim=0, largest=False)
+
+ img_topk_indices = img_topk.indices
+ txt_topk_indices = txt_topk.indices
+
+ img_acc = (img_topk_indices == 0).float().mean().item()
+ txt_acc = (txt_topk_indices == 0).float().mean().item()
+ rprint(f"[{batch_idx}] Datacomp Text Retrieval Acc: {img_acc}, Datacomp Image Retrieval Accuracy: {txt_acc}")
+ self.datacomp_img_acc.update(img_acc)
+ self.datacomp_txt_acc.update(txt_acc)
+ # img_class_sim is (B) - argmin since loss txt_conds
+
+@torch.no_grad()
+def validation_step(self, batch, batch_idx):
+ batch = self.update_batch(batch)
+ continuous_mode = self.config.trainer.image_mode == "continuous"
+
+ if self.config.mode == "eval":
+ logs = dict()
+ logs["gpu_max_mem_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3)
+ logs["gpu_cur_mem_reserved_gb"] = torch.cuda.memory_reserved() / (1024**3)
+ logs["gpu_max_mem_allocated_gb"] = torch.cuda.max_memory_allocated() / (1024**3)
+ logs["gpu_cur_mem_allocated_gb"] = torch.cuda.memory_allocated() / (1024**3)
+ log({**logs, **self.get_step_metrics()})
+
+ if self.get_every_n_evals(getattr(self.config.eval, "log_every_n_evals", 10)) \
+ and self.image_model \
+ and (batch_idx == 0 or self.config.eval.visualize_data_only) \
+ and not continuous_mode:
+ self.visualize_samples(batch, batch_idx)
+ if self.config.eval.visualize_data_only: return
+
+ if batch_idx < self.config.eval.num_sample_batches and self.config.eval.compute_generative_perplexity:
+ if continuous_mode:
+ # todo update to use modality once multimodal batches update is done by alex
+ gt_text_samples = wrapped_batch_decode(self.tokenizer, batch['text_tokens'][:, :self.config.model.txt_length], skip_special_tokens=True, clean_up_tokenization_spaces=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos) # since input_ids is for images
+ else:
+ input_ids = batch["input_ids"]
+ pad_tokens = torch.full_like(input_ids, self.tokenizer.pad_token_id)
+ text_tokens = torch.where(batch["modality"] == 0, input_ids, pad_tokens)
+ gt_text_samples = wrapped_batch_decode(self.tokenizer, text_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)
+ if getattr(self.config.trainer, "disable_text_modality", False):
+ gt_text_samples = [' ']
+ self.compute_generative_perplexity(gt_text_samples, gt=True)
+
+ if getattr(self.config.trainer, "log_flops", False) \
+ and batch_idx == 0 \
+ and self.current_run_global_step <= 1 \
+ and self.config.trainer.fsdp is False:
+ self.log_flops(batch=batch, batch_idx=batch_idx)
+ if self.fid_eval:
+ if self.config.eval.fid_mode == "inline":
+ self.update_inline_fid(batch, batch_idx)
+ elif self.config.eval.fid_mode == "clean":
+ self.update_clean_fid(batch, batch_idx)
+ else:
+ raise ValueError(f"Invalid FID mode: {self.config.eval.fid_mode}")
+
+ if getattr(self.config.eval, "get_top_k", False) and self.config.parameterization == "ar":
+ self.get_top_k(batch, batch_idx)
+
+ try:
+ if self.config.eval.compute_img_to_txt_mauve_clip and not self.config.eval.unconditional_fid:
+ self.update_img_to_txt_mauve_clip(batch, batch_idx)
+ except Exception as e:
+ empty_device_cache()
+ rprint(f"Error in update_img_to_txt_mauve_clip: {e}")
+
+ if (self.get_every_n_evals(getattr(self.config.eval, "log_every_n_evals", 10)) \
+ and continuous_mode \
+ and self.config.eval.generate_samples \
+ and not self.config.eval.test_eval_speed):
+ # todo remove this from here and move to on_validation_epoch_end
+ data = self.sample_transfusion(batch_size_per_gpu=batch['input_ids'].shape[0])
+ # TODO @sid support batching. prob pass list of lists to be general.
+ rec_embs = [data.xt_img_embed[i, data.modality[i] == 1] for i in range(data.shape[0])]
+ # stack and transpose
+ rec_embs = torch.stack(rec_embs)
+ rec_txt = data.xt_ids[data.modality == 0][None]
+ recon_image = decode_latents(self.config, self.get_vae(), rec_embs, batched=True) # TODO @sid support batching e.g. not just first element. prob pass list of lists to be general.
+ txt = wrapped_batch_decode(self.tokenizer, rec_txt, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)
+ rprint(f"Sampled {len(txt)} text samples:\n {txt[:1][:50]}")
+ image_list = [wandb.Image(img) for img in recon_image]
+ val_loss = self.compute_loss(batch, prefix="val")
+ log({"val/gen_img": image_list, "val/loss": val_loss, **self.get_step_metrics()})
+
+ if (
+ self.get_every_n_evals(getattr(self.config.eval, "log_every_n_evals", 10))
+ and (self.unified_model or self.cub_model or self.vggface_model)
+ and batch_idx < getattr(self.config.eval, "num_masking_viz_batches", 1)
+ and not continuous_mode # todo add masking val support s
+ ):
+ self.sample_masking(batch=batch, batch_idx=batch_idx)
+
+ return self.compute_loss(batch, prefix="val", batch_idx=batch_idx)
+
+@try_except(write_error_to_file=True)
+@torch.no_grad()
+def zero_shot_eval_epoch_end(self, example_batch=None):
+ dataset_name = self.config.data.train
+ dprint("zero_shot_eval_epoch_end")
+ if dataset_name == "nlphuji/flickr30k":
+ cider_score = self.cider_score.compute()
+ rprint('Flickr30k CIDEr score: ', cider_score)
+ # log it
+ log({
+ 'val/cider_score': cider_score
+ })
+ elif dataset_name == "facebook/winoground":
+ win_text_accuracy = self.win_text_accuracy.compute()
+ win_image_accuracy = self.win_image_accuracy.compute()
+ win_group_accuracy = self.win_group_accuracy.compute()
+ rprint(f'Winoground Text Accuracy: {win_text_accuracy}')
+ rprint(f'Winoground Image Accuracy: {win_image_accuracy}')
+ rprint(f'Winoground Group Accuracy: {win_group_accuracy}')
+ # log it
+ log({
+ 'val/win_text_accuracy': win_text_accuracy,
+ 'val/win_image_accuracy': win_image_accuracy,
+ 'val/win_group_accuracy': win_group_accuracy
+ })
+ else:
+ datacomp_img_acc = self.datacomp_img_acc.compute()
+ datacomp_txt_acc = self.datacomp_txt_acc.compute()
+ rprint(f'Datacomp Text Accuracy: {datacomp_img_acc}')
+ rprint(f'Datacomp Image Accuracy: {datacomp_txt_acc}')
+ # log it
+ log({
+ 'val/datacomp_text_retr_acc': datacomp_img_acc,
+ 'val/datacomp_img_retr_acc': datacomp_txt_acc
+ })
+
+@try_except(write_error_to_file=True)
+@torch.no_grad()
+def get_img_text_saturation_batch(self, example_batch):
+ max_sampling_steps = self.config.model.length
+ batch_size_per_gpu = example_batch["input_ids"].shape[0]
+ do_standalone = getattr(self.config.eval, "cham_standalone", False)
+ pplx_per_step = []
+ # make stpes linspace between 1 and max_sampling_steps with 100 steps
+ # steps = np.linspace(1, max_sampling_steps, 10).astype(int)
+ # steps = [1,2,4,8,16,32,64,128,256,512,1024]
+ steps = [1,2,4,8,16,32,64] # todo revert
+
+ rprint(f"do_standalone: {do_standalone} with steps: {steps}")
+ dec_txt_list = []
+ dec_img_list = []
+ for step in steps:
+ rprint(f"Step: {step}")
+ (txt_tokens, img_tokens), nfe_cnt = self._sample(text_only=False, batch_size_per_gpu=batch_size_per_gpu, sample_modality=example_batch["modality"], return_nfe=True, num_steps=step)
+ decoded_img = Im(decode_latents(self.config, self.get_vae(), img_tokens)).pil
+ decoded_txt = wrapped_batch_decode(self.tokenizer, txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)
+ if not isinstance(decoded_img, list):
+ decoded_img = [decoded_img]
+ if not isinstance(decoded_txt, list):
+ decoded_txt = [decoded_txt]
+ dec_txt_list.append(decoded_txt)
+ dec_img_list.append(decoded_img)
+ tot_ppl, tot_loss, img_loss, txt_loss = self.calculate_chameleon_perplexity(self.chameleon_model, self.chameleon_processor, prompts=decoded_txt, images=decoded_img, return_all=True)[0]
+ rprint(f"Step {step} - Total PPL: {tot_ppl} | Total Loss: {tot_loss} | Img Loss: {img_loss} | Txt Loss: {txt_loss}")
+ pplx_per_step.append((step, tot_ppl, tot_loss, img_loss, txt_loss))
+ empty_device_cache()
+ return dec_txt_list, dec_img_list, pplx_per_step
+
+@torch.no_grad()
+@try_except(write_error_to_file=True)
+@torch.no_grad()
+def on_validation_epoch_end(self, example_batch=None):
+ dprint("on_validation_epoch_end")
+
+ if self.config.eval.compute_val_metrics_standalone:
+ self.compute_val_metrics_standalone()
+
+ all_val_metrics = self.get_step_metrics()
+ all_val_metrics.update(self.valid_metrics.compute())
+ if hasattr(self, "valid_txt_metrics"):
+ valid_txt_metrics = self.valid_txt_metrics.compute()
+ valid_img_metrics = self.valid_img_metrics.compute()
+ all_val_metrics.update({
+ **{f"val/txt_{k.split('/')[-1]}": v for k, v in replace_nan_dict(valid_txt_metrics).items()},
+ **{f"val/img_{k.split('/')[-1]}": v for k, v in replace_nan_dict(valid_img_metrics).items()},
+ })
+
+ log(all_val_metrics)
+
+ gprint("example_batch['input_ids'].ndim: ", example_batch['input_ids'].ndim)
+ if example_batch['input_ids'].ndim == 3:
+ combined_batches = example_batch
+ example_batch = self.update_batch(example_batch[0])
+ else:
+ example_batch = self.update_batch(example_batch)
+
+ if self.config.eval.auto_enhance:
+ self.auto_enhance(combined_batches)
+
+ continuous_mode = self.config.trainer.image_mode == "continuous"
+ compute_chameleon_perplexity = getattr(self.config.eval, "compute_chameleon_perplexity", False)
+ all_images = []
+ with try_except(write_error_to_file=True, clear_cuda_cache=True):
+ if self.fid_eval:
+ if self.config.eval.fid_mode == "inline":
+ self.compute_inline_fid_eval()
+ elif self.config.eval.fid_mode == "clean":
+ self.compute_clean_fid_eval()
+ else:
+ raise ValueError(f"Invalid FID mode: {self.config.eval.fid_mode}")
+
+ if self.config.eval.calculate_clip_score:
+ prefix = "unconditional" if self.config.eval.unconditional_fid else "fid"
+ self.compute_clip_score(self.fid_gen_dir, f"{prefix}_gen")
+ self.compute_clip_score(self.fid_gt_dir, f"{prefix}_gt")
+ if self.config.trainer.ar_inpainting:
+ import shutil
+ target_dir = Path(self.fid_gt_dir).parent / "fid_inpainting"
+ target_dir.mkdir(parents=True, exist_ok=True)
+
+ for img_file in Path(self.fid_gt_dir).rglob("*.png"):
+ shutil.copy2(img_file, target_dir / img_file.name)
+
+ for json_file in Path(self.fid_gen_dir).rglob("*.json"):
+ shutil.copy2(json_file, target_dir / json_file.name)
+
+ self.compute_clip_score(target_dir, f"{prefix}_inpainting")
+
+ if self.config.eval.unconditional_fid and \
+ self.config.eval.compute_img_to_txt_mauve_during_unconditional_fid and self.config.eval.compute_img_to_txt_mauve_clip:
+ rprint("Computing img to txt mauve during unconditional fid")
+ # CLIP score is the same as the fid clip score so we don't need to compute it again
+ gen_txt_tokens = self.gather_tokens(self.saved_tokens["unconditional_gen_txt_tokens"])
+ gt_txt_tokens = self.gather_tokens(self.saved_tokens["unconditional_gt_txt_tokens"])
+ if not getattr(self.config.eval, "global_disable_mauve", False):
+ self.compute_mauve_entropy(self.fid_gen_dir, self.fid_gt_dir, gen_txt_tokens, gt_txt_tokens, "unconditional")
+ elif self.config.eval.compute_img_to_txt_mauve_clip:
+ gen_txt_tokens = self.gather_tokens(self.saved_tokens["img_to_txt_gen_txt_tokens"])
+ gt_txt_tokens = self.gather_tokens(self.saved_tokens["img_to_txt_gt_txt_tokens"])
+ if not getattr(self.config.eval, "global_disable_mauve", False):
+ self.compute_mauve_entropy(self.img_to_txt_mauve_gen_dir, self.img_to_txt_mauve_gt_dir, gen_txt_tokens, gt_txt_tokens, "img_to_txt")
+ if self.config.eval.calculate_clip_score:
+ self.compute_clip_score(self.img_to_txt_mauve_gen_dir, "img_to_txt_mauve_gen")
+ self.compute_clip_score(self.img_to_txt_mauve_gt_dir, "img_to_txt_mauve_gt")
+ self.compute_mauve_entropy(self.img_to_txt_mauve_gen_dir, self.img_to_txt_mauve_gt_dir, gen_txt_tokens, gt_txt_tokens, "img_to_txt")
+
+ should_eval_speed = getattr(self.config.eval, "test_eval_speed", False)
+ if self.config.eval.generate_samples:
+ with try_except(write_error_to_file=True):
+ empty_device_cache()
+ if getattr(self.config.eval, 'set_random_gen_seed', False):
+ new_seed = get_rank() * 10 + 32
+ torch.manual_seed(new_seed)
+ torch.cuda.manual_seed(new_seed)
+ random.seed(new_seed)
+ np.random.seed(new_seed)
+
+ tot_time_per_sample = []
+ tot_token_time_per_token = []
+ tot_nfe_cnt = 0
+ batch_size_per_gpu = self.config.loader.eval_batch_size
+ sampling_steps = self.config.sampling.steps
+ num_batches = self.config.eval.num_sample_batches
+ gen_ppl_max_batches = 1e8
+ compute_entropy = getattr(self.config.eval, "compute_entropy", False)
+ compute_gen_ppl = self.config.eval.compute_generative_perplexity
+ entropies = []
+
+ if self.config.eval.compute_standalone_mauve and not getattr(self.config.eval, "global_disable_mauve", False):
+ mauve_N = self.config.eval.mauve_num_samples
+ # we need to generate this many samples distributed over the batch size * num_gpus
+ # if not clean division, generate one extra batch we can discard later
+ num_batches = math.ceil(mauve_N / (batch_size_per_gpu * get_num_gpus()))
+ should_eval_speed = True # if we are generating this many samples might as well time it
+ gen_ppl_max_batches = getattr(self.config.eval, "gen_ppl_max_batches", 1e8) # since we are generating a lot of samples, we can compute gen ppl for a few batches but not all since that'll be slow with eval_mode = llama
+ compute_entropy = True
+ compute_gen_ppl = True
+ rprint(f"[MAUVE] Generating {mauve_N} samples with batch size {batch_size_per_gpu}, sampling steps {sampling_steps}, total length {self.config.model.length}, num_batches: {num_batches}, max_gen_ppl_batches: {gen_ppl_max_batches}")
+
+ rprint(f"Generating {num_batches} samples with batch size {batch_size_per_gpu}, sampling steps {sampling_steps}, total length {self.config.model.length}, compute_entropy: {compute_entropy}, compute_gen_ppl: {compute_gen_ppl}")
+ all_samples = []
+ get_img_text_saturation = getattr(self.config.eval, "get_img_text_saturation", False)
+ for i in tqdm(range(num_batches), desc="Generating samples"):
+ if get_img_text_saturation:
+ dec_txt_list, dec_img_list, all_vals = self.get_img_text_saturation_batch(example_batch)
+ # Prepare data for logging
+ df = pd.DataFrame(all_vals, columns=["step", "tot_ppl", "tot_loss", "img_loss", "txt_loss"])
+ df.to_csv(Path(self.config.output_dir) / f"img_text_saturation_batch_{i}.csv", index=False)
+ rprint(f"Saved img_text_saturation_batch_{i}.csv to {Path(self.config.output_dir) / f'img_text_saturation_batch_{i}.csv'}")
+
+ log_data = []
+ for (step, tot_ppl, tot_loss, img_loss, txt_loss), dec_txt, dec_img in zip(all_vals, dec_txt_list, dec_img_list):
+ concatenated_text = ' | '.join(dec_txt)
+ concatenated_image = dec_img[0]
+ log_data.append([step, tot_ppl, tot_loss, img_loss, txt_loss, concatenated_text, wandb.Image(concatenated_image)])
+
+ # Log to wandb
+ log_table = wandb.Table(columns=["Step", "Total PPL", "Total Loss", "Image Loss", "Text Loss", "Generated Text", "Generated Image"], data=log_data)
+ wandb.log({"img_text_saturation": log_table, "trainer/global_step": self.global_step})
+ rprint("Logged img_text_saturation table to wandb")
+ # log (step, Im)
+ # make it into pd df and store in output_dir
+ break
+ if should_eval_speed:
+ start_time = start_timing(sync=True, enable=True, message="Evaluating inference speed")
+
+ if self.parameterization == "ar" and continuous_mode:
+ data = self.sample_transfusion(text_only=True, batch_size_per_gpu=batch_size_per_gpu)
+ txt_tokens = data.xt_ids[:, self.static_txt_sl]
+ else:
+ (txt_tokens, img_tokens), nfe_cnt = self._sample(
+ text_only=False,
+ batch_size_per_gpu=batch_size_per_gpu,
+ sample_modality=example_batch["modality"],
+ return_nfe=True,
+ )
+ tot_nfe_cnt += nfe_cnt
+ if should_eval_speed:
+ tot_time = end_timing(start_time, enable=True, sync=True)
+ if continuous_mode: assert (data.modality == 0).all()
+ tot_time_per_sample.append(tot_time)
+ tot_token_time_per_token.append((tot_time) / self.config.model.length)
+
+ if compute_entropy:
+ entropies.append(self.compute_entropy(txt_tokens).item())
+
+ if compute_chameleon_perplexity:
+ all_images.extend(Im(decode_latents(self.config, self.get_vae(), img_tokens)).pil)
+ text_samples = wrapped_batch_decode(self.tokenizer, txt_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True)
+
+ if self.config.eval.compute_standalone_mauve and not getattr(self.config.eval, "global_disable_mauve", False):
+ self.mauve_predictions.extend(text_samples)
+ if len(text_samples) > 0 and len(text_samples[0]) > 0 and self.config.eval.compute_generative_perplexity and i <= gen_ppl_max_batches:
+ self.compute_generative_perplexity(text_samples)
+
+ rprint(f"Generated {len(text_samples)} samples - {[text_samples[i][:200] for i in range(min(len(text_samples), 5))]}")
+ all_samples.extend(text_samples)
+
+ # TODO: @ssg2 is this needed?
+ # Log the last generated samples
+ # if not compute_chameleon_perplexity:
+ # text_samples = all_samples[:self.config.sampling.num_sample_log]
+ # all_images = all_images[:self.config.sampling.num_sample_log]
+
+ avg_nfe_cnt = tot_nfe_cnt / num_batches
+ if should_eval_speed:
+
+ # TODO: @ssg2 is this needed?
+ # data_dict = {
+ # f"samples": wandb.Table(columns=["Generated Samples", "Time per sample", "Time per token", "Generated Images"], data=[[s, t, tt, wandb.Image(img)] for s, t, tt, img in zip(text_samples, tot_time_per_sample, tot_token_time_per_token, all_images )]),
+ # "trainer/global_step": self.global_step,
+ # }
+
+ data_dict = {
+ f"samples": wandb.Table(columns=["Generated Samples", "Generated Images"], data=[[s, wandb.Image(img)] for s, img in zip(all_samples[:self.config.sampling.num_sample_log], all_images[:self.config.sampling.num_sample_log])]),
+ "trainer/global_step": self.global_step,
+ }
+ assert len(tot_time_per_sample) == len(tot_token_time_per_token)
+ if len(tot_time_per_sample) > 1:
+ tot_time_per_sample = tot_time_per_sample[1:] # exclude warmup
+ tot_token_time_per_token = tot_token_time_per_token[1:]
+ print(f'Have {len(tot_time_per_sample)} samples')
+ print(f'tot_time_per_sample: {tot_time_per_sample}')
+ print(f'tot_token_time_per_token: {tot_token_time_per_token}')
+ avg_time_per_sample = sum(tot_time_per_sample) / len(tot_time_per_sample)
+ avg_time_per_token = sum(tot_token_time_per_token) / len(tot_token_time_per_token)
+ data_dict["val/avg_time_per_sample"] = avg_time_per_sample
+ data_dict["val/avg_time_per_token"] = avg_time_per_token
+ data_dict["val/avg_nfe_cnt"] = avg_nfe_cnt
+ rprint(f"Time per sample: avg (excluding warmup): {avg_time_per_sample} - {tot_time_per_sample} ")
+ rprint(f"Time per token: avg (excluding warmup): {avg_time_per_token} - {tot_token_time_per_token} ")
+ with open(Path(self.config.output_dir) / "times.txt", "a") as f:
+ f.write(f"{avg_time_per_sample}, {avg_time_per_token}\n")
+ f.write(f"{tot_time_per_sample}\n")
+ f.write(f"{tot_token_time_per_token}\n")
+ rprint(f"Logged time per sample and time per token to {Path(self.config.output_dir) / 'times.txt'}")
+ else:
+ if len(text_samples) > 0 and isinstance(text_samples[0], list):
+ text_samples = [[item] for sublist in text_samples for item in sublist]
+ else:
+ text_samples = [[item] for item in text_samples]
+
+ data_dict = {
+ "samples": wandb.Table(columns=["Generated Samples"], data=text_samples),
+ **self.get_step_metrics()
+ }
+
+ if compute_gen_ppl:
+ data_dict["val/gen_ppl"] = self.gen_ppl_metric.compute()
+ data_dict["val/gt_gen_ppl"] = self.gt_gen_ppl_metric.compute()
+ self.gen_ppl_metric.reset()
+ self.gt_gen_ppl_metric.reset()
+
+ if compute_entropy:
+ data_dict["val/val_entropy"] = sum(entropies) / len(entropies) if len(entropies) > 0 else 0
+
+ if compute_chameleon_perplexity:
+ if getattr(self.config.eval, "max_chameleon_samples", False):
+ all_images = all_images[:self.config.eval.max_chameleon_samples]
+ all_samples = all_samples[:self.config.eval.max_chameleon_samples]
+ pplxs = self.calculate_chameleon_perplexity(self.chameleon_model, self.chameleon_processor, images=all_images, prompts=all_samples)
+
+ # take average of pplxs
+ avg_pplx = sum(pplxs) / len(pplxs)
+ data_dict["val/chameleon_ppl"] = avg_pplx
+
+ if self.config.eval.compute_standalone_mauve and not getattr(self.config.eval, "global_disable_mauve", False):
+ all_mauve_preds = gather_object(self.mauve_predictions)
+ all_mauve_refs = gather_object(self.mauve_references)
+ data_dict["val/mauve_score"] = self.get_mauve_score(all_mauve_preds, all_mauve_refs, "standalone")
+
+ log(data_dict)
+
+ # Note: the above function got a little complicated due to the use in scoring/speed evals, etc. so we use the below function
+ # for both unconditional *and* conditional sampling.
+ if (
+ ((self.get_every_n_evals(getattr(self.config.eval, "log_every_n_evals", 10))
+ and (self.image_model or self.config.trainer.multimodal_batches)
+ and not getattr(self.config.model, "img_cond", False)
+ and not should_eval_speed) or getattr(self.config.eval, "force_eval_uncond", False)) and not getattr(self.config.eval, "force_disable_eval_uncond", False)
+ ):
+ dprint("Generating samples")
+ with try_except(write_error_to_file=True):
+ has_label = getattr(self.config.model, "cond_label", False)
+ sample_kwargs = dict()
+
+ if has_label:
+ label = torch.randint(0, self.config.model.label_vocab_size, (self.config.loader.eval_batch_size,)).to(device=self.device, dtype=torch.int64)
+ sample_kwargs["label"] = label
+ else:
+ label = torch.randint(0, 1, (self.config.loader.eval_batch_size * 20,))
+
+ text_samples_list = []
+ img_samples_list = []
+ for j in range(getattr(self.config.eval, "num_uncond_sample_batches", 1)):
+ if continuous_mode:
+ data = self.sample_transfusion(batch_size_per_gpu=self.config.loader.eval_batch_size)
+ text_samples = data.xt_ids[:, self.static_txt_sl]
+ img_samples = data.xt_img_embed[:, self.static_img_sl]
+ img_samples = decode_latents(self.config, self.get_vae(), img_samples)
+ else:
+ if getattr(self.config.eval, "eval_large_batch", None) is not None:
+ data = combined_batches[j]
+ data = self.update_batch(data)
+ rprint(f"Taken slice {j} of {getattr(self.config.eval, 'eval_large_batch', None)}")
+ else:
+ data = example_batch
+
+ _modality = data.get("modality", None)
+ _bs = min(self.config.eval.perplexity_batch_size, self.config.loader.eval_batch_size)
+ if _bs < _modality.shape[0]:
+ _modality = _modality[:_bs]
+
+ text_samples, img_samples = self._sample(
+ text_only=False,
+ num_steps=self.config.sampling.max_sampling_steps,
+ batch_size_per_gpu=_bs,
+ example_batch=data,
+ sample_batch_idx=j,
+ modality=_modality,
+ sample_ids=data.get("sample_ids", None),
+ allow_interleaved_conditional=True,
+ **sample_kwargs
+ )
+ num_text_tokens = self.config.model.txt_length if self.config.model.txt_length > 0 else 128
+ if text_samples is None:
+ text_samples = [torch.zeros((self.config.loader.eval_batch_size, num_text_tokens), dtype=torch.int64, device=self.device)]
+ elif isinstance(text_samples, list):
+ new_text_samples = []
+ for text_sample in text_samples:
+ text_samples_padded = torch.nn.functional.pad(text_sample, (0, num_text_tokens - text_sample.shape[-1]), value=self.tokenizer.pad_token_id) if text_sample.shape[-1] < num_text_tokens else text_sample[..., :num_text_tokens]
+ new_text_samples.append(text_samples_padded)
+ text_samples = new_text_samples
+ else:
+ text_samples = [torch.nn.functional.pad(text_samples, (0, num_text_tokens - text_samples.shape[-1]), value=self.tokenizer.pad_token_id) if text_samples.shape[-1] < num_text_tokens else text_samples[..., :num_text_tokens]]
+
+ text_samples_list.extend(text_samples)
+ if img_samples is not None:
+ if isinstance(img_samples, list):
+ img_samples_list.extend(img_samples)
+ else:
+ img_samples_list.append(img_samples)
+
+ if len(text_samples_list) > 0 and any(text_samples is not None for text_samples in text_samples_list):
+ text_samples = torch.cat(text_samples_list, dim=0)
+ else:
+ text_samples = None
+ has_img = any(img_samples is not None for img_samples in img_samples_list)
+ log_dict = {}
+ try:
+ if has_img:
+ if isinstance(img_samples_list[0], Tensor):
+ img_samples = torch.cat(img_samples_list, dim=0)
+ if img_samples.ndim == 2:
+ pred_img = decode_latents(self.config, self.get_vae(), img_samples)
+ else:
+ pred_img = img_samples
+
+ log_dict.update({"val/gen_images": wandb.Image(pred_img)})
+ else:
+ pred_img = img_samples_list
+ for i, img in enumerate(img_samples_list):
+ log_dict[f"val/gen_images_{i}"] = wandb.Image(img)
+ else:
+ pred_img = img_samples_list
+ except Exception as e:
+ rprint(f"Error during gather: {e}")
+ pred_img = [None] * len(img_samples_list)
+ has_img = False
+ with try_except(write_error_to_file=True):
+ if text_samples is not None:
+ text_samples = gather(text_samples)
+ pred_txt = wrapped_batch_decode(self.tokenizer, text_samples, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)
+ prefix = "class_cond" if has_label else ("cond" if self.config.trainer.interleaved else "uncond")
+
+ if isinstance(pred_img, Tensor):
+ pred_img = pred_img.float().cpu()
+
+ pred_img = gather_object(pred_img)
+ gen_table = wandb.Table(columns=[*([f"{prefix}_sampled_image"] if has_img else []), f"{prefix}_sampled_caption", *(["Label"] if has_label else [])])
+ for img, caption, label in zip(pred_img, pred_txt, label):
+ gen_table.add_data(*([wandb.Image(img)] if has_img else []), caption, *([label] if has_label else []))
+ log_dict[f"{prefix}_sample_table"] = gen_table
+ log({**log_dict, **self.get_step_metrics()})
+
+ if getattr(self.config.trainer, "print_llm_loss", False) and hasattr(self, 'histogram') and not should_eval_speed:
+ avg_losses = {t: sum(l) / len(l) for t, l in self.histogram.items()}
+ timesteps, avg_losses = zip(*sorted(avg_losses.items()))
+
+ from io import BytesIO
+
+ import matplotlib.pyplot as plt
+
+ plt.plot(timesteps, avg_losses)
+ plt.xlabel('Timesteps')
+ plt.ylabel('Average Loss')
+ plt.title('Loss over Time')
+ plt.show()
+
+ buf = BytesIO()
+ plt.savefig(buf, format='png')
+ plt.close()
+ buf.seek(0)
+ img = Image.open(buf)
+ log({"loss_over_time": wandb.Image(img)})
+ rprint("Logged loss over time")
+
+ if hasattr(self, "valid_txt_metrics"):
+ self.valid_metrics.reset()
+ self.valid_txt_metrics.reset()
+ self.valid_img_metrics.reset()
+
+ if (time.time() - getattr(self, "validation_start_time", time.time())) > 15:
+ rprint(f"Validation took: {time.time() - self.validation_start_time} seconds")
+
+ dprint("on_validation_epoch_end finished")
+
+def on_validation_epoch_cleanup(self):
+ self.reset_validation_metrics()
+ self.fid_eval = False
+ self.saved_tokens = defaultdict(list)
+ if hasattr(self, "inception_metrics"): del self.inception_metrics
+
+ if "tokens" in self.config.data.train and hasattr(self, "vae"):
+ del self.vae
+ self.vae = None
+
+ if is_main_process() and not getattr(self.config.eval, "disable_fid_cleanup", False): self.cleanup_fid_output()
+ empty_device_cache()
+
+ if getattr(self.config.trainer, "attach_oom_observer_eval", False):
+ if hasattr(self, "gpu_memory_reserved") and self.gpu_memory_reserved is not None:
+ cur_gpu_memory_reserved = torch.cuda.memory_reserved()
+ if getattr(self.config.trainer, "force_save_eval_memory_profile", False) or (cur_gpu_memory_reserved - self.gpu_memory_reserved > 4 * 1024**3): # 4GB in bytes
+ rprint(f"Warning: GPU memory usage increased by more than 4GB during validation. Initial: {self.gpu_memory_reserved / 1024**3:.2f}GB, Current: {cur_gpu_memory_reserved / 1024**3:.2f}GB")
+ oom_dir = Path(self.config.output_dir) / "oom_profile"
+ oom_dir.mkdir(parents=True, exist_ok=True)
+ save_memory_profile(oom_dir)
+ self.gpu_memory_reserved = None
+ dprint("Disabled memory history")
+ torch.cuda.memory._record_memory_history(enabled=None)
+
+ dprint("on_validation_epoch_cleanup finished")
+
+def gather_tokens(self, tokens):
+ tokens = torch.cat(tokens, dim=0).to(device=self.device, dtype=torch.int64)
+ tokens = gather(tokens)
+ return tokens
+
+@try_except(write_error_to_file=True, clear_cuda_cache=True)
+def get_top_k(self, batch, batch_idx):
+ if batch_idx == 0:
+ all_top_k = {1: [], 2: [], 5: []}
+ for i in range(16):
+ mod_input_ids = batch['input_ids'].clone()
+ mod_input_ids[:, self.static_txt_sl] = mod_input_ids[i, self.static_txt_sl]
+ mod_attention_mask = batch['attention_mask'].clone()
+ mod_attention_mask[:, self.static_txt_sl] = mod_attention_mask[i, self.static_txt_sl]
+
+ if getattr(self.config.eval, "cfg", None):
+ cat_mod_input_ids = torch.cat([mod_input_ids, torch.where(batch['modality'] == 1, self.mask_index, mod_input_ids)], dim=0)
+ cat_p_x0 = self.forward(
+ cat_mod_input_ids,
+ sigma=None,
+ attention_mask=mod_attention_mask,
+ batch=dict(modality=batch['modality']), modality=batch['modality']
+ )
+ logit_c, logit_u = cat_p_x0.chunk(2, dim=0)
+ _w = getattr(self.config.eval, "cfg", None)
+ model_output = (1 + _w) * logit_c - _w * logit_u
+ else:
+ model_output = self.forward(mod_input_ids, sigma=None, attention_mask=mod_attention_mask, batch=dict(modality=batch['modality']), modality=batch['modality'])
+
+ log_p_theta = torch.gather(input=model_output, dim=-1, index=mod_input_ids[:, 1:, None]).squeeze(-1)
+ mean_nll = (-log_p_theta * mod_attention_mask[:, 1:]).sum(dim=-1) / mod_attention_mask[:, 1:].sum(dim=-1)
+
+ for k in [1, 2, 5]:
+ topk_values, topk_indices = torch.topk(mean_nll, k, dim=0)
+ all_top_k[k].append(0 in topk_indices.tolist())
+
+ for k in [1, 2, 5]:
+ retrieval_rate = sum(all_top_k[k]) / len(all_top_k[k])
+ rprint(f"{retrieval_rate:.2%} retrieved in top {k}")
+ log({f"val/top_{k}": retrieval_rate})
+
+@try_except(write_error_to_file=True, clear_cuda_cache=True)
+def compute_clip_score(self, output_dir, prefix):
+ from model_utils import calculate_clip_score
+ caption_paths = [str(x.as_posix()) for x in Path(output_dir).glob('*.png') if x.is_file() and x.with_suffix('.json').exists()]
+ captions_mapping = {str(x): json.load(Path(x).with_suffix('.json').open())['caption'] for x in caption_paths}
+ clip_score = calculate_clip_score(caption_paths, captions_mapping=captions_mapping)
+ clip_score *= 100 # For some reason people scale cosine sim
+ rprint(f"{prefix} CLIP score: {clip_score}")
+ log({f"val/{prefix}_clip_score": clip_score, **self.get_step_metrics()})
+
+@try_except(write_error_to_file=True, clear_cuda_cache=True)
+def compute_inline_fid(self):
+ rprint(f"FID Eval. We have {len(self.inception_metrics.fake_uncond_features)} batches.")
+ try:
+ if self.config.mode == "eval" and not self.config.trainer.image_mode == "continuous":
+ output_dir = Path("eval_tokens").resolve()
+ output_dir.mkdir(parents=True, exist_ok=True)
+ dataset_size = sum(x[-1].shape[0] for x in self.computed_tokens)
+ data = TensorDict(
+ {
+ "txt_input_ids": torch.cat([x[1] for x in self.computed_tokens]).to(device="cpu", dtype=torch.int32),
+ "img_input_ids": torch.cat([x[2] for x in self.computed_tokens]).to(device="cpu", dtype=torch.int16),
+ "gt_img_input_ids": torch.cat([x[3] for x in self.computed_tokens]).to(device="cpu", dtype=torch.int16),
+ },
+ batch_size=[dataset_size],
+ )
+ save_loc = str(output_dir / f"{get_rank()}")
+ data.memmap(save_loc)
+ gprint(f"Saved tokens to {save_loc}")
+
+ rank = get_rank()
+ output_folder = Path("fid_metrics")
+ output_folder.mkdir(parents=True, exist_ok=True)
+ torch.save(self.inception_metrics.fake_uncond_features, output_folder / f"rank_{rank}_fake_uncond_features.pt")
+ torch.save(self.inception_metrics.fake_uncond_logits, output_folder / f"rank_{rank}_fake_uncond_logits.pt")
+ torch.save(self.inception_metrics.real_features, output_folder / f"rank_{rank}_real_features.pt")
+ rprint(f"Saved rank_{rank} tensors.")
+ except Exception as e:
+ gprint(f"Error during all_gather_object or saving tensors: {e}")
+
+ with torch.autocast(device_type=self.device.type, enabled=False):
+ metrics = self.inception_metrics.compute() # Gather is done internally
+
+ rprint(f"Computed metrics: {metrics}")
+ metrics = {f"val/{k}": v for k, v in metrics.items()}
+ log({**metrics, "trainer/global_step": self.global_step})
+ output_folder = Path("fid_metrics")
+ output_folder.mkdir(parents=True, exist_ok=True)
+ with open(output_folder / f'metrics_{get_rank()}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt', "w") as f:
+ for k, v in metrics.items():
+ f.write(f"val/{k}: {v}\n")
+
+ self.fid_eval = False
+ del self.inception_metrics
+ rprint("Finished FID eval")
+
+@try_except(write_error_to_file=True, clear_cuda_cache=True)
+def compute_clean_fid_eval(self):
+ with try_except(write_error_to_file=True):
+ images = []
+ for i, filename in enumerate(sorted(Path(self.fid_gen_dir).iterdir(), key=lambda x: random.random())):
+ if i >= self.config.loader.eval_batch_size * get_world_size():
+ break
+ if filename.is_file() and filename.suffix == ".png":
+ for i in range(3):
+ try:
+ img = Image.open(filename)
+ except Exception as e:
+ time.sleep(0.1)
+ rprint(f"Error opening image {filename}: {e}")
+ images.append(np.array(img))
+ images = np.stack(images)
+ log({"val/fid_gen_img_at_compute": wandb.Image(Im(images).torch)})
+
+ from cleanfid import fid
+ kwargs = dict()
+ if self.config.eval.clean_fid_use_precomputed_stats:
+ kwargs.update(dict(
+ dataset_name=self.config.eval.clean_fid_precomputed_name,
+ dataset_res=self.config.eval.clean_fid_precomputed_res,
+ dataset_split=self.config.eval.clean_fid_precomputed_split,
+ ))
+ else:
+ kwargs.update(dict(fdir2=str(self.fid_gt_dir)))
+
+ score = fid.compute_fid(
+ fdir1=str(self.fid_gen_dir),
+ use_dataparallel=False,
+ **kwargs
+ )
+
+ rprint(f"FID score: {score}")
+ metrics = {"val/fid_unconditional": score, **self.get_step_metrics()}
+ log(metrics)
+
+ metrics = {f"val/{k}": v for k, v in metrics.items()}
+ output_folder = Path("fid_metrics")
+ output_folder.mkdir(parents=True, exist_ok=True)
+ with open(output_folder / f'metrics_{get_rank()}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt', "w") as f:
+ for k, v in metrics.items():
+ f.write(f"{k}: {v}\n")
+
+ self.fid_eval = False
+
+def sample_for_fid(self, batch, batch_idx, return_gt_img=False, return_gt_txt=False, img_to_txt_gen=False):
+ """This function is also used for img -> txt generation."""
+ continuous_mode = self.config.trainer.image_mode == "continuous"
+ sample_kwargs = self.get_cond_dict(batch)
+ orig_modality, orig_input_ids = None, None
+ if img_to_txt_gen:
+ if self.config.parameterization == "ar":
+ txt_first_sl = slice(None, self.config.model.txt_length)
+ img_first_sl = slice(None, self.config.model.img_length)
+ if (batch["modality"][:, txt_first_sl] == 0).all(): # Flip [txt, img] -> [img, txt]
+ assert (batch["modality"][:, :self.config.model.txt_length] == 0).all() and (batch["modality"][:, self.config.model.txt_length:] == 1).all()
+ flipped_batch = dict()
+ img_slice = slice(-self.config.model.img_length, None)
+ txt_slice = slice(None, self.config.model.txt_length)
+ for key in ["modality", "attention_mask", "input_ids"]:
+ flipped_batch[key] = torch.cat([batch[key][:, img_slice], batch[key][:, txt_slice]], dim=1)
+
+ batch = flipped_batch
+ else:
+ assert (batch["modality"][:, img_first_sl] == 1).all() # We already have [img, txt]
+
+ assert (batch["modality"][:, :self.config.model.img_length] == 1).all(), "Img tokens should be 0"
+ else:
+ assert (batch["modality"][:, :self.config.model.txt_length] == 0).all() # We already have [txt, img]
+
+ sample_kwargs["sample_modality"] = batch["modality"]
+ _x0_unmask = (batch["modality"] == 1)
+ elif getattr(self.config.eval, "unconditional_fid", False):
+ sample_kwargs["x0_unmask"] = None
+ sample_kwargs["x0"] = None
+ sample_kwargs["sample_modality"] = batch["modality"]
+ elif self.config.trainer.ar_inpainting:
+ assert getattr(self.config.eval, "txt_conditional_fid", False)
+ min_val, max_val = getattr(self.config.eval, "ar_inpainting_min_val", 0.9), getattr(self.config.eval, "ar_inpainting_max_val", 1.0)
+ n = batch["modality"].shape[0]
+ _eps_t = torch.rand(n, device=self.device)
+ t = (max_val - min_val) * _eps_t + min_val
+ if getattr(self.config.eval, "ar_inpainting_force_val", None) is not None:
+ t = torch.full_like(t, getattr(self.config.eval, "ar_inpainting_force_val"), dtype=t.dtype, device=t.device)
+ if self.config.parameterization == "ar":
+ orig_modality, orig_input_ids = batch["modality"].clone(), batch["input_ids"].clone()
+ del batch["batch_contains_img"]
+ batch.auto_batch_size_()
+ batch = torch.cat([batch, batch], dim=1)
+ x0 = batch["input_ids"]
+ move_indices = torch.rand(*x0.shape, device=x0.device) < t[:, None] # Unmask so we switch sign compared to move_indices
+ move_indices[:, x0.shape[1] // 2:] = False
+ batch["input_ids"] = torch.where(move_indices, self.mask_index, x0)
+ _x0_unmask = torch.zeros_like(batch["input_ids"], dtype=torch.bool)
+ _x0_unmask[:, :batch["input_ids"].shape[1] // 2] = True
+ else:
+ _x0_unmask = torch.rand(*batch["modality"].shape, device=batch["modality"].device) > t[:, None] # Unmask so we switch sign compared to move_indices
+ sample_kwargs["sample_modality"] = batch["modality"]
+ sample_kwargs["x0_unmask"] = _x0_unmask
+ sample_kwargs["x0"] = batch["input_ids"]
+ elif getattr(self.config.eval, "class_conditional_fid", False) or getattr(self.config.eval, "txt_conditional_fid", False):
+ sample_kwargs["x0"] = batch["input_ids"]
+ if getattr(self.config.eval, "class_conditional_fid", False):
+ sample_kwargs["sample_modality"] = torch.full_like(batch["modality"], 1)
+ sample_kwargs["sample_modality"][:, 0] = 0
+ _x0_unmask = torch.zeros_like(batch["input_ids"], dtype=torch.bool)
+ _x0_unmask[..., 0] = True
+ elif getattr(self.config.eval, "txt_conditional_fid", False):
+ assert ((batch["modality"] == 1).sum(dim=-1) > 0).all(), "No img samples provided"
+ sample_kwargs["sample_modality"] = batch["modality"]
+ _x0_unmask = (batch["modality"] == 0)
+ sample_kwargs["x0_unmask"] = _x0_unmask
+
+ if continuous_mode:
+ data = self.sample_transfusion(batch_size_per_gpu=self.config.loader.eval_batch_size)
+ gen_txt_tokens = data.xt_ids[:, self.static_txt_sl]
+ gen_img_tokens = data.xt_img_embed[:, self.static_img_sl]
+ gen_img = decode_latents(self.config, self.get_vae(), gen_img_tokens)
+ else:
+ gen_txt_tokens, gen_img_tokens = self._sample(text_only=False, **sample_kwargs)
+ gen_img = decode_latents(self.config, self.get_vae(), gen_img_tokens)
+
+ fid_rec_img, gt_img_tokens, gt_txt_tokens = None, None, None
+ if return_gt_img:
+ if "img" in batch:
+ fid_rec_img = batch["img"]
+ else:
+ if orig_modality is None:
+ orig_modality = batch.get("modality", None)
+ if orig_input_ids is None:
+ orig_input_ids = batch["input_ids"]
+
+ _, gt_img_tokens = self.decode_batch(orig_input_ids, text_only=False, sample_modality=orig_modality)
+ if gt_img_tokens.shape[0] == 0:
+ rprint(f"{gt_img_tokens.shape} {batch['input_ids'].shape}")
+ fid_rec_img = decode_latents(self.config, self.get_vae(), gt_img_tokens)
+
+ if return_gt_txt:
+ if orig_input_ids is None:
+ orig_input_ids = batch["input_ids"]
+ if orig_modality is None:
+ orig_modality = batch.get("modality", None)
+ gt_txt_tokens, _ = self.decode_batch(orig_input_ids, text_only=False, sample_modality=orig_modality)
+
+ _prefix = "img_to_txt" if img_to_txt_gen else ("unconditional" if getattr(self.config.eval, "unconditional_fid", False) else "txt_to_img")
+ self.saved_tokens[_prefix + "_gen_img_tokens"].append(gen_img_tokens.detach().cpu().to(torch.int32))
+ self.saved_tokens[_prefix + "_gen_txt_tokens"].append(gen_txt_tokens.detach().cpu().to(torch.int32))
+ if gt_img_tokens is not None: self.saved_tokens[_prefix + "_gt_img_tokens"].append(gt_img_tokens.detach().cpu().to(torch.int32))
+ if gt_txt_tokens is not None: self.saved_tokens[_prefix + "_gt_txt_tokens"].append(gt_txt_tokens.detach().cpu().to(torch.int32))
+
+ return gen_img, gen_txt_tokens, gt_img_tokens, gt_txt_tokens, gen_img_tokens, fid_rec_img
+
+
+def update_inline_fid(self, batch, batch_idx):
+ gen_img, txt_tokens, gt_img_tokens, gt_txt_tokens, gen_img_tokens, fid_rec_img = self.sample_for_fid(batch, batch_idx, return_gt_img=True, return_gt_txt=True)
+
+ if self.config.mode == "eval":
+ self.computed_tokens.append((txt_tokens, gen_img_tokens, gt_img_tokens))
+ with torch.autocast(device_type=self.device.type, enabled=False):
+ self.inception_metrics.update(remap_image_torch(fid_rec_img).to(self.device), None, image_type="real")
+ self.inception_metrics.update(remap_image_torch(gen_img).to(self.device), None, image_type="unconditional")
+
+ if batch_idx == 0:
+ log({"val/fid_gen": wandb.Image(gen_img), "val/fid_gt": wandb.Image(fid_rec_img), **self.get_step_metrics()})
+
+ if batch_idx > 0 and batch_idx % 5 == 0 and self.config.mode == "eval":
+ gprint(f"Saving rank_{get_rank()} tensors.")
+ try:
+ rank = get_rank()
+ torch.save(self.inception_metrics.fake_uncond_features, f"{batch_idx}_rank_{rank}_fake_uncond_features.pt")
+ torch.save(self.inception_metrics.fake_uncond_logits, f"{batch_idx}_rank_{rank}_fake_uncond_logits.pt")
+ torch.save(self.inception_metrics.real_features, f"{batch_idx}_rank_{rank}_real_features.pt")
+ gprint(f"Saved rank_{rank} tensors.")
+ except Exception as e:
+ gprint(f"Error during all_gather_object or saving tensors: {e}")
+
+def update_clean_fid(self, batch, batch_idx):
+ assert hasattr(self, "fid_gen_dir")
+ save_gt_img = not self.config.eval.clean_fid_use_precomputed_stats
+ gen_img, txt_tokens, gt_img_tokens, gt_txt_tokens, img_samples, fid_rec_img = self.sample_for_fid(batch, batch_idx, return_gt_img=save_gt_img, return_gt_txt=True)
+
+ if self.config.model.image_model_fid_eval:
+ txt_samples = wrapped_batch_decode(self.tokenizer, txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)
+ gt_txt_samples = wrapped_batch_decode(self.tokenizer, gt_txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)
+
+ save_loc = Path(self.fid_gen_dir)
+ save_loc.mkdir(parents=True, exist_ok=True)
+ quantized_img = remap_image_torch(gen_img).permute(0, 2, 3, 1).cpu().numpy()
+
+ if save_gt_img:
+ gt_quantized_img = remap_image_torch(fid_rec_img).permute(0, 2, 3, 1).cpu().numpy()
+ save_loc_gt = Path(self.fid_gt_dir)
+ save_loc_gt.mkdir(parents=True, exist_ok=True)
+
+ for i in range(gen_img.shape[0]):
+ gen_img_pil = Image.fromarray(quantized_img[i])
+ suffix = ''.join(random.choices(string.ascii_lowercase + string.digits, k=4))
+ filename = f"{batch_idx}_{get_rank()}_{i}_{suffix}.png"
+ out_file_path = save_loc / filename
+ gen_img_pil.save(out_file_path)
+
+ if self.config.eval.txt_conditional_fid:
+ with open(out_file_path.with_suffix(".json"), 'w') as json_file:
+ json.dump({"caption": txt_samples[i]}, json_file)
+
+ if save_gt_img:
+ gt_img_pil = Image.fromarray(gt_quantized_img[i])
+ gt_out_file_path = save_loc_gt / filename
+ gt_img_pil.save(gt_out_file_path)
+
+ if self.config.eval.txt_conditional_fid:
+ with open(gt_out_file_path.with_suffix(".json"), 'w') as json_file:
+ json.dump({"caption": gt_txt_samples[i]}, json_file)
+
+ if batch_idx == 0:
+ rprint(f"Logging at batch idx {batch_idx}")
+ time.sleep(0.2)
+ with try_except(write_error_to_file=True):
+ images = []
+ for i, filename in enumerate(sorted(Path(self.fid_gen_dir).iterdir(), key=lambda x: random.random())):
+ if i >= self.config.loader.eval_batch_size * get_world_size():
+ break
+ if filename.is_file() and filename.suffix == ".png":
+ img = Image.open(filename)
+ images.append(np.array(img))
+ images = np.stack(images)
+ log({"val/fid_gen_img": wandb.Image(Im(images).torch)})
+ rprint(f"FID Txt: {txt_samples[0]}")
+
+def update_img_to_txt_mauve_clip(self, batch, batch_idx):
+ assert hasattr(self, "img_to_txt_mauve_gen_dir")
+ save_gt_img = True
+ empty_device_cache()
+ gen_img, gen_txt_tokens, gt_img_tokens, gt_txt_tokens, gen_img_tokens, fid_rec_img = self.sample_for_fid(batch, batch_idx, return_gt_img=save_gt_img, return_gt_txt=True, img_to_txt_gen=True)
+
+ gen_txt_samples = wrapped_batch_decode(self.tokenizer, gen_txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)
+ gt_txt_samples = wrapped_batch_decode(self.tokenizer, gt_txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)
+
+ save_loc = Path(self.img_to_txt_mauve_gen_dir)
+ save_loc.mkdir(parents=True, exist_ok=True)
+ quantized_img = remap_image_torch(gen_img).permute(0, 2, 3, 1).cpu().numpy()
+
+ if save_gt_img:
+ gt_quantized_img = remap_image_torch(fid_rec_img).permute(0, 2, 3, 1).cpu().numpy()
+ save_loc_gt = Path(self.img_to_txt_mauve_gt_dir)
+ save_loc_gt.mkdir(parents=True, exist_ok=True)
+
+ for i in range(gen_img.shape[0]):
+ gen_img_pil = Image.fromarray(quantized_img[i])
+ suffix = ''.join(random.choices(string.ascii_lowercase + string.digits, k=4))
+ filename = f"{batch_idx}_{get_rank()}_{i}_{suffix}.png"
+ out_file_path = save_loc / filename
+ gen_img_pil.save(out_file_path)
+ with open(out_file_path.with_suffix(".json"), 'w') as json_file:
+ json.dump({"caption": gen_txt_samples[i]}, json_file)
+
+ if save_gt_img:
+ gt_img_pil = Image.fromarray(gt_quantized_img[i])
+ gt_out_file_path = save_loc_gt / filename
+ gt_img_pil.save(gt_out_file_path)
+ with open(gt_out_file_path.with_suffix(".json"), 'w') as json_file:
+ json.dump({"caption": gt_txt_samples[i]}, json_file)
+
+ if batch_idx == 0:
+ rprint(f"GT img -> txt mauve: {gt_txt_samples[0]}")
+ rprint(f"Gen img -> txt mauve: {gen_txt_samples[0]}")
+
+def compute_mauve_entropy(self, img_to_txt_mauve_gen_dir, img_to_txt_mauve_gt_dir, gen_txt_tokens, gt_txt_tokens, prefix):
+ gt_txt = []
+ gt_img = []
+ gt_dir = Path(img_to_txt_mauve_gt_dir)
+ gen_dir = Path(img_to_txt_mauve_gen_dir)
+ stems = [f.stem for f in gt_dir.iterdir() if f.suffix == '.json' and (gen_dir / f.name.replace("gt", "gen")).exists()]
+ assert len(stems) > 0, f"No stems found in {gt_dir} and {gen_dir}"
+ rprint(f"Found {len(stems)} unique stems")
+
+ gt_img = []
+ gt_txt = []
+ gen_txt = []
+ gen_img = []
+ data_dict = {}
+ for stem in stems:
+ gt_img_path = gt_dir / f"{stem}.png"
+ gt_img.append(Image.open(gt_img_path))
+
+ gen_img_path = gen_dir / f"{stem}.png"
+ gen_img.append(Image.open(gen_img_path))
+
+ with open(gt_dir / f"{stem}.json", 'r') as f:
+ gt_txt.append(json.load(f)["caption"])
+
+ with open(gen_dir / f"{stem}.json", 'r') as f:
+ gen_txt.append(json.load(f)["caption"])
+
+ table = wandb.Table(columns=["GT Image", "GT Text", "Generated Image", "Generated Text"])
+ num_samples_to_display = min(20, len(stems))
+ for i in range(num_samples_to_display):
+ table.add_data(
+ wandb.Image(gt_img[i]),
+ gt_txt[i],
+ wandb.Image(gen_img[i]),
+ gen_txt[i]
+ )
+
+ data_dict[f"val/{prefix}_mauve_samples"] = table
+ if not getattr(self.config.eval, "global_disable_mauve", False):
+ data_dict[f"val/{prefix}_mauve_score"] = self.get_mauve_score(gen_txt, gt_txt, prefix)
+ data_dict[f"val/{prefix}_gt_entropy"] = self.compute_entropy(gt_txt_tokens)
+ data_dict[f"val/{prefix}_gen_entropy"] = self.compute_entropy(gen_txt_tokens)
+ data_dict[f"val/{prefix}_percent_valid_txt_tokens"] = self.count_valid_tokens(gen_txt_tokens).float().mean(dim=-1) / gen_txt_tokens.shape[-1]
+ log({**data_dict, **self.get_step_metrics()})
+
+def count_valid_tokens(self, text_tokens):
+ after_first_eos = torch.cumsum(text_tokens == self.tokenizer.eos_token_id, dim=1).bool()
+ after_first_eos_mask = after_first_eos.cumsum(dim=1) > 1
+ return ~after_first_eos_mask
+
+def get_valid_seq(self, text_tokens):
+ if self.tokenizer.bos_token_id == self.tokenizer.eos_token_id:
+ assert False, "BOS and EOS are the same."
+
+ eos_positions = (text_tokens == self.tokenizer.eos_token_id).nonzero(as_tuple=True)[0]
+ if len(eos_positions) > 0:
+ return text_tokens[..., :eos_positions[0] + 1]
+ else:
+ return text_tokens
+
+@try_except(write_error_to_file=True, clear_cuda_cache=True)
+def compute_entropy(self, text_tokens):
+ """Compute the entropy of the generated text.
+ Definition Pg 33 of https://arxiv.org/pdf/2409.02908
+
+ Args:
+ text_tokens: Tensor of generated text tokens. (B, L)
+ Returns:
+ Entropy of the generated text.
+ """
+ val_entropy = Entropy(sync_on_compute=False).to(self.device)
+ B, L = text_tokens.shape
+ K = self.tokenizer.vocab_size # Use the actual vocabulary size
+
+ # Compute entropy for each sequence in the batch
+ entropies = []
+ for seq in text_tokens:
+ seq_length = seq.numel()
+ token_frequencies = torch.bincount(self.get_valid_seq(seq), minlength=K)
+ p_k = token_frequencies.float() / seq_length
+ p_k = p_k.to(self.device)
+ nll = -torch.sum(p_k * torch.log(p_k + 1e-10))
+ entropies.append(nll)
+
+ # Calculate the average entropy across the batch
+ avg_entropy = torch.mean(torch.tensor(entropies))
+
+ # Update the validation entropy metric
+ val_entropy.update(avg_entropy, weight=B)
+ return val_entropy.compute()
+
+@try_except(write_error_to_file=True, clear_cuda_cache=True)
+def get_mauve_score(self, pred, gt, prefix):
+ from evaluate import load
+ mauve = load('mauve')
+
+ # We require a list of strings for pred, gt
+ mauve_metric = MauveScore(sync_on_compute=False).to(self.device)
+ rprint(f"Generated {len(pred)} MAUVE predictions")
+ assert len(pred) >= self.config.eval.mauve_num_samples
+ rprint(f'Before removing duplicates: {len(pred)}')
+ pred_text = list(set(pred))
+ rprint(f'After removing duplicates: {len(pred_text)}')
+ ref_text = list(set(gt))
+ store_path = os.path.join(self.config.output_dir, f"{prefix}_mauve_predictions.pkl")
+ with open(store_path, "wb") as f:
+ pickle.dump(pred_text, f)
+
+ rprint(f"Stored {len(pred_text)} unique MAUVE predictions to {store_path}")
+
+ min_len = min(len(pred_text), len(ref_text))
+ pred_text = pred_text[:min_len]
+ ref_text = ref_text[:min_len]
+
+ rprint(f"Computing img to txt MAUVE score for {len(pred_text)} unique predictions and {len(ref_text)} references")
+
+ # compute mauve score
+ device_id = 0 # this is main process
+ mauve_divergence_curve_discretization_size = self.config.eval.mauve_divergence_curve_discretization_size
+ mauve_scaling_factor = self.config.eval.mauve_scaling_factor
+ avg_over_seed = self.config.eval.mauve_average_over_seeds
+
+ # generate avg_over_seed number of seeds randomly
+ random_seeds = [random.randint(0, 100000) for _ in range(avg_over_seed)]
+ for seed in random_seeds:
+ mauve_score = mauve.compute(
+ references=ref_text,
+ predictions=pred_text,
+ device_id=device_id,
+ divergence_curve_discretization_size=mauve_divergence_curve_discretization_size,
+ mauve_scaling_factor=mauve_scaling_factor
+ )
+ mauve_metric.update(mauve_score.mauve)
+ rprint(f"MAUVE score for seed {seed}: {mauve_score.mauve}")
+ store_path = os.path.join(self.config.output_dir, f"{prefix}_mauve_score_seed_{seed}.txt")
+ with open(store_path, "w") as f:
+ f.write(str(mauve_score))
+
+ rprint(f"Stored MAUVE score for seed {seed} to {store_path}")
+
+ avg_mauve_score = mauve_metric.compute()
+ return avg_mauve_score
+
+
+def _sample_prior(self, *batch_dims):
+ return self.mask_index * torch.ones(*batch_dims, dtype=torch.int64)
+
+def get_cfg_weight(self, t):
+ _cfg = self.config.eval.cfg
+ if not getattr(self.config.eval, "force_cfg_value", False):
+ if _cfg == -1:
+ _cfg = torch.linspace(0, 10, t.shape[0]).to(t.device)
+
+ if getattr(self.config.eval, "cfg_min_timestep", None) is not None and getattr(self.config.eval, "cfg_max_timestep", None) is not None:
+ _w = (_cfg * ((t - getattr(self.config.eval, "cfg_max_timestep")) / (getattr(self.config.eval, "cfg_min_timestep") - getattr(self.config.eval, "cfg_max_timestep"))))[:, None]
+ else:
+ _w = (_cfg * (1 - t))[:, None]
+ else:
+ _w = _cfg
+
+ if getattr(self.config.eval, "cfg_min_timestep", None) is not None:
+ _w = torch.where(t > getattr(self.config.eval, "cfg_min_timestep", None), _w, torch.tensor(0.0))
+
+ if getattr(self.config.eval, "cfg_max_timestep", None) is not None:
+ _w = torch.where(t < getattr(self.config.eval, "cfg_max_timestep", None), _w, torch.tensor(0.0))
+
+ if not isinstance(_w, torch.Tensor):
+ _w = torch.tensor(_w)
+
+ return _w
+
+def _ddpm_forward(self, x, t, sigma_t, x0=None, x0_unmask=None, force_cfg=None, **kwargs):
+ _w = None
+ if getattr(self.config.eval, "cfg", None) is not None and x0_unmask is not None and x0_unmask.sum() > 0:
+ _w = self.get_cfg_weight(t)
+
+ orig_modality, orig_sample_ids = None, None
+ if _w is not None and (_w > 0).any():
+ x_uncond = x.clone()
+ x_uncond[x0_unmask] = self.mask_index
+ if getattr(self.config.eval, "split_cfg_batches", False):
+ cat_p_x0 = torch.cat([
+ self.forward(
+ x=x,
+ sigma=sigma_t,
+ return_logits=True,
+ **kwargs
+ ),
+ self.forward(
+ x=x_uncond,
+ sigma=sigma_t,
+ return_logits=True,
+ **kwargs
+ )
+ ], dim=0)
+ else:
+ orig_modality = kwargs.get("modality", None)
+ if orig_modality is not None:
+ orig_modality = orig_modality.clone()
+ kwargs["modality"] = torch.cat([orig_modality, orig_modality], dim=0)
+
+ orig_sample_ids = kwargs.get("sample_ids", None)
+ if orig_sample_ids is not None:
+ orig_sample_ids = orig_sample_ids.clone()
+ kwargs["sample_ids"] = torch.cat([orig_sample_ids, orig_sample_ids], dim=0)
+
+ if self.config.trainer.interleaved_training_flex_attention:
+ assert 'sample_ids' in kwargs
+ kwargs['block_mask'] = get_interleaved_block_mask(kwargs['sample_ids'], x.shape[0], x.shape[-1], self.device)
+
+ cat_p_x0 = self.forward(
+ x=torch.cat([x, x_uncond], dim=0),
+ sigma=torch.cat([sigma_t, sigma_t], dim=0) if sigma_t is not None else None,
+ return_logits=True,
+ **kwargs
+ )
+ kwargs["modality"] = orig_modality
+ kwargs["sample_ids"] = orig_sample_ids
+
+ logit_c, logit_u = cat_p_x0.chunk(2, dim=0)
+ if isinstance(_w, torch.Tensor) and _w.ndim == 2 and logit_c.ndim == 3:
+ _w = _w.unsqueeze(-1)
+ output_logits = (1 + _w) * logit_c - _w * logit_u
+ _modality = kwargs.get("modality", None)
+ if self.config.trainer.ar_shift:
+ _modality = _modality[:, 1:]
+
+ p_x0 = self._subs_parameterization(output_logits, xt=None, batch=None, modality=_modality)
+ p_x0 = p_x0.exp()
+ del logit_c, logit_u, cat_p_x0, output_logits, orig_modality, orig_sample_ids, x, x_uncond
+ else:
+ p_x0 = self.forward(x=x, sigma=sigma_t, **kwargs)
+ p_x0 = p_x0.exp()
+
+ if self.config.trainer.force_bf16_eval:
+ p_x0 = p_x0.to(torch.bfloat16)
+
+ kwargs.pop("attention_caching", None)
+ kwargs.pop("block_mask", None)
+
+ if getattr(self.config.eval, "force_empty_cache", False):
+ empty_device_cache()
+
+ return p_x0
+
+
+def sample_masking(self, batch, batch_idx):
+ assert (self.config.loader.batch_size == self.config.loader.eval_batch_size) or self.config.mode == 'eval' # need for modality otherwise x and modality have different batch sizes
+ if getattr(self.config.model, "img_cond", False):
+ text_samples, img_samples = self._sample(text_only=False, **self.get_cond_dict(batch))
+ pred_img = decode_latents(self.config, self.get_vae(), img_samples)
+ log({"val/gen_images_": wandb.Image(pred_img), "trainer/global_step": self.global_step})
+
+ orig_bs = batch["input_ids"].shape[0]
+ bs = min(10, max(1, int(orig_bs // 2)))
+ bs = getattr(self.config.eval, "masking_batch_size", bs)
+ bs = min(bs, orig_bs)
+
+ if getattr(self.config.eval, "num_random_masking", None) is not None:
+ num_random_masking = getattr(self.config.eval, "num_random_masking", 1)
+ bs = max(bs, num_random_masking)
+ else:
+ num_random_masking = max((x0.shape[0] + 1) // 4, 1)
+
+ _attention_mask = (batch["attention_mask"] if "attention_mask" in batch else None)[:bs]
+ _input_ids = (batch["input_ids"])[:bs]
+ _x_modality = (batch["modality"])[:bs] if "modality" in batch else None
+
+ if _x_modality.shape[0] != bs:
+ _x_modality = _x_modality[[0]].repeat(bs, 1)
+
+ (input_tokens, output_tokens, _attention_mask) = self._maybe_sub_sample(_input_ids, _attention_mask)
+ x0 = input_tokens
+ forward_kwargs = self.get_cond_dict(batch)
+ forward_kwargs['is_sample_masking'] = True
+
+ if "x_cond" in forward_kwargs:
+ forward_kwargs["x_cond"] = forward_kwargs["x_cond"][:bs]
+
+ assert output_tokens is None
+ assert self.T == 0 and self.change_of_variables is False
+
+ random_masking_ratio = getattr(self.config.eval, "random_masking_ratio", 0.95)
+ t = random_masking_ratio + (1 - random_masking_ratio) * torch.rand(num_random_masking, device=x0.device)
+ sigma, dsigma = self.noise(t)
+ unet_conditioning = sigma[:, None]
+ move_chance = 1 - torch.exp(-sigma[:, None])
+
+ unet_conditioning = torch.cat([unet_conditioning, unet_conditioning.new_full((bs - num_random_masking, 1), torch.nan)], dim=0)
+ move_chance = torch.cat([move_chance, move_chance.new_full((bs - num_random_masking, move_chance.shape[1]), 1)], dim=0)
+
+ uniform_mask = torch.full(x0.shape, True, device=x0.device, dtype=torch.bool)
+ text_only_mask = uniform_mask.clone()
+ text_only_mask = torch.where(_x_modality == 1, False, text_only_mask)
+
+ image_only_mask = uniform_mask.clone()
+ image_only_mask = torch.where(_x_modality == 0, False, image_only_mask)
+ image_only_mask = torch.where(batch["batch_contains_img"][:bs, None], image_only_mask, True)
+ mask_dict = dict(mask_all=uniform_mask, mask_text_only=text_only_mask, mask_image_only=image_only_mask)
+
+ if getattr(self.config.eval, "mask_img_only", False):
+ uniform_mask = torch.full(x0.shape, True, device=x0.device, dtype=torch.bool)
+ image_only_mask = torch.where(_x_modality == 0, False, uniform_mask)
+ move_chance = torch.ones_like(move_chance)
+ mask_dict = dict(mask_image_only=image_only_mask)
+ elif getattr(self.config.eval, "mask_img_only_keep_partial", False):
+ mask_dict = dict(mask_image_only=image_only_mask)
+ elif getattr(self.config.eval, "mask_all_only", False):
+ mask_dict = dict(mask_all=uniform_mask)
+
+ only_uniform_mask = getattr(self.config.eval, "only_uniform_mask", False)
+
+ table_dict = dict()
+ for mask_name, allow_move_mask in mask_dict.items():
+ if mask_name == "mask_all" and not only_uniform_mask:
+ _move_chance = 0.5 + (1 - 0.5) * torch.rand_like(move_chance)
+ elif mask_name == "mask_text_only":
+ _move_chance = torch.zeros_like(move_chance)
+ else:
+ _move_chance = move_chance
+
+ xt = self.q_xt(
+ x0,
+ _move_chance,
+ allow_move_mask,
+ mask_image_square=(mask_name != "mask_text_only") and not only_uniform_mask,
+ mask_text_region=(mask_name != 'mask_image_only') and not only_uniform_mask
+ )
+
+ if getattr(self.config.eval, "single_step_denoising", False):
+ forward_kwargs.pop("is_sample_masking", None)
+ model_output = self.forward(xt, unet_conditioning, **forward_kwargs)
+ if not self.is_compiled:
+ utils.print_nans(model_output, "model_output")
+ model_output = model_output.exp()
+ pred_tokens = model_output.argmax(dim=-1)
+ pred_tokens = torch.where(xt == self.mask_index, pred_tokens, xt)
+ pred_text, pred_img = self.decode_batch(pred_tokens, text_only=False, sample_modality=_x_modality)
+ pred_img = decode_latents(self.config, self.get_vae(), pred_img)
+ pred_txt = wrapped_batch_decode(self.tokenizer, pred_text, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)
+ else:
+ xt_unmasked = xt != self.mask_index
+ pred_txt, pred_img = self.sample(x0=xt, x0_unmask=xt_unmasked, sample_modality=_x_modality, **forward_kwargs)
+
+ gen_table = wandb.Table(columns=["GT Img", "GT Caption", "Masked Img", "Masked Caption", "Pred Img", "Pred Caption", "Move chance"])
+ masked_txt, masked_img, mask_text_mask, mask_img_mask = self.decode_batch(
+ xt, text_only=False, return_masks=True, allow_mask_index=True, sample_modality=_x_modality
+ )
+
+ downscale_ratio = self.config.model.downscale_ratio
+ latent_dim = self.config.data.resolution // downscale_ratio
+
+ img_mask = einops.repeat(
+ einops.rearrange(mask_img_mask[:, self.static_img_sl], "b (h w) -> b h w", h=latent_dim, w=latent_dim),
+ "b h w -> b (h na) (w nb)",
+ na=downscale_ratio,
+ nb=downscale_ratio,
+ )
+
+ gt_txt, gt_img = self.decode_batch(_input_ids, text_only=False, sample_modality=_x_modality)
+ gt_txt = wrapped_batch_decode(self.tokenizer, gt_txt, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)
+ gt_img = decode_latents(self.config, self.get_vae(), gt_img)
+
+ masked_txt = wrapped_batch_decode(self.tokenizer, masked_txt, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)
+ masked_img = gt_img.clone().permute(0, 2, 3, 1)
+ masked_img[img_mask] = torch.tensor([0.5, 0.5, 0.5], dtype=masked_img.dtype, device=masked_img.device)
+ masked_img = masked_img.permute(0, 3, 1, 2)
+ for _gt_img, _gt_txt, _masked_img, _masked_txt, _pred_img, _pred_txt, _move_chance in zip(
+ gt_img, gt_txt, masked_img, masked_txt, pred_img, pred_txt, move_chance
+ ):
+ gen_table.add_data(
+ wandb.Image(_gt_img), _gt_txt, wandb.Image(_masked_img), _masked_txt, wandb.Image(_pred_img), _pred_txt, _move_chance
+ )
+
+ table_suffix = f"_{batch_idx}"
+ table_dict[f"{mask_name}_sample_table{table_suffix}"] = gen_table
+
+ log({**table_dict, "trainer/global_step": self.global_step})
+
+def log_flops(self, batch, batch_idx):
+ use_torch_tnt = False
+ use_native_torch = True
+ use_fvcore = False
+ with torch.enable_grad():
+ with torch.autocast(self.device.type, dtype=self.dtype):
+ new_batch_idxs = batch["input_ids"].new_ones((self.config.loader.batch_size, self.config.model.length))
+ if use_fvcore:
+ # Broken due to some issue with triton
+ from fvcore.nn import (ActivationCountAnalysis,
+ FlopCountAnalysis, flop_count_str,
+ flop_count_table)
+ example_input = (new_batch_idxs, None)
+ fca = FlopCountAnalysis(self.accelerator.unwrap_model(self.backbone), example_input)
+ aca = ActivationCountAnalysis(self.accelerator.unwrap_model(self.backbone), example_input)
+ print(flop_count_table(fca, max_depth=1))
+ print(flop_count_str(fca))
+ print(fca.total())
+
+ if use_torch_tnt:
+ from torchtnt.utils.module_summary import get_module_summary
+ module_summary = get_module_summary(self.backbone, module_args=(new_batch_idxs, None), module_kwargs={})
+ rprint(module_summary)
+ rprint(f"TorchTNT Forward FLOPs: {module_summary.flops_forward / 1e12:.2f} FLOPs")
+ rprint(f"TorchTNT Backward FLOPs: {module_summary.flops_backward / 1e12:.2f} FLOPs")
+ rprint(f"TorchTNT Total FLOPs: {(module_summary.flops_forward + module_summary.flops_backward) / 1e12:.2f} FLOPs")
+
+ if use_native_torch:
+ from torch.utils.flop_counter import FlopCounterMode
+ flop_counter = FlopCounterMode(self.backbone, display=True, depth=3)
+ with flop_counter:
+ fake_batch = {}
+ fake_batch["input_ids"] = new_batch_idxs
+ fake_batch['attention_mask'] = batch['attention_mask'].new_ones(new_batch_idxs.shape)
+ if 'modality' in batch:
+ fake_batch['modality'] = batch['modality'].new_ones(new_batch_idxs.shape)
+ fake_batch['x0'] = fake_batch["input_ids"]
+ t = self._sample_t(fake_batch['x0'].shape[0], fake_batch['x0'].device)
+ sigma, dsigma = self.noise(t)
+ move_chance = 1 - torch.exp(-sigma[:, None])
+ xt = self.q_xt(fake_batch['x0'], move_chance)
+ fake_batch['xt'] = xt
+ if self.config.trainer.image_mode == "continuous":
+ B, T = fake_batch["input_ids"].shape
+ indices = fake_batch["input_ids"].to(batch['text_tokens'].dtype)
+ fake_sigma = torch.ones(B, T, device=self.device).long()
+ fake_x_img_emb = torch.randn(B, T, 4 * (self.config.model.patching_downscale ** 2), device=self.device)
+ fake_modality = torch.zeros(B, T, device=self.device, dtype=torch.long)
+ fake_modality[:, self.config.model.txt_length:] = True
+ logits = self.backbone(indices=indices, sigma=fake_sigma, continuous_mode=True, x_img_emb=fake_x_img_emb, modality=fake_modality) # todo remove hardcoding 4
+ else:
+ logits = self.backbone(fake_batch["input_ids"], sigma=None, modality=fake_batch.get("modality", None))
+ from transformers.modeling_outputs import \
+ CausalLMOutputWithPast
+ if isinstance(logits, torch.Tensor):
+ logits = logits
+ elif isinstance(logits, tuple):
+ logits = logits[0]
+ elif isinstance(logits, CausalLMOutputWithPast):
+ logits = logits.logits
+
+ loss = logits.mean().to(torch.float32)
+ loss.backward()
+
+ total_flops = flop_counter.get_total_flops()
+ rprint(f"Total FLOPs Per Sample Fwd+Bwd: {(total_flops / self.config.loader.batch_size) / 1e12:.2f} TFLOPs")
+ rprint(f"Total FLOPs Per Fwd+Bwd: {total_flops / 1e12:.2f} TFLOPs")
+ rprint(f"Total FLOPs Per Global Step: {(total_flops / 1e12) * self.world_size * self.gradient_accumulation_steps:.2f} TFLOPs")
+
+ rprint(f"GPU available FLOP/s: {get_available_flops(new_batch_idxs.device, self.dtype) / 1e12:.2f} TFLOP/s")
+ rprint(f"Total available FLOP/s: {(get_available_flops(new_batch_idxs.device, self.dtype) / 1e12) * self.world_size * self.gradient_accumulation_steps:.2f} TFLOP/s")
+ rprint(f"Used Batch Size: {self.config.loader.batch_size} for FLOP Calculations")
+
+@torch.inference_mode()
+def _ddpm_update(self, x, t, dt, **kwargs):
+ sigma_t, _ = self.noise(t)
+ sigma_s, _ = self.noise(t - dt)
+ if sigma_t.ndim > 1:
+ sigma_t = sigma_t.squeeze(-1)
+ if sigma_s.ndim > 1:
+ sigma_s = sigma_s.squeeze(-1)
+ assert sigma_t.ndim == 1, sigma_t.shape
+ assert sigma_s.ndim == 1, sigma_s.shape
+ move_chance_t = 1 - torch.exp(-sigma_t)
+ move_chance_s = 1 - torch.exp(-sigma_s)
+ move_chance_t = move_chance_t[:, None, None]
+ move_chance_s = move_chance_s[:, None, None]
+ nfe_cnt = 0
+ _sigma = None if getattr(self.config.trainer, "force_null_sigma", False) else sigma_t
+ p_x0 = self._ddpm_forward(x, t, _sigma, **kwargs)
+ nfe_cnt += 1
+ assert move_chance_t.ndim == p_x0.ndim
+ # Technically, this isn't q_xs since there's a division
+ # term that is missing. This division term doesn't affect
+ # the samples.
+ q_xs = p_x0 * (move_chance_t - move_chance_s)
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
+ _x = _sample_categorical(q_xs)
+
+ copy_flag = (x != self.mask_index).to(x.dtype)
+ del p_x0, q_xs, move_chance_t, move_chance_s
+ return copy_flag * x + (1 - copy_flag) * _x, nfe_cnt
+
+@torch.inference_mode()
+def _ddpm_caching_update(self, x, t, dt, p_x0=None, x0=None, x0_unmask=None, modality=None,**kwargs):
+ assert self.config.noise.type == "loglinear"
+ sigma_t, _ = self.noise(t)
+ if t.ndim > 1:
+ t = t.squeeze(-1)
+
+ nfe_cnt = 0
+ assert t.ndim == 1
+ move_chance_t = t[:, None, None]
+ move_chance_s = (t - dt)[:, None, None]
+ assert move_chance_t.ndim == 3, move_chance_t.shape
+
+ if p_x0 is None:
+ _sigma = None if getattr(self.config.trainer, "force_null_sigma", False) else sigma_t
+ p_x0 = self._ddpm_forward(x, t, _sigma, x0=x0, x0_unmask=x0_unmask, modality=modality, **kwargs)
+ nfe_cnt += 1
+ assert move_chance_t.ndim == p_x0.ndim
+ if self.config.trainer.force_bf16_eval: empty_device_cache()
+ q_xs = p_x0 * (move_chance_t - move_chance_s)
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
+ _x = _sample_categorical(q_xs)
+ copy_flag = (x != self.mask_index).to(x.dtype)
+ if self.config.trainer.force_bf16_eval: empty_device_cache()
+
+ if self.config.trainer.ar_shift:
+ if x0 is not None:
+ _x = torch.cat([x0[:, [0]], _x], dim=1)
+ else:
+ _x = torch.cat([torch.full_like(_x[..., :1], fill_value=self.tokenizer.pad_token_id), _x], dim=1)
+
+ del q_xs, move_chance_t, move_chance_s
+ return p_x0, copy_flag * x + (1 - copy_flag) * _x, nfe_cnt
+
+
+@try_except(write_error_to_file=True, clear_cuda_cache=True)
+@torch.inference_mode()
+def _sample(
+ self,
+ num_steps=None,
+ eps=1e-5,
+ text_only=True,
+ x0=None,
+ x0_unmask=None,
+ batch_size_per_gpu=None,
+ example_batch=None,
+ sample_batch_idx=None,
+ sample_modality=None,
+ sample_ids=None,
+ return_raw_data=False,
+ **kwargs,
+):
+ """Generate samples from the model."""
+ if not (x0 is None) == (x0_unmask is None):
+ breakpoint()
+ assert (x0 is None) == (x0_unmask is None), f"x0: {x0} x0_unmask: {x0_unmask}"
+ batch_size_per_gpu = (x0.shape[0] if x0 is not None else self.config.loader.eval_batch_size) if batch_size_per_gpu is None else batch_size_per_gpu
+ sample_modality = kwargs.get("modality", None) if sample_modality is None else sample_modality
+ kwargs['modality'] = sample_modality
+ kwargs['sample_ids'] = sample_ids
+ return_nfe = kwargs.pop('return_nfe', False)
+ is_sample_masking = kwargs.pop('is_sample_masking', False)
+ allow_interleaved_conditional = kwargs.pop('allow_interleaved_conditional', False)
+ nfe_cnt = 0
+ assert batch_size_per_gpu > 0
+ if num_steps is None:
+ num_steps = self.config.sampling.steps
+ if getattr(self.config.eval, "test_eval_speed", False) and getattr(self.config.eval, 'eval_at_ratio_length', False):
+ num_steps = self.config.model.length
+ if getattr(self.config.eval, "num_steps_ratio", None) is not None:
+ num_steps = int(num_steps * self.config.eval.num_steps_ratio)
+
+ decode_kwargs = dict(sample_modality=sample_modality, return_raw_data=return_raw_data, is_sample_masking=is_sample_masking)
+
+ if x0 is not None and x0_unmask is not None:
+ x = self._sample_prior(batch_size_per_gpu, x0.shape[1]).to(self.device)
+ decode_kwargs['x0_unmask'] = x0_unmask
+ if getattr(self.config.eval, "visualize_sample", False):
+ x_viz = x.clone()
+ x_viz = torch.where(x0_unmask, x0, x)
+ _mask_id = self.tokenizer("mask")['input_ids']
+ assert len(_mask_id) == 3
+ x_viz[x_viz == self.mask_index] = _mask_id[1]
+ ret_txt, ret_img = self.decode_sampling(x_viz, text_only, **kwargs, **decode_kwargs, image_save_postfix="_masked_input")
+ print(ret_txt)
+
+ elif (self.config.trainer.interleaved and not self.config.backbone == "chameleon") and allow_interleaved_conditional:
+ assert self.config.trainer.interleaved_training_flex_attention
+ x0 = example_batch['input_ids'].to(self.device)
+ total_samples = getattr(self.config.eval, "num_uncond_sample_batches", 1) - 1
+ half_uncond = getattr(self.config.eval, "half_uncond", False)
+ if not half_uncond or sample_batch_idx >= total_samples // 2:
+ unmask_modality = getattr(self.config.eval, "unmask_modality", sample_batch_idx % 2)
+ x0_unmask = sample_modality == unmask_modality
+ if x0_unmask.sum() == x0.numel():
+ unmask_modality = 1 - unmask_modality
+ x0_unmask = sample_modality == unmask_modality
+
+ if x0.shape != sample_modality.shape:
+ breakpoint()
+
+ if unmask_modality == 1:
+ x0_unmask = torch.zeros_like(x0_unmask)
+ for i in range(x0.shape[0]):
+ eos_pos = (x0[i] == self.tokenizer.eos_token_id).nonzero(as_tuple=True)[0]
+ if len(eos_pos) > 0:
+ idx = random.randint(0, len(eos_pos) - 2)
+ x0_unmask[i, :] = True
+ if len(eos_pos) >= idx + 1:
+ _sl = slice(eos_pos[idx], None)
+ else:
+ _sl = slice(eos_pos[idx] + 2, eos_pos[idx+1] - 1)
+
+ x0_unmask[i, _sl] = (sample_modality[i, _sl] == 1)
+
+ # Set first sentence to be unmasked
+ for i in range(x0.shape[0]):
+ eos_pos = (x0[i] == self.tokenizer.eos_token_id).nonzero(as_tuple=True)[0]
+ if len(eos_pos) > 0:
+ assert (eos_pos[0] < 48) or (sample_modality[i].sum() == 0), f"eos_pos: {eos_pos}"
+ x0_unmask[i, :eos_pos[0]+1] = True
+
+ if unmask_modality == 1 and x0_unmask.sum() == 0:
+ x0_unmask = torch.ones_like(x0_unmask)
+ print(f"Found no umasked tokens, unmasking random sequences")
+ for i in range(x0.shape[0]):
+ seq_len = (x0[i] != self.tokenizer.pad_token_id).sum()
+ if seq_len == 0:
+ continue
+
+ start_pos = random.randint(0, seq_len-1)
+ max_len = min(seq_len - start_pos, 200)
+ unmask_len = random.randint(1, max_len)
+ x0_unmask[i, start_pos:start_pos+unmask_len] = False
+
+ gprint(f"Unmasking modality: {unmask_modality}, Unmasking {(x0_unmask.sum() / x0_unmask.numel()):.2%} of image tokens. Txt tokens: {(sample_modality == 0).sum()}, Img tokens: {(sample_modality == 1).sum()}")
+
+ x0_unmask[~example_batch['attention_mask']] = True
+ x = self._sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device)
+ decode_kwargs['x0_unmask'] = x0_unmask
+ x = torch.where(x0_unmask, x0, x)
+
+ if getattr(self.config.eval, "visualize_sample", False):
+ x_viz = x.clone()
+ _mask_id = self.tokenizer("mask")['input_ids']
+ assert len(_mask_id) == 3
+ _mask_id = _mask_id[1]
+ x_viz[x == self.mask_index] = _mask_id
+ self.decode_sampling(x_viz, text_only, **kwargs, **decode_kwargs, image_save_postfix="_x0_unmasked")
+
+ if self.parameterization == "ar" or getattr(self.config.eval, "eval_large_batch", None) is not None:
+ rprint(f"Masking all tokens by default.")
+ x0_unmask = torch.zeros(*x0.shape, device=x0.device).to(torch.bool)
+ else:
+ rprint(f"Hit chamelon sample")
+ if sample_batch_idx == getattr(self.config.eval, "num_uncond_sample_batches", 1) - 1:
+ x0_unmask = torch.zeros(*x0.shape, device=x0.device, dtype=torch.bool)
+ x0_unmask[..., -20:] = True
+ rprint(f"Unmasking first {x0_unmask.shape[-1] // 2} tokens")
+ else:
+ x0_unmask = torch.rand(*x0.shape, device=x0.device) < (sample_batch_idx / 60)
+ rprint(f"Unmasking {(sample_batch_idx / 60)} of image_tokens, {x0_unmask.sum()}")
+
+ x = self._sample_prior(batch_size_per_gpu, x0.shape[1]).to(self.device)
+ _img_indices = torch.isin(x0, torch.tensor(list(image_indices), device=self.device))
+ if getattr(self.config.eval, "unmask_chameleon_txt", False):
+ rprint(f"Unmasking all text tokens")
+ x0_unmask |= _img_indices
+ x0_unmask[:, :4] = True
+ rprint(f"All tokens: {x0_unmask.tolist()}")
+ # assert sample_modality is None
+ # decode_kwargs['sample_modality'] = torch.isin(x0, torch.tensor(list(image_indices), device=self.device)).to(torch.long)
+ else:
+ x0_unmask |= (~_img_indices)
+
+ kwargs['forward_attention_mask'] = attention_mask
+ decode_kwargs['image_indices'] = image_indices
+ decode_kwargs['x0_unmask'] = x0_unmask
+ rprint(f"Unmasking: {torch.sum(x0_unmask)}")
+ else:
+ x = self._sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device)
+ decode_kwargs['x0_unmask'] = x0_unmask
+
+ if self.config.trainer.interleaved_training_flex_attention:
+ assert 'sample_ids' in kwargs
+ kwargs['block_mask'] = get_interleaved_block_mask(kwargs['sample_ids'], x.shape[0], x.shape[-1], self.device)
+
+ if num_steps > (~x0_unmask).sum(dim=-1).min():
+ rprint(f"num_steps {num_steps} > sequence length {(~x0_unmask).sum(dim=-1).min()}, setting num_steps to sequence length")
+ num_steps = (~x0_unmask).sum(dim=-1).min()
+
+ if self.parameterization == "ar":
+ with show_memory_usage(empty_cache=True):
+ out, nfe_cnt = self._ar_sampler(batch_size_per_gpu, x0=x0, x0_unmask=x0_unmask, **kwargs)
+ res = self.decode_sampling(out, text_only, **kwargs, **decode_kwargs)
+ if return_nfe:
+ return res, nfe_cnt
+ return res
+
+ if x0 is not None and x0_unmask is not None:
+ x = torch.where(x0_unmask, x0, x)
+
+ if self.sampler == "maskgit" or self.sampler == "first_hitting" or self.sampler == "maskgit_nucleus":
+ sampling_schedule = 'arccos' if self.sampler in ['maskgit', 'maskgit_nucleus'] else 'linear'
+
+ # v1
+ # schedule = adap_sche(num_steps, mode=sampling_schedule, seq_len=x.shape[-1], leave=False)
+
+ # v2
+ # make seq length equal to max number of masked tokens in any sample in the batch
+ # Calculate the number of masked tokens for each sample in the batch
+ # num_masked = (x == self.mask_index).sum(dim=-1)
+ # Get the maximum number of masked tokens across all samples
+ # min_masked = num_masked.min().item()
+ # schedule = adap_sche(num_steps, mode=sampling_schedule, seq_len=min_masked, leave=False)
+
+ # v3 - use x shape
+ schedule = adap_sche(x=x, step=num_steps, mask_index=self.mask_index, mode=sampling_schedule)
+ print(f"schedule: {schedule}")
+
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
+ dt = (1 - eps) / num_steps
+ p_x0_cache = None
+
+ is_x_sliced = False
+ attention_caching = self.config.eval.attention_caching
+ attention_caching_txt_to_img_ratio = getattr(self.config.eval, "attention_caching_txt_to_img_ratio", 10)
+ if attention_caching:
+ backbone = self.accelerator.unwrap_model(self.backbone)
+ backbone.set_flex_attention_cache(x.shape[0], x.shape[1], self.device, self.dtype)
+ full_data = dict()
+ x_next = None
+
+ # At the beginning of _sample method, after initializing variables
+ if getattr(self.config.eval, "visualize_denoising", False):
+ denoising_steps = [x.clone()]
+
+ for i in range(num_steps):
+ t = timesteps[i] * torch.ones(x.shape[0], 1, device=self.device, dtype=self.dtype if self.config.trainer.force_bf16_eval else torch.float32)
+ if attention_caching:
+ if i % attention_caching_txt_to_img_ratio == 0:
+ if is_x_sliced:
+ def replace_new_data(_key, _new_data):
+ if full_data[_key] is not None:
+ full_data[_key][:,self.static_txt_sl] = _new_data
+ return full_data[_key]
+
+ x = replace_new_data("x", x)
+ x0 = replace_new_data("x0", x0)
+ x0_unmask = replace_new_data("x0_unmask", x0_unmask)
+ p_x0_cache = replace_new_data("p_x0_cache", p_x0_cache)
+ kwargs["modality"] = replace_new_data("modality", kwargs.get("modality", None))
+ del full_data
+ full_data = dict()
+ is_x_sliced = False
+
+ update_cache_slice = None
+ block_mask = True
+ elif (i - 1) % attention_caching_txt_to_img_ratio == 0:
+ update_cache_slice = slice(0, x.shape[1])
+ block_mask = get_block_mask(
+ txt_batch_attn_dropout=torch.zeros(x.shape[0], dtype=torch.bool, device=x.device),
+ img_batch_attn_dropout=torch.ones(x.shape[0], dtype=torch.bool, device=x.device),
+ txt_length=self.config.model.txt_length,
+ batch_size=x.shape[0],
+ seq_len=x.shape[1],
+ device=x.device
+ )
+ else:
+ update_cache_slice = self.static_txt_sl
+ block_mask = True
+ if not is_x_sliced:
+ is_x_sliced = True
+
+ def clone_if_valid(_data):
+ if _data is not None:
+ return _data.clone()
+ else:
+ return None
+
+ def sl_if_valid(_data):
+ if _data is not None:
+ return _data[:, self.static_txt_sl]
+ else:
+ return None
+
+ full_data.update(x=clone_if_valid(x), x0=clone_if_valid(x0), x0_unmask=clone_if_valid(x0_unmask), modality=clone_if_valid(kwargs.get("modality", None)), p_x0_cache=clone_if_valid(p_x0_cache))
+ x = sl_if_valid(x)
+ x0 = sl_if_valid(x0)
+ x0_unmask = sl_if_valid(x0_unmask)
+ x_next = sl_if_valid(x_next)
+ p_x0_cache = sl_if_valid(p_x0_cache)
+ kwargs["modality"] = sl_if_valid(kwargs.get("modality", None))
+
+ kwargs["update_cache_slice"] = update_cache_slice
+ kwargs["block_mask"] = block_mask
+
+ if self.sampler == "maskgit":
+ x, nfe_step_cnt = self._maskgit_update(x, t, dt, x0=x0, x0_unmask=x0_unmask, schedule=schedule, step=i, **kwargs)
+ elif self.sampler == "maskgit_nucleus":
+ x, nfe_step_cnt = self._maskgit_nucleus_update(x, t, dt, x0=x0, x0_unmask=x0_unmask, schedule=schedule, step=i, **kwargs)
+ elif self.sampler == "first_hitting":
+ x, nfe_step_cnt = self._first_hitting_update(x, t, dt, x0=x0, x0_unmask=x0_unmask, schedule=schedule, step=i, **kwargs)
+ elif self.sampler == "ddpm":
+ x, nfe_step_cnt = self._ddpm_update(x, t, dt, x0=x0, x0_unmask=x0_unmask, **kwargs)
+ elif self.sampler == "ddpm_tweedie":
+ assert not return_nfe, "Tweedie sampler does not support return_nfe"
+ x = self._ddpm_update_finetune_controlled_tweedie(x, t, dt, sampling_step=i, **kwargs)
+ nfe_step_cnt = 0
+ elif self.sampler == "ddpm_cache":
+ p_x0_cache, x_next, nfe_step_cnt = self._ddpm_caching_update(x, t, dt, p_x0=p_x0_cache, x0=x0, x0_unmask=x0_unmask, **kwargs)
+ if not torch.allclose(x_next, x) or self.time_conditioning:
+ p_x0_cache = None # Disable caching
+ x = x_next
+ else:
+ x, nfe_step_cnt = self._analytic_update(x, t, dt)
+
+ nfe_cnt += nfe_step_cnt
+ if self.tokenizer.eos_token_id in x and getattr(self.config.trainer, "force_after_eos_padding", False) and (self.tokenizer.eos_token_id != self.tokenizer.bos_token_id) and not attention_caching:
+ after_first_eos = torch.cumsum(x == self.tokenizer.eos_token_id, dim=1).bool()
+ after_first_eos_mask = after_first_eos.cumsum(dim=1) > 1
+ to_mask = ((after_first_eos_mask & (sample_modality == 0)) & (x != self.tokenizer.pad_token_id)) & (x != self.mask_index)
+ x[to_mask] = self.tokenizer.pad_token_id
+
+ if to_mask.sum() > 0:
+ rprint(f"Masked an avg of {torch.sum(to_mask, dim=1).float().mean()} tokens due to EOS.")
+
+ if x0 is not None and x0_unmask is not None: x = torch.where(x0_unmask, x0, x)
+
+ # Add capture of current state for visualization
+ if getattr(self.config.eval, "visualize_denoising", False) and i % getattr(self.config.eval, "visualize_step_interval", max(1, num_steps // 10)) == 0:
+ denoising_steps.append(x.clone())
+
+ clear_gpu_memory_if_needed()
+
+ if getattr(self.config.eval, "visualize_denoising", False) and denoising_steps:
+ if denoising_steps[-1] is not x:
+ denoising_steps.append(x.clone())
+
+ step_images = []
+ for step_x in denoising_steps:
+ _, step_res = self.decode_sampling(step_x, text_only=False, bypass_return_interleaved_modalities_split=True, **kwargs, **decode_kwargs)
+ if not isinstance(step_res, Image.Image):
+ step_res = step_res[0]
+ step_images.append(step_res)
+
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ date_folder = datetime.now().strftime("%Y-%m-%d")
+ save_dir = Path("/dev/shm") / os.getenv("USER", 'user') / "denoise_vis" / date_folder / f"{timestamp}.png"
+ save_dir.parent.mkdir(parents=True, exist_ok=True)
+ Im.concat_horizontal(step_images).save(save_dir)
+ rprint(f"Saved denoising visualization to {save_dir}")
+
+ if is_x_sliced:
+ def replace_new_data(_key, _new_data):
+ if full_data[_key] is not None:
+ full_data[_key][:,self.static_txt_sl] = _new_data
+ return full_data[_key]
+
+ x = replace_new_data("x", x)
+ x0 = replace_new_data("x0", x0)
+ x0_unmask = replace_new_data("x0_unmask", x0_unmask)
+ p_x0_cache = replace_new_data("p_x0_cache", p_x0_cache)
+ kwargs["modality"] = replace_new_data("modality", kwargs.get("modality", None))
+ del full_data
+ full_data = dict()
+ is_x_sliced = False
+
+ if self.config.sampling.noise_removal:
+ t = timesteps[-1] * torch.ones(x.shape[0], 1, device=self.device)
+ if self.sampler == "analytic":
+ x = self._denoiser_update(x, t)
+ else:
+ unet_conditioning = self.noise(t)[0]
+ x = self.forward(x=x, sigma=unet_conditioning, **kwargs).argmax(dim=-1)
+
+ if x0 is not None and x0_unmask is not None:
+ x = torch.where(x0_unmask, x0, x)
+ res = self.decode_sampling(x, text_only, **kwargs, **decode_kwargs)
+
+ if return_nfe:
+ return res, nfe_cnt
+ return res
+
+def decode_sampling(self, x, text_only, is_sample_masking=False, bypass_return_interleaved_modalities_split=False, **kwargs):
+ if self.config.trainer.interleaved and getattr(self.config.eval, "return_interleaved_modalities_split", False) and not bypass_return_interleaved_modalities_split:
+ decoded_data = self.decode_batch({"input_ids": x, **kwargs}, text_only=False)
+ image_save_postfix = kwargs.get("image_save_postfix", None)
+ assert len(decoded_data) == 1
+ all_imgs = []
+ all_txt = []
+ for i in range(min(len(decoded_data), 64)):
+ sample_data, sample_modalities = decoded_data[i].to_list()
+ ret = self.get_interleaved_image(sample_data, sample_modalities, image_save_postfix=image_save_postfix)
+ all_txt_in_sample = []
+ all_img_in_sample = []
+ for j in range(len(sample_data)):
+ if sample_modalities[j] == 0:
+ text_samples = sample_data[j]
+ pred_txt = wrapped_batch_decode(
+ self.tokenizer, text_samples[None], clean_up_tokenization_spaces=False, skip_special_tokens=False, disable_mask_after_eos=True
+ )
+ all_txt_in_sample.extend(pred_txt)
+ else:
+ img_samples = sample_data[j]
+ pred_img = decode_latents(self.config, self.get_vae(), img_samples[None])
+ all_img_in_sample.extend([Im(x).pil for x in pred_img])
+
+ # in case we have text..... This causes [" text...", ""], which we merge below.
+ if len(all_txt_in_sample) >= 2 and all_txt_in_sample[-1] == self.tokenizer.eos_token:
+ all_txt_in_sample[-2] += all_txt_in_sample[-1]
+ all_txt_in_sample.pop()
+
+ all_txt.extend(all_txt_in_sample)
+ all_imgs.extend(all_img_in_sample)
+
+ print(f"Returning... all_txt: {all_txt}, all_imgs: {all_imgs}")
+ for i in range(len(all_imgs)):
+ filename = f"img_{get_rank()}_{str(time.time()).replace('.', '__')}.png"
+ Im(all_imgs[i]).save(filename)
+ return all_txt, all_imgs
+ elif (self.config.trainer.interleaved and not is_sample_masking) or getattr(self.config.eval, "fake_interleaved", False):
+ image_save_postfix = kwargs.get("image_save_postfix", None)
+ decoded_data = self.decode_batch({"input_ids": x, **kwargs}, text_only=False)
+ all_imgs = []
+ all_txt_ids = []
+ num_text_tokens = self.config.model.txt_length
+ for i in range(min(len(decoded_data), 64)):
+ sample_data, sample_modalities = decoded_data[i].to_list()
+ all_imgs.append(self.get_interleaved_image(sample_data, sample_modalities, image_save_postfix=image_save_postfix))
+ all_txt_ids_in_sample = []
+ for j in range(len(sample_data)):
+ if sample_modalities[j] == 0:
+ text_samples = sample_data[j]
+ if text_samples.shape[-1] < num_text_tokens:
+ text_samples = torch.nn.functional.pad(
+ text_samples,
+ (0, num_text_tokens - text_samples.shape[-1]),
+ value=self.tokenizer.pad_token_id
+ )
+ else:
+ text_samples = text_samples[..., :num_text_tokens]
+ all_txt_ids_in_sample.append(text_samples)
+
+ if len(all_txt_ids_in_sample) == 0:
+ all_txt_ids_in_sample.append(torch.zeros((num_text_tokens), dtype=torch.long, device=self.device))
+
+ all_txt_ids.append(torch.cat(all_txt_ids_in_sample, dim=0))
+
+ if kwargs.get("return_raw_data", False):
+ return all_txt_ids, all_imgs, x
+
+ return all_txt_ids, all_imgs
+ else:
+ ret = self.decode_batch(x, text_only=text_only, **kwargs)
+ if getattr(self.config.eval, "visualize_sample", False):
+ self.save_image_text_pair(ret[1], ret[0][:, self.static_txt_sl])
+ return ret
+
+
+@tensorclass
+class InputData:
+ # x0: Float[Tensor, "b c h w"]
+ xt_ids: Integer[Tensor, "b h w c"]
+
+ # x0_emb: Optional[Float[Tensor, "b h w 2"]] = None
+ xt_img_embed: Optional[Float[Tensor, "b h w 2"]] = None
+ modality: Bool[Tensor, "b h w"] = False
+ sigma: Optional[Float[Tensor, "b"]] = None
+
+@torch.no_grad()
+def sample_transfusion(
+ self,
+ batch_size_per_gpu=None,
+ text_only=False, # todo maybe make default True
+):
+ """Generate samples from the model in autoregressive discrete mode for text and diffusion for image."""
+ # x0 = example_batch["input_ids"] # (for img tokens?)
+ # x0_emb = batch["img_emb"]
+ B = batch_size_per_gpu if batch_size_per_gpu is not None else self.config.loader.eval_batch_size
+ T = self.config.model.length
+ C = self.config.model.downscale_ratio # = vae_latent_dim * (patching_downscale ** 2)
+ # num_pred_tokens = T - 1
+ num_img_tokens = self.config.model.img_length
+ num_img_diffusion_steps = self.config.sampling.steps
+
+ # TODO @sid This should be what we want? but for interleaved, we should prob add eos after text and before start img.
+ xt_ids = torch.full((B, T), fill_value=self.tokenizer.pad_token_id, dtype=torch.long, device=self.device)
+ xt_ids[:, 0] = self.tokenizer.bos_token_id
+ xt_img_embed = torch.zeros((B, T, C), device=self.device)
+ modality = torch.zeros((B, T), dtype=torch.long, device=self.device) # assuming everything is text initially
+ sigma = torch.zeros((B, T), dtype=self.dtype, device=self.device)
+ data = InputData(xt_ids=xt_ids, xt_img_embed=xt_img_embed, modality=modality, sigma=sigma, batch_size=[B])
+
+ noise = torch.distributions.Gumbel(0, 1).sample((data.shape[0], T, self.vocab_size)).to(self.device)
+ img_start_token_id = self.tokenizer.eos_token_id
+ i = 1 # since we already have
+ continuous_diffusion_mode = False
+ while i < T:
+ if continuous_diffusion_mode:
+ # Diffusing mode
+ img_sl = slice(i, i+num_img_tokens)
+ data.modality[:, img_sl] = 1
+ data.xt_img_embed[:, img_sl] = self.sample_continuous_image(data, img_sl=img_sl, num_steps=num_img_diffusion_steps, return_embeddings=True) # (b, n_img, latent_dim * 4)
+ i += num_img_tokens
+ continuous_diffusion_mode = False
+ break
+ else:
+ # autoregressive mode
+ ar_sl = slice(None, i)
+ if self.use_kv_cache:
+ start_pos = i - 1
+ kv_sl = slice(start_pos, i)
+ else:
+ kv_sl = ar_sl
+ start_pos=None
+ pred_logits, pred_noise = self.forward(x=data.xt_ids[:, kv_sl], sigma=data.sigma[:, ar_sl], modality=data.modality[:, ar_sl], x_img_emb=data.xt_img_embed[:, ar_sl], disable_ar_shift=True, continuous_mode=True, start_pos=start_pos)
+ pred_logits = pred_logits[:, -1]
+ y = (pred_logits + noise[:, i]).argmax(-1)
+ # y = (pred_logits).argmax(-1)
+
+ data.xt_ids[:, i] = y
+ # data.xt_ids[:, i + 1] = y
+ i += 1
+ if not text_only and (i == self.config.model.txt_length-1 or torch.all(y == img_start_token_id)): # todo make variable
+ continuous_diffusion_mode = True
+
+ if self.config.model.use_kv_cache:
+ backbone = self.accelerator.unwrap_model(self.backbone)
+ backbone.reset_kv_cache(batch_size=self.config.model.inference_max_batch_size, seq_len=self.config.model.inference_max_seq_len, dtype=self.dtype, device=self.device)
+
+ return data
+
+def sample_continuous_image(self, data: InputData, img_sl, num_steps=None, return_embeddings=False):
+ if num_steps is None:
+ num_steps = self.config.sampling.steps
+ B = data.xt_img_embed.shape[0]
+ noise_scheduler = self.vae.scheduler
+ noise_scheduler.set_timesteps(num_steps, device=self.device)
+ timesteps = noise_scheduler.timesteps
+ data.xt_img_embed[:, img_sl] = torch.randn_like(data.xt_img_embed[:, img_sl])
+
+ visible_sl = slice(None, img_sl.stop)
+ for i in range(num_steps+1):
+ data.sigma[:, img_sl] = (timesteps[i] * torch.ones(B, device=self.device)).unsqueeze(-1)
+ pred_logits, pred_noise = self.forward(
+ x=data.xt_ids[:, visible_sl], sigma=data.sigma[:, visible_sl], x_img_emb=data.xt_img_embed[:, visible_sl], modality=data.modality[:, visible_sl], disable_ar_shift=True, continuous_mode=True
+ ) # exp not needed since we predict noise (b,n,c) in latent space directly, not a probability distribution
+ data.xt_img_embed[:, img_sl] = noise_scheduler.step(pred_noise[:, img_sl], timesteps[i], data.xt_img_embed[:, img_sl]).prev_sample
+
+ if return_embeddings: return data.xt_img_embed[:, img_sl] # (b, n_img, latent_dim * 4)
+
+ # x = x.transpose(1, 2)
+ # data.xt_img_embed[:, img_sl] = data.xt_img_embed[:, img_sl].transpose(1, 2)
+ text_tokens, img_tokens = self.decode_batch(data.xt_ids[:, img_sl], text_only=False)
+ return text_tokens, img_tokens
+
+
+def cfg(config, t, cat_p_x0):
+ logit_c, logit_u = cat_p_x0.chunk(2, dim=0)
+ _cfg = config.eval.cfg
+ if not getattr(config.eval, "force_cfg_value", False):
+ if _cfg == -1:
+ _cfg = torch.linspace(0, 10, t.shape[0]).to(t.device)
+ _w = (_cfg * (1 - t))[:, None, None]
+ else:
+ _w = _cfg
+
+ return (1 + _w) * logit_c - _w * logit_u
+
+def nucleus_sampling_batch(logits, top_p=0.9, temperature=1.0):
+ """
+ Perform nucleus (top-p) sampling on batched and sequenced logits.
+
+ Args:
+ logits (torch.Tensor): A tensor of shape (B, N, C) where B is the batch size,
+ N is the sequence length, and C is the number of classes.
+ top_p (float): The cumulative probability threshold for nucleus sampling.
+ temperature (float): Temperature value for scaling logits.
+
+ Returns:
+ torch.Tensor: Indices sampled from the filtered distribution for each position,
+ with shape (B, N).
+ """
+ B, N, C = logits.shape
+
+ # Apply softmax to get probabilities
+ # probs = torch.nn.functional.softmax(logits / temperature, dim=-1) # Shape: (B, N, C)
+ probs = logits / temperature
+
+ # Sort the probabilities in descending order
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1) # Both shape: (B, N, C)
+
+ # Compute the cumulative sum of probabilities
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # Shape: (B, N, C)
+
+ # Create a mask for top-p
+ mask = cumulative_probs <= top_p # Shape: (B, N, C)
+
+ # Ensure at least one token is included
+ mask[:, :, 0] = True
+
+ # Apply the mask to the sorted probabilities
+ filtered_probs = sorted_probs * mask.float() # Shape: (B, N, C)
+
+ # Renormalize the probabilities
+ filtered_probs /= filtered_probs.sum(dim=-1, keepdim=True) # Shape: (B, N, C)
+
+ # Sample from the renormalized distribution
+ sampled_indices = torch.multinomial(filtered_probs.view(-1, C), num_samples=1).squeeze(-1) # Shape: (B*N)
+
+ # Reshape sampled_indices to (B, N)
+ sampled_indices = sampled_indices.view(B, N)
+
+ # Gather the original indices based on sorted_indices
+ final_indices = torch.gather(sorted_indices, -1, sampled_indices.unsqueeze(-1)).squeeze(-1) # Shape: (B, N)
+
+ return final_indices
+
+def nucleus_sampling(logits, top_p=0.9, temperature=1.0):
+ """
+ Perform nucleus (top-p) sampling on the given logits.
+
+ Args:
+ logits (torch.Tensor): A tensor of shape (B, C) where B is the batch size
+ and C is the number of classes.
+ top_p (float): The cumulative probability threshold for nucleus sampling.
+
+ Returns:
+ torch.Tensor: Indices sampled from the filtered distribution.
+ """
+ # Apply softmax to get probabilities
+ probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
+
+ # Sort the probabilities in descending order and get the sorted indices
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
+
+ # Compute the cumulative sum of probabilities along the last dimension
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
+
+ # Create a mask to filter out probabilities that contribute to top_p mass
+ mask = cumulative_probs <= top_p
+
+ # Ensure at least one token is always included
+ mask[..., 0] = True # Always include the most probable token
+
+ # Zero out probabilities that are not part of the top-p mass
+ filtered_probs = sorted_probs * mask.float()
+
+ # Renormalize the filtered probabilities
+ filtered_probs /= (filtered_probs.sum(dim=-1, keepdim=True))
+ # Sample from the renormalized distribution
+ sampled_indices = torch.multinomial(filtered_probs, num_samples=1)[:, 0]
+ # Map back to original indices
+ final_indices = sorted_indices.gather(dim=-1, index=sampled_indices.unsqueeze(-1)).squeeze(-1)
+
+ return final_indices
+
+def clear_gpu_memory_if_needed():
+ if torch.cuda.is_available():
+ current_memory = torch.cuda.memory_reserved() / torch.cuda.get_device_properties(0).total_memory
+ if current_memory >= 0.50:
+ torch.cuda.empty_cache()
+
+def _ar_sampler(self, B, x0=None, x0_unmask=None, modality=None, **kwargs):
+ assert B > 0
+ assert (x0 is None) == (x0_unmask is None), f"x0: {x0} x0_unmask: {x0_unmask}"
+ num_pred_tokens = self.config.model.length - 1
+ x = torch.zeros((B, num_pred_tokens + 1), dtype=torch.long, device=self.device)
+ x[:, 0] = self.tokenizer.bos_token_id
+ if x0 is not None: x = torch.where(x0_unmask, x0, x)
+ split_cfg_batches = getattr(self.config.eval, "split_cfg_batches", False) and not self.config.model.use_kv_cache
+ effective_bs = B * 2 if ((self.config.eval.cfg is not None and x0 is not None) and split_cfg_batches is False) else B
+ top_p = getattr(self.config.eval, "top_p", None)
+ temperature = getattr(self.config.eval, "temperature", 1.0)
+ if self.config.model.use_kv_cache:
+ assert getattr(self.config.model, "inference_max_batch_size", None) is None
+ assert getattr(self.config.model, "inference_max_seq_len", None) is None
+ self.accelerator.unwrap_model(self.backbone).reset_kv_cache(
+ batch_size=effective_bs,
+ seq_len=num_pred_tokens,
+ dtype=self.dtype,
+ device=self.device
+ )
+
+ _x, _modality = None, None
+ if self.config.eval.cfg is not None and x0 is not None:
+ if split_cfg_batches is False:
+ _x = torch.cat([x, torch.where(x0_unmask, self.mask_index, x)], dim=0)
+ _modality = torch.cat([modality, modality], dim=0)
+
+ nfe_cnt = 0
+ noise = torch.distributions.Gumbel(0, 1).sample((B, num_pred_tokens, self.vocab_size)).to(self.device) # precompute noise
+ for i in range(num_pred_tokens):
+ start_pos = i if self.use_kv_cache else None
+ ar_sl = slice(start_pos, i+1)
+
+ if self.config.eval.cfg is not None and x0 is not None:
+ if split_cfg_batches:
+ logit_c = self.forward(
+ x=x[:, ar_sl], sigma=None, modality=modality[:, ar_sl], start_pos=start_pos, disable_ar_shift=True
+ )[:, -1]
+ logit_u = self.forward(
+ x=torch.where(x0_unmask, self.mask_index, x)[:, ar_sl], sigma=None, modality=modality[:, ar_sl], start_pos=start_pos, disable_ar_shift=True
+ )[:, -1]
+ else:
+ _x[:B] = x
+ _x[B:] = torch.where(x0_unmask, self.mask_index, x)
+ next_logits = self.forward(x=_x[:, ar_sl], sigma=None, modality=_modality[:, ar_sl], start_pos=start_pos, disable_ar_shift=True)[:, -1]
+ logit_c, logit_u = next_logits.chunk(2, dim=0)
+
+ _w = self.get_cfg_weight(1 - (i / num_pred_tokens))
+ next_logits = (1 + _w) * logit_c - _w * logit_u
+ else:
+ next_logits = self.forward(x=x[:, ar_sl], sigma=None, modality=modality[:, ar_sl], start_pos=start_pos, disable_ar_shift=True)[:, -1]
+
+ if getattr(self.config.model, "force_argmax_valid_indices", False):
+ # start_pos = i
+ next_sl = slice(i + 1, i + 2)
+ try:
+ next_logits[..., self.text_vocab_size:] = torch.where((modality[:, next_sl] == 0), torch.finfo(next_logits.dtype).min, next_logits[..., self.text_vocab_size:])
+ next_logits[..., :self.text_vocab_size] = torch.where((modality[:, next_sl] == 1), torch.finfo(next_logits.dtype).min, next_logits[..., :self.text_vocab_size])
+ except:
+ breakpoint()
+ if top_p is not None:
+ # do nucleus sampling
+ y = nucleus_sampling(next_logits, top_p=top_p, temperature=temperature)
+ else:
+ next_logits = next_logits + noise[:, i]
+ nfe_cnt += 1
+ y = (next_logits).argmax(-1)
+ x[:, i + 1] = y
+ if x0 is not None: x = torch.where(x0_unmask, x0, x)
+ if not self.config.model.use_kv_cache:
+ empty_device_cache()
+
+ if getattr(self.config.eval, "force_empty_cache", False):
+ empty_device_cache()
+
+ if self.config.model.use_kv_cache:
+ # TODO: PyTorch must have a b
+ del noise, next_logits, _x, _modality
+ self.accelerator.unwrap_model(self.backbone).reset_kv_cache(
+ batch_size=effective_bs,
+ seq_len=num_pred_tokens,
+ dtype=self.dtype,
+ device=self.device,
+ set_to_none=True
+ )
+
+ return x, nfe_cnt
+
+def handle_interleaved_decode(self, sample, allow_mask_index=False, new_mask_index=None, **kwargs):
+ batch = sample
+ sample_modality = sample.get("modality", None)
+ sample = sample.get("input_ids", None)
+
+ text_tokens = torch.where(sample_modality == 0, sample, self.tokenizer.pad_token_id)
+ img_tokens = torch.where((sample_modality == 1), sample, self.mask_index)
+
+ invalid_text_mask = (text_tokens >= self.text_vocab_size) & (sample_modality == 0)
+ invalid_img_mask = (img_tokens < self.text_vocab_size) & (sample_modality == 1)
+ mask_img_mask = (img_tokens == self.mask_index) & (sample_modality == 1)
+
+ if invalid_text_mask.sum() > 0:
+ assert allow_mask_index or self.config.model.force_argmax_valid_indices is False or self.config.sampling.predictor == "ddpm_tweedie" or self.config.parameterization == "ar", f"invalid_text_mask.sum(): {invalid_text_mask.sum()}, {invalid_text_mask.nonzero()[:4]}"
+ text_tokens[invalid_text_mask] = self.mask_index
+
+ if new_mask_index is not None:
+ img_invalid_mask_v2 = ((img_tokens < self.text_vocab_size) & (img_tokens != self.mask_index))
+
+ sample = torch.where(sample_modality == 1, img_tokens - self.text_vocab_size, text_tokens)
+ if invalid_img_mask.sum() > 0 or mask_img_mask.sum() > 0:
+ if new_mask_index is not None:
+ assert img_invalid_mask_v2.sum().item() == 0
+ sample[mask_img_mask] = new_mask_index
+ else:
+ sample[mask_img_mask] = 0
+ sample[invalid_img_mask] = 0
+
+ new_batch = {**batch, "input_ids": sample}
+ new_batch = InterleavedBatch.custom_from_dict(new_batch)
+ new_batch = new_batch.to_elements()
+ return new_batch
+
+def decode_batch(self,
+ sample,
+ text_only=True,
+ return_masks: bool = False,
+ allow_mask_index: bool = False,
+ new_mask_index=None,
+ sample_modality=None,
+ **kwargs
+):
+
+ if isinstance(sample, dict) or isinstance(sample, TensorDict):
+ if self.config.trainer.interleaved or getattr(self.config.eval, "fake_interleaved", False):
+ return handle_interleaved_decode(self, sample, allow_mask_index=allow_mask_index, new_mask_index=new_mask_index, **kwargs)
+ else:
+ sample_modality = sample.get("modality", None)
+ sample = sample.get("input_ids", None)
+
+ img_tokens = None
+ continuous_mode = self.config.trainer.image_mode == "continuous"
+ if continuous_mode:
+ text_tokens, img_tokens = sample[..., self.static_txt_sl], sample[..., self.static_img_sl]
+ elif self.unified_model and self.config.trainer.multimodal_batches and sample_modality is not None:
+ if (sample_modality == 0).all(dim=-1).sum() > 0:
+ text_tokens = torch.where(sample_modality == 0, sample, self.tokenizer.pad_token_id)
+ img_tokens = torch.where((sample_modality == 1)[:, self.static_img_sl], sample[:, self.static_img_sl], self.mask_index)
+ else:
+ text_tokens = torch.where(sample_modality == 0, sample, self.tokenizer.pad_token_id)
+ img_tokens = torch.where((sample_modality == 1), sample, self.mask_index)
+
+ invalid_text_mask = text_tokens >= self.text_vocab_size
+ if getattr(self.config.model, "add_labels", None) is not None:
+ invalid_img_mask = (img_tokens < self.text_vocab_size) | (img_tokens >= (self.vocab_size - self.config.model.add_labels))
+ else:
+ invalid_img_mask = (img_tokens < self.text_vocab_size)
+ mask_text_mask = text_tokens == self.mask_index
+ mask_img_mask = img_tokens == self.mask_index
+ if invalid_text_mask.sum() > 0:
+ assert allow_mask_index or self.config.model.force_argmax_valid_indices is False or self.config.sampling.predictor == "ddpm_tweedie" or self.config.parameterization == "ar", f"invalid_text_mask.sum(): {invalid_text_mask.sum()}, {invalid_text_mask.nonzero()[:4]}"
+ text_tokens[invalid_text_mask] = self.mask_index
+
+ if new_mask_index is not None:
+ img_invalid_mask_v2 = ((img_tokens < self.text_vocab_size) & (img_tokens != self.mask_index))
+
+ img_tokens = img_tokens - self.text_vocab_size
+ if invalid_img_mask.sum() > 0 or mask_img_mask.sum() > 0:
+ if new_mask_index is not None:
+ assert img_invalid_mask_v2.sum().item() == 0
+ img_tokens[mask_img_mask] = new_mask_index
+ else:
+ img_tokens[mask_img_mask] = 0
+ img_tokens[invalid_img_mask] = 0
+
+ if img_tokens.shape[-1] != self.config.model.img_length:
+ if (sample_modality[:, -self.config.model.img_length:].sum(dim=-1) == self.config.model.img_length).all():
+ img_tokens = img_tokens[:, -self.config.model.img_length:]
+ elif (sample_modality[:, :self.config.model.img_length].sum(dim=-1) == self.config.model.img_length).all():
+ img_tokens = img_tokens[:, :self.config.model.img_length]
+
+ elif self.unified_model:
+ text_tokens, img_tokens = sample[..., self.static_txt_sl], sample[..., self.static_img_sl]
+ invalid_text_mask = text_tokens >= self.text_vocab_size
+ invalid_img_mask = img_tokens < self.text_vocab_size
+ mask_text_mask = text_tokens == self.mask_index
+ mask_img_mask = img_tokens == self.mask_index
+
+ if invalid_text_mask.sum() > 0:
+ assert allow_mask_index or self.config.model.force_argmax_valid_indices is False or self.config.sampling.predictor == "ddpm_tweedie" or self.config.parameterization == "ar", f"invalid_text_mask.sum(): {invalid_text_mask.sum()}"
+ text_tokens[invalid_text_mask] = self.mask_index
+
+ if new_mask_index is not None:
+ img_invalid_mask_v2 = ((img_tokens < self.text_vocab_size) & (img_tokens != self.mask_index))
+
+ img_tokens = img_tokens - self.text_vocab_size
+ if invalid_img_mask.sum() > 0 or mask_img_mask.sum() > 0:
+ assert allow_mask_index or self.config.model.force_argmax_valid_indices is False or self.config.sampling.predictor == "ddpm_tweedie" or self.config.parameterization == "ar", f"invalid_img_mask.sum(): {invalid_img_mask.sum()}"
+ if new_mask_index is not None:
+ assert img_invalid_mask_v2.sum().item() == 0
+ img_tokens[mask_img_mask] = new_mask_index
+ else:
+ img_tokens[mask_img_mask] = 0
+ img_tokens[invalid_img_mask] = 0
+
+ try:
+ assert img_tokens.shape[-1] == self.config.model.img_length, f"img_tokens.shape[-1]: {img_tokens.shape[-1]}, config.model.img_length: {self.config.model.img_length}, sample_modality: {sample_modality}"
+ except:
+ breakpoint()
+
+ elif self.image_model:
+ text_tokens, img_tokens = None, sample
+ else:
+ text_tokens, img_tokens = sample, None
+ if text_only:
+ return text_tokens
+ else:
+ if return_masks:
+ return text_tokens, img_tokens, mask_text_mask, mask_img_mask
+ else:
+ return text_tokens, img_tokens
+
+def optional_add_bos(self, _x, x0):
+ if self.config.trainer.ar_shift:
+ if x0 is not None:
+ _x = torch.cat([x0[:, [0]], _x], dim=1)
+ else:
+ _x = torch.cat([torch.full_like(_x[..., :1], fill_value=self.tokenizer.pad_token_id), _x], dim=1)
+ return _x
+
+def adap_sche(x, step, mask_index, mode="arccos"):
+ """ Create a 2D sampling scheduler
+ :param
+ x -> torch.Tensor: input tensor with shape (B, seq_len)
+ step -> int: number of prediction steps during inference
+ mode -> str: the rate of value to unmask
+ leave -> bool: tqdm arg on either to keep the bar or not
+ :return
+ scheduler -> torch.LongTensor(): 2D tensor of shape (B, max_seq_len) with schedules for each sample
+ """
+ num_masked = (x == mask_index).sum(dim=-1).to(x.device)
+
+ r = torch.linspace(1, 0, step)
+
+ if mode == "root":
+ val_to_mask = 1 - (r ** .5)
+ elif mode == "linear":
+ val_to_mask = 1 - r
+ elif mode == "square":
+ val_to_mask = 1 - (r ** 2)
+ elif mode == "cosine":
+ val_to_mask = torch.cos(r * math.pi * 0.5)
+ elif mode == "arccos":
+ val_to_mask = torch.arccos(r) / (math.pi * 0.5)
+ else:
+ return None
+ val_to_mask = val_to_mask.to(x.device)
+ schedules = []
+ for seq_len in num_masked:
+ print(f"seq_len: {seq_len}")
+ sche = (val_to_mask / val_to_mask.sum()) * seq_len
+ sche = sche.round()
+ sche[sche == 0] = 1
+ sche[-1] += seq_len - sche.sum()
+ sche[-1] = max(sche[-1], 0)
+ schedules.append(sche.int())
+
+ return torch.stack(schedules, dim=0)
+
+
+@torch.no_grad()
+def _first_hitting_update(self, x, t, dt, schedule=None, step=None, **kwargs):
+ sigma_t, _ = self.noise(t)
+ sigma_s, _ = self.noise(t - dt)
+ if sigma_t.ndim > 1:
+ sigma_t = sigma_t.squeeze(-1)
+ if sigma_s.ndim > 1:
+ sigma_s = sigma_s.squeeze(-1)
+ assert sigma_t.ndim == 1, sigma_t.shape
+ assert sigma_s.ndim == 1, sigma_s.shape
+ move_chance_t = 1 - torch.exp(-sigma_t)
+ move_chance_s = 1 - torch.exp(-sigma_s)
+ move_chance_t = move_chance_t[:, None, None]
+ move_chance_s = move_chance_s[:, None, None]
+
+ _sigma = None if getattr(self.config.trainer, "force_null_sigma", False) else sigma_t
+ nfe_cnt = 0
+ p_x0 = self._ddpm_forward(x, t, _sigma, **kwargs)
+ nfe_cnt += 1
+
+ copy_flag = (x != self.mask_index) # [B, N]
+
+ # TODO: inefficient that we sample all tokens even if we only want to unmask a few
+ _x = _sample_categorical(p_x0)
+
+ num_unmask = schedule[:, step]
+ num_unmask = torch.minimum(num_unmask, (~copy_flag).sum(dim=-1))
+ if torch.all(num_unmask <= 0):
+ return x, nfe_cnt
+
+ random_values = torch.rand_like(copy_flag, dtype=torch.float32)
+ random_values = torch.where(~copy_flag, random_values, -1)
+ _, indices = torch.sort(random_values, dim=-1, descending=True)
+ range_tensor = torch.arange(copy_flag.shape[-1], device=copy_flag.device).expand(copy_flag.shape)
+ final_mask = range_tensor < num_unmask[:, None]
+
+ result = torch.zeros_like(copy_flag)
+ result.scatter_(-1, indices, final_mask)
+
+ return torch.where(result, _x, x), nfe_cnt
+
+@torch.no_grad()
+def _maskgit_update(self, x, t, dt, schedule=None, step=None, **kwargs):
+ sigma_t, _ = self.noise(t)
+ sigma_s, _ = self.noise(t - dt)
+ if sigma_t.ndim > 1:
+ sigma_t = sigma_t.squeeze(-1)
+ if sigma_s.ndim > 1:
+ sigma_s = sigma_s.squeeze(-1)
+ assert sigma_t.ndim == 1, sigma_t.shape
+ assert sigma_s.ndim == 1, sigma_s.shape
+ move_chance_t = 1 - torch.exp(-sigma_t)
+ move_chance_s = 1 - torch.exp(-sigma_s)
+ move_chance_t = move_chance_t[:, None, None]
+ move_chance_s = move_chance_s[:, None, None]
+ nfe_cnt = 0
+ _sigma = None if getattr(self.config.trainer, "force_null_sigma", False) else sigma_t
+
+ copy_flag = (x != self.mask_index)
+ r_temp = getattr(self.config.eval, 'maskgit_r_temp', 10)
+ num_unmask = schedule[:, step]
+ # rprint(f"num_unmask: {num_unmask}, (~copy_flag).sum(dim=-1).max().item(): {(~copy_flag).sum(dim=-1).max().item()}")
+ num_unmask = torch.minimum(num_unmask, (~copy_flag).sum(dim=-1))
+ if torch.all(num_unmask <= 0):
+ return x, nfe_cnt
+
+ p_x0 = self._ddpm_forward(x, t, _sigma, **kwargs)
+ nfe_cnt += 1
+ pred_code = torch.multinomial(p_x0.view(-1, p_x0.shape[-1]), 1)[:, 0].view(p_x0.shape[:-1])
+ conf = torch.gather(p_x0, -1, pred_code.unsqueeze(-1)).squeeze(-1)
+
+ rand = r_temp * torch.from_numpy(np.random.gumbel(size=pred_code.shape)).to(self.device) * t
+ conf = torch.log(conf.squeeze()) + rand
+
+ if self.config.trainer.ar_shift:
+ copy_flag = copy_flag[:, 1:]
+
+ # do not predict on already predicted tokens
+ conf = torch.where(copy_flag, -torch.inf, conf)
+
+ # Choose the predicted tokens with the highest confidence
+ # Get the maximum num_unmask across the batch for top k
+ max_num_unmask = num_unmask.max().item()
+
+ # Use top k to get the highest confidence tokens
+ tresh_conf, indice_mask = torch.topk(conf, k=max_num_unmask, dim=-1)
+
+ # tresh_conf is [B, max_num_unmask]
+ # for each sample i, we want to get num_unmask[i] highest confidence tokens
+
+ # handle the case where num_unmask is 0 by setting the threshold to inf
+ gather_indices = torch.clamp(num_unmask - 1, min=0)[:, None]
+ tresh_conf = tresh_conf.gather(-1, gather_indices)
+ tresh_conf = torch.where((num_unmask <= 0)[:, None], torch.inf, tresh_conf)
+
+ # replace the chosen tokens
+ conf = (conf >= tresh_conf.expand_as(conf))
+ if self.config.trainer.ar_shift:
+ out = torch.where(conf, pred_code, x[:, 1:])
+ out = optional_add_bos(self, out, x0=kwargs.get("x0", None))
+ else:
+ out = torch.where(conf, pred_code, x)
+
+ if getattr(self.config.eval, "allow_token_updates", False):
+ out = torch.where(copy_flag, p_x0.argmax(dim=-1), out)
+
+ del conf, indice_mask, gather_indices, tresh_conf, pred_code, p_x0
+ if getattr(self.config.eval, "force_empty_cache", False):
+ empty_device_cache()
+
+ return out, nfe_cnt
+
+
+@torch.no_grad()
+def _maskgit_nucleus_update(self, x, t, dt, schedule=None, step=None, **kwargs):
+ nfe_cnt = 0
+ _sigma = None # sigma useless for non time-conditioned models like us
+
+ copy_flag = (x != self.mask_index)
+ if self.config.trainer.ar_shift:
+ copy_flag = copy_flag[:, 1:]
+
+ assert getattr(self.config.eval, 'maskgit_r_temp', None) != None
+ r_temp = getattr(self.config.eval, "maskgit_r_temp", 10)
+ num_unmask = schedule[:, step]
+ num_unmask = torch.minimum(num_unmask, (~copy_flag).sum(dim=-1))
+ if num_unmask <= 0:
+ return x, nfe_cnt
+
+ p_x0 = self._ddpm_forward(x, t, _sigma, **kwargs)
+ nfe_cnt += 1
+ top_p = getattr(self.config.eval, "top_p", 0.95)
+ temperature = getattr(self.config.eval, "temperature", 0.9)
+ if top_p is not None:
+ pred_code = nucleus_sampling_batch(p_x0, top_p=top_p, temperature=temperature)
+ else:
+ pred_code = torch.multinomial(p_x0.view(-1, p_x0.shape[-1]), 1)[:, 0].view(p_x0.shape[:-1]) # pred tokens?
+ conf = torch.gather(p_x0, -1, pred_code.unsqueeze(-1)).squeeze(-1)
+
+ rand = r_temp * torch.from_numpy(np.random.gumbel(size=pred_code.shape)).to(self.device) * t
+ conf = torch.log(conf.squeeze()) + rand
+
+ # do not predict on already predicted tokens
+ conf = torch.where(copy_flag, -torch.inf, conf)
+
+ # chose the predicted token with the highest confidence
+ # get the maximum num_unmask across the batch for top k
+ max_num_unmask = num_unmask.max().item()
+
+ tresh_conf, indice_mask = torch.topk(conf, k=max_num_unmask, dim=-1)
+
+ # for each sample i, we want to get num_unmask[i] highest confidence tokens
+ # handle the case where num_unmask is 0 by setting the threshold to inf
+ gather_indices = torch.clamp(num_unmask - 1, min=0)[:, None]
+ tresh_conf = tresh_conf.gather(-1, gather_indices.long())
+ tresh_conf = torch.where((num_unmask <= 0)[:, None], torch.inf, tresh_conf)
+
+ # replace the chosen tokens
+ conf = (conf >= tresh_conf)
+ if self.config.trainer.ar_shift:
+ out = torch.where(conf, pred_code, x[:, 1:])
+ out = optional_add_bos(self, out, x0=kwargs.get("x0", None))
+ else:
+ out = torch.where(conf, pred_code, x)
+ return out, nfe_cnt
+
+
+
+@torch.no_grad()
+def _ddpm_update_finetune_controlled_tweedie(self, x, t, dt, reward_model=None, repeats=10, sampling_step=None, **kwargs):
+ sigma_t, _ = self.noise(t)
+ sigma_s, _ = self.noise(t - dt)
+ if sigma_t.ndim > 1:
+ sigma_t = sigma_t.squeeze(-1)
+ if sigma_s.ndim > 1:
+ sigma_s = sigma_s.squeeze(-1)
+ assert sigma_t.ndim == 1, sigma_t.shape
+ assert sigma_s.ndim == 1, sigma_s.shape
+ move_chance_t = 1 - torch.exp(-sigma_t)
+ move_chance_s = 1 - torch.exp(-sigma_s)
+ move_chance_t = move_chance_t[:, None, None]
+ move_chance_s = move_chance_s[:, None, None]
+ _sigma = None if getattr(self.config.trainer, "force_null_sigma", False) else sigma_t
+ p_x0 = self._ddpm_forward(x, t, _sigma, **kwargs)
+ assert move_chance_t.ndim == p_x0.ndim
+
+ if self.config.trainer.force_bf16_eval: empty_device_cache()
+ q_xs = p_x0 * (move_chance_t - move_chance_s)
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
+ copy_flag = (x != self.mask_index).to(x.dtype)
+
+ del p_x0, move_chance_t, move_chance_s
+ resample_interval = getattr(self.config.eval, "tweedie_resample_interval", None)
+ return_single_sample = False
+ _repeats = repeats
+ if resample_interval is not None and sampling_step % resample_interval != 0:
+ _repeats = 1
+ return_single_sample = True
+
+ # Generate 10 samples for each position
+ samples = [copy_flag * x + (1 - copy_flag) * optional_add_bos(self, _sample_categorical(q_xs), x0=kwargs.get("x0", None)) for _ in range(_repeats)]
+
+ if return_single_sample:
+ return samples[0]
+
+ if not hasattr(self, "reward_model"):
+ from unidisc.tokenizers.laion_aesthetic_v2 import get_predictor_func
+ self.reward_model = get_predictor_func(self.device)
+ rprint("Using reward model. Should delete this after eval.")
+
+ # TODO: Make this more general (e.g., support interleaved text/image)
+ # Get scores for each sample
+ scores = []
+ expected_x0_args = []
+ for i in range(repeats):
+ # Use Tweedie's formula. Aim to calcuate r(E[x_0|x_t])
+ expected_x0 = self._ddpm_forward(samples[i], t, sigma_s, **kwargs) # Calcualte E[x_0|x_t]
+ if getattr(self.config.eval, "use_generic_tweedie_rewards", False):
+ assert self.config.trainer.interleaved
+ expected_x0_arg = torch.argmax(expected_x0, dim=-1)
+ expected_x0_args.append(expected_x0_arg)
+ assert samples[0].shape[0] == 1
+ else:
+ expected_x0[..., :self.text_vocab_size] = 0
+ expected_x0[..., self.mask_index] = 0
+ expected_x0[..., self.text_vocab_size:] = expected_x0[..., self.text_vocab_size:] + 1e-6
+ expected_x0_arg = torch.argmax(expected_x0, dim=-1)
+ expected_x0_arg = expected_x0_arg - self.text_vocab_size
+ expected_x0_img_pred = decode_latents(self.config, self.get_vae(), expected_x0_arg[:, self.static_img_sl])
+ scorer = self.reward_model(expected_x0_img_pred) # [B]
+
+ scorer = scorer.squeeze()
+ if scorer.ndim == 0:
+ scorer = scorer[None]
+ scores.append(torch.from_numpy(scorer))
+
+ if getattr(self.config.eval, "use_generic_tweedie_rewards", False):
+ orig_modality = kwargs.get("modality", None)
+ if orig_modality is not None:
+ orig_modality = orig_modality.clone()
+ kwargs["modality"] = orig_modality.repeat(len(expected_x0_args), 1)
+
+ orig_sample_ids = kwargs.get("sample_ids", None)
+ if orig_sample_ids is not None:
+ orig_sample_ids = orig_sample_ids.clone()
+ kwargs["sample_ids"] = orig_sample_ids.repeat(len(expected_x0_args), 1)
+
+ decoded_data = self.decode_batch({"input_ids": torch.cat(expected_x0_args, dim=0), **kwargs}, text_only=False)
+ kwargs["modality"] = orig_modality
+ kwargs["sample_ids"] = orig_sample_ids
+
+ all_imgs = []
+ all_txt_ids = []
+ for i in range(len(decoded_data)):
+ sample_data, sample_modalities = decoded_data[i].to_list()
+ assert len(sample_data) == 2
+ assert sample_modalities == [0, 1]
+ sample_text = wrapped_batch_decode(
+ self.tokenizer,
+ sample_data[0][None],
+ clean_up_tokenization_spaces=True,
+ skip_special_tokens=False,
+ disable_mask_after_eos=True
+ )
+ assert len(sample_text) == 1
+ all_txt_ids.append(sample_text[0])
+ all_imgs.append(self.get_interleaved_image(sample_data, sample_modalities, single_image_only=True, disable_img_save=True))
+
+ all_imgs = torch.cat(all_imgs, dim=0)
+ reward_config = getattr(self.config.eval, "tweedie_reward_config")
+ scores = self.get_rewards(reward_config, all_imgs, all_txt_ids).float().cpu()
+ scores = torch.softmax(scores, dim=0)[None]
+ else:
+ scores = torch.stack(scores, dim=1)
+ scores = torch.softmax(scores, dim=1) # Convert scores to probabilities for each batch
+
+ # Sample from the weighted categorical distribution formed by scores
+ # Select the index of the highest score for each batch
+ final_sample_indices = torch.argmax(scores, dim=1) # Shape [batch_size]
+ final_samples = [samples[final_sample_indices[j]][j,:] for j in range(x.size(0))] # Select the chosen samples using gathered indices
+ final_samples = torch.stack(final_samples, dim=0)
+ return final_samples
+
+@try_except(write_error_to_file=True, clear_cuda_cache=True)
+def visualize_samples(self, batch, batch_idx, split='val'):
+ split = split.removesuffix("/")
+ gt_txt = None
+ step_metrics = self.get_step_metrics()
+ step_metrics["trainer/global_step"] = (batch_idx if self.config.eval.visualize_data_only else self.global_step)
+ rprint('[IMPORTANT] Visualizing ground truth samples, verify tokenization')
+
+ if getattr(self.config.eval, "disable_visualization", False):
+ return
+
+ if self.config.trainer.interleaved:
+ decoded_data = self.decode_batch(batch, text_only=False)
+ all_imgs = []
+ max_num = 10000 if getattr(self.config.eval, "visualize_data_only", False) else 32
+ for i in range(min(len(decoded_data), max_num)):
+ sample_data, sample_modalities = decoded_data[i].to_list()
+ all_imgs.append(self.get_interleaved_image(sample_data, sample_modalities))
+
+ if not getattr(self.config.eval, "visualize_data_only", False):
+ log({f"{split}/rec_img": wandb.Image(Im.concat_horizontal(*all_imgs).pil), **step_metrics})
+ else:
+ gt_txt, gt_img = self.decode_batch(batch["input_ids"], text_only=False, sample_modality=batch.get("modality", None))
+ if gt_img is not None:
+ rec_img = decode_latents(self.config, self.get_vae(), gt_img)
+ log({f"{split}/rec_img": wandb.Image(rec_img), **step_metrics})
+
+ gt_txt = gt_txt[:4]
+ if self.config.trainer.multimodal_batches:
+ txt_batch = batch["input_ids"][~self.img_txt_pair_batch_mask(batch)]
+ if txt_batch.shape[0] > 0:
+ rprint(f"Txt Only (GT): {wrapped_batch_decode(self.tokenizer, txt_batch[:4], clean_up_tokenization_spaces=True, skip_special_tokens=True)}")
+ else:
+ rprint(f"GT Captions: {wrapped_batch_decode(self.tokenizer, gt_txt, clean_up_tokenization_spaces=True, skip_special_tokens=True)}")
+ else:
+ if gt_txt is not None:
+ rprint(f"GT Captions: {wrapped_batch_decode(self.tokenizer, gt_txt, clean_up_tokenization_spaces=True, skip_special_tokens=True)}")
+
+ if getattr(self.config.eval, "visualize_data_only", False):
+ exit()
+
+ if split == "train":
+ if hasattr(self, "vae"):
+ del self.vae
+ empty_device_cache()
+
+
+@try_except(write_error_to_file=True, clear_cuda_cache=True)
+def mauve_store_references(self, dataloader):
+ total_batches = len(dataloader)
+ sample_batch = next(iter(dataloader))
+ batch_size = sample_batch["input_ids"].shape[0]
+ # only execute on rank 0
+ N = self.config.eval.mauve_num_samples
+ if not is_main_process():
+ return
+ if N is None or N <= 0 or batch_size * total_batches < N:
+ rprint(f"[WARNING] Skipping Mauve reference storage. N: {N}, batch_size: {batch_size}, total_batches: {total_batches}")
+ return
+ # need to get N samples from dataloader, which has a batch size of batch_size
+ # we need to get ceil(N / batch_size) batches
+ num_batches = math.ceil(N / batch_size)
+ # store in self.mauve_references
+ for i, batch in tqdm(enumerate(dataloader), total=num_batches, desc="Mauve storing references"): #, disable=not is_main_process()):
+ if i >= num_batches:
+ break
+ reference_txt_tokens, _ = self.decode_batch(batch["input_ids"], text_only=False, sample_modality=batch.get("modality", None))
+ reference_txt = wrapped_batch_decode(self.tokenizer, reference_txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)
+ self.mauve_references.extend(reference_txt)
+
+ assert len(self.mauve_references) >= N, f"len(self.mauve_references) ({len(self.mauve_references)}) < N ({N})"
+ self.mauve_references = self.mauve_references[:N]
+ save_path = os.path.join(self.config.output_dir, f'mauve_references_{N}.pkl')
+ with open(save_path, 'wb') as f:
+ pickle.dump(self.mauve_references, f)
+ rprint(f"[MAUVE] Stored {N} references in {save_path}")
+
+
+@try_except(write_error_to_file=True)
+def cleanup_fid_output(self):
+ if getattr(self.config.eval, "force_fid_output_dir", None) is not None:
+ return
+ if hasattr(self, "fid_gen_dir"):
+ fid_output_dir_path = Path(self.fid_gen_dir)
+ if fid_output_dir_path.exists() and fid_output_dir_path.is_dir():
+ rprint(f"Removing fid output dir: {fid_output_dir_path}")
+ shutil.rmtree(fid_output_dir_path)
+
+ if hasattr(self, "fid_gt_dir"):
+ fid_gt_dir_path = Path(self.fid_gt_dir)
+ if fid_gt_dir_path.exists() and fid_gt_dir_path.is_dir():
+ rprint(f"Removing fid gt dir: {fid_gt_dir_path}")
+ shutil.rmtree(fid_gt_dir_path)
+
+ if hasattr(self, "img_to_txt_mauve_gen_dir"):
+ img_to_txt_mauve_gen_dir_path = Path(self.img_to_txt_mauve_gen_dir)
+ if img_to_txt_mauve_gen_dir_path.exists() and img_to_txt_mauve_gen_dir_path.is_dir():
+ rprint(f"Removing img to txt mauve gen dir: {img_to_txt_mauve_gen_dir_path}")
+ shutil.rmtree(img_to_txt_mauve_gen_dir_path)
+
+ if hasattr(self, "img_to_txt_mauve_gt_dir"):
+ img_to_txt_mauve_gt_dir_path = Path(self.img_to_txt_mauve_gt_dir)
+ if img_to_txt_mauve_gt_dir_path.exists() and img_to_txt_mauve_gt_dir_path.is_dir():
+ rprint(f"Removing img to txt mauve gt dir: {img_to_txt_mauve_gt_dir_path}")
+ shutil.rmtree(img_to_txt_mauve_gt_dir_path)
+
+def compute_val_metrics_standalone(self):
+ rprint("Computing validation metrics standalone")
+ self.reset_validation_metrics()
+ num_samples = 0
+ for i, batch in tqdm(enumerate(self.validation_dataloader), desc="Standalone validation steps", disable=not is_main_process(), leave=False):
+ batch = self.update_batch(batch)
+ num_samples += batch["input_ids"].shape[0]
+ self.compute_loss(batch, prefix="val", batch_idx=i)
+ if i >= self.config.eval.num_val_metrics_standalone_batches_per_device:
+ break
+
+ log({**self.get_step_metrics(), "num_samples": num_samples * get_world_size()})
+ rprint(f"Finished computing validation metrics standalone.")
+
+
+def compute_val_metrics_constant_per_batch(self):
+ rprint("Computing validation metrics standalone")
+ self.reset_validation_metrics()
+ if self.config.eval.num_val_metrics_standalone_batches_per_device is None or self.config.eval.num_val_metrics_standalone_batches_per_device <= 0:
+ return
+ num_samples = 0
+ for i, batch in tqdm(enumerate(self.validation_dataloader), desc="Standalone validation steps", disable=not is_main_process(), leave=False):
+ batch = self.update_batch(batch)
+ num_samples += batch["input_ids"].shape[0]
+ self.compute_loss(batch, prefix="val", batch_idx=i)
+ if i >= self.config.eval.num_val_metrics_standalone_batches_per_device:
+ break
+
+ log({**self.get_step_metrics(), "num_samples": num_samples * get_world_size()})
+ rprint(f"Finished computing validation metrics standalone.")
+
+def get_interleaved_image(self, sample_data, sample_modalities, single_image_only=False, disable_img_save=False, image_save_postfix=None):
+ all_sample_imgs = []
+ single_image_only = self.config.eval.auto_enhance or single_image_only or getattr(self.config.eval, "fake_interleaved", False)
+ if getattr(self.config.eval, "disable_shm_save", False):
+ disable_img_save = True
+
+ if not disable_img_save:
+ date_folder = datetime.now().strftime("%Y-%m-%d")
+ save_dir = Path("/dev/shm") / os.getenv("USER", 'user') / "imgs" / date_folder
+ save_dir.mkdir(exist_ok=True, parents=True)
+
+ for j in range(len(sample_data)):
+ if sample_modalities[j] == 0 and not single_image_only:
+ sample_text = wrapped_batch_decode(
+ self.tokenizer,
+ sample_data[j][None],
+ clean_up_tokenization_spaces=True,
+ skip_special_tokens=False,
+ disable_mask_after_eos=True
+ )
+ txt_image = create_text_image(text=sample_text[0], desired_width=self.config.data.resolution)
+ all_sample_imgs.append(txt_image)
+ elif sample_modalities[j] == 1:
+ sample_img = decode_latents(self.config, self.get_vae(), sample_data[j][None])
+ all_sample_imgs.append(sample_img)
+
+ if not disable_img_save:
+ image_save_postfix = image_save_postfix or ""
+ filename = f"img_{get_rank()}_{str(time.time()).replace('.', '__')}"[:100] + f"{image_save_postfix}.png"
+ save_path = save_dir / filename
+ if single_image_only:
+ if not disable_img_save:
+ gprint(Im(all_sample_imgs[0]).save(save_path))
+ assert len(all_sample_imgs) == 1, "Expected single image only"
+ return all_sample_imgs[0]
+ else:
+ img = Im.concat_vertical(*all_sample_imgs).pil
+ if not disable_img_save:
+ gprint(Im(img).save(save_path))
+ return img
+
+
+def get_hpsv2_score(
+ self,
+ images,
+ prompts
+):
+ from unidisc.tokenizers.hpsv2_img_score import score, initialize_model
+ if not hasattr(self, "hpsv2_model_dict"):
+ self.hpsv2_model_dict = initialize_model(self.device, "v2.1")
+
+ if isinstance(images, Tensor):
+ images = [Im(x).pil for x in images]
+
+ with torch.inference_mode(mode=False), torch.no_grad():
+ scores = []
+ for img, prompt in zip(images, prompts):
+ scores.append(score(self.hpsv2_model_dict, img, prompt)[0].item())
+ return torch.tensor(scores)
+
+def get_dfn_score(
+ self,
+ images,
+ prompts
+):
+ if isinstance(images, Tensor):
+ images = [Im(x).pil for x in images]
+
+ from open_clip import create_model_from_pretrained, get_tokenizer
+
+ if not hasattr(self, "dfn_model"):
+ self.dfn_model, self.dfn_preprocess = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384')
+ self.dfn_tokenizer = get_tokenizer('ViT-H-14')
+ self.dfn_model.to(str(self.device))
+
+ assert len(images) == len(prompts), "Expected same number of images and prompts"
+ images = torch.stack([self.dfn_preprocess(x) for x in images])
+ text = self.dfn_tokenizer(prompts, context_length=self.dfn_model.context_length)
+ dfn_dtype = next(iter(self.dfn_model.parameters())).dtype
+
+ with torch.no_grad(), torch.cuda.amp.autocast():
+ image_features = self.dfn_model.encode_image(images.to(device=self.device, dtype=dfn_dtype))
+ text_features = self.dfn_model.encode_text(text.to(device=self.device))
+ image_features = F.normalize(image_features, dim=-1)
+ text_features = F.normalize(text_features, dim=-1)
+ sim = (image_features * text_features).sum(dim=-1)
+
+ return sim
+
+
+def get_clip_score(
+ self,
+ images,
+ prompts
+):
+
+ if isinstance(images, Tensor):
+ images = [Im(x).pil for x in images]
+
+ from transformers import (
+ CLIPTokenizer,
+ CLIPTextModelWithProjection,
+ CLIPVisionModelWithProjection,
+ CLIPImageProcessor,
+ )
+
+ if not hasattr(self, "clip_tokenizer"):
+ clip_id = "openai/clip-vit-large-patch14"
+ self.clip_tokenizer = CLIPTokenizer.from_pretrained(clip_id)
+ self.clip_text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_id).to(self.device)
+ self.clip_image_processor = CLIPImageProcessor.from_pretrained(clip_id)
+ self.clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to(self.device)
+
+ assert len(images) == len(prompts), "Expected same number of images and prompts"
+
+ with torch.no_grad(), torch.cuda.amp.autocast():
+ preprocessed_images = self.clip_image_processor(images, return_tensors="pt")["pixel_values"]
+ image_features = self.clip_image_encoder(pixel_values=preprocessed_images.to(self.device)).image_embeds
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
+
+ tokenized_text = self.clip_tokenizer(
+ prompts,
+ max_length=self.clip_tokenizer.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt"
+ )
+ text_features = self.clip_text_encoder(input_ids=tokenized_text.input_ids.to(self.device)).text_embeds
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
+
+ sim = (image_features * text_features).sum(dim=-1)
+
+ return sim
+
+def get_laion_aesthetic_score(
+ self,
+ images,
+ prompts
+):
+ from unidisc.tokenizers.laion_aesthetic_v2 import get_predictor_func
+ if not hasattr(self, "laion_aesthetic_model"):
+ self.laion_aesthetic_model = get_predictor_func(self.device)
+
+ return torch.from_numpy(self.laion_aesthetic_model(images)).squeeze(-1)
+
+def get_model_likelihood_score(self, batch, num_timesteps=100, return_unweighed=True):
+ class_log_probs = []
+ unweighed_class_log_probs = []
+ effective_batch_size = batch['modality'].shape[0]
+ empty_device_cache()
+ times = torch.linspace(0, 1, steps=num_timesteps + 2)[1:-1].to(self.device).to(torch.float32)
+ attention_mask = batch['attention_mask']
+
+ for i in range(num_timesteps):
+ empty_device_cache()
+ t = times[i]
+ t = t.expand(effective_batch_size)
+ sigma, dsigma = self.noise(t)
+
+ unet_conditioning = None # sigma[:, None] -> This causes CUDA OOM
+ move_chance = 1 - torch.exp(-sigma[:, None])
+
+ x0 = batch['input_ids']
+ xt = self.q_xt(x0, move_chance)
+
+ model_output = self.forward(
+ xt, unet_conditioning, return_additional_loss=True, batch=batch, modality=batch['modality']
+ )
+
+ log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1)
+ log_p_theta = torch.where(attention_mask, log_p_theta, 0)
+ std_weighting = (dsigma / torch.expm1(sigma))[:, None]
+ unweighed_log_p_theta = -log_p_theta
+ loss = -log_p_theta * std_weighting
+ log_probs = loss.sum(dim=-1) / attention_mask.sum(dim=-1)
+ unweighed_log_probs = unweighed_log_p_theta.sum(dim=-1) / attention_mask.sum(dim=-1)
+ # print(f'Weighed loss: {log_probs.mean()} | Log P Theta: {-log_p_theta.mean()} | Std Weighting: {std_weighting.mean()}')
+ class_log_probs.append(log_probs)
+ unweighed_class_log_probs.append(unweighed_log_probs)
+
+ overall_time_log_probs = torch.stack(class_log_probs) # (num_time, B)
+ unweighed_overall_time_log_probs = torch.stack(unweighed_class_log_probs) # (num_time, B)
+
+ if return_unweighed:
+ return unweighed_overall_time_log_probs.mean(dim=0) # (B)
+ return overall_time_log_probs.mean(dim=0) # (B)
+
+def get_chameleon_score(self, images, prompts):
+ return torch.tensor(self.calculate_chameleon_perplexity(None, None, prompts, images))
+
+def get_text_likelihood_score(self, images, prompts):
+ return self.compute_generative_perplexity(prompts, return_raw_score=True)
+
+@torch.inference_mode()
+def get_text_reward_model_score(
+ self,
+ images,
+ prompts
+):
+ if not hasattr(self, "text_reward_model"):
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
+ model_name = "Skywork/Skywork-Reward-Llama-3.1-8B"
+ self.text_reward_model = AutoModelForSequenceClassification.from_pretrained(
+ model_name,
+ torch_dtype=torch.bfloat16,
+ device_map=self.device,
+ num_labels=1,
+ )
+ self.text_reward_tokenizer = AutoTokenizer.from_pretrained(model_name)
+
+ prompt = "Please generate a realistic caption for a text-to-image generator. The caption should have proper grammar and describe a realistic scene that a user might ask for. The caption should not be non-sensical. The caption does not need to be elaborate, but should be descriptive and realistic. Penalize improper grammar and spelling."
+
+ batch_size = 4
+ formatted_conversations = []
+ for resp in prompts:
+ conv = [{"role": "user", "content": prompt}, {"role": "assistant", "content": resp}]
+ formatted = self.text_reward_tokenizer.apply_chat_template(conv, tokenize=False)
+ formatted_conversations.append(formatted)
+
+ all_scores = []
+ for i in range(0, len(formatted_conversations), batch_size):
+ batch_texts = formatted_conversations[i : i + batch_size]
+ batch_inputs = self.text_reward_tokenizer(
+ batch_texts, return_tensors="pt", padding=True, truncation=True
+ ).to(self.device)
+
+ with torch.no_grad():
+ batch_logits = self.text_reward_model(**batch_inputs).logits.squeeze(-1)
+
+ all_scores.extend(batch_logits.cpu().tolist())
+
+ return torch.tensor(all_scores).to(self.device)
+
+
+def get_rewards(self, reward_config, images, prompts, batch=None, return_raw_rewards=False):
+ assert isinstance(images, Tensor) and isinstance(prompts, list), "Expected images to be a Tensor and prompts to be a list"
+ assert images.ndim == 4 and 0 <= images.min() and images.max() <= 1, "Expected images to be in [0, 1]"
+ assert len(prompts) == images.shape[0], "Expected same number of images and prompts"
+ reward_name_to_fn = dict(
+ dfn_score=self.get_dfn_score,
+ clip_score=self.get_clip_score,
+ hpsv2_score=self.get_hpsv2_score,
+ laion_aesthetic_score=self.get_laion_aesthetic_score,
+ model_likelihood_score=self.get_model_likelihood_score,
+ chameleon_score=self.get_chameleon_score,
+ text_likelihood_score=self.get_text_likelihood_score,
+ text_reward_model_score=self.get_text_reward_model_score
+ )
+
+ rewards = []
+ raw_rewards = dict()
+ for reward_name, reward_weight in reward_config.items():
+ start_time = time.time()
+ assert reward_name in reward_name_to_fn, f"Invalid reward name: {reward_name}"
+ reward_fn = reward_name_to_fn[reward_name]
+ if reward_name == "model_likelihood_score" or reward_name == "chameleon_score" or reward_name == "text_likelihood_score":
+ assert batch is not None, "Expected batch to be provided for model likelihood score"
+ if reward_name == "chameleon_score" or reward_name == "text_likelihood_score":
+ reward = reward_fn(images, prompts).cpu()
+ else:
+ reward = reward_fn(batch=batch).cpu()
+ raw_rewards[reward_name] = reward
+ rprint(f"Orig {reward_name}: {reward}")
+ reward = -reward
+ reward = (reward - reward.min()) / (reward.max() - reward.min())
+ rprint(f"Normalized {reward_name}: {reward}")
+ else:
+ reward = reward_fn(images, prompts).cpu()
+ raw_rewards[reward_name] = reward
+ # reward = reward.softmax(dim=-1)
+ reward = (reward - reward.min()) / (reward.max() - reward.min())
+
+ reward = torch.nan_to_num(reward, nan=0.0)
+ rewards.append(reward * reward_weight)
+ print(f"Processed {reward_name} in {time.time() - start_time:.2f} seconds")
+
+ rewards = torch.stack(rewards, dim=-1).sum(dim=-1)
+
+ if return_raw_rewards:
+ return rewards, raw_rewards
+
+ return rewards
+
+def clear_reward_models(self):
+ if hasattr(self, "laion_aesthetic_model"):
+ del self.laion_aesthetic_model
+ if hasattr(self, "dfn_model"):
+ del self.dfn_model
+ if hasattr(self, "dfn_tokenizer"):
+ del self.dfn_tokenizer
+ if hasattr(self, "clip_tokenizer"):
+ del self.clip_tokenizer
+ if hasattr(self, "clip_text_encoder"):
+ del self.clip_text_encoder
+ if hasattr(self, "clip_image_processor"):
+ del self.clip_image_processor
+ if hasattr(self, "clip_image_encoder"):
+ del self.clip_image_encoder
+ if hasattr(self, "text_reward_model"):
+ del self.text_reward_model
+ if hasattr(self, "text_reward_tokenizer"):
+ del self.text_reward_tokenizer
+ if hasattr(self, "hpsv2_model_dict"):
+ del self.hpsv2_model_dict
+
+def auto_enhance(self, batch):
+ gprint(f"Auto enhancing")
+ from dataloader import tokenize_text
+ assert isinstance(batch, TensorDict), "Expected batch to be a TensorDict"
+ batch = batch.squeeze(1)
+ assert batch['input_ids'].ndim == 2, "Expected batch to be 2D"
+
+ # from datasets import load_dataset
+ # dataset = load_dataset("nateraw/parti-prompts", split='train')
+ # dataset = dataset.filter(lambda x: x["Category"] == "Artifacts")
+
+ x0 = batch["input_ids"].clone()
+ add_object = getattr(self.config.eval, "auto_enhance_add_object", False)
+ if add_object:
+ img_tokens = x0[:, self.static_img_sl] - self.text_vocab_size
+ assert 0 <= img_tokens.min() and img_tokens.max() <= self.image_vocab_size, "Expected img tokens to be in [0, img_vocab_size]"
+ orig_imgs = decode_latents(self.config, self.get_vae(), img_tokens)
+ orig_imgs = [Im(img).pil for img in orig_imgs]
+ aug_imgs = [augment_image_with_random_object_coco(img, str(UNIDISC_DIR / "archive" / "objects")) for img in orig_imgs]
+ gprint(f"Augmented {len(aug_imgs)} images")
+ aug_imgs = torch.stack([Im(img).torch for img in aug_imgs]).to(self.device)
+ image_ids = get_image_batch(self.config, self.get_vae(), {"img": aug_imgs}, self.device)
+ x0[:, self.static_img_sl] = image_ids + self.text_vocab_size
+
+ gen_batch = batch.clone()
+ if 'interleaved_metadata' in gen_batch:
+ del gen_batch['interleaved_metadata']
+ gen_batch.auto_batch_size_()
+
+ orig_caption = wrapped_batch_decode(self.tokenizer, batch['input_ids'][:, self.static_txt_sl], clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=True)
+
+ max_num_augmentations = getattr(self.config.eval, "max_num_auto_enhance_augmentations", 10)
+
+ llm_func = get_llm(llm_model_type="")
+ llm_augmented_captions = [llm_func(cap, fake_openai_failure=False)[0] for cap in orig_caption]
+ _augmented_captions = []
+ for caps in llm_augmented_captions:
+ _shuf = deepcopy(caps)
+ random.shuffle(_shuf)
+ assert len(_shuf) >= max_num_augmentations, "Expected at least max_num_augmentations augmentations"
+ _augmented_captions.append(_shuf[:max_num_augmentations])
+ gprint(f"Augmented {len(_augmented_captions)} captions")
+
+ _orig_imgs = Im(decode_latents(self.config, self.get_vae(), x0[:, self.static_img_sl] - self.text_vocab_size)).pil
+ if not isinstance(_orig_imgs, list):
+ _orig_imgs = [_orig_imgs]
+
+ num_iter_per_sample = self.config.eval.num_auto_enhance_iter
+ num_iter = num_iter_per_sample * max_num_augmentations
+ bs = 1
+ n = num_iter * bs * len(_augmented_captions)
+ _gen_batch = []
+ for i in range(len(_augmented_captions)):
+ for j in range(num_iter):
+ _gen_batch.append(gen_batch[[i]])
+ gen_batch = torch.cat(_gen_batch, dim=0)
+
+ txt_data = [tokenize_text(self.tokenizer, self.config.data.block_size, caps) for caps in _augmented_captions]
+ txt_sl = slice(None, self.config.data.block_size)
+ real_captions = []
+ augmented_captions = []
+ orig_images = []
+
+ gprint(f"Generating {num_iter} samples, gen_batch shape: {gen_batch.shape}")
+
+ for j in range(len(_augmented_captions)):
+ for k in range(max_num_augmentations):
+ sl = slice(j * max_num_augmentations + k * num_iter_per_sample, j * max_num_augmentations + (k + 1) * num_iter_per_sample)
+ gen_batch[sl]['input_ids'][:, txt_sl] = txt_data[j]['input_ids'][k]
+ gen_batch[sl]['attention_mask'][:, txt_sl] = txt_data[j]['attention_mask'][k]
+ augmented_captions.extend([_augmented_captions[j][k]] * num_iter_per_sample)
+ real_captions.extend([orig_caption[j]] * num_iter_per_sample)
+ orig_images.extend([_orig_imgs[j]] * num_iter_per_sample)
+
+ # min_val, max_val = 0.94, 0.98
+ # _eps_t = torch.rand(n, device=self.device)
+ # offset = torch.arange(n, device=self.device) / n
+ # _eps_t = (_eps_t / n + offset) % 1
+ # t = (max_val - min_val) * _eps_t + min_val
+
+ if getattr(self.config.eval, "auto_enhance_use_low_masking", False):
+ mean_txt, std_txt = 0.85, 0.2 / 0.8416 # First half
+ mean_img, std_img = 0.75, 0.04 / 1.645 # Second half - higher mean = more masking
+ else:
+ mean_txt, std_txt = 0.85, 0.2 / 0.8416 # First half
+ mean_img, std_img = 0.95, 0.04 / 1.645 # Second half - higher mean = more masking
+
+ def slice_len(_sl, _seq_len):
+ # TODO: This is super incorrect
+ assert _sl.step is None
+ if _sl.start is not None and _sl.start < 0:
+ assert _sl.stop is None
+ return -_sl.start
+ else:
+ return (_sl.stop if _sl.stop is not None else _seq_len) - (_sl.start if _sl.start is not None else 0)
+
+ seq_len = x0.shape[1]
+
+ t = torch.zeros((n,), device=self.device)
+ t = t.to(torch.float32)
+
+ t_txt = torch.normal(mean=mean_txt, std=std_txt, size=(n,), device=self.device)
+ t_img = torch.normal(mean=mean_img, std=std_img, size=(n,), device=self.device)
+
+ t_txt = torch.clamp(t_txt, max=1.0)
+ t_img = torch.clamp(t_img, max=1.0)
+ move_indices = torch.zeros(n, seq_len, device=self.device, dtype=torch.bool)
+
+ move_indices[:, self.static_txt_sl] = torch.rand(move_indices.shape[0], slice_len(self.static_txt_sl, seq_len), device=self.device) < t_txt.unsqueeze(1)
+ move_indices[:, self.static_img_sl] = torch.rand(move_indices.shape[0], slice_len(self.static_img_sl, seq_len), device=self.device) < t_img.unsqueeze(1)
+
+ x0_unmask = ~move_indices
+ rprint(f"Text masking ratio: {move_indices[:, self.static_txt_sl].sum() / move_indices[:, self.static_txt_sl].numel():.3f}")
+ rprint(f"Image masking ratio: {move_indices[:, self.static_img_sl].sum() / move_indices[:, self.static_img_sl].numel():.3f}")
+ rprint(f"Num unmasked: {x0_unmask.sum(dim=-1).float().mean():.1f}")
+
+ text_samples_list = []
+ img_samples_list = []
+
+ x0 = x0.to(self.device)
+ x0_unmask = x0_unmask.to(self.device)
+
+ idx = 0
+ for i in range(len(_augmented_captions)):
+ for j in range(num_iter_per_sample):
+ _modality = gen_batch[[idx]].get("modality", None)
+ _sample_ids = gen_batch[[idx]].get("sample_ids", None)
+ if _modality is not None:
+ _modality = _modality.to(self.device)
+ if _sample_ids is not None:
+ _sample_ids = _sample_ids.to(self.device)
+ else:
+ _sample_ids = torch.zeros_like(_modality)
+ text_samples, img_samples, x = self._sample(
+ text_only=False,
+ num_steps=self.config.sampling.max_sampling_steps,
+ batch_size_per_gpu=bs,
+ modality=_modality,
+ sample_ids=_sample_ids,
+ x0=gen_batch["input_ids"][[idx]].to(self.device),
+ x0_unmask=x0_unmask[[idx]].to(self.device),
+ return_raw_data=True,
+ allow_interleaved_conditional=True
+ )
+ gen_batch[[idx]]['input_ids'] = x
+ text_samples_list.extend(text_samples)
+ img_samples_list.extend(img_samples)
+ rprint(f"Sampled {j + 1} / {num_iter}")
+ idx += 1
+
+ # gen_batch = torch.cat([gen_batch, orig_batch], dim=0)
+ # for i in range(orig_batch.shape[0]):
+ # _modality = orig_batch[[i]].get("modality", None)
+ # _sample_ids = orig_batch[[i]].get("sample_ids", None)
+ # if _modality is not None:
+ # _modality = _modality.to(self.device)
+ # if _sample_ids is not None:
+ # _sample_ids = _sample_ids.to(self.device)
+ # else:
+ # _sample_ids = torch.zeros_like(_modality)
+ # res = self.decode_sampling(
+ # orig_batch[[i]]["input_ids"].to(self.device),
+ # text_only=False,
+ # modality=_modality,
+ # sample_ids=_sample_ids
+ # )
+ # text_samples_list.extend(res[0])
+ # img_samples_list.extend(res[1])
+ # augmented_captions.append(orig_caption[i])
+ # real_captions.append(orig_caption[i])
+ # orig_images.append(orig_imgs[i])
+
+ text_samples_list = wrapped_batch_decode(
+ self.tokenizer,
+ torch.stack(text_samples_list, dim=0),
+ clean_up_tokenization_spaces=True,
+ skip_special_tokens=True,
+ disable_mask_after_eos=True
+ )
+
+ # for i in range(len(text_samples_list) - orig_batch.shape[0], len(text_samples_list)):
+ # text_samples_list[i] = "Original: " + text_samples_list[i]
+
+ img_samples_list = torch.cat(img_samples_list, dim=0)
+
+ reward_config = self.config.eval.auto_enhance_reward_config
+ rewards, raw_rewards = self.get_rewards(reward_config, img_samples_list, text_samples_list, batch=gen_batch, return_raw_rewards=True)
+
+ gprint(f"Avg Rewards: {rewards}")
+
+ sorted_indices = torch.argsort(rewards, descending=True).tolist()
+ sorted_text_samples = [text_samples_list[i] for i in sorted_indices]
+ sorted_augmented_captions = [augmented_captions[i] for i in sorted_indices]
+ sorted_real_captions = [real_captions[i] for i in sorted_indices]
+ sorted_img_samples = [img_samples_list[i] for i in sorted_indices]
+ sorted_orig_images = [orig_images[i] for i in sorted_indices]
+ sorted_avg_rewards = [rewards[i] for i in sorted_indices]
+ sorted_raw_rewards = {k: [raw_rewards[k][i] for i in sorted_indices] for k in raw_rewards}
+
+ text_samples_list = sorted_text_samples
+ real_captions = sorted_real_captions
+ augmented_captions = sorted_augmented_captions
+ img_samples_list = sorted_img_samples
+ orig_images = sorted_orig_images
+ raw_rewards = sorted_raw_rewards
+
+ # clear all reward models
+ self.clear_reward_models()
+
+ log_dict = {}
+ with try_except(write_error_to_file=True):
+ if text_samples_list is not None:
+ gprint(f"Gathering {len(text_samples_list)} text samples")
+ text_samples_list = gather_object(text_samples_list)
+
+ real_captions = gather_object(real_captions)
+ augmented_captions = gather_object(augmented_captions)
+ prefix = "auto_enhance"
+
+ if isinstance(img_samples_list, Tensor): img_samples_list = img_samples_list.float().cpu()
+ img_samples_list = [Im(img).pil for img in img_samples_list]
+ img_samples_list = gather_object(img_samples_list)
+ orig_images = gather_object(orig_images)
+
+ dprint(f"Gathered {len(text_samples_list)} text samples")
+
+ new_sorted_avg_rewards = gather_object(sorted_avg_rewards)
+ sorted_avg_rewards = new_sorted_avg_rewards
+
+ new_raw_rewards = {k: gather_object(v) for k, v in raw_rewards.items()}
+ raw_rewards = new_raw_rewards
+ rprint(f"Finished gathering, length: {len(orig_images)}")
+
+ gen_table = wandb.Table(columns=[f"real_caption", f"original_image", f"augmented_caption", f"sampled_caption", f"sampled_image", f"avg_reward", *reward_config.keys()])
+ assert len(img_samples_list) == len(text_samples_list) == len(augmented_captions) == len(real_captions) == len(sorted_avg_rewards)
+ for real_caption, orig_img, augmented_caption, sampled_caption, sampled_img, avg_reward, *rewards in zip(real_captions, orig_images, augmented_captions, text_samples_list, img_samples_list, sorted_avg_rewards, *raw_rewards.values()):
+ gen_table.add_data(real_caption, wandb.Image(Im(orig_img).pil), augmented_caption, sampled_caption, wandb.Image(Im(sampled_img).pil), avg_reward, *rewards)
+
+ log_dict[f"{prefix}_sample_table"] = gen_table
+
+ log({**log_dict, **self.get_step_metrics()})
+
+def save_image_text_pair(self, image_tensor, text_tensor, single_image_only=False, disable_img_save=False, image_save_postfix=None):
+ """
+ Take separate image and text tensors and save them as paired visualizations.
+
+ Args:
+ image_tensor: Tensor [B, N] of image tokens
+ text_tensor: Tensor [B, M] of text tokens
+ single_image_only: If True, only return the image without text visualization
+ disable_img_save: If True, don't save to disk
+ image_save_postfix: Optional postfix for the saved image filename
+
+ Returns:
+ PIL Image or tensor of concatenated images and text visualizations
+ """
+ batch_size = image_tensor.shape[0]
+ assert batch_size == text_tensor.shape[0], "Batch sizes must match between image and text tensors"
+
+ all_paired_imgs = []
+
+ # Check config settings for single_image_only
+ if hasattr(self, 'config') and hasattr(self.config, 'eval'):
+ single_image_only = self.config.eval.auto_enhance or single_image_only or getattr(self.config.eval, "fake_interleaved", False)
+
+ if hasattr(self, 'config') and hasattr(self.config.eval, "disable_shm_save"):
+ disable_img_save = disable_img_save or getattr(self.config.eval, "disable_shm_save", False)
+
+ # Create save directory if needed
+ if not disable_img_save:
+ date_folder = datetime.now().strftime("%Y-%m-%d")
+ save_dir = Path("/dev/shm") / os.getenv("USER", 'user') / "paired_imgs" / date_folder
+ save_dir.mkdir(exist_ok=True, parents=True)
+
+ for i in range(batch_size):
+ pair_imgs = []
+
+ # Process text (if not in single_image_only mode)
+ if not single_image_only:
+ sample_text = wrapped_batch_decode(
+ self.tokenizer,
+ text_tensor[i:i+1],
+ clean_up_tokenization_spaces=True,
+ skip_special_tokens=False,
+ disable_mask_after_eos=True
+ )
+ txt_image = create_text_image(text=sample_text[0], desired_width=self.config.data.resolution)
+ pair_imgs.append(txt_image)
+
+ # Process image
+ img_tokens = image_tensor[i:i+1]
+ sample_img = decode_latents(self.config, self.get_vae(), img_tokens)
+ pair_imgs.append(sample_img)
+
+ # Combine text and image for this pair
+ if single_image_only:
+ all_paired_imgs.append(pair_imgs[0])
+ else:
+ paired_img = Im.concat_vertical(*pair_imgs).pil
+ all_paired_imgs.append(paired_img)
+
+ # Save images if needed
+ if not disable_img_save:
+ image_save_postfix = image_save_postfix or ""
+ for i, img in enumerate(all_paired_imgs):
+ filename = f"pair_{get_rank()}_{i}_{str(time.time()).replace('.', '__')}"[:100] + f"{image_save_postfix}.png"
+ save_path = save_dir / filename
+ gprint(Im(img).save(save_path))
+
+ # Return either a single image or all as list
+ if batch_size == 1:
+ return all_paired_imgs[0]
+ else:
+ return all_paired_imgs
\ No newline at end of file
diff --git a/model_setup.py b/model_setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..33109c617f7be9689f7145bca64cdc9cb6e5d97c
--- /dev/null
+++ b/model_setup.py
@@ -0,0 +1,1054 @@
+import functools
+import itertools
+import os
+import signal
+import subprocess
+import sys
+import time
+import typing
+from functools import partial
+from pathlib import Path
+from types import FrameType
+from contextlib import nullcontext
+
+import transformers
+from constants import HF_TOKEN, HF_CACHE_DIR
+import hydra
+import hydra.utils
+import torch
+import torch.utils.checkpoint
+from accelerate.utils import gather, gather_object
+from omegaconf import open_dict, read_write
+from safetensors.torch import load_file
+
+import models.noise_schedule as noise_schedule
+import utils
+import wandb
+from decoupled_utils import (barrier, dprint, get_slurm_job_id, get_world_size, gprint,
+ is_local_main_process, is_main_process,
+ is_torch_cuda_available, is_torch_xla_available,
+ module_hash, parameter_hash, print_memory,
+ rank_zero_fn, rprint, save_memory_profile,
+ show_memory_usage, try_except, use_dist)
+from unidisc.tokenizers.image_tokenizers import get_vae as tokenizer_get_vae
+from unidisc.utils.xla_utils import (tpu_spmd_dataloader, wrap_xla_fsdp)
+from model_utils import BPD, NLL, Perplexity, empty_device_cache, log, CIDErScore, Accuracy
+from unidisc.utils.trainer_utils import (TrainingState, check_every_n_epochs,
+ check_every_n_steps, handle_checkpointing_dirs, count_parameters)
+from utils import compile_model, grad_norm
+
+is_xla_available = is_torch_xla_available()
+if is_xla_available:
+ rprint("Using standalone torchmetrics on XLA")
+ from unidisc.utils.standalone_metrics import MetricCollection
+else:
+ from torchmetrics import MetricCollection
+
+def init(self, config, tokenizer, device):
+ import models
+ import models.elm_custom as elm_custom
+
+ self.global_step = 0
+ self.current_run_global_step = 0
+ self.current_run_fwd_bwd_pass = 0
+ self.num_evals = 0
+
+ self.config = config
+ self.device = device
+ self.image_model = False
+ self.unified_model = False
+
+ self.dtype = (
+ torch.float32
+ if ("fp32" in self.config.trainer.precision or "no" in self.config.trainer.precision)
+ else (torch.bfloat16 if "bf16" in self.config.trainer.precision else torch.float16)
+ )
+ rprint(f"Set compute dtype in model: {self.dtype}")
+
+ if getattr(self.config.model, "image_model", False):
+ self.image_model = True
+ if "tokens" not in self.config.data.train:
+ self.vae = self.get_vae()
+ if self.config.data.cond_resolution is not None:
+ self.cond_vae = self.get_cond_vae()
+ else:
+ self.vae = None
+ self.cond_vae = None
+
+ if getattr(self.config.model, "unified_model", False):
+ self.unified_model = True
+
+ self.tokenizer = tokenizer
+ self.sampler = self.config.sampling.predictor
+ self.gen_ppl_eval_model_name_or_path = self.config.eval.gen_ppl_eval_model_name_or_path
+ self.antithetic_sampling = self.config.trainer.antithetic_sampling
+ self.importance_sampling = self.config.trainer.importance_sampling
+ self.change_of_variables = self.config.trainer.change_of_variables
+ if getattr(self.config.trainer, "add_label", False):
+ assert self.image_model and self.unified_model
+
+ if self.image_model is False or self.unified_model:
+ self.vocab_size = len(self.tokenizer)
+ if getattr(self.config.model, "force_text_vocab_size", None) is not None:
+ self.vocab_size = self.config.model.force_text_vocab_size
+ if not hasattr(self.tokenizer, "mask_token") or self.tokenizer.mask_token is None:
+ self.mask_index = self.vocab_size
+ self.vocab_size += 1
+ else:
+ self.mask_index = self.tokenizer.mask_token_id
+ if self.image_model:
+ if self.unified_model:
+ self.text_vocab_size = self.vocab_size
+ self.vocab_size += self.config.model.image_vocab_size
+ self.image_vocab_size = self.config.model.image_vocab_size
+ if getattr(self.config.model, "add_labels", None) is not None:
+ rprint(f"Adding labels: {self.config.model.add_labels}")
+ self.vocab_size += self.config.model.add_labels
+ rprint(f"Text vocab size: {self.text_vocab_size}, Image vocab size: {self.image_vocab_size}")
+ else:
+ self.vocab_size = self.config.model.image_vocab_size + 1
+ self.mask_index = self.vocab_size - 1
+ self.text_vocab_size = 0
+ else:
+ self.text_vocab_size = self.vocab_size
+
+ rprint(f"Vocab size: {self.vocab_size}, Mask index: {self.mask_index}")
+ rprint(f"Image Model: {self.image_model}, Unified Model: {self.unified_model}")
+ self.parameterization = self.config.parameterization
+
+ tf_kwargs = dict(device_map=self.device, use_auth_token=HF_TOKEN, torch_dtype=self.dtype if (self.config.model.use_lora or self.config.trainer.low_precision_params) else torch.float32, trust_remote_code=True, cache_dir=HF_CACHE_DIR)
+ tf_kwargs['attn_implementation'] = 'sdpa' if is_xla_available else 'flash_attention_2'
+ force_sdpa_attention = os.environ.get("UNIDISC_FORCE_CHAMELEON_SDPA_ATTENTION", "0") == "1"
+ force_eager_attention = os.environ.get("UNIDISC_FORCE_EAGER_ATTENTION", "0") == "1"
+ if force_sdpa_attention:
+ tf_kwargs['attn_implementation'] = 'sdpa'
+ rprint("WARNING!!!! Forcing SDPA Attention")
+ if force_eager_attention:
+ tf_kwargs['attn_implementation'] = 'eager'
+ rprint("WARNING!!!! Forcing Eager Attention")
+
+ if is_xla_available:
+ del tf_kwargs['cache_dir']
+ rprint(f"Using cache dir: {HF_CACHE_DIR}")
+
+ if self.config.backbone == "dit":
+ dit_kwargs = dict(mask_index=self.mask_index)
+ if getattr(self.config.trainer, "use_orig_unidisc_dit", False):
+ from accelerate.utils import set_seed; set_seed(42)
+ if self.config.model.full_attention:
+ import models.dit_orig
+ _backbone_cls = models.dit_orig.DIT
+ rprint("WARNING!!!! Using original DIT")
+ dit_kwargs.pop('mask_index')
+ else:
+ import models.autoregressive_orig
+ _backbone_cls = models.autoregressive_orig.AR
+ dit_kwargs['causal'] = not self.config.model.full_attention
+ rprint(f"WARNING!!!! Using original AR DIT, {dit_kwargs}")
+ else:
+ import models.dit
+ _backbone_cls = models.dit.DIT
+ dit_kwargs['text_vocab_size'] = self.text_vocab_size
+ dit_kwargs['autocast_dtype'] = self.dtype
+ dit_kwargs['device'] = self.device
+ dit_kwargs['static_img_sl'] = self.static_img_sl
+ dit_kwargs['static_txt_sl'] = self.static_txt_sl
+
+ self.backbone = _backbone_cls(
+ config=self.config,
+ vocab_size=self.vocab_size,
+ **dit_kwargs
+ )
+ utils.print_trainable_parameters(self.backbone)
+ if self.config.model.mup:
+ self.get_base_shapes_for_mup(self.backbone)
+ elif self.config.backbone == "elm":
+ del tf_kwargs['attn_implementation']
+ config = transformers.AutoConfig.from_pretrained(self.config.model.model_id, **tf_kwargs)
+ config.extra_tokens = self.vocab_size - config.vocab_size
+ config.full_attention = self.config.model.full_attention
+ config.is_compiled = self.is_compiled
+ _cls = elm_custom.OpenELMForCausalLM if self.config.trainer.scratch else partial(elm_custom.OpenELMForCausalLM.from_pretrained, pretrained_model_name_or_path=self.config.model.model_id)
+ self.backbone = _cls(
+ config=config,
+ )
+ if self.config.model.use_lora:
+ from peft import LoraConfig, get_peft_model
+ lora_config = LoraConfig(
+ r=16,
+ lora_alpha=32,
+ target_modules=["qkv_proj"],
+ lora_dropout=0.05,
+ bias="none",
+ task_type="CAUSAL_LM",
+ )
+ self.backbone = get_peft_model(self.backbone, lora_config)
+ self.backbone.model.transformer.token_embeddings_extra.requires_grad_(True)
+ if hasattr(self.backbone.model, "lm_extra"):
+ self.backbone.model.lm_extra.requires_grad_(True)
+ else:
+ self.backbone.requires_grad_(True)
+ self.backbone.train()
+ rprint("Using Full ELM")
+
+ if getattr(self.config.trainer, "scratch", False):
+ rprint("Training from scratch")
+ self.backbone.apply(self.backbone._init_weights)
+
+ if getattr(self.config.trainer, "use_gradient_checkpointing", False):
+ self.backbone.gradient_checkpointing_enable()
+ utils.print_trainable_parameters(self.backbone)
+ elif self.config.backbone == "ar":
+ self.backbone = models.autoregressive.AR(self.config, vocab_size=self.vocab_size, mask_index=self.mask_index)
+ else:
+ raise ValueError(f"Unknown backbone: {self.config.backbone}")
+
+ self.T = self.config.T
+ self.subs_masking = self.config.subs_masking
+ self.softplus = torch.nn.Softplus()
+ if getattr(self.config.trainer, "disable_torchmetrics", False) is False:
+ # metrics are automatically reset at end of epoch
+ metrics = MetricCollection(
+ {
+ "nll": NLL(sync_on_compute=False),
+ "bpd": BPD(sync_on_compute=False),
+ "ppl": Perplexity(sync_on_compute=False),
+ },
+ compute_groups=(not is_torch_xla_available() and not getattr(self.config.trainer, "disable_distributed_torchmetrics", False))
+ )
+ metrics.set_dtype(torch.float64)
+ self.train_metrics = metrics.clone(prefix="train/")
+ self.valid_metrics = metrics.clone(prefix="val/")
+ self.test_metrics = metrics.clone(prefix="test/")
+
+ if getattr(self.config.trainer, "log_seperate_modal_losses", False):
+ self.txt_metrics = metrics.clone(prefix="train/")
+ self.img_metrics = metrics.clone(prefix="train/")
+
+ if getattr(self.config.eval, "compute_chameleon_perplexity", False) or getattr(self.config.eval, "wino_chameleon", False):
+ rprint("[INFO] Loading Big Chameleon Model")
+ # pip install 'git+ssh://git@github.com/alexanderswerdlow/image_utils.git@wip_v1' --force-reinstall
+ from image_utils import Im
+ from transformers import (ChameleonForConditionalGeneration, ChameleonProcessor)
+ self.chameleon_model = ChameleonForConditionalGeneration.from_pretrained("leloy/Anole-7b-v0.1-hf", torch_dtype=torch.bfloat16).to("cuda")
+ self.chameleon_processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf")
+
+ if self.config.mode == "zero-shot-eval":
+ # flickr cider
+ self.cider_score = CIDErScore(sync_on_compute=False)
+
+ # winoground
+ self.win_text_accuracy = Accuracy(sync_on_compute=False)
+ self.win_image_accuracy = Accuracy(sync_on_compute=False)
+ self.win_group_accuracy = Accuracy(sync_on_compute=False)
+
+
+ self.datacomp_img_acc = Accuracy(sync_on_compute=False)
+ self.datacomp_txt_acc = Accuracy(sync_on_compute=False)
+
+ self.eval_model_tokenizer = transformers.AutoTokenizer.from_pretrained(self.gen_ppl_eval_model_name_or_path)
+ if self.eval_model_tokenizer.pad_token is None:
+ self.eval_model_tokenizer.pad_token = self.eval_model_tokenizer.eos_token
+ self.eval_model_tokenizer.pad_token_id = self.eval_model_tokenizer.eos_token_id
+
+ self.noise = noise_schedule.get_noise(self.config, dtype=self.dtype)
+ if self.config.trainer.ema > 0:
+ if self.config.trainer.use_custom_ema:
+ from copy import deepcopy
+ self.ema = deepcopy(self.backbone).eval()
+ self.ema.to(self.device)
+ else:
+ self.ema = models.ema.EMAModel(self.get_params(), decay=self.config.trainer.ema)
+ rprint(f"Using EMA with decay {self.config.trainer.ema}")
+ else:
+ self.ema = None
+
+ self.lr = self.config.optim.lr
+ self.sampling_eps = self.config.trainer.sampling_eps
+ self.time_conditioning = self.config.time_conditioning
+ self.neg_infinity = -1000000.0
+ self.fast_forward_epochs = None
+ self.fast_forward_batches = None
+ self._validate_configuration()
+
+ self.fid_eval = False
+
+ if ((self.config.slurm or self.config.trainer.restart_on_failure) and not self.config.trainer.force_disable_signal_handler) and self.config.mode == 'train':
+ self.register_signal_handler()
+
+ if getattr(self.config.model, "image_model_fid_eval", False) or getattr(self.config.trainer, "disable_strict_load", False):
+ self.strict_loading = False
+
+ if self.config.backbone != 'dit' and self.config.backbone != 'chameleon':
+ assert self.config.model.force_argmax_valid_indices is False
+
+ if self.config.parameterization == "ar":
+ assert self.config.trainer.ar_shift
+
+ self.trainable_params = sum(p.numel() for p in self.backbone.parameters() if p.requires_grad)
+ self.frozen_params = sum(p.numel() for p in self.backbone.parameters() if not p.requires_grad)
+ self.non_embedding_params = count_parameters(self.backbone)
+ rprint(f"Total trainable parameters (excluding embeddings): {self.non_embedding_params:,}, Total trainable parameters: {self.trainable_params:,}, Total frozen parameters: {self.frozen_params:,}")
+ self._validate_configuration()
+
+ if not self.config.trainer.low_precision_params:
+ for name, param in self.backbone.named_parameters():
+ if param.requires_grad and param.dtype != torch.float32:
+ raise ValueError(f"Parameter {name} is not in fp32. It is in {param.dtype}")
+
+ if self.config.eval.test_eval_speed:
+ rprint("WARNING!!!! Running eval speed test")
+
+ self.use_kv_cache = getattr(self.config.model, "use_kv_cache", False)
+ if not getattr(self.config.eval, 'enable_gen_pplx_cleanup', True):
+ assert self.config.mode == 'eval' # shouldn't really be on in train mode
+ rprint(f"WARNING!!!! Disabling gen pplx cleanup, having eval model {self.gen_ppl_eval_model_name_or_path} in memory always!!!!")
+ self.gen_pplx_eval_model = transformers.AutoModelForCausalLM.from_pretrained(self.gen_ppl_eval_model_name_or_path).eval()
+
+ if self.config.eval.compute_standalone_mauve and not getattr(self.config.eval, "global_disable_mauve", False):
+ self.mauve_predictions = []
+ self.mauve_references = []
+
+ if self.config.mode == "zero-shot-eval":
+ self.cider_score_metric = CiderScorer()
+
+ if self.config.mode == "eval":
+ self.backbone.eval()
+ self.backbone.requires_grad_(False)
+
+ if self.config.trainer.awr:
+ breakpoint()
+ config = transformers.AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-135M", **tf_kwargs)
+ config.vocab_size = self.vocab_size
+ config.full_attention = True
+ self.awr_policy = llama_custom.LlamaForCausalLM(
+ config=config,
+ )
+
+
+def to(self, device):
+ self.device = device
+ self.backbone.to(device)
+ self.train_metrics.to(device)
+ self.test_metrics.to(device)
+ if hasattr(self, "txt_metrics"):
+ self.txt_metrics.to(device)
+ if hasattr(self, "img_metrics"):
+ self.img_metrics.to(device)
+
+ if self.ema is not None:
+ self.ema.to(device)
+
+def reset_validation_metrics(self):
+ metrics = MetricCollection(
+ {
+ "nll": NLL(sync_on_compute=False),
+ "bpd": BPD(sync_on_compute=False),
+ "ppl": Perplexity(sync_on_compute=False),
+ },
+ compute_groups=(not is_torch_xla_available() and not getattr(self.config.trainer, "disable_distributed_torchmetrics", False))
+ )
+ metrics.set_dtype(torch.float64)
+
+ if getattr(self.config.trainer, "disable_torchmetrics", False) is False or hasattr(self, "valid_metrics"):
+ self.valid_metrics = metrics.clone(prefix="val/").to(self.device)
+
+ if getattr(self.config.trainer, "log_seperate_modal_losses", False):
+ self.valid_txt_metrics = metrics.clone(prefix="val/").to(self.device)
+ self.valid_img_metrics = metrics.clone(prefix="val/").to(self.device)
+
+ self.gen_ppl_metric = Perplexity(sync_on_compute=False).to(self.device)
+ self.gt_gen_ppl_metric = Perplexity(sync_on_compute=False).to(self.device)
+
+def get_params(self):
+ return itertools.chain(self.backbone.parameters())
+
+def get_vae(self):
+ if getattr(self, "vae", None) is not None:
+ return self.vae
+
+ empty_device_cache()
+
+ self.vae = tokenizer_get_vae(self.config, self.device)
+
+ return self.vae
+
+def get_cond_vae(self):
+ if getattr(self, "cond_vae", None) is not None:
+ return self.cond_vae
+
+ torch.cuda.empty_cache()
+ self.cond_vae = get_vae(self.config, self.device, use_cond=True)
+ return self.cond_vae
+
+
+def configure_optimizers(self):
+ # TODO(yair): Lightning currently giving this warning when using `fp16`:
+ # "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
+ # Not clear if this is a problem or not.
+ # See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558
+ kwargs = dict(
+ betas=(self.config.optim.beta1, self.config.optim.beta2),
+ eps=self.config.optim.eps,
+ weight_decay=self.config.optim.weight_decay,
+ )
+ if getattr(self.config.trainer, "adafactor", False):
+ optim_cls = Adafactor
+ kwargs = dict()
+ kwargs.update({"scale_parameter": False, "relative_step": False})
+ rprint("Using Adafactor")
+ if getattr(self.config.trainer, "ademamix", False):
+ from unidisc.utils.ademamix import AdEMAMix
+ optim_cls = AdEMAMix
+ rprint("Using AdEMAMix")
+ elif is_xla_available:
+ from torch_xla.amp.syncfree import AdamW
+ optim_cls = AdamW
+ rprint("Using XLA AdamW")
+ elif getattr(self.config.trainer, "is_deepspeed", False):
+ import deepspeed
+ optim_cls = deepspeed.ops.adam.FusedAdam
+ kwargs["set_grad_none"] = True
+ else:
+ optim_cls = torch.optim.AdamW
+ kwargs["fused"] = self.config.optim.fused
+
+ if self.config.model.mup:
+ from mup import MuAdam
+ optim_cls = partial(MuAdam, impl=optim_cls)
+
+ optimizer = optim_cls(
+ self.get_params(),
+ lr=self.config.optim.lr,
+ **kwargs,
+ )
+
+ scheduler = hydra.utils.instantiate(self.config.lr_scheduler, optimizer=optimizer)
+ scheduler_dict = {
+ "scheduler": scheduler,
+ "interval": "step",
+ "monitor": "val/loss",
+ "name": "trainer/lr",
+ }
+ return [optimizer], [scheduler_dict]
+
+def _validate_configuration(self):
+ assert not (self.change_of_variables and self.importance_sampling)
+ if self.parameterization == "sedd":
+ assert not self.importance_sampling
+ assert not self.change_of_variables
+ if self.parameterization == "d3pm":
+ assert self.T > 0
+ if self.T > 0:
+ assert self.parameterization in {"d3pm", "subs"}
+ if self.subs_masking:
+ assert self.parameterization == "d3pm"
+
+ if hasattr(self.config.model, "text_vocab_size"):
+ assert self.config.model.text_vocab_size == self.text_vocab_size, f"text_vocab_size {self.config.model.text_vocab_size} != {self.text_vocab_size}"
+
+ if getattr(self.config.trainer, "first_token_dropout", None) is not None:
+ assert self.config.data.allow_label is True
+ assert self.config.trainer.add_label is True
+ assert self.config.model.add_labels > 0
+ assert self.config.trainer.joint_ar_nar_prob is None
+ assert self.config.trainer.mask_entire_modality is None
+
+ if getattr(self.config.eval, "class_conditional_fid", False):
+ assert self.config.eval.fid_mode == "inline"
+
+ assert getattr(self.config.model, "mask_entire_modality", None) is None
+
+ if self.config.trainer.interleaved and not getattr(self.config.eval, "auto_enhance", False) and not getattr(self.config.trainer, "bypass_interleaved_check", False):
+ assert self.config.data.use_packing_collate or self.config.mode == 'eval'
+ assert self.config.data.dynamic_packing_lengths
+ assert self.config.data.require_sample_ids
+ assert self.config.trainer.interleaved_training_flex_attention
+ assert self.config.data.use_slow_tokenizer and self.config.data.add_image_token
+ assert not getattr(self.config.trainer, "force_full_attention_mask_loss_only", False)
+
+ assert self.config.sampling.steps == self.config.sampling.max_sampling_steps
+
+def register_signal_handler(self):
+ def _handler(sig, frame: FrameType | None, prior_handler=None):
+ rprint(f"Called sig handler with {sig=} {self.global_step=}")
+ if sig == signal.SIGUSR1:
+ signal.signal(sig, signal.SIG_IGN)
+
+ checkpoint_path = Path(self.config.output_dir) / "checkpoints"
+ timeout_minutes = self.config.trainer.ckpt_recent_timeout_minutes
+
+ # Don't re-save checkpoint within this interval to avoid unecessary re-writing.
+ # If we checkpoint on SIGUSR2, we don't need to do it on SIGTERM
+ recent_ckpt_exists = checkpoint_path.exists() and any(
+ (time.time() - p.stat().st_mtime) < (timeout_minutes * 60) for p in checkpoint_path.iterdir() if p.is_dir()
+ )
+ if (self.current_run_global_step > 100 and recent_ckpt_exists is False) or self.config.trainer.skip_early_checkpointing is False:
+ rprint(f"Saving checkpoint due to {sig}")
+ self.checkpoint()
+ rprint(f"Finished saving checkpoint due to {sig}")
+ else:
+ rprint(f"Checkpoint already saved within {timeout_minutes} minutes, called by {sig}. Current run global step: {self.current_run_global_step}")
+
+ job_str = get_slurm_job_id()
+ if is_main_process():
+ if sig == signal.SIGTERM:
+ if self.current_run_global_step > 100 and self.config.devices >= 4:
+ wandb.alert(title="Terminated", text=f"Terminated by SIGTERM at {self.global_step}")
+ rprint("Marking experiment as preempting")
+ wandb.mark_preempting()
+
+ rprint(f"Prior handler on rank: {prior_handler}")
+ is_custom_sbatch_launcher = os.environ.get("CUSTOM_SBATCH_LAUNCHER", "0") == "1"
+ if is_custom_sbatch_launcher:
+ rprint("Using custom sbatch launcher, requeueing job manually")
+ subprocess.check_call(["scontrol", "requeue", job_str])
+ rprint("Finished requeueing job")
+ elif prior_handler is not None and callable(prior_handler):
+ rprint("Calling prior signal handler")
+ prior_handler(sig, frame, exit_on_requeue=False)
+ rprint(f"Returned from prior signal handler")
+ else:
+ # TODO: For some unknown reason, sometimes the main process [and a few others] hangs doesn't properly receive the signal.
+ # Generally, we want to let the main process checkpoint/exit but if it fails, we let any rank re-queue.
+ if self.config.slurm:
+ time.sleep(180)
+ rprint(f"WARNING: Not on rank zero! Timed out waiting for main process to exit...Requeuing job...")
+ rprint(f"WARNING: Not on rank zero! Using prior signal handler: {prior_handler}. ")
+ else:
+ time.sleep(5)
+
+ try:
+ if prior_handler is not None and callable(prior_handler):
+ rprint("WARNING: Not on rank zero! Returning to prior handler")
+ prior_handler(sig, frame, exit_on_requeue=False)
+ rprint(f"WARNING: Not on rank zero! Returned from prior handler")
+ except:
+ rprint(f"WARNING: Not on rank zero! Failed to return to prior handler")
+
+ if self.config.slurm:
+ time.sleep(5) # Should be enough time for SLURM to send a SIGTERM to all ranks. If not, we resort to manual requeueing.
+ rprint(f"WARNING: Not on rank zero! Failed to requeue using prior handler, requeuing job ourselves... {job_str}")
+ subprocess.check_call(["scontrol", "requeue", job_str])
+ rprint(f"WARNING: Not on rank zero! Requeued job: {job_str}")
+
+ if self.config.slurm:
+ if torch.distributed.is_initialized():
+ rprint(f"Destroying process group...")
+ torch.distributed.destroy_process_group()
+ return sys.exit(0)
+ else:
+ rprint(f"Not on SLURM, not exiting")
+
+ prior_sigterm_handler = signal.getsignal(signal.SIGTERM)
+ prior_sigusr1_handler = signal.getsignal(signal.SIGUSR1)
+ prior_sigusr2_handler = signal.getsignal(signal.SIGUSR2)
+
+ rprint(f"Found Prior SIGTERM handler: {prior_sigterm_handler}, type: {type(prior_sigterm_handler)}")
+ rprint(f"Found Prior SIGUSR1 handler: {prior_sigusr1_handler}, type: {type(prior_sigusr1_handler)}")
+ rprint(f"Found Prior SIGUSR2 handler: {prior_sigusr2_handler}, type: {type(prior_sigusr2_handler)}")
+
+ signal.signal(signal.SIGTERM, functools.partial(_handler, prior_handler=prior_sigterm_handler))
+ signal.signal(signal.SIGUSR2, functools.partial(_handler, prior_handler=prior_sigusr2_handler))
+ signal.signal(signal.SIGUSR1, functools.partial(_handler, prior_handler=prior_sigusr1_handler))
+
+def on_train_start(self):
+ gprint(f"Starting train at step: {self.global_step}")
+
+ if is_main_process() and getattr(self.config.trainer, "compile", None) is None and getattr(self.config.trainer, "watch_gradients", True):
+ wandb.watch(
+ self.backbone,
+ log=("all" if getattr(self.config.trainer, "watch_all", False) else "gradients"),
+ log_freq=getattr(self.config.trainer, "watch_gradients_freq", 500),
+ )
+
+ if getattr(self.config.trainer, "attach_oom_observer_train", False):
+ from torchtnt.utils.oom import attach_oom_observer
+ attach_oom_observer(output_dir=str(self.config.output_dir), trace_max_entries=500000)
+ gprint(f"Attached OOM observer to {self.config.output_dir}")
+
+ if self.config.trainer.nvtx_profile and self.is_compiled is False:
+ torch.cuda.cudart().cudaProfilerStart()
+
+ # TODO: Make sure we don't need the code below with the new accelerate code.
+ return
+
+def optimizer_step(self, *args, **kwargs):
+ super().optimizer_step(*args, **kwargs)
+ if self.ema is not None:
+ self.ema.update(self.get_params())
+
+def init_dataloader(self, train_dataloader, val_dataloader):
+ rprint("Creating train_dataset + self.train_dataloader")
+ self.train_dataloader = train_dataloader
+ self.validation_dataloader = val_dataloader
+ if not self.config.data.iterable and not self.config.data.webdataset_indexed: assert len(self.validation_dataloader) > 0
+
+def init_optimizer_lr_scheduler(self):
+ [optimizer], [scheduler_dict] = self.configure_optimizers()
+ self.optimizer = optimizer
+ self.lr_scheduler = scheduler_dict["scheduler"]
+
+def set_accelerator(self, accelerator, ckpt_path=None):
+ if ckpt_path is not None:
+ rprint(f"Set accelerator with ckpt path {ckpt_path}")
+
+ self.accelerator = accelerator
+ self.device = accelerator.device
+ self.dtype = getattr(torch, self.config.trainer.dtype.split(".")[-1])
+
+ def _load(obj, path, update_fn=None, key="model"):
+ _ckpt_path = Path(path)
+ if _ckpt_path.is_dir() and (_ckpt_path / "model.safetensors").exists():
+ _ckpt_path = _ckpt_path / "model.safetensors"
+ path = str(_ckpt_path)
+
+ print(f"Loading from {_ckpt_path}, {_ckpt_path.suffix}, {_ckpt_path.is_dir()}")
+ if _ckpt_path.suffix == ".safetensors":
+ state_dict = load_file(path)
+ elif _ckpt_path.is_dir():
+ if getattr(self.config.trainer, 'dynamic_convert_to_normal_state_dict', False):
+ gprint(f"Converting distributed checkpoint to normal state dict")
+ from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
+ import hashlib
+ ckpt_hash = hashlib.md5(str(path).encode()).hexdigest()[:8] + "_" + Path(path).stem
+ new_path = str(Path("/dev/shm") / os.getenv("USER", "aswerdlo") / f"tmp_ckpt_{ckpt_hash}.pth")
+ dcp_to_torch_save(path, new_path)
+ gprint(f"Converted distributed checkpoint to normal state dict at {new_path}")
+ state_dict = torch.load(new_path)
+ gprint(f"Loaded state dict from {path}")
+ else:
+ gprint(f"Loading from distributed checkpoint directory {path}")
+ import torch.distributed.checkpoint as dcp
+ state_dict = {
+ key: obj.state_dict(),
+ }
+ if getattr(self.config.trainer, 'ignore_chameleon_embed', False):
+ for k in list(state_dict[key].keys()):
+ if "embed_tokens" in k:
+ state_dict[key].pop(k)
+ gprint(f"Ignoring {k}")
+ dcp.load(
+ state_dict=state_dict,
+ checkpoint_id=path,
+ )
+ gprint(f"Loaded state dict from {path}")
+ # obj.load_state_dict(state_dict[key])
+ else:
+ state_dict = torch.load(path)
+
+ if 'model' in state_dict and len(state_dict) < 10:
+ state_dict = state_dict['model']
+
+ state_dict = {k.replace("_orig_module.", ""): v for k, v in state_dict.items()}
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
+ if self.config.backbone == 'llama' and "lm_head.weight" in state_dict and "model.embed_tokens.weight" not in state_dict:
+ # LLaMa ties weights
+ state_dict["model.embed_tokens.weight"] = state_dict["lm_head.weight"].clone()
+
+ if update_fn is not None:
+ state_dict = update_fn(state_dict)
+ elif getattr(self.config.trainer, 'use_orig_unidisc_dit', False):
+ # loading from the original .ckpt files from unidisc repo
+ state_dict = state_dict['state_dict']
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
+
+ try:
+ kwargs = {}
+ kwargs['strict'] = self.config.trainer.disable_strict_load
+ if '.bin' in str(path):
+ kwargs = {}
+ obj.load_state_dict(state_dict, **kwargs)
+ except Exception as e:
+ rprint(f"Failed to load state dict: {e}")
+ rprint(f"State dict keys: {state_dict.keys()}")
+ rprint(f"Model state dict keys: {obj.state_dict().keys()}")
+ raise e
+
+ if self.config.mode != 'eval':
+ self.init_optimizer_lr_scheduler()
+
+ if getattr(self.config.trainer, "bypass_load_from_state_dicts_if_resuming", False) and ckpt_path is not None:
+ rprint(f"Skipping load from state dicts since we are resuming from: {ckpt_path}")
+ else:
+ if self.config.trainer.load_from_state_dict is not None:
+ rprint(f"Loading model state dict from {self.config.trainer.load_from_state_dict}")
+ _load(self.backbone, self.config.trainer.load_from_state_dict)
+ rprint(f"Loaded model state dict from {self.config.trainer.load_from_state_dict}")
+
+ if getattr(self.config.trainer, "load_from_optimizer_state_dict", None) is not None:
+ # TODO: Optimizer.bin from accelerate is the wrong format here. Look into this. The keys/are different and need to be mapped.
+ def update_param_group(state_dict):
+ rprint(f"len(self.optimizer.param_groups): {len(self.optimizer.param_groups[0]['params'])}, len(state_dict['param_groups']): {len(state_dict['param_groups'][0]['params'])}")
+ rprint(f"self.optimizer.param_groups: {self.optimizer.param_groups[0]['params']}")
+ rprint(f"state_dict['param_groups']: {state_dict['param_groups'][0]['params']}")
+ state_dict["param_groups"] = self.optimizer.param_groups
+ return state_dict
+
+ _load(self.optimizer, self.config.trainer.load_from_optimizer_state_dict, update_fn=update_param_group, key="optim")
+ rprint(f"Loaded optimizer state dict from {self.config.trainer.load_from_optimizer_state_dict}")
+
+ if self.config.mode == 'eval':
+ rprint(f"Moving model to {self.device}")
+
+ self.backbone.to(self.device)
+ if getattr(self.config.trainer, 'force_bf16_eval', False) and self.config.mode == 'eval':
+ self.backbone.to(torch.bfloat16)
+
+ # Model needs to be wrapped before optimizer is created for fsdp
+ if self.config.trainer.xla_spmd and is_xla_available:
+ self.backbone = wrap_xla_fsdp(self.config, self.backbone)
+
+ self.backbone, self.ema = self.accelerator.prepare(self.backbone, self.ema)
+
+ if self.config.trainer.compile and not is_xla_available:
+ rprint("Compiling entire model...")
+ self.backbone = compile_model(self.config, self.backbone)
+
+ if getattr(self.config.trainer, 'mup_coord_plot', False):
+ self.get_coord_plot()
+
+ if self.config.mode == 'eval':
+ return
+
+ if not self.config.data.iterable and not self.config.data.webdataset_indexed and self.train_dataloader is not None and self.config.data.wrap_dataloaders:
+ rprint(f"Before prepare: Train len: {len(self.train_dataloader)}, Validation len: {len(self.validation_dataloader)}")
+
+ if getattr(self.config.eval, 'test_eval_speed', False):
+ self.optimizer, self.lr_scheduler = None, None
+ else:
+ if getattr(self.config.trainer, 'force_disable_wrap_optimizer', False) is False and self.config.mode != 'eval':
+ self.optimizer, self.lr_scheduler = self.accelerator.prepare(
+ self.optimizer, self.lr_scheduler
+ )
+ elif self.config.mode != 'eval':
+ rprint("WARNING: Not wrapping optimizer with accelerator.prepare()")
+
+ if self.config.data.webdataset_iterable is False and self.config.data.wrap_dataloaders:
+ self.train_dataloader, self.validation_dataloader = self.accelerator.prepare(self.train_dataloader, self.validation_dataloader)
+ else:
+ rprint("WARNING: Not wrapping dataloaders with accelerator.prepare()")
+
+ if is_xla_available and self.config.trainer.fsdp:
+ self.train_dataloader = tpu_spmd_dataloader(self.train_dataloader, self.device)
+ self.validation_dataloader = tpu_spmd_dataloader(self.validation_dataloader, self.device)
+
+ if not self.config.data.iterable and not self.config.data.webdataset_indexed and self.train_dataloader is not None:
+ rprint(f"After prepare: Train len: {len(self.train_dataloader)}, Validation len: {len(self.validation_dataloader)}")
+
+ if (self.config.trainer.use_spmd_distributed_checkpointing or self.config.trainer.use_simple_spmd_distributed_checkpointing) and is_xla_available:
+ gprint("Initializing distributed process group")
+ import torch.distributed as dist
+ import torch_xla.distributed.xla_backend
+ import torch_xla.runtime as xr
+ dist.init_process_group('gloo', init_method='xla://')
+ gprint("Distributed process group initialized, before creating checkpoint manager")
+
+ if (self.config.trainer.use_spmd_distributed_checkpointing and self.config.trainer.disable_all_checkpointing is False) and is_xla_available:
+ gprint("Initializing checkpoint manager")
+ from torch_xla.experimental.distributed_checkpoint import CheckpointManager, prime_optimizer
+ self.chkpt_mgr = CheckpointManager(self.config.checkpointing.save_dir, self.config.trainer.ckpt_steps)
+ gprint(f"Checkpoint manager created")
+
+ if getattr(self.config.trainer, "force_from_ckpt", None) is not None:
+ ckpt_path = getattr(self.config.trainer, "force_from_ckpt")
+ if ckpt_path == "":
+ ckpt_path = None
+
+ if ckpt_path is not None and Path(ckpt_path).exists():
+ rprint(f"Loading checkpoint {ckpt_path}")
+ if self.config.trainer.use_spmd_distributed_checkpointing and self.config.trainer.disable_all_checkpointing is False:
+ gprint("Loading checkpoint for XLA")
+ from torch_xla.experimental.distributed_checkpoint import CheckpointManager, prime_optimizer
+ tracked_steps = self.chkpt_mgr.all_steps()
+ if tracked_steps:
+ rprint(f"Found tracked steps: {tracked_steps}")
+ best_step = max(tracked_steps) # Choose the highest step
+ prime_optimizer(self.optimizer) # Before restoring the checkpoint, the optimizer state must be primed to allow state to be loaded into it.
+ state_dict = {'model': self.accelerator.unwrap_model(self.backbone).state_dict(), 'optim': self.optimizer.state_dict()}
+ self.chkpt_mgr.restore(best_step, state_dict)
+ self.backbone.load_state_dict(state_dict['model'])
+ self.optimizer.load_state_dict(state_dict['optim'])
+ else:
+ import os
+ folder_contents = os.listdir(ckpt_path)
+ gprint(f"Contents of the folder {ckpt_path}: {folder_contents}")
+ self.accelerator.load_state(ckpt_path, strict=self.config.trainer.disable_strict_load is False)
+
+ elif ckpt_path is not None:
+ rprint(f"WARNING: Checkpoint {ckpt_path} does not exist")
+
+ if getattr(self.config.trainer, "reset_lr_scheduler_step", False):
+ with open_dict(self.config):
+ with read_write(self.config):
+ rprint(f"Resetting lr scheduler")
+ if getattr(self.config.trainer, "global_num_warmup_steps", None) is not None:
+ self.config.lr_scheduler.num_warmup_steps = self.config.trainer.global_num_warmup_steps
+ rprint(f"Set num_warmup_steps to {self.config.lr_scheduler.num_warmup_steps}")
+
+ if getattr(self.config.trainer, "global_num_training_steps", None) is not None:
+ self.config.lr_scheduler.num_training_steps = self.config.trainer.global_num_training_steps
+ rprint(f"Set num_training_steps to {self.config.lr_scheduler.num_training_steps}")
+ if not self.config.trainer.disable_adjust_num_warmup_steps:
+ _world_size = 1 if (is_xla_available and self.config.trainer.xla_spmd) else self.world_size
+ rprint(f"Warmup steps was {self.config.lr_scheduler.num_warmup_steps}")
+ self.config.lr_scheduler.num_warmup_steps = self.config.lr_scheduler.num_warmup_steps * _world_size
+ rprint(f"Warmup steps is now {self.config.lr_scheduler.num_warmup_steps}, world size is {_world_size}")
+
+ if hasattr(self.config.lr_scheduler, "num_training_steps"):
+ rprint(f"num_training_steps was: {self.config.lr_scheduler.num_training_steps}. Applying to num_training_steps")
+ self.config.lr_scheduler.num_training_steps = self.config.trainer.global_num_training_steps * _world_size
+
+ rprint(f"Set num_warmup_steps to {self.config.lr_scheduler.num_warmup_steps}")
+ if getattr(self.config.trainer, "global_num_training_steps", None) is not None:
+ rprint(f"Set num_training_steps to {self.config.lr_scheduler.num_training_steps}")
+ self.lr_scheduler.scheduler = hydra.utils.instantiate(self.config.lr_scheduler, optimizer=self.lr_scheduler.scheduler.optimizer)
+ rprint(self.lr_scheduler.scheduler.__dict__)
+ rprint(self.lr_scheduler.scheduler.state_dict())
+ rprint("WARNING!!! Resetting lr scheduler")
+ elif getattr(self.config.trainer, "force_reset_optimizer_lr_scheduler", False):
+ self.init_optimizer_lr_scheduler()
+ self.lr_scheduler, self.optimizer = self.accelerator.prepare(self.lr_scheduler, self.optimizer)
+
+def set_callbacks(self):
+ from torchtnt.framework._callback_handler import CallbackHandler
+
+ from unidisc.utils.throughput_monitor import ThroughputMonitor
+
+ precomputed_flops_per_sample = {}
+ _flops_per_sample = precomputed_flops_per_sample.get(self.config.model.name, 0)
+ if _flops_per_sample == 0 or self.config.backbone != 'dit':
+ # Assume approx 6ND for decoder transformer model
+ _flops_per_sample = 6 * self.config.model.length * self.non_embedding_params
+
+ if self.config.trainer.xla_spmd and is_xla_available:
+ _flops_per_sample /= self.world_size
+
+ callbacks = []
+ callbacks.append(
+ ThroughputMonitor(
+ batch_size_fn=None,
+ length_fn=None,
+ log_every_n_steps=50,
+ window_size=2,
+ separator="_",
+ world_size=1 if self.config.trainer.xla_spmd else self.world_size,
+ device=self.device,
+ dtype=self.dtype,
+ flops_per_sample=_flops_per_sample
+ )
+ )
+
+ self.cb_handler = CallbackHandler(callbacks)
+
+@try_except(write_error_to_file=True)
+def checkpoint(self, state: TrainingState = None):
+ if is_torch_xla_available():
+ gprint("Saving checkpoint on XLA...")
+
+ self.on_train_resume() # In case we start checkpointing in the middle of validation
+
+ checkpoint_all_ranks = self.config.trainer.checkpoint_all_ranks
+ if (not is_main_process()) and checkpoint_all_ranks is False:
+ return
+
+ if self.current_run_global_step < 200 and self.config.trainer.skip_early_checkpointing:
+ rprint("Skipping checkpointing for the first 200 steps...")
+ return
+
+ if self.config.trainer.disable_all_checkpointing:
+ rprint("Disabled all checkpointing...")
+ return
+
+ start_time = time.time()
+ if self.config.trainer.use_simple_spmd_distributed_checkpointing and is_xla_available:
+ import torch.distributed.checkpoint as dist_cp
+ import torch_xla.experimental.distributed_checkpoint as xc
+ gprint("Saving checkpoint...0")
+ import torch_xla.core.xla_model as xm
+ xm.mark_step()
+ gprint("Saving checkpoint...1")
+ xm.wait_device_ops()
+ gprint("Saving checkpoint...2")
+ CHECKPOINT_DIR = Path(self.config.checkpointing.save_dir) / f"checkpoint_{self.global_step}"
+ gprint("Saving checkpoint...4")
+
+ if is_main_process():
+ gprint(f"Clearing old checkpoints")
+ handle_checkpointing_dirs(self.config, prefix="checkpoint")
+ gprint(f"Finished clearing old checkpoints")
+
+ state_dict = {
+ "model": self.backbone.state_dict(),
+ }
+ if not self.config.trainer.ckpt_model_only:
+ gprint("Saving optimizer state dict")
+ state_dict["optim"] = self.optimizer.state_dict()
+
+ gprint(f"Saving checkpoint...5 to {CHECKPOINT_DIR}")
+ dist_cp.save(
+ state_dict=state_dict,
+ storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
+ planner=xc.SPMDSavePlanner(),
+ )
+ if is_main_process():
+ from main import save_config_to_ckpt
+ save_config_to_ckpt(self.config, CHECKPOINT_DIR, self)
+ gprint("Saving checkpoint...6")
+ elif self.config.checkpointing.use_automatic_naming:
+ rprint("Saving checkpoint...")
+ self.accelerator.save_state()
+ rprint("Saved checkpoint...")
+ else:
+ rprint(f"Saving checkpoint...")
+ prefix = "checkpoint"
+ Path(self.config.checkpointing.save_dir).mkdir(exist_ok=True, parents=True)
+
+ if is_main_process():
+ handle_checkpointing_dirs(self.config, prefix="checkpoint")
+
+ save_path = Path(self.config.checkpointing.save_dir) / f"{prefix}_{self.global_step}"
+ save_path.mkdir(exist_ok=True, parents=True)
+
+ if checkpoint_all_ranks:
+ barrier()
+
+ if self.config.trainer.ckpt_model_only:
+ from safetensors.torch import save_file, save_model
+ try:
+ self.accelerator.save_model(self.backbone, save_path)
+ except Exception as e:
+ rprint(f"Failed to save model with 'save_file': {e}")
+ if getattr(self.config.trainer, 'finetuning_mode', False):
+ rprint("Fallback to 'save_model' instead")
+ if is_main_process():
+ save_model(self.backbone, save_path / "model.safetensors")
+ else:
+ try:
+ self.accelerator.save_state(save_path)
+ except Exception as e:
+ from traceback import print_exc
+ print_exc()
+ gprint(f"Failed to save state: {e}, saving model instead")
+ self.accelerator.save_model(self.backbone, save_path)
+ gprint("Saved model instead")
+
+ if checkpoint_all_ranks:
+ barrier()
+
+ rprint(f"Saved checkpoint to: {save_path}")
+ with try_except(write_error_to_file=True, clear_cuda_cache=True):
+ self.print_hashes()
+
+ rprint(f"Checkpointing took: {time.time() - start_time} seconds")
+
+def print_hashes(self):
+ if self.config.trainer.fsdp:
+ rprint('Skipping module hash for FSDP')
+ return
+
+ rprint(f"Module hash: {module_hash(self.backbone)}")
+ if self.ema is not None:
+ if self.config.trainer.use_custom_ema:
+ rprint(f"EMA hash: {module_hash(self.ema)}")
+ else:
+ rprint(f"EMA hash: {parameter_hash(self.ema.state_dict()['shadow_params'])}")
+
+@try_except(write_error_to_file=True)
+def on_train_step_end(self, state: TrainingState):
+ self.cb_handler.on_train_step_end(state=state, unit=self)
+ del state.batch
+ tr = self.config.trainer
+ if check_every_n_steps(
+ state, tr.val_check_interval, run_first=tr.eval_on_start, all_processes=True, decay_steps=tr.eval_decay_steps
+ ) or check_every_n_epochs(state, tr.eval_epochs, all_processes=True):
+ rprint(f"Starting validation at {state.global_step}...")
+ with show_memory_usage():
+ with try_except(write_error_to_file=True, clear_cuda_cache=True):
+ with nullcontext() if is_xla_available else (torch.no_grad() if getattr(self.config.trainer, "force_disable_inference_mode", False) else torch.inference_mode()):
+ self.validate(state)
+ self.on_validation_epoch_cleanup()
+ self.num_evals += 1
+ self.on_train_resume()
+ dprint("All processes finished validation")
+
+ xla_spmd = self.config.trainer.use_spmd_distributed_checkpointing
+ if xla_spmd and self.config.trainer.disable_all_checkpointing is False and self.global_step > 10:
+ # Call every step, but only runs after n steps internally
+ gprint("Might save async checkpoint...")
+ if getattr(self.config.checkpointing, "save_optimizer_state", True):
+ state_dict = {'model': self.backbone.state_dict(), 'optim': self.optimizer.state_dict()}
+ else:
+ gprint("[WARNING] Not saving optimizer state")
+ state_dict = {'model': self.backbone.state_dict()}
+ if self.chkpt_mgr.save_async(self.global_step, state_dict):
+ gprint(f'Checkpoint taken at step {self.global_step}')
+
+ current_time = time.time()
+ if not hasattr(self, "last_checkpoint_time"):
+ self.last_checkpoint_time = current_time
+
+ checkpoint_due_to_time = (current_time - self.last_checkpoint_time) >= (tr.ckpt_every_n_minutes * 60)
+ checkpoint_due_to_step = check_every_n_steps(state, tr.ckpt_steps, run_first=False, all_processes=True)
+
+ if is_torch_cuda_available() and tr.ckpt_every_n_minutes > 0:
+ should_ckpt_all_ranks = gather_object([checkpoint_due_to_time or checkpoint_due_to_step])
+ else:
+ should_ckpt_all_ranks = [checkpoint_due_to_step]
+
+ if should_ckpt_all_ranks[0] and not xla_spmd: # To avoid timing inconsistencies, we take the value from the main process
+ rprint(f"Saving checkpoint at {self.global_step}...due to {'time' if checkpoint_due_to_time else 'step'}. Ranks thought: {should_ckpt_all_ranks}")
+ self.last_checkpoint_time = current_time
+ self.checkpoint(state)
+ rprint(f"Checkpoint saved at {self.global_step}...")
+
+def after_backward(self, state):
+ freq = getattr(self.config.trainer, "log_grad_norm_every_n_steps", 200 if self.is_compiled else 50)
+ if not is_xla_available and self.config.trainer.log_grad_norm and check_every_n_steps(state, freq, run_first=True, all_processes=False):
+ norms, total_norm = grad_norm(self.backbone, norm_type=2, group_separator="")
+ grad_norm_dict = {f"grad_norms/{k}": v for k, v in norms.items()}
+ if 'text-diffusion' in self.config.wandb.project:
+ grad_norm_dict = {k.replace("module.", ""): v for k, v in grad_norm_dict.items()}
+ log({**grad_norm_dict, "trainer/total_grad_norm": total_norm, "trainer/global_step": self.global_step})
+
+from model_utils import Loss
+def shortcut_return(self, logprobs, output_tokens, attention_mask, prefix): # For comparing to unidisc only
+ loss = -logprobs.gather( -1, output_tokens[:, :, None])[:, :, 0]
+ nlls = loss * attention_mask
+ count = attention_mask.sum()
+
+ batch_nll = nlls.sum()
+ token_nll = batch_nll / count
+
+ losses = Loss(
+ loss=token_nll,
+ img_loss=0,
+ txt_loss=0,
+ nlls=nlls,
+ txt_nlls=0,
+ img_nlls=0,
+ token_mask=attention_mask,
+ modality_mask=None,
+ extra_losses=None,
+ )
+
+ if getattr(self.config.trainer, "disable_torchmetrics", False):
+ raise NotImplementedError("Torchmetrics disabled")
+
+ elif prefix == "train":
+ return losses
+ elif prefix == "val":
+ self.valid_metrics.update(losses.nlls, losses.token_mask)
+ elif prefix == "test":
+ self.test_metrics.update(losses.nlls, losses.token_mask)
+ metrics = self.test_metrics
+ self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True)
+ else:
+ raise ValueError(f"Invalid prefix: {prefix}")
+
+def unwrap_model(self, model):
+ from diffusers.utils.torch_utils import is_compiled_module
+ model = self.accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
\ No newline at end of file
diff --git a/model_utils.py b/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd2623a26cecb3cd2409540e051c8807978cae65
--- /dev/null
+++ b/model_utils.py
@@ -0,0 +1,844 @@
+import math
+import os
+import random
+import typing
+from contextlib import nullcontext
+from dataclasses import dataclass
+from pathlib import Path
+from types import FrameType
+from typing import Dict, List, Optional, Tuple, Union
+
+import einops
+import hydra
+import hydra.utils
+import numpy as np
+import pandas as pd
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torchmetrics
+import transformers
+from image_utils import Im
+from torch import Tensor, nn
+from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
+from tqdm.auto import tqdm
+
+import models
+import wandb
+from decoupled_utils import (Profiler, barrier, dprint, get_rank,
+ get_slurm_job_id, get_world_size, gprint,
+ is_local_main_process, is_main_process,
+ is_torch_cuda_available, is_torch_xla_available,
+ module_hash, mprint, parameter_hash, print_memory,
+ rank_zero_fn, rprint, save_memory_profile,
+ show_memory_usage, try_except, use_dist)
+
+is_xla_available = is_torch_xla_available()
+if is_xla_available:
+ from unidisc.utils.standalone_metrics import MeanMetric, MetricCollection
+else:
+ from torchmetrics import MetricCollection
+ from torchmetrics.aggregation import MeanMetric
+
+LOG2 = math.log(2)
+
+@try_except(write_error_to_file=True)
+def log(*arg, **kwargs):
+ for key, value in arg[0].items():
+ if isinstance(value, torch.Tensor):
+ arg[0][key] = value.detach().cpu().float()
+
+ if is_main_process():
+ wandb.log(*arg, **kwargs)
+
+def replace_nan_dict(x):
+ return {k: v.nan_to_num(0) for k, v in x.items()}
+
+def ddprint(*args, **kwargs):
+ mprint(*args, **kwargs)
+
+def empty_device_cache():
+ if is_torch_cuda_available():
+ torch.cuda.empty_cache()
+ else:
+ dprint("Not using cuda, skipping cache clear")
+
+def update_logs(_logs, _extra_logs):
+ _logs.update(_extra_logs())
+ for k, v in _logs.items():
+ if isinstance(v, torch.Tensor):
+ _logs[k] = v.detach().cpu().item()
+ gprint(f"Converting {k} to item: {v}")
+
+ log(_logs)
+
+def ema_update(model_dest: nn.Module, model_src: nn.Module, rate):
+ param_dict_src = dict(model_src.named_parameters())
+ for p_name, p_dest in model_dest.named_parameters():
+ if p_name not in param_dict_src:
+ print(f"Parameter {p_name} not found in src: {param_dict_src}")
+ p_src = param_dict_src[p_name]
+ assert p_src is not p_dest
+ p_dest.data.mul_(rate).add_((1 - rate) * p_src.data)
+
+def identity(x):
+ return x
+
+def remap_image_torch(image):
+ image_torch = image * 255
+ image_torch = torch.clip(image_torch, 0, 255).to(torch.uint8)
+ return image_torch
+
+def _sample_categorical(categorical_probs):
+ gumbel_norm = 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()
+ return (categorical_probs / gumbel_norm).argmax(dim=-1)
+
+def wrapped_batch_decode(tokenizer, tokens, disable_mask_after_eos=False, **kwargs):
+ tokens = tokens.clone()
+ if (tokenizer.bos_token_id != tokenizer.eos_token_id) and not disable_mask_after_eos:
+ after_first_five = torch.cumsum(tokens == tokenizer.eos_token_id, dim=1).bool()
+ tokens[after_first_five.cumsum(dim=1) > 1] = tokenizer.pad_token_id
+ return tokenizer.batch_decode(tokens, **kwargs)
+
+def _unsqueeze(x, reference):
+ return x.view(*x.shape, *((1,) * (len(reference.shape) - len(x.shape))))
+
+
+@dataclass
+class Loss:
+ loss: torch.FloatTensor
+ img_loss: torch.FloatTensor = None
+ txt_loss: torch.FloatTensor = None
+ nlls: torch.FloatTensor = None
+ token_mask: torch.FloatTensor = None
+ txt_nlls: torch.FloatTensor = None
+ img_nlls: torch.FloatTensor = None
+ extra_losses: dict = None
+ modality_mask: torch.FloatTensor = None
+
+
+class NLL(MeanMetric):
+ pass
+
+
+class BPD(NLL):
+ def compute(self) -> Tensor:
+ """Computes the bits per dimension.
+
+ Returns:
+ bpd
+ """
+ return self.mean_value / self.weight / LOG2
+
+
+class Perplexity(NLL):
+ def compute(self) -> Tensor:
+ """Computes the Perplexity.
+
+ Returns:
+ Perplexity
+ """
+ return torch.exp(self.mean_value / self.weight)
+
+class Entropy(NLL):
+ def compute(self) -> Tensor:
+ """Computes the Entropy.
+
+ Returns:
+ Entropy
+ """
+ return self.mean_value / self.weight
+
+class MauveScore(NLL):
+ def compute(self) -> Tensor:
+ """Computes the Mauve Score.
+
+ Returns:
+ Mauve Score
+ """
+ return self.mean_value / self.weight
+
+class CIDErScore(NLL):
+ def compute(self) -> Tensor:
+ """Computes the CIDEr Score.
+
+ Returns:
+ CIDEr Score
+ """
+ return self.mean_value / self.weight
+
+class Accuracy(NLL):
+ def compute(self) -> Tensor:
+ """Computes the Accuracy.
+
+ Returns:
+ Accuracy
+ """
+ return self.mean_value / self.weight
+
+def get_coord_plot(self):
+ from mup.coord_check import get_coord_data, plot_coord_data
+ def gen(w):
+ def f():
+ from copy import deepcopy
+
+ from omegaconf import read_write
+
+ import models as _models
+ _conf = deepcopy(self.config)
+ with read_write(_conf):
+ _conf.model.hidden_size = _conf.model.n_heads * w
+
+ _backbone = _models.dit.DIT(
+ _conf, vocab_size=self.vocab_size, mask_index=self.mask_index, text_vocab_size=self.text_vocab_size, dtype=self.dtype
+ )
+ self.get_base_shapes_for_mup(_backbone)
+ return _backbone
+ return f
+
+ optimizer = 'adamw'
+ widths = np.array([2**i for i in range(2, 6)])
+ models = {int(w) * self.config.model.n_heads: gen(int(w)) for w in widths}
+
+ fake_dataloader = []
+ self.validation_dataloader.num_workers = 0
+ nsteps = 30
+ for i, dataloader_batch in enumerate(self.validation_dataloader):
+ fake_batch = self.update_batch(dataloader_batch)
+ fake_batch['x0'] = fake_batch["input_ids"]
+ t = self._sample_t(fake_batch['x0'].shape[0], fake_batch['x0'].device)
+ sigma, dsigma = self.noise(t)
+ move_chance = 1 - torch.exp(-sigma[:, None])
+ xt = self.q_xt(fake_batch['x0'], move_chance)
+ fake_batch['xt'] = xt
+
+ fake_dataloader.append(fake_batch)
+ if i >= nsteps:
+ break
+
+ def loss_fn(_batch, _logits):
+ attention_mask = _batch['attention_mask']
+ model_output = self._subs_parameterization(logits=_logits, xt=_batch['xt'])
+ log_p_theta = torch.gather(input=model_output, dim=-1, index=_batch['x0'][:, :, None]).squeeze(-1)
+ std_weighting = (dsigma / torch.expm1(sigma))[:, None]
+ loss = -log_p_theta * std_weighting
+ loss = (loss * attention_mask).sum() / attention_mask.sum()
+ return loss
+
+ mup = True
+ lr = 1e-2
+ prm = 'μP' if mup else 'SP'
+ nseeds = 2
+ with torch.autocast(device_type=self.device.type, dtype=self.dtype):
+ df = get_coord_data(
+ models,
+ fake_dataloader,
+ lr=lr,
+ optimizer=optimizer,
+ nsteps=nsteps,
+ nseeds=nseeds,
+ dict_in_out=True,
+ lossfn=loss_fn,
+ mup=mup,
+ )
+
+ output_path = Path(__file__).parent / 'output' / f'{prm.lower()}_trsfmr_{optimizer}_coord.png'
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ plot_coord_data(
+ df,
+ legend='brief',
+ save_to=str(output_path.resolve()),
+ suptitle=f'{prm} Transformer {optimizer} lr={lr} nseeds={nseeds}',
+ face_color='xkcd:light grey' if not mup else None,
+ loglog=True
+ )
+ rprint(f"Saved coord plot to {output_path.resolve()}")
+
+ csv_path = output_path.with_suffix('.csv')
+ df.to_csv(csv_path, index=False)
+ rprint(f"DataFrame saved as CSV to {csv_path.resolve()}")
+
+ result = df[df['t'] == 1].nsmallest(100, 'l1').sort_values('l1', ascending=True)
+ with pd.option_context('display.max_rows', None, 'display.max_columns', None):
+ print(result[['module', 'width', 'l1']])
+ exit()
+
+def _score_entropy(self, log_score, sigma, xt, x0):
+ """Computes the SEDD loss.
+
+ Args:
+ log_score: float torch.Tensor with shape (batch_size,
+ diffusion_model_input_length, vocab_size),
+ log score, output of the denoising network.
+ xt: int torch.Tensor with shape (batch_size,
+ diffusion_model_input_length), input.
+ x0: int torch.Tensor with shape (batch_size,
+ diffusion_model_input_length), input.
+ sigma: float torch.Tensor with shape (batch_size, 1).
+
+ Returns:
+ loss with shape (batch_size, diffusion_model_input_length)
+ """
+ masked_indices = xt == self.mask_index
+
+ expsig_minus_1 = torch.expm1(sigma).expand_as(xt)
+ q_ratio = 1 / expsig_minus_1[masked_indices]
+
+ words_that_were_masked = x0[masked_indices]
+
+ neg_term = q_ratio * torch.gather(log_score[masked_indices], -1, words_that_were_masked[..., None]).squeeze(-1)
+ score = log_score[masked_indices].exp()
+ if self.mask_index == self.vocab_size - 1:
+ pos_term = score[:, :-1].sum(dim=-1)
+ else:
+ pos_term = score[:, : self.mask_index].sum(dim=-1) + score[:, self.mask_index + 1 :].sum(dim=-1)
+ const = q_ratio * (q_ratio.log() - 1)
+
+ entropy = torch.zeros(*xt.shape, device=xt.device)
+ entropy[masked_indices] += pos_term - neg_term + const
+ return entropy
+
+@torch.no_grad
+def sample_subs_guidance(self, n_samples, stride_length, num_strides, dt=0.001):
+ ones = torch.ones(n_samples, dtype=self.dtype, device=self.device)
+
+ num_steps = int(1 / dt)
+ sampling_steps = 0
+ intermediate_tokens = []
+ target = None
+ for _ in range(num_strides + 1):
+ p_x0_cache = None
+ x = self._sample_prior(n_samples, self.config.model.length).to(self.device)
+ if target is not None:
+ x[:, :-stride_length] = target
+ for i in range(num_steps + 1):
+ p_x0_cache, x_next, nfe_cnt = self._ddpm_caching_update(x=x, t=(1 - i * dt) * ones, dt=dt, p_x0=p_x0_cache)
+ if not torch.allclose(x_next, x) or self.time_conditioning:
+ p_x0_cache = None
+ sampling_steps += 1
+ x = x_next
+ x = self.forward(x, 0 * ones).argmax(dim=-1)
+ intermediate_tokens.append(x[:, :stride_length].cpu().numpy())
+ target = x[:, stride_length:]
+
+ intermediate_tokens.append(target.cpu().numpy())
+ intermediate_text_samples = []
+ sequence_lengths = ((np.concatenate(intermediate_tokens, axis=1)[:, 1:] == self.tokenizer.eos_token_id).cumsum(-1) == 0).sum(-1)
+ for i in range(2, len(intermediate_tokens) + 1):
+ intermediate_text_samples.append(self.tokenizer.batch_decode(np.concatenate(intermediate_tokens[:i], axis=1)))
+ return (sampling_steps, intermediate_text_samples, sequence_lengths)
+
+def restore_model_and_semi_ar_sample(self, stride_length, num_strides, dt=0.001):
+ """Generate samples from the model."""
+ # Lightning auto-casting is not working in this method for some reason
+ if self.ema:
+ self.ema.store(self.get_params())
+ self.ema.copy_to(self.get_params())
+ self.backbone.eval()
+ (sampling_steps, samples, sequence_lengths) = self.sample_subs_guidance(
+ n_samples=self.config.loader.eval_batch_size, stride_length=stride_length, num_strides=num_strides, dt=dt
+ )
+ if self.ema:
+ self.ema.restore(self.get_params())
+ self.backbone.train()
+ self.noise.train()
+ return sampling_steps, samples, sequence_lengths
+
+def _reconstruction_loss(self, x0):
+ t0 = torch.zeros(x0.shape[0], dtype=self.dtype, device=self.device)
+ assert self.config.noise.type == "loglinear"
+ # The above assert is for d3pm parameterization
+ unet_conditioning = self.noise(t0)[0][:, None]
+ model_output_t0 = self.forward(x0, unet_conditioning)
+ return -torch.gather(input=model_output_t0, dim=-1, index=x0[:, :, None]).squeeze(-1)
+
+def restore_model_and_sample(self, num_steps, eps=1e-5):
+ """Generate samples from the model."""
+ # Lightning auto-casting is not working in this method for some reason
+ if self.ema is not None:
+ self.ema.store(self.get_params())
+ self.ema.copy_to(self.get_params())
+ self.backbone.eval()
+ samples = self._sample(num_steps=num_steps, eps=eps)
+ if self.ema is not None:
+ self.ema.restore(self.get_params())
+ self.backbone.train()
+ return samples
+
+def get_score(self, x, sigma, **kwargs):
+ model_output = self.forward(x, sigma, **kwargs)
+ if self.parameterization == "subs":
+ # score(x, t) = p_t(y) / p_t(x)
+ # => log score(x, t) = log p_t(y) - log p_t(x)
+
+ # case 1: x = masked
+ # (i) y = unmasked
+ # log score(x, t) = log p_\theta(x)|_y + log k
+ # where k = exp(- sigma) / (1 - exp(- sigma))
+ # (ii) y = masked
+ # log score(x, t) = 0
+
+ # case 2: x = unmasked
+ # (i) y != masked, y != x
+ # log score(x_i, t) = - inf
+ # (ii) y = x
+ # log score(x_i, t) = 0
+ # (iii) y = masked token
+ # log score(x_i, t) = - log k
+ # where k = exp(- sigma) / (1 - exp(- sigma))
+
+ log_k = -torch.log(torch.expm1(sigma)).squeeze(-1)
+ assert log_k.ndim == 1
+
+ masked_score = model_output + log_k[:, None, None]
+ masked_score[:, :, self.mask_index] = 0
+
+ unmasked_score = self.neg_infinity * torch.ones_like(model_output)
+ unmasked_score = torch.scatter(unmasked_score, -1, x[..., None], torch.zeros_like(unmasked_score[..., :1]))
+ unmasked_score[:, :, self.mask_index] = -(log_k[:, None] * torch.ones_like(x))
+
+ masked_indices = (x == self.mask_index).to(model_output.dtype)[:, :, None]
+ model_output = masked_score * masked_indices + unmasked_score * (1 - masked_indices)
+ return model_output.exp()
+
+def _staggered_score(self, score, dsigma):
+ score = score.clone()
+ extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
+ score *= dsigma.exp()[:, None]
+ score[..., self.mask_index] += extra_const
+ return score
+
+def _analytic_update(self, x, t, step_size):
+ curr_sigma, _ = self.noise(t)
+ next_sigma, _ = self.noise(t - step_size)
+ dsigma = curr_sigma - next_sigma
+ nfe_cnt = 0
+ score = self.get_score(x, curr_sigma)
+ nfe_cnt += 1
+ stag_score = self._staggered_score(score, dsigma)
+ probs = stag_score * self._transp_transition(x, dsigma)
+ return _sample_categorical(probs), nfe_cnt
+
+def _denoiser_update(self, x, t):
+ sigma, _ = self.noise(t)
+ score = self.get_score(x, sigma)
+ stag_score = self._staggered_score(score, sigma)
+ probs = stag_score * self._transp_transition(x, sigma)
+ probs[..., self.mask_index] = 0
+ samples = _sample_categorical(probs)
+ return samples
+
+def _transp_transition(self, i, sigma):
+ sigma = _unsqueeze(sigma, reference=i[..., None])
+ edge = torch.exp(-sigma) * F.one_hot(i, num_classes=self.vocab_size)
+ edge += torch.where(i == self.mask_index, 1 - torch.exp(-sigma).squeeze(-1), 0)[..., None]
+ return edge
+
+@torch.no_grad()
+def eval_retokenize(self, text_samples, max_length):
+ """Retokenizes samples for the eval model.
+
+ Args:
+ text_samples: List of sentences generated by the model.
+ Returns:
+ samples: Samples re-tokenized for the eval model
+ attn_mask: Attention mask for the eval model
+ eval_context_size: Size of the context for the eval model
+ """
+ if "llama2" in self.gen_ppl_eval_model_name_or_path:
+ tokenizer_kwargs = {
+ "text_samples": text_samples,
+ "return_tensors": "pt",
+ "return_token_type_ids": False,
+ "return_attention_mask": True,
+ "truncation": True,
+ "padding": True,
+ "max_length": max_length,
+ }
+ eval_context_size = 4096
+ else:
+ tokenizer_kwargs = {
+ "return_tensors": "pt",
+ "return_token_type_ids": False,
+ "return_attention_mask": True,
+ "truncation": True,
+ "padding": True,
+ "max_length": max_length,
+ }
+ eval_context_size = 1024
+
+ if getattr(self.config.eval, "force_eval_context_size_match_model", False):
+ eval_context_size = self.config.model.txt_length
+
+ samples = self.eval_model_tokenizer(text_samples, **tokenizer_kwargs)
+ attn_mask = samples["attention_mask"]
+ samples = samples["input_ids"]
+ if "llama2" not in self.gen_ppl_eval_model_name_or_path:
+ attn_mask = attn_mask.to(self.device)
+ samples = samples.to(self.device)
+ return samples, attn_mask, eval_context_size
+
+
+@try_except(write_error_to_file=True)
+@torch.no_grad()
+def compute_cider(self, text_samples, gt_text_samples):
+ """Compute the CIDEr score for the generated text.
+ Args:
+ text_samples: List of sentences generated by the model.
+ gt_text_samples: List of ground truth sentences.
+ Returns:
+ CIDEr score for the generated text.
+ """
+ for text_sample, gt_text_sample in zip(text_samples, gt_text_samples):
+ self.cider_score_metric += (text_sample, gt_text_sample)
+ score = self.cider_score_metric.compute_cider() # list of np.float64
+ avg_score = sum(score) / len(score)
+ self.cider_score.update(avg_score.item()) # weight=len(text_samples))
+
+
+def get_anole_data(model, processor, prompt, image, device):
+ inputs = processor(prompt, [image], padding=True, return_tensors="pt").to(device=device, dtype=dtype)
+ image_tokens = model.model.get_image_tokens(inputs["pixel_values"])
+ special_image_mask = inputs["input_ids"] == model.model.vocabulary_mapping.image_token_id
+ image_tokens = image_tokens.to(inputs["input_ids"].device, inputs["input_ids"].dtype)
+ inputs["input_ids"] = inputs["input_ids"].masked_scatter(special_image_mask, image_tokens)
+ inputs.pop("pixel_values")
+ inputs['input_ids'] = torch.load('save.pth').to(device)
+ return inputs
+
+@try_except(write_error_to_file=True)
+@torch.inference_mode()
+def compute_generative_perplexity(self, text_samples: typing.List[str], retokenize: bool = True, max_length: typing.Optional[int] = None, gt: bool = False, return_raw_score: bool = False) -> None:
+ """Compute the generative perplexity of the model.
+
+ Args:
+ text_samples: List of sentences generated by the model.
+ retokenize: Whether to retokenize using eval model's tokenizer
+ max_length: Maximum sequence length for tokenization
+ gt: Whether these are ground truth samples
+ return_raw_score: Whether to return raw NLL scores instead of updating metrics
+
+ Returns:
+ If return_raw_score is True, returns tensor of NLL scores.
+ Otherwise updates internal perplexity metrics.
+ """
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
+ if not getattr(self.config.eval, 'enable_gen_pplx_cleanup', True):
+ eval_model = self.gen_pplx_eval_model
+ elif getattr(self.config.eval, 'gen_ppl_use_chameleon', False):
+ from transformers import (ChameleonForConditionalGeneration,
+ ChameleonProcessor)
+ model = ChameleonForConditionalGeneration.from_pretrained("leloy/Anole-7b-v0.1-hf", torch_dtype=torch.bfloat16).to("cuda")
+ processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf")
+ image = Im(Im("https://cdn.outsideonline.com/wp-content/uploads/2023/03/Funny_Dog_H.jpg").np[50:-150, 550:-900, :]).resize(256, 256).pil
+ prompt = "A picture of a cat."
+ device = "cuda:0"
+ inputs = get_anole_data(model, processor, prompt, image, self.dtype, device)
+ output = model(input_ids=inputs['input_ids'].to(device))
+ attention_mask = torch.ones_like(inputs["input_ids"])
+ logits = output.logits
+ logits = logits.transpose(-1, -2)
+ sample_chunk = inputs["input_ids"]
+ nlls = F.cross_entropy(logits[..., :-1].to(device), sample_chunk[..., 1:].to(device), reduction="none")
+ nlls = nlls * attention_mask[..., 1:].to(nlls.dtype)
+ nlls = nlls.sum(-1) / attention_mask[..., 1:].sum(-1)
+ print(torch.exp(nlls))
+ else:
+ eval_model = transformers.AutoModelForCausalLM.from_pretrained(self.gen_ppl_eval_model_name_or_path).eval()
+ if max_length is None:
+ max_length = self.config.model.txt_length
+
+ if "llama2" not in self.gen_ppl_eval_model_name_or_path:
+ eval_model = eval_model.to(self.device)
+
+ # Re-tokenize using eval model's tokenizer
+ if retokenize:
+ (samples, attn_mask, eval_context_size) = self.eval_retokenize(text_samples, max_length=max_length)
+ else:
+ samples = text_samples
+ attn_mask = torch.ones(samples.shape).to(self.device)
+ eval_context_size = samples.shape[-1]
+
+ batch_size = min(self.config.eval.perplexity_batch_size, samples.shape[0])
+ num_batches = (samples.shape[0] + batch_size - 1) // batch_size
+ all_nlls = []
+ all_valid_mask = []
+ for i in range(num_batches):
+ batch_samples = samples[i * batch_size : (i + 1) * batch_size]
+ batch_attn_mask = attn_mask[i * batch_size : (i + 1) * batch_size]
+
+ with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.FLASH_ATTENTION, torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION]):
+ logits = eval_model(batch_samples, attention_mask=batch_attn_mask)[0]
+
+ logits = logits.transpose(-1, -2)
+ nlls = F.cross_entropy(logits[..., :-1], batch_samples[..., 1:], reduction="none")
+
+ # Only consider tokens up to first EOS or padding
+ first_eos = (batch_samples == self.eval_model_tokenizer.eos_token_id).cumsum(-1) <= 1
+ token_mask = batch_attn_mask[..., 1:] > 0
+ valid_mask = first_eos[..., 1:] * token_mask
+
+ if not return_raw_score:
+ if gt:
+ self.gt_gen_ppl_metric.update(nlls, valid_mask)
+ else:
+ self.gen_ppl_metric.update(nlls, valid_mask)
+ else:
+ all_nlls.append(nlls)
+ all_valid_mask.append(valid_mask)
+
+ if getattr(self.config.eval, 'enable_gen_pplx_cleanup', True):
+ eval_model.to(torch.device('cpu'))
+ del eval_model
+
+ if return_raw_score:
+ all_nlls = torch.cat(all_nlls)
+ all_valid_mask = torch.cat(all_valid_mask)
+ # Compute mean NLL per sequence, ignoring padding/post-EOS tokens
+ nll = (all_nlls * all_valid_mask).sum(-1) / all_valid_mask.sum(-1)
+ return nll
+
+def _d3pm_loss(self, model_output, xt, x0, t):
+ dt = 1 / self.T
+
+ if torch.is_tensor(t):
+ t = t[:, None]
+ assert t.ndim == 2
+ t = t.clamp(0.0, 1.0 - 1e-4)
+ alpha_t = 1 - t + torch.zeros_like(xt)
+ alpha_s = 1 - (t - dt) + torch.zeros_like(xt)
+
+ log_x_theta_at_x0 = torch.gather(model_output, -1, x0[:, :, None]).squeeze(-1)
+ log_x_theta_at_m = model_output[:, :, self.mask_index]
+ x_theta_at_m = log_x_theta_at_m.exp()
+
+ term_1_coef = dt / t
+ term_1_log_nr = torch.log(alpha_t * x_theta_at_m / t + 1)
+ term_1_log_dr = log_x_theta_at_x0
+
+ term_2_coef = 1 - dt / t
+ term_2_log_nr = term_1_log_nr
+ term_2_log_dr = torch.log(alpha_s * x_theta_at_m / (t - dt) + 1)
+
+ L_vb_masked = term_1_coef * (term_1_log_nr - term_1_log_dr) + term_2_coef * (term_2_log_nr - term_2_log_dr)
+
+ L_vb = L_vb_masked * (xt == self.mask_index)
+
+ return self.T * L_vb
+
+def _d3pm_parameterization(self, logits):
+ if self.subs_masking:
+ logits[:, :, self.mask_index] += self.neg_infinity
+ logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
+ return logits
+
+def _sedd_parameterization(self, logits, xt, sigma):
+ esigm1_log = torch.where(sigma < 0.5, torch.expm1(sigma), sigma.exp() - 1).log().to(logits.dtype)
+ # logits shape
+ # (batch_size, diffusion_model_input_length, vocab_size)
+ logits = logits - esigm1_log[:, None, None] - np.log(logits.shape[-1] - 1)
+ # The below scatter operation sets the log score
+ # for the input word to 0.
+ logits = torch.scatter(logits, -1, xt[..., None], torch.zeros_like(logits[..., :1]))
+ return logits
+
+def get_base_shapes_for_mup(self, _model):
+ from copy import deepcopy
+
+ from mup import set_base_shapes
+ from omegaconf import read_write
+
+ base_config = deepcopy(self.config)
+ with read_write(base_config):
+ base_config.model.hidden_size = base_config.model.n_heads # We need at least n_heads dim
+
+ delta_config = deepcopy(base_config)
+ with read_write(delta_config):
+ delta_config.model.hidden_size = base_config.model.n_heads * 2
+
+ base_model = models.dit.DIT(
+ base_config, vocab_size=self.vocab_size, mask_index=self.mask_index, text_vocab_size=self.text_vocab_size, dtype=self.dtype
+ )
+
+ delta_model = models.dit.DIT(
+ delta_config, vocab_size=self.vocab_size, mask_index=self.mask_index, text_vocab_size=self.text_vocab_size, dtype=self.dtype
+ )
+
+ set_base_shapes(_model, base_model, delta=delta_model)
+
+
+def update_histogram(histogram, timesteps: torch.Tensor, losses: torch.Tensor):
+ for t, l in zip(timesteps, losses):
+ if t.item() in histogram:
+ histogram[t.item()].append(l.item())
+ else:
+ histogram[t.item()] = [l.item()]
+
+def _maybe_sub_sample(self, x0, attention_mask):
+ seqlen = x0.shape[1]
+ if seqlen > self.config.model.length:
+ if not getattr(self.config.eval, 'big_seq_len_eval', False):
+ assert seqlen == 2 * self.config.model.length
+ # cropping is needed for text8-crop dataset
+ # try the same starting point for now
+ start = np.random.choice(self.config.model.length)
+ end = start + self.config.model.length
+ input_tokens = x0[:, start:end]
+ output_tokens = x0[:, start + 1 : end + 1]
+ new_attention_mask = attention_mask[:, start:end]
+
+ # Helps with validation PPL, since the val
+ # examples will all start and end with BOS/EOS
+ input_tokens[:, 0] = self.tokenizer.bos_token_id
+ output_tokens[:, -1] = self.tokenizer.eos_token_id
+ else:
+ input_tokens = x0
+ output_tokens = None
+ new_attention_mask = attention_mask
+ return input_tokens, output_tokens, new_attention_mask
+
+from unidisc.tokenizers.image_tokenizers import decode_latents
+
+
+def viz_images_from_dataloader(self):
+ _iter = iter(self.train_dataloader)
+ random_elements = [next(_iter) for _ in range(10)]
+ # random_elements[0]['input_ids'] - self.text_vocab_size
+ out = decode_latents(self.config, self.get_vae(), torch.cat([torch.zeros_like(random_elements[0]['input_ids'][:, :1]), (random_elements[0]['input_ids'] - self.text_vocab_size)], dim=-1))
+ from image_utils import Im
+ print(Im(out[:16]).save())
+ breakpoint()
+ return random_elements
+
+try:
+ from torch.nn.attention.flex_attention import create_block_mask
+except:
+ pass
+
+def _attn_mask(txt_batch_dropout, img_batch_dropout, txt_length):
+ def mask_mod(b, h, q_idx, kv_idx):
+ txt_sees_txt = (q_idx < txt_length) & (kv_idx < txt_length)
+ img_sees_img_and_txt = (q_idx >= txt_length)
+ txt_dropout_case = ~txt_batch_dropout[b] | (txt_sees_txt | img_sees_img_and_txt)
+
+ img_sees_img = ((q_idx >= txt_length) & (kv_idx >= txt_length))
+ txt_sees_txt_and_img = (q_idx < txt_length)
+ img_dropout_case = ~img_batch_dropout[b] | (img_sees_img | txt_sees_txt_and_img)
+ return txt_dropout_case & img_dropout_case
+ return mask_mod
+
+
+def get_block_mask(txt_batch_attn_dropout, img_batch_attn_dropout, txt_length, batch_size, seq_len, device):
+ return create_block_mask(
+ _attn_mask(txt_batch_attn_dropout, img_batch_attn_dropout, txt_length),
+ B = batch_size, H = None, Q_LEN = seq_len, KV_LEN = seq_len, device = device
+ )
+
+def _interleaved_attn_mask(interleaved_sample_ids):
+ def mask_mod(b, h, q_idx, kv_idx):
+ return (interleaved_sample_ids[b, q_idx] == interleaved_sample_ids[b, kv_idx]) & (interleaved_sample_ids[b, q_idx] != -1)
+ return mask_mod
+
+def visualize_flex_attention(mask_mod, B, SEQ_LEN, H=16, HEAD_DIM=64, device="cuda"):
+ from models.archived.utils import visualize_attention_scores
+ def make_tensor():
+ return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=device)
+
+ query, key = make_tensor(), make_tensor()
+ visualize_attention_scores(
+ query,
+ key,
+ mask_mod=mask_mod,
+ device=device,
+ name="interleaved_attn_mask",
+ )
+
+def get_interleaved_block_mask(interleaved_sample_ids, batch_size, seq_len, device, visualize=False):
+ # Uncomment this to visualize the mask
+ if visualize:
+ visualize_flex_attention(_interleaved_attn_mask(interleaved_sample_ids), batch_size, seq_len, device=device)
+ if (interleaved_sample_ids == -1).all(dim=-1).any():
+ gprint(f"WARNING: Found all -1s in interleaved_sample_ids, setting one to 0")
+ interleaved_sample_ids = interleaved_sample_ids.clone()
+ interleaved_sample_ids[(interleaved_sample_ids == -1).all(dim=-1), 0] = 0
+
+ return create_block_mask(
+ _interleaved_attn_mask(interleaved_sample_ids),
+ B = batch_size, H = None, Q_LEN = seq_len, KV_LEN = seq_len, device = device
+ )
+
+def calculate_clip_score(
+ image_paths: List[str],
+ captions_mapping: Dict[str, str],
+ device: torch.device = "cuda",
+ seed: Optional[int] = 42,
+ batch_size: int = 128,
+ dataloader_workers: int = 16,
+ verbose: bool = True,
+):
+ import clip
+ from T2IBenchmark.feature_extractors import (BaseFeatureExtractor,
+ InceptionV3FE)
+ from T2IBenchmark.loaders import CaptionImageDataset
+ from T2IBenchmark.model_wrapper import (ModelWrapperDataloader,
+ T2IModelWrapper)
+ from T2IBenchmark.utils import dprint, set_all_seeds
+
+ if seed:
+ set_all_seeds(seed)
+
+ model, preprocess = clip.load("ViT-B/32", device=device)
+ dataset = CaptionImageDataset(
+ images_paths=image_paths,
+ captions=list(map(lambda x: captions_mapping[x], image_paths)),
+ preprocess_fn=preprocess,
+ )
+ dataloader = DataLoader(
+ dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ drop_last=False,
+ num_workers=dataloader_workers,
+ )
+
+ score_acc = 0.0
+ num_samples = 0.0
+
+ for image, caption in tqdm(dataloader):
+ image_embedding = model.encode_image(image.to(device))
+ caption_embedding = model.encode_text(clip.tokenize(caption, truncate=True).to(device))
+
+ image_features = image_embedding / image_embedding.norm(dim=1, keepdim=True).to(
+ torch.float32
+ )
+ caption_features = caption_embedding / caption_embedding.norm(
+ dim=1, keepdim=True
+ ).to(torch.float32)
+
+ score = (image_features * caption_features).sum()
+ score_acc += score
+ num_samples += image.shape[0]
+
+ clip_score = score_acc / num_samples
+ dprint(verbose, f"CLIP score is {clip_score}")
+
+ return clip_score
+
+def get_chameleon_txt_indices(vae, include_special_tokens=True):
+ image_indices = set(vae.chameleon_ori_translation.bpe2img.keys())
+ if include_special_tokens:
+ h_grids, w_grids = 32, 32
+ image_start_token = vae.token2id(vae.image_start_token)
+ n_grids_token = vae.token2id(vae.get_n_grids_token(h_grids))
+ image_end_token = vae.token2id(vae.image_end_token)
+ image_indices.add(image_start_token)
+ image_indices.add(n_grids_token)
+ image_indices.add(image_end_token)
+ image_indices.add(-100)
+ image_indices.add(1)
+ image_indices.update(range(8192, 8820 + 1))
+
+ return image_indices
\ No newline at end of file
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/autoregressive_orig.py b/models/autoregressive_orig.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef6caea44e2e49b7ae7fe5115a0180e225ae6fb4
--- /dev/null
+++ b/models/autoregressive_orig.py
@@ -0,0 +1,358 @@
+import math
+import typing
+
+import flash_attn
+import flash_attn.layers.rotary
+import huggingface_hub
+import omegaconf
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+# Flags required to enable jit fusion kernels
+torch._C._jit_set_profiling_mode(False)
+torch._C._jit_set_profiling_executor(False)
+torch._C._jit_override_can_fuse_on_cpu(True)
+torch._C._jit_override_can_fuse_on_gpu(True)
+
+
+def bias_dropout_add_scale(
+ x: torch.Tensor,
+ bias: typing.Optional[torch.Tensor],
+ scale: torch.Tensor,
+ residual: typing.Optional[torch.Tensor],
+ prob: float,
+ training: bool,
+) -> torch.Tensor:
+ if bias is not None:
+ out = scale * F.dropout(
+ x + bias, p=prob, training=training
+ )
+ else:
+ out = scale * F.dropout(x, p=prob, training=training)
+
+ if residual is not None:
+ out = residual + out
+ return out
+
+
+def get_bias_dropout_add_scale(training):
+ def _bias_dropout_add(x, bias, scale, residual, prob):
+ return bias_dropout_add_scale(
+ x, bias, scale, residual, prob, training
+ )
+
+ return _bias_dropout_add
+
+
+@torch.jit.script
+def bias_dropout_add_scale_fused_train(
+ x: torch.Tensor,
+ bias: typing.Optional[torch.Tensor],
+ scale: torch.Tensor,
+ residual: typing.Optional[torch.Tensor],
+ prob: float,
+) -> torch.Tensor:
+ return bias_dropout_add_scale(
+ x, bias, scale, residual, prob, True
+ )
+
+
+@torch.jit.script
+def bias_dropout_add_scale_fused_inference(
+ x: torch.Tensor,
+ bias: typing.Optional[torch.Tensor],
+ scale: torch.Tensor,
+ residual: typing.Optional[torch.Tensor],
+ prob: float,
+) -> torch.Tensor:
+ return bias_dropout_add_scale(
+ x, bias, scale, residual, prob, False
+ )
+
+
+class Rotary(torch.nn.Module):
+ def __init__(self, dim, base=10_000):
+ super().__init__()
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2).float() / dim)
+ )
+ self.register_buffer('inv_freq', inv_freq)
+ self.seq_len_cached = None
+ self.cos_cached = None
+ self.sin_cached = None
+
+ def forward(self, x, seq_dim=1):
+ seq_len = x.shape[seq_dim]
+ if seq_len != self.seq_len_cached:
+ self.seq_len_cached = seq_len
+ t = torch.arange(
+ x.shape[seq_dim], device=x.device
+ ).type_as(self.inv_freq)
+ freqs = torch.einsum(
+ 'i,j->ij', t, self.inv_freq.clone()
+ )
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ # dims are: batch, seq_len, qkv, head, dim
+ self.cos_cached = emb.cos()[
+ None, :, None, None, :
+ ].repeat(1, 1, 3, 1, 1)
+ self.sin_cached = emb.sin()[
+ None, :, None, None, :
+ ].repeat(1, 1, 3, 1, 1)
+ # This makes the transformation on v an identity.
+ self.cos_cached[:, :, 2, :, :].fill_(1.0)
+ self.sin_cached[:, :, 2, :, :].fill_(0.0)
+
+ return self.cos_cached, self.sin_cached
+
+
+def rotate_half(x):
+ x1, x2 = (
+ x[..., : x.shape[-1] // 2],
+ x[..., x.shape[-1] // 2 :],
+ )
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(qkv, cos, sin):
+ cos = cos[0, :, 0, 0, : cos.shape[-1] // 2]
+ sin = sin[0, :, 0, 0, : sin.shape[-1] // 2]
+ return flash_attn.layers.rotary.apply_rotary_emb_qkv_(
+ qkv, cos, sin
+ )
+
+
+#################################################################################
+# Layers #
+#################################################################################
+class LayerNorm(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones([dim]))
+ self.dim = dim
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(enabled=False):
+ x = F.layer_norm(x.float(), [self.dim])
+ return x * self.weight[None, None, :]
+
+
+def residual_linear(x, W, x_skip, residual_scale):
+ """x_skip + residual_scale * W @ x"""
+ dim_out, dim_in = W.shape[0], W.shape[1]
+ return torch.addmm(
+ x_skip.view(-1, dim_out),
+ x.view(-1, dim_in),
+ W.T,
+ alpha=residual_scale,
+ ).view(*x.shape[:-1], dim_out)
+
+
+#################################################################################
+# Core Model #
+#################################################################################
+
+
+class DDiTBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ cond_dim,
+ mlp_ratio=4,
+ dropout=0.1,
+ causal=False,
+ ):
+ super().__init__()
+ self.n_heads = n_heads
+ self.causal = causal
+
+ self.norm1 = LayerNorm(dim)
+ self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
+ self.attn_out = nn.Linear(dim, dim, bias=False)
+ self.dropout1 = nn.Dropout(dropout)
+
+ self.norm2 = LayerNorm(dim)
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, mlp_ratio * dim, bias=True),
+ nn.GELU(approximate='tanh'),
+ nn.Linear(mlp_ratio * dim, dim, bias=True),
+ )
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout = dropout
+
+ def _get_bias_dropout_scale(self):
+ if self.training:
+ return bias_dropout_add_scale_fused_train
+ else:
+ return bias_dropout_add_scale_fused_inference
+
+ def forward(self, x, rotary_cos_sin, c, seqlens=None):
+ batch_size, seq_len = x.shape[0], x.shape[1]
+
+ bias_dropout_scale_fn = self._get_bias_dropout_scale()
+
+ # attention operation
+ x_skip = x
+ x = self.norm1(x)
+
+ qkv = self.attn_qkv(x)
+ qkv = rearrange(
+ qkv,
+ 'b s (three h d) -> b s three h d',
+ three=3,
+ h=self.n_heads,
+ )
+ with torch.cuda.amp.autocast(enabled=False):
+ cos, sin = rotary_cos_sin
+ qkv = apply_rotary_pos_emb(
+ qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)
+ )
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
+ if seqlens is None:
+ cu_seqlens = torch.arange(
+ 0,
+ (batch_size + 1) * seq_len,
+ step=seq_len,
+ dtype=torch.int32,
+ device=qkv.device,
+ )
+ else:
+ cu_seqlens = seqlens.cumsum(-1)
+ x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func(
+ qkv, cu_seqlens, seq_len, 0.0, causal=self.causal
+ )
+
+ x = rearrange(x, '(b s) h d -> b s (h d)', b=batch_size)
+
+ scale = torch.ones(1, device=x.device, dtype=x.dtype)
+ x = bias_dropout_scale_fn(
+ self.attn_out(x), None, scale, x_skip, self.dropout
+ )
+
+ # mlp operation
+ x = bias_dropout_scale_fn(
+ self.mlp(self.norm2(x)), None, scale, x, self.dropout
+ )
+ return x
+
+
+class EmbeddingLayer(nn.Module):
+ def __init__(self, dim, vocab_dim):
+ super().__init__()
+ self.embedding = nn.Parameter(
+ torch.empty((vocab_dim, dim))
+ )
+ torch.nn.init.kaiming_uniform_(
+ self.embedding, a=math.sqrt(5)
+ )
+
+ def forward(self, x):
+ return self.embedding[x]
+
+
+class DDitFinalLayer(nn.Module):
+ def __init__(
+ self, hidden_size, out_channels, cond_dim, causal=False
+ ):
+ super().__init__()
+ self.causal = causal
+ assert causal == True
+
+ self.norm_final = LayerNorm(hidden_size)
+ self.linear = nn.Linear(hidden_size, out_channels)
+ self.linear.weight.data.zero_()
+ self.linear.bias.data.zero_()
+
+ def forward(self, x, c):
+ return self.linear(self.norm_final(x))
+
+
+class DDIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
+ def __init__(self, config, vocab_size: int, causal: bool):
+ super().__init__()
+ if type(config) == dict:
+ config = omegaconf.OmegaConf.create(config)
+
+ self.config = config
+ self.vocab_size = vocab_size
+ self.causal = (
+ hasattr(config.model, 'causal')
+ and config.model.causal
+ ) or causal
+ assert self.causal == True
+
+ self.vocab_embed = EmbeddingLayer(
+ config.model.hidden_size, vocab_size
+ )
+ self.rotary_emb = Rotary(
+ config.model.hidden_size // config.model.n_heads
+ )
+
+ blocks = []
+ for _ in range(config.model.n_blocks):
+ blocks.append(
+ DDiTBlock(
+ config.model.hidden_size,
+ config.model.n_heads,
+ config.model.cond_dim,
+ dropout=config.model.dropout,
+ causal=self.causal,
+ )
+ )
+ self.blocks = nn.ModuleList(blocks)
+
+ self.output_layer = DDitFinalLayer(
+ config.model.hidden_size,
+ vocab_size,
+ config.model.cond_dim,
+ causal=self.causal,
+ )
+ self.scale_by_sigma = config.model.scale_by_sigma
+
+ def _get_bias_dropout_scale(self):
+ if self.training:
+ return bias_dropout_add_scale_fused_train
+ else:
+ return bias_dropout_add_scale_fused_inference
+
+
+class AR(DDIT):
+ def __init__(self, config, vocab_size, mask_index, causal: bool = False):
+ super().__init__(config, vocab_size, causal)
+ self.mask_index = mask_index
+ self.neg_infinity = -1000.0
+
+ def forward(self, xt, sigma, **kwargs):
+ """Forward pass of the denoising model.
+
+ Args:
+ xt: int torch.Tensor with shape
+ (batch_size, diffusion_model_input_length), token ids.
+ sigma: float torch.Tensor with shape
+ (batch_size).
+
+ Returns:
+ log probability with shape
+ (batch_size, diffusion_model_input_length, vocab_size)
+ """
+ x = self.vocab_embed(xt)
+
+ rotary_cos_sin = self.rotary_emb(x)
+
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+ for i in range(len(self.blocks)):
+ x = self.blocks[i](
+ x, rotary_cos_sin, None, seqlens=None
+ )
+ output = self.output_layer(x, None)
+
+ # log prob at the mask index = - infinity
+ output[:, :, self.mask_index] = self.neg_infinity
+
+ # Normalize the logits such that x.exp() is
+ # a probability distribution over vocab_size.
+ # x = x - torch.logsumexp(x, dim=-1, keepdim=True)
+ return output.log_softmax(-1)
diff --git a/models/configuration_openelm_local.py b/models/configuration_openelm_local.py
new file mode 100644
index 0000000000000000000000000000000000000000..86497b6d76315e2b4e4d9f8b12c2895151a6c458
--- /dev/null
+++ b/models/configuration_openelm_local.py
@@ -0,0 +1,318 @@
+#
+# For licensing see accompanying LICENSE file.
+# Copyright (C) 2024 Apple Inc. All Rights Reserved.
+#
+
+"""Implements HF OpenELMConfig based on PretrainedConfig"""
+from numbers import Number
+from typing import List, Optional, Union
+
+import numpy as np
+from transformers import PretrainedConfig
+
+
+def make_divisible(
+ v: Union[float, int],
+ divisor: Optional[int] = 8,
+ min_value: Optional[Union[float, int]] = None,
+) -> Union[float, int]:
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by the divisor
+ It can be seen at:
+ https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62
+
+ Args:
+ v: input value
+ divisor: default to 8
+ min_value: minimum divisor value
+ Returns:
+ new_v: new divisible value
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+def compute_heads(model_dim: int, head_dim: int) -> int:
+ """Compute the number of heads.
+
+ Args:
+ model_dim: Model dimension.
+ head_dim: Head dimension.
+
+ Returns:
+ An integer denoting number of heads in multi-head attention is returned.
+
+ Raises:
+ ValueError: if model dimension is not divisible by head dimension.
+ """
+ if model_dim % head_dim == 0:
+ return model_dim // head_dim
+ else:
+ raise ValueError(
+ f"Model dimension should be divisible by head dimension. Got: {model_dim} and {head_dim}."
+ )
+
+
+OpenELM_CONFIGS = {
+ "OpenELM-270M": dict(
+ num_transformer_layers=16,
+ model_dim=1280,
+ head_dim=64,
+ num_gqa_groups=4,
+ normalize_qk_projections=True,
+ share_input_output_layers=True,
+ # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
+ ffn_multipliers=(0.5, 4.0),
+ qkv_multipliers=(0.5, 1.0),
+ ),
+ "OpenELM-450M": dict(
+ num_transformer_layers=20,
+ model_dim=1536,
+ head_dim=64,
+ num_gqa_groups=4,
+ normalize_qk_projections=True,
+ share_input_output_layers=True,
+ # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
+ ffn_multipliers=(0.5, 4.0),
+ qkv_multipliers=(0.5, 1.0),
+ ),
+ "OpenELM-1_1B": dict(
+ num_transformer_layers=28,
+ model_dim=2048,
+ head_dim=64,
+ num_gqa_groups=4,
+ normalize_qk_projections=True,
+ share_input_output_layers=True,
+ # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
+ ffn_multipliers=(0.5, 4.0),
+ qkv_multipliers=(0.5, 1.0),
+ ),
+ "OpenELM-3B": dict(
+ num_transformer_layers=36,
+ model_dim=3072,
+ head_dim=128,
+ num_gqa_groups=4,
+ normalize_qk_projections=True,
+ share_input_output_layers=True,
+ # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
+ ffn_multipliers=(0.5, 4.0),
+ qkv_multipliers=(0.5, 1.0),
+ ),
+}
+
+
+class OpenELMConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`OpenELMModel`]. It is used to instantiate an OpenELM model according to the specified arguments, defining the model architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the OpenELM model.
+ max_context_length (`int`, *optional*, defaults to 2048):
+ Maximum number of input tokens.
+ num_transformer_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer decoder.
+ model_dim (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ head_dim (`int`, *optional*, defaults to 128):
+ The attention head dimension.
+ qkv_multipliers (`Union[Number, List[Number]]`, *optional*, defaults to 1.0):
+ If the qkv_multipliers is a Number, then all attention layers have the same latent dimensions,
+ resulting in uniform allocation of parameters.
+ If the qkv_multipliers is a List of Number, then each attention layer have different latent dimensions
+ assuming qkv_multipliers[0] != qkv_multipliers[1]. This results in variable allocation of parameters in attention layer.
+ This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
+ num_query_heads (`Union[int, None]`, *optional*, defaults to None):
+ The number of query heads, computed from `compute_heads(model_dim=model_dim, head_dim=head_dim)`.
+ num_gqa_groups (`int`, *optional*, defaults to 1):
+ This variable allows to switch between multi-head attention, group query attention, and multi-query attention.
+ When num_gqa_groups == 1, then it is multi-head attention.
+ When 1 < num_gqa_groups < num_heads and num_heads is divisible by num_gqa_groups, then it is group query attention
+ When num_gqa_groups == num_heads, then it is multi-query attention
+ ffn_multipliers (`Union[Number, List[Number]]`, *optional*, defaults to 4.0):
+ Feed-forward network (FFN) multipliers.
+ If the ffn_multipliers is a Number, then all FFN layers have the same latent dimensions,
+ resulting in uniform allocation of parameters.
+ If the ffn_multipliers is a List of Number, then each FFN layer have different latent dimensions
+ assuming ffn_multipliers[0] != ffn_multipliers[1]. This results in variable allocation of parameters in FFN layer.
+ This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
+ ffn_with_glu (`bool`, *optional*, defaults to True):
+ Whether to use FFN with Gated Linear Unit (GLU)
+ ffn_dim_divisor (`int`, *optional*, defaults to 256):
+ The ffn layer dimension divisor.
+ activation_fn_name (`str` or `function`, *optional*, defaults to `"swish"`):
+ The non-linear activation function (function or string) in the decoder.
+ normalization_layer_name (`str` or `function`, *optional*, defaults to `"rms_norm"`):
+ Type of normalization layer.
+ normalize_qk_projections (`bool`, *optional*, defaults to False):
+ Whether to normalize queries and keys after projections
+ share_input_output_layers (`bool`, *optional*, defaults to False):
+ Whether to share the embedding between input and output linear layer
+ rope_freq_constant (`int`, *optional*, defaults to 10000):
+ The base period of the RoPE embeddings.
+ rope_max_length (`int`, *optional*, defaults to 4096):
+ That rope_max_length is set to twice of max_context_length.
+ This allows flexibility in token lengths during training or fine-tuning.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ bos_token_id (`int`, *optional*, defaults to 2):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 1):
+ End of stream token id.
+ """
+
+ model_type = "openelm"
+
+ def __init__(
+ self,
+ vocab_size: int = 32000,
+ max_context_length: int = 2048,
+ num_transformer_layers: int = 12,
+ model_dim: int = 2048,
+ head_dim: int = 128,
+ qkv_multipliers: Union[Number, List[Number]] = 1.0,
+ num_query_heads: Union[int, None] = None,
+ num_gqa_groups: int = 1,
+ ffn_multipliers: Union[Number, List[Number]] = 4.0,
+ ffn_with_glu: bool = True,
+ ffn_dim_divisor: int = 256,
+ activation_fn_name: str = "swish",
+ normalization_layer_name: str = "rms_norm",
+ normalize_qk_projections: bool = False,
+ share_input_output_layers: bool = False,
+ rope_freq_constant: int = 10000,
+ rope_max_length: int = 4096,
+ initializer_range: float = 0.02,
+ use_cache: bool = True,
+ bos_token_id: int = 1,
+ eos_token_id: int = 2,
+ **kwargs,
+ ) -> None:
+ self.vocab_size = vocab_size
+ self.max_context_length = max_context_length
+ self.num_transformer_layers = num_transformer_layers
+ self.model_dim = model_dim
+ self.head_dim = head_dim
+ self.qkv_multipliers = qkv_multipliers
+ self.num_query_heads = num_query_heads
+ self.num_gqa_groups = num_gqa_groups
+ self.ffn_multipliers = ffn_multipliers
+ self.ffn_with_glu = ffn_with_glu
+ self.ffn_dim_divisor = ffn_dim_divisor
+ self.activation_fn_name = activation_fn_name
+ self.normalization_layer_name = normalization_layer_name
+ self.normalize_qk_projections = normalize_qk_projections
+ self.share_input_output_layers = share_input_output_layers
+ self.rope_freq_constant = rope_freq_constant
+ self.rope_max_length = rope_max_length
+ self.num_query_heads = (
+ compute_heads(model_dim=model_dim, head_dim=head_dim)
+ if num_query_heads is None
+ else num_query_heads
+ )
+ self.initializer_range = initializer_range
+
+ self.__post_init__()
+ super().__init__(
+ use_cache=use_cache,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+
+ def __post_init__(self) -> None:
+ if self.num_gqa_groups is not None:
+ head_multiple_of = self.num_gqa_groups
+ else:
+ head_multiple_of = 2
+
+ if isinstance(self.qkv_multipliers, Number):
+ # All attention layers have the same latent dimensions, resulting in uniform allocation of parameters.
+ qkv_dim = make_divisible(
+ self.model_dim * self.qkv_multipliers,
+ divisor=self.head_dim * head_multiple_of,
+ )
+ query_dims = [int(qkv_dim)] * self.num_transformer_layers
+
+ elif (
+ isinstance(self.qkv_multipliers, (tuple, list))
+ and len(self.qkv_multipliers) == 2
+ ):
+ # Each attention layer have different latent dimensions assuming qkv_multipliers[0] != qkv_multipliers[1].
+ # This results in variable allocation of parameters in attention layer.
+ # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
+ qkv_multipliers = [
+ round(v, 2)
+ for v in np.linspace(
+ self.qkv_multipliers[0],
+ self.qkv_multipliers[1],
+ num=self.num_transformer_layers,
+ dtype=float,
+ )
+ ]
+ # Make sure that scaled model dimension is divisible by scaled head dimension.
+ query_dims = [
+ int(
+ make_divisible(
+ self.model_dim * m, divisor=self.head_dim * head_multiple_of
+ )
+ )
+ for m in qkv_multipliers
+ ]
+ else:
+ raise NotImplementedError(
+ f"QKV multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}."
+ )
+
+ # compute the number of query, key, and value heads
+ # For multi-head and multi-query attention, the number of heads for query, key, and value are the same.
+ # For group query attention, the number of key and value heads are the same.
+ self.num_query_heads = [
+ int(compute_heads(q_dim, self.head_dim)) for q_dim in query_dims
+ ]
+ self.num_kv_heads = [
+ q_heads // self.num_gqa_groups for q_heads in self.num_query_heads
+ ]
+
+ # Feed-forward network (FFN) multipliers
+ if isinstance(self.ffn_multipliers, Number):
+ # All FFN layers have the same latent dimensions, resulting in uniform allocation of parameters.
+ self.ffn_multipliers = [self.ffn_multipliers] * self.num_transformer_layers
+ elif isinstance(self.ffn_multipliers, (tuple, list)):
+ # Each FFN layer have different latent dimensions assuming ffn_multipliers[0] != ffn_multipliers[1].
+ # This results in variable allocation of parameters in FFN layer.
+ # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
+ if len(self.ffn_multipliers) == 2:
+ self.ffn_multipliers = [
+ round(v, 2)
+ for v in np.linspace(
+ self.ffn_multipliers[0],
+ self.ffn_multipliers[1],
+ num=self.num_transformer_layers,
+ dtype=float,
+ )
+ ]
+ else:
+ assert (
+ len(self.ffn_multipliers) == self.num_transformer_layers
+ ), f"{len(self.ffn_multipliers)=}!={self.num_transformer_layers=}"
+ else:
+ raise NotImplementedError(
+ f"FFN multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}."
+ )
+
+ # check num_query_heads divisible by num_kv_heads for every layer
+ for layer_idx in range(len(query_dims)):
+ assert self.num_query_heads[layer_idx] % self.num_kv_heads[layer_idx] == 0
diff --git a/models/datasets/__init__.py b/models/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/datasets/combine_token_dicts.py b/models/datasets/combine_token_dicts.py
new file mode 100644
index 0000000000000000000000000000000000000000..601c92ec36312dcea6c64e5fee176cf44a49c6d5
--- /dev/null
+++ b/models/datasets/combine_token_dicts.py
@@ -0,0 +1,216 @@
+from pathlib import Path
+from typing import Optional
+
+import torch
+import typer
+from tensordict import TensorDict
+from typing_extensions import Annotated
+import time
+import shutil
+from decoupled_utils import rprint
+
+app = typer.Typer(pretty_exceptions_show_locals=False)
+typer.main.get_command_name = lambda name: name
+
+def split_dataset(dataset, n: int, m: int):
+ # Ensure m is valid
+ if m < 0 or m >= n:
+ raise ValueError(f"m must be between 0 and {n-1}, but got {m}.")
+
+ # Calculate the size of each subset
+ total_len = len(dataset)
+ subset_size = total_len // n
+ remainder = total_len % n
+
+ # Calculate the start and end index of the m-th subset
+ start_idx = m * subset_size + min(m, remainder)
+ end_idx = start_idx + subset_size + (1 if m < remainder else 0)
+
+ # Return the m-th subset
+ return dataset[slice(start_idx, end_idx)]
+
+@app.command()
+def main(
+ data_dir: Path,
+ splits: Optional[list[str]] = ["train", "val"],
+ add_vggface2_text_tokens: bool = False,
+ use_tmp: bool = False,
+ use_all: bool = False,
+ allow_zero_idx: bool = False,
+ use_timestamp: bool = False,
+ delete_after_combining: bool = False,
+ allow_existing: bool = False,
+ force_overwrite: bool = False,
+ move_files: bool = False,
+ allow_tmp: bool = False,
+ mem_efficient: bool = False,
+ output_dir: Optional[Path] = None,
+ require_image_tokens: bool = False,
+ min_idx: Optional[int] = None,
+ max_idx: Optional[int] = None,
+ split_num: Optional[int] = None,
+ split_idx: Optional[int] = None,
+):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ for split in splits:
+ if allow_tmp:
+ all_folders = sorted([folder for folder in data_dir.iterdir() if folder.is_dir() and split in folder.name and "_" in folder.name and (allow_existing or "existing" not in folder.name)])
+ print(f"All folders: len({len(all_folders)})")
+
+ from collections import defaultdict
+ unique_ids = defaultdict(list)
+ for folder in all_folders:
+ folder_id = int(folder.name.split("_")[-1])
+ unique_ids[folder_id].append(folder)
+
+ folders = []
+ for folder_id, _folders in unique_ids.items():
+ if len(_folders) == 1:
+ folders.append(_folders[0])
+ else:
+ for folder in _folders:
+ if "tmp" not in folder.name:
+ folders.append(folder)
+
+ folders = sorted(folders)
+ print(f"Using {len(folders)} folders for {split}")
+ else:
+ folders = sorted([folder for folder in data_dir.iterdir() if folder.is_dir() and split in folder.name and "_" in folder.name and (use_all or (not use_tmp or "tmp" in folder.name)) and (allow_existing or "existing" not in folder.name)])
+
+ if min_idx is not None and max_idx is not None:
+ print(f"Filtering with min_idx: {min_idx} and max_idx: {max_idx}")
+ _tmp_folders = []
+ for folder in folders:
+ _name = int(folder.name.split("_")[-1])
+ if min_idx <= _name <= max_idx:
+ _tmp_folders.append(folder)
+ folders = _tmp_folders
+ print(f"Filtered folders and got: {len(folders)}")
+
+ if split_num is not None and split_idx is not None:
+ folders = split_dataset(folders, split_num, split_idx)
+ print(f"Filtered folders and got: {len(folders)}")
+
+ initial_folder_count = len(folders)
+ folders = [folder for folder in folders if any(folder.iterdir())]
+ removed_folders_count = initial_folder_count - len(folders)
+ print(f"Removed {removed_folders_count} empty folders")
+ if len(folders) == 0:
+ print(f"No folders found for {split}")
+ continue
+ print(f"{split} folders: {folders}")
+ _tensors = [TensorDict.load_memmap(folder) for folder in folders if (folder / "meta.json").exists()]
+ _tensors = [tensor for tensor in _tensors if tensor.shape[0] > 0]
+ for _tensor in _tensors:
+ if "write_flag" not in _tensor:
+ _tensor["write_flag"] = torch.ones((len(_tensor), 1), dtype=torch.bool)
+ loaded_tensors = torch.cat(_tensors, dim=0)
+ del _tensors
+
+ if add_vggface2_text_tokens:
+ loaded_tensors.set("txt_input_ids", loaded_tensors["img_input_ids"].new_zeros(loaded_tensors["img_input_ids"].shape[0], 47), inplace=True)
+ loaded_tensors.set("txt_attention_mask", loaded_tensors["img_input_ids"].new_zeros(loaded_tensors["img_input_ids"].shape[0], 1), inplace=True)
+ print(f"Added VGGFace2 text tokens to {split}")
+
+ index_keys = ("img_label", "img_input_ids", "txt_input_ids", "input_ids")
+ if not mem_efficient:
+ for key in index_keys:
+ if key in loaded_tensors:
+ loaded_tensors[key] = loaded_tensors[key].to(torch.int32)
+
+ if "img_input_ids" in loaded_tensors:
+ written_indices = ((loaded_tensors["write_flag"] > 0).squeeze(-1) & (loaded_tensors["img_input_ids"] > 0).all(dim=-1))
+ else:
+ if mem_efficient:
+ written_indices = (loaded_tensors["write_flag"] > 0).squeeze(-1)
+ else:
+ written_indices = ((loaded_tensors["write_flag"] > 0).squeeze(-1) & (loaded_tensors["input_ids"] > 0).any(dim=-1))
+
+ print(f"Valid elements for {split}: {written_indices.shape[0]}")
+ loaded_tensors = loaded_tensors[written_indices]
+ invalid_indices = loaded_tensors["idx"].squeeze(-1) == -1
+ if require_image_tokens:
+ invalid_modality = ~(loaded_tensors["modality"] > 0).any(dim=-1)
+ invalid_indices |= invalid_modality
+ print(f"Found {invalid_modality.sum()} invalid indices for {split} due to missing image tokens")
+ print(f"Invalid indices for {split}: {invalid_indices.sum()}")
+
+ loaded_tensors = loaded_tensors[~invalid_indices]
+ if allow_zero_idx is False:
+ _, idx = torch.unique(loaded_tensors["idx"].to(device), dim=0, sorted=True, return_inverse=True)
+ loaded_tensors = loaded_tensors[torch.unique(idx, return_inverse=False).to(loaded_tensors.device)]
+
+ print(f"After filtering: {loaded_tensors.shape[0]}")
+
+ if loaded_tensors.shape[0] == 0:
+ rprint(f"WARNING!!! No valid elements for {split}")
+ return
+
+ for _key in ["img_input_ids", "input_ids"]:
+ if _key in loaded_tensors:
+ assert 0 <= loaded_tensors[_key].min() and loaded_tensors[_key].max() < torch.iinfo(torch.int16).max
+ loaded_tensors[_key] = loaded_tensors[_key].to(torch.int16)
+
+ index_keys = ("img_label", "txt_attention_mask", "attention_mask")
+ for key in index_keys:
+ if key in loaded_tensors:
+ loaded_tensors[key] = loaded_tensors[key].squeeze(-1)
+
+ if "write_flag" in loaded_tensors:
+ del loaded_tensors["write_flag"]
+
+ if split_idx is not None:
+ split = f"split_{split_idx}_{split}"
+
+ if use_timestamp:
+ loaded_tensors.memmap(data_dir / f"{split}_existing_{int(time.time())}")
+ else:
+ if (data_dir / f"{split}").exists():
+ print("Already exists!")
+ if force_overwrite:
+ shutil.rmtree(data_dir / f"{split}")
+ else:
+ breakpoint()
+
+ if output_dir is not None:
+ loaded_tensors.memmap(output_dir / f"{split}")
+ else:
+ loaded_tensors.memmap(data_dir / f"{split}")
+
+ if delete_after_combining:
+ for folder in folders:
+ try:
+ rprint(f"Removing folder: {folder}")
+ shutil.rmtree(folder)
+ except Exception as e:
+ rprint(f"Error removing folder: {e}")
+
+ if force_overwrite:
+ from pathlib import Path
+ for train_folder in Path(data_dir).glob('train_*'):
+ rprint(f"Removing folder: {train_folder}")
+ if train_folder.is_file():
+ train_folder.unlink()
+ else:
+ shutil.rmtree(train_folder)
+
+ train_dir = data_dir / 'train'
+ if train_dir.exists() and train_dir.is_dir():
+ for item in train_dir.iterdir():
+ shutil.move(str(item), str(train_dir.parent))
+ shutil.rmtree(train_dir)
+
+ elif move_files:
+ train_dir = data_dir / 'train'
+ if train_dir.exists() and train_dir.is_dir():
+ for item in train_dir.iterdir():
+ shutil.move(str(item), str(train_dir.parent))
+
+ # Check if train_dir is empty after moving files
+ if train_dir.exists() and train_dir.is_dir():
+ if not any(train_dir.iterdir()):
+ shutil.rmtree(train_dir)
+ rprint(f"Removed empty train directory: {train_dir}")
+
+if __name__ == "__main__":
+ app()
\ No newline at end of file
diff --git a/models/datasets/cub200.py b/models/datasets/cub200.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2e3c267705394c0b03f89451f711a3e63d647ac
--- /dev/null
+++ b/models/datasets/cub200.py
@@ -0,0 +1,356 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+
+from collections import defaultdict
+from io import BytesIO # Added import for BytesIO
+
+import torch
+import torch.utils.data as data
+from torch.autograd import Variable
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as F
+
+import os
+import sys
+import numpy as np
+import pandas as pd
+from PIL import Image
+import numpy.random as random
+if sys.version_info[0] == 2:
+ import cPickle as pickle
+else:
+ import pickle
+
+#from pycocotools.coco import COCO
+from skimage import io
+import matplotlib.pyplot as plt
+from matplotlib import cm
+
+import nltk, sklearn
+nltk.download('punkt')
+nltk.download('averaged_perceptron_tagger')
+
+def prepare_data(data):
+ imgs, captions, captions_lens, class_ids, keys, wrong_caps, \
+ wrong_caps_len, wrong_cls_id, noise, word_labels = data
+
+ # sort data by the length in a decreasing order
+ sorted_cap_lens, sorted_cap_indices = \
+ torch.sort(captions_lens, 0, True)
+
+ real_imgs = []
+ for i in range(len(imgs)):
+ imgs[i] = imgs[i][sorted_cap_indices]
+ if False:
+ real_imgs.append(Variable(imgs[i]).cuda())
+ else:
+ real_imgs.append(Variable(imgs[i]))
+
+ noise = noise[sorted_cap_indices]
+ word_labels = word_labels[sorted_cap_indices]
+
+ captions = captions[sorted_cap_indices].squeeze()
+ class_ids = class_ids[sorted_cap_indices].numpy()
+ keys = [keys[i] for i in sorted_cap_indices.numpy()]
+
+ if False:
+ captions = Variable(captions).cuda()
+ sorted_cap_lens = Variable(sorted_cap_lens).cuda()
+ else:
+ captions = Variable(captions)
+ sorted_cap_lens = Variable(sorted_cap_lens)
+
+ w_sorted_cap_lens, w_sorted_cap_indices = \
+ torch.sort(wrong_caps_len, 0, True)
+
+ wrong_caps = wrong_caps[w_sorted_cap_indices].squeeze()
+ wrong_cls_id = wrong_cls_id[w_sorted_cap_indices].numpy()
+
+ if False:
+ wrong_caps = Variable(wrong_caps).cuda()
+ w_sorted_cap_lens = Variable(w_sorted_cap_lens).cuda()
+ else:
+ wrong_caps = Variable(wrong_caps)
+ w_sorted_cap_lens = Variable(w_sorted_cap_lens)
+
+
+ ##
+ return [real_imgs, captions, sorted_cap_lens,
+ class_ids, keys, wrong_caps, w_sorted_cap_lens, wrong_cls_id, noise, word_labels]
+
+
+
+def get_imgs(img_path, bbox, imsize, do_augment=False, image_cache=None):
+ """
+ Load image with caching of raw bytes to improve performance on repeated accesses.
+ Raw bytes are cached before any transformations like cropping to maintain compression.
+ """
+ if image_cache is None: image_cache = {}
+ if img_path in image_cache:
+ raw_bytes = image_cache[img_path]
+ else:
+ with open(img_path, 'rb') as f:
+ raw_bytes = f.read()
+ image_cache[img_path] = raw_bytes
+
+ img = Image.open(BytesIO(raw_bytes)).convert('RGB')
+ width, height = img.size
+
+ if bbox is not None:
+ r = int(np.maximum(bbox[2], bbox[3]) * 0.75)
+ center_x = int((2 * bbox[0] + bbox[2]) / 2)
+ center_y = int((2 * bbox[1] + bbox[3]) / 2)
+ y1 = np.maximum(0, center_y - r)
+ y2 = np.minimum(height, center_y + r)
+ x1 = np.maximum(0, center_x - r)
+ x2 = np.minimum(width, center_x + r)
+ img = img.crop([x1, y1, x2, y2])
+
+ w, h = img.size
+ if do_augment:
+ if random.random() < 0.5:
+ img = F.hflip(img)
+ crop_side = random.randint(int(min(w, h) * 0.7), int(min(w, h) * 1.0))
+ left = random.randint(0, w - crop_side)
+ top = random.randint(0, h - crop_side)
+ img = F.crop(img, top, left, crop_side, crop_side)
+ img = F.resize(img, (imsize, imsize), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)
+ else:
+ # if w != h:
+ # min_side = min(w, h)
+ # left = (w - min_side) // 2
+ # top = (h - min_side) // 2
+ # img = F.crop(img, top, left, min_side, min_side)
+ crop_side = int(min(w, h) * 0.9)
+ left = random.randint(0, w - crop_side)
+ top = random.randint(0, h - crop_side)
+ img = F.crop(img, top, left, crop_side, crop_side)
+ img = F.resize(img, (imsize, imsize), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)
+
+ return img
+
+class TextDataset(data.Dataset):
+ def __init__(self, data_dir, split='train'):
+ self.transform = None
+ self.norm = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
+ self.target_transform = None
+ self.embeddings_num = 10
+ self.imsize = 256
+ self.data = []
+ self.data_dir = data_dir
+ if data_dir.find('birds') != -1:
+ self.bbox = self.load_bbox()
+ else:
+ self.bbox = None
+ split_dir = os.path.join(data_dir, split)
+ self.split = split
+ self.filenames, self.captions, self.ixtoword, self.wordtoix, self.n_words = self.load_text_data(data_dir, split)
+ self.class_id = self.load_class_id(split_dir, len(self.filenames))
+ self.number_example = len(self.filenames)
+ self.image_cache = {}
+ print(f"CUB200 {split} dataset loaded with {len(self)} examples")
+
+ def load_bbox(self):
+ data_dir = self.data_dir
+ bbox_path = os.path.join(data_dir, 'CUB_200_2011/bounding_boxes.txt')
+ df_bounding_boxes = pd.read_csv(bbox_path,
+ delim_whitespace=True,
+ header=None).astype(int)
+ #
+ filepath = os.path.join(data_dir, 'CUB_200_2011/images.txt')
+ df_filenames = \
+ pd.read_csv(filepath, delim_whitespace=True, header=None)
+ filenames = df_filenames[1].tolist()
+ print('Total filenames: ', len(filenames), filenames[0])
+ #
+ filename_bbox = {img_file[:-4]: [] for img_file in filenames}
+ numImgs = len(filenames)
+ for i in range(0, numImgs):
+ bbox = df_bounding_boxes.iloc[i][1:].tolist()
+
+ key = filenames[i][:-4]
+ filename_bbox[key] = bbox
+ #
+ return filename_bbox
+
+ def load_captions(self, data_dir, filenames):
+ all_captions = []
+ for i in range(len(filenames)):
+ cap_path = '%s/text/%s.txt' % (data_dir, filenames[i])
+ with open(cap_path, "r") as f:
+ captions = f.read().split('\n')
+ cnt = 0
+ for cap in captions:
+ if len(cap) == 0:
+ continue
+ cap = cap.replace("\ufffd\ufffd", " ")
+ # picks out sequences of alphanumeric characters as tokens
+ # and drops everything else
+ from nltk.tokenize import RegexpTokenizer
+ tokenizer = RegexpTokenizer(r'\w+')
+ tokens = tokenizer.tokenize(cap.lower())
+ if len(tokens) == 0:
+ print('cap', cap)
+ continue
+
+ tokens_new = []
+ for t in tokens:
+ t = t.encode('ascii', 'ignore').decode('ascii')
+ if len(t) > 0:
+ tokens_new.append(t)
+ all_captions.append(tokens_new)
+ cnt += 1
+ if cnt == self.embeddings_num:
+ break
+ if cnt < self.embeddings_num:
+ print('ERROR: the captions for %s less than %d'
+ % (filenames[i], cnt))
+ return all_captions
+
+ def build_dictionary(self, train_captions, test_captions):
+ word_counts = defaultdict(float)
+ captions = train_captions + test_captions
+ for sent in captions:
+ for word in sent:
+ word_counts[word] += 1
+
+ vocab = [w for w in word_counts if word_counts[w] >= 0]
+
+ ixtoword = {}
+ ixtoword[0] = ''
+ wordtoix = {}
+ wordtoix[''] = 0
+ ix = 1
+ for w in vocab:
+ wordtoix[w] = ix
+ ixtoword[ix] = w
+ ix += 1
+
+ train_captions_new = []
+ for t in train_captions:
+ rev = []
+ for w in t:
+ if w in wordtoix:
+ rev.append(wordtoix[w])
+ # rev.append(0) # do not need '' token
+ # this train_captions_new hold index of each word in sentence
+ train_captions_new.append(rev)
+
+ test_captions_new = []
+ for t in test_captions:
+ rev = []
+ for w in t:
+ if w in wordtoix:
+ rev.append(wordtoix[w])
+ # rev.append(0) # do not need '' token
+ test_captions_new.append(rev)
+
+ return [train_captions_new, test_captions_new, ixtoword, wordtoix, len(ixtoword)]
+
+ def load_text_data(self, data_dir, split):
+ filepath = os.path.join(data_dir, 'captions.pickle')
+ train_names = self.load_filenames(data_dir, 'train')
+ test_names = self.load_filenames(data_dir, 'test')
+ if not os.path.isfile(filepath):
+ train_captions = self.load_captions(data_dir, train_names)
+ test_captions = self.load_captions(data_dir, test_names)
+
+ train_captions, test_captions, ixtoword, wordtoix, n_words = self.build_dictionary(train_captions, test_captions)
+ with open(filepath, 'wb') as f:
+ pickle.dump([train_captions, test_captions,
+ ixtoword, wordtoix], f, protocol=2)
+ print('Save to: ', filepath)
+ else:
+ with open(filepath, 'rb') as f:
+ print("filepath", filepath)
+ x = pickle.load(f)
+ train_captions, test_captions = x[0], x[1]
+ ixtoword, wordtoix = x[2], x[3]
+ del x
+ n_words = len(ixtoword)
+ print(f'Loaded from: {filepath}, Vocab size: {n_words}')
+ if split == 'train':
+ # a list of list: each list contains
+ # the indices of words in a sentence
+ captions = train_captions
+ filenames = train_names
+ else: # split=='test'
+ captions = test_captions
+ filenames = test_names
+
+ return filenames, captions, ixtoword, wordtoix, n_words
+
+ def load_class_id(self, data_dir, total_num):
+ if os.path.isfile(data_dir + '/class_info.pickle'):
+ with open(data_dir + '/class_info.pickle', 'rb') as f:
+ class_id = pickle.load(f, encoding='latin1')
+ else:
+ class_id = np.arange(total_num)
+ return class_id
+
+ def load_filenames(self, data_dir, split):
+ filepath = '%s/%s/filenames.pickle' % (data_dir, split)
+ if os.path.isfile(filepath):
+ with open(filepath, 'rb') as f:
+ filenames = pickle.load(f)
+ print('Load filenames from: %s (%d)' % (filepath, len(filenames)))
+ else:
+ filenames = []
+ return filenames
+
+ def get_caption(self, sent_ix):
+ # a list of indices for a sentence
+ sent_caption = np.asarray(self.captions[sent_ix]).astype('int64')
+ if (sent_caption == 0).sum() > 0:
+ print('ERROR: do not need END (0) token', sent_caption)
+ num_words = len(sent_caption)
+ # pad with 0s (i.e., '')
+ x = np.zeros((18, 1), dtype='int64')
+ x_len = num_words
+ if num_words <= 18:
+ x[:num_words, 0] = sent_caption
+ else:
+ ix = list(np.arange(num_words))
+ np.random.shuffle(ix)
+ ix = ix[:18]
+ ix = np.sort(ix)
+ x[:, 0] = sent_caption[ix]
+ x_len = 18
+ return x, x_len
+
+ def __getitem__(self, global_index):
+ index = global_index // self.embeddings_num
+ key = self.filenames[index]
+ cls_id = self.class_id[index]
+ # print(f"glindex: {global_index}, index: {index}, key: {key}, cls_id: {cls_id}")
+
+ if self.bbox is not None:
+ bbox = self.bbox[key]
+ data_dir = '%s/CUB_200_2011' % self.data_dir
+ else:
+ bbox = None
+ data_dir = self.data_dir
+
+ img_name = f'{data_dir}/images/{key}.jpg'
+ imgs = get_imgs(img_name, bbox=None, imsize=self.imsize, do_augment=self.split == 'train', image_cache=self.image_cache)
+ imgs = np.array(imgs) / 255.0
+ imgs = imgs.transpose(2, 0, 1)
+
+ # sent_ix = random.randint(0, self.embeddings_num)
+ # new_sent_ix = index * self.embeddings_num + sent_ix
+ new_sent_ix = global_index
+ caps, cap_len = self.get_caption(new_sent_ix)
+
+ return {
+ "img": imgs,
+ "input_ids": torch.from_numpy(caps).squeeze(-1),
+ "attention_mask": torch.ones((caps.shape[0],), dtype=torch.bool)
+ }
+
+ def __len__(self):
+ return len(self.filenames) * self.embeddings_num
\ No newline at end of file
diff --git a/models/datasets/image_datasets.py b/models/datasets/image_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4af7185021e44e692f25559a469fc2295b7042e
--- /dev/null
+++ b/models/datasets/image_datasets.py
@@ -0,0 +1,953 @@
+from email.mime import image
+import os
+import random
+import typing
+from pathlib import Path
+from typing import Optional
+import subprocess
+import datasets
+import torch
+from numpy import pad
+from PIL import Image, ImageFile
+from tensordict import TensorDict
+from torchvision import transforms
+from decoupled_utils import get_world_size
+import time
+import re
+import shutil
+from constants import UNIDISC_DIR
+from decoupled_utils import barrier, get_rank, gprint, is_local_main_process, is_main_process, is_torch_cuda_available, is_torch_xla_available, rprint
+from models.datasets.webdataset_utils import get_data
+import hashlib
+from decoupled_utils import sanitize_filename
+from omegaconf import OmegaConf, read_write
+from models.datasets.misc_image_datasets import *
+from copy import deepcopy
+from datasets import Dataset, DatasetDict
+import numpy as np
+from PIL import Image
+import json
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+import torch
+from torch.utils.data import Subset
+
+def split_dataset(dataset, n: int, m: int):
+ # Ensure m is valid
+ if m < 0 or m >= n:
+ raise ValueError(f"m must be between 0 and {n-1}, but got {m}.")
+
+ # Calculate the size of each subset
+ total_len = len(dataset)
+ subset_size = total_len // n
+ remainder = total_len % n
+
+ # Calculate the start and end index of the m-th subset
+ start_idx = m * subset_size + min(m, remainder)
+ end_idx = start_idx + subset_size + (1 if m < remainder else 0)
+
+ # Return the m-th subset
+ indices = list(range(start_idx, end_idx))
+ if isinstance(dataset, torch.utils.data.Dataset):
+ return Subset(dataset, indices)
+ else:
+ return dataset[slice(start_idx, end_idx)]
+
+def get_webdataset_indexed(config, tokenizer, transform, cond_transform, n_samples, name, should_tokenize=False):
+ should_tokenize = ("tokenize" in name) or should_tokenize
+ import wids # You need to use the custom sampler!!
+
+ custom_ignore_func_dict = {
+ "pixelprose": lambda x: len(x[".txt"]) > 400,
+ }
+
+ valid_func = None
+ for k in custom_ignore_func_dict.keys():
+ if k in name:
+ valid_func = custom_ignore_func_dict[k]
+ break
+
+ from dataloader import tokenize_text
+
+ def process(x, idx):
+ data = {}
+
+ if "mmc4" in name:
+ print(x['.json']['image_info'][0])
+ breakpoint()
+
+ img = x[".jpg"].convert("RGB")
+ data["is_valid"] = True
+ if valid_func is not None and valid_func(x) is False:
+ print(f"Invalid")
+ data["is_valid"] = False
+
+ data["img"] = transform(img)
+ if cond_transform is not None:
+ data["cond_img"] = cond_transform(x[".jpg"].convert("RGB"))
+
+ if data["img"].shape[0] != 3:
+ raise Exception(f"Image shape: {data['img'].shape}, {x['.jpg'].size}, {x['.jpg'].mode}")
+
+ if "pixelprose" in name:
+ before = x[".txt"]
+ x[".txt"] = re.sub(r"This image displays.*?(?=[a-zA-Z0-9])", "", x[".txt"])
+ if abs(len(before) - len(x[".txt"])) > 100:
+ data["is_valid"] = False
+
+ if not "imagenet" in name:
+ if should_tokenize:
+ tokens = tokenize_text(tokenizer, config.data.block_size, x[".txt"])
+ data["input_ids"] = tokens["input_ids"]
+ data["attention_mask"] = tokens["attention_mask"].float()
+ else:
+ data[".txt"] = x[".txt"]
+
+ data["idx"] = idx
+
+ return data
+
+ disable_split = False
+ if isinstance(config.data.raw_data_dir, str) and '*' in config.data.raw_data_dir:
+ import glob
+ index_path = sorted(glob.glob(config.data.raw_data_dir))
+ if not index_path:
+ raise ValueError(f"No files found matching the pattern: {config.data.raw_data_dir:}")
+ print(f"Expanded glob pattern to {len(index_path)} files")
+ if os.getenv("SLURM_ARRAY_TASK_COUNT", None) is not None:
+ index_path = split_dataset(index_path, int(os.getenv("SLURM_ARRAY_TASK_COUNT")), int(os.getenv("SLURM_ARRAY_TASK_ID")))
+ print(f"After splitting, dataset is length {len(index_path)}")
+ shards = []
+ for shard in index_path:
+ shards.append({"url": shard, "nsamples": wids.wids.compute_num_samples(shard)})
+ print(f"Shard: {shard}")
+ index_path = shards
+ disable_split = True
+ elif Path(config.data.raw_data_dir).is_file():
+ index_path = config.data.raw_data_dir
+ else:
+ default_path = Path(config.data.raw_data_dir) / "index.json"
+ shard_path = Path(config.data.raw_data_dir) / "shardindex.json"
+ index_path = str(default_path if default_path.exists() else shard_path)
+
+ assert getattr(config.data, "shard_list_path", None) is None, "shard_list_path is deprecated, use raw_data_dir instead"
+ dataset = wids.ShardListDataset(index_path) # lru_size=20
+ dataset = CustomTransformDataset(dataset, process)
+
+ if n_samples is not None:
+ from torch.utils.data import Subset
+ indices = torch.randperm(len(dataset))[:n_samples]
+ dataset = Subset(dataset, indices)
+
+ if config.data.split_dataset and not disable_split:
+ gprint(f"Original dataset was length {len(dataset)}")
+ dataset = split_dataset(dataset, int(os.getenv("SLURM_ARRAY_TASK_COUNT")), int(os.getenv("SLURM_ARRAY_TASK_ID")))
+ gprint(f"After splitting, dataset is length {len(dataset)}")
+
+ return dataset
+
+
+def _copy_data(src_path, dst_path, use_rsync=True):
+ dst_path.mkdir(parents=True, exist_ok=True)
+ if use_rsync:
+ rprint(f"Rsyncing data from {src_path} to {dst_path}")
+ rsync_command = ["rsync", "-av", str(src_path) + "/", str(dst_path) + "/"]
+ try:
+ result = subprocess.run(rsync_command, check=True, capture_output=True, text=True)
+ rprint(f"Rsync output: {result.stdout}")
+ rprint(f"Successfully rsynced data from {src_path} to {dst_path}")
+ except subprocess.CalledProcessError as e:
+ rprint(f"Rsync failed: {e}")
+ rprint(f"Rsync stderr: {e.stderr}")
+ raise
+ else:
+ rprint(f"Copying tensordict from {src_path} to {dst_path}")
+ shutil.copytree(src_path, dst_path)
+ rprint(f"Copied tensordict from {src_path} to {dst_path}")
+
+def copy_data(shm_path, src_path, dst_path):
+ shm_path.mkdir(parents=True, exist_ok=True)
+ use_rsync = True
+ if not dst_path.exists() or use_rsync:
+ _copy_data(src_path, dst_path, use_rsync=use_rsync)
+ else:
+ src_files = sum(1 for _ in src_path.rglob('*'))
+ dst_files = sum(1 for _ in dst_path.rglob('*'))
+ src_size = sum(f.stat().st_size for f in src_path.rglob('*') if f.is_file())
+ dst_size = sum(f.stat().st_size for f in dst_path.rglob('*') if f.is_file())
+ size_diff_percent = abs(src_size - dst_size) / max(src_size, dst_size) * 100
+ if src_files != dst_files or size_diff_percent > 10:
+ rprint(f"Src files: {src_files}, Dst files: {dst_files} contain different number of files, {src_size} {dst_size}, size difference {size_diff_percent}, Deleting {dst_path}")
+ shutil.rmtree(dst_path)
+ rprint(f"Deleted {dst_path}, copying from {src_path}")
+ _copy_data(src_path, dst_path, use_rsync=False)
+ rprint(f"Deleted and re-copied tensordict from {src_path} to {dst_path}")
+ else:
+ rprint(f"Tensordict already exists at {dst_path}, loading from there")
+
+def get_tensordict(config, path, dataset_idx, dataset_name=None):
+ parquet_files = list(Path(path).glob('*.arrow'))
+ if parquet_files:
+ # Does not load into memory by default
+ from datasets import load_from_disk
+ dataset = load_from_disk(path)
+ rprint(f"Loaded {len(dataset)} samples from {path} as parquet")
+ return dataset
+
+ if getattr(config.data, "force_dummy_tensordict", False):
+ return get_dummy_tensordict(config, 1000000, dataset_idx=dataset_idx)
+
+ if config.data.move_tensordict_to_shm:
+ assert config.data.keep_tensordict_on_disk is True
+ shm_path = Path(getattr(config.data, "shm_path", Path("/dev/shm") / Path.home().name))
+ src_path = Path(path)
+ dst_path = shm_path / (dataset_name if dataset_name is not None else src_path.name)
+ if getattr(config.data, "skip_copy_tensordict_to_shm", False):
+ gprint(f"Skipping copy of tensordict to SHM")
+ elif is_torch_xla_available():
+ if is_main_process():
+ copy_data(shm_path, src_path, dst_path)
+
+ barrier()
+ if not is_main_process():
+ import time
+ from torch_xla._internal import tpu
+ host_ip = tpu.get_worker_ips()[0]
+ file_dst_path = Path(shm_path)
+ src_path_on_host = f"aswerdlow@{host_ip}:{file_dst_path}"
+ gprint(f"Copying data from {src_path_on_host} to {file_dst_path}")
+ file_dst_path.mkdir(parents=True, exist_ok=True)
+ max_retries = 5
+ retry_delay = 15
+ for attempt in range(max_retries):
+ try:
+ gprint(f"After main copy, rsyncing data from {src_path_on_host} to {file_dst_path}")
+ command = f'bash {(UNIDISC_DIR / "scripts/rsync_data.sh").resolve()} {src_path_on_host}/ {file_dst_path}/'
+ os.environ.pop('SSH_AUTH_SOCK', None) # Breaks without this
+ gprint(command)
+ subprocess.run(command, shell=True, check=True)
+ gprint(f"Successfully rsynced data from {src_path_on_host} to {file_dst_path}")
+ break
+ except subprocess.CalledProcessError as e:
+ if attempt < max_retries - 1:
+ gprint(f"Rsync attempt {attempt + 1} failed. Retrying in {retry_delay} seconds..., {e}")
+ time.sleep(retry_delay)
+ retry_delay *= 2
+ else:
+ raise RuntimeError(f"Failed to rsync data after {max_retries} attempts: {e}")
+
+ gprint(f"Finished rsyncing data from {src_path_on_host} to {file_dst_path}")
+ barrier()
+ else:
+ if is_local_main_process():
+ copy_data(shm_path, src_path, dst_path)
+
+ # For now we assume we are on SPMD and there is only one process per worker [node]
+ if not is_torch_xla_available():
+ barrier()
+
+ else:
+ dst_path = Path(path)
+
+ path = dst_path
+ data = TensorDict.load_memmap(path)
+ if config.data.keep_tensordict_on_disk:
+ rprint(f"Keeping tensordict on disk at {path}")
+ else:
+ data = data.clone() # Move to CPU memory
+ rprint(f"Loaded {len(data)} samples from {path}")
+ return data
+
+
+def get_dummy_tensordict(config, dataset_size, txt_length=None, img_length=None, dataset_idx=0):
+ if img_length is None:
+ img_length = config.model.img_length
+ if txt_length is None:
+ txt_length = config.model.txt_length
+ return TensorDict(
+ {
+ "input_ids": torch.ones(dataset_size, config.model.length, dtype=torch.int32),
+ "attention_mask": torch.ones(dataset_size, config.model.length, dtype=torch.bool),
+ "img_input_ids": torch.ones(dataset_size, img_length, dtype=torch.int16),
+ "txt_input_ids": torch.ones(dataset_size, txt_length, dtype=torch.int32),
+ "txt_attention_mask": torch.ones(dataset_size, txt_length, dtype=torch.bool),
+ "idx": torch.arange(dataset_size, dtype=torch.int32).view(-1, 1),
+ "dataset_idx": torch.full((dataset_size,), fill_value=dataset_idx, dtype=torch.int32),
+ "write_flag": torch.zeros(dataset_size, 1, dtype=torch.bool),
+ },
+ batch_size=[dataset_size],
+ )
+
+
+def get_token_dataset(config, name, is_train, n_samples, n_duplicate, tokenizer):
+ assert getattr(config.data, "token_data_dir", None) is None, "token_data_dir is deprecated, use data_dir_train and data_dir_val instead"
+ if "dummy" in name:
+ return get_dummy_tensordict(config, n_samples if n_samples is not None else 100000)
+ data_key = (
+ config.data.data_dir_train if is_train else (config.data.data_dir_val if config.data.data_dir_val is not None else config.data.data_dir_train)
+ )
+ image_datasets_key = getattr(config.data, "image_data_train", None) if is_train else getattr(config.data, "image_data_val", None)
+
+ if config.data.use_weighted_tensordict_sampler:
+ _dataset_cls = MultipleTensorDictDataset
+ _datasets = [get_tensordict(config, x['dir'], dataset_idx=i, dataset_name=x['name']) for i, x in enumerate(data_key)]
+ _weights = [x['weight'] for x in data_key]
+ _dataset_names = [x['name'] for x in data_key]
+ _kwargs = dict()
+ _kwargs["config"] = config
+ _kwargs["tokenizer"] = tokenizer
+
+ if any(not isinstance(x, TensorDict) for x in _datasets):
+ _kwargs["returns_parquet"] = True
+ elif getattr(config.data, "add_text_to_weighted_sampler", False):
+ from datasets import load_dataset, interleave_datasets
+ rprint("Loading smollm datasets")
+ ds1 = load_dataset("HuggingFaceTB/smollm-corpus", "cosmopedia-v2", split="train", cache_dir=config.data.cache_dir, streaming=True)
+ ds2 = load_dataset("HuggingFaceTB/smollm-corpus", "fineweb-edu-dedup", split="train", cache_dir=config.data.cache_dir, streaming=True)
+ # DKYoon/SlimPajama-6B, "cerebras/SlimPajama-627B"
+ ds3 = load_dataset("DKYoon/SlimPajama-6B", split="train", cache_dir=config.data.cache_dir, streaming=True)
+ ds4 = load_dataset("bigcode/starcoderdata", split="train", cache_dir=config.data.cache_dir, streaming=True)
+ rprint(f"Finished loading datasets")
+ if getattr(config.data, "code_only", False):
+ _dataset = ds4
+ else:
+ _dataset = interleave_datasets([ds1, ds2, ds3, ds4], probabilities=[0.3, 0.3, 0.2, 0.2], seed=config.seed)
+
+ rprint(f"Finished interleaving datasets")
+ _datasets.append(_dataset)
+ _weights.append(1)
+ _dataset_names.append("SlimPajama-627B")
+ _kwargs["returns_tokenized_text"] = True
+ rprint(f"Finished creating dataset")
+ elif image_datasets_key is not None:
+ returns_raw_images = False
+ tokenize_vqvae_in_dataloader = False
+ allow_label = False
+ for key in image_datasets_key:
+ _key = OmegaConf.to_object(key)
+ if _key.get("raw_images", False) or config.data.force_raw_images_in_multiple_tensordict:
+ rprint(f"WARNING!!! Using raw images")
+ returns_raw_images = True
+
+ if _key.get("tokenize_vqvae_in_dataloader", False):
+ tokenize_vqvae_in_dataloader = True
+
+ if _key.get("allow_label", False):
+ rprint(f"WARNING!!! Using allow_label")
+ allow_label = True
+
+ if config.data.force_raw_images_in_multiple_tensordict:
+ tokenize_vqvae_in_dataloader = False
+ _key["tokenize_vqvae_in_dataloader"] = False
+ _key["disable_text_modality"] = True
+
+ image_config = OmegaConf.merge(deepcopy(config),
+ {
+ "data": {
+ **{k:v for k,v in _key.items() if k not in {"dir", "weight", "name", "raw_images"}}
+ },
+ }
+ )
+ image_dataset = get_image_dataset(
+ mode="train" if is_train else "val",
+ config=image_config,
+ tokenizer=tokenizer,
+ allow_aug=is_train,
+ force_aug=False,
+ name=key["train"] if is_train else key["val"],
+ )
+ _datasets.append(image_dataset)
+ _weights.append(key["weight"])
+ _dataset_names.append(key["name"])
+
+ _kwargs["returns_raw_images"] = returns_raw_images
+ _kwargs["returns_tokenize_vqvae_in_dataloader"] = tokenize_vqvae_in_dataloader
+ _kwargs["allow_label"] = allow_label
+
+ if n_samples is not None:
+ if getattr(config.data, "force_no_shuffle_tensordict", False):
+ _datasets = [data[:n_samples] for data in _datasets]
+ else:
+ _datasets = [data[torch.randperm(len(data), generator=torch.Generator().manual_seed(config.seed))[:n_samples]] for data in _datasets]
+ rprint(f"Sampled {n_samples} samples from {len(_datasets)}, is_train: {is_train}.")
+
+ data = _dataset_cls(datasets=_datasets, weights=_weights, dataset_names=_dataset_names, **_kwargs)
+ else:
+ data = get_tensordict(config, data_key, 0)
+ if n_samples is not None:
+ if getattr(config.data, "force_no_shuffle_tensordict", False):
+ indices = list(range(n_samples))
+ else:
+ indices = torch.randperm(len(data), generator=torch.Generator().manual_seed(config.seed))[:n_samples]
+ data = data[indices]
+ rprint(f"Sampled {n_samples} samples from {len(data)}, is_train: {is_train}. First 2 indices: {indices[:2]}")
+
+ if n_duplicate is not None:
+ data = torch.cat([data for _ in range(n_duplicate)], dim=0)
+ rprint(f"Duplicated {n_duplicate} times, now {len(data)} samples")
+
+ return data
+
+
+class UnpairedDatasetWrapper(torch.utils.data.Dataset):
+ def __init__(self, img_dataset, txt_dataset, shuffle=True):
+ self.img_dataset = img_dataset
+ self.txt_dataset = txt_dataset
+ self.shuffle = shuffle
+
+ def __len__(self):
+ if self.shuffle:
+ return min(len(self.img_dataset), len(self.txt_dataset))
+ else:
+ return max(len(self.img_dataset), len(self.txt_dataset))
+
+ def __getitem__(self, idx):
+ while True:
+ try:
+ if self.shuffle:
+ img_idx = torch.randint(0, len(self.img_dataset), (1,)).item()
+ txt_idx = torch.randint(0, len(self.txt_dataset), (1,)).item()
+ else:
+ txt_idx = idx
+ img_idx = idx % len(self.img_dataset)
+ return dict(**self.img_dataset[img_idx], **self.txt_dataset[txt_idx])
+ except Exception as e:
+ gprint(e)
+ import traceback
+
+ traceback.print_exc()
+ idx = (idx + 1) % len(self)
+
+
+def get_unpaired_dataset(config=None, tokenizer=None, mode="train", **kwargs):
+ image_dataset = get_image_dataset(config=config, mode=mode, tokenizer=tokenizer, **kwargs)
+ from models.datasets.text_datasets import get_text_dataset
+
+ text_dataset = get_text_dataset(
+ dataset_name=getattr(config.data, "txt_train", "text8") if mode == "train" else getattr(config.data, "txt_val", "text8"),
+ tokenizer=tokenizer,
+ mode="test" if (mode == "validation" and getattr(config.data, "txt_val", "text8") == "lm1b") else mode,
+ wrap=config.data.wrap,
+ block_size=config.model.txt_length, # Intentional
+ cache_dir=config.data.cache_dir,
+ num_proc=config.data.num_proc,
+ streaming=config.data.streaming,
+ )
+ return UnpairedDatasetWrapper(image_dataset, text_dataset, shuffle=getattr(config.data, "force_disable_shuffle", False) is False)
+
+
+def get_transform(resolution, orig_mode, allow_aug, force_aug, aggressive_aug=False):
+ if orig_mode == "train" and (allow_aug or force_aug):
+ if aggressive_aug:
+ rprint("Using aggressive augmentations")
+ transform = transforms.Compose(
+ [
+ transforms.RandomResizedCrop((resolution, resolution), scale=(0.8, 1.0), ratio=(0.97, 1.03)),
+ transforms.RandomHorizontalFlip(1.0 if force_aug else 0.5),
+ transforms.ToTensor(),
+ ]
+ )
+ else:
+ transform = transforms.Compose(
+ [
+ transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.RandomCrop((resolution, resolution)),
+ transforms.RandomHorizontalFlip(1.0 if force_aug else 0.5),
+ transforms.ToTensor(),
+ ]
+ )
+ else:
+ transform = transforms.Compose(
+ [
+ transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop((resolution, resolution)),
+ transforms.ToTensor(),
+ ]
+ )
+ return transform
+
+def load_vqvae_from_cache(config, full_cache_path):
+ global_cache_parent = os.environ.get("DIFFUSION_DATA_DIR", None)
+ if global_cache_parent is not None:
+ global_full_cache_path = Path(global_cache_parent) / full_cache_path.relative_to(Path(config.data.cache_dir))
+ gprint(f"Checking global cache path: {global_full_cache_path}")
+ if global_full_cache_path.exists() and len(list(global_full_cache_path.iterdir())) > 0:
+ gprint(f"Loading data from global cache: {global_full_cache_path}")
+ full_cache_path = global_full_cache_path
+
+ if not (full_cache_path.exists() and len(list(full_cache_path.iterdir())) > 0):
+ gprint(f"Cache path {full_cache_path} does not exist or is empty")
+ return None
+
+ gprint(f"Loading data from: {full_cache_path}, found {len(list(full_cache_path.iterdir()))} shards")
+ ret = []
+ kwargs = dict()
+
+ if config.data.keep_hf_dataset_in_memory:
+ kwargs["keep_in_memory"] = True
+ if config.loader.num_workers > 0:
+ for _ in range(5):
+ gprint(f"WARNING!!!! Keeping dataset in memory and num_workers > 0, this will cause excessive memory usage")
+ else:
+ gprint(f"Loading datasets into memory")
+
+ for folder in full_cache_path.iterdir():
+ if folder.is_dir():
+ ret.append(datasets.load_from_disk(folder, **kwargs))
+
+ ret = datasets.concatenate_datasets(ret).with_format("torch")
+ gprint(f"Loaded data from cache: {full_cache_path} with {len(ret)} samples")
+ return ret
+
+def get_vqvae_dataloader(config, name, split):
+ cache_key = f'vqvae_tokenized_{name}_{split}_{config.data.resolution}'
+ vae_ckpt_hash = ""
+
+ if hasattr(config.model, "use_custom_vae_ckpt") and config.model.use_custom_vae_ckpt:
+ vae_ckpt_hash = hashlib.md5(str(Path(config.model.use_custom_vae_ckpt).name).encode()).hexdigest()[:8]
+ cache_key += f"_{vae_ckpt_hash}"
+ if hasattr(config.model, "vae_type") and config.model.vae_type != "VQ-16":
+ cache_key += f"_{config.model.vae_type}"
+ if getattr(config.data, "vqvae_cache_suffix", None) is not None:
+ cache_key += f"_{config.data.vqvae_cache_suffix}"
+
+ cache_dir = config.data.cache_dir
+ full_cache_path = Path(cache_dir) / "tokens" / sanitize_filename(cache_key)
+ return full_cache_path
+
+
+def get_image_dataset(mode, config, tokenizer, allow_aug=True, force_aug=False, name=None, **kwargs):
+ rprint(f"Getting image dataset with mode {mode}")
+ if getattr(config.data, "tokenizers_parallelism", None) is not None:
+ rprint(f"Setting tokenizers parallelism to {config.data.tokenizers_parallelism}")
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" if config.data.tokenizers_parallelism is False else "true"
+
+ resolution = config.data.resolution
+ name = name or config.data.train
+ streaming = config.data.streaming
+ precache = config.data.precache
+ dynamic = streaming or precache is False
+
+ orig_mode = mode
+ block_size = getattr(config.data, "block_size", 1024)
+ is_train = orig_mode == "train"
+
+ n_duplicate_train = getattr(config.data, "n_duplicate_train", None)
+ n_duplicate_val = getattr(config.data, "n_duplicate_val", None)
+ n_duplicate = n_duplicate_train if is_train else n_duplicate_val
+
+ n_val_samples = getattr(config.data, "n_val_samples", None)
+ n_train_samples = getattr(config.data, "n_train_samples", None)
+ n_samples = n_train_samples if is_train else n_val_samples
+
+ raw_data_dir = getattr(config.data, "raw_data_dir", getattr(config.data, "data_dir", None))
+ rprint(f"Data dir is {raw_data_dir}")
+ unified_model = getattr(config.model, "unified_model", False) and getattr(config.data, "unpaired", False) is False
+
+ cond_resolution = getattr(config.data, "cond_resolution", None)
+
+ if "sora" in name:
+ return get_sora_dataset(config=config, tokenizer=tokenizer, **kwargs)
+ elif "tokens" in name:
+ print(f"Loading token dataset {name}")
+ assert config.data.use_token_dataset, "data.use_token_dataset must be true to load token datasets"
+ return get_token_dataset(config, name, is_train, n_samples, n_duplicate, tokenizer)
+
+ dataset_splits = {
+ "cassiekang/cub200_dataset": (
+ "train"
+ if ((orig_mode == "train" and n_train_samples is not None) or (orig_mode != "train" and n_val_samples is not None))
+ else "train+test"
+ ),
+ "nlphuji/flickr30k": "test",
+ "richwardle/reduced-imagenet": "train",
+ "tglcourse/lsun_church_train": "train" if is_train else "test",
+ "pixparse/cc12m-wds": "train",
+ "imagenet": "train" if is_train else "val",
+ "imagefolder": "train" if is_train else "validation",
+ "ILSVRC/imagenet-1k": "train" if is_train else "validation",
+ "pouya-haghi/imagenet-subset": "validation",
+ "laion/clevr-webdataset": "train" if is_train else "validation",
+ "pcuenq/lsun-bedrooms": "train" if is_train else "test",
+ "facebook/winoground": "test",
+ "sayakpaul/coco-30-val-2014": "train"
+ }
+
+ split = dataset_splits[name] if name in dataset_splits else "train"
+
+ if n_samples is not None:
+ split = f"{split}[:{n_samples}]"
+
+ extra_kwargs = dict()
+ cache_dir = Path(config.data.cache_dir)
+ cache_dir.mkdir(parents=True, exist_ok=True)
+
+ if "HF_HUB_DATASETS_TOKEN" in os.environ:
+ extra_kwargs["token"] = os.environ["HF_HUB_DATASETS_TOKEN"]
+
+ if name == "mmc4" or name == "cambrian":
+ from unidisc.tokenizers.tokenize_interleaved import JsonlDataset
+ dataset = JsonlDataset(glob_pattern=config.data.raw_data_dir)
+
+ if n_samples is not None:
+ from torch.utils.data import Subset
+ indices = list(range(len(dataset)))[:n_samples]
+ dataset = Subset(dataset, indices)
+
+ if config.data.split_dataset:
+ if getattr(config.data, "split_dataset_total_count", None) is not None and \
+ getattr(config.data, "split_dataset_cur_idx", None) is not None:
+ gprint(f"Splitting dataset into {config.data.split_dataset_total_count} shards, original length {len(dataset)}")
+ dataset = split_dataset(dataset, config.data.split_dataset_total_count, config.data.split_dataset_cur_idx)
+
+ gprint(f"Original dataset was length {len(dataset)}")
+ total_count, cur_idx = int(os.getenv("SLURM_ARRAY_TASK_COUNT")), int(os.getenv("SLURM_ARRAY_TASK_ID"))
+ dataset = split_dataset(dataset, total_count, cur_idx)
+ gprint(f"After splitting, dataset is length {len(dataset)}")
+
+ return dataset
+
+ if name == "imagefolder":
+ from datasets.data_files import DataFilesDict
+ with open(config.data.train_data_dir, "r") as f:
+ train_txt = [f"{config.data.data_dir}/{line.strip()}" for line in f.readlines()]
+ with open(config.data.val_data_dir, "r") as f:
+ val_txt = [f"{config.data.data_dir}/{line.strip()}" for line in f.readlines()]
+ data_files = DataFilesDict({"train": train_txt, "validation": val_txt})
+ extra_kwargs["data_files"] = data_files
+
+ if config.data.tokenize_vqvae_in_dataloader and not getattr(config.data, "allow_aug_vqvae_dataloader", False):
+ rprint(f"WARNING!!!! Disabling augmentations for VQVAE dataloader")
+ allow_aug = False
+ force_aug = False
+
+ transform = get_transform(resolution, orig_mode, allow_aug, force_aug, getattr(config.data, "aggressive_aug", False))
+ if cond_resolution is not None:
+ cond_transform = get_transform(cond_resolution, orig_mode, allow_aug, force_aug)
+ else:
+ cond_transform = None
+
+ if kwargs.get("transform", None) is not None:
+ rprint(f"Using transform from kwargs: {kwargs['transform']}")
+ transform = kwargs.pop("transform")
+
+ if name == "torchvision_imagenet":
+ from torchvision.datasets import ImageFolder
+
+ raw_data_dir = Path(config.data.raw_data_dir)
+ raw_data_dir = raw_data_dir / "train" if orig_mode == "train" else raw_data_dir / "val"
+ dataset = ImageFolder(raw_data_dir, transform=transform)
+ dataset = CustomTransformDataset(dataset, lambda x, idx: {"img": x[0], "label": x[1]})
+ return dataset
+
+ if "pixparse/cc12m-wds-fast" in name or "pixparse/cc3m-wds-fast" in name or "indexed" in name:
+ return get_webdataset_indexed(config, tokenizer, transform, cond_transform, n_samples, name, should_tokenize=True)
+
+ if name == "vggface2":
+ dataset = VGGFace(
+ Path(raw_data_dir),
+ is_train,
+ transform=transform,
+ filter_resolution=(resolution - 48),
+ cond_transform=cond_transform,
+ v2=getattr(config.data, "add_vggface_v2_attributes", False),
+ )
+ rprint(f"VGGFace2 has size {len(dataset)}")
+ return dataset
+
+ if name == "cub2011_custom":
+ from models.datasets.cub200 import TextDataset
+ dataset = TextDataset(data_dir='/path/to/cub200/birds', split='train' if is_train else 'test')
+ return dataset
+
+ wds_config = OmegaConf.create(
+ {
+ "train_data": None,
+ "val_data": None,
+ "dataset_type": "webdataset",
+ "train_data_upsampling_factors": None,
+ "batch_size": config.loader.batch_size if mode == "train" else config.loader.eval_batch_size,
+ "workers": config.loader.num_workers,
+ "distributed": True,
+ "seed": config.seed,
+ "val_num_samples": None,
+ "train_num_samples": config.data.webdataset_train_num_samples,
+ "val_num_samples": config.data.webdataset_val_num_samples,
+ "world_size": config.trainer.devices * config.trainer.num_nodes,
+ "block_size": block_size,
+ }
+ )
+ if config.data.dataset_type == "webdataset":
+ clean_brace_escape = lambda x: x.replace("[", "{").replace("]", "}")
+ wds_config.train_data = clean_brace_escape(config.data.webdataset_train_data)
+ wds_config.val_data = clean_brace_escape(config.data.webdataset_val_data)
+
+ if getattr(config.data, "webdataset_prefix", None) is not None:
+ wds_config.train_data = config.data.webdataset_prefix.replace("LITERALQUOTE", "'").replace("LITERALSPACE", " ") + wds_config.train_data
+ wds_config.val_data = config.data.webdataset_prefix.replace("LITERALQUOTE", "'").replace("LITERALSPACE", " ") + wds_config.val_data
+
+ if getattr(config.data, "webdataset_postfix", None) is not None:
+ wds_config.train_data = wds_config.train_data + config.data.webdataset_postfix.replace("LITERALQUOTE", "'").replace("LITERALSPACE", " ")
+ wds_config.val_data = wds_config.val_data + config.data.webdataset_postfix.replace("LITERALQUOTE", "'").replace("LITERALSPACE", " ")
+
+ return get_data(wds_config, (transform, transform), epoch=0, tokenizer=tokenizer)
+ if name == "laion400m":
+ # TODO: Debug if these configs are correct!!!! Not fully sure how the webdataset sharded dataloader should work.
+ wds_config.train_data = "/grogu/datasets/laion400m/dataset/{00000..00625}.tar"
+ wds_config.val_data = "/grogu/datasets/laion400m/dataset/{00000..00625}.tar"
+ return get_data(wds_config, (transform, transform), epoch=0, tokenizer=tokenizer)
+ elif name == "cc12m_3m":
+ # TODO: Debug if these configs are correct!!!! Not fully sure how the webdataset sharded dataloader should work.
+ wds_config.train_data = config.data.raw_data_dir + "/cc3m-train-{0000..0575}.tar"
+ wds_config.val_data = config.data.raw_data_dir + "/cc3m-validation-{0000..0015}.tar"
+ return get_data(wds_config, (transform, transform), epoch=0, tokenizer=tokenizer)
+ elif name == "facecaption":
+ if getattr(config.data, "webdataset_iterable", False):
+ wds_config.train_data = "/grogu/user/mprabhud/data/diffusion/facecaption/{00000..00001}.tar"
+ wds_config.val_data = "/grogu/user/mprabhud/data/diffusion/facecaption/{00000..00001}.tar"
+ return get_data(wds_config, (transform, transform), epoch=0, tokenizer=tokenizer)
+ elif getattr(config.data, "webdataset_indexed", False) is False:
+ return get_webdataset_indexed(config, tokenizer, transform, cond_transform, n_samples, name, should_tokenize=True)
+ else:
+ raise Exception("Unknown webdataset type")
+
+ # hf webdataset
+ if name == "pixparse/cc12m-wds":
+ extra_kwargs["data_dir"] = config.data.raw_data_dir
+
+ if name == "generated_images":
+ extra_kwargs["data_files"] = {"train": getattr(config.data, "parquet_path", None)}
+
+ if name != "imagefolder":
+ rprint(f"Loading dataset {name}, split={split}, streaming={streaming}, cache_dir={cache_dir}, extra_kwargs={extra_kwargs}, dynamic={dynamic}")
+
+ load_map = {"pixparse/cc12m-wds": "webdataset", "laion400m": "webdataset", "generated_images": "parquet"}
+ load_name = load_map.get(name, name)
+ if streaming is False:
+ extra_kwargs["num_proc"] = 16
+
+ if config.data.tokenize_vqvae_in_dataloader:
+ full_cache_path = get_vqvae_dataloader(config, name, split)
+ _ret = load_vqvae_from_cache(config, full_cache_path)
+ if _ret is not None: return _ret
+ from model import get_image_batch, get_vae
+ rank = get_rank()
+ vae = get_vae(config, device="cpu").eval()
+ vae.to(f"cuda:{rank}")
+
+ def tokenize_vqvae(batch):
+ device = f"cuda:{rank}"
+ img_input_ids = get_image_batch(config, vae, batch, device)
+ batch.pop("img")
+ batch["img_input_ids"] = img_input_ids
+ return batch
+
+ if config.data.keep_hf_dataset_in_memory:
+ extra_kwargs["keep_in_memory"] = True
+ gprint(f"WARNING!!!! Keeping dataset in memory")
+
+ if name == "geneval":
+ def create_blank_image():
+ return Image.new("RGB", (resolution, resolution), color=(255, 255, 255))
+
+ # https://github.com/djghosh13/geneval/blob/main/prompts/generation_prompts.txt
+ prompts_path = Path.home() / ".cache" / "unidisc" / "geneval_generation_prompts.txt"
+ if not prompts_path.exists():
+ prompts_path.parent.mkdir(parents=True, exist_ok=True)
+ import urllib.request
+ urllib.request.urlretrieve(
+ "https://raw.githubusercontent.com/djghosh13/geneval/main/prompts/generation_prompts.txt",
+ prompts_path
+ )
+ with open(prompts_path, "r") as f:
+ captions = [line.strip() for line in f.readlines()]
+
+ dataset = Dataset.from_dict({
+ "caption": captions,
+ "image": [
+ create_blank_image() for i in range(len(captions))
+ ],
+ })
+ elif name == "MJHQ":
+ def create_blank_image():
+ return Image.new("RGB", (resolution, resolution), color=(255, 255, 255))
+ prompts_path = Path.home() / ".cache" / "unidisc" / "MJHQ_meta_data.json"
+ if not prompts_path.exists():
+ prompts_path.parent.mkdir(parents=True, exist_ok=True)
+ import urllib.request
+ urllib.request.urlretrieve(
+ "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/meta_data.json",
+ prompts_path
+ )
+
+ with open(prompts_path, "r") as f:
+ data = json.load(f)
+ captions = [item["prompt"] for item in data.values()]
+
+ dataset = Dataset.from_dict({
+ "caption": captions,
+ "image": [
+ create_blank_image() for i in range(len(captions))
+ ],
+ })
+ else:
+ dataset = datasets.load_dataset(load_name, split=split, streaming=streaming, cache_dir=cache_dir, **extra_kwargs)
+
+ dataset_keys = {
+ "cassiekang/cub200_dataset": ("image", "text"),
+ "Andron00e/CUB200-custom": ("image",),
+ "nlphuji/flickr30k": ("image", "caption"),
+ "ILSVRC/imagenet-1k": ("image", "label"),
+ "richwardle/reduced-imagenet": ("image",),
+ "tglcourse/lsun_church_train": ("image",),
+ "imagefolder": ("image",),
+ "pixparse/cc12m-wds": ("jpg", "txt"),
+ "pravsels/FFHQ_1024": ("image",),
+ "pravsels/SFHQ_256": ("image",),
+ "jxie/celeba-hq": ("image",),
+ "tglcourse/lsun_church_train": ("image",),
+ "pouya-haghi/imagenet-subset": ("image",),
+ "DeepLearner101/ImageNetSubsetValidate": ("image",),
+ "PixArt-alpha/SAM-LLaVA-Captions10M": ("__key__", "txt"),
+ "generated_images": ("__key__", "caption"),
+ "laion/clevr-webdataset": ("jpg","txt"),
+ "pcuenq/lsun-bedrooms": ("image",),
+ "facebook/winoground": ("image_0", "image_1", "caption_0", "caption_1"),
+ "sayakpaul/coco-30-val-2014": ("image", "caption"),
+ "geneval": ("image", "caption"),
+ "MJHQ": ("image", "caption"),
+ }
+
+ from dataloader import tokenize_text
+
+ def preprocess_images(example, index: typing.Optional[typing.Any] = None):
+ data = {}
+ if dataset_keys[name][0] == "__key__":
+ images = []
+ is_valid = []
+ for key, _image_path in zip(example[dataset_keys[name][0]], example["image_path"]):
+ img_path = (
+ (Path(config.data.raw_data_dir) / key).with_suffix(".jpg") if not key.endswith(".jpg") else (Path(config.data.raw_data_dir) / key)
+ )
+ allow_relative = False
+ if Path(_image_path).exists() and Path(_image_path).stat().st_size > 0:
+ img = Image.open(_image_path)
+ is_valid.append(True)
+ elif allow_relative and img_path.exists() and img_path.stat().st_size > 0:
+ img = Image.open(img_path)
+ is_valid.append(True)
+ else:
+ img = Image.new("RGB", (resolution, resolution), color=(255, 255, 255))
+ is_valid.append(False)
+ images.append(img)
+ data["is_valid"] = is_valid
+ if sum(data["is_valid"]) < len(data["is_valid"]):
+ gprint(f"WARNING!!! Found {len(data['is_valid']) - sum(data['is_valid'])} invalid images")
+ else:
+ images = [image.convert("RGB") for image in example[dataset_keys[name][0]]]
+
+ data["img"] = [transform(image) for image in images]
+ if cond_resolution is not None:
+ data["cond_img"] = [cond_transform(image) for image in images]
+
+ if index is not None:
+ data["idx"] = index
+
+ if "idx" in example:
+ data["idx"] = example["idx"]
+
+ if dynamic and dataset_keys[name][0] is not None:
+ data["img"] = torch.stack(data["img"])
+
+ if "label" in example:
+ data["label"] = example["label"]
+ if (unified_model or getattr(config.data, "txt_only", False)) and not getattr(config.data, "disable_text_modality", False):
+ tokenizer.padding_side = "right"
+ tokenizer.truncation_side = "right"
+
+ if name == "facebook/winoground":
+ caption_0 = example["caption_0"]
+ caption_1 = example["caption_1"]
+ img_0 = example["image_0"]
+ img_1 = example["image_1"]
+ # tokenize and store captions separately
+ tokens_0 = tokenize_text(tokenizer, block_size, caption_0)
+ tokens_1 = tokenize_text(tokenizer, block_size, caption_1)
+ data["caption_0_input_ids"] = tokens_0["input_ids"]
+ data["caption_0_attention_mask"] = tokens_0["attention_mask"].float()
+ data["caption_1_input_ids"] = tokens_1["input_ids"]
+ data["caption_1_attention_mask"] = tokens_1["attention_mask"].float()
+ # convert img_0 and img_1 which are lists of PIL images to tensors
+ # convert some rgba pil images to rgb
+ data["img_0"] = torch.stack([transform(img.convert("RGB")) for img in img_0])
+ data["img_1"] = torch.stack([transform(img.convert("RGB")) for img in img_1])
+ else:
+ text_data = example[dataset_keys[name][1]]
+ if isinstance(text_data[0], list):
+ # Flickr has a list of captions for each image
+ text_data = [random.choice(_data) for _data in text_data]
+
+ tokens = tokenize_text(tokenizer, block_size, text_data)
+ data["input_ids"] = tokens["input_ids"]
+ data["attention_mask"] = tokens["attention_mask"].float()
+
+ return data
+
+ if precache is False:
+ tokenized_dataset = dataset.with_transform(preprocess_images)
+ else:
+ extra_kwargs = dict()
+ if streaming is False:
+ extra_kwargs["load_from_cache_file"] = True
+ else:
+ if name == "pixparse/cc12m-wds":
+ extra_kwargs["remove_columns"] = ["__key__", "jpg", "__url__", "json", "txt"]
+ elif name == "ILSVRC/imagenet-1k":
+ extra_kwargs["remove_columns"] = ["image"]
+
+ tokenized_dataset = dataset.map(preprocess_images, batched=True, with_indices=True, **extra_kwargs)
+ allowed_column_names = ["img", "input_ids", "attention_mask", "tokens", "text", "idx"]
+ current_column_names = tokenized_dataset.column_names
+ if current_column_names is not None:
+ for column_name in current_column_names:
+ if column_name not in allowed_column_names:
+ tokenized_dataset = tokenized_dataset.remove_columns(column_name)
+
+ if n_duplicate is not None:
+ tokenized_dataset = datasets.concatenate_datasets([tokenized_dataset] * n_duplicate)
+
+ ret = tokenized_dataset if dynamic else tokenized_dataset.with_format("torch")
+ if isinstance(dataset, torch.utils.data.IterableDataset) or "cc12m" in name:
+ ret = ResilientIterableDatasetWrapper(ret)
+
+ if config.data.tokenize_vqvae_in_dataloader:
+ assert config.data.force_mp_spawn
+ ret = ret.shard(num_shards=get_world_size(), index=get_rank(), contiguous=True, keep_in_memory=True)
+ gprint(f"Rank {rank} has {len(ret)} samples. World size is {get_world_size()}")
+ ret = ret.map(tokenize_vqvae, batch_size=getattr(config.data, "vqvae_batch_size", 128), batched=True, keep_in_memory=True)
+ ret.reset_format()
+ allowed_column_names = ["img_input_ids"]
+ map_column_list = getattr(config.data, "map_columns", None)
+ if map_column_list is not None:
+ for old_column_name, new_column_name in map_column_list.items():
+ ret = ret.rename_column(old_column_name, new_column_name)
+ if getattr(config.data, "allow_label", False):
+ allowed_column_names.append("label")
+ if getattr(config.data, "allowed_columns_vqvae_dataloader", None):
+ allowed_column_names.extend(list(config.data.allowed_columns_vqvae_dataloader))
+ current_column_names = ret.column_names
+ if current_column_names is not None:
+ for column_name in current_column_names:
+ if column_name not in allowed_column_names:
+ ret = ret.remove_columns(column_name)
+ rank_cache_path = full_cache_path / f"rank_{rank}"
+ gprint(f"Rank {rank} has saved to {rank_cache_path} with {len(ret)} samples")
+ ret.save_to_disk(rank_cache_path)
+ barrier()
+ gprint(f"Rank {rank} has finished saving to {rank_cache_path}. Sleeping for a bit. You may want to Ctrl+C now")
+ time.sleep(60 * 30)
+ ret = load_vqvae_from_cache(config, full_cache_path)
+ gprint(f"Rank {rank} has finished loading from file: {rank_cache_path}")
+
+ return ret
diff --git a/models/datasets/misc_image_datasets.py b/models/datasets/misc_image_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f155a41f1b779ca946c57914691d5e64ec59208
--- /dev/null
+++ b/models/datasets/misc_image_datasets.py
@@ -0,0 +1,604 @@
+import os
+import random
+import typing
+from pathlib import Path
+from typing import Optional
+
+import datasets
+import numpy as np
+import pandas as pd
+from unidisc.tokenizers.conversation import get_image_gen_tokens, get_image_suffix
+import torch
+import torch.nn as nn
+from numpy import pad
+from PIL import Image, ImageFile
+from tensordict import TensorDict
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.datasets import VisionDataset
+from torchvision.datasets.folder import default_loader
+import re
+import shutil
+from constants import LIB_DIR
+from decoupled_utils import barrier, gprint, is_main_process, is_torch_cuda_available, rprint
+from models.datasets.webdataset_utils import get_data
+from unidisc.utils.tensor_utils import get_interleaved_indices, get_contiguous_blocks, packbits, unpackbits
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+
+class ResilientIterableDatasetWrapper(torch.utils.data.IterableDataset):
+ def __init__(self, dataset):
+ self.dataset = dataset
+
+ def __iter__(self):
+ iterator = iter(self.dataset)
+ while True:
+ try:
+ yield next(iterator)
+ except StopIteration:
+ raise StopIteration
+ except Exception as e:
+ gprint(e)
+ iterator = iter(self.dataset)
+
+
+class ResilientDatasetWrapper(torch.utils.data.Dataset):
+ def __init__(self, dataset):
+ self.dataset = dataset
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, idx):
+ while True:
+ try:
+ return self.dataset[idx]
+ except Exception as e:
+ gprint(e)
+ import traceback
+ traceback.print_exc()
+ idx = (idx + 1) % len(self.dataset)
+
+
+class CustomTransformDataset(Dataset):
+ def __init__(self, original_dataset, transform):
+ self.original_dataset = original_dataset
+ self.transform = transform
+
+ def __len__(self):
+ return len(self.original_dataset)
+
+ def __getitem__(self, idx):
+ for i in range(10):
+ try:
+ data = self.original_dataset[idx]
+ if i > 0:
+ rprint(f"Took {i} times")
+ break
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+ gprint(e)
+
+ transformed_data = self.transform(data, idx=idx)
+ return transformed_data
+
+class TensorCollate(nn.Module):
+ def __init__(self, device=None, transform=None, enable_cuda_in_tensordict_collate=True):
+ super().__init__()
+ self.device = torch.device(device) if device is not None else None
+ self.transform = transform
+ self.enable_cuda_in_tensordict_collate = enable_cuda_in_tensordict_collate
+
+ def __call__(self, x: TensorDict):
+ if self.device is not None and self.device.type == "cuda" and self.enable_cuda_in_tensordict_collate:
+ out = x.pin_memory() # move data to RAM
+ else:
+ out = x
+
+ if self.device and self.enable_cuda_in_tensordict_collate:
+ out = out.to(self.device)
+
+ if self.transform:
+ out = self.transform(out)
+
+ return out
+
+def clean_identity(value):
+ cleaned_value = "".join(filter(str.isdigit, str(value)))
+ return int(cleaned_value) if cleaned_value else None
+
+
+class VGGFace(Dataset):
+ def __init__(self, path, is_train, filter_resolution: int = 196, transform=None, cond_transform=None, v2=False):
+ self.path = Path(path)
+ self.is_train = is_train
+
+ self.train_folders = self.get_folders("train")
+ self.test_folders = self.get_folders("test")
+ self.prefix = "train" if self.is_train else "test"
+ self.gender_meta = pd.read_csv(self.path / 'meta' / 'identity_meta.csv', on_bad_lines='skip')
+ self.v2 = v2
+ self.transform = transform
+ self.cond_transform = cond_transform
+ self.filter_resolution = filter_resolution
+
+ cache_file = self.path / f"{self.prefix}_{'filtered' if filter_resolution == 196 else ('unfiltered' if filter_resolution is None else 'filtered_' + str(filter_resolution))}.pkl"
+ if cache_file.exists():
+ self.data = pd.read_pickle(cache_file)
+ else:
+ self.data = pd.read_csv(self.path / "MAAD_Face.csv")
+ self.data["Identity"] = self.data["Identity"].apply(clean_identity)
+ self.data = self.data[self.data["Identity"].isin(self.train_folders if self.is_train else self.test_folders)]
+ def get_image_size(file_path):
+ with Image.open(file_path) as img:
+ return img.size
+
+ self.data['Resolution'] = self.data.apply(lambda row: get_image_size(self.path / "data" / self.prefix / f"{row['Filename']}"), axis=1)
+ if filter_resolution:
+ self.data = self.data[self.data['Resolution'].apply(lambda x: x[0] >= filter_resolution and x[1] >= filter_resolution)]
+
+ self.data = self.data.drop('Resolution', axis=1)
+ self.data.to_pickle(cache_file)
+
+ def get_folders(self, split):
+ train_path = Path(self.path) / "data" / split
+ folders = [int(folder.name[1:]) for folder in train_path.iterdir() if folder.is_dir()]
+ return folders
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ row = self.data.iloc[idx]
+ img_path = self.path / "data" / self.prefix / f"{row['Filename']}"
+ attr = row.to_numpy()[2:].astype(int)
+ tokens = attr.copy() + 1
+ non_zero_mask = attr > 0
+ non_zero_idx = np.where(non_zero_mask)[0]
+
+ if self.v2:
+ attn_mask = np.ones(48)
+ matched_ = self.gender_meta[self.gender_meta["Class_ID"] == row.Filename.split("/")[0]]
+ assert len(matched_) <= 1, f"idx: {idx}, filename: {row}"
+ if len(matched_) == 1:
+ matched_row = matched_.iloc[0]
+ is_female = matched_row[" Gender"] == " f"
+ else:
+ is_female = False
+ attn_mask[0] = 0
+
+ tokens[non_zero_idx] = non_zero_idx + 3
+ tokens = np.concatenate([np.array([2 if is_female else 0]), tokens])
+ else:
+ attn_mask = np.zeros(len(tokens))
+ tokens[non_zero_idx] = non_zero_idx + 2
+
+ img = Image.open(img_path)
+ ret_dict = {"img": img, "input_ids": tokens, "attention_mask": attn_mask, "idx": idx}
+
+ if self.transform:
+ ret_dict["img"] = self.transform(img)
+
+ if self.cond_transform is not None:
+ ret_dict["cond_img"] = self.cond_transform(img)
+
+ return ret_dict
+
+class Cub2011(VisionDataset):
+ def __init__(
+ self,
+ root: Path,
+ train=True,
+ transform=None,
+ target_transform=None,
+ transforms=None,
+ shuffle_attributes=False,
+ n_duplicate=None,
+ n_samples=None,
+ **kwargs,
+ ):
+ super(Cub2011, self).__init__(root, transform=transform, target_transform=target_transform, transforms=transforms)
+ self.train = train
+ self.shuffle_attributes = shuffle_attributes
+ self.n_duplicate = n_duplicate
+ self.n_samples = n_samples
+ self.loader = default_loader
+ self._load_metadata()
+
+ def _load_metadata(self):
+ images = pd.read_csv(self.root / "images.txt", sep=" ", names=["img_id", "filepath"])
+ image_class_labels = pd.read_csv(self.root / "image_class_labels.txt", sep=" ", names=["img_id", "target"])
+ train_test_split = pd.read_csv(self.root / "train_test_split.txt", sep=" ", names=["img_id", "is_training_img"])
+
+ data = images.merge(image_class_labels, on="img_id")
+ self.data = data.merge(train_test_split, on="img_id")
+ class_names = pd.read_csv(self.root / "classes.txt", sep=" ", names=["class_name"], usecols=[1])
+ self.class_names = class_names["class_name"].to_list()
+
+ if self.train:
+ self.data = self.data[(self.data.is_training_img == 1) | (self.data.index < 10000)]
+ else:
+ self.data = self.data[(self.data.is_training_img == 0) & (self.data.index >= 10000)]
+
+ df_images = pd.read_csv(self.root / "images.txt", sep="\s+", names=["img_id", "img_path"])
+ df_labels = pd.read_csv(self.root / "classes.txt", sep="\s+", names=["cls_id", "cls_name"])
+ df_is_train = pd.read_csv(self.root / "train_test_split.txt", sep="\s+", names=["img_id", "is_train"])
+
+ df_att = pd.read_csv(self.root / "attributes.txt", sep="\s+", names=["att_id", "att_name"])
+ df_att_ant = pd.read_csv(
+ self.root / "attributes/image_attribute_labels_filtered.txt", names=["img_id", "att_id", "is_pres", "cert_id", "time"], sep="\s+"
+ )
+
+ image_ids = df_att_ant["img_id"].unique()
+ df_images = df_images[df_images["img_id"].isin(image_ids)]
+ df_is_train = df_is_train[df_is_train["img_id"].isin(image_ids)]
+
+ df_data_att = pd.merge(df_att_ant, df_att, on="att_id", how="left")
+ df_data_att = df_data_att.loc[(df_data_att["is_pres"] == 1) & (df_data_att["cert_id"] > 2)]
+
+ self.df_data_att = df_data_att
+
+ def __len__(self):
+ orig_size = len(self.data)
+ if self.n_samples is not None:
+ orig_size = self.n_samples
+ if self.n_duplicate is not None:
+ orig_size = orig_size * self.n_duplicate
+ return orig_size
+
+ def __getitem__(self, idx):
+ if isinstance(idx, torch.Tensor):
+ idx = idx.item()
+
+ if self.n_samples is not None:
+ idx = idx % self.n_samples
+
+ idx = idx % len(self.data)
+ sample = self.data.iloc[idx]
+ img_id = sample["img_id"]
+ path = self.root / "images" / sample.filepath
+ img = self.loader(path)
+ if self.transform is not None:
+ img = self.transform(img)
+
+ data = {"img": img}
+ data["text"] = ", ".join(list(self.df_data_att.loc[(self.df_data_att["img_id"] == img_id)]["att_name"].values))
+ tokens = torch.full((312,), dtype=torch.int64, fill_value=0) # 40 is our pad token
+ _atts = self.df_data_att.loc[(self.df_data_att["img_id"] == img_id)]["att_id"].values
+ _atts = _atts.tolist()
+ if self.shuffle_attributes:
+ random.shuffle(_atts)
+ tokens[: len(_atts)] = torch.tensor(_atts)
+ data["input_ids"] = tokens
+ data["attention_mask"] = tokens > 0
+ return data
+
+
+class TokenDataset(Dataset):
+ def __init__(self, path, n_samples: typing.Optional[int] = None, n_duplicate: Optional[int] = None, should_aug: bool = False):
+ self.path = path
+ self.data = TensorDict.load_memmap(path)
+ self.n_samples = n_samples
+ self.n_duplicate = n_duplicate
+ self.device = None
+
+ def to_gpu(self, device):
+ self.device = device
+ self.data = self.data.to(self.device)
+
+ def __len__(self):
+ if self.n_duplicate is None and self.n_samples is None:
+ return len(self.data)
+ else:
+ n_duplicate = 1 if self.n_duplicate is None else self.n_duplicate
+ n_samples = 1 if self.n_samples is None else self.n_samples
+ return n_samples * n_duplicate
+
+ def __getitem__(self, idx):
+ n_samples = self.n_samples if self.n_samples is not None else len(self.data)
+ n_duplicate = self.n_duplicate if self.n_duplicate is not None else 1
+ idx = idx % (n_samples * n_duplicate)
+ element = self.data[idx]
+
+ index_keys = ["img_input_ids", "txt_input_ids"]
+ for key in index_keys:
+ if key in element:
+ element[key] = element[key].to(torch.int64)
+
+ index_keys = ["img_label"]
+ for key in index_keys:
+ if key in element:
+ element[key] = element[key].squeeze(-1)
+
+ return element.to_dict()
+
+
+def get_sora_dataset(mode, config, tokenizer, should_aug=True, **kwargs):
+ assert (LIB_DIR / "Open-Sora-Plan").exists()
+ __import__("sys").path.append(str(LIB_DIR / "Open-Sora-Plan"))
+ from opensora.dataset.transform import (CenterCropResizeVideo,
+ RandomHorizontalFlipVideo,
+ TemporalRandomCropGlobal,
+ ToTensorVideo)
+
+ from models.datasets.t2v_datasets import T2V_dataset
+
+ is_train = mode == "train"
+ n_duplicate_train = getattr(config.data, "n_duplicate_train", None)
+ n_duplicate_val = getattr(config.data, "n_duplicate_val", None)
+ n_duplicate = n_duplicate_train if is_train else n_duplicate_val
+
+ n_val_samples = getattr(config.data, "n_val_samples", None)
+ n_train_samples = getattr(config.data, "n_train_samples", None)
+ n_samples = n_train_samples if is_train else n_val_samples
+
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+ is_celeb = getattr(config.data, "celeb", False)
+ temporal_sample = TemporalRandomCropGlobal() # DynamicSampleDuration
+
+ transform = transforms.Compose(
+ [
+ ToTensorVideo(),
+ CenterCropResizeVideo(config.data.resolution),
+ *([RandomHorizontalFlipVideo(p=0.5)] if is_train and should_aug else []), # This may mess up certain captions.
+ ]
+ )
+
+ dataset = T2V_dataset(
+ num_frames=config.data.num_frames,
+ transform=transform,
+ temporal_sample=temporal_sample,
+ tokenizer=tokenizer,
+ hf_format=True,
+ unified_model=config.model.unified_model,
+ specified_keywords_only=getattr(config.data, "specified_keywords_only", None),
+ ignore_clips=True,
+ celeb_only=is_celeb,
+ model_max_length=128,
+ is_train=is_train,
+ n_duplicate=n_duplicate,
+ n_samples=n_samples,
+ **kwargs,
+ )
+
+ return dataset
+
+
+def get_sample_ids_from_attention_mask(attention_mask):
+ if attention_mask.all():
+ return torch.zeros_like(attention_mask, dtype=torch.int)
+
+ # Convert boolean tensor to integer for easy manipulation (True -> 1, False -> 0)
+ inverted = (~attention_mask).to(torch.int)
+
+ # Find the last position where the False sequence starts
+ diff = inverted.diff(dim=0, prepend=torch.tensor([0], dtype=inverted.dtype))
+
+ # Find the starting position of the last contiguous False sequence
+ nonzero_indices = (diff == 1).nonzero(as_tuple=True)[0]
+ if nonzero_indices.numel() == 0: assert False
+ last_false_start = nonzero_indices.max(dim=0)[0] if nonzero_indices.numel() > 0 else torch.tensor(0)
+
+ # Mark all elements in the last contiguous False sequence as -1
+ output = torch.zeros_like(attention_mask, dtype=torch.int)
+ output[last_false_start:] = inverted[last_false_start:].cumsum(0).ne(0).to(torch.int) * -1
+
+ return output
+
+
+class MultipleTensorDictDataset(Dataset):
+ def __init__(self, datasets, weights, dataset_names, config, tokenizer=None, returns_raw_images=False, returns_tokenized_text=False, returns_parquet=False, returns_tokenize_vqvae_in_dataloader=False, allow_label=False):
+ self.datasets = [x.to("cpu") if isinstance(x, TensorDict) else x for x in datasets]
+ self.weights = weights
+ self.dataset_names = dataset_names
+ self.add_dataset_idx = True
+ self.tokenizer = tokenizer # this is for text only
+ self.text_vocab_size = getattr(config.model, "text_vocab_size")
+
+ self.config = config
+ self.returns_raw_images = returns_raw_images
+ self.returns_tokenized_text = returns_tokenized_text
+ self.returns_parquet = returns_parquet
+ self.returns_tokenize_vqvae_in_dataloader = returns_tokenize_vqvae_in_dataloader
+ self.seq_len = config.model.length
+ self.allow_label = allow_label
+ self.require_sample_ids = getattr(config.data, "require_sample_ids", False)
+ self.remove_txt_img_padding = getattr(config.data, "remove_txt_img_padding", False)
+ self.add_image_gen_tokens = getattr(config.data, "add_image_gen_tokens", False)
+ self.dynamic_packing_lengths = getattr(config.data, "dynamic_packing_lengths", False)
+
+ if self.dynamic_packing_lengths:
+ # We can't directly stack here, we first need to pack/pad in the packing collate
+ rprint(f"Removing __getitems__ from {self.__class__.__name__} as we are using dynamic packing lengths")
+ if hasattr(self, '__getitems__'):
+ delattr(self.__class__, '__getitems__')
+
+ if self.allow_label and not self.returns_raw_images:
+ self.raw_images_keys_supported = ["input_ids", "attention_mask", "modality", "label", "sample_ids"]
+ else:
+ self.raw_images_keys_supported = ["img", "input_ids", "attention_mask", "modality", "idx", "label", "sample_ids"]
+
+ assert not getattr(config.trainer, "force_shift_image_batches", False)
+
+ def __len__(self):
+ return sum(10 if isinstance(dataset, torch.utils.data.IterableDataset) else len(dataset) for dataset in self.datasets)
+
+ def __getitem__(self, index_data):
+ dataset_idx, idx = index_data
+ dataset = self.datasets[dataset_idx]
+ if isinstance(dataset, TensorDict):
+ data = dataset[idx]
+ txt_len = None
+
+ if "attention_mask" in data and (data["attention_mask"] == False).all():
+ is_pad = data["input_ids"] == self.tokenizer.pad_token_id
+ change_points = torch.where(is_pad[:-1] != is_pad[1:])[0] + 1
+ if change_points.numel() > 0 and is_pad[-1]:
+ start_pos = change_points[-1].item()
+ data["attention_mask"][:start_pos] = True
+
+ if "input_ids" not in data:
+ if self.remove_txt_img_padding:
+ image_gen_tokens = get_image_gen_tokens(self.tokenizer)
+ new_txt_input_ids = data["txt_input_ids"].to(torch.int64)[data["txt_attention_mask"].to(torch.bool)]
+ new_txt_attention_mask = data["txt_attention_mask"].to(torch.bool)[data["txt_attention_mask"].to(torch.bool)]
+ new_txt_input_ids = torch.cat([image_gen_tokens["input_ids"][0], new_txt_input_ids], dim=-1)
+
+ if new_txt_input_ids[-1] == self.tokenizer.eos_token_id:
+ new_txt_input_ids = new_txt_input_ids[:-1]
+ new_txt_attention_mask = new_txt_attention_mask[:-1]
+
+ new_txt_input_ids = torch.cat([new_txt_input_ids, torch.tensor(get_image_suffix(self.tokenizer), dtype=torch.int64)], dim=-1)
+ new_txt_attention_mask = torch.cat([new_txt_attention_mask, torch.ones_like(new_txt_attention_mask[:1])], dim=-1)
+ new_txt_input_modality = torch.zeros((new_txt_input_ids.shape[0],), dtype=torch.int64)
+ img_modality = torch.ones((data["img_input_ids"].shape[0],), dtype=torch.int64)
+
+ new_input_ids = torch.cat([new_txt_input_ids, data["img_input_ids"].to(torch.int64), torch.tensor([self.tokenizer.eos_token_id], dtype=torch.int64)], dim=-1)
+ new_attention_mask = torch.ones_like(new_input_ids, dtype=torch.bool)
+ new_modality = torch.cat([new_txt_input_modality, img_modality, torch.zeros_like(new_txt_input_modality[:1])], dim=-1)
+
+ txt_len = None
+ data = TensorDict.from_dict(
+ {
+ "input_ids": new_input_ids,
+ "attention_mask": new_attention_mask,
+ "modality": new_modality
+ },
+ batch_size=[],
+ )
+ else:
+ txt_len = data["txt_input_ids"].shape[0]
+ data = TensorDict.from_dict(
+ {
+ "input_ids": torch.cat(
+ [data["txt_input_ids"].to(torch.int64), data["img_input_ids"].to(torch.int64)], dim=-1
+ ),
+ "attention_mask": torch.cat(
+ [data["txt_attention_mask"].to(torch.bool), torch.ones_like(data["img_input_ids"]).to(torch.bool)], dim=-1
+ ),
+ },
+ batch_size=[],
+ )
+
+ if self.require_sample_ids and "sample_ids" not in data:
+ data["sample_ids"] = get_sample_ids_from_attention_mask(data["attention_mask"])
+
+ else:
+ if "modality" in data and data["modality"].shape[-1] != data["input_ids"].shape[-1]:
+ data["modality"] = unpackbits(data["modality"]).to(torch.int64)
+
+ if "attention_mask" in data and data["attention_mask"].shape[-1] != data["input_ids"].shape[-1]:
+ data["attention_mask"] = unpackbits(data["attention_mask"]).to(torch.bool)
+
+ if "modality" not in data:
+ data["modality"] = torch.zeros((data["input_ids"].shape[0],), dtype=torch.int64)
+
+ elif data["modality"].shape[0] == 1:
+ data["modality"] = data["modality"].expand(data["input_ids"].shape[0])
+
+ if txt_len is not None:
+ data["modality"][txt_len:] = 1
+
+ if "idx" in data:
+ data.pop("idx")
+ else:
+ if isinstance(dataset, torch.utils.data.IterableDataset):
+ data = next(iter(dataset))
+ else:
+ data = dataset[idx]
+
+ if self.returns_raw_images:
+ if not isinstance(data, TensorDict):
+ data = TensorDict.from_dict(data, batch_size=[])
+
+ if "idx" in data and len(data["idx"].shape) == 0:
+ data["idx"] = data["idx"].unsqueeze(-1)
+
+ if "input_ids" not in data:
+ data["input_ids"] = torch.full((self.seq_len,), dtype=torch.int64, fill_value=-1)
+ data["attention_mask"] = torch.full((self.seq_len,), dtype=torch.bool, fill_value=True)
+ data["modality"] = torch.full((self.seq_len,), dtype=torch.int64, fill_value=1)
+
+ elif "modality" not in data:
+ data["modality"] = torch.full((self.seq_len,), dtype=torch.int64, fill_value=1) # assuming images
+ data["modality"][:data["input_ids"].shape[0]] = 0
+ data["input_ids"] = torch.cat([data["input_ids"], torch.full((self.seq_len - data["input_ids"].shape[0],), dtype=torch.int64, fill_value=-1)])
+ data["attention_mask"] = torch.cat([data["attention_mask"], torch.full((self.seq_len - data["attention_mask"].shape[0],), dtype=torch.bool, fill_value=True)]).bool()
+
+ elif self.returns_tokenized_text:
+ from dataloader import tokenize_text
+ _txt = data["content"] if "content" in data else data["text"]
+ data = tokenize_text(self.tokenizer, self.text_length, _txt)
+ data = TensorDict.from_dict({
+ "input_ids": data["input_ids"].to(torch.int64),
+ "attention_mask": data["attention_mask"].to(torch.bool)},
+ batch_size=[])
+ if "modality" not in data:
+ data["modality"] = torch.full((data["input_ids"].shape[0], ), dtype=torch.int64, fill_value=0)
+ elif self.returns_parquet:
+ if "attention_mask" not in data:
+ data["attention_mask"] = torch.ones((len(data["input_ids"])), dtype=torch.bool)
+ data = TensorDict.from_dict({
+ "input_ids": data["input_ids"],
+ "attention_mask": data["attention_mask"].bool() if isinstance(data["attention_mask"], torch.Tensor) else torch.tensor(data["attention_mask"], dtype=torch.bool)
+ }, batch_size=[])
+
+ if "modality" not in data:
+ data["modality"] = torch.full((data["input_ids"].shape[0],), dtype=torch.int64, fill_value=0)
+
+ if self.require_sample_ids and "sample_id" not in data:
+ sequence_starts = (data["input_ids"] == self.tokenizer.bos_token_id).long()
+ assert sequence_starts[0] == 1
+ sample_ids = torch.cumsum(sequence_starts, dim=0) - 1
+ unique_ids, counts = torch.unique(sample_ids, return_counts=True)
+ occurrence_mask = torch.isin(sample_ids, unique_ids[counts < 10]) # Require at least 10 tokens to be presen
+ data["sample_ids"] = torch.where(occurrence_mask, -1, sample_ids)
+
+ elif self.returns_tokenize_vqvae_in_dataloader:
+ if "txt_input_ids" in data and "txt_attention_mask" in data:
+ modality = torch.zeros(data["txt_input_ids"].shape[0] + data["img_input_ids"].shape[0], dtype=torch.int64)
+ modality[data["txt_input_ids"].shape[0]:] = 1
+ data = TensorDict.from_dict({
+ "input_ids": torch.cat([data["txt_input_ids"], data["img_input_ids"]], dim=-1),
+ "attention_mask": torch.cat([data["txt_attention_mask"], torch.ones_like(data["img_input_ids"], dtype=torch.bool)], dim=-1).bool(),
+ "modality": modality
+ }, batch_size=[])
+ else:
+ data = TensorDict.from_dict({
+ "input_ids": data["img_input_ids"],
+ "attention_mask": torch.ones_like(data["img_input_ids"], dtype=torch.bool),
+ "modality": torch.full((data["img_input_ids"].shape[0],), dtype=torch.int64, fill_value=1)
+ }, batch_size=[])
+ else:
+ raise ValueError(f"Unsupported return type")
+
+ data["input_ids"] = data["input_ids"].to(torch.int64)
+ data["input_ids"] = torch.where(
+ (data["modality"] == 1) & (data["input_ids"] != -1),
+ data["input_ids"] + self.config.data.img_token_shift,
+ data["input_ids"]
+ )
+
+ if not self.allow_label and "label" in data:
+ data.pop("label")
+
+ if self.returns_raw_images or self.allow_label:
+ # fill in the missing keys in tensor dict for both text and image batches
+ for key in self.raw_images_keys_supported:
+ if key not in data:
+ if key == "img":
+ data[key] = torch.zeros((3, self.config.data.resolution, self.config.data.resolution), dtype=torch.float32)
+ elif key == "label":
+ data[key] = torch.full((1,), dtype=torch.int64, fill_value=0)
+ else:
+ data[key] = torch.full((self.config.model.length,), dtype=torch.int64, fill_value=self.tokenizer.pad_token_id)
+
+ if "attention_mask" in data and (data["attention_mask"] == 0).all():
+ breakpoint()
+
+ return data.clone()
+
+ def __getitems__(self, index_data_list):
+ return torch.stack([self.__getitem__(index_data) for index_data in index_data_list]).clone()
diff --git a/models/datasets/precompute_text_tokens.py b/models/datasets/precompute_text_tokens.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3617da9f3ad409439950033a8f3182d3710e1c8
--- /dev/null
+++ b/models/datasets/precompute_text_tokens.py
@@ -0,0 +1,211 @@
+import os
+import shutil
+import signal
+import sys
+import time
+from contextlib import ExitStack
+from functools import partial
+from pathlib import Path
+
+from accelerate.utils import gather_object
+from torchinfo import summary
+
+from unidisc.tokenizers.chameleon_tokenizers import tokenize_chameleon
+from utils import _print_config, set_numa_affinity, set_omega_conf_resolvers
+
+sys.path.append(str(Path(__file__).parent.parent.parent / "unidisc/misc/hydra_submitit_launcher"))
+import itertools
+import json
+import os
+import random
+import sys
+from contextlib import nullcontext
+from pathlib import Path
+
+import fsspec
+import hydra
+import numpy as np
+import omegaconf
+import rich.syntax
+import rich.tree
+import torch
+from accelerate import Accelerator
+from PIL import Image
+from tensordict import TensorDict
+from tqdm import tqdm
+from viztracer import VizTracer
+
+from dataloader import get_dataloaders, get_tokenizer, tokenize_text
+from decoupled_utils import (barrier, breakpoint_on_error, get_world_size,
+ is_local_main_process, is_main_process,
+ rank_zero_fn, rprint, set_global_breakpoint,
+ set_global_exists, gprint)
+from model import decode_latents, get_image_batch, get_vae
+from models.datasets.combine_token_dicts import main as combine_token_dicts
+from models.datasets.vggface_v2_attributes import (get_inference_func,
+ get_output)
+from utils import (_print_config, set_numa_affinity, set_omega_conf_resolvers,
+ set_torch_defaults)
+
+os.environ["HYDRA_FULL_ERROR"] = "1"
+
+set_global_breakpoint() # Overrides breakpoint() to use ipdb.set_trace() instead and handle distributed training
+set_global_exists()
+set_omega_conf_resolvers()
+set_torch_defaults()
+
+def get_dict(config, dataset_size):
+ data = TensorDict(
+ {
+ "input_ids": torch.zeros(dataset_size, config.model.img_length, dtype=torch.int16),
+ "idx": torch.full((dataset_size, 1), fill_value=-1, dtype=torch.int32),
+ "write_flag": torch.zeros(dataset_size, 1, dtype=torch.bool),
+ "modality": torch.full((dataset_size, 1), fill_value=-1, dtype=torch.int16),
+ },
+ batch_size=[dataset_size],
+ )
+ return data
+
+def _group_texts(examples, block_size, bos, eos):
+ # Concatenate all texts.
+ concatenated_examples = list(itertools.chain(* examples['input_ids']))
+ total_length = len(concatenated_examples)
+ # TODO(yair): look into not dropping the remainder but rather padding it.
+ # We drop the small remainder, and if the total_length < block_size - 2
+ # we exclude this batch and return an empty dict.
+ # We could add padding if the model supported it instead of
+ # this drop, you can customize this part to your needs.
+ new_block_size = block_size # [BOS] and [EOS] to be added
+ total_length = (total_length // new_block_size) * new_block_size
+ # Split by chunks of max_len.
+ result = {}
+ _values = []
+ _attn_masks = []
+ for i in range(0, total_length, new_block_size):
+ _data = concatenated_examples[i : i + new_block_size]
+ _data[0] = bos
+ _data[-1] = eos
+ _values.append(_data)
+
+ result['input_ids'] = _values
+
+ # We don't have pad tokens when wrapped so we can ignore these
+ # result['attention_mask'] = _attn_masks
+ # result['modality'] = [[0] for _ in range(len(result['input_ids']))]
+
+ return result
+
+def preprocess_and_tokenize(example, tokenizer, dataset_name, wrap, block_size, EOS, BOS):
+ if dataset_name == 'ptb':
+ text = example['sentence']
+ elif 'scientific_papers' in dataset_name:
+ text = example['article']
+ else:
+ text = example['text']
+
+ tokenizer.padding_side = 'right'
+ tokenizer.truncation_side = 'right'
+
+ if wrap:
+ tokens = tokenizer(text,
+ add_special_tokens=True,
+ return_attention_mask=False,
+ return_token_type_ids=False)
+ tokens = {'input_ids': tokens['input_ids']}
+ # Still missing BOS, but will be added in group_texts
+ else:
+ tokens = tokenizer(text,
+ max_length=block_size,
+ padding='max_length',
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_token_type_ids=True)
+ return tokens
+
+def add_modality(output_dataset):
+ modality_column = torch.zeros(len(output_dataset), 1, dtype=torch.long)
+ output_dataset = output_dataset.add_column("modality", modality_column)
+
+import datasets
+@hydra.main(version_base=None, config_path="../../configs", config_name="config")
+def main(config):
+ """Main entry point for training."""
+ _print_config(config, resolve=True, save_cfg=True)
+ tokenizer = get_tokenizer(config)
+ block_size = config.data.block_size
+
+ wrap = True
+ streaming = config.data.streaming
+ num_proc = config.data.num_proc
+ split = getattr(config.data, "split", "train")
+ use_cache = False
+
+ assert getattr(config.data, "use_slow_tokenizer", False) is False
+
+ output_dir = config.data.token_output_dir
+ output_dir = Path(f"{output_dir}")
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ tensordict_output_dir = output_dir.parent / f"{output_dir.stem}_tensordict"
+ tensordict_output_dir.mkdir(parents=True, exist_ok=True)
+
+ dataset_name = config.data.train
+ if isinstance(dataset_name, list):
+ data = datasets.concatenate_datasets([
+ datasets.load_dataset(name, split=split, cache_dir=config.data.cache_dir, streaming=streaming)
+ for name in dataset_name
+ ])
+ else:
+ _args = []
+ if getattr(config.data, "add_load_dataset_args", None) is not None:
+ _args.append(getattr(config.data, "add_load_dataset_args", None))
+ data = datasets.load_dataset(dataset_name, *_args, split=split, cache_dir=config.data.cache_dir, streaming=streaming)
+
+
+ EOS = tokenizer.eos_token_id
+ BOS = tokenizer.bos_token_id
+
+ if config.data.n_train_samples is not None:
+ print(f"Selecting {config.data.n_train_samples} samples")
+ data = data.select(range(config.data.n_train_samples))
+
+ _preprocess_and_tokenize = partial(preprocess_and_tokenize, tokenizer=tokenizer, dataset_name=dataset_name, wrap=wrap, block_size=block_size, EOS=EOS, BOS=BOS)
+ if streaming:
+ tokenized_dataset = data.map(
+ _preprocess_and_tokenize,
+ batched=True
+ )
+ else:
+ rprint(f"Tokenizing with num_proc: {num_proc}")
+ tokenized_dataset = data.map(
+ _preprocess_and_tokenize,
+ batched=True,
+ num_proc=num_proc,
+ load_from_cache_file=use_cache,
+ desc='Tokenizing')
+
+ tokenized_dataset = tokenized_dataset.remove_columns('text')
+ columns_to_keep = ['input_ids']
+ if tokenized_dataset.column_names is not None:
+ columns_to_remove = [col for col in tokenized_dataset.column_names if col not in columns_to_keep]
+ tokenized_dataset = tokenized_dataset.remove_columns(columns_to_remove)
+
+ output_dataset = None
+ if wrap:
+ group_texts = partial(_group_texts, block_size=block_size, bos=BOS, eos=EOS)
+ if streaming:
+ chunked_dataset = tokenized_dataset.map(group_texts, batched=True)
+ else:
+ chunked_dataset = tokenized_dataset.map(group_texts, batched=True, num_proc=num_proc, load_from_cache_file=use_cache, desc='Grouping')
+ chunked_dataset.save_to_disk(output_dir)
+
+ output_dataset = chunked_dataset.with_format('torch')
+ else:
+ if streaming is False:
+ tokenized_dataset.save_to_disk(output_dir)
+ output_dataset = tokenized_dataset.with_format('torch')
+
+if __name__ == "__main__":
+ with breakpoint_on_error():
+ main()
diff --git a/models/datasets/precompute_tokens.py b/models/datasets/precompute_tokens.py
new file mode 100644
index 0000000000000000000000000000000000000000..25513367282a9352a469ef40ae824c447faa73e6
--- /dev/null
+++ b/models/datasets/precompute_tokens.py
@@ -0,0 +1,577 @@
+import os
+import shutil
+import signal
+import sys
+import time
+from contextlib import ExitStack
+from functools import partial
+from pathlib import Path
+
+from accelerate.utils import gather_object, gather
+from torchinfo import summary
+
+from unidisc.tokenizers.chameleon_tokenizers import tokenize_chameleon, tokenize_chameleon_fast, get_chameleon_images, decode_ids, decode_ids_batched, tokenize_chameleon_mmc4, tokenize_regular_cambrian_mmc4
+from utils import _print_config, set_numa_affinity, set_omega_conf_resolvers
+
+sys.path.append(str(Path(__file__).parent.parent.parent / "unidisc/misc/hydra_submitit_launcher"))
+
+import json
+import os
+import random
+import sys
+from contextlib import nullcontext
+from pathlib import Path
+
+import fsspec
+import hydra
+import numpy as np
+import omegaconf
+import rich.syntax
+import rich.tree
+import torch
+from accelerate import Accelerator
+from PIL import Image
+from tensordict import TensorDict
+from tqdm import tqdm
+try:
+ from viztracer import VizTracer
+except ImportError:
+ print("VizTracer not installed, skipping")
+
+from dataloader import get_dataloaders, get_tokenizer, tokenize_text
+from decoupled_utils import (barrier, breakpoint_on_error, get_local_rank, get_rank, get_world_size,
+ is_local_main_process, is_main_process,
+ rank_zero_fn, rprint, set_global_breakpoint,
+ set_global_exists, gprint)
+from model import decode_latents, get_image_batch, get_vae
+from models.datasets.combine_token_dicts import main as combine_token_dicts
+from models.datasets.vggface_v2_attributes import (get_inference_func,
+ get_output)
+from utils import (_print_config, set_numa_affinity, set_omega_conf_resolvers,
+ set_torch_defaults)
+from omegaconf import DictConfig, OmegaConf, open_dict, read_write
+
+os.environ["HYDRA_FULL_ERROR"] = "1"
+
+set_global_breakpoint() # Overrides breakpoint() to use ipdb.set_trace() instead and handle distributed training
+set_global_exists()
+set_omega_conf_resolvers()
+set_torch_defaults()
+
+def get_batch_size(config):
+ with open_dict(config):
+ if any(x.lower() in torch.cuda.get_device_name().lower() for x in ["v100", "1080", "2080", "quadro", "titan"]) or torch.cuda.get_device_capability()[0] <= 7:
+ config.trainer.precision = "no"
+ config.model.force_optimized_native_attn = False
+ config.trainer.compile = False
+ config.loader.batch_size = config.loader.batch_size // 3
+ print(f"Found {torch.cuda.get_device_name().lower()}, set batch size to {config.loader.batch_size}")
+ return config
+
+def enc(data, idx, encode_images, config, vae, batch, accelerator, mixed_precision, tokenizer, vgg_data, existing_ids=None, device=None, mapping=None):
+
+ if isinstance(batch, list):
+ bs = len(batch)
+ elif "img" in batch:
+ bs = batch["img"].shape[0]
+ else:
+ bs = batch["attention_mask"].shape[0]
+
+ sl = slice(idx * bs, (idx + 1) * bs)
+ if not isinstance(batch, list) and "idx" in batch:
+ if set(data[sl]["idx"].flatten().tolist()) == set(batch["idx"].tolist()):
+ rprint(f"Skipping {idx} as all samples have already been processed 1")
+ return
+ if existing_ids is not None:
+ set_inter = set(batch["idx"].tolist()) & existing_ids
+ if len(set_inter) == bs:
+ rprint(f"Skipping {idx} as all samples have already been processed 2")
+ return
+ elif len(set_inter) > 0:
+ rprint(f"Running {idx} as some samples have already been processed: {len(set_inter)}")
+ else:
+ if (data[sl]["idx"] != -1).all():
+ rprint(f"Skipping {idx} as all samples have already been processed")
+ return
+
+ if not isinstance(batch, list) and "img" in batch:
+ batch["img"] = batch["img"].to(device=device, dtype=torch.bfloat16 if mixed_precision else None)
+
+ with torch.no_grad():
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=mixed_precision):
+ use_chameleon = getattr(config.data, "use_chameleon", False)
+ use_mmc4 = config.data.train == "mmc4"
+ use_cambrian = config.data.train == "cambrian"
+ if not use_chameleon and not use_mmc4 and not use_cambrian:
+ if tokenizer is not None and getattr(config.model, "unified_model", False):
+ if "input_ids" in batch and "attention_mask" in batch:
+ tokens = batch
+ else:
+ tokens = tokenize_text(tokenizer, config.data.block_size, batch[".txt"])
+
+ batch["txt_input_ids"] = tokens["input_ids"]
+ batch["txt_attention_mask"] = tokens["attention_mask"].float()
+ elif getattr(config.data, "add_vggface_v2_attributes", False) and "vggface" not in config.data.train:
+ txt_input_ids, txt_attention_mask = get_output(batch, **vgg_data)
+ batch["txt_input_ids"] = txt_input_ids
+ batch["txt_attention_mask"] = txt_attention_mask
+ elif getattr(config.data, "txt_only", False):
+ batch["txt_input_ids"] = batch["input_ids"]
+ batch["txt_attention_mask"] = batch["attention_mask"]
+
+ if getattr(config.model, "unified_model", False) is False:
+ if getattr(config.data, "txt_only", False):
+ batch["modality"] = torch.full((bs, 1), fill_value=0, dtype=torch.int16)
+ else:
+ batch["modality"] = torch.full((bs, 1), fill_value=1, dtype=torch.int16)
+
+ if isinstance(batch, list) and batch[0].get("idx", None) is not None:
+ _idx = torch.tensor([x["idx"] for x in batch], dtype=torch.int32).unsqueeze(-1)
+ elif "idx" in batch:
+ _idx = batch["idx"].to(torch.int32).unsqueeze(-1)
+ else:
+ _idx = torch.full((bs, 1), fill_value=0, dtype=torch.int32)
+
+ if "is_valid" in batch:
+ _idx[~batch["is_valid"]] = -1
+ if (_idx == -1).all():
+ gprint(f"WARNING: All samples are invalid")
+
+ sl = slice(idx * bs, (idx + 1) * bs)
+ assert (idx + 1) * bs <= len(data), f"Index {idx} + batch size {bs} is greater than the data length {len(data)}"
+
+ if encode_images:
+ if use_chameleon:
+ if isinstance(batch, list):
+ all_input_ids, all_attention_masks = tokenize_chameleon_mmc4(config, tokenizer, vae, batch, device, mapping)
+ else:
+ all_input_ids, all_attention_masks = tokenize_chameleon_fast(config, tokenizer, vae, batch)
+
+ # all_input_ids_gt, all_attention_masks_gt = tokenize_chameleon(config, tokenizer, vae, batch)
+ # txt_tokens, img_tokens = decode_ids_batched(_vae, all_input_ids[:4], return_tokens=True)
+ # img = decode_latents(config, _vae, img_tokens)
+ # from image_utils import Im; Im(img).save()
+
+ elif use_mmc4 or use_cambrian:
+ all_input_ids, all_attention_masks, all_modality = tokenize_regular_cambrian_mmc4(config, tokenizer, vae, batch, device, mapping)
+ if all_input_ids is None:
+ return
+ else:
+ image_ids = get_image_batch(config, vae, batch, device)
+
+ if use_chameleon or use_mmc4 or use_cambrian:
+ if not use_chameleon:
+ assert (all_input_ids < torch.iinfo(torch.int16).max).all()
+
+ _kwargs = {}
+ if use_mmc4 or use_cambrian:
+ _kwargs["modality"] = all_modality.to(torch.int8)
+
+ data[sl] = TensorDict(
+ {
+ "input_ids": all_input_ids.to(torch.int32 if use_chameleon else torch.int16),
+ "attention_mask": all_attention_masks.to(torch.bool),
+ "idx": _idx,
+ "write_flag": torch.ones((bs, 1), dtype=torch.bool),
+ **_kwargs,
+ },
+ batch_size=[bs],
+ )
+ elif getattr(config.model, "cond_label", False):
+ data[sl] = TensorDict(
+ {
+ "img_input_ids": image_ids.to(torch.int16),
+ "img_label": batch["label"].to(torch.int32).unsqueeze(-1),
+ "idx": _idx,
+ "write_flag": torch.ones((bs, 1), dtype=torch.bool),
+ },
+ batch_size=[bs],
+ )
+ elif getattr(config.model, "unified_model", False) or getattr(config.data, "add_vggface_v2_attributes", False):
+ data[sl] = TensorDict(
+ {
+ "img_input_ids": image_ids.to(torch.int16),
+ "txt_input_ids": (batch.get("txt_input_ids") if batch.get("txt_input_ids") is not None else batch["input_ids"]).to(
+ torch.int32
+ ),
+ "txt_attention_mask": (
+ batch.get("txt_attention_mask") if batch.get("txt_attention_mask") is not None else batch["attention_mask"]
+ ).to(torch.bool),
+ "idx": _idx,
+ "write_flag": torch.ones((bs, 1), dtype=torch.bool),
+ },
+ batch_size=[bs],
+ )
+ else:
+ data[sl] = TensorDict(
+ {"input_ids": image_ids.to(torch.int32), "attention_mask": torch.ones((image_ids.shape[0], image_ids.shape[1]), dtype=torch.bool), "idx": _idx, "write_flag": torch.ones((bs, 1), dtype=torch.bool), "modality": batch["modality"].to(torch.int16)},
+ batch_size=[bs],
+ )
+
+ elif getattr(config.data, "txt_only", False):
+ data[sl] = TensorDict(
+ {"input_ids": batch['input_ids'].to(torch.int32), "attention_mask": batch['attention_mask'].to(torch.bool), "idx": _idx, "write_flag": torch.ones((bs, 1), dtype=torch.bool), "modality": batch["modality"].to(torch.int16)},
+ batch_size=[bs],
+ )
+ else:
+ real_image = batch["img"]
+ if (config.data.resolution == 512 and batch["img"].shape[0] > 16) or (config.model.downscale_ratio <= 8):
+ chunk_size = 8 if (config.model.image_vocab_size > 64000 or config.model.downscale_ratio <= 8) else 16
+ chunks = [batch["img"][i : i + chunk_size] for i in range(0, batch["img"].shape[0], chunk_size)]
+ rec_img_list = []
+ for chunk in chunks:
+ batch_chunk = {"img": chunk}
+ image_ids = get_image_batch(config, vae, batch_chunk, device)
+ rec_img = decode_latents(config, vae, image_ids)
+ rec_img_list.append(rec_img)
+ rec_img = torch.cat(rec_img_list, dim=0)
+ else:
+ image_ids = get_image_batch(config, vae, batch, device)
+ rec_img = decode_latents(config, vae, image_ids)
+
+ viz_img = torch.cat([real_image, rec_img], dim=-1)
+ from image_utils import Im
+
+ if getattr(config.model, 'custom_vae_name', None) is not None:
+ custom_str = getattr(config.model, 'custom_vae_name')
+ else:
+ custom_str = f"{'_custom' if getattr(config.model, 'use_custom_vae_ckpt', False) else ''}"
+ (Path(__file__).parent.parent.parent / "output").mkdir(parents=True, exist_ok=True)
+ Im(viz_img).save(
+ Path(__file__).parent.parent.parent / f"output/{config.data.train.replace('/', '')}_seq{image_ids.shape[1]}_res{config.data.resolution}_{config.model.vae_type}{custom_str}_voc{config.model.image_vocab_size}.png"
+ )
+
+ # Create directories for saving images
+ dataset_name = config.data.train.replace('/', '')
+ vae_name = f"seq{image_ids.shape[1]}_res{config.data.resolution}_{config.model.vae_type}{custom_str}_voc{config.model.image_vocab_size}"
+ output_dir = Path(__file__).parent.parent.parent / "output" / dataset_name / vae_name
+ gt_output_dir = Path(__file__).parent.parent.parent / "output" / dataset_name / f"GT_{config.data.resolution}"
+ output_dir.mkdir(parents=True, exist_ok=True)
+ gt_output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Save each image separately
+ for i, (real, rec) in enumerate(zip(real_image, rec_img)):
+ print(Im(rec).save(output_dir / f"{i}.png"))
+ if (gt_output_dir / f"{i}.png").exists() is False:
+ print(Im(real).save(gt_output_dir / f"{i}.png"))
+
+ gprint(f"Exiting")
+ exit()
+
+
+def get_dict(config, dataset_size):
+ if getattr(config.data, "use_chameleon", False) or config.data.train == "cambrian" or config.data.train == "mmc4":
+ input_ids_dtype = torch.int32 if getattr(config.data, "use_chameleon", False) else torch.int16
+ data = TensorDict(
+ {
+ "input_ids": torch.zeros(dataset_size, config.model.length, dtype=input_ids_dtype),
+ "attention_mask": torch.zeros(dataset_size, config.model.length, dtype=torch.bool),
+ "modality": torch.full((dataset_size, config.model.length), fill_value=-1, dtype=torch.int8),
+ "idx": torch.full((dataset_size, 1), fill_value=-1, dtype=torch.int32),
+ "write_flag": torch.zeros(dataset_size, 1, dtype=torch.bool),
+ },
+ batch_size=[dataset_size],
+ )
+ elif getattr(config.model, "cond_label", False):
+ data = TensorDict(
+ {
+ "img_input_ids": torch.zeros(dataset_size, config.model.img_length, dtype=torch.int16),
+ "img_label": torch.zeros(dataset_size, 1, dtype=torch.int32),
+ "idx": torch.full((dataset_size,), fill_value=-1, dtype=torch.int32),
+ "write_flag": torch.zeros(dataset_size, 1, dtype=torch.bool),
+ },
+ batch_size=[dataset_size],
+ )
+ elif getattr(config.model, "unified_model", False) or getattr(config.data, "add_vggface_v2_attributes", False):
+ data = TensorDict(
+ {
+ "img_input_ids": torch.zeros(dataset_size, config.model.img_length, dtype=torch.int16),
+ "txt_input_ids": torch.zeros(dataset_size, config.model.txt_length, dtype=torch.int32),
+ "txt_attention_mask": torch.zeros(dataset_size, config.model.txt_length, dtype=torch.bool),
+ "idx": torch.full((dataset_size, 1), fill_value=-1, dtype=torch.int32),
+ "write_flag": torch.zeros(dataset_size, 1, dtype=torch.bool),
+ },
+ batch_size=[dataset_size],
+ )
+ else:
+ data = TensorDict(
+ {
+ "input_ids": torch.zeros(dataset_size, config.model.txt_length if config.data.txt_only else config.model.img_length, dtype=torch.int16),
+ "idx": torch.full((dataset_size, 1), fill_value=-1, dtype=torch.int32),
+ "write_flag": torch.zeros(dataset_size, 1, dtype=torch.bool),
+ "modality": torch.full((dataset_size, 1), fill_value=-1, dtype=torch.int16),
+ },
+ batch_size=[dataset_size],
+ )
+ return data
+
+def signal_handler(signum, frame, train_data, tmp_path):
+ """Handle signals to save temporary train data."""
+ rprint(f"Received signal {signum}, saving temporary train data.")
+ print(f"[PRINT] Received signal {signum}, saving temporary train data.")
+ save_tmp_data(train_data, tmp_path)
+ sys.exit
+
+def save_tmp_data(data, tmp_path):
+ """Save data to a temporary path."""
+ if tmp_path.exists() and tmp_path.is_dir():
+ rprint(f"Deleting {tmp_path}")
+ shutil.rmtree(tmp_path) # Delete old tmp directory if it exists
+ rprint(f"Saving tmp data to {tmp_path}")
+ data.memmap(tmp_path, copy_existing=True)
+
+def periodic_save(data, tmp_path, start_time, interval=2 * 60 * 60):
+ """Periodically save data to a temporary path."""
+ current_time = time.time()
+ if current_time - start_time >= interval:
+ rprint(f"Hit periodic save interval, saving tmp data to {tmp_path}")
+ save_tmp_data(data, tmp_path)
+ return current_time # Reset start time
+ return start_time
+
+@hydra.main(version_base=None, config_path="../../configs", config_name="config")
+def main(config):
+ """Main entry point for training."""
+
+ try:
+ import resource
+ soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
+ resource.setrlimit(resource.RLIMIT_NOFILE, (hard_limit, hard_limit)) # Set the soft limit to the hard limit
+ rprint(f"Successfully set RLIMIT_NOFILE to {hard_limit}")
+ except Exception as e:
+ rprint(f"Failed to set RLIMIT_NOFILE: {e}")
+
+ mixed_precision = False
+ train_start_time = time.time()
+
+ from datetime import timedelta
+ from accelerate import Accelerator, DataLoaderConfiguration
+ from accelerate.utils import InitProcessGroupKwargs
+ kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=3600))
+ prepare_kwargs = {}
+ if config.data.train == "mmc4":
+ prepare_kwargs["dispatch_batches"] = False
+
+ accelerator = Accelerator(mixed_precision="bf16" if mixed_precision else None, kwargs_handlers=[kwargs], dataloader_config=DataLoaderConfiguration(**prepare_kwargs))
+ device = torch.device(f"cuda:{accelerator.local_process_index}")
+
+ import socket
+ hostname = socket.gethostname()
+ print(f"Hostname: {hostname}, Process index: {accelerator.process_index}, {device}, local_process_index: {accelerator.local_process_index}, get_local_process_index: {get_local_rank()}, device: {device}")
+ _print_config(config, resolve=True, save_cfg=True)
+
+ config = get_batch_size(config)
+
+ # with omegaconf.open_dict(config):
+ # batch_sizes = gather_object([config.loader.batch_size])
+ # rprint(f"Batch sizes: {batch_sizes}")
+ # smallest_batch_size = min(batch_sizes)
+ # config.loader.batch_size = smallest_batch_size
+ # rprint(f"New config batch size: {config.loader.batch_size}")
+
+ prefix = f"[Rank {accelerator.process_index}/{accelerator.num_processes}, Node: {os.environ.get('SLURM_NODEID', 'N/A')}, Hostname: {os.environ.get('SLURM_JOB_NODELIST', 'N/A')}, {config.data.train}]"
+ print(f"{prefix} Starting precomputing tokens")
+ save_validation_dataloader = getattr(config.data, "save_validation_dataloader", False)
+ save_train_dataloader = getattr(config.data, "save_train_dataloader", False)
+
+ tokenizer = get_tokenizer(config)
+ train_dataloader, val_dataloader = get_dataloaders(
+ config, tokenizer=tokenizer, allow_aug=False, force_aug=getattr(config.data, "force_aug", False), skip_valid=not save_validation_dataloader
+ )
+
+ train_dataloader = accelerator.prepare(train_dataloader)
+ if save_validation_dataloader:
+ val_dataloader = accelerator.prepare(val_dataloader)
+ encode_images = getattr(config.model, "encode_images", False)
+
+ use_chameleon = getattr(config.data, "use_chameleon", False)
+ use_mmc4 = config.data.train == "mmc4"
+ use_cambrian = config.data.train == "cambrian"
+ mapping = None
+
+ if use_chameleon:
+ from unidisc.tokenizers.chameleon_tokenizers import ItemProcessor
+ vae = ItemProcessor(target_size=config.data.resolution)
+ else:
+ vae = get_vae(config, device)
+
+ if use_mmc4:
+ import pandas as pd
+ mapping = pd.read_parquet(config.data.mmc4_mapping_parquet)
+ # Keep tar_filepath if it exists, otherwise use shard_path or map img2dataset_shard_id
+ if "tar_filepath" in mapping.columns:
+ pass
+ elif "shard_path" in mapping.columns:
+ mapping = mapping.rename(columns={"shard_path": "tar_filepath"})
+ mapping["tar_filepath"] = mapping["tar_filepath"].str.replace(".parquet", ".tar")
+ else:
+ tar_path = Path(config.data.mmc4_tar_path)
+ mapping["tar_filepath"] = mapping["img2dataset_shard_id"].apply(lambda x: tar_path / f"{x}.tar")
+
+ mapping = mapping[['url', 'tar_filepath', 'key']]
+ mapping = mapping.set_index("url").sort_index()
+
+ if use_mmc4 or use_cambrian:
+ assert config.data.use_slow_tokenizer and config.data.add_image_token
+
+ if config.data.iterable:
+ train_dataset_size = getattr(config.data, "train_dataset_size", None)
+ else:
+ print(f"{prefix} Train dataloader: {len(train_dataloader)} batches")
+ print(f"{prefix} Train underlying dataset: {len(train_dataloader.dataset)} samples")
+ train_dataset_size = (len(train_dataloader.dataset) // accelerator.num_processes) + config.loader.batch_size
+ if save_validation_dataloader:
+ print(f"{prefix} Val dataloader: {len(val_dataloader)} batches")
+ print(f"Val underlying dataset: {len(val_dataloader.dataset)} samples")
+ val_dataset_size = (len(val_dataloader.dataset) // accelerator.num_processes) + config.loader.batch_size
+
+ print(f"{prefix} Train dataset size: {train_dataset_size} for 1 GPU")
+ if save_validation_dataloader:
+ print(f"{prefix} Val dataset size: {val_dataset_size} for 1 GPU")
+
+ rank = accelerator.process_index
+ output_dir = config.data.token_output_dir
+ output_dir = Path(f"{output_dir}")
+ output_dir.mkdir(parents=True, exist_ok=True)
+ assert config.data.force_disable_shuffle
+
+ debug = getattr(config.data, "debug", False)
+ print(f"{prefix} Output dir: {output_dir}")
+
+ vgg_data = None
+ if getattr(config.data, "add_vggface_v2_attributes", False):
+ print(f"{prefix} Adding VGGFace V2 attributes")
+ vgg_data = get_inference_func()
+ vgg_data["model"] = accelerator.prepare(vgg_data["model"])
+
+ if not config.data.split_dataset and is_main_process() and any(output_dir.iterdir()):
+ rprint(f"Found temporary directories in output dir, combining them")
+ combine_token_dicts(output_dir, use_tmp=False, use_timestamp=True, delete_after_combining=True)
+ for item in output_dir.iterdir():
+ if item.is_dir() and "tmp" in item.name:
+ rprint(f"Removing temporary directory: {item}")
+ shutil.rmtree(item)
+
+ # barrier() # TODO: Should be a barrier here
+ if not config.data.split_dataset:
+ existing_folders = sorted([folder for folder in output_dir.iterdir() if folder.is_dir() and "existing" in folder.name])
+ if existing_folders:
+ rprint(f"Found existing folders: {existing_folders}")
+ existing_data = torch.cat([TensorDict.load_memmap(folder) for folder in existing_folders], dim=0)
+ rprint(f"Concatenated existing data with shape: {existing_data.shape}")
+ existing_ids = set(existing_data["idx"].to(torch.int32).flatten().tolist())
+ else:
+ rprint("No existing folders found")
+ existing_ids = None
+ else:
+ existing_ids = None
+
+ if save_train_dataloader:
+ if not config.data.split_dataset and getattr(config.data, "allow_load_from_tmp", True) and Path(output_dir / f"tmp_train_{rank}").exists():
+ rprint("Found tmp_train_{rank} in output dir, loading from it")
+ train_data = TensorDict.load_memmap(output_dir / f"tmp_train_{rank}")
+ train_data = train_data.clone()
+ else:
+ train_data = get_dict(config, train_dataset_size)
+
+ print(f"{prefix} Starting train dataloader")
+ if config.data.split_dataset:
+ rank = int(os.getenv("SLURM_ARRAY_TASK_ID"))
+ print(f"Using task id: {rank}")
+
+ split_path = output_dir / f"train_{rank}"
+ tmp_train_path = output_dir / f"tmp_train_{rank}"
+
+ signal.signal(signal.SIGUSR1, partial(signal_handler, train_data=train_data, tmp_path=tmp_train_path))
+ signal.signal(signal.SIGUSR2, partial(signal_handler, train_data=train_data, tmp_path=tmp_train_path))
+
+ try:
+ signal.signal(signal.SIGKILL, partial(signal_handler, train_data=train_data, tmp_path=tmp_train_path))
+ except:
+ rprint(f"Failed to set SIGKILL handler")
+
+ start_time = time.time()
+ with VizTracer(output_file="optional.json", tracer_entries=5000000) if debug else nullcontext():
+ for i, batch in tqdm(enumerate(train_dataloader), leave=False, disable=not is_local_main_process()):
+ if i == 0 and "img" in batch:
+ print(f"Batch shape: {batch['img'].shape}")
+ if debug and i >= 1:
+ break
+ enc(train_data, i, encode_images, config, vae, batch, accelerator, mixed_precision, tokenizer, vgg_data=vgg_data, existing_ids=existing_ids, device=device, mapping=mapping)
+ try:
+ if not config.data.split_dataset or True:
+ start_time = periodic_save(train_data, tmp_train_path, start_time, getattr(config.data, "periodic_save", 2 * 60 * 60))
+ except Exception as e:
+ gprint(f"Failed to save train data: {e}")
+ start_time = time.time()
+
+ if debug:
+ exit()
+
+ del train_dataloader
+ print(f"{prefix} Saving train data")
+ if split_path.exists() and split_path.is_dir():
+ rprint(f"Removing {split_path}")
+ shutil.rmtree(split_path)
+
+ split_path.mkdir(parents=True, exist_ok=True)
+ gprint(f"Saving train data to {split_path}: {train_data.shape}")
+ train_data.memmap(split_path, copy_existing=True)
+
+ if tmp_train_path.exists() and tmp_train_path.is_dir():
+ rprint(f"Removing {tmp_train_path}")
+ shutil.rmtree(tmp_train_path)
+
+ if not config.data.split_dataset:
+ with open(output_dir / f"train_{rank}.completed", 'w') as f:
+ f.write(f"Processing done for rank {rank}\n")
+
+ print(f"{prefix} Finished train dataloader")
+
+ if save_validation_dataloader:
+ val_data = get_dict(config, val_dataset_size)
+ split_path = output_dir / f"val_{rank}"
+ split_path.mkdir(parents=True, exist_ok=True)
+ tmp_val_path = output_dir / f"tmp_val_{rank}"
+ print(f"Starting val dataloader")
+ start_time = time.time() # Track start time for periodic saving
+ for i, batch in tqdm(enumerate(val_dataloader), leave=False):
+ if debug and i >= 10:
+ break
+ enc(val_data, i, encode_images, config, vae, batch, accelerator, mixed_precision, tokenizer, vgg_data=vgg_data, device=device)
+
+ # Periodically save data
+ start_time = periodic_save(val_data, tmp_val_path, start_time)
+
+ print(f"{prefix} Saving val data")
+ if split_path.exists() and split_path.is_dir():
+ rprint(f"Removing {split_path}")
+ shutil.rmtree(split_path)
+ split_path.mkdir(parents=True, exist_ok=True)
+ rprint(f"Saving val data to {split_path}")
+ val_data.memmap(split_path, copy_existing=True)
+ if tmp_val_path.exists() and tmp_val_path.is_dir():
+ shutil.rmtree(tmp_val_path) # Delete tmp directory after final save
+ print(f"{prefix} Finished val dataloader")
+
+ rprint(f"{prefix} Finished precomputing tokens")
+
+ if config.data.split_dataset:
+ rprint(f"We are splitting the dataset and thus exiting.")
+ exit()
+
+ if get_world_size() > 1 and (time.time() - train_start_time) > 60 * 60:
+ time.sleep(60 * 60)
+ barrier()
+
+ rprint('after barrier')
+ if is_main_process():
+ combine_token_dicts(data_dir=output_dir, allow_zero_idx=True, move_files=True, delete_after_combining=True)
+
+ barrier()
+ rprint(f"Finished concating tokens")
+
+
+if __name__ == "__main__":
+ with breakpoint_on_error():
+ main()
diff --git a/models/datasets/text_datasets.py b/models/datasets/text_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..738d4bc578ddf940ad3ac99b1029d4ced6135945
--- /dev/null
+++ b/models/datasets/text_datasets.py
@@ -0,0 +1,487 @@
+
+import functools
+import itertools
+import json
+import math
+import os
+import random
+import re
+import shutil
+import typing
+import urllib
+import zipfile
+from pathlib import Path
+
+import datasets
+import fsspec
+import pandas as pd
+import requests
+import tokenizers
+import torch
+import transformers
+import utils
+from decoupled_utils import rprint
+
+def wt_detokenizer(string):
+ # contractions
+ string = string.replace("s '", "s'")
+ string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
+ # number separators
+ string = string.replace(" @-@ ", "-")
+ string = string.replace(" @,@ ", ",")
+ string = string.replace(" @.@ ", ".")
+ # punctuation
+ string = string.replace(" : ", ": ")
+ string = string.replace(" ; ", "; ")
+ string = string.replace(" . ", ". ")
+ string = string.replace(" ! ", "! ")
+ string = string.replace(" ? ", "? ")
+ string = string.replace(" , ", ", ")
+ # double brackets
+ string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
+ string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
+ string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
+ string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
+ string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
+ # miscellaneous
+ string = string.replace("= = = =", "====")
+ string = string.replace("= = =", "===")
+ string = string.replace("= =", "==")
+ string = string.replace(" " + chr(176) + " ", chr(176))
+ string = string.replace(" \n", "\n")
+ string = string.replace("\n ", "\n")
+ string = string.replace(" N ", " 1 ")
+ string = string.replace(" 's", "'s")
+ return string
+
+
+def ptb_detokenizer(x):
+ x = x.replace(" 's", "'s")
+ x = x.replace("s ' ", "s' ")
+ x = x.replace(" n't", "n't")
+ x = x.replace(" \n ", "\n")
+ x = x.replace("\\/", "/")
+ for _ in range(10):
+ x = x.replace(" N ", " 1 ")
+ x = x.replace("$ 1", "$1")
+ x = x.replace("# 1", "#1")
+ x = x.replace("", "?")
+ return x
+
+
+def lm1b_detokenizer(x):
+ x = x.replace('http : / / ', 'http://')
+ x = x.replace('https : / / ', 'https://')
+ x = re.sub(r' \'(\w+)', r"'\1", x)
+ x = re.sub(r' (\w+) \. ', r' \1. ', x)
+ x = re.sub(r' (\w+) \.$', r' \1.', x)
+ x = x.replace(' ? ', '? ')
+ x = re.sub(r' \?$', '?', x)
+ x = x.replace(' ! ', '! ')
+ x = re.sub(r' \!$', '!', x)
+ x = x.replace(' , ', ', ')
+ x = x.replace(' : ', ': ')
+ x = x.replace(' ; ', '; ')
+ x = x.replace(' / ', '/')
+ x = re.sub(r'\" ([^\"]+) \"', r'"\1"', x)
+ x = re.sub(r'\' ([^\']+) \'', r"'\1'", x)
+ x = re.sub(r'\( ([^\(\)]+) \)', r"(\1)", x)
+ x = re.sub(r'\[ ([^\[\]]+) \]', r"[\1]", x)
+ x = x.replace('$ ', '$')
+ x = x.replace('£ ', '£')
+ return x
+
+
+def lambada_detokenizer(text):
+ text = text.replace("“", '"')
+ text = text.replace("”", '"')
+ return '\n'+text.strip()
+
+
+def scientific_papers_detokenizer(x):
+ x = wt_detokenizer(x)
+ x = lm1b_detokenizer(x)
+ return x
+
+
+class Text8Tokenizer(transformers.PreTrainedTokenizer):
+ def __init__(
+ self,
+ bos_token='[BOS]',
+ eos_token='[EOS]',
+ sep_token='[SEP]',
+ cls_token='[CLS]',
+ pad_token='[PAD]',
+ mask_token='[MASK]',
+ unk_token='[UNK]',
+ **kwargs):
+ self.characters = list('abcdefghijklmnopqrstuvwxyz ')
+ self._vocab_str_to_int = {
+ '[CLS]': 0,
+ '[SEP]': 1,
+ '[BOS]': 2,
+ '[EOS]': 3,
+ '[MASK]': 4,
+ '[PAD]': 5,
+ '[RESERVED]': 6,
+ '[UNK]': 7,
+ ** {ch: i + 8 for i, ch in enumerate(self.characters)}}
+ self._vocab_int_to_str = {
+ v: k for k, v in self._vocab_str_to_int.items()}
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ unk_token=unk_token,
+ **kwargs)
+
+ @property
+ def vocab_size(self) -> int:
+ return len(self._vocab_str_to_int)
+
+ def _tokenize(self, text: str, **kwargs):
+ return list(text.lower())
+
+ def _convert_token_to_id(self, token: str) -> int:
+ return self._vocab_str_to_int.get(
+ token, self._vocab_str_to_int['[UNK]'])
+
+ def _convert_id_to_token(self, index: int) -> str:
+ return self._vocab_int_to_str[index]
+
+ def convert_tokens_to_string(self, tokens):
+ return ''.join(tokens)
+
+ def get_vocab(self) -> typing.Dict[str, int]:
+ return self._vocab_str_to_int
+
+
+def get_lambada_test_dataset():
+ url = "https://openaipublic.blob.core.windows.net/gpt-2/data/lambada_test.jsonl"
+
+ def read_jsonl_to_list(url):
+ response = requests.get(url, stream=True)
+ data_list = []
+
+ # Process each line in the response content
+ for line in response.iter_lines(decode_unicode=True):
+ if line:
+ data = json.loads(line)
+ data_list.append(data)
+
+ return data_list
+
+ lambada_data = read_jsonl_to_list(url)
+ dataset = datasets.Dataset.from_list(lambada_data)
+ return dataset
+
+def get_text8_dataset(cache_dir, max_seq_length=256,
+ drop_last=True, crop_train=False):
+ """Adapted from:
+ https://github.com/google-research/google-research/blob/master/d3pm/text/datasets.py#L344
+
+ Args:
+ cache_dir: str, path to cache directory.
+ max_seq_length: int, maximum length of sequences.
+ (default: 256, as in D3PM codebase.)
+ drop_last: bool, whether to drop the last incomplete
+ batch. (default: True, as in D3PM codebase.)
+ crop_train: bool, whether to subsample contiguous
+ subsequences from training example. serves to
+ make sure transformer models with absolute position
+ embeddings do not have incorrect position-wise
+ marginals. (default: False, but necessary to match D3PM AR)
+
+ Returns:
+ dataset: dataset.DatasetDict, with keys 'train',
+ 'valid', 'test'.
+ """
+ url = 'http://mattmahoney.net/dc/text8.zip'
+ if not crop_train:
+ cache_dir = f'{cache_dir}/text8'
+ else:
+ cache_dir = f'{cache_dir}/text8-crop-train'
+ split_names = ['train', 'validation', 'test']
+ if not all([
+ utils.fsspec_exists(os.path.join(cache_dir, split))
+ for split in split_names
+ ]):
+ # Check if raw data exists
+ raw_cache_dir = os.path.join(cache_dir, 'raw_data')
+ if not all([
+ utils.fsspec_exists(
+ os.path.join(raw_cache_dir, f'text8.{split}.txt'))
+ for split in split_names
+ ]):
+ if not utils.fsspec_exists(
+ os.path.join(raw_cache_dir, 'text8.zip')):
+ utils.fsspec_mkdirs(raw_cache_dir, exist_ok=True)
+ print('Downloading text8 from URL {}.'.format(url))
+ with (urllib.request.urlopen(url) as in_stream,
+ open(os.path.join(raw_cache_dir, 'text8.zip'),
+ 'wb') as out_file):
+ shutil.copyfileobj(in_stream, out_file)
+
+ with fsspec.open(
+ os.path.join(raw_cache_dir, 'text8.zip'),
+ 'rb') as f:
+ rawdata = zipfile.ZipFile(f).read(
+ 'text8').decode('utf-8')
+
+ # Splits taken from D3PM codebase
+ splits = {
+ 'train': rawdata[:90000000],
+ 'validation': rawdata[90000000: 95000000],
+ 'test': rawdata[95000000:],
+ }
+
+ for split, data in splits.items():
+ _path = os.path.join(raw_cache_dir,
+ f'text8.{split}.txt')
+ with fsspec.open(_path, 'w') as f:
+ f.write(data)
+ else:
+ splits = {}
+ for split in split_names:
+ _path = os.path.join(raw_cache_dir,
+ f'text8.{split}.txt')
+ with fsspec.open(_path, 'r') as f:
+ splits[split] = f.read()
+
+ # Chunk and save as datasets.DatasetDict
+ def chunks(lst, n):
+ """Yield successive n-sized chunks from lst."""
+ for i in range(0, len(lst), n):
+ yield lst[i:i + n]
+
+ dataset_dict = {}
+ for k, v in splits.items():
+ if k == 'train' and crop_train == True:
+ chunk_size = 2 * max_seq_length
+ else:
+ chunk_size = max_seq_length
+ text = list(chunks(v, chunk_size))
+ if drop_last and len(text[-1]) < chunk_size:
+ text = text[:-1]
+ dataset_dict[k] = datasets.Dataset.from_dict({'text': text})
+ dataset = datasets.DatasetDict(dataset_dict)
+ dataset.save_to_disk(cache_dir)
+ else:
+ dataset = datasets.load_from_disk(cache_dir)
+
+ return dataset
+
+
+def _group_texts(examples, block_size, bos, eos):
+ # Concatenate all texts.
+ concatenated_examples = list(itertools.chain(* examples['input_ids']))
+ total_length = len(concatenated_examples)
+ # TODO(yair): look into not dropping the remainder but rather padding it.
+ # We drop the small remainder, and if the total_length < block_size - 2
+ # we exclude this batch and return an empty dict.
+ # We could add padding if the model supported it instead of
+ # this drop, you can customize this part to your needs.
+ new_block_size = block_size - 2 # [BOS] and [EOS] to be added
+ total_length = (total_length // new_block_size) * new_block_size
+ # Split by chunks of max_len.
+ result = {}
+ _values = []
+ _attn_masks = []
+ for i in range(0, total_length, new_block_size):
+ _values.append(
+ [bos]
+ + concatenated_examples[i : i + new_block_size]
+ + [eos])
+ _attn_masks.append(torch.ones(block_size))
+ result['input_ids'] = _values
+ result['attention_mask'] = _attn_masks
+ return result
+
+
+def get_text_dataset(dataset_name, tokenizer, wrap, mode, cache_dir, block_size=1024, num_proc=len(os.sched_getaffinity(0)), streaming=False, **kwargs):
+ if wrap:
+ filename = f'{dataset_name}_{mode}_bs{block_size}_{tokenizer.__class__.__name__}_wrapped.dat'
+ else:
+ filename = f'{dataset_name}_{mode}_bs{block_size}_{tokenizer.__class__.__name__}_unwrapped.dat'
+ _path = os.path.join(cache_dir, filename)
+ if utils.fsspec_exists(_path):
+ print(f'Loading data from: {_path}')
+ _dataset = datasets.load_from_disk(_path).with_format('torch')
+ rprint(f"Sample 0: {_dataset[0]}")
+ rprint(f"Sample -1: {_dataset[-1]}")
+ return _dataset
+ print(f'Generating new data at: {_path}')
+
+ crop_train = dataset_name == 'text8-crop'
+ if mode == 'train' and crop_train:
+ # double block size for sub-sampling
+ block_size *= 2
+
+ if dataset_name == 'wikitext103':
+ dataset = datasets.load_dataset(
+ 'wikitext',
+ name='wikitext-103-raw-v1',
+ cache_dir=cache_dir)
+ elif dataset_name == 'wikitext2':
+ dataset = datasets.load_dataset(
+ 'wikitext',
+ name='wikitext-2-raw-v1',
+ cache_dir=cache_dir)
+ elif dataset_name == 'ptb':
+ dataset = datasets.load_dataset(
+ 'ptb_text_only', cache_dir=cache_dir)
+ elif dataset_name == 'lambada':
+ dataset = get_lambada_test_dataset()
+ elif dataset_name == 'text8':
+ assert wrap
+ dataset = get_text8_dataset(
+ cache_dir, max_seq_length=block_size)
+ elif dataset_name == 'text8-crop':
+ dataset = get_text8_dataset(
+ cache_dir, max_seq_length=block_size, crop_train=True)
+ elif dataset_name == 'openwebtext-train':
+ dataset = datasets.load_dataset(
+ 'openwebtext',
+ split='train' if streaming else 'train[:-100000]',
+ cache_dir=cache_dir,
+ streaming=streaming, trust_remote_code=True)
+ elif dataset_name == 'openwebtext-valid':
+ dataset = datasets.load_dataset(
+ 'openwebtext',
+ split='train' if streaming else 'train[-100000:]',
+ cache_dir=cache_dir,
+ streaming=streaming)
+ elif dataset_name == 'scientific_papers_arxiv':
+ dataset = datasets.load_dataset(
+ 'scientific_papers', 'arxiv',
+ trust_remote_code=True,
+ cache_dir=cache_dir,
+ streaming=streaming)
+ elif dataset_name == 'scientific_papers_pubmed':
+ dataset = datasets.load_dataset(
+ 'scientific_papers', 'pubmed',
+ trust_remote_code=True,
+ cache_dir=cache_dir,
+ streaming=streaming)
+ elif dataset_name == 'ag_news':
+ dataset = datasets.load_dataset(
+ 'ag_news',
+ cache_dir=cache_dir,
+ streaming=streaming)
+ else:
+ dataset = datasets.load_dataset(
+ dataset_name,
+ cache_dir=cache_dir,
+ streaming=streaming,
+ trust_remote_code=True)
+
+ if dataset_name in ['lambada', 'openwebtext-train',
+ 'openwebtext-valid']:
+ data = dataset
+ else:
+ data = dataset[mode]
+
+ if dataset_name.startswith('wikitext'):
+ detokenizer = wt_detokenizer
+ elif dataset_name == 'ptb':
+ detokenizer = ptb_detokenizer
+ elif dataset_name == 'lm1b':
+ detokenizer = lm1b_detokenizer
+ elif dataset_name == 'lambada':
+ detokenizer = lambada_detokenizer
+ elif dataset_name.startswith('scientific_papers'):
+ detokenizer = scientific_papers_detokenizer
+ else:
+ detokenizer = None
+
+ def _apply_detokenizer(detokenizer):
+ def detok(text):
+ for i, t in enumerate(text, 0):
+ text[i] = detokenizer(t)
+ return text
+ return detok
+
+ EOS = tokenizer.encode(tokenizer.eos_token)[0]
+ BOS = tokenizer.encode(tokenizer.bos_token)[0]
+
+ def preprocess_and_tokenize(example):
+ if dataset_name == 'ptb':
+ text = example['sentence']
+ elif 'scientific_papers' in dataset_name:
+ text = example['article']
+ else:
+ text = example['text']
+
+ if detokenizer is not None:
+ text = _apply_detokenizer(detokenizer)(text)
+
+ tokenizer.padding_side = 'right'
+ tokenizer.truncation_side = 'right'
+
+ if wrap:
+ tokens = tokenizer(text,
+ add_special_tokens=False,
+ return_attention_mask=False,
+ return_token_type_ids=False)
+ tokens = {'input_ids':
+ [t + [EOS] for t in tokens['input_ids']]}
+ # Still missing BOS, but will be added in group_texts
+ else:
+ tokens = tokenizer(text,
+ max_length=block_size,
+ padding='max_length',
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_token_type_ids=True)
+ return tokens
+ if streaming:
+ tokenized_dataset = data.map(
+ preprocess_and_tokenize,
+ batched=True
+ )
+ else:
+ rprint(f"Tokenizing with num_proc: {num_proc}")
+ tokenized_dataset = data.map(
+ preprocess_and_tokenize,
+ batched=True,
+ num_proc=num_proc,
+ load_from_cache_file=True,
+ desc='Tokenizing')
+ if dataset_name == 'ptb':
+ tokenized_dataset = tokenized_dataset.remove_columns(
+ 'sentence')
+ elif 'scientific_papers' in dataset_name:
+ tokenized_dataset = tokenized_dataset.remove_columns([
+ 'article', 'abstract', 'section_names'])
+ elif dataset_name == 'ag_news':
+ tokenized_dataset = tokenized_dataset.remove_columns(
+ ['text', 'label'])
+ else:
+ tokenized_dataset = tokenized_dataset.remove_columns(
+ 'text')
+
+ if not wrap:
+ if streaming is False:
+ tokenized_dataset.save_to_disk(_path)
+ return tokenized_dataset.with_format('torch')
+
+ group_texts = functools.partial(
+ _group_texts, block_size=block_size, bos=BOS, eos=EOS)
+ if streaming:
+ chunked_dataset = tokenized_dataset.map(
+ group_texts,
+ batched=True)
+ else:
+ chunked_dataset = tokenized_dataset.map(
+ group_texts,
+ batched=True,
+ num_proc=num_proc,
+ load_from_cache_file=True,
+ desc='Grouping')
+ chunked_dataset.save_to_disk(_path)
+ chunked_dataset = chunked_dataset.with_format('torch')
+ return chunked_dataset
diff --git a/models/datasets/webdataset_utils.py b/models/datasets/webdataset_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e095df9a628e6bf6e23c3e159cbe2fe784f5d3c
--- /dev/null
+++ b/models/datasets/webdataset_utils.py
@@ -0,0 +1,488 @@
+import ast
+import json
+import logging
+import math
+import os
+import random
+import sys
+from dataclasses import dataclass
+from multiprocessing import Value
+from typing import List
+import numpy as np
+import pandas as pd
+import torch
+import torchvision.datasets as datasets
+import webdataset as wds
+from PIL import Image
+from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info
+from torch.utils.data.distributed import DistributedSampler
+from webdataset.filters import _shuffle
+from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
+from torch.utils.data import default_collate
+
+
+class SharedEpoch:
+ def __init__(self, epoch: int = 0):
+ self.shared_epoch = Value('i', epoch)
+
+ def set_value(self, epoch):
+ self.shared_epoch.value = epoch
+
+ def get_value(self):
+ return self.shared_epoch.value
+
+
+@dataclass
+class DataInfo:
+ dataloader: DataLoader
+ sampler: DistributedSampler = None
+ shared_epoch: SharedEpoch = None
+
+ def set_epoch(self, epoch):
+ if self.shared_epoch is not None:
+ self.shared_epoch.set_value(epoch)
+ if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
+ self.sampler.set_epoch(epoch)
+
+
+def expand_urls(urls, weights=None):
+ if weights is None:
+ expanded_urls = wds.shardlists.expand_urls(urls)
+ return expanded_urls, None
+ if isinstance(urls, str):
+ urllist = urls.split("::")
+ weights = weights.split('::')
+ assert len(weights) == len(urllist),\
+ f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match."
+ weights = [float(weight) for weight in weights]
+ all_urls, all_weights = [], []
+ for url, weight in zip(urllist, weights):
+ import braceexpand
+ expanded_url = list(braceexpand.braceexpand(url))
+ expanded_weights = [weight for _ in expanded_url]
+ all_urls.extend(expanded_url)
+ all_weights.extend(expanded_weights)
+ return all_urls, all_weights
+ else:
+ all_urls = list(urls)
+ return all_urls, weights
+
+
+def get_dataset_size(shards):
+ shards_list, _ = expand_urls(shards)
+ dir_path = os.path.dirname(shards_list[0])
+ sizes_filename = os.path.join(dir_path, 'sizes.json')
+ len_filename = os.path.join(dir_path, '__len__')
+ if os.path.exists(sizes_filename):
+ sizes = json.load(open(sizes_filename, 'r'))
+ total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list])
+ elif os.path.exists(len_filename):
+ # FIXME this used to be eval(open(...)) but that seemed rather unsafe
+ total_size = ast.literal_eval(open(len_filename, 'r').read())
+ else:
+ total_size = None # num samples undefined
+ # some common dataset sizes (at time of authors last download)
+ # CC3M (train): 2905954
+ # CC12M: 10968539
+ # LAION-400M: 407332084
+ # LAION-2B (english): 2170337258
+ num_shards = len(shards_list)
+ return total_size, num_shards
+
+def count_samples(dataloader):
+ os.environ["WDS_EPOCH"] = "0"
+ n_elements, n_batches = 0, 0
+ for images, texts in dataloader:
+ n_batches += 1
+ n_elements += len(images)
+ assert len(images) == len(texts)
+ return n_elements, n_batches
+
+
+def filter_no_caption_or_no_image(sample):
+ has_caption = ('txt' in sample)
+ has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample)
+ return has_caption and has_image
+
+
+def log_and_continue(exn):
+ """Call in an exception handler to ignore any exception, issue a warning, and continue."""
+ logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
+ return True
+
+
+def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+ """Return function over iterator that groups key, value pairs into samples.
+
+ :param keys: function that splits the key into key and extension (base_plus_ext)
+ :param lcase: convert suffixes to lower case (Default value = True)
+ """
+ current_sample = None
+ for filesample in data:
+ assert isinstance(filesample, dict)
+ if "fname" not in filesample or "data" not in filesample:
+ continue
+ fname, value = filesample["fname"], filesample["data"]
+ prefix, suffix = keys(fname)
+ if prefix is None:
+ continue
+ if lcase:
+ suffix = suffix.lower()
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
+ if valid_sample(current_sample):
+ yield current_sample
+ current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
+ if suffixes is None or suffix in suffixes:
+ current_sample[suffix] = value
+ if valid_sample(current_sample):
+ yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=log_and_continue):
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+ streams = url_opener(src, handler=handler)
+ files = tar_file_expander(streams, handler=handler)
+ samples = group_by_keys_nothrow(files, handler=handler)
+ return samples
+
+
+def pytorch_worker_seed(increment=0):
+ """get dataloader worker seed from pytorch"""
+ worker_info = get_worker_info()
+ if worker_info is not None:
+ # favour using the seed already created for pytorch dataloader workers if it exists
+ seed = worker_info.seed
+ if increment:
+ # space out seed increments so they can't overlap across workers in different iterations
+ seed += increment * max(1, worker_info.num_workers)
+ return seed
+ # fallback to wds rank based seed
+ return wds.utils.pytorch_worker_seed()
+
+
+_SHARD_SHUFFLE_SIZE = 2000
+_SHARD_SHUFFLE_INITIAL = 500
+_SAMPLE_SHUFFLE_SIZE = 5000
+_SAMPLE_SHUFFLE_INITIAL = 1000
+
+
+class detshuffle2(wds.PipelineStage):
+ def __init__(
+ self,
+ bufsize=1000,
+ initial=100,
+ seed=0,
+ epoch=-1,
+ ):
+ self.bufsize = bufsize
+ self.initial = initial
+ self.seed = seed
+ self.epoch = epoch
+
+ def run(self, src):
+ if isinstance(self.epoch, SharedEpoch):
+ epoch = self.epoch.get_value()
+ else:
+ # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
+ # situation as different workers may wrap at different times (or not at all).
+ self.epoch += 1
+ epoch = self.epoch
+ rng = random.Random()
+ if self.seed < 0:
+ # If seed is negative, we use the worker's seed, this will be different across all nodes/workers
+ seed = pytorch_worker_seed(epoch)
+ else:
+ # This seed to be deterministic AND the same across all nodes/workers in each epoch
+ seed = self.seed + epoch
+ rng.seed(seed)
+ return _shuffle(src, self.bufsize, self.initial, rng)
+
+
+class ResampledShards2(IterableDataset):
+ """An iterable dataset yielding a list of urls."""
+
+ def __init__(
+ self,
+ urls,
+ weights=None,
+ nshards=sys.maxsize,
+ worker_seed=None,
+ deterministic=False,
+ epoch=-1,
+ ):
+ """Sample shards from the shard list with replacement.
+
+ :param urls: a list of URLs as a Python list or brace notation string
+ """
+ super().__init__()
+ urls, weights = expand_urls(urls, weights)
+ self.urls = urls
+ self.weights = weights
+ if self.weights is not None:
+ assert len(self.urls) == len(self.weights),\
+ f"Number of urls {len(self.urls)} and weights {len(self.weights)} should match."
+ assert isinstance(self.urls[0], str)
+ self.nshards = nshards
+ self.rng = random.Random()
+ self.worker_seed = worker_seed
+ self.deterministic = deterministic
+ self.epoch = epoch
+ print(f"ResampledShards2: {self.urls}, {self.weights}, {self.nshards}, {self.worker_seed}, {self.deterministic}, {self.epoch}")
+
+ def __iter__(self):
+ """Return an iterator over the shards."""
+ if isinstance(self.epoch, SharedEpoch):
+ epoch = self.epoch.get_value()
+ else:
+ # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
+ # situation as different workers may wrap at different times (or not at all).
+ self.epoch += 1
+ epoch = self.epoch
+ if self.deterministic:
+ # reset seed w/ epoch if deterministic
+ if self.worker_seed is None:
+ # pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id
+ seed = pytorch_worker_seed(epoch)
+ else:
+ seed = self.worker_seed() + epoch
+ self.rng.seed(seed)
+ for _ in range(self.nshards):
+ if self.weights is None:
+ yield dict(url=self.rng.choice(self.urls))
+ else:
+ yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0])
+
+def tokenize_func(tokenizer, block_size=32):
+ from dataloader import tokenize_text
+ def _f(dictionary):
+ dictionary.update(**tokenize_text(tokenizer, block_size, dictionary['text'], return_token_type_ids=False))
+ return {k:v for k,v in dictionary.items() if k in ("img", "input_ids", "attention_mask", "text")}
+
+ return _f
+
+def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None):
+ input_shards = args.train_data if is_train else args.val_data
+ assert input_shards is not None
+ resampled = getattr(args, 'dataset_resampled', False) and is_train
+
+ num_shards = None
+ if is_train:
+ if args.train_num_samples is not None:
+ num_samples = args.train_num_samples
+ else:
+ num_samples, num_shards = get_dataset_size(input_shards)
+ if not num_samples:
+ raise RuntimeError(
+ 'Currently, the number of dataset samples must be specified for the training dataset. '
+ 'Please specify it via `--train-num-samples` if no dataset length info is present.')
+ else:
+ # Eval will just exhaust the iterator if the size is not specified.
+ num_samples = args.val_num_samples or 0
+
+ shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc
+
+ if is_train and args.train_data_upsampling_factors is not None:
+ assert resampled, "--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)."
+
+ if resampled:
+ pipeline = [ResampledShards2(
+ input_shards,
+ weights=args.train_data_upsampling_factors,
+ deterministic=True,
+ epoch=shared_epoch,
+ )]
+ else:
+ pipeline = [wds.SimpleShardList(input_shards)]
+
+ # at this point we have an iterator over all the shards
+ if is_train:
+ if not resampled:
+ pipeline.extend([
+ detshuffle2(
+ bufsize=_SHARD_SHUFFLE_SIZE,
+ initial=_SHARD_SHUFFLE_INITIAL,
+ seed=args.seed,
+ epoch=shared_epoch,
+ ),
+ wds.split_by_node,
+ wds.split_by_worker,
+ ])
+ pipeline.extend([
+ # at this point, we have an iterator over the shards assigned to each worker at each node
+ tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue),
+ wds.shuffle(
+ bufsize=_SAMPLE_SHUFFLE_SIZE,
+ initial=_SAMPLE_SHUFFLE_INITIAL,
+ ),
+ ])
+ else:
+ pipeline.extend([
+ wds.split_by_worker,
+ # at this point, we have an iterator over the shards assigned to each worker
+ wds.tarfile_to_samples(handler=log_and_continue),
+ ])
+ pipeline.extend([
+ wds.select(filter_no_caption_or_no_image),
+ wds.decode("pilrgb", handler=log_and_continue),
+ wds.rename(img="jpg;png;jpeg;webp", text="txt"),
+ wds.map_dict(img=preprocess_img),
+ wds.map(tokenize_func(tokenizer, block_size=args.block_size)),
+ wds.batched(args.batch_size, partial=not is_train, collation_fn=default_collate)
+ ])
+
+ dataset = wds.DataPipeline(*pipeline)
+
+ if is_train:
+ if not resampled:
+ num_shards = num_shards or len(expand_urls(input_shards)[0])
+ assert num_shards >= args.workers * args.world_size, f'number of shards must be >= total workers, {num_shards} < {args.workers * args.world_size}'
+ # roll over and repeat a few samples to get same number of full batches on each node
+ round_fn = math.floor if floor else math.ceil
+ global_batch_size = args.batch_size * args.world_size
+ num_batches = round_fn(num_samples / global_batch_size)
+ num_workers = max(1, args.workers)
+ num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+ dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this
+ else:
+ # last batches are partial, eval is done on single (master) node
+ num_batches = math.ceil(num_samples / args.batch_size)
+
+ dataloader = wds.WebLoader(
+ dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=args.workers,
+ persistent_workers=args.workers > 0,
+ )
+
+ # FIXME not clear which approach is better, with_epoch before vs after dataloader?
+ # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
+ # if is_train:
+ # # roll over and repeat a few samples to get same number of full batches on each node
+ # global_batch_size = args.batch_size * args.world_size
+ # num_batches = math.ceil(num_samples / global_batch_size)
+ # num_workers = max(1, args.workers)
+ # num_batches = math.ceil(num_batches / num_workers) * num_workers
+ # num_samples = num_batches * global_batch_size
+ # dataloader = dataloader.with_epoch(num_batches)
+ # else:
+ # # last batches are partial, eval is done on single (master) node
+ # num_batches = math.ceil(num_samples / args.batch_size)
+
+ # add meta-data to dataloader instance for convenience
+ dataloader.num_batches = num_batches
+ dataloader.num_samples = num_samples
+
+ return dataloader
+
+
+
+class SyntheticDataset(Dataset):
+
+ def __init__(
+ self,
+ transform=None,
+ image_size=(224, 224),
+ caption="Dummy caption",
+ dataset_size=100,
+ tokenizer=None,
+ ):
+ self.transform = transform
+ self.image_size = image_size
+ self.caption = caption
+ self.image = Image.new('RGB', image_size)
+ self.dataset_size = dataset_size
+
+ self.preprocess_txt = lambda text: tokenizer(text)[0]
+
+ def __len__(self):
+ return self.dataset_size
+
+ def __getitem__(self, idx):
+ if self.transform is not None:
+ image = self.transform(self.image)
+ return image, self.preprocess_txt(self.caption)
+
+
+def get_synthetic_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None):
+ image_size = preprocess_fn.transforms[0].size
+ dataset = SyntheticDataset(
+ transform=preprocess_fn, image_size=image_size, dataset_size=args.train_num_samples, tokenizer=tokenizer)
+ num_samples = len(dataset)
+ sampler = DistributedSampler(dataset) if args.distributed and is_train else None
+ shuffle = is_train and sampler is None
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=args.batch_size,
+ shuffle=shuffle,
+ num_workers=args.workers,
+ pin_memory=True,
+ sampler=sampler,
+ drop_last=is_train,
+ )
+ dataloader.num_samples = num_samples
+ dataloader.num_batches = len(dataloader)
+
+ return DataInfo(dataloader, sampler).dataloader
+
+
+def get_dataset_fn(data_path, dataset_type):
+ if dataset_type == "webdataset":
+ return get_wds_dataset
+ elif dataset_type == "synthetic":
+ return get_synthetic_dataset
+ else:
+ raise ValueError(f"Unsupported dataset type: {dataset_type}")
+
+def get_data(args, preprocess_fns, epoch=0, tokenizer=None, is_train=True):
+ preprocess_train, preprocess_val = preprocess_fns
+ if is_train:
+ return get_dataset_fn(args.train_data, args.dataset_type)(
+ args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer)
+ else:
+ return get_dataset_fn(args.val_data, args.dataset_type)(
+ args, preprocess_val, is_train=False, tokenizer=tokenizer)
+
+if __name__ == "__main__":
+ from omegaconf import OmegaConf
+ config = OmegaConf.create({
+ 'train_data': '/grogu/datasets/laion400m/dataset/{00000..00625}.tar',
+ 'dataset_type': 'webdataset',
+ 'train_data_upsampling_factors': None,
+ 'batch_size': 64,
+ 'workers': config.loader.num_workers,
+ 'distributed': True,
+ 'seed': config.seed,
+ 'val_num_samples': None,
+ 'train_num_samples': 100000,
+ 'val_data': None,
+ 'imagenet_val': None,
+ 'imagenet_v2': None,
+ 'world_size': 1,
+ "data" : {
+ "tokenizer_name_or_path" : "gpt2",
+ }
+ })
+ args = config
+ from dataloader import get_tokenizer
+ tokenizer = get_tokenizer(config)
+ from torchvision import transforms
+ transform_train = transforms.Compose(
+ [
+ transforms.RandomResizedCrop(224),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ ]
+ )
+
+ data = get_data(config, (transform_train, transform_train), epoch=0, tokenizer=tokenizer)
+ print(data)
+ batch = next(iter(data['train'].dataloader))
+ breakpoint()
\ No newline at end of file
diff --git a/models/dit.py b/models/dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a33d78a1a7cb4e4878d8e931cf8e257d3c338a8
--- /dev/null
+++ b/models/dit.py
@@ -0,0 +1,1500 @@
+import math
+import typing
+from contextlib import nullcontext
+import os
+
+# Torch must be imported before flash-attn
+from unidisc.utils.tensor_utils import get_contiguous_blocks, get_interleaved_indices
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.checkpoint import checkpoint
+from diffusers.models.embeddings import get_2d_rotary_pos_embed_lumina
+from decoupled_utils import gprint, is_torch_xla_available, rprint
+from models.standalone_rotary import flash_torch_apply_rotary_emb_torch
+
+import huggingface_hub
+import omegaconf
+from einops import rearrange
+
+is_xla_available = is_torch_xla_available()
+
+force_cudnn_spda_context = os.environ.get("UNIDISC_FORCE_CUDNN_SPDA_CONTEXT", "0") == "1"
+allow_any_spda = os.environ.get("UNIDISC_ALLOW_ANY_SPDA", "0") == "1"
+force_xla_flash_attention = os.environ.get("UNIDISC_FORCE_XLA_FLASH_ATTENTION", "0") == "1"
+use_non_packed_fa2 = os.getenv("UNIDISC_USE_NON_PACKED_FA2", "0") == "1"
+disable_flash_attention_3 = os.getenv("UNIDISC_FORCE_DISABLE_FA3", "0") == "1"
+is_xla_linear_patched = os.getenv("UNIDISC_IS_XLA_LINEAR_PATCHED", "0") == "1"
+use_causal_attn = os.getenv("UNIDISC_USE_CAUSAL_ATTN", "0") == "1"
+
+if force_cudnn_spda_context: rprint("Forcing cudnn spda context")
+if allow_any_spda: rprint("Allowing any spda")
+if force_xla_flash_attention: rprint("Forcing xla flash attention")
+if use_non_packed_fa2: rprint("Using non-packed Flash Attention 2!")
+if disable_flash_attention_3: rprint("Disabling Flash Attention 3!")
+
+try:
+ failed_to_import_fa3 = True
+ if disable_flash_attention_3 is False:
+ from flash_attn_interface import flash_attn_func as flash_attn_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3
+
+ failed_to_import_fa3 = False
+ rprint("Imported Flash Attention 3!")
+except:
+ rprint("Not using Flash Attention 3!")
+
+try:
+ import flash_attn.layers.rotary
+ from flash_attn.layers.rotary import apply_rotary_emb
+
+ if failed_to_import_fa3:
+ from flash_attn.flash_attn_interface import (flash_attn_func,
+ flash_attn_qkvpacked_func,
+ flash_attn_varlen_func,
+ flash_attn_varlen_qkvpacked_func)
+ rprint("Imported Flash Attention 2!")
+except:
+ rprint("Failed to import Flash Attention 2!")
+
+try:
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
+ from torch.nn.attention import SDPBackend, sdpa_kernel
+except:
+ pass
+
+try:
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+ compiled_flex_attention = torch.compile(flex_attention)
+except:
+ pass
+
+# Flags required to enable jit fusion kernels
+torch._C._jit_set_profiling_mode(False)
+torch._C._jit_set_profiling_executor(False)
+torch._C._jit_override_can_fuse_on_cpu(True)
+torch._C._jit_override_can_fuse_on_gpu(True)
+
+class RMSNorm(torch.nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6):
+ """
+ Initialize the RMSNorm normalization layer.
+
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+
+ """
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+@torch.no_grad()
+def get_transfusion_mask(B, N_tot, img_start_idx, img_length, modality):
+ # todo - this is temporary and only works for [text+image] mode. does NOT handle interleaved (need to use modality_mask for this)
+ # (B, N_tot) -> (B, N_tot, N_tot)
+ rows, cols = torch.meshgrid(torch.arange(N_tot), torch.arange(N_tot), indexing="ij")
+ idxs = torch.stack([rows, cols], dim=-1).to(modality.device)
+ idxs = idxs.expand(B, -1, -1, -1)
+ q_idx, kv_idx = idxs.unbind(dim=-1)
+
+ offset = torch.full((B,), img_start_idx, device=modality.device).unsqueeze(-1).unsqueeze(-1)
+ limit = torch.full((B,), img_length, device=modality.device).unsqueeze(-1).unsqueeze(-1)
+
+ ar = q_idx >= kv_idx
+ nar = (q_idx >= offset) & (kv_idx >= limit)
+ mask = ar | nar
+
+ # Assume that batches with all text are autoregressive only
+ mask = torch.where(((modality == 0).all(dim=-1))[:, None, None], ar, mask)
+ return mask
+
+@torch.compiler.disable()
+def add_img_data_to_blocks(input_emb, rotary_emb, modality_mask, sample_ids, add_data, img_count_embedding):
+ """
+ Dynamically adds 2D RoPE embeddings to image blocks. Handles variable resolutions by matching to hardcoded block sizes.
+ """
+ assert sample_ids is not None
+ B, N = modality_mask.shape
+ batch_indices, start_positions, end_positions = get_interleaved_indices(modality_mask)
+
+ block_sizes = end_positions - start_positions
+ unique_block_sizes = [size for size in torch.unique(block_sizes).tolist() if size in add_data.keys()]
+
+ # For each block, count number of blocks before it within same sample_id group
+ block_counts = torch.zeros_like(batch_indices)
+ for i in range(len(batch_indices)):
+ curr_sample_id = sample_ids[batch_indices[i], start_positions[i]]
+
+ # Find blocks before this one with same batch index and sample_id
+ prev_blocks_mask = (batch_indices[:i] == batch_indices[i]) & \
+ (sample_ids[batch_indices[:i], start_positions[:i]] == curr_sample_id)
+
+ block_counts[i] = prev_blocks_mask.sum()
+
+ for block_size in unique_block_sizes:
+ block_mask = (block_sizes == block_size)
+ block_indices = block_mask.nonzero(as_tuple=False).squeeze()
+ if block_indices.ndim == 0:
+ block_indices = block_indices.unsqueeze(0)
+
+ if block_indices.numel() == 0:
+ continue
+
+ # Get the batch indices and start positions for these blocks
+ batch_idx = batch_indices[block_indices]
+ start_pos = start_positions[block_indices]
+ img_idx = block_counts[block_indices] # Get the block count for each selected block
+
+ # Calculate the maximum valid length for each block (in case they exceed N)
+ max_lengths = torch.clamp(N - start_pos, max=block_size)
+ max_block_length = max_lengths.max().item()
+
+ positions = start_pos.unsqueeze(1) + torch.arange(max_block_length, device=rotary_emb.device).unsqueeze(0) # [num_blocks, max_block_length]
+
+ # Create a mask to handle blocks that may be shorter than block_size
+ valid_mask = torch.arange(max_block_length, device=rotary_emb.device).unsqueeze(0) < max_lengths.unsqueeze(1)
+ positions = positions * valid_mask # Positions beyond valid lengths are set to zero
+ batch_idx_expanded = batch_idx.unsqueeze(1).expand(-1, max_block_length)
+
+ if input_emb is not None:
+ input_emb_to_add_full = img_count_embedding[img_idx][:, None, :] # Shape: [block_size]
+ input_emb_to_add = input_emb_to_add_full.expand(-1, valid_mask.shape[-1], -1)
+ input_emb_to_add = input_emb_to_add * valid_mask.unsqueeze(-1) # Mask data beyond valid lengths
+ input_emb[batch_idx_expanded[valid_mask], positions[valid_mask], :] = input_emb[batch_idx_expanded[valid_mask], positions[valid_mask], :] + input_emb_to_add[valid_mask]
+
+ rotary_emb_to_add_full = add_data[block_size] # Shape: [block_size]
+ rotary_emb_to_add = rotary_emb_to_add_full[:max_block_length].unsqueeze(0).expand(batch_idx.size(0), -1, -1)
+ rotary_emb_to_add = rotary_emb_to_add * valid_mask.unsqueeze(-1) # Mask data beyond valid lengths
+ rotary_emb[batch_idx_expanded[valid_mask], positions[valid_mask], :] = rotary_emb_to_add[valid_mask]
+
+@torch.compiler.disable()
+def add_txt_data_to_blocks(rotary_emb, modality_mask, sample_ids, add_data):
+ assert sample_ids is not None
+ batch_indices, start_positions, end_positions = get_contiguous_blocks(sample_ids)
+ block_sizes = end_positions - start_positions
+ for i in range(len(batch_indices)):
+ batch_idx = batch_indices[i]
+ start_pos = start_positions[i]
+ block_size = block_sizes[i]
+ sample_slice = slice(start_pos, start_pos+block_size)
+ rotary_emb[batch_idx, sample_slice, :] = torch.where(modality_mask[batch_idx, sample_slice, None], rotary_emb[batch_idx, sample_slice, :], add_data[:block_size])
+
+def apply_xla_flash_attention_with_spmd(query_states, key_states, value_states, causal=False):
+ from torch_xla.experimental.custom_kernel import flash_attention
+
+ # q, k, v should all have the shape [B, n_head, S, head_dim]
+ head_dim = query_states.size()[-1]
+ query_states = query_states / math.sqrt(head_dim)
+
+ # Our simplified version of decoder only model does not use any mask.
+ # flash_attention will use the global_mesh set in the TrainDecoderOnlyFSDPv2.
+ attn_output = flash_attention(query_states, key_states, value_states, causal=causal, partition_spec=("fsdp", None, None, None))
+ return attn_output
+
+
+def ckpt_wrapper(module):
+ def ckpt_forward(*inputs):
+ outputs = module(*inputs)
+ return outputs
+
+ return ckpt_forward
+
+
+# To avoid XLA issues
+if is_xla_available:
+ def _dropout(x: torch.Tensor, p: float, training: bool) -> torch.Tensor:
+ if p > 0.0:
+ return F.dropout(input=x, p=p, training=training).to(torch.bfloat16)
+ else:
+ return x
+else:
+ def _dropout(x: torch.Tensor, p: float, training: bool) -> torch.Tensor:
+ if p > 0.0:
+ return F.dropout(input=x, p=p, training=training)
+ else:
+ return x
+
+
+def bias_dropout_add_scale(
+ x: torch.Tensor,
+ bias: typing.Optional[torch.Tensor],
+ scale: typing.Optional[torch.Tensor],
+ residual: typing.Optional[torch.Tensor],
+ prob: float,
+ training: bool,
+ modality: typing.Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+
+ out = _dropout(x=(x + bias) if bias is not None else x, p=prob, training=training)
+
+ if scale is not None:
+ out = scale * out
+
+ if modality is not None:
+ out = torch.where((modality == 1).unsqueeze(-1), out, _dropout(x, p=prob, training=training))
+
+ if modality is not None:
+ out = torch.where((modality == 1).unsqueeze(-1), out, x)
+
+ if residual is not None:
+ out = residual + out
+
+ return out
+
+
+def get_bias_dropout_add_scale(training):
+ def _bias_dropout_add(x, bias, scale, residual, prob):
+ return bias_dropout_add_scale(x, bias, scale, residual, prob, training)
+
+ return _bias_dropout_add
+
+# function overload
+def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
+ return x * (1 + scale) + shift
+
+def modulate_with_mask(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor, modality: torch.Tensor) -> torch.Tensor:
+ # Only images need time conditioning
+ return torch.where(modality.unsqueeze(-1) == 1, x * (1 + scale) + shift, x)
+
+if is_xla_available:
+ def bias_dropout_add_scale_fused_train(
+ x: torch.Tensor, bias: typing.Optional[torch.Tensor], scale: typing.Optional[torch.Tensor], residual: typing.Optional[torch.Tensor], prob: float, modality: typing.Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ return bias_dropout_add_scale(x, bias, scale, residual, prob, True, modality)
+
+ def bias_dropout_add_scale_fused_inference(
+ x: torch.Tensor, bias: typing.Optional[torch.Tensor], scale: typing.Optional[torch.Tensor], residual: typing.Optional[torch.Tensor], prob: float, modality: typing.Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ return bias_dropout_add_scale(x, bias, scale, residual, prob, False, modality)
+
+ def modulate_fused(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor, modality: torch.Tensor=None) -> torch.Tensor:
+ if modality is not None and modality.any():
+ return modulate_with_mask(x, shift, scale, modality)
+ return modulate(x, shift, scale)
+else:
+ @torch.jit.script
+ def bias_dropout_add_scale_fused_train(
+ x: torch.Tensor, bias: typing.Optional[torch.Tensor], scale: typing.Optional[torch.Tensor], residual: typing.Optional[torch.Tensor], prob: float, modality: typing.Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ return bias_dropout_add_scale(x, bias, scale, residual, prob, True, modality)
+
+
+ @torch.jit.script
+ def bias_dropout_add_scale_fused_inference(
+ x: torch.Tensor, bias: typing.Optional[torch.Tensor], scale: typing.Optional[torch.Tensor], residual: typing.Optional[torch.Tensor], prob: float, modality: typing.Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ return bias_dropout_add_scale(x, bias, scale, residual, prob, False, modality)
+
+
+ @torch.jit.script
+ def modulate_fused(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor, modality: typing.Optional[torch.Tensor] = None) -> torch.Tensor:
+ if modality is not None and modality.any():
+ return modulate_with_mask(x, shift, scale, modality)
+ return modulate(x, shift, scale)
+
+
+class Rotary(torch.nn.Module):
+ def __init__(self, dim, base=10_000):
+ super().__init__()
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+ self.seq_len_cached = None
+ self.cos_cached = None
+ self.sin_cached = None
+
+ def forward(self, seq_len, device=None):
+ # seq_len = x.shape[seq_dim]
+ if seq_len != self.seq_len_cached:
+ self.seq_len_cached = seq_len
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
+ emb = torch.cat((freqs, freqs), dim=-1).to(device)
+ # dims are: batch, seq_len, qkv, head, dim
+ self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
+ self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
+ # This makes the transformation on v an identity.
+ self.cos_cached[:, :, 2, :, :].fill_(1.0)
+ self.sin_cached[:, :, 2, :, :].fill_(0.0)
+
+ return self.cos_cached, self.sin_cached
+
+ @staticmethod
+ def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0, rope_scaling_factor: float = 1.0, ntk_factor: float = 1.0):
+ """
+ Precompute the frequency tensor for complex exponentials (cis) with
+ given dimensions.
+
+ This function calculates a frequency tensor with complex exponentials
+ using the given dimension 'dim' and the end index 'end'. The 'theta'
+ parameter scales the frequencies. The returned tensor contains complex
+ values in complex64 data type.
+
+ Args:
+ dim (int): Dimension of the frequency tensor.
+ end (int): End index for precomputing frequencies.
+ theta (float, optional): Scaling factor for frequency computation.
+ Defaults to 10000.0.
+
+ Returns:
+ torch.Tensor: Precomputed frequency tensor with complex
+ exponentials.
+ """
+
+ theta = theta * ntk_factor
+
+ rprint(f"theta {theta} rope scaling {rope_scaling_factor} ntk {ntk_factor}")
+
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float().cuda() / dim))
+ t = torch.arange(seq_len, device=freqs.device, dtype=torch.float) # type: ignore
+ t = t / rope_scaling_factor
+ freqs = torch.outer(t, freqs).float() # type: ignore
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ cos = cos[:, : cos.shape[-1] // 2]
+ sin = sin[:, : sin.shape[-1] // 2]
+ return cos, sin
+
+
+def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(qkv, cos, sin):
+ # cos = cos[0, :, 0, 0, : cos.shape[-1] // 2]
+ # sin = sin[0, :, 0, 0, : sin.shape[-1] // 2]
+ return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)
+
+#################################################################################
+# Layers #
+#################################################################################
+class LayerNorm(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones([dim]))
+ self.dim = dim
+
+ def forward(self, x):
+ with torch.amp.autocast(x.device.type, enabled=False):
+ x = F.layer_norm(x.float(), [self.dim])
+
+ if is_xla_available:
+ x = x.to(torch.bfloat16)
+ if x.ndim == 3:
+ return (x * self.weight[None, None, :]).to(torch.bfloat16)
+ elif x.ndim == 2:
+ return (x * self.weight[None]).to(torch.bfloat16)
+ else:
+ if x.ndim == 3:
+ return x * self.weight[None, None, :]
+ elif x.ndim == 2:
+ return x * self.weight[None]
+
+
+def residual_linear(x, W, x_skip, residual_scale):
+ """x_skip + residual_scale * W @ x"""
+ dim_out, dim_in = W.shape[0], W.shape[1]
+ return torch.addmm(x_skip.view(-1, dim_out), x.view(-1, dim_in), W.T, alpha=residual_scale).view(*x.shape[:-1], dim_out)
+
+
+#################################################################################
+# Embedding Layers for Timesteps and Class Labels #
+#################################################################################
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True)
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
+ args = (t[:, None].float() * freqs[None] if t.ndim == 1 else t[..., None].float() * freqs[None, None]) # TODO @sid I think this is right but remind me if things aren't working
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class LabelEmbedderCFG(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+ """
+
+ def __init__(self, num_classes, hidden_size, dropout_prob):
+ super().__init__()
+ use_cfg_embedding = dropout_prob > 0
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
+ self.num_classes = num_classes
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, labels, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
+ else:
+ drop_ids = force_drop_ids == 1
+ labels = torch.where(drop_ids, self.num_classes, labels)
+ return labels
+
+ def forward(self, labels, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ labels = self.token_drop(labels, force_drop_ids)
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+
+class LabelEmbedder(nn.Module):
+ """Embeds class labels into vector representations.
+
+ Also handles label dropout for classifier-free guidance.
+ """
+
+ def __init__(self, num_classes, cond_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes + 1, cond_size)
+ self.num_classes = num_classes
+
+ # TODO think of initializing with 0.02 std deviation like in original DiT paper
+
+ def forward(self, labels):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+
+def get_norm(*args, norm_type="layernorm", elementwise_affine=False, **kwargs):
+ if norm_type == "layernorm":
+ return LayerNorm(*args, **kwargs)
+ elif norm_type == "rms":
+ return RMSNorm(*args, **kwargs)
+ else:
+ raise ValueError(f"Unknown norm type: {norm_type}")
+
+
+def get_linear(*args, **kwargs):
+ return nn.Linear(*args, **kwargs)
+
+def causal_mask(b, h, q_idx, kv_idx):
+ return q_idx >= kv_idx
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ dropout=0.1,
+ cross_attn=False,
+ attn_type="flash",
+ is_compiled=False,
+ force_varlen_attn=False,
+ force_cast_bf16=False,
+ qk_norm=False,
+ use_flash_attn_3=False,
+ use_spda_attn=False,
+ compile_flag_pos_emb=False,
+ causal=False,
+ use_kv_cache=False,
+ time_conditioning=False,
+ use_flex_attention=False,
+ idx=None,
+ attn_dropout=None
+ ):
+ super().__init__()
+ self.cross_attn = cross_attn
+ self.attn_type = attn_type
+ self.force_varlen_attn = force_varlen_attn
+ self.is_compiled = is_compiled
+ self.compile_flag_pos_emb = compile_flag_pos_emb
+ self.n_heads = n_heads
+ self.force_cast_bf16 = force_cast_bf16
+ self.qk_norm = qk_norm
+ self.head_dim = dim // n_heads
+ self.dropout = dropout
+ self.use_flash_attn_3 = use_flash_attn_3
+ self.use_spda_attn = use_spda_attn
+ self.causal = causal
+ self.use_kv_cache = use_kv_cache
+ self.time_conditioning = time_conditioning
+ self.use_flex_attention = use_flex_attention
+ self.idx = idx
+ self.attn_dropout = attn_dropout
+ if self.attn_dropout is None:
+ self.attn_dropout = 0
+
+ self.old_start_pos = None
+
+ self.attn_qkv = get_linear(dim, 3 * dim, bias=False)
+
+ if self.cross_attn:
+ self.attn_qkv_cond = get_linear(dim, 3 * dim, bias=False)
+
+ self.attn_out = get_linear(dim, dim, bias=False)
+
+ if self.qk_norm:
+ self.q_norm = nn.LayerNorm(self.n_heads * self.head_dim)
+ self.k_norm = nn.LayerNorm(self.n_heads * self.head_dim)
+ assert self.cross_attn is False
+
+ self.softmax_scale = None
+
+ if self.use_flash_attn_3 or self.use_spda_attn:
+ assert self.attn_type == "flash" and self.force_varlen_attn is False
+ assert self.cross_attn is False
+
+ if self.use_flex_attention:
+ assert self.attn_type == "flash" and self.use_spda_attn
+ assert allow_any_spda is False
+ assert self.softmax_scale is None
+
+ self.use_flex_attention_cache = False
+ self.warn_cache_dtype = True
+
+ def update_kv_cache(self, q, new_k, new_v, batch_size, start_pos, seq_len):
+ self.cache_k[:, start_pos : start_pos + seq_len] = new_k
+ self.cache_v[:, start_pos : start_pos + seq_len] = new_v
+ k = self.cache_k[:, :start_pos + seq_len] # (batch_size, cache_len + seq_len, nheads, headdim)
+ v = self.cache_v[:, :start_pos + seq_len] # (batch_size, cache_len + seq_len, nheads, headdim)
+ return q, k, v # q is (batch_size, seq_len, nheads*headdim)
+
+ def reset_kv_cache(self, batch_size, seq_len, dtype, device, set_to_none=False):
+ assert self.use_kv_cache
+ if set_to_none:
+ del self.cache_k
+ del self.cache_v
+ self.cache_k = None
+ self.cache_v = None
+ else:
+ self.cache_k = torch.zeros(
+ batch_size, seq_len, self.n_heads, self.head_dim, dtype=dtype, device=device
+ )
+ self.cache_v = torch.zeros(
+ batch_size, seq_len, self.n_heads, self.head_dim, dtype=dtype, device=device
+ )
+
+ def set_flex_attention_cache(self, batch_size, seq_len, device, dtype):
+ assert self.use_flex_attention
+ self.use_flex_attention_cache = True
+ self.cache_k = torch.zeros(batch_size, self.n_heads, seq_len, self.head_dim, device=device, dtype=dtype)
+ self.cache_v = torch.zeros(batch_size, self.n_heads, seq_len, self.head_dim, device=device, dtype=dtype)
+
+ def forward(
+ self,
+ x,
+ x_cond=None,
+ x_skip=None,
+ rotary_cos_sin=None,
+ cu_seqlens=None,
+ max_seqlen_in_batch=None,
+ bias_dropout_scale_fn=None,
+ gate_msa=None,
+ attention_mask=None,
+ start_pos=None,
+ modality=None,
+ block_mask=None,
+ update_cache_slice=None,
+ ):
+ if x.ndim == 2:
+ batch_size, seq_len = 1, x.shape[0]
+ has_batch_dim = False
+ else:
+ batch_size, seq_len = x.shape[0], x.shape[1]
+ has_batch_dim = True
+
+ if is_xla_linear_patched:
+ x = x.to(torch.float32)
+
+ qkv = self.attn_qkv(x)
+ if self.use_kv_cache and start_pos is not None:
+ if not self.cache_k.dtype == self.cache_v.dtype == qkv.dtype:
+ self.cache_k = self.cache_k.to(qkv.dtype)
+ self.cache_v = self.cache_v.to(qkv.dtype)
+
+ if is_xla_linear_patched:
+ qkv = qkv.to(torch.bfloat16)
+
+ if self.cross_attn:
+ qkv_cond = self.attn_qkv_cond(x_cond)
+
+ if not has_batch_dim:
+ if self.cross_attn:
+ q = q.unsqueeze(0)
+ kv = kv.unsqueeze(0)
+ else:
+ qkv = qkv.unsqueeze(0)
+
+ # qkv now has b s (three h d)
+ if self.qk_norm:
+ if is_xla_available:
+ if is_xla_linear_patched:
+ qkv_size = self.n_heads * self.head_dim
+ qkv = torch.cat(
+ [
+ self.q_norm(qkv[:, :, :qkv_size].to(torch.bfloat16)).to(torch.bfloat16),
+ self.k_norm(qkv[:, :, qkv_size : 2 * qkv_size].to(torch.bfloat16)).to(torch.bfloat16),
+ qkv[:, :, 2 * qkv_size :].to(torch.bfloat16),
+ ],
+ dim=-1,
+ ).to(torch.bfloat16)
+ else:
+ qkv_size = self.n_heads * self.head_dim
+ qkv = torch.cat(
+ [self.q_norm(qkv[:, :, :qkv_size]), self.k_norm(qkv[:, :, qkv_size : 2 * qkv_size]), qkv[:, :, 2 * qkv_size :]], dim=-1
+ )
+ else:
+ qkv_size = self.n_heads * self.head_dim
+ qkv[:, :, :qkv_size] = self.q_norm(qkv[:, :, :qkv_size])
+ qkv[:, :, qkv_size : 2 * qkv_size] = self.k_norm(qkv[:, :, qkv_size : 2 * qkv_size])
+
+ if rotary_cos_sin is not None:
+ orig_dtype = qkv.dtype
+ assert not (self.is_compiled and self.qk_norm is None)
+ if cu_seqlens is not None and self.force_varlen_attn is False:
+ assert not self.cross_attn, "Not yet supported"
+ assert qkv.is_contiguous()
+ qkv = rearrange(qkv, "b s (three h d) -> (b s) three h d", three=3, h=self.n_heads)
+ qk = qkv[:, :2].reshape(seq_len, -1, self.head_dim) # (b s) (two h) d
+ with torch.autocast(x.device.type, enabled=False):
+ cos, sin = rotary_cos_sin
+ qk = apply_rotary_emb(
+ qk, cos.to(qkv.dtype), sin.to(qkv.dtype), inplace=True, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch
+ )
+ qkv[:, :2] = qk.reshape(seq_len, 2, -1, self.head_dim)
+ else:
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.n_heads)
+ if self.cross_attn:
+ qkv_cond = rearrange(qkv_cond, "b s (three h d) -> b s three h d", three=3, h=self.n_heads)
+
+ with torch.autocast(x.device.type, enabled=is_xla_available):
+ cos, sin = rotary_cos_sin
+
+ # TODO: This causes a ~4-8% slowdown on XLA
+ if self.compile_flag_pos_emb:
+ if is_xla_available:
+ if is_xla_linear_patched:
+ cos, sin, qkv = cos.to(torch.bfloat16), sin.to(torch.bfloat16), qkv.to(torch.bfloat16)
+ qk = qkv[:, :, :2].to(torch.bfloat16).reshape(batch_size, seq_len, -1, self.head_dim).to(torch.bfloat16)
+ qk = flash_torch_apply_rotary_emb_torch(qk, cos, sin)
+ qkv = qkv.clone() # TODO: Appears to be needed for XLA
+ qkv = qkv.to(torch.bfloat16)
+ qkv[:, :, :2] = qk.to(torch.bfloat16).reshape(batch_size, seq_len, 2, -1, self.head_dim).to(torch.bfloat16)
+ qkv = qkv.to(torch.bfloat16)
+ else:
+ qk = qkv[:, :, :2].reshape(batch_size, seq_len, -1, self.head_dim)
+ qk = flash_torch_apply_rotary_emb_torch(qk, cos, sin).to(x)
+ qkv = qkv.clone() # TODO: Appears to be needed for XLA
+ qkv[:, :, :2] = qk.reshape(batch_size, seq_len, 2, -1, self.head_dim)
+ qkv = qkv.to(x)
+ else:
+ qk = qkv[:, :, :2].reshape(batch_size, seq_len, -1, self.head_dim)
+ qk = flash_torch_apply_rotary_emb_torch(qk, cos, sin)
+ qkv[:, :, :2] = qk.reshape(batch_size, seq_len, 2, -1, self.head_dim)
+ else:
+ qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
+
+ if self.cross_attn:
+ qkv_cond = apply_rotary_pos_emb(qkv_cond, cos.to(qkv_cond.dtype), sin.to(qkv_cond.dtype))
+ qkv_cond = qkv_cond.to(orig_dtype)
+ q, _, _ = qkv.unbind(dim=2)
+ _, k_cond, v_cond = qkv_cond.unbind(dim=2)
+
+ qkv = qkv.to(orig_dtype)
+ if self.force_varlen_attn:
+ assert start_pos is not None
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
+ else:
+ assert not self.use_flash_attn_3
+ if cu_seqlens is not None:
+ assert False
+ else:
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.n_heads)
+
+ if self.use_kv_cache:
+ assert self.attn_type == "flash" and self.use_spda_attn and allow_any_spda is False and not self.use_flex_attention
+
+ if self.attn_type == "flash":
+ if cu_seqlens is None and self.force_varlen_attn is False: # qkv: (batch_size, seqlen, 3, nheads, headdim)
+ if self.use_flash_attn_3:
+ # We do not yet support flash attn 3 for cross attention
+ q, k, v = qkv[:, :, 0, :, :], qkv[:, :, 1, :, :], qkv[:, :, 2, :, :]
+ x = flash_attn_func_v3(
+ q, k, v, softmax_scale=self.softmax_scale, causal=self.causal
+ )[0]
+ elif self.use_spda_attn:
+ if allow_any_spda:
+ b, s, _, h, d = qkv.shape
+ q, k, v = qkv[:, :, 0, :, :], qkv[:, :, 1, :, :], qkv[:, :, 2, :, :]
+ q = q.view(b, -1, h, d).transpose(1, 2)
+ k = k.view(b, -1, h, d).transpose(1, 2)
+ v = v.view(b, -1, h, d).transpose(1, 2)
+
+ if attention_mask is None:
+ with nullcontext() if allow_any_spda else sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.FLASH_ATTENTION]):
+ x = sdpa(q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=None, is_causal=self.causal)
+ else:
+ x = sdpa(q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attention_mask, is_causal=self.causal)
+ else:
+ if is_xla_linear_patched:
+ qkv = qkv.to(torch.bfloat16)
+
+ q, k, v = qkv.unbind(dim=2)
+ disable_causal_attn = False
+ if self.use_kv_cache and start_pos is not None:
+ disable_causal_attn = True
+ q, k, v = self.update_kv_cache(q, k, v, batch_size, start_pos, seq_len)
+
+ is_causal = self.causal and not disable_causal_attn
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
+
+ if self.use_flex_attention:
+ # During inference we have a variable batch size which is not supported by torch.compile w/flex attention right now
+ # See: https://github.com/pytorch/pytorch/issues/136196
+ if self.training:
+ x = compiled_flex_attention(q, k, v, block_mask=block_mask)
+ else:
+ # Step 0: We want full attention for joint img/txt update
+ # Step 1: We want txt -> (txt + img) attention and img -> img attention. Cache the img kv for the next step
+ # Step 2...N: We want txt -> (txt + img) attention, using the cached kv for img
+ if self.use_flex_attention_cache:
+ if seq_len != self.cache_k.shape[2]: # Step 2
+ assert update_cache_slice is not None
+ # (B, H, S, D)
+ self.cache_k[:, :, update_cache_slice] = k
+ self.cache_v[:, :, update_cache_slice] = v
+ elif block_mask is not None and block_mask is not True: # Step 1
+ assert update_cache_slice is not None
+ assert (update_cache_slice.stop - update_cache_slice.start) == k.shape[2]
+ self.cache_k = k
+ self.cache_v = v
+ else: # Step 0
+ pass
+
+ assert block_mask is not None
+
+ # Hack to set full attention when we explicitly want it
+ if block_mask is True:
+ block_mask = None
+ x = flex_attention(q, k, v, block_mask=block_mask)
+ elif force_xla_flash_attention:
+ assert not is_causal, "XLA Flash Attention does not support causal attention"
+ x = apply_xla_flash_attention_with_spmd(q=q, k=k, v=v, causal=is_causal)
+ elif force_cudnn_spda_context:
+ with (
+ nullcontext()
+ if (is_xla_available or attention_mask is not None)
+ else sdpa_kernel(backends=[
+ SDPBackend.CUDNN_ATTENTION,
+ *([] if (self.use_spda_attn and force_cudnn_spda_context) else [SDPBackend.FLASH_ATTENTION])
+ ])
+ ):
+ dropout_p = self.attn_dropout if self.training else 0
+ x = sdpa(q, k, v, attn_mask=None, is_causal=is_causal, scale=self.softmax_scale, dropout_p=dropout_p)
+ else:
+ dropout_p = self.attn_dropout if self.training else 0
+ x = sdpa(q, k, v, attn_mask=attention_mask, is_causal=is_causal, scale=self.softmax_scale, dropout_p=dropout_p)
+
+ if is_xla_linear_patched:
+ x = x.to(torch.bfloat16)
+
+ elif self.cross_attn:
+ x = flash_attn_func(q, k_cond, v_cond, dropout_p=0.0, softmax_scale=self.softmax_scale, causal=self.causal)
+ else:
+ if use_non_packed_fa2:
+ q, k, v = qkv.unbind(dim=2)
+ x = flash_attn_func(
+ q, k, v, dropout_p=0.0, softmax_scale=self.softmax_scale, causal=self.causal
+ )
+ else:
+ x = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=self.softmax_scale, causal=self.causal)
+
+ if self.use_spda_attn:
+ x = rearrange(x, "b h s d -> b s (h d)", b=batch_size)
+ else:
+ x = rearrange(x, "b s h d -> b s (h d)", b=batch_size)
+ else:
+ if cu_seqlens is None:
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=qkv.device)
+
+ # If we want all *other* ops to be FP32, we still need to cast the input for attn to BF16 as Flash Attn only supports FP16/BF16. This is a quick hack to do this.
+ with torch.amp.autocast(x.device.type, dtype=torch.bfloat16) if self.force_cast_bf16 else nullcontext():
+ if self.cross_attn:
+ if self.force_cast_bf16:
+ q = q.to(torch.bfloat16)
+ k_cond = k_cond.to(torch.bfloat16)
+ v_cond = v_cond.to(torch.bfloat16)
+ x = flash_attn_varlen_func(
+ q, k_cond, v_cond, cu_seqlens, seq_len, dropout_p=0.0, softmax_scale=self.softmax_scale, causal=self.causal
+ )
+ else:
+ if self.force_cast_bf16:
+ qkv = qkv.to(torch.bfloat16)
+ x = flash_attn_varlen_qkvpacked_func(
+ qkv, cu_seqlens, seq_len, dropout_p=0.0, softmax_scale=self.softmax_scale, causal=self.causal
+ )
+ x = rearrange(x, "(b s) h d -> b s (h d)", b=batch_size)
+
+ if not has_batch_dim:
+ x = x.squeeze(0)
+
+ if is_xla_linear_patched:
+ x = x.to(torch.float32)
+
+ if bias_dropout_scale_fn is not None:
+ return bias_dropout_scale_fn(
+ x=self.attn_out(x),
+ bias=None,
+ scale=gate_msa,
+ residual=x_skip,
+ prob=self.dropout,
+ modality=(modality if self.time_conditioning else None),
+ )
+ else:
+ return self.attn_out(x)
+
+
+class DDiTBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ cond_dim,
+ mlp_ratio=4,
+ dropout=0.1,
+ time_conditioning=True,
+ img_cond=False,
+ norm_type="layernorm",
+ sandwich_normalization=False,
+ **kwargs,
+ ):
+ super().__init__()
+ self.time_conditioning = time_conditioning
+
+ self.dropout = dropout
+ self.attention = Attention(dim, n_heads, dropout, **kwargs)
+ self.img_cond = img_cond
+ if img_cond:
+ self.cross_attention = Attention(dim, n_heads, dropout, cross_attn=True, **kwargs)
+
+ self.norm1 = get_norm(dim, norm_type=norm_type)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm2 = get_norm(dim, norm_type=norm_type)
+
+ self.mlp = nn.Sequential(
+ get_linear(dim, mlp_ratio * dim, bias=True), nn.GELU(approximate="tanh"), get_linear(mlp_ratio * dim, dim, bias=True)
+ )
+ self.dropout2 = nn.Dropout(dropout)
+
+ if self.time_conditioning:
+ self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
+ self.adaLN_modulation.weight.data.zero_()
+ self.adaLN_modulation.bias.data.zero_()
+
+ self.sandwich_normalization = sandwich_normalization
+ if self.sandwich_normalization:
+ self.post_ff_norm = get_norm(dim, norm_type=norm_type)
+ self.pre_residual_norm = get_norm(dim, norm_type=norm_type)
+ assert self.img_cond is False, "Sandwich normalization is not supported with cross attention."
+ else:
+ self.pre_residual_norm = nn.Identity()
+ self.post_ff_norm = nn.Identity()
+
+ def _get_bias_dropout_scale(self):
+ if self.training:
+ return bias_dropout_add_scale_fused_train
+ else:
+ return bias_dropout_add_scale_fused_inference
+
+ def reset_kv_cache(self, *args, **kwargs):
+ self.attention.reset_kv_cache(*args, **kwargs)
+
+ def set_flex_attention_cache(self, *args, **kwargs):
+ self.attention.set_flex_attention_cache(*args, **kwargs)
+
+ def forward(
+ self,
+ x,
+ rotary_cos_sin=None,
+ c=None,
+ cu_seqlens=None,
+ max_seqlen_in_batch=None,
+ x_cond=None,
+ attention_mask=None,
+ modality=None,
+ start_pos=None,
+ block_mask=None,
+ update_cache_slice=None,
+ ):
+
+ bias_dropout_scale_fn = self._get_bias_dropout_scale()
+
+ if self.time_conditioning:
+ _cond = self.adaLN_modulation(c)
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (_cond if _cond.ndim == 3 else _cond[:, None, :]).chunk(6, dim=2)
+ else:
+ gate_msa, gate_mlp = None, None
+ x_skip = x
+ x = self.norm1(x)
+
+ if self.time_conditioning:
+ x = modulate_fused(x, shift_msa, scale_msa, modality)
+
+ # Self Attention Start
+ x = self.attention(
+ x,
+ rotary_cos_sin=rotary_cos_sin,
+ cu_seqlens=cu_seqlens,
+ max_seqlen_in_batch=max_seqlen_in_batch,
+ x_skip=x_skip,
+ bias_dropout_scale_fn=None if self.sandwich_normalization else bias_dropout_scale_fn,
+ gate_msa=gate_msa,
+ attention_mask=attention_mask,
+ modality=modality,
+ start_pos=start_pos,
+ block_mask=block_mask,
+ update_cache_slice=update_cache_slice,
+ )
+
+ # Self Attention End
+ if self.sandwich_normalization:
+ x = x_skip + self.pre_residual_norm(x)
+
+
+ # Cross Attention Start
+ if self.img_cond:
+ x = self.cross_attention(
+ x,
+ x_cond=x_cond,
+ rotary_cos_sin=rotary_cos_sin,
+ cu_seqlens=cu_seqlens,
+ max_seqlen_in_batch=max_seqlen_in_batch,
+ x_skip=x_skip,
+ bias_dropout_scale_fn=bias_dropout_scale_fn,
+ gate_msa=gate_msa,
+ )
+ # Cross Attention End
+
+ # mlp operation
+ _modality = (modality if self.time_conditioning else None)
+ if self.time_conditioning:
+ # assert not self.sandwich_normalization
+ x = bias_dropout_scale_fn(
+ x=self.post_ff_norm(self.mlp(modulate_fused(self.norm2(x), shift_mlp, scale_mlp, modality))),
+ bias=None,
+ scale=gate_mlp,
+ residual=x,
+ prob=self.dropout,
+ modality=_modality,
+ )
+ else:
+ x = bias_dropout_scale_fn(
+ x=self.post_ff_norm(self.mlp(self.norm2(x))),
+ bias=None,
+ scale=None,
+ residual=x,
+ prob=self.dropout,
+ modality=_modality,
+ )
+
+ return x
+
+
+class EmbeddingLayer(nn.Module):
+ def __init__(self, dim, vocab_dim):
+ super().__init__()
+ self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
+ torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))
+
+ def forward(self, x):
+ return self.embedding[x]
+
+
+def get_2d_rope(seq_len_2d, dim, linear_factor):
+ seq_len_2d_side = int(math.sqrt(seq_len_2d))
+ assert seq_len_2d_side**2 == seq_len_2d, f"seq_len_2d must be a square number, got {seq_len_2d}"
+ if linear_factor is not None:
+ rprint(f"Using Scale factor: {linear_factor}")
+ ntk_factor = 1.0
+ rotary_emb_2d = get_2d_rotary_pos_embed_lumina(
+ dim,
+ seq_len_2d_side,
+ seq_len_2d_side,
+ linear_factor=linear_factor,
+ ntk_factor=ntk_factor,
+ )
+ cos_2d_emb = rotary_emb_2d.flatten(0, 1).real
+ sin_2d_emb = rotary_emb_2d.flatten(0, 1).imag
+ return cos_2d_emb, sin_2d_emb
+
+class DDitFinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels, cond_dim, time_conditioning=True, norm_type="layernorm", zero_linear_init=True):
+ super().__init__()
+ self.time_conditioning = time_conditioning
+ self.norm_final = get_norm(hidden_size, norm_type=norm_type)
+
+ linear_kwargs = dict()
+ self.linear = get_linear(hidden_size, out_channels, **linear_kwargs)
+
+ if zero_linear_init:
+ self.linear.weight.data.zero_()
+ self.linear.bias.data.zero_()
+ else:
+ self.linear.bias.data.zero_()
+
+ if self.time_conditioning:
+ self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True)
+ self.adaLN_modulation.weight.data.zero_()
+ self.adaLN_modulation.bias.data.zero_()
+
+ def forward(self, x, c, modality):
+ if self.time_conditioning:
+ _cond = self.adaLN_modulation(c)
+ shift, scale = (_cond if _cond.ndim == 3 else _cond[:, None, :]).chunk(2, dim=2)
+ x = modulate_fused(self.norm_final(x), shift, scale, modality)
+ else:
+ x = self.norm_final(x)
+
+ x = self.linear(x)
+ return x
+
+
+class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
+ def __init__(self, config, vocab_size: int, text_vocab_size: int, mask_index: int, dtype=None, device=None, static_img_sl=None, static_txt_sl=None, **kwargs):
+ super().__init__()
+ if type(config) == dict:
+ config = omegaconf.OmegaConf.create(config)
+
+ self.config = config
+ self.autocast_dtype = dtype
+ self.vocab_size = vocab_size
+ self.text_vocab_size = text_vocab_size
+ self.time_conditioning = config.time_conditioning or getattr(self.config.model, "force_time_conditioning", False)
+ self.use_gradient_checkpointing = getattr(config.trainer, "use_gradient_checkpointing", False)
+ self.img_cond = getattr(config.model, "img_cond", False)
+ self.mask_index = mask_index
+ self.force_cast_bf16 = (self.autocast_dtype == torch.float32)
+ self.use_flash_attn_3 = getattr(config.model, "use_flash_attn_3", False)
+ self.use_spda_attn = getattr(config.model, "use_spda_attn", False)
+ self.compile_flag_pos_emb = getattr(config.trainer, "compile_flag_pos_emb", False)
+ self.sandwich_normalization = getattr(config.model, "sandwich_normalization", False)
+ self.use_kv_cache = getattr(config.model, "use_kv_cache", False)
+ self.use_flex_attention = getattr(config.model, "use_flex_attention", False)
+ self.static_img_sl = static_img_sl
+ self.static_txt_sl = static_txt_sl
+ self.causal = not config.model.full_attention
+
+ if getattr(config.model, "use_flash_attn_3", False):
+ assert not failed_to_import_fa3
+
+ if getattr(self.config.model, "cond_label", False):
+ self.y_embedder = LabelEmbedderCFG(1000, config.model.cond_dim, 0.1)
+
+ if getattr(config.model, "use_pretrained_img_emb", False):
+ from model import get_vae
+
+ self.vocab_embed = EmbeddingLayer(config.model.hidden_size, text_vocab_size + 1)
+ if getattr(config.model, "freeze_txt_emb", False):
+ self.vocab_embed.requires_grad_(False)
+ device = next(iter(self.vocab_embed.parameters())).device
+ vae = get_vae(config, device)
+ self.img_vocab_embed = vae.quantize.embedding
+ if self.time_conditioning: # TODO: Debug
+ rprint("Requires grad: False")
+ self.img_vocab_embed.requires_grad_(False)
+ self.img_vocab_proj = get_linear(self.img_vocab_embed.embedding_dim, config.model.hidden_size)
+ self.split_embed = True
+ self.new_mask_index = text_vocab_size
+ rprint(f"Using pretrained image embedding. Projecting from: {self.img_vocab_embed.embedding_dim} to {config.model.hidden_size}")
+ else:
+ self.vocab_embed = EmbeddingLayer(config.model.hidden_size, vocab_size)
+ self.split_embed = False
+
+ self.is_compiled = getattr(config.trainer, "compile", False)
+ if self.img_cond:
+ if getattr(config.model, "use_pretrained_img_emb", False):
+ cond_vae = get_vae(config, device, use_cond=True)
+ self.cond_img_vocab_embed = cond_vae.quantize.embedding
+ self.cond_img_vocab_proj = get_linear(self.cond_img_vocab_embed.embedding_dim, config.model.hidden_size)
+ else:
+ self.cond_img_vocab_embed = EmbeddingLayer(config.model.hidden_size, config.model.cond_image_vocab_size)
+
+ img_cond_blocks = []
+ for idx in range(8):
+ img_cond_blocks.append(
+ DDiTBlock(
+ config.model.hidden_size,
+ config.model.n_heads,
+ config.model.cond_dim,
+ dropout=config.model.dropout,
+ img_cond=False,
+ time_conditioning=self.time_conditioning,
+ attn_type=config.model.attn_type,
+ is_compiled=self.is_compiled,
+ force_varlen_attn=config.model.force_varlen_attn,
+ force_cast_bf16=self.force_cast_bf16,
+ norm_type=config.model.norm_type,
+ qk_norm=config.model.qk_norm,
+ use_flash_attn_3=self.use_flash_attn_3,
+ use_spda_attn=self.use_spda_attn,
+ compile_flag_pos_emb=self.compile_flag_pos_emb,
+ sandwich_normalization=self.sandwich_normalization,
+ causal=not config.model.full_attention,
+ use_kv_cache=self.use_kv_cache,
+ use_flex_attention=self.use_flex_attention,
+ idx=idx,
+ attn_dropout=getattr(config.model, "attn_dropout", None),
+ )
+ )
+ self.img_cond_blocks = nn.ModuleList(img_cond_blocks)
+ self.img_cond_rotary_emb = Rotary(config.model.hidden_size // config.model.n_heads)
+ assert not self.is_compiled, "Need to fix rotary embeddings"
+
+ self.sigma_map = None
+ if self.time_conditioning and getattr(self.config.model, "cond_label", False) is False:
+ self.sigma_map = TimestepEmbedder(config.model.cond_dim)
+ rprint(f"Using timestep embedder with dim: {config.model.cond_dim}")
+
+ self.use_legacy_rotary = False
+ self.modality_embed = None
+
+ if self.config.model.modality_embed:
+ self.modality_embed = EmbeddingLayer(self.config.model.hidden_size, 2)
+
+ continuous_mode = self.config.trainer.image_mode == "continuous"
+ if continuous_mode:
+ assert getattr(config.model, "vae_type", None) == "stable_diffusion"
+ # an extra projection layer for the continuous diffusion
+ self.continuous_img_proj = get_linear(4 * (config.model.patching_downscale ** 2), config.model.hidden_size) # todo remove 4 (vae hardcode)
+
+ if self.config.model.rope_2d:
+ seq_len_1d = self.config.model.txt_length
+ seq_len_2d = self.config.model.img_length
+ linear_factor = getattr(config.model, "linear_factor", 1.0)
+ dim = config.model.hidden_size // config.model.n_heads
+
+ if self.config.data.require_sample_ids:
+ for seq_len_2d, linear_factor in ((256, 1), (1024, 2), (2304, 3), (4096, 4)):
+ cos_2d_emb, sin_2d_emb = get_2d_rope(seq_len_2d, dim, linear_factor)
+ self.register_buffer(f'rotary_cos_emb_img_{seq_len_2d}', cos_2d_emb, persistent=False)
+ self.register_buffer(f'rotary_sin_emb_img_{seq_len_2d}', sin_2d_emb, persistent=False)
+
+ max_images_in_sequence = 16
+ self.img_count_embedding = nn.Parameter(torch.zeros((max_images_in_sequence, config.model.hidden_size)))
+ else:
+ cos_2d_emb, sin_2d_emb = get_2d_rope(seq_len_2d, dim, linear_factor)
+ self.register_buffer('rotary_cos_emb_img', cos_2d_emb, persistent=False)
+ self.register_buffer('rotary_sin_emb_img', sin_2d_emb, persistent=False)
+
+ rotary_emb_1d = Rotary(dim)(seq_len_1d)
+ cos_1d_emb = rotary_emb_1d[0][0, :, 0, 0, : cos_2d_emb.shape[1]]
+ sin_1d_emb = rotary_emb_1d[1][0, :, 0, 0, : sin_2d_emb.shape[1]]
+
+ if self.config.trainer.multimodal_batches:
+ seq_len_1d = self.config.model.length
+ rotary_emb_1d = Rotary(config.model.hidden_size // config.model.n_heads)(seq_len_1d)
+ cos_1d_emb = rotary_emb_1d[0][0,:,0, 0,: cos_2d_emb.shape[1]]
+ sin_1d_emb = rotary_emb_1d[1][0,:,0, 0,: sin_2d_emb.shape[1]]
+ self.register_buffer('rotary_cos_emb_txt', cos_1d_emb, persistent=False)
+ self.register_buffer('rotary_sin_emb_txt', sin_1d_emb, persistent=False)
+ else:
+ seq_len_1d = self.config.model.length
+ self.rotary_emb_1d = Rotary(config.model.hidden_size // config.model.n_heads)(seq_len_1d)
+ cos_1d_emb = self.rotary_emb_1d[0][0,:,0, 0,: self.rotary_emb_1d[0].shape[-1] // 2]
+ sin_1d_emb = self.rotary_emb_1d[1][0,:,0, 0,: self.rotary_emb_1d[1].shape[-1] // 2]
+ self.register_buffer('rotary_cos_emb', cos_1d_emb, persistent=False)
+ self.register_buffer('rotary_sin_emb', sin_1d_emb, persistent=False)
+
+ blocks = []
+ for idx in range(config.model.n_blocks):
+ blocks.append(
+ DDiTBlock(
+ config.model.hidden_size,
+ config.model.n_heads,
+ config.model.cond_dim,
+ dropout=config.model.dropout,
+ time_conditioning=self.time_conditioning,
+ img_cond=self.img_cond,
+ attn_type=config.model.attn_type,
+ is_compiled=self.is_compiled,
+ force_varlen_attn=config.model.force_varlen_attn,
+ force_cast_bf16=self.force_cast_bf16,
+ norm_type=config.model.norm_type,
+ qk_norm=config.model.qk_norm,
+ use_flash_attn_3=self.use_flash_attn_3,
+ use_spda_attn=self.use_spda_attn,
+ compile_flag_pos_emb=self.compile_flag_pos_emb,
+ sandwich_normalization=self.sandwich_normalization,
+ causal=not config.model.full_attention,
+ use_kv_cache=self.use_kv_cache,
+ use_flex_attention=self.use_flex_attention,
+ idx=idx,
+ attn_dropout=getattr(config.model, "attn_dropout", None),
+ )
+ )
+
+ self.blocks = nn.ModuleList(blocks)
+ self.output_layer = DDitFinalLayer(
+ config.model.hidden_size,
+ 1 if config.parameterization == "planner" else vocab_size,
+ config.model.cond_dim,
+ time_conditioning=self.time_conditioning,
+ norm_type=config.model.norm_type,
+ zero_linear_init=config.model.zero_linear_init,
+ )
+
+ if continuous_mode:
+ assert getattr(self.config.model, "vae_type", None) == "stable_diffusion"
+ self.output_later_img = DDitFinalLayer(
+ config.model.hidden_size,
+ 4 * (config.model.patching_downscale ** 2), # todo, remove hardcoding
+ config.model.cond_dim,
+ time_conditioning=self.time_conditioning,
+ norm_type=config.model.norm_type,
+ zero_linear_init=config.model.zero_linear_init,
+ )
+
+ self.scale_by_sigma = config.model.scale_by_sigma
+ self.txt_dropout = getattr(config.model, "txt_dropout", None)
+ if config.parameterization != "ar":
+ rprint(f"Not using AR, disabling txt dropout")
+ self.txt_dropout = None
+
+ self.txt_length = self.config.model.txt_length
+ self.img_length = self.config.model.img_length
+ self.total_length = self.config.model.length
+ assert (self.txt_length + self.img_length == self.total_length) or self.config.trainer.multimodal_batches
+ self.allow_compiled_embed = self.config.model.rope_2d is False and self.config.model.modality_embed is False and not getattr(self.config.model, "disable_allow_compiled_embed", False)
+ self.multimodal_batches = self.config.trainer.multimodal_batches
+ self.rope_2d = self.config.model.rope_2d
+ rprint(f"DIT Found XLA: {is_xla_available}")
+ self.require_sample_ids = self.config.data.require_sample_ids
+
+ if self.config.model.force_optimized_native_attn:
+ assert force_cudnn_spda_context
+ assert self.config.model.use_spda_attn
+
+ def _get_bias_dropout_scale(self):
+ if self.training:
+ return bias_dropout_add_scale_fused_train
+ else:
+ return bias_dropout_add_scale_fused_inference
+
+ def reset_kv_cache(self, *args, **kwargs):
+ for block in self.blocks:
+ block.reset_kv_cache(*args, **kwargs)
+
+ def set_flex_attention_cache(self, *args, **kwargs):
+ for block in self.blocks:
+ block.set_flex_attention_cache(*args, **kwargs)
+
+ def forward(
+ self,
+ indices,
+ sigma=None,
+ label=None,
+ x_cond=None,
+ attention_mask=None,
+ continuous_mode=False,
+ x_img_emb=None,
+ modality=None,
+ start_pos=None,
+ block_mask=None,
+ update_cache_slice=None,
+ sample_ids=None,
+ ):
+ if self.txt_dropout is not None and self.training:
+ mask = torch.rand_like(indices, dtype=torch.float) < self.txt_dropout
+ indices = torch.where(mask & (modality == 0), self.mask_index, indices)
+
+ if self.split_embed:
+ # TODO: This is a bit inefficient
+ text_mask = indices < self.text_vocab_size
+ img_mask = (indices >= self.text_vocab_size) & (indices != self.mask_index)
+ mask_token_mask = indices == self.mask_index
+
+ text_indices = indices.clone()
+ text_indices[~text_mask] = 0 # Set non-text tokens to 0
+ text_indices[mask_token_mask] = self.new_mask_index
+ txt_x = self.vocab_embed(text_indices)
+
+ img_indices = indices.clone() - self.text_vocab_size
+ img_indices[~img_mask] = 0 # Set non-image tokens to 0
+ img_x = self.img_vocab_proj(self.img_vocab_embed(img_indices))
+
+ mask_x = self.vocab_embed(torch.full_like(indices, self.new_mask_index))
+ x = torch.where(text_mask.unsqueeze(-1), txt_x, torch.where(img_mask.unsqueeze(-1), img_x, mask_x))
+ elif continuous_mode:
+ assert sigma is not None
+ text_embed = self.vocab_embed(indices)
+ img_embed = self.continuous_img_proj(x_img_emb)
+ x = torch.where(modality[:, :, None] == 1, img_embed, text_embed)
+ attention_mask_shape = self.total_length if self.use_kv_cache else modality.shape[1]
+ attention_mask = get_transfusion_mask(indices.shape[0], attention_mask_shape, self.txt_length, self.img_length, modality)
+ if self.use_kv_cache:
+ # we only care about (seq_len, cache_len+seq_len)
+ assert self.total_length <= self.inference_max_seq_len
+ seq_len = indices.shape[1]
+ attention_mask = attention_mask[:, start_pos:start_pos+seq_len, :start_pos+seq_len]
+ x = x[:, start_pos:start_pos+seq_len, :]
+ attention_mask = attention_mask.unsqueeze(1).to(x.device) # (B, 1, N_tot, N_tot) for SDPA
+ else:
+ x = self.vocab_embed(indices)
+ x = x.to(self.autocast_dtype)
+ c = None
+ if self.sigma_map is not None:
+ c = F.silu(self.sigma_map(sigma))
+
+ if label is not None:
+ assert c is None
+ c = self.y_embedder(label, train=self.training)
+
+ if x_cond is not None:
+ assert not self.use_kv_cache
+ if self.split_embed:
+ x_cond = self.cond_img_vocab_proj(self.cond_img_vocab_embed(x_cond))
+ else:
+ x_cond = self.cond_img_vocab_embed(x_cond)
+
+ img_cond_rotary_cos_sin = True if self.is_compiled else self.img_cond_rotary_emb(x_cond)
+ img_cond_attention_args = (img_cond_rotary_cos_sin, None, None, None, None, attention_mask, start_pos)
+ with torch.autocast(x_cond.device.type, dtype=self.autocast_dtype):
+ for i in range(len(self.img_cond_blocks)):
+ x_cond = (
+ checkpoint(ckpt_wrapper(self.img_cond_blocks[i]), x_cond, *img_cond_attention_args, use_reentrant=True)
+ if (self.use_gradient_checkpointing and self.training)
+ else self.img_cond_blocks[i](x_cond, *img_cond_attention_args)
+ )
+
+ if self.modality_embed is not None:
+ if self.multimodal_batches:
+ assert modality is not None
+ try:
+ x = x + torch.where((modality == 0).unsqueeze(-1), self.modality_embed(0).unsqueeze(0).unsqueeze(0), self.modality_embed(1).unsqueeze(0).unsqueeze(0))
+ except:
+ breakpoint()
+ else:
+ x[:, self.static_txt_sl] = x[:, self.static_txt_sl] + self.modality_embed(0).unsqueeze(0).unsqueeze(0)
+ x[:, self.static_img_sl] = x[:, self.static_img_sl] + self.modality_embed(1).unsqueeze(0).unsqueeze(0)
+
+ if self.is_compiled and self.allow_compiled_embed:
+ rotary_cos_sin = True
+ else:
+ if self.use_legacy_rotary:
+ rotary_cos_sin = self.rotary_emb(x)
+ else:
+ if self.modality_embed is not None and self.rope_2d and self.multimodal_batches:
+ valid_sl = slice(start_pos, start_pos+x.shape[1]) if start_pos is not None else slice(None, x.shape[1])
+ if self.require_sample_ids:
+ assert modality.shape == indices.shape == sample_ids.shape
+ cos = torch.zeros((x.shape[0], *self.rotary_cos_emb_txt.shape), device=x.device, dtype=x.dtype)
+ sin = torch.zeros((x.shape[0], *self.rotary_sin_emb_txt.shape), device=x.device, dtype=x.dtype)
+ modality_mask = modality.bool()
+ @torch.compiler.disable()
+ def fn():
+ add_img_data_to_blocks(x, cos, modality_mask, sample_ids, {
+ 256: self.rotary_cos_emb_img_256,
+ 1024: self.rotary_cos_emb_img_1024,
+ 2304: self.rotary_cos_emb_img_2304,
+ 4096: self.rotary_cos_emb_img_4096
+ }, self.img_count_embedding)
+ add_img_data_to_blocks(None, sin, modality_mask, sample_ids, {
+ 256: self.rotary_sin_emb_img_256,
+ 1024: self.rotary_sin_emb_img_1024,
+ 2304: self.rotary_sin_emb_img_2304,
+ 4096: self.rotary_sin_emb_img_4096
+ }, None)
+ add_txt_data_to_blocks(cos, modality_mask, sample_ids, self.rotary_cos_emb_txt)
+ add_txt_data_to_blocks(sin, modality_mask, sample_ids, self.rotary_sin_emb_txt)
+
+ fn()
+ rotary_cos_sin = (cos, sin)
+ elif modality.shape[-1] != self.img_length:
+ # Pretty hacky but we want to support the following batch: [[text img], [text], [img]]
+ pad_size = modality.shape[-1] - self.img_length
+ pad_size = max(pad_size, 0)
+ padding = torch.full((1, pad_size, self.rotary_cos_emb_img.shape[-1]), torch.nan, device=x.device, dtype=x.dtype)
+ rotary_cos_sin = (
+ torch.where(modality[:, :, None] == 0, self.rotary_cos_emb_txt[None, valid_sl], torch.cat([padding, self.rotary_cos_emb_img[None, valid_sl]], dim=1)[:, valid_sl]).squeeze(0),
+ torch.where(modality[:, :, None] == 0, self.rotary_sin_emb_txt[None, valid_sl], torch.cat([padding, self.rotary_sin_emb_img[None, valid_sl]], dim=1)[:, valid_sl]).squeeze(0)
+ )
+ else:
+ rotary_cos_sin = (
+ torch.where(modality[:, :, None] == 0, self.rotary_cos_emb_txt[None, valid_sl], self.rotary_cos_emb_img[None, valid_sl]).squeeze(0),
+ torch.where(modality[:, :, None] == 0, self.rotary_sin_emb_txt[None, valid_sl], self.rotary_sin_emb_img[None, valid_sl]).squeeze(0)
+ )
+ else:
+ rotary_cos_sin = (self.rotary_cos_emb, self.rotary_sin_emb)
+
+ if start_pos is not None: assert self.use_kv_cache
+ if self.use_kv_cache and start_pos is not None:
+ cos, sin = rotary_cos_sin
+ seq_len = x.shape[1]
+ if cos.ndim == 3:
+ rotary_cos_sin = (
+ cos[:, start_pos:start_pos+seq_len],
+ sin[:, start_pos:start_pos+seq_len]
+ )
+ elif cos.ndim == 2:
+ rotary_cos_sin = (
+ cos[start_pos:start_pos+seq_len],
+ sin[start_pos:start_pos+seq_len]
+ )
+ else:
+ raise ValueError(f"Invalid rotary cos and sin shape for KV cache slicing: {cos.shape}")
+
+ if self.causal and self.use_flex_attention and block_mask is None and not (self.use_kv_cache and start_pos is not None):
+ # For causal, we do not need a mask if we are using KV cache
+ block_mask = create_block_mask(causal_mask, B=None, H=None, Q_LEN=x.shape[1], KV_LEN=x.shape[1])
+
+ attention_args = (rotary_cos_sin, c, None, None, x_cond, attention_mask, modality, start_pos, block_mask, update_cache_slice)
+ with torch.autocast(x.device.type, dtype=self.autocast_dtype):
+ for i in range(len(self.blocks)):
+ x = (
+ checkpoint(ckpt_wrapper(self.blocks[i]), x, *attention_args, use_reentrant=True)
+ if (self.use_gradient_checkpointing and self.training)
+ else self.blocks[i](x, *attention_args)
+ )
+
+ if continuous_mode:
+ x_img_emb = self.output_later_img(x, c, modality)
+
+ x = self.output_layer(x, c, modality)
+
+ if continuous_mode:
+ return (x, x_img_emb)
+
+ return x
diff --git a/models/dit_orig.py b/models/dit_orig.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8d09d98e2edbce6f3acb7cec64a5a14ebfaaf3e
--- /dev/null
+++ b/models/dit_orig.py
@@ -0,0 +1,373 @@
+import math
+import typing
+
+import flash_attn
+import flash_attn.layers.rotary
+import huggingface_hub
+import omegaconf
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+# Flags required to enable jit fusion kernels
+torch._C._jit_set_profiling_mode(False)
+torch._C._jit_set_profiling_executor(False)
+torch._C._jit_override_can_fuse_on_cpu(True)
+torch._C._jit_override_can_fuse_on_gpu(True)
+
+
+def bias_dropout_add_scale(
+ x: torch.Tensor,
+ bias: typing.Optional[torch.Tensor],
+ scale: torch.Tensor,
+ residual: typing.Optional[torch.Tensor],
+ prob: float,
+ training: bool) -> torch.Tensor:
+ if bias is not None:
+ out = scale * F.dropout(x + bias, p=prob, training=training)
+ else:
+ out = scale * F.dropout(x, p=prob, training=training)
+
+ if residual is not None:
+ out = residual + out
+ return out
+
+
+def get_bias_dropout_add_scale(training):
+ def _bias_dropout_add(x, bias, scale, residual, prob):
+ return bias_dropout_add_scale(
+ x, bias, scale, residual, prob, training)
+
+ return _bias_dropout_add
+
+
+# function overload
+def modulate(x: torch.Tensor,
+ shift: torch.Tensor,
+ scale: torch.Tensor) -> torch.Tensor:
+ return x * (1 + scale) + shift
+
+
+@torch.jit.script
+def bias_dropout_add_scale_fused_train(
+ x: torch.Tensor,
+ bias: typing.Optional[torch.Tensor],
+ scale: torch.Tensor,
+ residual: typing.Optional[torch.Tensor],
+ prob: float) -> torch.Tensor:
+ return bias_dropout_add_scale(
+ x, bias, scale, residual, prob, True)
+
+
+@torch.jit.script
+def bias_dropout_add_scale_fused_inference(
+ x: torch.Tensor,
+ bias: typing.Optional[torch.Tensor],
+ scale: torch.Tensor,
+ residual: typing.Optional[torch.Tensor],
+ prob: float) -> torch.Tensor:
+ return bias_dropout_add_scale(
+ x, bias, scale, residual, prob, False)
+
+
+@torch.jit.script
+def modulate_fused(x: torch.Tensor,
+ shift: torch.Tensor,
+ scale: torch.Tensor) -> torch.Tensor:
+ return modulate(x, shift, scale)
+
+
+class Rotary(torch.nn.Module):
+ def __init__(self, dim, base=10_000):
+ super().__init__()
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer('inv_freq', inv_freq)
+ self.seq_len_cached = None
+ self.cos_cached = None
+ self.sin_cached = None
+
+ def forward(self, x, seq_dim=1):
+ seq_len = x.shape[seq_dim]
+ if seq_len != self.seq_len_cached:
+ self.seq_len_cached = seq_len
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ # dims are: batch, seq_len, qkv, head, dim
+ self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1)
+ self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1)
+ # This makes the transformation on v an identity.
+ self.cos_cached[:,:,2,:,:].fill_(1.)
+ self.sin_cached[:,:,2,:,:].fill_(0.)
+
+ return self.cos_cached, self.sin_cached
+
+
+def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(qkv, cos, sin):
+ cos = cos[0,:,0,0,:cos.shape[-1]//2]
+ sin = sin[0,:,0,0,:sin.shape[-1]//2]
+ return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)
+
+
+# function overload
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+#################################################################################
+# Layers #
+#################################################################################
+class LayerNorm(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones([dim]))
+ self.dim = dim
+ def forward(self, x):
+ with torch.cuda.amp.autocast(enabled=False):
+ x = F.layer_norm(x.float(), [self.dim])
+ return x * self.weight[None,None,:]
+
+
+def residual_linear(x, W, x_skip, residual_scale):
+ """x_skip + residual_scale * W @ x"""
+ dim_out, dim_in = W.shape[0], W.shape[1]
+ return torch.addmm(
+ x_skip.view(-1, dim_out),
+ x.view(-1, dim_in),
+ W.T,
+ alpha=residual_scale).view(*x.shape[:-1], dim_out)
+
+
+#################################################################################
+# Embedding Layers for Timesteps and Class Labels #
+#################################################################################
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True))
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ - math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding,
+ torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class LabelEmbedder(nn.Module):
+ """Embeds class labels into vector representations.
+
+ Also handles label dropout for classifier-free guidance.
+ """
+ def __init__(self, num_classes, cond_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes + 1, cond_size)
+ self.num_classes = num_classes
+
+ # TODO think of initializing with 0.02 std deviation like in original DiT paper
+
+ def forward(self, labels):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+
+#################################################################################
+# Core Model #
+#################################################################################
+
+
+class DDiTBlock(nn.Module):
+ def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1):
+ super().__init__()
+ self.n_heads = n_heads
+
+ self.norm1 = LayerNorm(dim)
+ self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
+ self.attn_out = nn.Linear(dim, dim, bias=False)
+ self.dropout1 = nn.Dropout(dropout)
+
+ self.norm2 = LayerNorm(dim)
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, mlp_ratio * dim, bias=True),
+ nn.GELU(approximate='tanh'),
+ nn.Linear(mlp_ratio * dim, dim, bias=True))
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout = dropout
+
+ self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
+ self.adaLN_modulation.weight.data.zero_()
+ self.adaLN_modulation.bias.data.zero_()
+
+
+ def _get_bias_dropout_scale(self):
+ if self.training:
+ return bias_dropout_add_scale_fused_train
+ else:
+ return bias_dropout_add_scale_fused_inference
+
+
+ def forward(self, x, rotary_cos_sin, c, seqlens=None):
+ batch_size, seq_len = x.shape[0], x.shape[1]
+
+ bias_dropout_scale_fn = self._get_bias_dropout_scale()
+
+ (shift_msa, scale_msa, gate_msa, shift_mlp,
+ scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
+
+ # attention operation
+ x_skip = x
+ x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
+
+ qkv = self.attn_qkv(x)
+ qkv = rearrange(qkv,
+ 'b s (three h d) -> b s three h d',
+ three=3,
+ h=self.n_heads)
+ with torch.cuda.amp.autocast(enabled=False):
+ cos, sin = rotary_cos_sin
+ qkv = apply_rotary_pos_emb(
+ qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
+ if seqlens is None:
+ cu_seqlens = torch.arange(
+ 0, (batch_size + 1) * seq_len, step=seq_len,
+ dtype=torch.int32, device=qkv.device)
+ else:
+ cu_seqlens = seqlens.cumsum(-1)
+ x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func(
+ qkv, cu_seqlens, seq_len, 0., causal=False)
+
+ x = rearrange(x, '(b s) h d -> b s (h d)', b=batch_size)
+
+ x = bias_dropout_scale_fn(self.attn_out(x),
+ None,
+ gate_msa,
+ x_skip,
+ self.dropout)
+
+ # mlp operation
+ x = bias_dropout_scale_fn(
+ self.mlp(modulate_fused(
+ self.norm2(x), shift_mlp, scale_mlp)),
+ None, gate_mlp, x, self.dropout)
+ return x
+
+
+
+class EmbeddingLayer(nn.Module):
+ def __init__(self, dim, vocab_dim):
+ super().__init__()
+ self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
+ torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))
+
+ def forward(self, x):
+ return self.embedding[x]
+
+
+class DDitFinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels, cond_dim):
+ super().__init__()
+ self.norm_final = LayerNorm(hidden_size)
+ self.linear = nn.Linear(hidden_size, out_channels)
+ self.linear.weight.data.zero_()
+ self.linear.bias.data.zero_()
+
+ self.adaLN_modulation = nn.Linear(cond_dim,
+ 2 * hidden_size,
+ bias=True)
+ self.adaLN_modulation.weight.data.zero_()
+ self.adaLN_modulation.bias.data.zero_()
+
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
+ x = modulate_fused(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
+ def __init__(self, config, vocab_size: int):
+ super().__init__()
+ if type(config) == dict:
+ config = omegaconf.OmegaConf.create(config)
+
+ self.config = config
+ self.vocab_size = vocab_size
+
+ self.vocab_embed = EmbeddingLayer(config.model.hidden_size,
+ vocab_size)
+ self.sigma_map = TimestepEmbedder(config.model.cond_dim)
+ self.rotary_emb = Rotary(
+ config.model.hidden_size // config.model.n_heads)
+
+ blocks = []
+ for _ in range(config.model.n_blocks):
+ blocks.append(DDiTBlock(config.model.hidden_size,
+ config.model.n_heads,
+ config.model.cond_dim,
+ dropout=config.model.dropout))
+ self.blocks = nn.ModuleList(blocks)
+
+ self.output_layer = DDitFinalLayer(
+ config.model.hidden_size,
+ vocab_size,
+ config.model.cond_dim)
+ self.scale_by_sigma = config.model.scale_by_sigma
+
+ def _get_bias_dropout_scale(self):
+ if self.training:
+ return bias_dropout_add_scale_fused_train
+ else:
+ return bias_dropout_add_scale_fused_inference
+
+ def forward(self, indices, sigma, **kwargs):
+ if sigma is None:
+ sigma = torch.zeros(indices.shape[0], device=indices.device)
+
+ x = self.vocab_embed(indices)
+ c = F.silu(self.sigma_map(sigma))
+
+ rotary_cos_sin = self.rotary_emb(x)
+
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+ for i in range(len(self.blocks)):
+ x = self.blocks[i](x, rotary_cos_sin, c, seqlens=None)
+ x = self.output_layer(x, c)
+
+ return x
\ No newline at end of file
diff --git a/models/elm_custom.py b/models/elm_custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5c5a0ee6635c8093f666456bf53ceec73a484ea
--- /dev/null
+++ b/models/elm_custom.py
@@ -0,0 +1,1050 @@
+#
+# For licensing see accompanying LICENSE file.
+# Copyright (C) 2024 Apple Inc. All Rights Reserved.
+#
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import Tensor, nn
+from torch.nn import CrossEntropyLoss
+from torch.nn import functional as F
+from transformers import PreTrainedModel
+from transformers.activations import ACT2FN
+try:
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
+except:
+ pass
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+)
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+# this import has to be relative, otherwise, when setting trust_remote_code=True
+# huggingface transformers won't be able to load the module correctly
+from models.configuration_openelm_local import OpenELMConfig, make_divisible
+
+
+class OpenELMRMSNorm(nn.Module):
+ def __init__(self, num_features: int, eps: float = 1e-6):
+ """
+ Initialize the OpenELMRMSNorm normalization layer.
+
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+
+ """
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(num_features))
+ self.num_features = num_features
+
+ def _norm(self, x: Tensor) -> Tensor:
+ """
+ Apply the OpenELMRMSNorm normalization to the input tensor.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The normalized tensor.
+
+ """
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """
+ Forward pass through the OpenELMRMSNorm layer.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The output tensor after applying OpenELMRMSNorm.
+
+ """
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+ def extra_repr(self) -> str:
+ return (
+ super().extra_repr() + f"num_features={self.num_features}, eps={self.eps}"
+ )
+
+
+class OpenELMPreTrainedModel(PreTrainedModel):
+ config_class = OpenELMConfig
+ base_model_prefix = "transformer"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["OpenELMDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+
+ def __init__(self, *inputs, **kwargs) -> None:
+ super().__init__(*inputs, **kwargs)
+
+ def _init_weights(self, module: nn.Module) -> None:
+ """Initialize the weights."""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, OpenELMRMSNorm):
+ module.weight.data.fill_(1.0)
+
+
+def _rotate_half(x: Tensor) -> Tensor:
+ x1, x2 = x.chunk(2, dim=-1)
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def _apply_rotary_pos_emb(x: Tensor, pos_sin: Tensor, pos_cos: Tensor) -> Tensor:
+ return (x * pos_cos) + (_rotate_half(x) * pos_sin)
+
+
+class OpenELMRotaryEmbedding(torch.nn.Module):
+ """
+ The rotary position embeddings (aka RoPE) from `RoFormer `_.
+
+ RoPE encodes the position information of tokens using a rotation matrix, and is able to capture
+ explicit relative positional dependencies.
+
+ Args:
+ model_dim: The dimensionality of the model's hidden state.
+ max_seq_length: Maximum sequence length.
+ freq_constant: A constant used for computing frequencies.
+ """
+
+ def __init__(
+ self, model_dim: int, max_seq_length: int, freq_constant: int = 10000
+ ) -> None:
+ inv_freq = 1.0 / (
+ freq_constant
+ ** (torch.arange(0, model_dim, 2, dtype=torch.float32) / model_dim)
+ )
+ super().__init__()
+
+ self.model_dim = model_dim
+ self.freq_constant = freq_constant
+ self.max_seq_length = max_seq_length
+
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self._cached_cos = None
+ self._cached_sin = None
+ self._cached_seq_length = max_seq_length
+ self._compute_sin_cos_embeddings(max_seq_length)
+
+ def extra_repr(self) -> str:
+ return f"\tmodel_dim={self.model_dim}, max_seq_length={self.max_seq_length}, freq_constant={self.freq_constant}"
+
+ def _compute_sin_cos_embeddings(
+ self,
+ key_len: int,
+ key_device: torch.device = torch.device("cpu"),
+ key_dtype: torch.dtype = torch.float32,
+ ) -> None:
+ """
+ Compute sine and cos embeddings.
+
+ Args:
+ key_len: Number of tokens in the key embeddings in the transformer model.
+ device: Device where the key embeddings are stored.
+ key_dtype: Data type of the key embeddings.
+
+ Returns:
+ None
+
+ ...note:
+ We recalculate the sine and cosine embeddings if any of the following conditions are met:
+ 1. The number of tokens in key embeddings are greater than the cached sequence length.
+ 2. Sine and cosine caches are empty.
+ 3. The device and data type of sine and cosine embeddings does not match with the key embeddings.
+ """
+ if (
+ key_len > self._cached_seq_length
+ or self._cached_cos is None
+ or (self._cached_cos is not None and self._cached_cos.device != key_device)
+ or (self._cached_cos is not None and self._cached_cos.dtype != key_dtype)
+ or self._cached_sin is None
+ or (self._cached_sin is not None and self._cached_sin.device != key_device)
+ or (self._cached_sin is not None and self._cached_sin.dtype != key_dtype)
+ ):
+ self._cached_seq_length = max(key_len, self._cached_seq_length)
+
+ # The shape of 'pos_index' is [number of key tokens]
+ pos_index = torch.arange(
+ self._cached_seq_length,
+ dtype=torch.float32,
+ device=self.inv_freq.device,
+ )
+ # The shape of 'pos_index_theta' is [number of key tokens, model dimension]
+ pos_index_theta = torch.einsum("i,j->ij", pos_index, self.inv_freq)
+ # The shape of 'emb' is [number of key tokens, model dimension]
+ emb = torch.cat((pos_index_theta, pos_index_theta), dim=-1)
+
+ # the shape of cos and sin embeddings is [number of key tokens, model_dim]
+ cos_emb = emb.cos().to(dtype=key_dtype, device=key_device)
+ sin_emb = emb.sin().to(dtype=key_dtype, device=key_device)
+
+ # the shape of cached cos and sin embeddings is [1, 1, number of key tokens, model_dim]
+ self._cached_cos = cos_emb[None, None, :, :]
+ self._cached_sin = sin_emb[None, None, :, :]
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ The forward function of RoPE embeddings.
+
+ Args:
+ query: Query embeddings in the transformer model. The shape of query embeddings is
+ [Batch, number of query heads, number of query tokens, model dimension].
+ key: Key embeddings in the transformer model. The shape of key embeddings is
+ [Batch, number of key heads, number of key tokens, model dimension].
+
+ Returns:
+ A tuple containing the query and key embeddings with positional information. The shape of the returned query
+ and key embeddings is the same as the input query and key embeddings respectively.
+
+ ...note:
+ The RoPE embedding computation is done in full-precision. After the computation, input query and key tensors
+ are casted to original input datatype.
+ """
+ dim = key.shape[-1]
+ key_len = key.shape[2]
+ query_len = query.shape[2]
+
+ assert dim == self.model_dim
+ assert key.device == query.device
+ assert key.dtype == query.dtype
+
+ # In the context of self-attention, the lengths of keys and queries are equal.
+ # However, in generation tasks, such as predicting the next token in a sequence, the lengths of keys and queries
+ # can differ. For instance, when employing key-value (KV) caching for sequence prediction, the keys
+ # represent embeddings of previous tokens and the current token, while the query corresponds
+ # to the embedding of the current token only.
+ assert (
+ key_len >= query_len
+ ), "Number of keys has to be greater than or equal to number of queries."
+
+ query_float = query.float()
+ key_float = key.float()
+
+ self._compute_sin_cos_embeddings(
+ key_len, key_device=key_float.device, key_dtype=key_float.dtype
+ )
+ query_float = _apply_rotary_pos_emb(
+ x=query_float,
+ pos_sin=self._cached_sin[..., key_len - query_len : key_len, :],
+ pos_cos=self._cached_cos[..., key_len - query_len : key_len, :],
+ )
+ key_float = _apply_rotary_pos_emb(
+ x=key_float,
+ pos_sin=self._cached_sin[..., :key_len, :],
+ pos_cos=self._cached_cos[..., :key_len, :],
+ )
+
+ return query_float.type_as(query), key_float.type_as(key)
+
+
+class OpenELMMultiHeadCausalAttention(nn.Module):
+ def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.full_attention = config.full_attention
+ head_dim = config.head_dim
+ q_heads = config.num_query_heads[layer_idx]
+ k_heads = config.num_kv_heads[layer_idx]
+ v_heads = config.num_kv_heads[layer_idx]
+
+ self.qkv_proj = nn.Linear(
+ in_features=config.model_dim,
+ out_features=(q_heads + k_heads + v_heads) * head_dim,
+ bias=False,
+ )
+
+ self.pos_embedding = OpenELMRotaryEmbedding(
+ model_dim=config.head_dim,
+ max_seq_length=config.rope_max_length,
+ freq_constant=config.rope_freq_constant,
+ )
+
+ if config.normalize_qk_projections:
+ self.q_norm = OpenELMRMSNorm(
+ num_features=config.head_dim,
+ )
+ self.k_norm = OpenELMRMSNorm(
+ num_features=config.head_dim,
+ )
+ else:
+ self.q_norm = None
+ self.k_norm = None
+
+ self.out_proj = nn.Linear(
+ in_features=q_heads * head_dim,
+ out_features=config.model_dim,
+ bias=False,
+ )
+
+ self.head_dim = config.head_dim
+ self.num_q_heads = q_heads
+ self.num_k_heads = k_heads
+ self.num_v_heads = v_heads
+ self.transformer_dim = config.model_dim
+ self.num_groups = self.num_q_heads // self.num_k_heads
+
+ def extra_repr(self) -> str:
+ return (
+ super().extra_repr()
+ + f"query_heads={self.num_q_heads}, key_heads={self.num_k_heads}, value_heads={self.num_v_heads}"
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """
+ Forward pass of multi-head self-attention.
+
+ Args:
+ hidden_states: Input tensor of the shape [batch size, sequence length, model dimension].
+ past_key_value: Tensor storing the cached keys and values.
+ output_attentions: output attention weights.
+ use_cache: Specifies whether to use kv-cache for generation.
+ cache_position: used for updating the kv-cache.
+
+ Returns:
+ The output of the same shape as the input, optionally with a tensor containing cached keys and values.
+ """
+
+ # scaled_dot_product_attention does not return attention weights, set output_attentions to False
+ output_attentions = False
+ batch_size, seq_length, d_model = hidden_states.size()
+
+ # [B, S, d] --> [B, S, (q_h + k_h + v_h) * h]
+ qkv = self.qkv_proj(hidden_states)
+ # [B, S, (q_h + k_h + v_h) * h] --> [B, S, (q_h + k_h + v_h), h]
+ qkv = qkv.reshape(
+ batch_size,
+ seq_length,
+ self.num_q_heads + self.num_k_heads + self.num_v_heads,
+ self.head_dim,
+ )
+ # [B, S, (q_h + k_h + v_h), h] --> [B, (q_h + k_h + v_h), S, h]
+ qkv = qkv.transpose(1, 2)
+ # [B, (q_h + k_h + v_h), S, h] --> [B, q_h, S h], [B, k_h, S, h], [B, v_h, S, h]
+ queries, keys, values = qkv.split(
+ [self.num_q_heads, self.num_k_heads, self.num_v_heads], dim=1
+ )
+
+ if self.q_norm is not None:
+ queries = self.q_norm(queries)
+
+ if self.k_norm is not None:
+ keys = self.k_norm(keys)
+
+ past_key_value = getattr(self, "past_key_value", past_key_value)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
+ # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ cache_kwargs = {"cache_position": cache_position}
+ keys, values = past_key_value.update(
+ keys, values, self.layer_idx, cache_kwargs
+ )
+
+ # Add positional embedding
+ queries, keys = self.pos_embedding(queries, keys)
+
+ if self.num_groups != 1:
+ # GQA
+ # [B, k_h, S, h] --> [B, q_h, S, h]
+ keys = keys.repeat_interleave(self.num_groups, dim=1)
+ # [B, v_h, S, h] --> [B, q_h, S, h]
+ values = values.repeat_interleave(self.num_groups, dim=1)
+
+ if self.full_attention:
+ is_causal=False
+ attn_output = F.scaled_dot_product_attention(
+ queries,
+ keys,
+ values,
+ is_causal=is_causal,
+ dropout_p=0,
+ )
+ else:
+ causal_mask = attention_mask
+ if attention_mask is not None and cache_position is not None:
+ causal_mask = causal_mask[:, :, cache_position, : keys.shape[-2]]
+ attn_output = F.scaled_dot_product_attention(
+ queries,
+ keys,
+ values,
+ attn_mask=causal_mask,
+ dropout_p=0,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(
+ batch_size, seq_length, self.num_q_heads * self.head_dim
+ )
+ attn_output = self.out_proj(attn_output)
+ if not output_attentions:
+ attn_weights = None
+ return attn_output, attn_weights, past_key_value
+
+
+class OpenELMFeedForwardNetwork(nn.Module):
+ def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
+ super().__init__()
+ ffn_multiplier = config.ffn_multipliers[layer_idx]
+ intermediate_dim = int(
+ make_divisible(
+ ffn_multiplier * config.model_dim,
+ divisor=config.ffn_dim_divisor,
+ )
+ )
+ if config.ffn_with_glu:
+ # FFN with Gated linear unit, as described in https://arxiv.org/abs/2002.05202v1.
+ self.proj_1 = nn.Linear(
+ in_features=config.model_dim,
+ out_features=2 * intermediate_dim,
+ bias=False,
+ )
+ self.proj_2 = nn.Linear(
+ in_features=intermediate_dim,
+ out_features=config.model_dim,
+ bias=False,
+ )
+ self.ffn_with_glu = True
+ else:
+ # Standard FFN, as described in https://arxiv.org/abs/1706.03762
+ self.proj_1 = nn.Linear(
+ in_features=config.model_dim,
+ out_features=intermediate_dim,
+ bias=False,
+ )
+ self.proj_2 = nn.Linear(
+ in_features=intermediate_dim,
+ out_features=config.model_dim,
+ bias=False,
+ )
+ self.ffn_with_glu = False
+
+ self.act = ACT2FN[config.activation_fn_name]
+
+ def extra_repr(self) -> str:
+ return super().extra_repr() + f"(ffn_with_glu) : {self.ffn_with_glu}"
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward function of FFN layer.
+
+ Args:
+ x: Input tensor of the shape [batch size, sequence length, model dimension].
+
+ Returns:
+ A tensor of the same shape as the input.
+ """
+ if self.ffn_with_glu:
+ y_12 = self.proj_1(x)
+ y_1, y_2 = y_12.chunk(2, dim=-1)
+ y = self.act(y_1) * y_2
+ return self.proj_2(y)
+ else:
+ return self.proj_2(self.act(self.proj_1(x)))
+
+
+class OpenELMDecoderLayer(nn.Module):
+ def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
+ super().__init__()
+ self.attn = OpenELMMultiHeadCausalAttention(config=config, layer_idx=layer_idx)
+ self.ffn = OpenELMFeedForwardNetwork(config=config, layer_idx=layer_idx)
+ self.ffn_norm = OpenELMRMSNorm(
+ num_features=config.model_dim,
+ )
+ self.attn_norm = OpenELMRMSNorm(
+ num_features=config.model_dim,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
+ ]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+ residual = hidden_states
+ hidden_states = self.attn_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.ffn_norm(hidden_states)
+ hidden_states = self.ffn(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class OpenELMModel(OpenELMPreTrainedModel):
+ config_class = OpenELMConfig
+
+ def __init__(self, config: OpenELMConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.use_extra_tokens = False
+
+ self.token_embeddings = nn.Embedding(
+ embedding_dim=config.model_dim,
+ num_embeddings=config.vocab_size,
+ )
+
+ if hasattr(config, 'extra_tokens') and config.extra_tokens > 0:
+ self.use_extra_tokens = True
+ self.token_embeddings_extra = nn.Embedding(config.extra_tokens, config.model_dim)
+
+ self.layers = nn.ModuleList(
+ OpenELMDecoderLayer(config=config, layer_idx=layer_idx)
+ for layer_idx in range(config.num_transformer_layers)
+ )
+ self.norm = OpenELMRMSNorm(num_features=config.model_dim)
+ if config.share_input_output_layers:
+ self.classifier = None
+ else:
+ self.classifier = nn.Linear(
+ in_features=config.model_dim,
+ out_features=config.vocab_size,
+ bias=False,
+ )
+ self.num_transformer_layers = config.num_transformer_layers
+ self.gradient_checkpointing = False
+
+ # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
+ # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_context_length`.
+ causal_mask = torch.full(
+ (config.max_context_length, config.max_context_length),
+ fill_value=True,
+ dtype=torch.bool,
+ )
+ self.register_buffer(
+ "causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+ self.reset_parameters(config=config)
+
+ def get_input_embeddings(self):
+ return self.token_embeddings
+
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
+ self.token_embeddings = new_embeddings
+
+ def reset_parameters(self, config: OpenELMConfig) -> None:
+ """Initialize the layers in Language Model
+
+ The initialization scheme is followed, following `OPT `_.
+
+ Args:
+ use_megatron_std: Use standard deviation as described in Megatron-LM.
+
+ Returns:
+ None
+ """
+ for module in self.modules():
+ if isinstance(module, nn.Linear):
+ std = module.in_features**-0.5
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
+ if module.bias is not None:
+ torch.nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Embedding):
+ std = module.embedding_dim**-0.5
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
+ elif isinstance(module, OpenELMRMSNorm):
+ if module.weight is not None:
+ torch.nn.init.ones_(module.weight)
+ if hasattr(module, "bias") and module.bias is not None:
+ torch.nn.init.zeros_(module.bias)
+
+ model_dim = config.model_dim
+ n_layers = config.num_transformer_layers
+ std = (model_dim**-0.5) * ((2 * n_layers) ** -0.5)
+ for param_name, param in self.named_parameters():
+ if param_name.endswith("out_proj.weight") or param_name.endswith(
+ "ffn.proj_2.weight"
+ ):
+ torch.nn.init.normal_(param, mean=0.0, std=std)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if self.use_extra_tokens:
+ combined_weight = torch.cat([self.token_embeddings.weight, self.token_embeddings_extra.weight], dim=0)
+
+ if inputs_embeds is None:
+ if self.use_extra_tokens:
+ inputs_embeds = combined_weight[input_ids]
+ else:
+ inputs_embeds = self.token_embeddings(input_ids)
+
+ if self.training:
+ inputs_embeds.requires_grad_(True)
+
+
+ past_seen_tokens = 0
+ if use_cache: # kept for BC (cache positions)
+ if not isinstance(past_key_values, StaticCache):
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ past_seen_tokens = past_key_values.get_seq_length()
+
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device,
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = None
+ if use_cache:
+ next_cache = (
+ next_decoder_cache.to_legacy_cache()
+ if isinstance(next_decoder_cache, Cache)
+ else next_decoder_cache
+ )
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
+ if v is not None
+ )
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(self, attention_mask, input_tensor):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ batch_size, seq_length = input_tensor.shape[:2]
+ dtype = input_tensor.dtype
+ device = input_tensor.device
+
+ # support going beyond cached `max_position_embedding`
+ if seq_length > self.causal_mask.shape[-1]:
+ causal_mask = torch.full(
+ (2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]),
+ fill_value=1,
+ )
+ self.register_buffer(
+ "causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False
+ )
+
+ # We use the current dtype to avoid any overflows
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = (
+ self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype)
+ * min_dtype
+ )
+
+ causal_mask = causal_mask.to(dtype=dtype, device=device)
+ if attention_mask is not None and attention_mask.dim() == 2:
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[
+ :, None, None, :
+ ].eq(0.0)
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ if self.config._attn_implementation == "sdpa" and attention_mask is not None:
+ # For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
+ is_tracing = (
+ torch.jit.is_tracing()
+ or isinstance(input_tensor, torch.fx.Proxy)
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
+ )
+ if not is_tracing and torch.any(attention_mask != 1):
+ # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = causal_mask.mul(
+ ~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)
+ ).to(dtype)
+
+ return causal_mask
+
+
+class OpenELMForCausalLM(OpenELMPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: OpenELMConfig):
+ super().__init__(config)
+ self.transformer = OpenELMModel(config)
+ self.vocab_size = config.vocab_size
+ self.is_compiled = config.is_compiled
+ if config.share_input_output_layers:
+ self.lm_head = None
+ else:
+ self.lm_head = nn.Linear(config.model_dim, config.vocab_size, bias=False)
+ if hasattr(config, 'extra_tokens') and config.extra_tokens > 0:
+ self.use_extra_tokens = True
+ self.lm_extra = nn.Linear(config.model_dim, config.extra_tokens, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.transformer.token_embeddings
+
+ def set_input_embeddings(self, value):
+ self.transformer.token_embeddings = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.transformer = decoder
+
+ def get_decoder(self):
+ return self.transformer
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ sigma = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.transformer(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ if self.lm_head is None:
+ # shared
+ combined_weight = torch.cat([self.transformer.token_embeddings.weight, self.transformer.token_embeddings_extra.weight], dim=0)
+ logits = F.linear(
+ hidden_states, combined_weight
+ )
+ else:
+ if self.use_extra_tokens:
+ combined_weight = torch.cat([self.transformer.token_embeddings.weight, self.lm_extra.weight], dim=0)
+ logits = F.linear(hidden_states, combined_weight)
+ else:
+ logits = self.lm_head(hidden_states)
+ logits = logits[:, : self.config.vocab_size]
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, shift_logits.shape[-1])
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ if self.is_compiled:
+ return logits
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ **kwargs,
+ ):
+ past_length = 0
+ if past_key_values is not None:
+ if isinstance(past_key_values, Cache):
+ cache_length = past_key_values.get_seq_length()
+ past_length = past_key_values.seen_tokens
+ max_cache_length = past_key_values.get_max_length()
+ else:
+ cache_length = past_length = past_key_values[0][0].shape[2]
+ max_cache_length = None
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
+ # input)
+ if (
+ attention_mask is not None
+ and attention_mask.shape[1] > input_ids.shape[1]
+ ):
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
+ if (
+ max_cache_length is not None
+ and attention_mask is not None
+ and cache_length + input_ids.shape[1] > max_cache_length
+ ):
+ attention_mask = attention_mask[:, -max_cache_length:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ if self.generation_config.cache_implementation == "static":
+ # generation with static cache
+ cache_position = kwargs.get("cache_position", None)
+ if cache_position is None:
+ past_length = 0
+ else:
+ past_length = cache_position[-1] + 1
+ input_ids = input_ids[:, past_length:]
+ position_ids = position_ids[:, past_length:]
+
+ # we should only keep a `cache_position` in generate, and do +=1.
+ # same goes for position ids. Could also help with continued generation.
+ cache_position = torch.arange(
+ past_length,
+ past_length + position_ids.shape[-1],
+ device=position_ids.device,
+ )
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
+ # We could use `next_tokens` directly instead.
+ model_inputs = {"input_ids": input_ids.contiguous()}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids.contiguous(),
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(
+ past_state.index_select(0, beam_idx.to(past_state.device))
+ for past_state in layer_past
+ ),
+ )
+ return reordered_past
diff --git a/models/ema.py b/models/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..916ec22cdfa087e14f54b6ce5212d18ea29cf405
--- /dev/null
+++ b/models/ema.py
@@ -0,0 +1,371 @@
+import torch
+from decoupled_utils import is_torch_xla_available
+try:
+ if not is_torch_xla_available():
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
+ import deepspeed
+except:
+ is_deepspeed_zero3_enabled = lambda: False
+
+class ExponentialMovingAverage:
+ """
+ WARNING: DEPRECATED
+ Maintains (exponential) moving average of a set of parameters.
+ """
+
+ def __init__(self, parameters, decay, use_num_updates=True):
+ """
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; usually the result of
+ `model.parameters()`.
+ decay: The exponential decay.
+ use_num_updates: Whether to use number of updates when computing
+ averages.
+ """
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+ self.decay = decay
+ self.num_updates = 0 if use_num_updates else None
+ self.shadow_params = [p.clone().detach()
+ for p in parameters if p.requires_grad]
+ self.collected_params = []
+
+ def move_shadow_params_to_device(self, device):
+ self.shadow_params = [i.to(device) for i in self.shadow_params]
+
+ def update(self, parameters):
+ """
+ Update currently maintained parameters.
+
+ Call this every time the parameters are updated, such as the result of
+ the `optimizer.step()` call.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; usually the same set of
+ parameters used to initialize this object.
+ """
+ decay = self.decay
+ if self.num_updates is not None:
+ self.num_updates += 1
+ decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
+ one_minus_decay = 1.0 - decay
+ with torch.no_grad():
+ parameters = [p for p in parameters if p.requires_grad]
+ for s_param, param in zip(self.shadow_params, parameters):
+ s_param.sub_(one_minus_decay * (s_param - param))
+
+ def copy_to(self, parameters):
+ """
+ Copy current parameters into given collection of parameters.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored moving averages.
+ """
+ parameters = [p for p in parameters if p.requires_grad]
+ for s_param, param in zip(self.shadow_params, parameters):
+ if param.requires_grad:
+ param.data.copy_(s_param.data)
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
+
+ def state_dict(self):
+ return dict(decay=self.decay,
+ num_updates=self.num_updates,
+ shadow_params=self.shadow_params)
+
+ def load_state_dict(self, state_dict):
+ self.decay = state_dict['decay']
+ self.num_updates = state_dict['num_updates']
+ self.shadow_params = state_dict['shadow_params']
+
+
+
+from diffusers.utils import (
+ is_transformers_available,
+)
+from typing import Iterable, Union, Optional
+import contextlib
+import transformers
+import copy
+
+# Taken from diffusers
+class EMAModel:
+ """
+ Exponential Moving Average of models weights
+ """
+
+ def __init__(
+ self,
+ parameters: Iterable[torch.nn.Parameter],
+ decay: float = 0.9999,
+ min_decay: float = 0.0,
+ update_after_step: int = 0,
+ use_ema_warmup: bool = False,
+ inv_gamma: Union[float, int] = 1.0,
+ power: Union[float, int] = 2 / 3,
+ foreach: bool = False,
+ ):
+ """
+ Args:
+ parameters (Iterable[torch.nn.Parameter]): The parameters to track.
+ decay (float): The decay factor for the exponential moving average.
+ min_decay (float): The minimum decay factor for the exponential moving average.
+ update_after_step (int): The number of steps to wait before starting to update the EMA weights.
+ use_ema_warmup (bool): Whether to use EMA warmup.
+ inv_gamma (float):
+ Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
+ power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
+ foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
+ device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
+ weights will be stored on CPU.
+
+ @crowsonkb's notes on EMA Warmup:
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
+ at 215.4k steps).
+ """
+
+ parameters = list(parameters)
+ self.shadow_params = [p.clone().detach() for p in parameters]
+
+ self.temp_stored_params = None
+
+ self.decay = decay
+ self.min_decay = min_decay
+ self.update_after_step = update_after_step
+ self.use_ema_warmup = use_ema_warmup
+ self.inv_gamma = inv_gamma
+ self.power = power
+ self.optimization_step = 0
+ self.cur_decay_value = None # set in `step()`
+ self.foreach = foreach
+
+ def get_decay(self, optimization_step: int) -> float:
+ """
+ Compute the decay factor for the exponential moving average.
+ """
+ step = max(0, optimization_step - self.update_after_step - 1)
+
+ if step <= 0:
+ return 0.0
+
+ if self.use_ema_warmup:
+ cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
+ else:
+ cur_decay_value = (1 + step) / (10 + step)
+
+ cur_decay_value = min(cur_decay_value, self.decay)
+ # make sure decay is not smaller than min_decay
+ cur_decay_value = max(cur_decay_value, self.min_decay)
+ return cur_decay_value
+
+ @torch.no_grad()
+ def step(self, parameters: Iterable[torch.nn.Parameter]):
+ parameters = list(parameters)
+
+ self.optimization_step += 1
+
+ # Compute the decay factor for the exponential moving average.
+ decay = self.get_decay(self.optimization_step)
+ self.cur_decay_value = decay
+ one_minus_decay = 1 - decay
+
+ context_manager = contextlib.nullcontext
+
+ if self.foreach:
+ if is_transformers_available() and is_deepspeed_zero3_enabled():
+ context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
+
+ with context_manager():
+ params_grad = [param for param in parameters if param.requires_grad]
+ s_params_grad = [
+ s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
+ ]
+
+ if len(params_grad) < len(parameters):
+ torch._foreach_copy_(
+ [s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad],
+ [param for param in parameters if not param.requires_grad],
+ non_blocking=True,
+ )
+
+ torch._foreach_sub_(
+ s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay
+ )
+
+ else:
+ for s_param, param in zip(self.shadow_params, parameters):
+ if is_transformers_available() and is_deepspeed_zero3_enabled():
+ context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
+
+ with context_manager():
+ if param.requires_grad:
+ s_param.sub_(one_minus_decay * (s_param - param))
+ else:
+ s_param.copy_(param)
+
+ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
+ """
+ Copy current averaged parameters into given collection of parameters.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored moving averages. If `None`, the parameters with which this
+ `ExponentialMovingAverage` was initialized will be used.
+ """
+ parameters = list(parameters)
+ if self.foreach:
+ torch._foreach_copy_(
+ [param.data for param in parameters],
+ [s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)],
+ )
+ else:
+ for s_param, param in zip(self.shadow_params, parameters):
+ param.data.copy_(s_param.to(param.device).data)
+
+ def pin_memory(self) -> None:
+ r"""
+ Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for
+ offloading EMA params to the host.
+ """
+
+ self.shadow_params = [p.pin_memory() for p in self.shadow_params]
+
+ def to(self, device=None, dtype=None, non_blocking=False) -> None:
+ r"""Move internal buffers of the ExponentialMovingAverage to `device`.
+
+ Args:
+ device: like `device` argument to `torch.Tensor.to`
+ """
+ # .to() on the tensors handles None correctly
+ self.shadow_params = [
+ p.to(device=device, dtype=dtype, non_blocking=non_blocking)
+ if p.is_floating_point()
+ else p.to(device=device, non_blocking=non_blocking)
+ for p in self.shadow_params
+ ]
+
+ def state_dict(self) -> dict:
+ r"""
+ Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
+ checkpointing to save the ema state dict.
+ """
+ # Following PyTorch conventions, references to tensors are returned:
+ # "returns a reference to the state and not its copy!" -
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
+ return {
+ "decay": self.decay,
+ "min_decay": self.min_decay,
+ "optimization_step": self.optimization_step,
+ "update_after_step": self.update_after_step,
+ "use_ema_warmup": self.use_ema_warmup,
+ "inv_gamma": self.inv_gamma,
+ "power": self.power,
+ "shadow_params": self.shadow_params,
+ }
+
+ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
+ r"""
+ Args:
+ Save the current parameters for restoring later.
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
+
+ def restore(self, parameters: Iterable[torch.nn.Parameter], raise_error_if_already_restored: bool = True) -> None:
+ r"""
+ Args:
+ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
+ affecting the original optimization process. Store the parameters before the `copy_to()` method. After
+ validation (or model saving), use this to restore the former parameters.
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters. If `None`, the parameters with which this
+ `ExponentialMovingAverage` was initialized will be used.
+ """
+ if self.temp_stored_params is None:
+ if raise_error_if_already_restored:
+ raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
+ return
+ if self.foreach:
+ torch._foreach_copy_(
+ [param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
+ )
+ else:
+ for c_param, param in zip(self.temp_stored_params, parameters):
+ param.data.copy_(c_param.data)
+
+ # Better memory-wise.
+ self.temp_stored_params = None
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ r"""
+ Args:
+ Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
+ ema state dict.
+ state_dict (dict): EMA state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ # deepcopy, to be consistent with module API
+ state_dict = copy.deepcopy(state_dict)
+
+ self.decay = state_dict.get("decay", self.decay)
+ if self.decay < 0.0 or self.decay > 1.0:
+ raise ValueError("Decay must be between 0 and 1")
+
+ self.min_decay = state_dict.get("min_decay", self.min_decay)
+ if not isinstance(self.min_decay, float):
+ raise ValueError("Invalid min_decay")
+
+ self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
+ if not isinstance(self.optimization_step, int):
+ raise ValueError("Invalid optimization_step")
+
+ self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
+ if not isinstance(self.update_after_step, int):
+ raise ValueError("Invalid update_after_step")
+
+ self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
+ if not isinstance(self.use_ema_warmup, bool):
+ raise ValueError("Invalid use_ema_warmup")
+
+ self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
+ if not isinstance(self.inv_gamma, (float, int)):
+ raise ValueError("Invalid inv_gamma")
+
+ self.power = state_dict.get("power", self.power)
+ if not isinstance(self.power, (float, int)):
+ raise ValueError("Invalid power")
+
+ shadow_params = state_dict.get("shadow_params", None)
+ if shadow_params is not None:
+ self.shadow_params = shadow_params
+ if not isinstance(self.shadow_params, list):
+ raise ValueError("shadow_params must be a list")
+ if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
+ raise ValueError("shadow_params must all be Tensors")
diff --git a/models/noise_schedule.py b/models/noise_schedule.py
new file mode 100644
index 0000000000000000000000000000000000000000..1625a791a797cf14f2af17ae4320d11fde0bcf0b
--- /dev/null
+++ b/models/noise_schedule.py
@@ -0,0 +1,157 @@
+import abc
+
+import torch
+import torch.nn as nn
+
+# Flags required to enable jit fusion kernels
+torch._C._jit_set_profiling_mode(False)
+torch._C._jit_set_profiling_executor(False)
+torch._C._jit_override_can_fuse_on_cpu(True)
+torch._C._jit_override_can_fuse_on_gpu(True)
+
+
+def get_noise(config, dtype=torch.float32):
+ if config.noise.type == 'geometric':
+ return GeometricNoise(config.noise.sigma_min,
+ config.noise.sigma_max)
+ elif config.noise.type == 'loglinear':
+ return LogLinearNoise()
+ elif config.noise.type == 'cosine':
+ return CosineNoise()
+ elif config.noise.type == 'cosinesqr':
+ return CosineSqrNoise()
+ elif config.noise.type == 'linear':
+ return Linear(config.noise.sigma_min,
+ config.noise.sigma_max,
+ dtype)
+ else:
+ raise ValueError(f'{config.noise.type} is not a valid noise')
+
+
+def binary_discretization(z):
+ z_hard = torch.sign(z)
+ z_soft = z / torch.norm(z, dim=-1, keepdim=True)
+ return z_soft + (z_hard - z_soft).detach()
+
+
+class Noise(abc.ABC, nn.Module):
+ """
+ Baseline forward method to get the total + rate of noise at a timestep
+ """
+ def forward(self, t):
+ # Assume time goes from 0 to 1
+ return self.total_noise(t), self.rate_noise(t)
+
+ @abc.abstractmethod
+ def rate_noise(self, t):
+ """
+ Rate of change of noise ie g(t)
+ """
+ pass
+
+ @abc.abstractmethod
+ def total_noise(self, t):
+ """
+ Total noise ie \int_0^t g(t) dt + g(0)
+ """
+ pass
+
+
+class CosineNoise(Noise):
+ def __init__(self, eps=1e-3):
+ super().__init__()
+ self.eps = eps
+
+ def rate_noise(self, t):
+ cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
+ sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
+ scale = torch.pi / 2
+ return scale * sin / (cos + self.eps)
+
+ def total_noise(self, t):
+ cos = torch.cos(t * torch.pi / 2)
+ return - torch.log(self.eps + (1 - self.eps) * cos)
+
+
+class CosineSqrNoise(Noise):
+ def __init__(self, eps=1e-3):
+ super().__init__()
+ self.eps = eps
+
+ def rate_noise(self, t):
+ cos = (1 - self.eps) * (
+ torch.cos(t * torch.pi / 2) ** 2)
+ sin = (1 - self.eps) * torch.sin(t * torch.pi)
+ scale = torch.pi / 2
+ return scale * sin / (cos + self.eps)
+
+ def total_noise(self, t):
+ cos = torch.cos(t * torch.pi / 2) ** 2
+ return - torch.log(self.eps + (1 - self.eps) * cos)
+
+
+class Linear(Noise):
+ def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
+ super().__init__()
+ self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
+ self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
+
+ def rate_noise(self, t):
+ return self.sigma_max - self.sigma_min
+
+ def total_noise(self, t):
+ return self.sigma_min + t * (self.sigma_max - self.sigma_min)
+
+ def importance_sampling_transformation(self, t):
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
+ return (sigma_t - self.sigma_min) / (
+ self.sigma_max - self.sigma_min)
+
+
+class GeometricNoise(Noise):
+ def __init__(self, sigma_min=1e-3, sigma_max=1):
+ super().__init__()
+ self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
+
+ def rate_noise(self, t):
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
+ self.sigmas[1].log() - self.sigmas[0].log())
+
+ def total_noise(self, t):
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
+
+from decoupled_utils import is_torch_xla_available
+is_xla_available = is_torch_xla_available()
+
+class LogLinearNoise(Noise):
+ """Log Linear noise schedule.
+
+ Built such that 1 - 1/e^(n(t)) interpolates between 0 and
+ ~1 when t varies from 0 to 1. Total noise is
+ -log(1 - (1 - eps) * t), so the sigma will be
+ (1 - eps) * t.
+ """
+ def __init__(self, eps=1e-3):
+ super().__init__()
+ self.eps = eps
+ self.sigma_max = self.total_noise(torch.tensor(1.0, dtype=torch.float32))
+ self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0, dtype=torch.float32))
+
+ def rate_noise(self, t):
+ return (1 - self.eps) / (1 - (1 - self.eps) * t)
+
+ def total_noise(self, t):
+ if is_xla_available:
+ # XLA breaks here with large batch sizes...
+ return -torch.log(1 + (-(1 - self.eps) * t.to(torch.float64))).to(t.dtype)
+ else:
+ return -torch.log1p(-(1 - self.eps) * t)
+
+ def importance_sampling_transformation(self, t):
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
+ t = - torch.expm1(- sigma_t) / (1 - self.eps)
+ return t
diff --git a/models/standalone_rotary.py b/models/standalone_rotary.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1eaac1b3c799c48574972751647f498ad2b4ed1
--- /dev/null
+++ b/models/standalone_rotary.py
@@ -0,0 +1,117 @@
+from typing import Tuple
+import torch
+from einops import rearrange, repeat
+
+def flash_torch_rotate_half(x, interleaved=False):
+ if not interleaved:
+ x1, x2 = x.chunk(2, dim=-1)
+ return torch.cat((-x2, x1), dim=-1)
+ else:
+ x1, x2 = x[..., ::2], x[..., 1::2]
+ return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
+
+
+def flash_torch_apply_rotary_emb_torch(x, cos, sin, interleaved=False):
+ """
+ x: (batch_size, seqlen, nheads, headdim)
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
+ """
+ if x.shape[-3] < cos.shape[-2]:
+ # this fixes AR bug but NOT kv cache slicing
+ cos = cos[..., :x.shape[1], :]
+ sin = sin[..., :x.shape[1], :]
+
+ ro_dim = cos.shape[-1] * 2
+ assert ro_dim <= x.shape[-1]
+ cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
+ sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
+ return torch.cat(
+ [x[..., :ro_dim] * cos + flash_torch_rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
+ dim=-1,
+ )
+
+def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
+ """Applies the rotary embedding to the query and key tensors."""
+ x_ = torch.view_as_complex(
+ torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1),
+ dim=-1))
+ x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
+ x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
+ x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
+ -1).transpose(1, 2)
+ return x_out
+
+def rotate_half_(x):
+ x1, x2 = x.chunk(2, dim=-1)
+ return torch.cat((-x2, x1), dim=-1)
+
+def apply_rotary_pos_emb_(x, cos, sin):
+ # NOTE: This could probably be moved to Triton
+
+ # Handle a possible sequence length mismatch in between q and k
+ cos = cos[:, :, : x.shape[-2], :]
+ sin = sin[:, :, : x.shape[-2], :]
+
+ return (x * cos) + (rotate_half_(x) * sin)
+
+class StandaloneRotaryEmbedding(torch.nn.Module):
+ """
+ The rotary position embeddings from RoFormer_ (Su et. al).
+ A crucial insight from the method is that the query and keys are
+ transformed by rotation matrices which depend on the relative positions.
+
+ Other implementations are available in the Rotary Transformer repo_ and in
+ GPT-NeoX_, GPT-NeoX was an inspiration
+
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
+
+
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
+ """
+
+ def __init__(self, dim_model: int, *_, **__):
+ super().__init__()
+ # Generate and save the inverse frequency buffer (non trainable)
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
+ self.register_buffer("inv_freq", inv_freq)
+
+ self._seq_len_cached = None
+ self._cos_cached = None
+ self._sin_cached = None
+
+ def _update_cos_sin_tables(self, x, seq_dimension=1):
+ seq_len = x.shape[seq_dimension]
+
+ # Reset the tables if the sequence length has changed,
+ # or if we're on a new device (possibly due to tracing for instance)
+ if (
+ seq_len != self._seq_len_cached
+ or self._cos_cached.device != x.device
+ or self._cos_cached.dtype != x.dtype
+ ):
+ self._seq_len_cached = seq_len
+ t = torch.arange(
+ x.shape[seq_dimension], device=x.device, dtype=torch.float32
+ )
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
+
+ return self._cos_cached, self._sin_cached
+
+ def forward(
+ self, q: torch.Tensor, k: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
+ k, seq_dimension=-2
+ )
+
+ return (
+ apply_rotary_pos_emb_(q, self._cos_cached, self._sin_cached),
+ apply_rotary_pos_emb_(k, self._cos_cached, self._sin_cached),
+ )
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..6b23bdda1c0b86b51e66a430ee053e1f32eb8251
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,109 @@
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[tool.hatch.build]
+sources = ["unidisc"]
+
+[tool.uv]
+package = false
+default-groups = ["dev", "misc"]
+no-build-isolation-package = ["flash-attn"]
+
+[project]
+name = "unidisc"
+version = "0.0.1"
+authors = [{ name="Alexander Swerdlow", email="aswerdlow1@gmail.com" }]
+readme = "README.md"
+requires-python = ">=3.10,<3.13"
+dependencies = [
+ "setuptools>=75.8.0",
+ "torch>=2.6.0",
+ "torchvision>=0.21.0",
+ "diffusers~=0.32.2",
+ "transformers~=4.49.0",
+ "datasets~=3.2.0",
+ "numpy~=2.2",
+ "tensordict~=0.7.2",
+ "accelerate~=1.5.2",
+ "lightning_utilities~=0.12.0",
+ "hydra-core~=1.3.2",
+ "omegaconf~=2.3.0",
+ "torchtnt~=0.2.4",
+ "jaxtyping~=0.2.37",
+ "einops~=0.8.0",
+ "timm~=1.0.15",
+ "wandb~=0.19.6",
+ "image_utilities==0.0.3*",
+ "typer~=0.15.1",
+ "torchmetrics==1.6.1",
+ "rich~=13.9.4",
+ "fsspec",
+ "pandas",
+ "ml_collections",
+ "scikit-learn",
+ "torchinfo",
+ "sentencepiece",
+ "hf_transfer",
+ "ipdb",
+ "ipython",
+ "lovely-tensors",
+]
+
+[dependency-groups]
+dev = [
+ "peft",
+ "braceexpand",
+ "h5py",
+ "pynvml",
+ "evaluate",
+ "mauve-text",
+ "clean-fid",
+ "hpsv2x==1.2.0",
+ "open_clip_torch",
+ "T2IBenchmark",
+ "clip",
+ "python-fasthtml~=0.12.1",
+ "MonsterUI~=0.0.34",
+ "fastapi~=0.115.8",
+ "flash-attn~=2.7.4",
+]
+misc = [
+ "flask",
+ "werkzeug",
+ "sentence_transformers",
+ "opencv-python",
+ "lpips",
+ "simple_slurm",
+ "ftfy",
+ "bitsandbytes",
+ "requests",
+ "deepspeed",
+]
+# Not important and can be difficult to install.
+# This may be easier: `uv pip install fairseq --no-deps`
+# hard = [
+# "fairseq",
+# "langchain~=0.3.17",
+# "langchain_core~=0.3.15",
+# "langchain_groq~=0.2.1",
+# ]
+
+[tool.uv.sources]
+torch = [{ index = "pytorch-cu124", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },]
+torchvision = [{ index = "pytorch-cu124", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },]
+webdataset = { git = "ssh://git@github.com/alexanderswerdlow/webdataset.git", rev = "67d1d487dc1a9aa6aaf81e6712deaec29c1ae3d3" }
+submitit = { git = "ssh://git@github.com/alexanderswerdlow/submitit.git", rev = "eb6368c068a9a64e9f09c9128b47c39a81add324" }
+T2IBenchmark = { git = "ssh://git@github.com/boomb0om/text2image-benchmark.git", rev = "532229f679d7e97ecba61914db7276f95733e707" }
+clip = { git = "ssh://git@github.com/openai/CLIP.git", rev = "dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1" }
+# hydra-core = { git = "ssh://git@github.com/alexanderswerdlow/hydra.git", branch = "working_ci" }
+# hydra-submitit-launcher = { git = "ssh://git@github.com/alexanderswerdlow/hydra.git", branch = "working_ci", subdirectory = "plugins/hydra_submitit_launcher" }
+
+[[tool.uv.index]]
+name = "pytorch-cu124"
+url = "https://download.pytorch.org/whl/cu124"
+explicit = true
+
+[tool.black]
+line-length = 150
+target-version = ['py310']
diff --git a/scripts/precompute_tokens_slurm.sh b/scripts/precompute_tokens_slurm.sh
new file mode 100755
index 0000000000000000000000000000000000000000..9cd8ba70c6b47126135837c77debd92f593be6d0
--- /dev/null
+++ b/scripts/precompute_tokens_slurm.sh
@@ -0,0 +1,70 @@
+#!/bin/bash
+#SBATCH --job-name=precompute_tokens
+#SBATCH --partition=all
+#SBATCH --nodes=1
+#SBATCH --gpus-per-node=2
+#SBATCH --cpus-per-gpu=8
+#SBATCH --mem-per-gpu=32G
+#SBATCH --time=06:00:00
+#SBATCH --output=outputs/logs/%A_%a_%n_log.out
+#SBATCH --signal=B:USR2@600
+
+echo "ibstatus: $(ibstatus)"
+echo "ibdev2netdev: $(ibdev2netdev)"
+echo "rdma device: $(rdma link)"
+
+unset NCCL_P2P_LEVEL
+export NCCL_P2P_DISABLE=1
+export NCCL_IB_DISABLE=1
+export NCCL_DEBUG=INFO
+export NCCL_NSOCKS_PERTHREAD=4
+export NCCL_SOCKET_NTHREADS=2
+
+export LOGLEVEL=INFO
+export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
+export MASTER_PORT=$(( RANDOM % (50000 - 30000 + 1 ) + 30000 ))
+echo MASTER_ADDR: $MASTER_ADDR
+echo MASTER_PORT: $MASTER_PORT
+echo "environment: $(env | grep NCCL)"
+echo "SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST"
+echo "SLURM_NNODES: $SLURM_NNODES"
+
+trap 'echo "SIGUSR2"; \
+if [ -n "$SLURM_ARRAY_JOB_ID" ]; then echo "SLURM_ARRAY_JOB_ID: $SLURM_ARRAY_JOB_ID"; fi; \
+if [ -n "$SLURM_ARRAY_TASK_ID" ]; then echo "SLURM_ARRAY_TASK_ID: $SLURM_ARRAY_TASK_ID"; fi; \
+# ps auxww | grep $USER; \
+pid=$(pgrep -u $USER -f "python.*(accelerate|torchrun|deepspeed|distributed\.run).*"); \
+echo "Found parent PIDs: $pid"; \
+for p in $pid; do \
+ echo "Parent PID has cmd: $(ps -p $p -o cmd=)"; \
+ children=$(pgrep -P $p); \
+ echo "Children: $children"; \
+ if [ -n "$children" ]; then \
+ for child in $children; do \
+ ppid=$(ps -o ppid= -p $child | tr -d " ")
+ if [ "$ppid" -eq "$p" ]; then
+ echo "Killing direct child process: PID $child with cmd: $(ps -p $child -o cmd=)"
+ kill -USR2 $child &
+ else
+ echo "Skipping non-direct child process: PID $child with PPID $ppid"
+ fi
+ done; \
+ echo "Sent kill signals to children of $p"; \
+ else \
+ echo "No children found for $p"; \
+ fi; \
+done; \
+wait;' SIGUSR2
+
+num_processes=$((SLURM_NNODES * SLURM_GPUS_PER_NODE))
+echo "num_processes: $num_processes"
+srun --label accelerate launch \
+ --rdzv_backend c10d \
+ --machine_rank $SLURM_NODEID \
+ --num_processes $num_processes \
+ --num_machines $SLURM_NNODES \
+ --dynamo_backend no \
+ --mixed_precision no \
+ --main_process_ip $MASTER_ADDR \
+ --main_process_port $MASTER_PORT \
+ "$@"
diff --git a/scripts/small_scale_eval.sh b/scripts/small_scale_eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..16e38d594e40988cbd18e6b5ebe740c61d12e82d
--- /dev/null
+++ b/scripts/small_scale_eval.sh
@@ -0,0 +1,111 @@
+# MIN_WAIT=60 MAX_WAIT=300 bash scripts/osync.sh --on-changes --initiator=/home/aswerdlo/hdd/data/unidisc/ckpts/sync --target=ssh://mprabhud@grogu//grogu/user/mprabhud/aswerdlo/unidisc/ckpts/sync
+
+export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:$LD_LIBRARY_PATH"
+export CUDA_HOME=$CONDA_PREFIX
+export UNIDISC_FORCE_CUDNN_SPDA_CONTEXT=0
+export NUM_GPUS=${NUM_GPUS:-4}
+export CONSTRAINT="L40|L40S|A100_40GB|A100_80GB|6000Ada|A6000|A4500"
+export MEM_PER_GPU=32
+export CPUS_PER_GPU=8
+
+export CKPT_DIR='/home/aswerdlo/repos/unidisc_arxiv'
+
+RUN_NAR=${RUN_NAR:-0}
+RUN_AR=${RUN_AR:-0}
+RUN_CC=${RUN_CC:-0}
+RUN_DB=${RUN_DB:-0}
+RUN_FLICKR=${RUN_FLICKR:-0}
+RUN_COCO=${RUN_COCO:-0}
+RUN_MEDIUM=${RUN_MEDIUM:-0}
+
+common_args=(\
+debug=true \
+model=$([[ "$RUN_MEDIUM" -eq 1 ]] && echo "medium" || echo "small") \
+loader.eval_batch_size=$([[ "$RUN_MEDIUM" -eq 1 ]] && echo "3" || echo "24") \
+trainer.compile=true \
++trainer.forced_keys='[eval.cfg,eval.unconditional_fid,sampling.predictor,data.fid_dataset,sampling.sampling_step_frac]' \
+model.force_optimized_native_attn=false \
+wandb.project='unidisc-jan-eval-ablations' \
+partition=preempt \
+wandb.tags='[11_12_fid_ar_v2]' \
+eval.fid_samples=16384 \
+sampling.predictor=maskgit \
+sampling.sampling_step_frac='0.05' \
+eval.cfg=2 \
+trainer.compile=false \
+slurm_name="${USER}_ablations_nar" \
+mem_per_gpu=$MEM_PER_GPU \
+cpus_per_gpu=$CPUS_PER_GPU \
+devices=$NUM_GPUS \
+constraint=$CONSTRAINT \
+partition=general)
+
+common_a_args=(\
++experiments='[small_scale_train,paired_standalone_fid_eval,master_eval,fid_hf]' data.fid_dataset="sayakpaul/coco-30-val-2014")
+
+common_b_args=(\
++experiments='[small_scale_train,paired_standalone_fid_eval,master_eval,fid_hf]' data.fid_dataset="nlphuji/flickr30k")
+
+common_c_args=(\
++experiments='[small_scale_train,paired_standalone_fid_eval,master_eval,fid_cc12m]')
+
+common_d_args=(\
++experiments='[small_scale_train,paired_standalone_fid_eval,master_eval,fid_datacomp1b]')
+
+if [ "$RUN_MEDIUM" -eq 1 ]; then
+ NAR_CKPT="$CKPT_DIR/300m_nar.safetensors"
+ AR_CKPT="$CKPT_DIR/300m_ar.safetensors"
+else
+ AR_CKPT="$CKPT_DIR/115m_ar.safetensors"
+ NAR_CKPT="$CKPT_DIR/115m_nar.safetensors"
+fi
+
+echo "RUN_AR: ${RUN_AR}, RUN_NAR: ${RUN_NAR}, RUN_MEDIUM: ${RUN_MEDIUM}"
+echo "RUN_CC: ${RUN_CC}, RUN_DB: ${RUN_DB}, RUN_FLICKR: ${RUN_FLICKR}, RUN_COCO: ${RUN_COCO}"
+echo "NAR_CKPT: ${NAR_CKPT}"
+echo "AR_CKPT: ${AR_CKPT}"
+
+if [ "$RUN_AR" -eq 1 ]; then
+ if [ "$RUN_COCO" -eq 1 ]; then
+ python main.py "${common_a_args[@]}" "${common_args[@]}" $@ parameterization=ar trainer.compile=false wandb.name="1_2_ar_60k" \
+ trainer.load_from_state_dict="$AR_CKPT" --multirun > /dev/null 2>&1 &
+ fi
+
+ if [ "$RUN_FLICKR" -eq 1 ]; then
+ python main.py "${common_b_args[@]}" "${common_args[@]}" $@ parameterization=ar trainer.compile=false wandb.name="1_2_ar_60k" \
+ trainer.load_from_state_dict="$AR_CKPT" --multirun > /dev/null 2>&1 &
+ fi
+
+ if [ "$RUN_CC" -eq 1 ]; then
+ echo "RUN_CC: ${RUN_CC}"
+ python main.py "${common_c_args[@]}" "${common_args[@]}" $@ parameterization=ar trainer.compile=false wandb.name="1_2_ar_60k" \
+ trainer.load_from_state_dict="$AR_CKPT" --multirun
+ fi
+
+ if [ "$RUN_DB" -eq 1 ]; then
+ python main.py "${common_d_args[@]}" "${common_args[@]}" $@ parameterization=ar trainer.compile=false wandb.name="1_2_ar_60k" \
+ trainer.load_from_state_dict="$AR_CKPT" --multirun > /dev/null 2>&1 &
+ fi
+fi
+
+if [ "$RUN_NAR" -eq 1 ]; then
+ if [ "$RUN_COCO" -eq 1 ]; then
+ python main.py "${common_a_args[@]}" "${common_args[@]}" $@ wandb.name="1_2_nar_325k" \
+ trainer.load_from_state_dict="$NAR_CKPT" --multirun > /dev/null 2>&1 &
+ fi
+
+ if [ "$RUN_FLICKR" -eq 1 ]; then
+ python main.py "${common_b_args[@]}" "${common_args[@]}" $@ wandb.name="1_2_nar_325k" \
+ trainer.load_from_state_dict="$NAR_CKPT" --multirun > /dev/null 2>&1 &
+ fi
+
+ if [ "$RUN_CC" -eq 1 ]; then
+ python main.py "${common_c_args[@]}" "${common_args[@]}" $@ wandb.name="1_2_nar_325k" \
+ trainer.load_from_state_dict="$NAR_CKPT" --multirun > /dev/null 2>&1 &
+ fi
+
+ if [ "$RUN_DB" -eq 1 ]; then
+ python main.py "${common_d_args[@]}" "${common_args[@]}" $@ wandb.name="1_2_nar_325k" \
+ trainer.load_from_state_dict="$NAR_CKPT" --multirun > /dev/null 2>&1 &
+ fi
+fi
diff --git a/scripts/train_large_scale_slurm.sh b/scripts/train_large_scale_slurm.sh
new file mode 100755
index 0000000000000000000000000000000000000000..13b265870f783ab42ac0790dcc8620e5ed996366
--- /dev/null
+++ b/scripts/train_large_scale_slurm.sh
@@ -0,0 +1,63 @@
+#!/bin/bash
+#SBATCH --job-name=unidisc
+#SBATCH --partition=preempt
+#SBATCH --nodes=2
+#SBATCH --gpus-per-node=8
+#SBATCH --cpus-per-gpu=12
+#SBATCH --mem-per-gpu=64G
+#SBATCH --constraint=L40S
+#SBATCH --time=31-00:00:00
+#SBATCH --output=outputs/logs/%x-%j-%N.out
+#SBATCH --error=outputs/logs/%x-%j-%N.out
+#SBATCH --requeue
+
+printenv
+
+echo "Hostname: $(hostname)"
+echo "ibstatus: $(ibstatus)"
+echo "ibdev2netdev: $(ibdev2netdev)"
+echo "rdma device: $(rdma link)"
+echo "hostnames: $(scontrol show hostnames $SLURM_JOB_NODELIST)"
+
+export LOGLEVEL=INFO
+export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
+export MASTER_PORT=$(( RANDOM % (50000 - 30000 + 1 ) + 30000 ))
+
+echo MASTER_ADDR: $MASTER_ADDR
+echo MASTER_PORT: $MASTER_PORT
+echo "environment: $(env | grep NCCL)"
+
+unset CUDA_VISIBLE_DEVICES
+unset CUDA_LAUNCH_BLOCKING
+unset NCCL_SOCKET_IFNAME
+unset NCCL_IB_DISABLE
+unset NCCL_NSOCKS_PERTHREAD
+unset NCCL_SOCKET_NTHREADS
+unset OMP_NUM_THREADS
+unset NCCL_P2P_LEVEL
+
+ulimit -l
+ulimit -a
+
+export NCCL_P2P_DISABLE=1
+export NCCL_IB_DISABLE=1
+export NCCL_DEBUG=INFO
+export PYTHONUNBUFFERED=1
+export UNIDISC_FORCE_CUDNN_SPDA_CONTEXT=1
+export UNIDISC_DISABLE_APEX_RMSNORM=1
+export UNIDISC_ROOT_OUTPUT_DIR="outputs"
+export HYDRA_RUN_DIR_NAME='large_scale_v0'
+
+# accelerate
+num_processes=$((SLURM_NNODES * SLURM_GPUS_PER_NODE))
+srun --label accelerate launch \
+ --multi_gpu \
+ --rdzv_backend c10d \
+ --machine_rank $SLURM_NODEID \
+ --num_processes $num_processes \
+ --num_machines $SLURM_NNODES \
+ --dynamo_backend no \
+ --mixed_precision no \
+ --main_process_ip $MASTER_ADDR \
+ --main_process_port $MASTER_PORT \
+ main.py experiments='[large_scale_train,large_scale_train_high_res_interleaved]' nodes=2
\ No newline at end of file
diff --git a/third_party/1d-tokenizer b/third_party/1d-tokenizer
new file mode 160000
index 0000000000000000000000000000000000000000..bdec006fb7226b309aaf3511956d0aafac30125a
--- /dev/null
+++ b/third_party/1d-tokenizer
@@ -0,0 +1 @@
+Subproject commit bdec006fb7226b309aaf3511956d0aafac30125a
diff --git a/third_party/LlamaGen b/third_party/LlamaGen
new file mode 160000
index 0000000000000000000000000000000000000000..c4cc58b72677ec1d94039e87aaff8d41a3cec232
--- /dev/null
+++ b/third_party/LlamaGen
@@ -0,0 +1 @@
+Subproject commit c4cc58b72677ec1d94039e87aaff8d41a3cec232
diff --git a/third_party/Lumina-mGPT b/third_party/Lumina-mGPT
new file mode 160000
index 0000000000000000000000000000000000000000..5888993293d2ee85898186c6add05337542a939f
--- /dev/null
+++ b/third_party/Lumina-mGPT
@@ -0,0 +1 @@
+Subproject commit 5888993293d2ee85898186c6add05337542a939f
diff --git a/third_party/Show-o b/third_party/Show-o
new file mode 160000
index 0000000000000000000000000000000000000000..b00d96e420686cc39358c068abda78a8c28acefb
--- /dev/null
+++ b/third_party/Show-o
@@ -0,0 +1 @@
+Subproject commit b00d96e420686cc39358c068abda78a8c28acefb
diff --git a/third_party/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/__init__.py b/third_party/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd82c62bf8c93ca0fee1e81430498dee95662eaa
--- /dev/null
+++ b/third_party/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+__version__ = "1.4.0.dev0"
diff --git a/third_party/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/config.py b/third_party/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..417f6edff330fcf6e81993c94f25dfb76006245b
--- /dev/null
+++ b/third_party/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/config.py
@@ -0,0 +1,105 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional
+
+from hydra.core.config_store import ConfigStore
+
+
+@dataclass
+class BaseQueueConf:
+ """Configuration shared by all executors"""
+
+ submitit_folder: str = "${hydra.sweep.dir}/.submitit/%j"
+ python: Optional[str] = None
+ python_suffix: Optional[str] = None
+
+ # maximum time for the job in minutes
+ timeout_min: int = 60
+ # number of cpus to use for each task
+ cpus_per_task: Optional[int] = None
+ # number of gpus to use on each node
+ gpus_per_node: Optional[int] = None
+ # number of tasks to spawn on each node
+ tasks_per_node: int = 1
+ # memory to reserve for the job on each node (in GB)
+ mem_gb: Optional[int] = None
+ # number of nodes to use for the job
+ nodes: int = 1
+ # name of the job
+ name: str = "${hydra.job.name}"
+ # redirect stderr to stdout
+ stderr_to_stdout: bool = False
+
+
+@dataclass
+class SlurmQueueConf(BaseQueueConf):
+ """Slurm configuration overrides and specific parameters"""
+
+ _target_: str = (
+ "hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher"
+ )
+
+ # Params are used to configure sbatch, for more info check:
+ # https://github.com/facebookincubator/submitit/blob/main/submitit/slurm/slurm.py
+
+ # Following parameters are slurm specific
+ # More information: https://slurm.schedmd.com/sbatch.html
+ #
+ # slurm partition to use on the cluster
+ partition: Optional[str] = None
+ qos: Optional[str] = None
+ comment: Optional[str] = None
+ constraint: Optional[str] = None
+ exclude: Optional[str] = None
+ gres: Optional[str] = None
+ cpus_per_gpu: Optional[int] = None
+ gpus_per_task: Optional[int] = None
+ mem_per_gpu: Optional[str] = None
+ mem_per_cpu: Optional[str] = None
+ account: Optional[str] = None
+
+ # Following parameters are submitit specifics
+ #
+ # USR1 signal delay before timeout
+ signal_delay_s: int = 120
+ # Maximum number of retries on job timeout.
+ # Change this only after you confirmed your code can handle re-submission
+ # by properly resuming from the latest stored checkpoint.
+ # check the following for more info on slurm_max_num_timeout
+ # https://github.com/facebookincubator/submitit/blob/main/docs/checkpointing.md
+ max_num_timeout: int = 0
+ # Useful to add parameters which are not currently available in the plugin.
+ # Eg: {"mail-user": "blublu@fb.com", "mail-type": "BEGIN"}
+ additional_parameters: Dict[str, Any] = field(default_factory=dict)
+ # Maximum number of jobs running in parallel
+ array_parallelism: int = 256
+ # A list of commands to run in sbatch befure running srun
+ setup: Optional[List[str]] = None
+ # Any additional arguments that should be passed to srun
+ srun_args: Optional[List[str]] = None
+ signal: Optional[str] = None
+ post_srun_commands: Optional[List[str]] = None
+
+
+@dataclass
+class LocalQueueConf(BaseQueueConf):
+ _target_: str = (
+ "hydra_plugins.hydra_submitit_launcher.submitit_launcher.LocalLauncher"
+ )
+
+
+# finally, register two different choices:
+ConfigStore.instance().store(
+ group="hydra/launcher",
+ name="submitit_local",
+ node=LocalQueueConf(),
+ provider="submitit_launcher",
+)
+
+
+ConfigStore.instance().store(
+ group="hydra/launcher",
+ name="submitit_slurm",
+ node=SlurmQueueConf(),
+ provider="submitit_launcher",
+)
diff --git a/third_party/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/py.typed b/third_party/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/py.typed
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/third_party/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py b/third_party/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..80c2f56cba3d715a0d378801aa8ddd17e0764864
--- /dev/null
+++ b/third_party/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py
@@ -0,0 +1,201 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import logging
+import os
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Sequence
+
+from hydra.core.singleton import Singleton
+from hydra.core.utils import JobReturn, filter_overrides, run_job, setup_globals
+from hydra.plugins.launcher import Launcher
+from hydra.types import HydraContext, TaskFunction
+from omegaconf import DictConfig, OmegaConf, open_dict
+
+from .config import BaseQueueConf
+
+class DummyLogger:
+ def __init__(self):
+ pass
+
+ def debug(self, message):
+ print(f"DEBUG: {message}")
+
+ def info(self, message):
+ print(f"INFO: {message}")
+
+ def warning(self, message):
+ print(f"WARNING: {message}")
+
+ def error(self, message):
+ print(f"ERROR: {message}")
+
+ def critical(self, message):
+ print(f"CRITICAL: {message}")
+
+log = DummyLogger()
+
+def tail_log_file(log_file_path, glob_str=None):
+ import subprocess
+ import time
+ max_retries = 60
+ retry_interval = 4
+ for _ in range(max_retries):
+ try:
+ if (glob_str is None and Path(log_file_path).exists()) or len(list(Path(log_file_path).rglob(glob_str))) > 0:
+ try:
+ if glob_str is None:
+ print(f"Tailing {log_file_path}")
+ proc = subprocess.Popen(['tail', '-f', "-n", "+1", f"{log_file_path}"], stdout=subprocess.PIPE)
+ else:
+ print(['tail', '-f', "-n", "+1", f"{log_file_path}/{glob_str}"])
+ proc = subprocess.Popen(['sh', '-c', f'tail -f -n +1 {log_file_path}/{glob_str}'], stdout=subprocess.PIPE)
+ for line in iter(proc.stdout.readline, b''):
+ print(line.decode('utf-8'), end='')
+ except:
+ proc.terminate()
+ except:
+ print(f"Tried to glob: {log_file_path}, {glob_str}")
+ finally:
+ time.sleep(retry_interval)
+
+ print(f"File not found: {log_file_path} after {max_retries * retry_interval} seconds...")
+
+class BaseSubmititLauncher(Launcher):
+ _EXECUTOR = "abstract"
+
+ def __init__(self, **params: Any) -> None:
+ self.params = {}
+ for k, v in params.items():
+ if OmegaConf.is_config(v):
+ v = OmegaConf.to_container(v, resolve=True)
+ self.params[k] = v
+
+ self.config: Optional[DictConfig] = None
+ self.task_function: Optional[TaskFunction] = None
+ self.sweep_configs: Optional[TaskFunction] = None
+ self.hydra_context: Optional[HydraContext] = None
+
+ def setup(
+ self,
+ *,
+ hydra_context: HydraContext,
+ task_function: TaskFunction,
+ config: DictConfig,
+ ) -> None:
+ self.config = config
+ self.hydra_context = hydra_context
+ self.task_function = task_function
+
+ def __call__(
+ self,
+ sweep_overrides: List[str],
+ job_dir_key: str,
+ job_num: int,
+ job_id: str,
+ singleton_state: Dict[type, Singleton],
+ sweep_keys: Optional[List[str]] = None,
+ ) -> JobReturn:
+ # lazy import to ensure plugin discovery remains fast
+ import submitit
+
+ assert self.hydra_context is not None
+ assert self.config is not None
+ assert self.task_function is not None
+
+ Singleton.set_state(singleton_state)
+ setup_globals()
+ sweep_config = self.hydra_context.config_loader.load_sweep_config(
+ self.config, sweep_overrides
+ )
+
+ with open_dict(sweep_config.hydra.job) as job:
+ # Populate new job variables
+ job.id = submitit.JobEnvironment().job_id # type: ignore
+ sweep_config.hydra.job.num = job_num
+ sweep_config.hydra.job.sweep_keys = sweep_keys
+
+ return run_job(
+ hydra_context=self.hydra_context,
+ task_function=self.task_function,
+ config=sweep_config,
+ job_dir_key=job_dir_key,
+ job_subdir_key="hydra.sweep.subdir",
+ )
+
+ def checkpoint(self, *args: Any, **kwargs: Any) -> Any:
+ """Resubmit the current callable at its current state with the same initial arguments."""
+ # lazy import to ensure plugin discovery remains fast
+ import submitit
+
+ return submitit.helpers.DelayedSubmission(self, *args, **kwargs)
+
+ def launch(
+ self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int, sweep_keys=None
+ ) -> Sequence[JobReturn]:
+ # lazy import to ensure plugin discovery remains fast
+ import submitit
+
+ assert self.config is not None
+
+ num_jobs = len(job_overrides)
+ assert num_jobs > 0
+ params = self.params
+ # build executor
+ init_params = {"folder": self.params["submitit_folder"]}
+ specific_init_keys = {"max_num_timeout", "python", "python_suffix"}
+
+ init_params.update(
+ **{
+ f"{self._EXECUTOR}_{x}": y
+ for x, y in params.items()
+ if x in specific_init_keys
+ }
+ )
+ init_keys = specific_init_keys | {"submitit_folder"}
+ executor = submitit.AutoExecutor(cluster=self._EXECUTOR, **init_params)
+
+ # specify resources/parameters
+ baseparams = set(OmegaConf.structured(BaseQueueConf).keys())
+ params = {
+ x if x in baseparams else f"{self._EXECUTOR}_{x}": y
+ for x, y in params.items()
+ if x not in init_keys
+ }
+ executor.update_parameters(**params)
+
+ log.info(
+ f"Submitit '{self._EXECUTOR}' sweep output dir : "
+ f"{self.config.hydra.sweep.dir}"
+ )
+ sweep_dir = Path(str(self.config.hydra.sweep.dir))
+ sweep_dir.mkdir(parents=True, exist_ok=True)
+ if "mode" in self.config.hydra.sweep:
+ mode = int(str(self.config.hydra.sweep.mode), 8)
+ os.chmod(sweep_dir, mode=mode)
+
+ job_params: List[Any] = []
+ for idx, overrides in enumerate(job_overrides):
+ idx = initial_job_idx + idx
+ lst = " ".join(filter_overrides(overrides))
+ log.info(f"\t#{idx} : {lst}")
+ job_params.append(
+ (
+ list(overrides),
+ "hydra.sweep.dir",
+ idx,
+ f"job_id_for_{idx}",
+ Singleton.get_state(),
+ [] if sweep_keys is None else list(sweep_keys),
+ )
+ )
+
+ jobs = executor.map_array(self, *zip(*job_params))
+ # tail_log_file(str(Path(jobs[0].paths.stdout).parent.parent), "**/*.out")
+ return [j.results()[0] for j in jobs]
+
+
+class LocalLauncher(BaseSubmititLauncher):
+ _EXECUTOR = "local"
+
+
+class SlurmLauncher(BaseSubmititLauncher):
+ _EXECUTOR = "slurm"
diff --git a/third_party/hydra_submitit_launcher/hydra_submitit_launcher.egg-info/PKG-INFO b/third_party/hydra_submitit_launcher/hydra_submitit_launcher.egg-info/PKG-INFO
new file mode 100644
index 0000000000000000000000000000000000000000..fbebeb6f962a966e2dcb0d715f857bc04e398d6d
--- /dev/null
+++ b/third_party/hydra_submitit_launcher/hydra_submitit_launcher.egg-info/PKG-INFO
@@ -0,0 +1,18 @@
+Metadata-Version: 2.1
+Name: hydra-submitit-launcher
+Version: 1.4.0.dev0
+Summary: Submitit Launcher for Hydra apps
+Home-page: https://github.com/facebookincubator/submitit
+Author: Jeremy Rapin, Jieru Hu, Omry Yadan
+Author-email: jrapin@fb.com, jieru@fb.com, omry@fb.com
+Classifier: License :: OSI Approved :: MIT License
+Classifier: Programming Language :: Python :: 3.8
+Classifier: Programming Language :: Python :: 3.9
+Classifier: Programming Language :: Python :: 3.10
+Classifier: Programming Language :: Python :: 3.11
+Classifier: Operating System :: MacOS
+Classifier: Operating System :: POSIX :: Linux
+Classifier: Development Status :: 4 - Beta
+Description-Content-Type: text/markdown
+Requires-Dist: hydra-core>=1.1.0.dev7
+Requires-Dist: submitit>=1.3.3
diff --git a/third_party/hydra_submitit_launcher/hydra_submitit_launcher.egg-info/SOURCES.txt b/third_party/hydra_submitit_launcher/hydra_submitit_launcher.egg-info/SOURCES.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a62639cf3149ad60bc71a9098dade6f1b2034843
--- /dev/null
+++ b/third_party/hydra_submitit_launcher/hydra_submitit_launcher.egg-info/SOURCES.txt
@@ -0,0 +1,11 @@
+pyproject.toml
+setup.py
+hydra_plugins/hydra_submitit_launcher/__init__.py
+hydra_plugins/hydra_submitit_launcher/config.py
+hydra_plugins/hydra_submitit_launcher/py.typed
+hydra_plugins/hydra_submitit_launcher/submitit_launcher.py
+hydra_submitit_launcher.egg-info/PKG-INFO
+hydra_submitit_launcher.egg-info/SOURCES.txt
+hydra_submitit_launcher.egg-info/dependency_links.txt
+hydra_submitit_launcher.egg-info/requires.txt
+hydra_submitit_launcher.egg-info/top_level.txt
\ No newline at end of file
diff --git a/third_party/hydra_submitit_launcher/hydra_submitit_launcher.egg-info/dependency_links.txt b/third_party/hydra_submitit_launcher/hydra_submitit_launcher.egg-info/dependency_links.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/third_party/hydra_submitit_launcher/hydra_submitit_launcher.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/third_party/hydra_submitit_launcher/hydra_submitit_launcher.egg-info/requires.txt b/third_party/hydra_submitit_launcher/hydra_submitit_launcher.egg-info/requires.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4b58671f54c263bfc77b3e9fa9cf42bb406a8df4
--- /dev/null
+++ b/third_party/hydra_submitit_launcher/hydra_submitit_launcher.egg-info/requires.txt
@@ -0,0 +1,2 @@
+hydra-core>=1.1.0.dev7
+submitit>=1.3.3
diff --git a/third_party/hydra_submitit_launcher/hydra_submitit_launcher.egg-info/top_level.txt b/third_party/hydra_submitit_launcher/hydra_submitit_launcher.egg-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..829831622910a4e787fc904bcb5e97908c88b7fc
--- /dev/null
+++ b/third_party/hydra_submitit_launcher/hydra_submitit_launcher.egg-info/top_level.txt
@@ -0,0 +1 @@
+hydra_plugins
diff --git a/third_party/hydra_submitit_launcher/pyproject.toml b/third_party/hydra_submitit_launcher/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..b4d6dc930bf8b9179b5866ea6731c285acd3b1ad
--- /dev/null
+++ b/third_party/hydra_submitit_launcher/pyproject.toml
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools", "wheel", "read-version"]
+build-backend = "setuptools.build_meta"
\ No newline at end of file
diff --git a/third_party/hydra_submitit_launcher/setup.py b/third_party/hydra_submitit_launcher/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..448ebd4a50dc71831a2d9e2e524f9f7c2b62cc28
--- /dev/null
+++ b/third_party/hydra_submitit_launcher/setup.py
@@ -0,0 +1,32 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# type: ignore
+from pathlib import Path
+
+from read_version import read_version
+from setuptools import find_namespace_packages, setup
+
+setup(
+ name="hydra-submitit-launcher",
+ version=read_version("hydra_plugins/hydra_submitit_launcher", "__init__.py"),
+ author="Jeremy Rapin, Jieru Hu, Omry Yadan",
+ author_email="jrapin@fb.com, jieru@fb.com, omry@fb.com",
+ description="Submitit Launcher for Hydra apps",
+ long_description_content_type="text/markdown",
+ url="https://github.com/facebookincubator/submitit",
+ packages=find_namespace_packages(include=["hydra_plugins.*"]),
+ classifiers=[
+ "License :: OSI Approved :: MIT License",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Operating System :: MacOS",
+ "Operating System :: POSIX :: Linux",
+ "Development Status :: 4 - Beta",
+ ],
+ install_requires=[
+ "hydra-core>=1.1.0.dev7",
+ "submitit>=1.3.3",
+ ],
+ include_package_data=True,
+)
diff --git a/unidisc/datasets/README.md b/unidisc/datasets/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..bb335dbedb86c6550fbbf856c7e5d8e4d73cebc8
--- /dev/null
+++ b/unidisc/datasets/README.md
@@ -0,0 +1,31 @@
+Broadly speaking, we have a few types of datasets:
+
+1. WebDataset (preferred)
+
+These are provided e.g., by [img2dataset](https://github.com/rom1504/img2dataset) and are in the standardized [WebDataset format](https://github.com/webdataset/webdataset) consisting of a collection of `tar` files. We generally use a modified dataloader to use an Indexed WebDataset, making things like pre-processing easier.
+
+2. Huggingface datasets
+
+These are loaded from HF in `dataloader.py`, typically for text datasets or smaller datasets (e.g., for evaluation).
+
+3. Tokenized TensorDict datasets
+
+This is how most training is done to avoid the overhead of VQ-VAE tokenization during training. The data is stored as integers on disk in a [TensorDict](https://github.com/pytorch/tensordict) container, and possibly loaded into memory during training (either in the dataloader process space or `/dev/shm`).
+
+
+
+# Synthetic generation
+
+Generating synthetic text/image pairs has been cited in many T2I papers as critical for efficient training. Unfortunately, almost all of these datasets are proprietary, so we opt to create our own.
+
+We first combine text captions from the following sources:
+
+[HPSv2](https://github.com/tgxs002/HPSv2)
+[ImageReward](https://github.com/THUDM/ImageReward)
+[PickScore](https://github.com/yuvalkirstain/PickScore)
+[simulacra-aesthetic-captions](https://github.com/JD-P/simulacra-aesthetic-captions/tree/main)
+[gecko_benchmark_t2i](https://github.com/google-deepmind/gecko_benchmark_t2i) (For evaluation)
+
+To further diversify, we use prompt an LLM to use these as inspiration for new captions, giving it a list of random entities from wordnet to incorporate into the caption. Finally, we use Stable Diffusion 3.5 medium to generate 512x512 images for each caption.
+
+We make this process fully distributed by having each job take unused (or less commonly used) captions, and generate images in small batches.
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/cambrian/README.md b/unidisc/datasets/preprocessing/cambrian/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..610efa94d50019ebaadeaabe516e6ab630f698e5
--- /dev/null
+++ b/unidisc/datasets/preprocessing/cambrian/README.md
@@ -0,0 +1,33 @@
+# Cambrian
+
+First, download the Cambrian dataset from [here](https://huggingface.co/datasets/cambrian/cambrian-10m).
+
+```bash
+huggingface-cli download nyu-visionx/Cambrian-10M --local-dir . --repo-type dataset --resume-download
+```
+
+Next, untar all the *.tar.gz files.
+
+Next, precompute the tokens:
+
+**_Note:_** If you are on a SLURM cluster, you can replace `accelerate launch` with:
+
+```bash
+sbatch --time=2-00:00:00 --array=0-100%25 --cpus-per-gpu=12 --mem-per-gpu=100G --nodes=1 --gpus-per-node=1 --partition=preempt --job-name=cambrian_precompute_tokens scripts/precompute_tokens_slurm.sh
+```
+
+**_Note:_** If you want to only generate a subset of the tokens, append e.g., `data.n_train_samples=200` to the command.
+
+Finally, to tokenize the dataset, run:
+
+```bash
+accelerate launch models/datasets/precompute_tokens.py +experiments='[generated_images,tokenize,vq16_t2i]' data.token_output_dir="/path/to/token_output_dir" data.resolution=512 data.use_chameleon=false model.img_length=3072 data.block_size=3072 loader.batch_size=16 data.train='cambrian' data.raw_data_dir='/path/to/cambrian/jsons/Cambrian10M.jsonl' +model.text_vocab_size=32001 data.img_token_shift=32001 +data.use_identity_collate=true loader.num_workers=2 data.split_dataset=true +data.save_tmp_interval=3600 +data.use_slow_tokenizer=true +data.add_image_token=true
+```
+
+Now that the tokenization is complete, if it was done over multiple GPUs/nodes, you must combine the tensordicts on disk. If the to
+
+```bash
+python models/datasets/combine_token_dicts.py "/path/to/token_output_dir" --move_files --delete_after_combining --mem_efficient
+```
+
+**_Note:_** You may wish to add the `--allow_tmp` flag to the command if the tokenization was only partially completed (e.g., due to a SLURM job being preempted). In this case, the tokenization saves intermediate checkpoints with a `tmp_` prefix.
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/capsfusion/README.md b/unidisc/datasets/preprocessing/capsfusion/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..55106a6ce59bfa52025684982ea7c792bf9193d2
--- /dev/null
+++ b/unidisc/datasets/preprocessing/capsfusion/README.md
@@ -0,0 +1,31 @@
+
+# CapsFusion
+
+First, download the parquet files from [here](https://huggingface.co/datasets/BAAI/CapsFusion-120M/tree/main).
+
+```bash
+huggingface-cli download BAAI/CapsFusion-120M --local-dir . --repo-type dataset --resume-download
+```
+
+Next, download the images using `img2dataset`.
+
+```bash
+URL_DIR=/path/to/capsfusion/parquet
+RAW_IMG_DIR=/path/to/capsfusion/wds
+
+mkdir -p $RAW_IMG_DIR
+img2dataset \
+ --input_format=parquet \
+ --url_list=$URL_DIR \
+ --output_folder=$RAW_IMG_DIR \
+ --processes_count=32 \
+ --image_size=512 \
+ --resize_mode=keep_ratio \
+ --resize_only_if_bigger=True \
+ --output_format=webdataset \
+ --url_col=image_url \
+ --caption_col=capsfusion \
+ --enable_wandb=True 2>&1 | tee -a caps_fusion_img_download.log
+```
+
+Please see the [WebDataset](../webdataset.md) for more information on how to further process and then tokenize the WebDataset.
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/cc12m/README.md b/unidisc/datasets/preprocessing/cc12m/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7b2facbc4e15167f9e0ee0c781cf624c34c817b4
--- /dev/null
+++ b/unidisc/datasets/preprocessing/cc12m/README.md
@@ -0,0 +1,12 @@
+# CC12M
+
+We used the very helpful [cc12m-wds](https://huggingface.co/datasets/pixparse/cc12m-wds) dataset to avoid having to download the images from the original source.
+
+## Downloading the dataset
+
+```bash
+huggingface-cli download pixparse/cc12m-wds --local-dir . --repo-type datasetd
+huggingface-cli download pixparse/cc3m-wds --local-dir . --repo-type dataset
+```
+
+widsindex create *train*.tar
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/datacomp_1b/README.md b/unidisc/datasets/preprocessing/datacomp_1b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..51e182d332121648c6d669f837a32e9afe882be3
--- /dev/null
+++ b/unidisc/datasets/preprocessing/datacomp_1b/README.md
@@ -0,0 +1,27 @@
+# Datacomp 1b
+
+We use ReComp DataComp1B for a set of re-captioned, high-quality image/text pairs.
+
+Please download the metadata from [here](https://huggingface.co/datasets/UCSC-VLAA/Recap-DataComp-1B):
+
+```bash
+huggingface-cli download UCSC-VLAA/Recap-DataComp-1B --repo-type dataset --local-dir .
+```
+
+Then optionally split the parquet files into smaller chunks:
+
+```bash
+python split_parquet.py
+```
+
+Then download the actual images into a WebDataset format using [img2dataset](https://github.com/rom1504/img2dataset). Change the resolution, number of processes, number of thread, and input/output folders as needed.
+
+```bash
+input_folder='/path/to/recap-datacomp-1b/data/train_data_split/split_0'
+img2dataset --url_list "$input_folder" --input_format "parquet" \
+--url_col "url" --caption_col "re_caption" --output_format webdataset \
+--output_folder recap_datacomp_1b_data --processes_count 16 --thread_count 128 \
+--save_additional_columns '["org_caption"]' --enable_wandb True --image_size 256 --output_folder "/scratch/data/datacomp_1b_${input_folder##*/}" --resize_mode center_crop
+```
+
+Please see the [WebDataset](../webdataset.md) for more information on how to further process and then tokenize the WebDataset.
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/datacomp_1b/split_parquet.py b/unidisc/datasets/preprocessing/datacomp_1b/split_parquet.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d30f22915b4bc85ee85c730136bab3d44bd2c6c
--- /dev/null
+++ b/unidisc/datasets/preprocessing/datacomp_1b/split_parquet.py
@@ -0,0 +1,26 @@
+from pathlib import Path
+import shutil
+import random
+
+def split_parquet_files(input_folder, output_folder, max_files_per_folder=100):
+ input_folder = Path(input_folder)
+ output_folder = Path(output_folder)
+
+ parquet_files = list(input_folder.glob('*.parquet'))
+
+ random.shuffle(parquet_files)
+ output_folder.mkdir(parents=True, exist_ok=True)
+
+ for i in range(0, len(parquet_files), max_files_per_folder):
+ subfolder_name = f"subfolder_{i // max_files_per_folder + 1}"
+ subfolder_path = output_folder / subfolder_name
+
+ subfolder_path.mkdir(parents=True, exist_ok=True)
+
+ for file in parquet_files[i:i + max_files_per_folder]:
+ shutil.move(str(file), str(subfolder_path / file.name))
+
+# Example usage
+input_folder = '/path/to/recap-datacomp-1b/data/train_data'
+output_folder = '/path/to/recap-datacomp-1b/data/train_data_split'
+split_parquet_files(input_folder, output_folder)
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/journeydb/README.md b/unidisc/datasets/preprocessing/journeydb/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..515dbaacd4a25fd3ee0e9d521099c1cb2b98792f
--- /dev/null
+++ b/unidisc/datasets/preprocessing/journeydb/README.md
@@ -0,0 +1,16 @@
+# JourneyDB
+
+To download the dataset, run:
+
+```bash
+huggingface-cli download JourneyDB/JourneyDB --repo-type dataset --local-dir . --include "data/train/train_anno.jsonl.tgz"
+huggingface-cli download JourneyDB/JourneyDB --repo-type dataset --local-dir . --include "data/train/train_anno_realease_repath.jsonl.tgz"
+```
+
+To convert the dataset to a WebDataset, run:
+
+```bash
+python unidisc/datasets/preprocessing/journeydb/create_wds.py
+```
+
+Please see the [WebDataset](../webdataset.md) for more information on how to further process and then tokenize the WebDataset.
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/journeydb/create_wds.py b/unidisc/datasets/preprocessing/journeydb/create_wds.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f25fbdb790526a10985d494805bc6685f434943
--- /dev/null
+++ b/unidisc/datasets/preprocessing/journeydb/create_wds.py
@@ -0,0 +1,164 @@
+from concurrent.futures import ProcessPoolExecutor, as_completed
+from contextlib import nullcontext
+import glob
+import json
+import os
+import subprocess
+import tarfile
+from pathlib import Path
+from viztracer import VizTracer
+
+import webdataset as wds
+from tqdm import tqdm
+import tempfile
+import shutil
+from io import BytesIO
+from PIL import Image
+import socket
+
+prefix_path = Path("/scratch")
+
+output_tar_dir = prefix_path / Path("journeydb_wds/train")
+output_tar_dir.mkdir(parents=True, exist_ok=True)
+output_tar_path = str(output_tar_dir) + f"/output_dataset_%06d.tar"
+input_tgz_path = prefix_path / Path("journeydb/data/train/train_anno_realease_repath.jsonl.tgz")
+tgz_dir = prefix_path / Path("journeydb/data/train/imgs")
+
+jsonl_data = []
+print(f"Opening {input_tgz_path}")
+with tarfile.open(input_tgz_path, "r:gz") as tar:
+ for member in tar.getmembers():
+ if member.isfile() and member.name.endswith('.jsonl'):
+ f = tar.extractfile(member)
+ for line in f:
+ jsonl_data.append(json.loads(line))
+
+tgz_files = list(tgz_dir.glob("*.tgz"))
+prefix_to_tgz = {tgz_file.stem: tgz_file for tgz_file in tgz_files}
+cached_tgz_files = {prefix: tarfile.open(tgz_file, "r:gz") for prefix, tgz_file in prefix_to_tgz.items()}
+
+print(f"Extracted {len(jsonl_data)} samples")
+prefix_to_samples = {}
+for sample in jsonl_data:
+ img_path = sample["img_path"].removeprefix("./")
+ prefix = img_path.split('/')[0]
+ if prefix not in prefix_to_samples:
+ prefix_to_samples[prefix] = []
+ prefix_to_samples[prefix].append(sample)
+
+print(f"Prefix to samples: {len(prefix_to_samples)}")
+profile = False
+max_samples = 3 if profile else None
+
+mem_path = Path('/dev/shm/aswerdlo')
+mem_path.mkdir(parents=True, exist_ok=True)
+resolution = 1024
+
+def process_prefix(samples, tgz_file_path, output_tar_path, mem_path, max_samples, worker_id):
+ with tarfile.open(tgz_file_path, "r:gz") as tgz:
+ tmpdirname = tempfile.mkdtemp(dir=mem_path)
+ try:
+ print(f"Extracting {tgz_file_path} to {tmpdirname}")
+ tgz.extractall(path=tmpdirname)
+ output_path = output_tar_path.removesuffix('.tar') + f"_{worker_id}.tar"
+ print(f"Extracted {tgz_file_path} to {tmpdirname}, writing to {output_path}")
+ with wds.ShardWriter(output_path, maxsize=500*1024*1024) as sink:
+ for idx, sample in tqdm(enumerate(samples), desc=f'Worker {worker_id}', total=len(samples)):
+ if max_samples is not None and idx >= max_samples:
+ break
+
+ if idx == 0 or idx % 1000 == 0:
+ print(f"Worker {worker_id} processed {idx} samples")
+
+ img_path = sample["img_path"].removeprefix("./")
+ file_path = os.path.join(tmpdirname, img_path)
+ if os.path.exists(file_path):
+ try:
+ img = Image.open(file_path)
+ width, height = img.size
+ if width > height:
+ left = (width - height) / 2
+ top = 0
+ right = (width + height) / 2
+ bottom = height
+ else:
+ left = 0
+ top = (height - width) / 2
+ right = width
+ bottom = (height + width) / 2
+ img = img.crop((left, top, right, bottom))
+ img = img.resize((resolution, resolution), Image.LANCZOS)
+
+ img_byte_arr = BytesIO()
+ img.save(img_byte_arr, format='JPEG', quality=95)
+ img_data = img_byte_arr.getvalue()
+
+ key = Path(img_path).stem
+ sample_dict = {
+ "__key__": key,
+ "txt": sample['Task2']["Caption"],
+ "jpg": img_data
+ }
+ sink.write(sample_dict)
+ if os.path.exists(file_path):
+ os.remove(file_path)
+ except Exception as e:
+ print(f"Skipping bad sample {file_path}: {e}")
+ finally:
+ shutil.rmtree(tmpdirname)
+
+with VizTracer(output_file="result2.json") if profile else nullcontext():
+ with ProcessPoolExecutor(max_workers=8) as executor:
+ futures = []
+ for worker_id, (prefix, samples) in enumerate(tqdm(sorted(prefix_to_samples.items()))):
+ if prefix not in prefix_to_tgz:
+ print(f"Prefix {prefix} not found in tgz files")
+ continue
+ tgz_file_path = prefix_to_tgz[prefix]
+ futures.append(executor.submit(process_prefix, samples, tgz_file_path, output_tar_path, mem_path, max_samples, worker_id))
+
+ for future in as_completed(futures):
+ try:
+ future.result()
+ except Exception as e:
+ print(f"Error processing prefix: {e}")
+
+# Rename the tar files to remove the worker_id prefix
+for worker_id in range(len(prefix_to_samples)):
+ for tar_file in glob.glob(str(output_tar_dir / f"output_dataset_{worker_id:06d}_*.tar")):
+ new_name = tar_file.replace(f"_{worker_id:06d}_", "_")
+ counter = 0
+ while os.path.exists(new_name):
+ counter += 1
+ new_name = new_name.replace(".tar", f"_{counter:06d}.tar")
+ os.rename(tar_file, new_name)
+
+for tgz in cached_tgz_files.values():
+ tgz.close()
+
+tar_files = glob.glob(str(output_tar_dir / "*.tar"))
+
+if tar_files:
+ os.chdir(output_tar_dir)
+ command = ["widsindex", "create"] + [os.path.basename(f) for f in tar_files]
+ try:
+ result = subprocess.run(command, check=True, capture_output=True, text=True)
+ print("widsindex command executed successfully.")
+ print("Output:", result.stdout)
+ except subprocess.CalledProcessError as e:
+ print(f"Error running widsindex command: {e.returncode}\n{e.stderr}")
+ os.chdir(Path(__file__).parent)
+else:
+ print("No tar files found in the output directory.")
+
+print("Testing WebDataset reading:")
+
+dataset = wds.WebDataset(tar_files).decode("rgb")
+for i, sample in enumerate(dataset):
+ if i >= 5: # Print details for the first 5 samples
+ break
+ print(f"Sample {i + 1}:")
+ print(f"Key: {sample['__key__']}")
+ print(f"Image size: {sample['jpg'].size}")
+ print(f"Text: {sample['txt']}")
+ print()
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/mmc4/concat_parquet.py b/unidisc/datasets/preprocessing/mmc4/concat_parquet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d22b3a82a52efe442936947cc8f09c275ad5160a
--- /dev/null
+++ b/unidisc/datasets/preprocessing/mmc4/concat_parquet.py
@@ -0,0 +1,52 @@
+import argparse
+from pathlib import Path
+import pandas as pd
+from tqdm import tqdm
+import glob
+
+def load_and_concatenate_parquet_files(file_patterns, relative_to=None):
+ # Initialize an empty list to hold dataframes
+ dataframes = []
+
+ # Iterate through all file patterns
+ for pattern in file_patterns:
+ # Expand the glob pattern
+ parquet_files = glob.glob(pattern)
+
+ for file in tqdm(parquet_files):
+ # Load the parquet file
+ try:
+ df = pd.read_parquet(file)
+ except Exception as e:
+ print(f"Error loading {file}: {e}")
+ continue
+
+ # Extract shard id from the file name
+ shard_id = Path(file).stem
+
+ # Add the shard id as a new column
+ df['img2dataset_shard_id'] = shard_id
+ df['tar_filepath'] = str(Path(file).with_suffix(".tar"))
+ if args.relative_to:
+ df['relative_tar_filepath'] = str(Path(file).with_suffix(".tar").relative_to(args.relative_to))
+
+ # Append the dataframe to the list
+ dataframes.append(df)
+
+ # Concatenate all dataframes
+ concatenated_df = pd.concat(dataframes, ignore_index=True)
+
+ return concatenated_df
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Concatenate Parquet files from multiple paths or glob patterns")
+ parser.add_argument("file_patterns", nargs='+', type=str, help="Paths or glob patterns for Parquet files")
+ parser.add_argument("-o", "--output", type=str, default="concatenated_mmc4.parquet", help="Output file name")
+ parser.add_argument("--relative-to", type=str, help="Base path to make tar_filepath relative to (optional)")
+ args = parser.parse_args()
+
+ concatenated_df = load_and_concatenate_parquet_files(args.file_patterns, args.relative_to)
+ output_file = Path(args.output)
+ concatenated_df.to_parquet(output_file, index=False)
+ print(f"Concatenated data saved to {output_file}")
+ print(f"Total rows: {len(concatenated_df)}")
diff --git a/unidisc/datasets/preprocessing/mmc4/download_shards.sh b/unidisc/datasets/preprocessing/mmc4/download_shards.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1077828a438e17e6c9bc88f299aeb9ad245dd927
--- /dev/null
+++ b/unidisc/datasets/preprocessing/mmc4/download_shards.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+# Check if the destination folder argument is provided
+if [ $# -eq 0 ]; then
+ echo "Please provide the destination folder as an argument."
+ echo "Usage: ./download_and_unzip.sh /path/to/destination/folder"
+ exit 1
+fi
+
+# Set the download URL base
+URL_BASE="https://storage.googleapis.com/ai2-jackh-mmc4-public/data_v1.1/docs_no_face_shard_"
+
+# Smaller core shards
+# URL_BASE="https://storage.googleapis.com/ai2-jackh-mmc4-public/data_core_v1.1/docs_no_face_shard_"
+
+# Set the folder where you want to save the unzipped files
+DESTINATION_FOLDER="$1"
+
+# Create the destination folder if it doesn't exist
+mkdir -p "$DESTINATION_FOLDER"
+
+# Loop through the shard numbers and download and unzip the files, max: 23098
+for SHARD in {0..23098}; do
+ URL="${URL_BASE}${SHARD}_v2.jsonl.zip"
+ ZIP_FILE="${DESTINATION_FOLDER}/shard_${SHARD}.zip"
+ echo "Downloading shard $SHARD from $URL..."
+
+ # Download the file (continue if the file is missing or there is an error)
+ curl -fsSL --retry 3 --retry-delay 5 --max-time 20 --continue-at - "$URL" -o "$ZIP_FILE" || echo "Error downloading shard $SHARD, continuing..."
+
+ # Unzip the file if it was downloaded successfully
+ if [ -f "$ZIP_FILE" ]; then
+ echo "Unzipping $ZIP_FILE to $DESTINATION_FOLDER..."
+ yes | unzip -q "$ZIP_FILE" -d "$DESTINATION_FOLDER"
+
+ # Remove the zip file after unzipping
+ rm "$ZIP_FILE"
+ fi
+
+done
+
+echo "Download and unzip process completed."
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/mmc4/get_urls.py b/unidisc/datasets/preprocessing/mmc4/get_urls.py
new file mode 100644
index 0000000000000000000000000000000000000000..754fa5a12f2abc36b6c79d68c2de0f4b45ea677a
--- /dev/null
+++ b/unidisc/datasets/preprocessing/mmc4/get_urls.py
@@ -0,0 +1,40 @@
+from multiprocessing import Pool
+import tqdm
+import argparse
+import json
+import os
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('--input_jsonl', type=str, default=None, help='Local path to the input jsonl file')
+parser.add_argument("--output_dir", type=str)
+args = parser.parse_args()
+
+def gather_image_info(input_jsonl):
+ """Gather image info from the input jsonl"""
+ # data = []
+ output_urls = []
+ with open(input_jsonl) as f:
+ for line in tqdm.tqdm(f):
+ info = json.loads(line.strip())
+ for img_item in info['image_info']:
+ # data.append({
+ # 'local_identifier': img_item['image_name'],
+ # 'url': img_item['raw_url'],
+ # })
+ output_urls.append(img_item['raw_url'])
+ # return data
+ return output_urls
+
+filename = os.path.basename(args.input_jsonl)
+output_filepath = os.path.join(args.output_dir, filename.replace('.jsonl', '_urls.txt'))
+print(f'Reading from {args.input_jsonl} and writing to {output_filepath}')
+# data = gather_image_info(args)
+urls = gather_image_info(args.input_jsonl)
+# with open(output_filepath, 'w') as f:
+# for url in urls:
+# f.write(url + '\n')
+# save to parquet
+import pandas as pd
+df = pd.DataFrame({'url': urls})
+df.to_parquet(output_filepath.replace('.txt', '.parquet'))
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/mmc4/process_mmc4.py b/unidisc/datasets/preprocessing/mmc4/process_mmc4.py
new file mode 100644
index 0000000000000000000000000000000000000000..58235b4b3b5cd1b8a861a3ff25fdecdb82d20b20
--- /dev/null
+++ b/unidisc/datasets/preprocessing/mmc4/process_mmc4.py
@@ -0,0 +1,183 @@
+import math
+import gzip
+import random
+import tarfile
+import argparse
+import itertools
+import ujson as json
+import pandas as pd
+import multiprocessing as mp
+
+from glob import glob
+from tqdm import tqdm
+from typing import Optional, List, Dict, Tuple
+from collections import defaultdict
+import base64
+
+def load_image_bytes_to_base64(image_bytes: bytes) -> str:
+ # convert image to jpeg, then to data:image/jpeg;base64,
+ encoded_string = base64.b64encode(image_bytes).decode('utf-8')
+ return f"data:image/jpeg;base64,{encoded_string}"
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--input-mapping-parquet', type=str, default="data/raw/mmc4/data/images_no_face_v3.selected.parquet")
+parser.add_argument('--input-docs-glob', type=str, default="data/raw/mmc4/data/docs_no_face_v3/*.jsonl")
+parser.add_argument('--input-images-dir', type=str, default="data/raw/mmc4/data/images_no_face_v3")
+parser.add_argument('--output-filepath', type=str, help='Output file', default='data/processed/mmc4.shard_{shard_id:03d}.jsonl.gz')
+parser.add_argument('--n-workers', type=int, default=4)
+parser.add_argument('--chunk-size', type=int, default=32)
+parser.add_argument('--n-output-shards', type=int, default=128)
+parser.add_argument('--before-ratio', type=float, default=1.0) # default to always insert before text
+parser.add_argument('--seed', type=int, default=42)
+parser.add_argument('--remove-instances-missing-images', action="store_true")
+args = parser.parse_args()
+
+# Set seed
+random.seed(args.seed)
+
+# Load the mapping
+mapping = pd.read_parquet(args.input_mapping_parquet)
+mapping = mapping[["img2dataset_shard_id", "key", "url"]]
+mapping["tar_filepath"] = mapping["img2dataset_shard_id"].apply(lambda x: f"{args.input_images_dir}/{x}.tar")
+mapping = mapping[['url', 'tar_filepath', 'key']]
+mapping = mapping.set_index("url")
+
+def load_image(tar_filepath, key) -> bytes:
+ with tarfile.open(tar_filepath) as tar:
+ with tar.extractfile(f"{key}.jpg") as f:
+ return f.read()
+
+def convert_data_instance(jsonl_row: str) -> Optional[Tuple[List[Dict], Dict]]:
+ """
+ # Example
+ {'image_info': [{'face_detections': None,
+ 'image_name': 'b9040a0dbb22.jpg',
+ 'matched_sim': 0.27694183588027954,
+ 'matched_text_index': 2,
+ 'raw_url': 'http://www.hfitinfo.com/honda_fit_pics/3/2/index.90.jpg'},
+ {'face_detections': None,
+ 'image_name': 'db1c21bc8474.jpg',
+ 'matched_sim': 0.3234919607639313,
+ 'matched_text_index': 1,
+ 'raw_url': 'http://www.hfitinfo.com/honda_fit_pics/3/2/index.91.jpg'}],
+ 'similarity_matrix': [[0.24363446235656738,
+ 0.31758785247802734,
+ 0.27694183588027954],
+ [0.2233106791973114,
+ 0.3234919607639313,
+ 0.26118797063827515]],
+ 'text_list': ['When you lock the door using the lock tab on the driver’s '
+ 'door, all of the other doors and tailgate lock at the same '
+ 'time.',
+ 'Press the master door lock switch in as shown to lock or '
+ 'unlock all doors and the tailgate.',
+ 'When you lock/unlock the driver’s door and tailgate using the '
+ 'master lock switch, all the other doors lock/ unlock at the '
+ 'same time.'],
+ 'url': 'http://www.hfitinfo.com/hofi-48.html',
+ 'could_have_url_duplicate': 0 }
+ """
+ stat_counter = defaultdict(int)
+ text_list = jsonl_row["text_list"]
+ images_insert_before_text = [ [] for _ in range(len(text_list)) ]
+ images_insert_after_text = [ [] for _ in range(len(text_list)) ]
+
+ for image_info in jsonl_row["image_info"]:
+ # randomly decide whether to prepend or append the image to the corresponding text
+ insert_before = random.random() < args.before_ratio
+ try:
+ # print('image_info', image_info)
+ mapped_to = mapping.loc[image_info["raw_url"]]
+ tar_filepath = mapped_to["tar_filepath"]
+ key = mapped_to["key"]
+ except KeyError as e:
+ print('e', e)
+ if args.remove_instances_missing_images:
+ stat_counter["instance_skipped_due_to_missing_image"] += 1
+ return None # skip this instance
+ else:
+ stat_counter["n_missing_images"] += 1
+ continue # skip this image
+
+ print('tar_filepath', tar_filepath)
+
+ # Process image
+ image_bytes = load_image(tar_filepath, key)
+ image_base64 = load_image_bytes_to_base64(image_bytes)
+ image_content = {
+ "type": "image_url",
+ "image_url": {"url": image_base64}
+ }
+
+ stat_counter["n_images_inserted"] += 1
+
+ if insert_before:
+ stat_counter["n_images_inserted_before_text"] += 1
+ images_insert_before_text[image_info["matched_text_index"]].append(image_content)
+ else:
+ stat_counter["n_images_inserted_after_text"] += 1
+ images_insert_after_text[image_info["matched_text_index"]].append(image_content)
+
+ # flatten content: list of list of content -> list of content
+ content = []
+ for i, text in enumerate(text_list):
+ content.extend(images_insert_before_text[i])
+ content.append({"type": "text", "text": text})
+ content.extend(images_insert_after_text[i])
+
+ print(f"Saved a total of {stat_counter['n_images_inserted']} images")
+
+ return [
+ {
+ # since we are doing pre-training, we just set
+ # the role to assistant for all instances
+ # (this is required for training pipeline)
+ "role": "assistant",
+ "content": content
+ }
+ ], stat_counter
+
+
+# Load the docs
+docs_filepaths = glob(args.input_docs_glob)
+assert len(docs_filepaths) == 23085
+n_files_per_shard = math.ceil(len(docs_filepaths) / args.n_output_shards)
+
+pbar = tqdm(total=len(docs_filepaths))
+
+def jsonl_generator_fn(filepath):
+ with open(filepath) as f:
+ for line in f:
+ yield json.loads(line)
+ pbar.update(1)
+
+stats_counter = defaultdict(int)
+for shard_id in range(args.n_output_shards):
+ start = shard_id * n_files_per_shard
+ end = min((shard_id + 1) * n_files_per_shard, len(docs_filepaths))
+ pbar.set_description(f"Processing Shard {shard_id}: {start}-{end}")
+
+ # Build generator for parallel processing
+ jsonl_generator = itertools.chain.from_iterable(
+ map(jsonl_generator_fn, docs_filepaths[start:end])
+ )
+
+ # Process the data
+ # , initializer=initializer, initargs=(args,)
+ with mp.Pool(args.n_workers) as pool, \
+ gzip.open(args.output_filepath.format(shard_id=shard_id), "wt") as fout:
+ instances_generator = pool.imap(convert_data_instance, jsonl_generator, args.chunk_size)
+ # instances_generator = map(convert_data_instance, jsonl_generator)
+ for content in instances_generator:
+ if content is not None:
+ (instances, cur_stats_counter) = content
+ fout.write(json.dumps(instances) + "\n")
+ # add stats
+ for k, v in cur_stats_counter.items():
+ stats_counter[k] += v
+ pbar.set_postfix(stats_counter)
+
+pbar.close()
+
+with open(args.output_filepath.replace(".shard_{shard_id:03d}.jsonl.gz", ".stats.json"), "w") as f:
+ f.write(json.dumps(stats_counter, indent=4))
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/mmc4/read_wds.py b/unidisc/datasets/preprocessing/mmc4/read_wds.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ac9ec65057ebfdf8980ae8f9f58506f63d5e6e6
--- /dev/null
+++ b/unidisc/datasets/preprocessing/mmc4/read_wds.py
@@ -0,0 +1,12 @@
+import webdataset as wds
+import braceexpand
+from torch.utils.data import IterableDataset
+from webdataset import gopen
+from itertools import islice
+dataset = wds.WebDataset("/home/aswerdlo/hdd/data/diffusion/mmc4/core/img2dataset_imgs/07878.tar")
+
+for sample in islice(dataset, 0, 3):
+ for key, value in sample.items():
+ print(key, repr(value)[:50])
+ print()
+ breakpoint()
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/mmc4/trim_jsonl.py b/unidisc/datasets/preprocessing/mmc4/trim_jsonl.py
new file mode 100644
index 0000000000000000000000000000000000000000..e90226cbdde6f60e5bd3ca7c4401a4bed140e611
--- /dev/null
+++ b/unidisc/datasets/preprocessing/mmc4/trim_jsonl.py
@@ -0,0 +1,111 @@
+import pandas as pd
+from pathlib import Path
+import tarfile
+import glob
+import pickle
+from tqdm import tqdm
+import json
+import multiprocessing as mp
+
+mmc4_mapping_parquet = "/path/mmc4/fewer_faces/concatenated_mmc4.parquet"
+
+mapping = pd.read_parquet(mmc4_mapping_parquet)
+print("Finished loading mapping")
+mapping = mapping[['url', 'tar_filepath', 'key']]
+mapping = mapping.set_index("url")
+mapping = mapping.sort_values('tar_filepath')
+print("Finished sorting mapping")
+
+def process_tar_file(tar_filepath):
+ try:
+ with tarfile.open(tar_filepath) as tar:
+ return tar_filepath, set(tar.getnames())
+ except:
+ return tar_filepath, set()
+
+def get_cache():
+ from constants import UNIDISC_DIR
+
+ cache_path = UNIDISC_DIR / "archive" / "tar_contents_cache.pkl"
+ if cache_path.exists():
+ with open(cache_path, 'rb') as f:
+ return pickle.load(f)
+
+ _tar_contents_cache = {}
+ unique_tar_filepaths = mapping['tar_filepath'].unique()
+
+ # Use all available CPU cores
+ with mp.Pool() as pool:
+ results = list(tqdm(
+ pool.imap(process_tar_file, unique_tar_filepaths),
+ total=len(unique_tar_filepaths),
+ desc="Building tar contents cache"
+ ))
+
+ # Convert results to dictionary
+ _tar_contents_cache = dict(results)
+
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
+ with open(cache_path, 'wb') as f:
+ pickle.dump(_tar_contents_cache, f)
+
+ return _tar_contents_cache
+
+_tar_contents_cache = get_cache()
+jsonl_files = glob.glob("/path/mmc4/fewer_faces/shards/*.jsonl")
+output_dir = Path("/path/mmc4/fewer_faces/filtered_shards")
+output_dir.mkdir(parents=True, exist_ok=True)
+
+data_items = []
+for file in tqdm(jsonl_files, desc="Reading JSONL files"):
+ with open(file, 'r') as f:
+ for line in tqdm(f, desc=f"Processing {file}"):
+ data = json.loads(line)
+ has_valid_image = False
+ tar_filepath = None # Will hold the tar_filepath of the first valid image
+
+ for image_info in data["image_info"]:
+ try:
+ mapped_to_ = mapping.loc[image_info["raw_url"]]
+ if isinstance(mapped_to_, pd.Series):
+ mapped_to_ = [mapped_to_]
+ elif isinstance(mapped_to_, pd.DataFrame):
+ mapped_to_ = [row for _, row in mapped_to_.iterrows()]
+ else:
+ mapped_to_ = [mapped_to_]
+ for mapped_to in mapped_to_:
+ tar_filepath_candidate = mapped_to["tar_filepath"]
+ relative_tar_filepath_candidate = None
+ if "relative_tar_filepath" in mapped_to:
+ relative_tar_filepath_candidate = mapped_to["relative_tar_filepath"]
+
+ key = mapped_to["key"]
+ if f"{key}.jpg" in _tar_contents_cache[tar_filepath_candidate]:
+ has_valid_image = True
+ tar_filepath = tar_filepath_candidate
+ image_info["tar_filepath"] = tar_filepath
+ if relative_tar_filepath_candidate is not None:
+ image_info["relative_tar_filepath"] = relative_tar_filepath_candidate
+ image_info["key"] = key
+ break
+ if has_valid_image:
+ break
+ except KeyError:
+ continue
+
+ if has_valid_image and tar_filepath is not None:
+ data_items.append((tar_filepath, line))
+
+# Sort data_items by tar_filepath
+data_items.sort(key=lambda x: x[0])
+chunk_size = len(data_items) // 200
+output_files = []
+for i in range(0, len(data_items), chunk_size):
+ chunk = data_items[i:i+chunk_size]
+ output_path = output_dir / f"sorted_shard_{i//chunk_size:05d}.jsonl"
+ with open(output_path, 'w') as f_out:
+ for _, line in chunk:
+ f_out.write(line)
+ output_files.append(output_path)
+
+print(f"Finished writing sorted data into {len(output_files)} files.")
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/pixelprose/README.md b/unidisc/datasets/preprocessing/pixelprose/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7a74c38f595a24c2a8b1014f5c4a0124011f03ac
--- /dev/null
+++ b/unidisc/datasets/preprocessing/pixelprose/README.md
@@ -0,0 +1,20 @@
+# PixelProse
+
+To download the metadata, run:
+
+```bash
+huggingface-cli download tomg-group-umd/pixelprose --repo-type dataset --local-dir .
+```
+
+
+To download the images, run:
+
+```bash
+input_folder='/path/to/pixelprose/data'
+img2dataset --url_list "$input_folder" --input_format "parquet" \
+--url_col "url" --caption_col "vlm_caption" --output_format webdataset \
+--output_folder pixelprose_data --processes_count 16 --thread_count 32 \
+--save_additional_columns '["original_caption", "uid"]' --enable_wandb True --image_size 256 --output_folder "/path/to/output/folder" --resize_mode center_crop
+```
+
+Please see the [WebDataset](../webdataset.md) for more information on how to further process and then tokenize the WebDataset.
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/unidisc_dataset/README.md b/unidisc/datasets/preprocessing/unidisc_dataset/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0d2bf5176b6fabc38f75e2f8fd15319563c16035
--- /dev/null
+++ b/unidisc/datasets/preprocessing/unidisc_dataset/README.md
@@ -0,0 +1,92 @@
+# UniDisc Dataset
+
+The dataset generation consists of several parts:
+
+1. Combine seed prompts from multiple prior datasets.
+2. Send off generation jobs, each of which use an LLM to augment the prompts and pass these to a diffusion model to generate images.
+3. Postprocess the generated files on disk to create a parquet metadata file and finally convert the images to a WebDataset.
+
+## Seed Prompts
+
+We acquire seed prompts from the following datasets:
+
+- [ImageRewardDB](https://huggingface.co/datasets/THUDM/ImageRewardDB/)
+Please see [process_image_reward.py](./combine_prompts/process_image_reward.py) for the script used to process this dataset.
+- [simulacra-aesthetic-captions](https://github.com/JD-P/simulacra-aesthetic-captions)
+Please see [process_sac.py](./combine_prompts/process_sac.py) for the script used to process this dataset.
+- [PickScore](https://github.com/yuvalkirstain/PickScore)
+Please see [process_pickscore.py](./combine_prompts/process_pickscore.py) for the script used to process this dataset.
+- [HPDv2](https://huggingface.co/datasets/ymhao/HPDv2)
+Please download and concatenate the `json` files in [this](https://huggingface.co/datasets/ymhao/HPDv2/tree/main/benchmark) folder.
+- [Gecko Benchmark](https://huggingface.co/datasets/google-deepmind/gecko_benchmark_t2i) (For validation prompts only)
+Please see [process_gecko.py](./combine_prompts/process_gecko.py) for the script used to process this dataset.
+
+## Generation
+
+### LLM Prompting
+First, you will need to setup the LLM prompting code. We found smaller models such as `gpt-4o-mini` to not work as well as larger models and even with smaller models, we did not have sufficient funds to generate ~2.3 billion tokens. Thus, we use `langchain` to use multiple free LLMs along with local Ollama instances and paid backup APIs to distribute traffic.
+
+To install and run Ollama:
+
+```bash
+curl -L https://ollama.com/download/ollama-linux-amd64 -o $HOME/bin/ollama
+chmod +x $HOME/bin/ollama
+OLLAMA_HOST=0.0.0.0:11434 ollama serve
+```
+
+Then, you will need to set the corresponding Ollama server hostname/port in [llm_prompting.py](./generate/llm_prompting.py). Moreover, the code supports a round-robin mechanism, allowing you to further balance the load across multiple hosts.
+
+Please see [llm_prompting.py](./generate/llm_prompting.py) for the code used to setup the LLM calling.
+
+### Server
+Next, the primary generation code consists of a client and server architecture to properly assign jobs. This allows for better distribution of workloads and for robust failure handling. Jobs may die for any number of reasons (pre-emption, bad GPUs, disk failures, etc.).
+
+Please see [image_server.py](./generate/image_server.py) for the code used to setup the job server.
+
+You may run the server using the following SLURM command:
+
+**Note**: The client hardcodes the server hostname in this setting, so you must match `main-node` to the `hosting_node` in [generate/generate_images.py](./generate/generate_images.py).
+
+```bash
+sbatch --job-name='image_server' --mem=16G --cpus-per-task=4 --nodelist=main-node --time=4320 --partition=general --wrap "python $UNIDISC_DIR/unidisc/datasets/prompts/image_server.py" --output=$UNIDISC_DIR/outputs/generate_images/image_server.out --error=$UNIDISC_DIR/outputs/generate_images/image_server.out
+```
+
+or standalone (again, you must properly set the `hosting_node` in [generate/generate_images.py](./generate/generate_images.py)):
+
+```bash
+python $UNIDISC_DIR/unidisc/datasets/prompts/image_server.py
+```
+
+### Client
+
+This will dispatch an sbatch array job with 128 simultaneous jobs and 1000 total jobs:
+
+**Note**: In this case, we are using the `HPDv2.json` file we have generated. You can use any other json file you specify.
+
+```bash
+python $UNIDISC_DIR/unidisc/datasets/prompts/generate_images.py HPDv2.json --expected_samples_per_index=200 --num_workers=128 --num_chunks=1000 --max_chunk_size=512 --use_slurm --compile=True &
+```
+
+To monitor outputs:
+
+```bash
+tail -f -n1000 $UNIDISC_DIR/outputs/generate_images/image_server.out
+cd $UNIDISC_DIR/outputs/generate_images && /bin/ls -t | head -10 | xargs tail -n 100
+find $UNIDISC_DIR/outputs/generate_images/generated_images -type f -name "*.json" | wc -l
+```
+
+
+## Postprocessing
+
+To create the parquet metadata file, run the following command:
+
+```bash
+python $UNIDISC_DIR/unidisc/datasets/preprocessing/unidisc_dataset/postprocess_dataset/convert_json_to_parquet.py /path/to/data /path/to/output.parquet
+```
+
+To package the images into a WebDataset, run the following command:
+
+```bash
+python $UNIDISC_DIR/unidisc/datasets/preprocessing/unidisc_dataset/postprocess_dataset/convert_parquet_to_wds.py --parquet_file /path/to/parquet/file --base_dir /path/to/image/folder --output_dir /path/to/output/webdataset --num_workers=64
+```
+
diff --git a/unidisc/datasets/preprocessing/unidisc_dataset/combine_prompts/process_gecko.py b/unidisc/datasets/preprocessing/unidisc_dataset/combine_prompts/process_gecko.py
new file mode 100644
index 0000000000000000000000000000000000000000..be8c8d353086dafac9c2bf02bf35cb49ead0efaf
--- /dev/null
+++ b/unidisc/datasets/preprocessing/unidisc_dataset/combine_prompts/process_gecko.py
@@ -0,0 +1,18 @@
+import pandas as pd
+import json
+
+# https://github.com/google-deepmind/gecko_benchmark_t2i/blob/main/prompts.csv
+df = pd.read_csv('prompts.csv')
+
+def clean_prompt(prompt):
+ ascii_prompt = ''.join(char for char in prompt if ord(char) < 128)
+ cleaned_prompt = ascii_prompt.replace('\n', ' ').strip()
+ return cleaned_prompt
+
+
+df['prompt'] = df['prompt'].apply(clean_prompt)
+df = df.sample(frac=1).reset_index(drop=True)
+sampled_list = df['prompt'].tolist()
+
+with open('sampled_prompts.json', 'w') as f:
+ json.dump(sampled_list, f)
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/unidisc_dataset/combine_prompts/process_image_reward.py b/unidisc/datasets/preprocessing/unidisc_dataset/combine_prompts/process_image_reward.py
new file mode 100644
index 0000000000000000000000000000000000000000..48c447d371f26dbf6bc4ec0f2185988cec10c594
--- /dev/null
+++ b/unidisc/datasets/preprocessing/unidisc_dataset/combine_prompts/process_image_reward.py
@@ -0,0 +1,49 @@
+from datasets import load_dataset
+import json
+import re
+
+dataset = load_dataset("parquet", data_files={'test': 'https://huggingface.co/datasets/THUDM/ImageRewardDB/resolve/main/metadata-test.parquet', 'train': 'https://huggingface.co/datasets/THUDM/ImageRewardDB/resolve/main/metadata-train.parquet', 'validation': 'https://huggingface.co/datasets/THUDM/ImageRewardDB/resolve/main/metadata-validation.parquet'}, split='train+validation+test')
+
+unique_captions = set()
+
+for item in dataset:
+ unique_captions.add(item['prompt'])
+
+unique_captions = list(unique_captions)
+
+def correct_data(text):
+ # Replace 4, 3, and 2 digit numbers with spaces in between
+ text = re.sub(r'(\d) (\d) (\d) (\d)', r'\1\2\3\4', text)
+ text = re.sub(r'(\d) (\d) (\d)', r'\1\2\3', text)
+ text = re.sub(r'(\d) (\d)', r'\1\2', text)
+ # Remove spaces between a single digit and the letter k
+ text = re.sub(r'(\d) k', r'\1k', text)
+ # Replace 2 digits followed by 2 characters and a space
+ text = re.sub(r'(\d) (\d)(\w\w) ', r'\1\2\3 ', text)
+ # Remove leading and trailing whitespace
+ text = text.strip()
+ # Replace "( (" with "((" and "[ [" with "[["
+ text = re.sub(r'\( \(', r'((', text)
+ text = re.sub(r'\[ \[', r'[[', text)
+ # Replace ") )" with "))" and "] ]" with "]]"
+ text = re.sub(r'\) \)', r'))', text)
+ text = re.sub(r'\] \]', r']]', text)
+
+ text = re.sub(r'\( \(', r'(', text)
+ text = re.sub(r'\[ \[', r'[', text)
+ # Replace ") )" with ")" and "] ]" with "]"
+ text = re.sub(r'\) \)', r')', text)
+ text = re.sub(r'\] \]', r']', text)
+
+ text = re.sub(r'(\d) (\w) ', r'\1\2 ', text)
+
+ return text
+
+
+
+corrected_data = [correct_data(item) for item in unique_captions]
+
+with open('image_reward.json', 'w') as f:
+ json.dump(list(corrected_data), f)
+
+print(f"Total unique captions: {len(list(unique_captions))}")
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/unidisc_dataset/combine_prompts/process_pickscore.py b/unidisc/datasets/preprocessing/unidisc_dataset/combine_prompts/process_pickscore.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e10acc4e48989b25a2eb3ff4842d099b3f41178
--- /dev/null
+++ b/unidisc/datasets/preprocessing/unidisc_dataset/combine_prompts/process_pickscore.py
@@ -0,0 +1,26 @@
+from datasets import load_dataset
+import json
+
+# Load both datasets
+dataset_v1 = load_dataset('yuvalkirstain/pickapic_v1_no_images', split='train+validation+test')
+dataset_v2 = load_dataset('yuvalkirstain/pickapic_v2_no_images', split='train+validation+test')
+
+# Create a set to store unique captions
+unique_captions = set()
+
+# Add captions from v1 dataset
+for item in dataset_v1:
+ unique_captions.add(item['caption'])
+
+# Add captions from v2 dataset
+for item in dataset_v2:
+ unique_captions.add(item['caption'])
+
+# Convert set to list
+caption_list = list(unique_captions)
+
+# Write to JSON file
+with open('unique_captions.json', 'w') as f:
+ json.dump(caption_list, f, indent=2)
+
+print(f"Total unique captions: {len(caption_list)}")
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/unidisc_dataset/combine_prompts/process_sac.py b/unidisc/datasets/preprocessing/unidisc_dataset/combine_prompts/process_sac.py
new file mode 100644
index 0000000000000000000000000000000000000000..081468c12d542cc09b6dfd0c5e83b7670a4db84d
--- /dev/null
+++ b/unidisc/datasets/preprocessing/unidisc_dataset/combine_prompts/process_sac.py
@@ -0,0 +1,30 @@
+import sqlite3
+
+# Connect to the SQLite database
+conn = sqlite3.connect('sac_public_2022_06_29.sqlite')
+cursor = conn.cursor()
+# Query to select distinct prompts from the generations table
+query = "SELECT DISTINCT prompt FROM generations"
+
+# Execute the query and fetch all results
+cursor.execute(query)
+unique_prompts = cursor.fetchall()
+
+# Close the connection
+conn.close()
+
+# Function to clean each prompt
+def clean_prompt(prompt):
+ # Remove non-ASCII characters
+ ascii_prompt = ''.join(char for char in prompt if ord(char) < 128)
+ # Replace newlines with space, strip leading/trailing whitespace
+ cleaned_prompt = ascii_prompt.replace('\n', ' ').strip()
+ return cleaned_prompt
+
+# Write the unique prompts to a text file, applying cleaning
+with open('unique_prompts.txt', 'w') as file:
+ for prompt in unique_prompts:
+ cleaned_prompt = clean_prompt(prompt[0])
+ # Write prompt to file if it is not empty or only whitespace
+ if cleaned_prompt:
+ file.write(cleaned_prompt + '\n')
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/unidisc_dataset/generate/generate_images.py b/unidisc/datasets/preprocessing/unidisc_dataset/generate/generate_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..207eaf4a4514ce686be485a28404f3e3bec5e782
--- /dev/null
+++ b/unidisc/datasets/preprocessing/unidisc_dataset/generate/generate_images.py
@@ -0,0 +1,455 @@
+from __future__ import annotations
+
+import io
+import json
+import os
+import signal
+import socket
+import subprocess
+import time
+from datetime import datetime, timedelta
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import requests
+import torch
+import typer
+from tqdm import tqdm
+
+from decoupled_utils import breakpoint_on_error, check_gpu_memory_usage, rprint
+from unidisc.datasets.prompts.llm_prompting import get_llm
+
+typer.main.get_command_name = lambda name: name
+app = typer.Typer(pretty_exceptions_show_locals=False)
+
+hostname = socket.gethostname()
+hosting_node = "main-node-0-0"
+prompt_folder = Path(f"diffusion/prompts/inputs")
+root_output_folder = Path(f"diffusion/prompts/generated_images/v4")
+
+def get_list_from_file(data_path):
+ output_data = None
+
+ if not Path(data_path).exists():
+ data_path = prompt_folder / data_path
+
+ if str(data_path).endswith(".txt"):
+ with open(data_path, "r", encoding="utf-8") as file:
+ output_data = file.readlines()
+
+ elif str(data_path).endswith(".json"):
+ with open(data_path, "r", encoding="utf-8") as file:
+ output_data = json.load(file)
+
+ if output_data:
+ output_data = ["".join(char for char in line if ord(char) < 128).replace("\n", "").strip() for line in output_data]
+ output_data = [line for line in output_data if line]
+ else:
+ print("Warning: No data loaded from the file.")
+
+ output_data = [x.strip() for x in output_data]
+
+ return output_data
+
+
+def get_to_process(timestamp, data_path, return_raw_data=False):
+ data = get_list_from_file(data_path)
+ return list(range(len(data)))
+
+
+def get_pipe(compile=True, model="stabilityai/stable-diffusion-3-medium-diffusers", quantize_text_encoder=False, **kwargs):
+ from diffusers import (LuminaText2ImgPipeline, PixArtAlphaPipeline,
+ PixArtSigmaPipeline, StableDiffusion3Pipeline,
+ Transformer2DModel)
+
+ torch.set_float32_matmul_precision("high")
+
+ if compile:
+ torch._inductor.config.conv_1x1_as_mm = True
+ torch._inductor.config.coordinate_descent_tuning = True
+ torch._inductor.config.epilogue_fusion = False
+ torch._inductor.config.coordinate_descent_check_all_directions = True
+
+ if "PixArt-Sigma" in model:
+ model_cls = PixArtSigmaPipeline
+ transformer = Transformer2DModel.from_pretrained(
+ "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", # "PixArt-alpha/PixArt-Sigma-XL-2-512-MS",
+ subfolder="transformer",
+ use_safetensors=True,
+ **kwargs,
+ )
+ pipe = PixArtSigmaPipeline.from_pretrained(
+ "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", transformer=transformer, use_safetensors=True, **kwargs
+ )
+ elif "stable-diffusion-3" in model and quantize_text_encoder:
+ from transformers import BitsAndBytesConfig, T5EncoderModel
+
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+ text_encoder = T5EncoderModel.from_pretrained(model, subfolder="text_encoder_3", quantization_config=quantization_config, device_map="auto")
+ pipe = StableDiffusion3Pipeline.from_pretrained(model, text_encoder_3=text_encoder, device_map="balanced", **kwargs)
+ else:
+ if "stable-diffusion-3" in model:
+ model_cls = StableDiffusion3Pipeline
+ elif "PixArt-alpha" in model:
+ model_cls = PixArtAlphaPipeline
+ else:
+ model_cls = LuminaText2ImgPipeline
+
+ if "PixArt" in model:
+ kwargs["use_safetensors"] = True
+ kwargs["device_map"] = "balanced"
+
+ rprint(f"Loading model: {model}")
+ pipe = model_cls.from_pretrained(model, **kwargs)
+
+ pipe = pipe.to("cuda")
+ pipe.set_progress_bar_config(disable=True)
+
+ # pipe.enable_xformers_memory_efficient_attention()
+
+ if compile:
+ pipe.transformer.to(memory_format=torch.channels_last)
+ pipe.vae.to(memory_format=torch.channels_last)
+
+ pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
+ pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
+
+ rprint(f"Compiled model: {model}")
+
+ return pipe
+
+
+def get_indices_from_server(slurm_job_id, chunk_size, total_indices, output_dir, expected_samples_per_index):
+ import time
+ rprint(f"Getting indices from server: {slurm_job_id}, {chunk_size}, {total_indices}, {output_dir}, {expected_samples_per_index}")
+ retries = 10
+ for attempt in range(retries):
+ try:
+ response = requests.post(
+ f"http://{hosting_node}:5000/get_indices",
+ json={"slurm_job_id": slurm_job_id, "chunk_size": chunk_size, "total_indices": total_indices, "output_dir": str(output_dir), "expected_samples_per_index": expected_samples_per_index},
+ timeout=1200
+ )
+ rprint(f"Response: {response}")
+ if response.status_code == 200:
+ return response.json().get("indices", [])
+ except requests.RequestException as e:
+ rprint(f"Attempt {attempt + 1} failed: {e}")
+ if attempt < retries - 1:
+ time.sleep(120)
+ else:
+ raise Exception(f"Failed to get indices from server: {slurm_job_id}, {chunk_size}, {total_indices}, {output_dir}, {expected_samples_per_index}")
+
+
+def train(data_path, indices, batch_size=None, compile=True, model="stabilityai/stable-diffusion-3-medium-diffusers", resolution=512, augment_prompts=True, **kwargs):
+ rprint(f"Unused kwargs: {kwargs}, augment_prompts: {augment_prompts}")
+ rprint(f"Before getting device")
+ result = torch.cuda.get_device_name()
+ rprint(f"Initializing on {hostname}, Got {len(indices)} indices.")
+ rprint(f"GPU name: {result}")
+ check_gpu_memory_usage()
+
+ model_kwargs = dict(torch_dtype=torch.bfloat16, compile=compile, model=model)
+
+ if "A5500" in result or "A5000" in result:
+ if compile:
+ gpu_batch_size = 4
+ else:
+ gpu_batch_size = 8
+ elif "A100" in result or "6000 Ada" in result or "A6000" in result:
+ gpu_batch_size = 12
+ elif "V100" in result:
+ model_kwargs.update(torch_dtype=torch.float16)
+ gpu_batch_size = 6
+ else:
+ model_kwargs.update(torch_dtype=torch.float16)
+ gpu_batch_size = 4
+
+ if ("080" in result or "TITAN X" in result) and "stable-diffusion-3" in model:
+ rprint("Disabling text encoder 3 and compile...")
+ model_kwargs.update(text_encoder_3=None, tokenizer_3=None)
+ compile = False
+ model_kwargs["compile"] = False
+
+ if "3090" in result:
+ compile = False
+ model_kwargs["compile"] = False
+
+ if batch_size is None:
+ batch_size = gpu_batch_size
+
+ rprint(f"Using batch size: {batch_size}")
+ pipe = get_pipe(**model_kwargs)
+
+ if ("080" in result or "TITAN X" in result) and ("stable-diffusion-3" in model or "PixArt" in model):
+ rprint("Enabling model offload...")
+ pipe.enable_model_cpu_offload()
+
+ llm_model_type = "gpt-4o-mini"
+ llm = get_llm(hosting_node=hosting_node, llm_model_type=llm_model_type)
+
+ selected_lines = get_list_from_file(data_path)
+
+ output_folder = root_output_folder / data_path.stem
+ output_folder.mkdir(parents=True, exist_ok=True)
+
+ gpu_name = torch.cuda.get_device_name()
+ new_samples_per_index = 10
+
+ for i in tqdm(range(0, (len(indices) // batch_size) + 1), unit="it", unit_scale=batch_size, desc="Processing batches"):
+ initial_prompts = []
+ for j in range(batch_size):
+ if i * batch_size + j < len(indices):
+ initial_prompts.append(
+ (indices[i * batch_size + j], 0, "", selected_lines[indices[i * batch_size + j]], selected_lines[indices[i * batch_size + j]])
+ )
+
+ if len(initial_prompts) == 0:
+ rprint("No prompts to process...skipping")
+ continue
+
+ augmented_prompts = []
+ successful_initial_prompts = []
+ for prompt_index, _, _, original_prompt, _ in initial_prompts:
+ if augment_prompts:
+ try:
+ rprint(f'Given original prompt: "{original_prompt}", (index: {prompt_index})')
+ generated_prompts, llm_model_name = llm(prompt=original_prompt, new_samples_per_index=new_samples_per_index)
+ except Exception as e:
+ rprint(f"Error generating prompts for {original_prompt}: {e}")
+ continue
+
+ rprint(f"Generated prompts: {generated_prompts}")
+
+ if len(generated_prompts) < new_samples_per_index - 1: # Allow one less
+ rprint(f"Only {len(generated_prompts)} prompts generated for {original_prompt}")
+ continue
+
+ augmented_prompts.extend(
+ [(prompt_index, k + 1, llm_model_name, original_prompt, aug_prompt) for k, aug_prompt in enumerate(generated_prompts)]
+ )
+ successful_initial_prompts.append((prompt_index, 0, "", original_prompt, original_prompt))
+
+ successful_initial_prompts = [
+ prompt for prompt in successful_initial_prompts
+ if not (output_folder / f"{prompt[0]}_0.json").exists()
+ ]
+
+ if len(successful_initial_prompts) < len(initial_prompts):
+ rprint(f"Filtered out {len(initial_prompts) - len(successful_initial_prompts)} already processed initial prompts")
+
+ all_prompts = successful_initial_prompts + augmented_prompts
+ rprint(f"Total prompts to generate: {len(all_prompts)}")
+
+ total_generated = 0
+ for k in range(0, len(all_prompts), batch_size):
+ batch_all_prompts = all_prompts[k : k + batch_size]
+ pipe_kwargs = dict(negative_prompt=[""] * len(batch_all_prompts), height=resolution, width=resolution)
+
+ if "lumina" in model:
+ pipe_kwargs.pop("height")
+ pipe_kwargs.pop("width")
+
+ start_time = time.time()
+ images = pipe(list(map(lambda x: x[-1], batch_all_prompts)), **pipe_kwargs).images
+ end_time = time.time()
+ rprint(f"Image generation time: {end_time - start_time:.2f} seconds")
+
+ for j, image in enumerate(images):
+ total_generated += 1
+ prompt_idx, augmentation_idx, llm_model_name, original_prompt, augmented_prompt = batch_all_prompts[j]
+ generation_id = f"{prompt_idx}_{augmentation_idx}"
+ output_image_path = output_folder / f"{generation_id}.jpg"
+
+ while output_image_path.exists():
+ augmentation_idx += 1
+ generation_id = f"{prompt_idx}_{augmentation_idx}"
+ output_image_path = output_folder / f"{generation_id}.jpg"
+
+ image.save(output_image_path)
+ metadata = {
+ "prompt_index": prompt_idx,
+ "augmentation_idx": augmentation_idx,
+ "original_prompt": original_prompt,
+ "augmented_prompt": augmented_prompt,
+ "is_augmented": original_prompt != augmented_prompt,
+ "model_name": model,
+ "llm_model_name": llm_model_name,
+ "height": pipe_kwargs.get("height"),
+ "width": pipe_kwargs.get("width"),
+ "input_file": data_path.stem,
+ "hostname": hostname,
+ "image_path": str(output_image_path),
+ "gpu_name": gpu_name,
+ "generation_timestamp": datetime.now().isoformat(),
+ }
+
+ metadata_file = output_image_path.with_suffix(".json")
+ with open(metadata_file, "w") as f:
+ json.dump(metadata, f)
+
+ rprint(f"Generated {total_generated} prompts out of total {len(all_prompts)} prompts")
+
+ exit()
+
+
+def tail_log_file(log_file_path, glob_str=None):
+ import subprocess
+ import time
+
+ max_retries = 60
+ retry_interval = 4
+ for _ in range(max_retries):
+ try:
+ if (glob_str is None and Path(log_file_path).exists()) or len(list(Path(log_file_path).rglob(glob_str))) > 0:
+ try:
+ if glob_str is None:
+ print(f"Tailing {log_file_path}")
+ proc = subprocess.Popen(["tail", "-f", "-n", "+1", f"{log_file_path}"], stdout=subprocess.PIPE)
+ else:
+ print(["tail", "-f", "-n", "+1", f"{log_file_path}/{glob_str}"])
+ proc = subprocess.Popen(["sh", "-c", f"tail -f -n +1 {log_file_path}/{glob_str}"], stdout=subprocess.PIPE)
+ for line in iter(proc.stdout.readline, b""):
+ print(line.decode("utf-8"), end="")
+ except:
+ proc.terminate()
+ except:
+ print(f"Tried to glob: {log_file_path}, {glob_str}")
+ finally:
+ time.sleep(retry_interval)
+
+ print(f"File not found: {log_file_path} after {max_retries * retry_interval} seconds...")
+
+
+# TODO: Set this if desired
+cluster_node_gpus = {
+ "main-node-0-0": "titanx",
+}
+
+def get_excluded_nodes(*args):
+ return [x for x in cluster_node_gpus.keys() if any(s in cluster_node_gpus[x] for s in args)]
+
+
+def run_slurm(data_path, num_chunks, num_workers, current_datetime, partition, chunk_size, extra_args, tail_log=False):
+ print(f"Running slurm job with {num_chunks} chunks and {num_workers} workers...")
+ os.environ.pop("CUDA_VISIBLE_DEVICES", None)
+ from simple_slurm import Slurm
+
+ hostname = socket.gethostname()
+
+ kwargs = dict()
+ # TODO: Only needed if you wish to exclude specific bad nodes.
+ if "main-node" in hostname:
+ exclude = set(get_excluded_nodes())
+ exclude.add("main-node-0-0")
+ kwargs["exclude"] = ",".join(exclude)
+ print(f"Excluding nodes: {kwargs['exclude']}")
+
+ log_folder = Path("outputs/generate_images")
+ log_folder.mkdir(parents=True, exist_ok=True)
+ slurm = Slurm(
+ "--requeue",
+ job_name=f"generate_parallel_{data_path.stem}",
+ cpus_per_task=8,
+ mem="24g",
+ export="ALL",
+ gres=["gpu:1"],
+ output=f"{str(log_folder)}/{Slurm.JOB_ARRAY_MASTER_ID}_{Slurm.JOB_ARRAY_ID}.out",
+ time=timedelta(days=3, hours=0, minutes=0, seconds=0) if "kate" in partition else timedelta(days=0, hours=6, minutes=0, seconds=0),
+ array=f"0-{num_chunks-1}%{num_workers}",
+ partition=partition,
+ comment="generate",
+ **kwargs,
+ )
+ job_id = slurm.sbatch(
+ f"python {Path(__file__).relative_to(os.getcwd())} {data_path} --is_slurm_task --slurm_task_datetme={current_datetime} --slurm_task_index=$SLURM_ARRAY_TASK_ID --chunk_size={chunk_size} {' '.join(extra_args)}"
+ )
+ print(f"Submitted job {job_id} with {num_chunks} tasks and {num_workers} workers...")
+ if tail_log:
+ tail_log_file(Path(f"outputs/generate_images"), f"{job_id}*")
+
+
+@app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True})
+def main(
+ ctx: typer.Context,
+ data_path: Path,
+ num_workers: int = 1,
+ use_slurm: bool = False,
+ is_slurm_task: bool = False,
+ slurm_task_datetme: str = None,
+ slurm_task_index: int = None,
+ max_chunk_size: int = 20000,
+ num_chunks: Optional[int] = None,
+ shuffle: bool = True,
+ shuffle_seed: int = 42,
+ invalidate_cache: bool = False,
+ partition: str = "all",
+ chunk_size: Optional[int] = None,
+ tail_log: bool = False,
+):
+
+ rprint(f"Running with data_path: {data_path}, args: {ctx.args}")
+ default_values = dict(compile=False, batch_size=None, model="stabilityai/stable-diffusion-3-medium-diffusers", resolution=512, expected_samples_per_index=100, augment_prompts=True)
+ for arg in ctx.args:
+ if arg.removeprefix("--").split("=")[0] in default_values:
+ default_values[arg.removeprefix("--").split("=")[0]] = arg.split("=")[1]
+ else:
+ assert False, f"Unknown argument: {arg}"
+
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = str((Path.home() / ".cache" / "torchinductor").resolve())
+ current_datetime = datetime.now()
+ datetime_up_to_hour = current_datetime.strftime("%Y_%m_%d_%H_00_00") if use_slurm else current_datetime.strftime("%Y_%m_%d_00_00_00")
+ _timestamp = slurm_task_datetme if is_slurm_task else datetime_up_to_hour
+ if invalidate_cache or use_slurm:
+ _timestamp = current_datetime.strftime("%Y_%m_%d_%H_%M_00")
+
+ dataset = get_to_process(_timestamp, data_path)
+
+ if is_slurm_task:
+ slurm_job_id = os.environ.get("SLURM_JOB_ID")
+ if not slurm_job_id:
+ raise RuntimeError("SLURM_JOB_ID environment variable not set")
+ indices = get_indices_from_server(slurm_job_id, chunk_size, len(dataset), Path(root_output_folder) / data_path.stem, default_values["expected_samples_per_index"])
+ if indices is None or len(indices) == 0:
+ rprint(f"No images to process. Exiting...")
+ exit()
+ rprint(f"Running slurm task {slurm_task_index} with {len(indices)} images...")
+ train(data_path, indices, **default_values)
+ exit()
+
+ submission_list = list(range(len(dataset)))
+ if len(submission_list) == 0:
+ rprint("No images to process. Exiting...")
+ exit()
+
+ if shuffle:
+ import random
+
+ random.seed(shuffle_seed)
+ random.shuffle(submission_list)
+
+ if chunk_size is None:
+ chunk_size = min(len(submission_list) // num_workers, max_chunk_size) # Adjust this based on the number of workers
+
+ chunks = [submission_list[i : i + chunk_size] for i in range(0, len(submission_list), chunk_size)]
+ assert sum([len(chunk) for chunk in chunks]) == len(submission_list)
+ if len(chunks) > 999:
+ rprint(f"Too many chunks ({len(chunks)}), truncating to 999...")
+ chunks = chunks[:999]
+
+ num_chunks = num_chunks if num_chunks is not None else len(chunks)
+
+ if use_slurm:
+ run_slurm(data_path, num_chunks, num_workers, datetime_up_to_hour, partition, chunk_size, tail_log=tail_log, extra_args=ctx.args)
+ exit()
+ else:
+ import random
+ random.shuffle(submission_list)
+
+ with breakpoint_on_error():
+ train(data_path, submission_list, **default_values)
+
+
+if __name__ == "__main__":
+ app()
diff --git a/unidisc/datasets/preprocessing/unidisc_dataset/generate/image_server.py b/unidisc/datasets/preprocessing/unidisc_dataset/generate/image_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..219d0b303910a4b1905fc9760420661b40e56749
--- /dev/null
+++ b/unidisc/datasets/preprocessing/unidisc_dataset/generate/image_server.py
@@ -0,0 +1,130 @@
+import json
+import random
+import subprocess
+from collections import defaultdict
+from datetime import datetime, timedelta
+from pathlib import Path
+
+import joblib
+from flask import Flask, jsonify, request
+
+from decoupled_utils import rprint
+
+app = Flask(__name__)
+
+job_allocations = {}
+completed_indices = set()
+first_request = True
+
+allocations_file = Path("static/allocations.json")
+memory = joblib.Memory("static/.cache", verbose=0)
+cache_duration = timedelta(minutes=60)
+
+@memory.cache
+def load_completed_indices(output_dir, expected_samples_per_index):
+ rprint(f"Loading completed indices from {output_dir}")
+ completed_files = output_dir.glob("*.json")
+ index_count = defaultdict(int)
+ for f in completed_files:
+ index = int(f.stem.split("_")[0])
+ index_count[index] += 1
+ rprint(f"Finsihed loading completed indices: {index_count}")
+ return {index for index, count in index_count.items() if count >= (expected_samples_per_index - 2)}
+
+
+def get_running_jobs():
+ try:
+ result = subprocess.run(["squeue", "-h", "-o", "%A"], capture_output=True, text=True)
+ job_ids = result.stdout.split()
+ return set(job_ids)
+ except subprocess.CalledProcessError as e:
+ rprint(f"Failed to run squeue: {e}")
+ return set()
+
+
+def save_allocations():
+ allocations_file.parent.mkdir(parents=True, exist_ok=True)
+ with allocations_file.open("w") as f:
+ json.dump(
+ {
+ k: {"indices": list(v["indices"]), "timestamp": v["timestamp"].isoformat(), "dataset_key": v["dataset_key"]}
+ for k, v in job_allocations.items()
+ },
+ f,
+ )
+
+
+def load_allocations():
+ if allocations_file.exists():
+ with allocations_file.open("r") as f:
+ data = json.load(f)
+ return {
+ k: {"indices": set(v["indices"]), "timestamp": datetime.fromisoformat(v["timestamp"]), "dataset_key": v["dataset_key"]}
+ for k, v in data.items()
+ }
+ return {}
+
+
+job_allocations = load_allocations()
+
+
+@app.route("/get_indices", methods=["POST"])
+def get_indices():
+ global last_cache_time
+
+ slurm_job_id = request.json.get("slurm_job_id", None)
+ chunk_size = request.json.get("chunk_size", None)
+ total_indices = request.json.get("total_indices", None)
+ output_dir_path = request.json.get("output_dir", None)
+ expected_samples_per_index = int(request.json.get("expected_samples_per_index", 100))
+
+ if not slurm_job_id or not output_dir_path or not chunk_size or not total_indices:
+ return jsonify({"error": "SLURM job ID and output directory are required"}), 400
+
+ output_dir = Path(output_dir_path)
+ dataset_key = output_dir.stem
+ current_time = datetime.now()
+
+ if "last_cache_time" not in globals() or current_time - last_cache_time > cache_duration:
+ load_completed_indices.clear()
+ last_cache_time = current_time
+
+ completed_indices = set()
+ running_jobs = get_running_jobs()
+
+ n_hours = 12
+ threshold_time = datetime.now() - timedelta(hours=n_hours)
+ for job_id in list(job_allocations.keys()):
+ if job_id not in running_jobs or job_allocations[job_id]["timestamp"] < threshold_time:
+ rprint(f"Deleting job {job_id} from allocations")
+ del job_allocations[job_id]
+
+ if slurm_job_id in job_allocations:
+ allocated_indices = job_allocations[slurm_job_id]["indices"]
+ num_available_indices = None
+ rprint(f"Job {slurm_job_id} already allocated indices: {allocated_indices}")
+ else:
+ all_reserved_indices = {idx for indices in job_allocations.values() for idx in indices["indices"]}
+ available_indices = set(range(0, total_indices)) - completed_indices - all_reserved_indices
+
+ if not available_indices:
+ rprint(f"No indices available for job {slurm_job_id}, completed len: {len(completed_indices)}, running jobs: {len(running_jobs)}, current allocations: {len(job_allocations)}, expected samples per index: {expected_samples_per_index}, output dir: {output_dir}")
+ return jsonify({"error": "No indices available"}), 404
+
+ available_indices = list(available_indices)
+ random.shuffle(available_indices)
+ allocated_indices = set(available_indices[:chunk_size])
+ job_allocations[slurm_job_id] = {"indices": allocated_indices, "timestamp": datetime.now(), "dataset_key": dataset_key}
+ save_allocations()
+ num_available_indices = len(available_indices)
+
+ rprint(
+ f"Dataset {dataset_key}, total len: {total_indices}, chunk size: {chunk_size}, completed len: {len(completed_indices)}, available len: {num_available_indices}, running jobs: {len(running_jobs)}, current allocations: {len(job_allocations)}"
+ )
+ rprint(f"Job {slurm_job_id} allocated indices: {allocated_indices}")
+
+ return jsonify({"indices": list(allocated_indices)})
+
+
+if __name__ == "__main__":
+ app.run(host="0.0.0.0", port=5000, threaded=False)
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/unidisc_dataset/generate/llm_prompting.py b/unidisc/datasets/preprocessing/unidisc_dataset/generate/llm_prompting.py
new file mode 100644
index 0000000000000000000000000000000000000000..0806aa82d1949a483670031f5148d8072b004043
--- /dev/null
+++ b/unidisc/datasets/preprocessing/unidisc_dataset/generate/llm_prompting.py
@@ -0,0 +1,204 @@
+import functools
+import os
+import random
+import subprocess
+import time
+from contextlib import ExitStack
+
+from langchain_core.output_parsers import JsonOutputParser
+from langchain_core.prompts import ChatPromptTemplate
+
+from decoupled_utils import rprint
+
+OPENROUTER_BASE = "https://openrouter.ai"
+OPENROUTER_API_BASE = f"{OPENROUTER_BASE}/api/v1"
+OPENROUTER_REFERRER = "https://github.com/alexanderatallah/openrouter-streamlit"
+
+def get_ollama(hosting_node):
+ from langchain_community.chat_models import ChatOllama
+
+ possible_ports = [11434, 11435, 11436, 11437, 11438]
+ open_ports = []
+
+ for port in possible_ports:
+ result = subprocess.run(['nc', '-z', '-w1', hosting_node, str(port)], capture_output=True)
+ if result.returncode == 0:
+ open_ports.append(port)
+
+ if not open_ports:
+ open_ports = [11434]
+
+ chosen_port = random.choice(open_ports)
+
+ ollama_llm = ChatOllama(
+ model="llama3.1",
+ base_url=f"http://{hosting_node}:{chosen_port}",
+ temperature=0.8,
+ request_timeout=180,
+ )
+ return ollama_llm
+
+def get_groq_llama(model="llama3-70b-8192"):
+ from langchain_groq import ChatGroq
+ groq_llm = ChatGroq(
+ temperature=0.8,
+ model=model,
+ max_retries=0,
+ request_timeout=15,
+ )
+ return groq_llm
+
+def get_openai_azure():
+ from langchain_openai import AzureChatOpenAI
+ # Need to also set AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT
+ # Key only works for gpt-4o
+ os.environ["AZURE_OPENAI_API_VERSION"] = '2024-06-01'
+ os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"] = "gpt-4o"
+ llm = AzureChatOpenAI(
+ openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
+ azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"],
+ )
+ return llm
+
+def get_openai_openrouter(model="gpt-4o-mini"):
+ from langchain_openai import ChatOpenAI
+ llm = ChatOpenAI(
+ temperature=0.8,
+ model=model,
+ openai_api_key=os.environ["OPENROUTER_API_KEY"],
+ openai_api_base=OPENROUTER_API_BASE,
+ timeout=15,
+ )
+ return llm
+
+def get_llm(hosting_node, llm_model_type, **kwargs):
+ output_parser = JsonOutputParser()
+
+ prompt = ChatPromptTemplate.from_messages([
+ ('system', 'You are a helpful assistant.'),
+ ('user', """
+ I am generating a set of diverse prompts for a text-to-image model. Given the following prompt from a human user, please generate a set of {new_samples_per_index} diverse prompts that modify the original prompt in a meaningful way but maintains some of the original meaning or context. For example, you may add or remove objects, change the desired styling, the sentence structure, or reference different proper nouns. You might change the subject, time period, time of day, location, culture, camera angle, and other attributes. The new prompt does not need to be a complete sentence and may contain fragments and attributes if the original prompt does. You may substantially modify the prompt but make sure that the new prompt is self-contained and a plausible prompt that a user would ask a text-to-image model such as DALL-E or Stable Diffusion. Do not generate NSFW prompts. Do not preface the output with any numbers or text. {format_instructions}. The output should have keys as indices and values as the prompts, and should be valid, parseable JSON.
+
+ Original prompt: {prompt}
+ """
+ )])
+
+ if llm_model_type == "llama3.1":
+ llm = get_ollama(hosting_node)
+ rprint(f"Using Ollama on host {hosting_node}")
+ elif llm_model_type == "groq":
+ llm = get_groq_llama(**kwargs)
+ elif llm_model_type == "gpt-4o-mini-openrouter":
+ llm = get_openai_openrouter(**kwargs)
+ else:
+ from langchain_openai import ChatOpenAI
+ openai_llm = ChatOpenAI(
+ model="gpt-4o-mini",
+ temperature=0.8,
+ max_tokens=1000,
+ max_retries=0,
+ timeout=20,
+ )
+ llm = openai_llm.with_fallbacks([
+ *([get_openai_openrouter("gpt-4o-mini")] if "OPENROUTER_API_KEY" in os.environ else []),
+ get_groq_llama("llama-3.1-70b-versatile"),
+ get_groq_llama("llama3-70b-8192"),
+ get_groq_llama("llama-3.1-8b-instant"),
+ get_groq_llama("gemma2-9b-it"),
+ get_ollama(hosting_node)
+ ])
+ rprint("Using GPT4o-mini")
+
+ chain = prompt | llm
+
+ return functools.partial(forward_llm, chain=chain, output_parser=output_parser, llm_model_type=llm_model_type)
+
+def forward_llm(prompt, new_samples_per_index, chain, output_parser, llm_model_type, fake_openai_failure=False):
+ with ExitStack() as stack:
+ if "gpt" in llm_model_type:
+ from langchain_community.callbacks import get_openai_callback
+ cb = stack.enter_context(get_openai_callback())
+
+ if fake_openai_failure:
+ from unittest.mock import patch
+
+ import httpx
+ from openai import RateLimitError
+ request = httpx.Request("GET", "/")
+ response = httpx.Response(200, request=request)
+ error = RateLimitError("rate limit", response=response, body="")
+ stack.enter_context(patch("openai.resources.chat.completions.Completions.create", side_effect=error))
+
+ for i in range(5):
+ try:
+ start_time = time.time()
+ rprint(f"Calling LLM...")
+ output_message = chain.invoke({
+ "prompt": prompt,
+ "format_instructions": output_parser.get_format_instructions(),
+ "new_samples_per_index": new_samples_per_index
+ })
+
+ output = output_parser.invoke(output_message)
+ output = list(output.values())
+
+ if len([x for x in output if x is not None]) == 0:
+ raise ValueError("No output from LLM")
+
+ end_time = time.time()
+ rprint(f"LLM Time taken: {end_time - start_time:.2f} seconds")
+ break
+ except Exception as e:
+ rprint(f"Error, retrying: {i}, {e}")
+ if i == 4:
+ raise e
+ continue
+
+ try:
+ model_name = output_message.response_metadata['model_name']
+ rprint(f"Used model name: {model_name}")
+ except:
+ model_name = "Unknown"
+
+ if "gpt" in llm_model_type and i == 0:
+ rprint(cb)
+
+ output = [prompt for prompt in output if prompt is not None]
+
+ if len(output) == 0:
+ rprint("No output from LLM")
+ rprint(f"Raw: {output_message}")
+ output = []
+ else:
+ if any(x in output[0].lower() for x in [" here", "diverse"]):
+ rprint("Removing the first element.")
+ rprint(output[0])
+ output.pop(0)
+
+ output = [prompt.strip() for prompt in output]
+ output = [prompt for prompt in output if prompt != ""]
+
+ return output, model_name
+
+import json
+import random
+from pathlib import Path
+
+if __name__ == "__main__":
+ llm_func = get_llm("node-name", "gpt-4o-mini-openrouter")
+ from unidisc.datasets.prompts.generate_images import prompt_folder
+ input_directory = Path(prompt_folder)
+
+ for file_path in input_directory.glob("*.json"):
+ rprint(f"Opening {file_path}")
+ with file_path.open('r') as file:
+ prompts = json.load(file)
+
+ sampled_prompts = random.sample(prompts, min(2, len(prompts)))
+ for prompt in sampled_prompts:
+ rprint(f"Prompt: {prompt}")
+ rprint(llm_func(prompt, fake_openai_failure=False))
+
+ rprint("\n")
+
+
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/unidisc_dataset/postprocess_dataset/convert_json_to_parquet.py b/unidisc/datasets/preprocessing/unidisc_dataset/postprocess_dataset/convert_json_to_parquet.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c5e0a47bd0f9d2f9a532cf2e2a636d6aaded5af
--- /dev/null
+++ b/unidisc/datasets/preprocessing/unidisc_dataset/postprocess_dataset/convert_json_to_parquet.py
@@ -0,0 +1,33 @@
+import json
+from pathlib import Path
+from typing import Optional
+import pandas as pd
+import typer
+from tqdm import tqdm
+import socket
+
+def main(directories: list[Path], output_path: Optional[Path] = None):
+ data = []
+ for directory in directories:
+ for i, json_file in tqdm(enumerate(directory.glob('*.json'))):
+ try:
+ with json_file.open('r') as f:
+ metadata = json.load(f)
+ image_filename = json_file.with_suffix('.jpg').name
+ metadata['__key__'] = str((directory / image_filename).relative_to(directory.parent))
+ metadata["caption"] = metadata["augmented_prompt"]
+ metadata["subdirectory"] = str(directory.relative_to(directory.parent))
+ data.append(metadata)
+ except Exception as e:
+ print(f"Error loading {json_file}: {e}")
+
+ df = pd.DataFrame(data)
+ df['idx'] = df.index
+ hostname = socket.gethostname()
+ df['cluster'] = 'cluster_name'
+ df = df[df['image_path'].notna() & (df['image_path'] != '')]
+ df.to_parquet(output_path, index=False)
+ print(f"Metadata has been saved to {output_path}")
+
+if __name__ == "__main__":
+ typer.run(main)
\ No newline at end of file
diff --git a/unidisc/datasets/preprocessing/unidisc_dataset/postprocess_dataset/convert_parquet_to_wds.py b/unidisc/datasets/preprocessing/unidisc_dataset/postprocess_dataset/convert_parquet_to_wds.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fa5cff41aec935f632ce17fbc1f9419f1084aac
--- /dev/null
+++ b/unidisc/datasets/preprocessing/unidisc_dataset/postprocess_dataset/convert_parquet_to_wds.py
@@ -0,0 +1,77 @@
+import pandas as pd
+import numpy as np
+from concurrent.futures import ProcessPoolExecutor, as_completed
+import webdataset as wds
+from pathlib import Path
+from PIL import Image
+from io import BytesIO
+import os
+
+def process_chunk(chunk, shard_id, base_dir, output_tar_path):
+ base_dir = Path(base_dir)
+ output_tar_path = output_tar_path.replace("output_dataset_", f"output_dataset_{shard_id}_")
+ print(f"Processing shard {shard_id} with {len(chunk)} samples, {output_tar_path}")
+ with wds.ShardWriter(output_tar_path, maxsize=500*1024*1024) as sink:
+ for _, row in chunk.iterrows():
+ # Construct the image path
+ # image_path = base_dir / row['__key__']
+ image_path = row['image_path']
+ if not Path(image_path).exists(): assert False
+ try:
+ # Load the image
+ with Image.open(image_path) as img:
+ img_byte_arr = BytesIO()
+ img.save(img_byte_arr, format='JPEG')
+ img_data = img_byte_arr.getvalue()
+ # Create a unique key for each sample
+ key = Path(image_path).stem
+
+ # Prepare the sample dictionary
+ sample = {
+ '__key__': key,
+ 'jpg': img_data,
+ 'txt': row['caption'],
+ 'meta.json': row.drop(['__key__', 'caption']).to_dict()
+ }
+ # Write the sample to the shard
+ sink.write(sample)
+ except Exception as e:
+ print(f"Error processing image {image_path}: {e}")
+
+
+def main(parquet_file, base_dir, output_dir, num_workers=8):
+ # Load the Parquet file into a DataFrame
+ df = pd.read_parquet(parquet_file)
+ print(f"Dataframe loaded with {len(df)} rows.")
+ print(f"Columns: {df.columns.tolist()}")
+
+ # Ensure the output directory exists
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+ output_tar_path = str(output_dir) + "/output_dataset_%06d.tar"
+
+ # Split the DataFrame into chunks for each worker
+ df_split = np.array_split(df, num_workers)
+
+ # Use a ProcessPoolExecutor for parallel processing
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
+ futures = []
+ for shard_id, chunk in enumerate(df_split):
+ futures.append(executor.submit(process_chunk, chunk, shard_id, base_dir, output_tar_path))
+
+ # Collect results and handle exceptions
+ for future in as_completed(futures):
+ try:
+ future.result()
+ except Exception as e:
+ print(f"Error in worker: {e}")
+
+if __name__ == "__main__":
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--parquet_file', type=str, required=True, help="Path to the Parquet file.")
+ parser.add_argument('--base_dir', type=str, required=True, help="Base directory where images are located.")
+ parser.add_argument('--output_dir', type=str, required=True, help="Directory to store the output WebDataset tar files.")
+ parser.add_argument('--num_workers', type=int, default=8, help="Number of parallel workers.")
+ args = parser.parse_args()
+ main(args.parquet_file, args.base_dir, args.output_dir, args.num_workers)
diff --git a/unidisc/datasets/preprocessing/webdataset.md b/unidisc/datasets/preprocessing/webdataset.md
new file mode 100644
index 0000000000000000000000000000000000000000..0067f3f482c2da983e7850ab8ab1b4c2ec6bbab9
--- /dev/null
+++ b/unidisc/datasets/preprocessing/webdataset.md
@@ -0,0 +1,60 @@
+# WebDataset
+
+Many of our datasets are stored in an intermediate format called [WebDataset](https://github.com/webdataset/webdataset). This is simply a tar file containing a set of files in a specific format. Although typically WebDataset is used for distributed training where data repitition is not a large issue (e.g., with an IterableDataset), for some pre-processing tasks it is easier to have an indexed dataset.
+
+For this reason, you must run the following command to create an index, once inside the directory containing the tar files:
+
+**Note:** you may want/need to install the fork with `pip install git+ssh://git@github.com/alexanderswerdlow/webdataset.git@wip` to get a faster `widsindex` command.
+
+```bash
+widsindex create *.tar
+```
+
+
+To read a WebDataset, you can use the following code:
+
+```python
+import webdataset as wds
+import braceexpand
+from tqdm import tqdm
+shards = braceexpand.braceexpand('/scratch/data_cc3m/cc3m-train-{0000..0575}.tar')
+dataset = wds.WebDataset(shards, shardshuffle=True).shuffle(5000)
+
+for b in tqdm(dataset):
+ pass
+```
+
+
+
+
+## Tokenization
+
+Next, precompute the tokens:
+
+**_Note:_** If you are on a SLURM cluster, you can replace `accelerate launch` with:
+
+```bash
+sbatch --time=2-00:00:00 --array=0-100%25 --cpus-per-gpu=12 --mem-per-gpu=100G --nodes=1 --gpus-per-node=1 --partition=preempt --job-name=cambrian_precompute_tokens scripts/precompute_tokens_slurm.sh
+```
+
+**_Note:_** If you want to only generate a subset of the tokens, append e.g., `data.n_train_samples=200` to the command.
+
+**_Note:_** Set `data.block_size=128` if you want a different maximum token length.
+
+**_Note:_** `model.text_vocab_size` and `data.img_token_shift` are based on the text tokenizer used, in this case `Llama-2-7b-hf`.
+
+**_Note:_** Set the resolution as desired (e.g., 256, 512, 1024, etc.).
+
+Finally, to tokenize the dataset, run:
+
+```bash
+accelerate launch models/datasets/precompute_tokens.py +experiments='[webdataset,tokenize,vq16_t2i]' data.token_output_dir="/path/to/token_output_dir" data.resolution=512 data.use_chameleon=false loader.batch_size=16 data.raw_data_dir='/path/to/cambrian/jsons/Cambrian10M.jsonl' +model.text_vocab_size=32001 data.img_token_shift=32001 +data.use_identity_collate=true loader.num_workers=2 data.split_dataset=true +data.save_tmp_interval=3600 +data.use_slow_tokenizer=true +data.add_image_token=true
+```
+
+Now that the tokenization is complete, if it was done over multiple GPUs/nodes, you must combine the tensordicts on disk. If the to
+
+```bash
+python models/datasets/combine_token_dicts.py "/path/to/token_output_dir" --move_files --delete_after_combining --mem_efficient
+```
+
+**_Note:_** You may wish to add the `--allow_tmp` flag to the command if the tokenization was only partially completed (e.g., due to a SLURM job being preempted). In this case, the tokenization saves intermediate checkpoints with a `tmp_` prefix.
\ No newline at end of file
diff --git a/unidisc/datasets/sampler.py b/unidisc/datasets/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea99c809a8a5085314a2a65d79af5a191521cbbb
--- /dev/null
+++ b/unidisc/datasets/sampler.py
@@ -0,0 +1,160 @@
+import numpy as np
+import torch
+from torch.utils.data import Dataset, Sampler
+from decoupled_utils import gprint, rprint, dprint, tensor_hash
+
+def _get_len(dataset):
+ if isinstance(dataset, torch.utils.data.IterableDataset):
+ return 1000000000
+ else:
+ return len(dataset)
+
+class WeightedDatasetSampler(Sampler):
+ def __init__(self, combined_dataset, generator=None, batch_size=100000):
+ self.dataset_names = combined_dataset.dataset_names
+ self.datasets = combined_dataset.datasets
+ self.weights = combined_dataset.weights
+ self.generator = generator
+ self.batch_size = batch_size
+
+ assert len(self.datasets) == len(self.weights), "Each dataset must have a corresponding weight"
+
+ # Samples per epoch is the least common multiple of dataset sizes
+ self.lcm_size = np.lcm.reduce([_get_len(d) for d in self.datasets])
+ if self.lcm_size < 1:
+ self.lcm_size = max([_get_len(d) for d in self.datasets]) * 1000
+
+ total_dataset_sizes_sum = sum(_get_len(d) for d in self.datasets)
+ self.weights = [
+ weight if weight >= 0 else _get_len(dataset) / total_dataset_sizes_sum
+ for weight, dataset in zip(self.weights, self.datasets)
+ ]
+ self.counts = {name: int(round(weight / sum(self.weights) * self.lcm_size)) for weight, name in zip(self.weights, self.dataset_names)}
+ rprint(f"LCM Size: {self.lcm_size}, Weighted Sampler Counts: {self.counts}, Weights: {self.weights}")
+ self.dataset_idx_to_name = {idx: name for idx, name in enumerate(self.dataset_names)}
+ self._reset_state()
+ self.raise_stop_iteration = False
+
+ def _generate_multinomial_batch(self):
+ normalized_weights = [weight / sum(self._state['available_weights']) for weight in self._state['available_weights']]
+ _batched_indices = torch.multinomial(torch.tensor(normalized_weights), self.batch_size, replacement=True, generator=self.generator)
+ for i, weight in enumerate(self.weights):
+ if weight == 0 and i in _batched_indices:
+ print(f"WARNING: sampling item with zero weight")
+
+ map_to_original = torch.tensor(self._state['available_datasets']).to(_batched_indices)
+ self._state['batched_indices'] = map_to_original[_batched_indices].tolist()
+ self._state['batch_pointer'] = 0
+
+ def state_dict(self):
+ if self.generator is not None:
+ gprint(f"Generator state at saving: {tensor_hash(self.generator.get_state())}")
+
+ gprint(f"Counts: {sum(self._state['current_counts'].values())}, Batch pointer: {self._state['batch_pointer']}")
+ return self._state
+
+ def load_state_dict(self, state_dict):
+ self._state['generator'] = state_dict['generator']
+ self._state['batched_indices'] = state_dict['batched_indices']
+ self._state['batch_pointer'] = state_dict['batch_pointer']
+
+ if set(state_dict['current_counts'].keys()) == set(self.dataset_names):
+ self._state['current_counts'] = state_dict['current_counts']
+ self._state['dataset_iters'] = state_dict['dataset_iters']
+ else:
+ rprint(f"Dataset names mismatch, updating state_dict")
+ self._state['current_counts'].update(state_dict['current_counts'])
+ if self._state['dataset_iters'] is None:
+ self._state['dataset_iters'] = {name: (torch.randperm(_get_len(self.datasets[idx]), generator=self._state['generator']), 0) for idx, name in enumerate(self.dataset_names) if name not in state_dict['current_counts']}
+ gprint(f"new len of dataset iters: {len(self._state['dataset_iters'])}")
+ self._state['dataset_iters'].update(state_dict['dataset_iters'])
+ gprint(f"final len of dataset iters: {len(self._state['dataset_iters'])}")
+
+ gprint(f"Weights: {self._state['available_weights']}")
+
+ gprint(f"Updated state_dict: {self._state['current_counts']}")
+ gprint(f"Finished loading sampler state_dict")
+ if self.generator is not None:
+ gprint(f"Generator state at loading: {tensor_hash(self._state['generator'].get_state())}")
+
+ gprint(f"Counts: {sum(self._state['current_counts'].values())}, Batch pointer: {self._state['batch_pointer']}")
+
+ def restart(self):
+ if self._state['dataset_iters'] is None:
+ rprint(f"Resetting dataset iter. We have: {self.dataset_names}")
+ self._state['dataset_iters'] = {name: (torch.randperm(_get_len(self.datasets[idx]), generator=self._state['generator']), 0) for idx, name in enumerate(self.dataset_names)}
+
+ def exhausted(self):
+ self._reset_state()
+ if self.raise_stop_iteration:
+ dprint(f"Sampler exhausted")
+ raise StopIteration
+ else:
+ self.restart()
+
+ def check_is_not_exhausted(self):
+ return any(self._state['current_counts'][name] < self.counts[name] for name in self.dataset_names)
+
+ def __iter__(self):
+ self.restart()
+ while self.check_is_not_exhausted() or self.raise_stop_iteration is False:
+ if len(self._state['available_datasets']) == 0 or (self.raise_stop_iteration is False and self.check_is_not_exhausted() is False):
+ self.exhausted()
+
+ if self._state['batched_indices'] is None or self._state['batch_pointer'] >= self.batch_size:
+ self._generate_multinomial_batch()
+
+ try:
+ dataset_name = self.dataset_idx_to_name[self._state['batched_indices'][self._state['batch_pointer']]]
+ except Exception as e:
+ gprint(f"Error in dataset_name: {e}, batch pointer: {self._state['batch_pointer']}, batched indices: {self._state['batched_indices']}, dataset idx to name: {self.dataset_idx_to_name}")
+ self._state['batch_pointer'] += 1
+
+ tensor, idx = self._state['dataset_iters'][dataset_name]
+ if idx >= len(tensor):
+ rprint(f"Resetting dataset iter for {dataset_name}")
+ tensor = torch.randperm(_get_len(self.datasets[self.dataset_names.index(dataset_name)]), generator=self._state['generator'])
+ idx = 0
+ self._state['dataset_iters'][dataset_name] = (tensor, idx)
+
+ self._state['dataset_iters'][dataset_name] = (tensor, idx + 1)
+ self._state['current_counts'][dataset_name] += 1
+
+ dataset_idx = self.dataset_names.index(dataset_name)
+ if self._state['current_counts'][dataset_name] >= self.counts[dataset_name]:
+ index = self._state['available_datasets'].index(dataset_idx)
+ self._state['available_datasets'].pop(index)
+ self._state['available_weights'].pop(index)
+ if len(self._state['available_datasets']) > 0:
+ rprint(f"{dataset_name} has no samples left, resetting dataset sampler")
+ self._generate_multinomial_batch()
+
+ # print(f"Yielding {dataset_idx}, {tensor[idx].item()}")
+ yield dataset_idx, tensor[idx].item()
+
+ self.exhausted()
+
+ def _reset_state(self):
+ self._state = {
+ 'batched_indices': None,
+ 'batch_pointer': 0,
+ 'current_counts': {name: 0 for name in self.dataset_names},
+ 'available_datasets': list(range(len(self.dataset_names))),
+ 'available_weights': [weight / sum(self.weights) for weight in self.weights],
+ 'dataset_iters': None,
+ 'generator': self.generator,
+ }
+
+ def __len__(self):
+ return sum(self.counts.values())
+
+class DummyDataset(Dataset):
+ def __init__(self, dataset_name, size):
+ self.dataset_name = dataset_name
+ self.size = size
+
+ def __len__(self):
+ return self.size
+
+ def __getitem__(self, idx):
+ return (self.dataset_name, idx) # Just return the index for testing
\ No newline at end of file
diff --git a/unidisc/tokenizers/__init__.py b/unidisc/tokenizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/unidisc/tokenizers/chameleon_tokenizers.py b/unidisc/tokenizers/chameleon_tokenizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec6226f39a6eb11745e9a076006aa999ea149311
--- /dev/null
+++ b/unidisc/tokenizers/chameleon_tokenizers.py
@@ -0,0 +1,905 @@
+import base64
+import io
+import json
+import random
+import sys
+import tarfile
+from tqdm import tqdm
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+import random
+import io
+import socket
+from collections import defaultdict
+import pickle
+import torchvision
+from constants import LIB_DIR
+from dataloader import tokenize_text
+from decoupled_utils import gprint, rprint
+import pandas as pd
+import copy
+from pathlib import Path
+chameleon_path = LIB_DIR / "Lumina-mGPT/lumina_mgpt"
+sys.path.append(str(chameleon_path))
+
+try:
+ from data.convertsation import Conversation
+ from data.item_processor import FlexARItemProcessor
+ class ItemProcessor(FlexARItemProcessor):
+ def __init__(
+ self,
+ tokenizer="Alpha-VLLM/Lumina-mGPT-7B-768",
+ conv_template=Conversation,
+ target_size=512,
+ ):
+ super().__init__(tokenizer, conv_template, target_size)
+ print(self.crop_size_list)
+
+ def process_item(self, img, txt, training_mode=True, out_flatten=True, w=None, h=None):
+ # Add custom codes here to convert raw_item to the standard format
+ # The standard format contains the "conversations" and "image" keys
+
+ _prompt = f"Generate an image of {w}x{h} according to the following prompt:\n{txt}" if w is not None and h is not None else f"Generate an image according to the following prompt:\n{txt}"
+ item = {
+ "conversations": [
+ {
+ "from": "human",
+ "value": _prompt
+ },
+ {
+ "from": "gpt",
+ "value": "<|image|>"
+ },
+ ],
+ "image": [img],
+ }
+
+ return super(ItemProcessor, self).process_item(item, training_mode, out_flatten)
+
+ def process_item_json(self, item, training_mode=True, out_flatten=True):
+ # Add custom codes here to convert raw_item to the standard format
+ # The standard format contains the "conversations" and "image" keys
+
+ # _prompt = f"Generate an image of {w}x{h} according to the following prompt:\n{txt}" if w is not None and h is not None else f"Generate an image according to the following prompt:\n{txt}"
+ # item = {
+ # "conversations": [
+ # {
+ # "from": "human",
+ # "value": _prompt
+ # },
+ # {
+ # "from": "gpt",
+ # "value": "<|image|>"
+ # },
+ # ],
+ # "image": [img],
+ # }
+
+ return super(ItemProcessor, self).process_item(item, training_mode, out_flatten)
+except Exception as e:
+ if chameleon_path.exists():
+ rprint(f"Failed to import Chameleon tokenizers from {chameleon_path}: {e}")
+
+
+
+def tensor_center_crop(tensor_image, crop_size):
+ _, _, h, w = tensor_image.shape
+
+ while h >= 2 * crop_size[0] and w >= 2 * crop_size[1]:
+ tensor_image = F.interpolate(tensor_image, size=(h // 2, w // 2), mode='area')
+ _, _, h, w = tensor_image.shape
+
+ scale = max(crop_size[0] / h, crop_size[1] / w)
+ new_h, new_w = round(h * scale), round(w * scale)
+ tensor_image = F.interpolate(tensor_image, size=(new_h, new_w), mode='bilinear')
+
+ crop_top = random.randint(0, new_h - crop_size[0])
+ crop_left = random.randint(0, new_w - crop_size[1])
+ crop_bottom = crop_top + crop_size[0]
+ crop_right = crop_left + crop_size[1]
+ return tensor_image[:, :, crop_top:crop_bottom, crop_left:crop_right]
+
+def var_center_crop(tensor_image, crop_size_list, random_top_k=1):
+ _, _, h, w = tensor_image.shape
+ rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list]
+ crop_size = random.choice(
+ sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k]
+ )[1]
+ # alternates = sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)
+ # for i, (a, (x, y)) in enumerate(alternates):
+ # print(f"{i}: {(x // 16) * (y // 16)}")
+ return tensor_center_crop(tensor_image, crop_size)
+
+def tokenize_chameleon_fast(config, tokenizer=None, vae=None, batch=None, txt_decoded=None, **kwargs):
+ assert "idx" in batch
+
+ if txt_decoded is None:
+ txt_decoded = tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True, clean_up_tokenization_spaces=True)
+
+ all_attention_masks = []
+ all_input_ids = []
+
+ bs = batch['img'].shape[0]
+ _img = var_center_crop(batch['img'], vae.crop_size_list, random_top_k=5)
+ image_toks = vae.chameleon_ori_image_tokenizer.img_tokens_from_tensor(_img)
+ image_toks = vae.chameleon_ori_translation.img2bpe_mapping_tensor[image_toks]
+ h, w = _img.shape[-2:]
+ h_grids, w_grids = h // vae.patch_size, w // vae.patch_size
+ full_image_toks = image_toks.reshape(bs, h // 16, w // 16)
+ new_line_id = vae.token2id(vae.new_line_token)
+
+ full_image_toks = torch.cat(
+ (
+ full_image_toks,
+ torch.full((bs, h // 16, 1), fill_value=new_line_id, device=full_image_toks.device, dtype=full_image_toks.dtype),
+ ),
+ dim=-1,
+ ).flatten(start_dim=1, end_dim=-1)
+
+ result_toks = torch.cat([
+ torch.tensor([
+ vae.token2id(vae.image_start_token),
+ vae.token2id(vae.get_n_grids_token(h_grids)),
+ vae.token2id(vae.get_n_grids_token(w_grids))
+ ], device=full_image_toks.device, dtype=full_image_toks.dtype).unsqueeze(0).expand(bs, -1),
+ full_image_toks,
+ torch.tensor([
+ vae.token2id(vae.image_end_token)
+ ], device=full_image_toks.device, dtype=full_image_toks.dtype).unsqueeze(0).expand(bs, -1)
+ ], dim=1)
+
+ input_ids = torch.full((bs, config.model.length,), fill_value=-100, dtype=torch.int64)
+ attention_mask = torch.full((bs, config.model.length,), fill_value=False, dtype=torch.bool)
+
+ for i in range(batch['input_ids'].shape[0]):
+ _img = (result_toks[i], config.data.resolution, config.data.resolution)
+ _txt = txt_decoded[i][:200]
+ tokens, labels = vae.process_item(_img, _txt, out_flatten=False, h=h, w=w)
+ idx = 0
+ for j, token_or_media in enumerate(tokens):
+ if isinstance(token_or_media, int):
+ input_ids[i, idx:idx+1] = token_or_media
+ idx += 1
+ else:
+ media_len = len(token_or_media["input_ids"])
+ input_ids[i, idx:idx+media_len] = token_or_media["input_ids"]
+ idx += media_len
+
+ attention_mask[i, :idx] = True
+ if idx >= config.model.length:
+ gprint("WARNING!!!! Truncating input ids")
+
+ all_attention_masks = attention_mask
+ all_input_ids = input_ids
+
+ return all_input_ids, all_attention_masks
+
+
+def tokenize_chameleon_mmc4(config, tokenizer, vae, batch, device, mapping, **kwargs):
+ all_images, all_content = get_mmc4(config, tokenizer, vae, batch, device, mapping)
+
+ _img = torch.cat([tensor_center_crop(torch.from_numpy(np.array(img))[None, :].permute(0, 3, 1, 2) / 255, (config.data.resolution, config.data.resolution)) for img in all_images])
+
+ all_attention_masks = []
+ all_input_ids = []
+
+ bs = _img.shape[0]
+ # _img = var_center_crop(batch['img'], vae.crop_size_list, random_top_k=5)
+ image_toks = vae.chameleon_ori_image_tokenizer.img_tokens_from_tensor(_img)
+ image_toks = vae.chameleon_ori_translation.img2bpe_mapping_tensor[image_toks]
+ h, w = _img.shape[-2:]
+ h_grids, w_grids = h // vae.patch_size, w // vae.patch_size
+ full_image_toks = image_toks.reshape(bs, h // 16, w // 16)
+ new_line_id = vae.token2id(vae.new_line_token)
+
+ full_image_toks = torch.cat(
+ (
+ full_image_toks,
+ torch.full((bs, h // 16, 1), fill_value=new_line_id, device=full_image_toks.device, dtype=full_image_toks.dtype),
+ ),
+ dim=-1,
+ ).flatten(start_dim=1, end_dim=-1)
+
+ # TODO: Currently unused
+ result_toks = torch.cat([
+ torch.tensor([
+ vae.token2id(vae.image_start_token),
+ vae.token2id(vae.get_n_grids_token(h_grids)),
+ vae.token2id(vae.get_n_grids_token(w_grids))
+ ], device=full_image_toks.device, dtype=full_image_toks.dtype).unsqueeze(0).expand(bs, -1),
+ full_image_toks,
+ torch.tensor([
+ vae.token2id(vae.image_end_token)
+ ], device=full_image_toks.device, dtype=full_image_toks.dtype).unsqueeze(0).expand(bs, -1)
+ ], dim=1)
+
+ input_ids = torch.full((bs, config.model.length,), fill_value=-100, dtype=torch.int64)
+ attention_mask = torch.full((bs, config.model.length,), fill_value=False, dtype=torch.bool)
+
+ for i in range(len(all_content)):
+ w, h = config.data.resolution, config.data.resolution
+ _item = all_content[i]
+ conversations = []
+ first_text = True
+ for item in _item:
+ if item['type'] == "text":
+ if first_text:
+ conversations.append({"from": "human", "value": f"Generate an image of {w}x{h} according to the following prompt:\n{item['text']}"})
+ first_text = False
+ else:
+ conversations.append({"from": "human", "value": item["text"]})
+ elif item['type'] == "image_url":
+ conversations.append({"from": "gpt", "value": "<|image|>"})
+
+ item = {
+ "conversations": conversations,
+ "image": [(full_image_toks[it["image_url"]["url"]], w, h) for it in _item if it["type"] == "image_url"],
+ }
+
+ tokens, labels = vae.process_item_json(item, out_flatten=False)
+ idx = 0
+ for j, token_or_media in enumerate(tokens):
+ if isinstance(token_or_media, int):
+ input_ids[i, idx:idx+1] = token_or_media
+ idx += 1
+ else:
+ media_len = len(token_or_media["input_ids"])
+ input_ids[i, idx:idx+media_len] = token_or_media["input_ids"]
+ idx += media_len
+
+ attention_mask[i, :idx] = True
+ if idx >= config.model.length:
+ gprint("WARNING!!!! Truncating input ids")
+
+ all_attention_masks = attention_mask
+ all_input_ids = input_ids
+
+ return all_input_ids, all_attention_masks
+
+def tokenize_chameleon(config, tokenizer, vae, batch, **kwargs):
+ assert "idx" in batch
+ txt_decoded = tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True, clean_up_tokenization_spaces=True)
+ all_attention_masks = []
+ all_input_ids = []
+ for i in range(batch['input_ids'].shape[0]):
+ _img = Image.fromarray((batch['img'][i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
+ _txt = txt_decoded[i][:100]
+
+ tokens, labels = vae.process_item(_img, _txt, out_flatten=False)
+ input_ids = []
+ first_img_id = None
+ for i, token_or_media in enumerate(tokens):
+ if isinstance(token_or_media, int):
+ input_ids.append(token_or_media)
+ else:
+ if len(input_ids) > (config.model.length - 1061):
+ gprint("WARNING!!!! Truncating input ids")
+ input_ids = input_ids[:-1061]
+ input_ids += token_or_media["input_ids"]
+ first_img_id = i
+
+ input_ids = torch.tensor(input_ids, dtype=torch.int64)
+ attention_mask = torch.zeros(config.model.length, dtype=torch.bool)
+ attention_mask[:len(input_ids)] = True
+ if len(input_ids) > config.model.length:
+ gprint("WARNING!!!! Truncating input ids, this should not happen")
+ input_ids = input_ids[:config.model.length]
+ attention_mask = attention_mask[:config.model.length]
+
+ input_ids = torch.cat([input_ids, -100 * torch.ones(config.model.length - len(input_ids), dtype=torch.int64)])
+ all_input_ids.append(input_ids)
+ all_attention_masks.append(attention_mask)
+
+ all_attention_masks = torch.stack(all_attention_masks, dim=0)
+ all_input_ids = torch.stack(all_input_ids, dim=0)
+ return all_input_ids, all_attention_masks
+
+_tar_cache = {}
+_tar_contents_cache = None
+
+def process_tar_file(tar_filepath):
+ try:
+ with tarfile.open(tar_filepath) as tar:
+ return tar_filepath, set(tar.getnames())
+ except:
+ return tar_filepath, set()
+
+def get_cache(mapping, split, parent_dir):
+ global _tar_contents_cache
+ if _tar_contents_cache is not None:
+ return _tar_contents_cache
+
+ hostname = socket.gethostname()
+ userhome = Path.home()
+ cache_path = userhome / ".cache" / "unidisc" / f"{split}_tar_contents_cache.pkl"
+
+ if cache_path.exists():
+ print("Loading tar contents cache")
+ with open(cache_path, 'rb') as f:
+ _tar_contents_cache = pickle.load(f)
+ return _tar_contents_cache
+
+ unique_tar_filepaths = mapping['tar_filepath'].unique()
+ if parent_dir is not None:
+ for i in range(len(unique_tar_filepaths)):
+ orig_tar_filepath = Path(unique_tar_filepaths[i])
+ relative_path = orig_tar_filepath.relative_to(*orig_tar_filepath.parts[:len(Path(parent_dir).parts)])
+ unique_tar_filepaths[i] = Path(parent_dir) / relative_path
+ if not unique_tar_filepaths[i].exists():
+ unique_tar_filepaths[i] = Path(parent_dir) / orig_tar_filepath.relative_to(*orig_tar_filepath.parts[:len(Path(parent_dir).parts) - 1])
+
+ print(f"Building tar contents cache for {len(unique_tar_filepaths)} files. Example: {unique_tar_filepaths[:4]}")
+ import multiprocessing as mp
+ with mp.Pool() as pool:
+ results = list(tqdm(
+ pool.imap(process_tar_file, unique_tar_filepaths),
+ total=len(unique_tar_filepaths),
+ desc="Building tar contents cache"
+ ))
+
+ _tar_contents_cache = dict(results)
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
+ with open(cache_path, 'wb') as f:
+ pickle.dump(_tar_contents_cache, f)
+
+ return _tar_contents_cache
+
+def load_image(tar_filepath, key, split, parent_dir) -> bytes:
+ global _tar_contents_cache, _tar_cache
+ image_path = f"{key}.jpg"
+
+ if parent_dir is not None:
+ orig_tar_filepath = Path(tar_filepath)
+ relative_path = orig_tar_filepath.relative_to(*orig_tar_filepath.parts[:len(Path(parent_dir).parts)])
+ tar_filepath = Path(parent_dir) / relative_path
+ if not tar_filepath.exists():
+ tar_filepath = Path(parent_dir) / orig_tar_filepath.relative_to(*orig_tar_filepath.parts[:len(Path(parent_dir).parts) - 1])
+
+ if image_path not in _tar_contents_cache[tar_filepath]:
+ raise ValueError(f"Image {image_path} not found in {tar_filepath}")
+
+ hostname = socket.gethostname()
+ if parent_dir is None and "babel" in hostname:
+ tar_filepath = tar_filepath.replace("/scratch", "/other_path")
+
+ if tar_filepath not in _tar_cache:
+ _tar_cache[tar_filepath] = tarfile.open(tar_filepath)
+
+ tar = _tar_cache[tar_filepath]
+ with tar.extractfile(image_path) as f:
+ buffered_reader = io.BufferedReader(f)
+ return buffered_reader.read()
+
+def cleanup_tar_cache():
+ for tar in _tar_cache.values():
+ tar.close()
+ _tar_cache.clear()
+ _tar_contents_cache.clear()
+
+def get_mmc4(config, tokenizer, vae, batch, device, mapping):
+ split = "fewer_faces" if "fewer_faces" in getattr(config.data, "raw_data_dir") else "core"
+ parent_dir = Path(getattr(config.data, "mmc4_parent_dir", None))
+ get_cache(mapping, split, parent_dir)
+
+ all_images = []
+ image_idx = 0
+ all_content = []
+ remove_instances_missing_images = False
+ before_ratio = 0.8
+ for jsonl_row in batch:
+ stat_counter = defaultdict(int)
+ text_list = jsonl_row["text_list"]
+ images_insert_before_text = [ [] for _ in range(len(text_list)) ]
+ images_insert_after_text = [ [] for _ in range(len(text_list)) ]
+
+ for image_info in jsonl_row["image_info"]:
+ # randomly decide whether to prepend or append the image to the corresponding text
+ insert_before = random.random() < before_ratio
+ try:
+ mapped_to_ = mapping.loc[image_info["raw_url"]]
+ if isinstance(mapped_to_, pd.Series):
+ mapped_to_ = [mapped_to_]
+ elif isinstance(mapped_to_, pd.DataFrame):
+ mapped_to_ = [row for _, row in mapped_to_.iterrows()]
+ else:
+ mapped_to_ = [mapped_to_]
+
+ except KeyError as e:
+ if remove_instances_missing_images:
+ stat_counter["instance_skipped_due_to_missing_image"] += 1
+ break # skip this instance
+ else:
+ stat_counter["n_missing_images"] += 1
+ continue # skip this image
+
+ success = False
+ for mapped_to in mapped_to_:
+ try:
+ tar_filepath = mapped_to["tar_filepath"]
+ key = mapped_to["key"]
+ except Exception as e:
+ print(f"V2 Error mapping key to path: {e}")
+ continue
+
+ try:
+ image_bytes = load_image(tar_filepath, key, split, parent_dir)
+ except Exception as e:
+ # print(f"Failed to read key: {key}, {e}")
+ if remove_instances_missing_images:
+ stat_counter["instance_skipped_due_to_missing_image"] += 1
+ break # skip this instance
+ else:
+ stat_counter["n_missing_images"] += 1
+ continue # skip this image
+
+ image_pil = Image.open(io.BytesIO(image_bytes)).convert('RGB')
+ success = True
+ image_content = {
+ "type": "image_url",
+ "image_url": {"url": image_idx}
+ }
+ image_idx += 1
+ all_images.append(image_pil)
+ stat_counter["n_images_inserted"] += 1
+
+ if insert_before:
+ stat_counter["n_images_inserted_before_text"] += 1
+ images_insert_before_text[image_info["matched_text_index"]].append(image_content)
+ else:
+ stat_counter["n_images_inserted_after_text"] += 1
+ images_insert_after_text[image_info["matched_text_index"]].append(image_content)
+
+ break
+
+ if not success:
+ print(f"Failed find image: {key}")
+
+ # flatten content: list of list of content -> list of content
+ content = []
+ for i, text in enumerate(text_list):
+ content.extend(images_insert_before_text[i])
+ content.append({"type": "text", "text": text})
+ content.extend(images_insert_after_text[i])
+ all_content.append(content)
+
+ reordered_images = []
+ old_to_new_idx = {}
+ new_idx = 0
+ for content_list in all_content:
+ for item in content_list:
+ if item["type"] == "image_url":
+ old_idx = item["image_url"]["url"]
+ if old_idx not in old_to_new_idx:
+ old_to_new_idx[old_idx] = new_idx
+ reordered_images.append(all_images[old_idx])
+ new_idx += 1
+
+ batch_size_map = defaultdict(list)
+ for i, content_list in enumerate(all_content):
+ for item in content_list:
+ if item["type"] == "image_url":
+ old_idx = item["image_url"]["url"]
+ item["image_url"]["url"] = old_to_new_idx[old_idx]
+ batch_size_map[i].append(item["image_url"]["url"])
+
+ all_images = reordered_images
+
+ return all_images, all_content, batch_size_map
+
+def tokenize_regular_cambrian_mmc4(config, tokenizer, vae, batch, device, mapping, inference_data=False,**kwargs):
+ """Use this for MMC4 and Cambrian. Ignore the other functions below."""
+ is_cambrian = config.data.train == "cambrian"
+ if inference_data:
+ breakpoint()
+ elif is_cambrian:
+ all_images = []
+ all_content = batch
+ parent_path = Path(getattr(config.data, "cambrian_path", "/cambrian_base_path"))
+ batch_size_map = defaultdict(list)
+ for i in range(len(batch)):
+ if "image" in batch[i]:
+ img = Image.open(parent_path / batch[i]["image"]).convert("RGB")
+ all_images.append(img)
+ batch_size_map[i].append(i)
+ else:
+ all_images, all_content, batch_size_map = get_mmc4(config, tokenizer, vae, batch, device, mapping)
+ if len(all_images) == 0:
+ gprint(f"No images, skipping...")
+ return None, None, None
+
+ output_list = []
+ for input_data in all_content:
+ conversations = []
+ current_human_text = ""
+ images_to_prepend = []
+ text_counter = 0
+ has_image = False
+ for item in input_data:
+ if item['type'] == 'image_url':
+ if text_counter == 0:
+ images_to_prepend.append('')
+ else:
+ if current_human_text:
+ current_human_text += ''
+ conversations.append({'from': 'human', 'value': current_human_text})
+ current_human_text = ""
+ else:
+ if conversations and conversations[-1]['from'] == 'human':
+ conversations[-1]['value'] += ''
+ else:
+ images_to_prepend.append('')
+ has_image = True
+ elif item['type'] == 'text':
+ text_counter += 1
+ if current_human_text:
+ current_human_text += ' ' + ' '.join(images_to_prepend) + ' ' + item['text']
+ images_to_prepend = []
+ else:
+ current_human_text = ''.join(images_to_prepend) + ' ' + item['text']
+ images_to_prepend = []
+
+ if current_human_text or images_to_prepend:
+ if current_human_text:
+ current_human_text += ' ' + ' '.join(images_to_prepend)
+ else:
+ current_human_text = ' '.join(images_to_prepend)
+ conversations.append({'from': 'human', 'value': current_human_text.strip()})
+
+ _kwargs = {}
+ if has_image:
+ _kwargs['image'] = {}
+
+ output_list.append({
+ "id": "1",
+ "conversations": conversations,
+ **_kwargs
+ })
+
+ all_content = output_list
+
+ _res = config.data.resolution
+ _length = config.model.length
+ if is_cambrian and len(all_images) == 0:
+ image_ids = None
+ else:
+ _img = torch.cat([tensor_center_crop(torch.from_numpy(np.array(img))[None, :].permute(0, 3, 1, 2) / 255, (_res, _res)) for img in all_images])
+ from model import get_image_batch
+ try:
+ batch_size = 32
+ image_ids = []
+ for i in range(0, len(_img), batch_size):
+ batch = _img[i:i+batch_size]
+ batch_ids = get_image_batch(config, vae, {"img": batch}, device)
+ image_ids.append(batch_ids)
+ image_ids = torch.cat(image_ids)
+ except Exception as e:
+ gprint(f"{_img.shape}, {e}")
+ import traceback
+ traceback.print_exc()
+
+ from unidisc.tokenizers.tokenize_interleaved import preprocess, _has_image
+ all_input_ids = []
+ all_attention_masks = []
+ all_modality = []
+ for i, sources in enumerate(all_content):
+ has_image = _has_image(sources)
+ sources = copy.deepcopy([e["conversations"] for e in [sources]])
+ _image_ids = None
+ if has_image:
+ try:
+ _image_ids = image_ids[batch_size_map[i]]
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+ gprint(f"Error in tokenize_regular_cambrian_mmc4: {e}")
+ return None, None, None
+ try:
+ data_dict = preprocess(sources, tokenizer, has_image=has_image, image_ids=_image_ids)
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+ gprint(f"Error in preprocess: {e}")
+ return None, None, None
+ input_ids = data_dict["input_ids"][0]
+ attention_mask = data_dict["attention_mask"][0]
+ modality = data_dict["modality"][0]
+ if (input_ids[-2:] == tokenizer.eos_token_id).all():
+ input_ids = input_ids[:-1]
+ attention_mask = attention_mask[:-1]
+ modality = modality[:-1]
+
+ if input_ids.shape[0] > _length:
+ gprint(f"WARNING!!!! Truncating input ids: {input_ids.shape[0]} vs. {_length}")
+ input_ids = input_ids[:_length]
+ attention_mask = attention_mask[:_length]
+ modality = modality[:_length]
+ input_ids = torch.nn.functional.pad(input_ids, (0, _length - input_ids.shape[-1]), value=tokenizer.pad_token_id)
+ attention_mask = torch.nn.functional.pad(attention_mask.bool(), (0, _length - attention_mask.shape[-1]), value=False)
+ modality = torch.nn.functional.pad(modality, (0, _length - modality.shape[-1]), value=0)
+
+ # We don't want to cut off an image.
+ if modality[-1] == 1:
+ is_image = modality == 1
+ change_points = torch.where(is_image[:-1] != is_image[1:])[0] + 1
+ if change_points.numel() > 0:
+ # Get start of last contiguous image sequence
+ start_pos = change_points[-1].item()
+ modality[start_pos:] = 0
+ attention_mask[start_pos:] = False
+ input_ids[start_pos:] = tokenizer.pad_token_id
+
+ all_input_ids.append(input_ids)
+ all_attention_masks.append(attention_mask)
+ all_modality.append(modality)
+
+ all_input_ids = torch.stack(all_input_ids)
+ all_attention_masks = torch.stack(all_attention_masks)
+ all_modality = torch.stack(all_modality)
+
+ return all_input_ids, all_attention_masks, all_modality
+
+def decode_ids_batched(vae, tokens, pad_token_id, **kwargs):
+ all_text_ids, all_image_ids = [], []
+ num_text_tokens = 0
+ for b in tokens:
+ try:
+ text_ids, image_ids = decode_ids(vae, b.tolist(), **kwargs)
+ except Exception as e:
+ breakpoint()
+ all_image_ids.append(torch.tensor(image_ids))
+ all_text_ids.append(torch.tensor(text_ids))
+ num_text_tokens = max(num_text_tokens, all_text_ids[-1].shape[-1])
+
+ for i in range(len(all_text_ids)):
+ all_text_ids[i] = torch.nn.functional.pad(all_text_ids[i], (0, num_text_tokens - all_text_ids[i].shape[-1]), value=pad_token_id) if all_text_ids[i].shape[-1] < num_text_tokens else all_text_ids[i]
+
+ all_text_ids = torch.stack(all_text_ids)
+ return all_text_ids.to(tokens.device), all_image_ids
+
+def decode_ids(vae, tokens, return_tokens=False):
+ try:
+ generated_images = []
+ generation_result_processed = []
+
+ i = 0
+ while i < len(tokens):
+ token_id = tokens[i]
+ if token_id == vae.token2id(vae.image_start_token):
+ cache = []
+ for j in range(i + 1, len(tokens)):
+ if tokens[j] != vae.token2id(vae.image_end_token):
+ cache.append(tokens[j])
+ i = j + 1
+ else:
+ if return_tokens:
+ image = cache
+ else:
+ try:
+ image = vae.decode_image(cache)
+ except Exception as e:
+ rprint(f"Failed to decode image: len: {len(cache)}, E: {e}")
+
+ generated_images.append(image)
+ generation_result_processed.append(vae.token2id("<|image|>"))
+ i = j + 1
+ break
+ else:
+ generation_result_processed.append(token_id)
+ i += 1
+
+ if return_tokens:
+ rprint(f"generation_result_processed: {generation_result_processed[:50]}")
+ generated = generation_result_processed
+ else:
+ try:
+ generated = vae.tokenizer.decode(generation_result_processed)
+ except:
+ generated = None
+ rprint("Failed to decode text")
+ except Exception as e:
+ breakpoint()
+
+ return generated, generated_images
+
+def get_chameleon_images(vae, batch):
+ start_img_token = vae.token2id(vae.image_start_token)
+ end_img_token = vae.token2id(vae.image_end_token)
+ all_images = []
+ for i in range(batch["input_ids"].shape[0]):
+ start_idx = (batch["input_ids"][i] == start_img_token).nonzero(as_tuple=True)[0].item()
+ end_idx = (batch["input_ids"][i] == end_img_token).nonzero(as_tuple=True)[0].item()
+ all_images.append(batch["input_ids"][[i], start_idx:end_idx])
+ return all_images
+
+
+
+def preprocess_image(sample, image_processor):
+ """
+ Convert images to tensors for training.
+ Augmentations: random horizontal flip.
+ Normalization handled by wds.
+ """
+ image = [image_processor(s).unsqueeze(0) for s in sample]
+ image = torch.cat(image, dim=0)
+ image = torchvision.transforms.RandomHorizontalFlip(p=0.5)(image)
+ return image
+
+def preprocess_interleaved(
+ sample,
+ tokenizer,
+ clip_processor,
+ sim_threshold,
+ min_num_images,
+ max_num_images,
+ max_tokens=256,
+):
+
+ Image.MAX_IMAGE_PIXELS = 1000000000
+ N_CHANNELS = 3
+ MIN_KB = 10
+
+ """
+ Preprocess an interleaved image-text sequence, either by calling preprocess_gpt_interleaved (if the sequence
+ is ChatGPT-generated) or by preprocessing in this function (if the sequences is from MMC4).
+ """
+ info = sample
+
+ sentences = info["text_list"]
+ sim_matrix = info["similarity_matrix"]
+
+ # load images first to find which ones are valid
+ valid_images, valid_image_indices = [], []
+ for i, sample_image in enumerate(info["image_info"]):
+ print(i)
+ if "image_base64" not in sample_image:
+ # print(f"No image_base64 in sample_image")
+ continue
+ image_base64 = sample_image["image_base64"]
+ rawbytes = base64.b64decode(image_base64)
+
+ # filter to images >= 10KB
+ if len(rawbytes) // 1000 <= MIN_KB:
+ # print(f"Image {i} is too small")
+ continue
+
+ image = Image.open(io.BytesIO(rawbytes)).convert("RGB")
+ valid_images.append(image)
+ valid_image_indices.append(i)
+
+ if len(valid_image_indices) == 0:
+ raise ValueError("No images in sample")
+
+ sim_matrix = np.array(sim_matrix) # of shape images x sentences
+ sim_matrix = sim_matrix[valid_image_indices]
+
+ # negate the similarities to turn then into costs
+ cost_matrix = -sim_matrix
+ # find one to one assignements
+ from scipy.optimize import linear_sum_assignment
+ image_indices, sentence_indices = linear_sum_assignment(cost_matrix)
+
+ images, sentence_ixs = [], []
+ for i, sim_ix in zip(image_indices, sentence_indices):
+ sim_score = sim_matrix[i][sim_ix]
+
+ if sim_score < sim_threshold:
+ continue
+
+ images.append(valid_images[i])
+ sentence_ixs.append(sim_ix)
+
+ if len(images) == 0:
+ raise ValueError("No images in sample after filtering")
+
+ # preprocess and pad images
+ images_tensors = preprocess_image(images, clip_processor)
+ keep_ixs = range(min(len(images_tensors), max_num_images))
+ images_tensors = images_tensors[keep_ixs]
+ sentence_ixs = [sentence_ixs[ix] for ix in keep_ixs]
+ if len(images_tensors) < max_num_images:
+ zero_padding = torch.zeros(
+ (
+ max_num_images - len(images_tensors),
+ N_CHANNELS,
+ images_tensors[0].shape[1],
+ images_tensors[0].shape[2],
+ ),
+ dtype=torch.float,
+ )
+ images_tensors = torch.cat((images_tensors, zero_padding), dim=0)
+
+ # preprocess and tokenize text
+ # add in and tokens
+ for ix in sentence_ixs:
+ sentences[ix] = f"<|endofchunk|>{sentences[ix]}"
+ text = " ".join(sentences)
+ text = text.replace("<|endofchunk|>", "", 1) # but remove first eoc
+ # whitespace cleanup
+ text = (
+ text.replace(" <|endofchunk|>", "<|endofchunk|>")
+ .replace(" ", "")
+ .replace(" ", "")
+ )
+ text = f"{text}<|endofchunk|>{tokenizer.eos_token}"
+ tokenizer.padding_side = "right"
+ text_tensor = tokenizer(
+ text,
+ max_length=max_tokens,
+ truncation=True,
+ padding="max_length",
+ return_tensors="pt",
+ )
+
+ # reject sequences with too few images (after truncation)
+ num_images = torch.count_nonzero(
+ text_tensor["input_ids"]
+ == tokenizer.additional_special_tokens_ids[
+ tokenizer.additional_special_tokens.index("")
+ ]
+ )
+ if num_images < min_num_images:
+ raise ValueError(f"Fewer than {min_num_images} images in sample")
+ elif (
+ num_images == 1 and random.random() <= 0.5
+ ): # 50% chance of keeping single image samples
+ raise ValueError("Only one image in sample")
+
+ # avoid the situation where there's one token and it's at the end
+ if (
+ num_images == 1
+ and text_tensor["input_ids"][:, -1]
+ == tokenizer.additional_special_tokens_ids[
+ tokenizer.additional_special_tokens.index("")
+ ]
+ ):
+ raise ValueError(
+ "Only one image at the end of sample, so labels will all be -100"
+ )
+
+ return (
+ images_tensors,
+ (text_tensor["input_ids"], text_tensor["attention_mask"]),
+ )
+
+# clip_processor = get_transform(512, 'train', True, False)
+# preprocess_interleaved(x['.json'], tokenizer, clip_processor, 0.5, 1, 100, 20000)
+
+if __name__ == "__main__":
+ vae = ItemProcessor(target_size=512)
+
+ from image_utils import Im
+ bs = 1
+ raw_img = Im.random().torch[None]
+ _img = var_center_crop(raw_img, vae.crop_size_list, random_top_k=1)
+ image_toks = vae.chameleon_ori_image_tokenizer.img_tokens_from_tensor(_img)
+ image_toks = vae.chameleon_ori_translation.img2bpe_mapping_tensor[image_toks]
+ h, w = _img.shape[-2:]
+ h_grids, w_grids = h // vae.patch_size, w // vae.patch_size
+ full_image_toks = image_toks.reshape(bs, h // 16, w // 16)
+ new_line_id = vae.token2id(vae.new_line_token)
+
+ full_image_toks = torch.cat(
+ (
+ full_image_toks,
+ torch.full((bs, h // 16, 1), fill_value=new_line_id, device=full_image_toks.device, dtype=full_image_toks.dtype),
+ ),
+ dim=-1,
+ ).flatten(start_dim=1, end_dim=-1)
+
+ result_toks = torch.cat([
+ torch.tensor([
+ vae.token2id(vae.image_start_token),
+ vae.token2id(vae.get_n_grids_token(h_grids)),
+ vae.token2id(vae.get_n_grids_token(w_grids))
+ ], device=full_image_toks.device, dtype=full_image_toks.dtype).unsqueeze(0).expand(bs, -1),
+ full_image_toks,
+ torch.tensor([
+ vae.token2id(vae.image_end_token)
+ ], device=full_image_toks.device, dtype=full_image_toks.dtype).unsqueeze(0).expand(bs, -1)
+ ], dim=1)
+
+ img = (result_toks[0], 512, 512)
+ output = vae.process_item(img, "hello", out_flatten=False)
+ breakpoint()
\ No newline at end of file
diff --git a/unidisc/tokenizers/conversation.py b/unidisc/tokenizers/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4033c92d0e5bdfde9b423eb28f022ab05dc67c4
--- /dev/null
+++ b/unidisc/tokenizers/conversation.py
@@ -0,0 +1,698 @@
+import dataclasses
+from enum import auto, Enum
+from typing import List, Tuple
+import base64
+from io import BytesIO
+from PIL import Image
+
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+ SINGLE = auto()
+ TWO = auto()
+ MPT = auto()
+ PLAIN = auto()
+ LLAMA_2 = auto()
+ LLAMA_3 = auto()
+ MISTRAL = auto()
+ GEMMA = auto()
+ PHI3 = auto()
+ LLAMA_2_PLAIN = auto()
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+ sep: str = "###"
+ sep2: str = None
+ version: str = "Unknown"
+
+ skip_next: bool = False
+
+ def get_prompt(self):
+ messages = self.messages
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
+ messages = self.messages.copy()
+ init_role, init_msg = messages[0].copy()
+ init_msg = init_msg[0].replace("", "").strip()
+ if 'mmtag' in self.version:
+ messages[0] = (init_role, init_msg)
+ messages.insert(0, (self.roles[0], ""))
+ messages.insert(1, (self.roles[1], "Received."))
+ else:
+ messages[0] = (init_role, "\n" + init_msg)
+
+ #print("message is", messages)
+
+ if self.sep_style == SeparatorStyle.SINGLE:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + self.sep
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.MPT:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+ elif self.sep_style == SeparatorStyle.LLAMA_2_PLAIN:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[-1]
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += seps[0] + message + seps[1]
+ else:
+ ret += ""
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
+ wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
+ ret = ""
+
+ for i, (role, message) in enumerate(messages):
+ if i == 0:
+ assert message, "first message should not be none"
+ assert role == self.roles[0], "first message should come from user"
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ if i == 0: message = wrap_sys(self.system) + message
+ if i % 2 == 0:
+ message = wrap_inst(message)
+ ret += self.sep + message
+ else:
+ ret += " " + message + " " + self.sep2
+ else:
+ ret += ""
+ ret = ret.lstrip(self.sep)
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
+ wrap_sys = lambda msg: f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>{msg}<|eot_id|>" if len(msg) > 0 else msg
+ wrap_inst_user = lambda msg: f"<|start_header_id|>user<|end_header_id|>{msg}<|eot_id|>"
+ wrap_inst_assistant = lambda msg: f"<|start_header_id|>assistant<|end_header_id|>{msg}<|eot_id|>"
+ ret = ""
+
+ for i, (role, message) in enumerate(messages):
+ if i == 0:
+ assert message, "first message should not be none"
+ assert role == self.roles[0], "first message should come from user"
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ if i == 0: ret += wrap_sys(self.system)
+
+ if i % 2 == 0:
+ message = wrap_inst_user(message)
+ ret += message
+ else:
+ message = wrap_inst_assistant(message)
+ ret += message
+ else:
+ ret += ""
+ ret += "<|start_header_id|>assistant<|end_header_id|>"
+ elif self.sep_style == SeparatorStyle.MISTRAL:
+ wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
+ ret = "" # bos token
+
+ for i, (role, message) in enumerate(messages):
+ if i == 0:
+ assert message, "first message should not be none"
+ assert role == self.roles[0], "first message should come from user"
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ if i == 0: message = wrap_sys(self.system) + message
+ if i % 2 == 0:
+ message = wrap_inst(message)
+ ret += self.sep + message
+ else:
+ ret += message + self.sep2
+ else:
+ ret += ""
+ ret = ret.lstrip(self.sep)
+ elif self.sep_style == SeparatorStyle.PLAIN:
+ seps = [self.sep, self.sep2]
+ ret = self.system
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += message + seps[i % 2]
+ else:
+ ret += ""
+ elif self.sep_style == SeparatorStyle.GEMMA:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+ elif self.sep_style == SeparatorStyle.PHI3:
+ ret = self.system + self.sep
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += self.roles[i % 2] + message + self.sep
+ else:
+ ret += self.roles[i % 2]
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ return ret
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
+ if image_process_mode == "Pad":
+ def expand2square(pil_img, background_color=(122, 116, 104)):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+ image = expand2square(image)
+ elif image_process_mode in ["Default", "Crop"]:
+ pass
+ elif image_process_mode == "Resize":
+ image = image.resize((336, 336))
+ else:
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
+ if max(image.size) > max_len:
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ if return_pil:
+ return image
+ else:
+ buffered = BytesIO()
+ image.save(buffered, format=image_format)
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+ return img_b64_str
+
+ def get_images(self, return_pil=False):
+ images = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ msg, image, image_process_mode = msg
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
+ images.append(image)
+ return images
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ msg, image, image_process_mode = msg
+ img_b64_str = self.process_image(
+ image, "Default", return_pil=False,
+ image_format='JPEG')
+ img_str = f'
'
+ msg = img_str + msg.replace('', '').strip()
+ ret.append([msg, None])
+ else:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ version=self.version)
+
+ def dict(self):
+ if len(self.get_images()) > 0:
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+
+
+conv_vicuna_v0 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
+ ("Assistant",
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
+ "renewable and non-renewable energy sources:\n"
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
+ "energy sources are finite and will eventually run out.\n"
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
+ "and other negative effects.\n"
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
+ "have lower operational costs than non-renewable sources.\n"
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
+ "locations than non-renewable sources.\n"
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_vicuna_v1 = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_vicuna_cambrian = Conversation(
+ system="",
+ roles=("Human", "GPT"),
+ version="vicuna_cambrian",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep="\n",
+ sep2="\n\n",
+)
+
+conv_vicuna_cambrian = Conversation(
+ system="",
+ roles=("Human", "GPT"),
+ version="vicuna_cambrian",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep="\n",
+ sep2="\n\n",
+)
+
+conv_llama_2 = Conversation(
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_cambrian_llama_2 = Conversation(
+ system="""You are a highly intelligent multimodal AI with the ability to analyze and generate images.""",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_mpt = Conversation(
+ system="""<|im_start|>system
+A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+conv_gemma = Conversation(
+ system="""""",
+ roles=("user\n", "model\n"),
+ version="gemma",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.GEMMA,
+ sep="\n",
+)
+
+conv_cambrian_plain = Conversation(
+ system="",
+ roles=("", ""),
+ messages=(
+ ),
+ offset=0,
+ sep_style=SeparatorStyle.PLAIN,
+ sep="\n",
+)
+
+conv_cambrian_v0 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ),
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_cambrian_v0_mmtag = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "The visual content will be provided with the following format: visual content.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ),
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+ version="v0_mmtag",
+)
+
+conv_cambrian_v1 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_cambrian_cohere = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="coherev1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="<|END_OF_TURN_TOKEN|>",
+)
+
+conv_cambrian_cohere = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="coherev1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="<|END_OF_TURN_TOKEN|>",
+)
+
+conv_cambrian_v1_mmtag = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "The visual content will be provided with the following format: visual content.",
+ roles=("USER", "ASSISTANT"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+ version="v1_mmtag",
+)
+
+conv_mistral_instruct = Conversation(
+ system="",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_mistral_v2 = Conversation(
+ system="",
+ roles=("USER", "ASSISTANT"), # NOTE: these are not injected into the prompt. does not matter
+ version="mistral_v2",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MISTRAL,
+ sep="",
+ sep2="",
+)
+
+conv_mistral_v2 = Conversation(
+ system="",
+ roles=("USER", "ASSISTANT"), # NOTE: these are not injected into the prompt. does not matter
+ version="mistral_v2",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MISTRAL,
+ sep="",
+ sep2="",
+)
+
+conv_chatml_direct = Conversation(
+ system="""<|im_start|>system
+Answer the questions.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+
+conv_llama_3 = Conversation(
+ system="""You are Cambrian, a highly intelligent multimodal AI trained by NYU Vision X.
+ As a multimodal AI, you have the ability to process and analyze images. Whenever an image is present in the conversation, very carefully examine it and consider its content when formulating your response.
+ You should give concise responses to very simple questions, but provide thorough responses to more complex and open-ended questions. """,
+ roles=("USER", "ASSISTANT"),
+ version="llama_v3",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_3,
+ sep="<|begin_of_text|>",
+ sep2="<|end_of_text|>",
+)
+conv_chatml_direct = Conversation(
+ system="""<|im_start|>system\nYou are Cambrian, a highly intelligent multimodal AI trained by NYU Vision X. As a multimodal AI, you have the ability to process and analyze images. Whenever an image is present in the conversation, very carefully examine it and consider its content when formulating your response. You should give concise responses to very simple questions, but provide thorough responses to more complex and open-ended questions.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+conv_cambrian_chatml = Conversation(
+ system="""<|im_start|>system\nYou are Cambrian, a highly intelligent multimodal AI trained by NYU Vision X. As a multimodal AI, you have the ability to process and analyze images. Whenever an image is present in the conversation, very carefully examine it and consider its content when formulating your response. You should give concise responses to very simple questions, but provide thorough responses to more complex and open-ended questions.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+conv_phi3 = Conversation(
+ system="""<|system|>\nYou are a helpful AI assistant.""",
+ roles=("\n<|user|>\n", "\n<|assistant|>\n"),
+ version="phi3",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.PHI3,
+ sep="<|end|>",
+)
+
+
+conv_cambrian_plain = Conversation(
+ system="",
+ roles=("", ""),
+ messages=(
+ ),
+ offset=0,
+ sep_style=SeparatorStyle.PLAIN,
+ sep="\n",
+)
+
+# conv_cambrian_llama_2 = Conversation(
+# system="""You are a highly intelligent multimodal AI with the ability to analyze and generate images.""",
+# roles=("USER", "ASSISTANT"),
+# version="llama_v2",
+# messages=(),
+# offset=0,
+# sep_style=SeparatorStyle.LLAMA_2,
+# sep="",
+# sep2="",
+# )
+
+conv_unidisc_llama_2 = Conversation(
+ system="""You are a highly intelligent multimodal AI with the ability to analyze and generate images.""",
+ roles=("", ""),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2_PLAIN,
+ sep="",
+ sep2="",
+)
+
+# default_conversation = conv_chatml_direct
+default_conversation = conv_unidisc_llama_2
+
+#default_conversation = conv_llama_3
+
+conv_templates = {
+ "default": conv_vicuna_v0,
+ "v0": conv_vicuna_v0,
+ "v1": conv_vicuna_v1,
+ "vicuna_v1": conv_vicuna_v1,
+ "vicuna_cambrian": conv_vicuna_cambrian,
+ "cohere_v1": conv_cambrian_cohere,
+ "vicuna_cambrian": conv_vicuna_cambrian,
+ "cohere_v1": conv_cambrian_cohere,
+ "llama_2": conv_llama_2,
+ "llama_3": conv_llama_3,
+ "llama_v3": conv_llama_3,
+ "mistral_instruct": conv_mistral_instruct,
+ "chatml_direct": conv_chatml_direct,
+ "cambrian_chatml": conv_cambrian_chatml,
+ "mistral_direct": conv_chatml_direct,
+ "mistral_v2": conv_mistral_v2,
+ "mistral_v2": conv_mistral_v2,
+
+ "plain": conv_cambrian_plain,
+ "v0_plain": conv_cambrian_plain,
+ "cambrian_v0": conv_cambrian_v0,
+ "v0_mmtag": conv_cambrian_v0_mmtag,
+ "cambrian_v1": conv_cambrian_v1,
+ "v1_mmtag": conv_cambrian_v1_mmtag,
+ "cambrian_llama_2": conv_cambrian_llama_2,
+ "mpt": conv_mpt,
+ "conv_gemma": conv_gemma,
+ "phi3": conv_phi3,
+}
+
+image_gen_data = None
+
+def get_image_gen_tokens(tokenizer):
+ global image_gen_data
+ if image_gen_data is None:
+ conv = default_conversation.copy()
+ # roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+ # sentence = {"from": "human", "value": "Please generate an image of: "}
+ # role = roles[sentence["from"]]
+ # conv.append_message(role, sentence["value"])
+ prompt = conv.get_prompt()
+ prompt = prompt.removesuffix('\n')
+ prompt = prompt.removesuffix('') # An EOS is added by the tokenizer
+ image_gen_data = tokenizer([prompt], return_tensors="pt")
+ return image_gen_data
+
+image_token_suffix = None
+
+def get_image_suffix(tokenizer):
+ global image_token_suffix
+ if image_token_suffix is None:
+ image_token_suffix = tokenizer("", add_special_tokens=False, return_attention_mask=False, return_token_type_ids=False).input_ids
+ return image_token_suffix
+
+if __name__ == "__main__":
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf", add_eos_token=True, padding_side='right', use_fast=False)
+ special_token = ''
+ existing_id = 811
+ tmp_index = len(tokenizer)
+ tokenizer.add_special_tokens({
+ 'additional_special_tokens': [special_token]
+ }, replace_additional_special_tokens=False)
+ tokenizer._added_tokens_decoder[existing_id] = tokenizer._added_tokens_decoder.pop(tmp_index)
+ assert len(tokenizer.additional_special_tokens_ids) == 1
+ tokenizer.additional_special_tokens_ids = [existing_id]
+ tokenizer._added_tokens_encoder[''] = existing_id
+ tokenizer.total_vocab_size = tmp_index
+
+ conv = default_conversation.copy()
+ prompt = conv.get_prompt()
+ print(prompt)
+
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+ conversation_ = [
+ {"from": "human", "value": ""}, # What time of day does it seem to be in this picture?"
+ {"from": "gpt", "value": "Given the low light and the use of artificial lighting from the laptop screen, it appears to be nighttime in this image."},
+ # {"from": "human", "value": "What could the person in the image be doing?"},
+ # {"from": "gpt", "value": "The person seems to be using a laptop, possibly working, browsing the internet, or watching media content given the focused posture and the environment suggesting a moment of privacy and concentration."}
+ ]
+ for sentence in conversation_:
+ role = roles[sentence["from"]]
+ conv.append_message(role, sentence["value"])
+ prompt = conv.get_prompt()
+ print(prompt)
+
+ new_tokens = [""]
+ new_tokens = list(set(new_tokens) - set(tokenizer.get_vocab().keys()))
+ tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})
+ print(tokenizer(new_tokens, add_special_tokens=False))
+ breakpoint()
diff --git a/unidisc/tokenizers/hpsv2_img_score.py b/unidisc/tokenizers/hpsv2_img_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f07e22951889365416ecdc065d46ae9bbde44e6
--- /dev/null
+++ b/unidisc/tokenizers/hpsv2_img_score.py
@@ -0,0 +1,109 @@
+import torch
+from PIL import Image
+from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
+import warnings
+import argparse
+import os
+import requests
+from clint.textui import progress
+from typing import Union
+import huggingface_hub
+from hpsv2.utils import root_path, hps_version_map
+
+def initialize_model(device, hps_version):
+ model_dict = {}
+ model, preprocess_train, preprocess_val = create_model_and_transforms(
+ 'ViT-H-14',
+ 'laion2B-s32B-b79K',
+ precision='amp',
+ device=device,
+ jit=False,
+ force_quick_gelu=False,
+ force_custom_text=False,
+ force_patch_dropout=False,
+ force_image_size=None,
+ pretrained_image=False,
+ image_mean=None,
+ image_std=None,
+ light_augmentation=True,
+ aug_cfg={},
+ output_dict=True,
+ with_score_predictor=False,
+ with_region_predictor=False
+ )
+
+ cp = huggingface_hub.hf_hub_download("xswu/HPSv2", hps_version_map[hps_version])
+ checkpoint = torch.load(cp, map_location=device)
+ model.load_state_dict(checkpoint['state_dict'])
+ tokenizer = get_tokenizer('ViT-H-14')
+ model = model.to(device)
+ model.eval()
+
+ model_dict['model'] = model
+ model_dict['preprocess_val'] = preprocess_val
+ model_dict['device'] = device
+ model_dict['tokenizer'] = tokenizer
+
+ return model_dict
+
+def score(model_dict, img_path: Union[list, str, Image.Image], prompt: str) -> list:
+ model = model_dict['model']
+ preprocess_val = model_dict['preprocess_val']
+ device = model_dict['device']
+ tokenizer = model_dict['tokenizer']
+
+ if isinstance(img_path, list):
+ result = []
+ for one_img_path in img_path:
+ # Load your image and prompt
+ with torch.no_grad():
+ # Process the image
+ if isinstance(one_img_path, str):
+ image = preprocess_val(Image.open(one_img_path)).unsqueeze(0).to(device=device, non_blocking=True)
+ elif isinstance(one_img_path, Image.Image):
+ image = preprocess_val(one_img_path).unsqueeze(0).to(device=device, non_blocking=True)
+ else:
+ raise TypeError('The type of parameter img_path is illegal.')
+ # Process the prompt
+ text = tokenizer([prompt]).to(device=device, non_blocking=True)
+ # Calculate the HPS
+ with torch.cuda.amp.autocast():
+ outputs = model(image, text)
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
+ logits_per_image = image_features @ text_features.T
+
+ hps_score = torch.diagonal(logits_per_image).cpu().numpy()
+ result.append(hps_score[0])
+ return result
+ elif isinstance(img_path, str):
+ # Load your image and prompt
+ with torch.no_grad():
+ # Process the image
+ image = preprocess_val(Image.open(img_path)).unsqueeze(0).to(device=device, non_blocking=True)
+ # Process the prompt
+ text = tokenizer([prompt]).to(device=device, non_blocking=True)
+ # Calculate the HPS
+ with torch.cuda.amp.autocast():
+ outputs = model(image, text)
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
+ logits_per_image = image_features @ text_features.T
+
+ hps_score = torch.diagonal(logits_per_image).cpu().numpy()
+ return [hps_score[0]]
+ elif isinstance(img_path, Image.Image):
+ # Load your image and prompt
+ with torch.no_grad():
+ # Process the image
+ image = preprocess_val(img_path).unsqueeze(0).to(device=device, non_blocking=True)
+ # Process the prompt
+ text = tokenizer([prompt]).to(device=device, non_blocking=True)
+ # Calculate the HPS
+ with torch.cuda.amp.autocast():
+ outputs = model(image, text)
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
+ logits_per_image = image_features @ text_features.T
+
+ hps_score = torch.diagonal(logits_per_image).cpu().numpy()
+ return [hps_score[0]]
+ else:
+ raise TypeError('The type of parameter img_path is illegal.')
\ No newline at end of file
diff --git a/unidisc/tokenizers/image_tokenizers.py b/unidisc/tokenizers/image_tokenizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..97ff0fe86e8151fdde4f04b41cbb58ce239d05cd
--- /dev/null
+++ b/unidisc/tokenizers/image_tokenizers.py
@@ -0,0 +1,409 @@
+import sys
+from math import sqrt
+from pathlib import Path
+from types import FrameType
+
+import einops
+import hydra
+import hydra.utils
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from einops import rearrange
+from constants import LIB_DIR, UNIDISC_DIR
+from decoupled_utils import gprint, rprint
+from torchvision import transforms
+
+def get_vae(config, device, use_cond: bool = False):
+ def get_attr(attr_name):
+ if use_cond:
+ return getattr(config.model, f"cond_{attr_name}", None)
+ return getattr(config.model, attr_name, None)
+
+ vae_type = get_attr("vae_type")
+ if vae_type == "maskgit":
+ rprint(f"Using MaskGit VQGAN")
+ from vqgan.modeling_maskgit_vqgan import MaskGitVQGAN
+
+ vae = MaskGitVQGAN.from_pretrained(Path(__file__).parent.parent.parent / "vqgan" / "vqgan_pretrained")
+ assert get_attr("image_vocab_size") == vae.config.num_embeddings
+ elif vae_type == "taming":
+ rprint(f"Using Taming VQGAN")
+ from vqgan.modeling_taming_vqgan import VQGANModel
+
+ vae = VQGANModel.from_pretrained(Path(__file__).parent / "vqgan" / "vqgan_taming_ckpt")
+ elif vae_type == "diffusers":
+ from diffusers import VQModel
+
+ vae = VQModel.from_pretrained(get_attr("use_custom_vae_ckpt"), subfolder="vqvae")
+ vae.config.lookup_from_codebook = True
+ elif vae_type == "raw":
+ return None
+ elif vae_type == "video_vqvae":
+ sys.path.append(str(LIB_DIR / "Open-Sora-Plan"))
+ from opensora.models.ae import VQVAEModel
+
+ vae = VQVAEModel.download_and_load_model("kinetics_stride4x4x4")
+ elif vae_type == "VQ-16" or vae_type == "VQ-8":
+ sys.path.append(str(LIB_DIR / "LlamaGen"))
+ from tokenizer.tokenizer_image.vq_model_hf import (VQ_models_HF, VQModelHF)
+ if get_attr("use_custom_vae_ckpt") is not None:
+ from tokenizer.tokenizer_image.vq_model import VQ_models
+ vae = VQ_models["VQ-8" if vae_type == "VQ-8" else "VQ-16"](codebook_size=get_attr("image_vocab_size"), codebook_embed_dim=getattr(config.model, "codebook_embed_dim", 256))
+ vae.load_state_dict(torch.load(get_attr("use_custom_vae_ckpt"), map_location=device)["model"])
+ assert get_attr("downscale_ratio") == (8 if vae_type == "VQ-8" else 16)
+ elif vae_type == "VQ-8":
+ vae = VQ_models_HF["VQ-8"]()
+ vae.load_state_dict(torch.load(UNIDISC_DIR / "ckpts/vq_ds8_c2i.pt")["model"])
+ assert get_attr("downscale_ratio") == 8
+ elif vae_type == "VQ-16":
+ vae = VQModelHF.from_pretrained("FoundationVision/vq-ds16-c2i")
+ assert get_attr("downscale_ratio") == 16
+ assert (
+ get_attr("image_vocab_size") == vae.config.codebook_size
+ ), f"Image vocab size {get_attr('image_vocab_size')} does not match VAE codebook size {vae.config.codebook_size}"
+ elif vae_type == "lfq_128" or vae_type == "lfq_256":
+ sys.path.append(str(LIB_DIR / "Open-MAGVIT2"))
+ from magvit_inference import load_vqgan_new
+ from omegaconf import OmegaConf
+
+ if get_attr("use_custom_vae_ckpt") is not None and get_attr("use_custom_vae_config") is not None:
+ config_file = get_attr("use_custom_vae_config")
+ ckpt_path = get_attr("use_custom_vae_ckpt")
+ rprint(f"Using custom VAE config: {config_file} and ckpt: {ckpt_path}")
+ else:
+ config_file = LIB_DIR / "Open-MAGVIT2" / "configs" / f"imagenet_lfqgan_{128 if '128' in vae_type else '256'}_B.yaml"
+ ckpt_path = LIB_DIR / "Open-MAGVIT2" / "ckpts" / f"imagenet_{128 if '128' in vae_type else '256'}_B.ckpt"
+ configs = OmegaConf.load(config_file)
+ vae = load_vqgan_new(configs, ckpt_path).to(device)
+ elif vae_type == "bsq_18":
+ sys.path.append(str(LIB_DIR / "bsq-vit"))
+ from scripts.main_image_tokenizer import get_model
+ vae = get_model("bsq_18").to(device)
+ elif vae_type == 'cosmos':
+ # To use Cosmos, you first need to download the pretrained models from HuggingFace.
+ # from huggingface_hub import login, snapshot_download
+ # import os
+ # HUGGINGFACE_TOKEN = os.environ.get("HF_TOKEN")
+ # login(token=HUGGINGFACE_TOKEN, add_to_git_credential=True)
+ # model_names = [
+ # # "Cosmos-0.1-Tokenizer-CI8x8",
+ # # "Cosmos-0.1-Tokenizer-CI16x16",
+ # # "Cosmos-0.1-Tokenizer-DI8x8",
+ # "Cosmos-0.1-Tokenizer-DI16x16",
+ # ]
+ # for model_name in model_names:
+ # hf_repo = "nvidia/" + model_name
+ # local_dir = "pretrained_ckpts/" + model_name
+ # os.makedirs(local_dir, exist_ok=True)
+ # print(f"downloading {model_name}...")
+ # snapshot_download(repo_id=hf_repo, local_dir=local_dir)
+ import importlib
+ import cosmos_tokenizer.image_lib
+ importlib.reload(cosmos_tokenizer.image_lib)
+ from cosmos_tokenizer.image_lib import ImageTokenizer
+ model_name = 'Cosmos-0.1-Tokenizer-DI16x16' # @param ["Cosmos-0.1-Tokenizer-CI16x16", "Cosmos-0.1-Tokenizer-CI8x8", "Cosmos-0.1-Tokenizer-DI8x8", "Cosmos-0.1-Tokenizer-DI16x16"]
+ cosmos_dir = Path(LIB_DIR / "Cosmos-Tokenizer")
+ encoder_ckpt = str(cosmos_dir / f"pretrained_ckpts/{model_name}/encoder.jit")
+ decoder_ckpt = str(cosmos_dir / f"pretrained_ckpts/{model_name}/decoder.jit")
+ vae = ImageTokenizer(
+ checkpoint_enc=encoder_ckpt,
+ checkpoint_dec=decoder_ckpt,
+ device="cuda",
+ dtype="bfloat16",
+ )
+
+ torch._C._jit_override_can_fuse_on_cpu(False)
+ torch._C._jit_override_can_fuse_on_gpu(False)
+ torch._C._jit_set_texpr_fuser_enabled(False)
+ torch._C._jit_set_nvfuser_enabled(False)
+ elif "titok" in vae_type:
+ from huggingface_hub import hf_hub_download
+ sys.path.append(str(LIB_DIR / "1d-tokenizer"))
+ from modeling.titok import TiTok
+ if vae_type == "titok256":
+ vae = TiTok.from_pretrained("yucornetto/tokenizer_titok_sl256_vq8k_imagenet")
+ elif vae_type == "titok128":
+ vae = TiTok.from_pretrained("yucornetto/tokenizer_titok_bl128_vq8k_imagenet")
+ elif vae_type == "titok64":
+ vae = TiTok.from_pretrained("yucornetto/tokenizer_titok_b64_imagenet")
+ else:
+ raise ValueError(f"Unknown TiTok type: {vae_type}")
+ vae.eval()
+ vae.requires_grad_(False)
+ elif vae_type == "chameleon":
+ from transformers import ChameleonForConditionalGeneration
+ model = ChameleonForConditionalGeneration.from_pretrained(
+ "leloy/Anole-7b-v0.1-hf",
+ device_map="auto",
+ torch_dtype=torch.bfloat16,
+ )
+ vae = model.model
+ vae.vqmodel.to(torch.float32)
+ if config.data.resolution == 256:
+ vae.vqmodel.quantize.quant_state_dims = [16, 16]
+ elif config.data.resolution == 512:
+ vae.vqmodel.quantize.quant_state_dims = [32, 32]
+ elif vae_type == "lumina":
+ from unidisc.tokenizers.chameleon_tokenizers import ItemProcessor
+ vae = ItemProcessor(target_size=config.data.resolution)
+ elif vae_type == "stable_diffusion":
+ from diffusers import StableDiffusionPipeline
+ pipe = StableDiffusionPipeline.from_pretrained(
+ "benjamin-paine/stable-diffusion-v1-5", torch_dtype=torch.float16
+ ) # since runwayml/stable-diffusion-v1-5 dont work now
+ vae = pipe.vae
+
+ # add pipe.scheduler to vae as a new attribute
+ vae.scheduler = pipe.scheduler
+ elif vae_type == "magvit":
+ # sys.path.append(str(LIB_DIR / "Show-o"))
+ import importlib.util
+ def load_package(alias, pkg_path):
+ pkg_path = Path(pkg_path)
+ init_file = pkg_path / "__init__.py"
+ spec = importlib.util.spec_from_file_location(
+ alias,
+ str(init_file),
+ submodule_search_locations=[str(pkg_path)]
+ )
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[alias] = module
+ spec.loader.exec_module(module)
+ return module
+ magvit2 = load_package("MAGVITv2", str(LIB_DIR / "Show-o" / "models"))
+ vae = magvit2.MAGVITv2.from_pretrained("showlab/magvitv2").to(device)
+ vae.requires_grad_(False)
+ vae.eval()
+ else:
+ raise ValueError(f"Unknown VAE type: {vae_type}")
+
+ if vae_type != "lumina":
+ vae.requires_grad_(False)
+ vae = vae.to(device)
+ return vae
+
+
+@torch.no_grad()
+def vae_encode_image(config, vae, image, device, vae_type: str, use_cond: bool = False):
+ def get_attr(attr_name):
+ if use_cond:
+ return getattr(config.model, f"cond_{attr_name}", None)
+ return getattr(config.model, attr_name, None)
+ with torch.autocast(device_type="cuda", enabled=False):
+ image = image.to(device=device, dtype=torch.float32)
+ assert image.min() >= 0 - 1e-2 and image.max() <= 1 + 1e-2, f"Image values out of bounds: {image.min()}, {image.max()}"
+ downscale_ratio = get_attr("downscale_ratio")
+ batch_size = image.shape[0]
+ latent_dim = image.shape[-1] // downscale_ratio
+
+ if vae_type == "stable_diffusion":
+ # continuous latents
+
+ train_transforms = transforms.Compose(
+ [
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ latents = vae.encode(train_transforms(image).to(vae.dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor # shape = (B, C, H, W)
+ latents = torch.permute(latents, (0, 2, 3, 1)) # shape = (B, H, W, C)
+ latents = einops.rearrange(latents, 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=config.model.patching_downscale, p2=config.model.patching_downscale) # shape = (B, H*W, C*config.model.patching_downscale**2)
+ return latents
+ image.clamp_(0, 1) # todo verify if needed for continuous vae
+ if vae_type == "diffusers":
+ if "CompVis/ldm-celebahq-256" in get_attr("use_custom_vae_ckpt"):
+ image = (image * 2) - 1
+ latents = vae.encode(image).latents
+ discrete = vae.quantize(latents)[-1][-1]
+ discrete = rearrange(discrete, "(b n) -> b n", b=batch_size)
+ elif vae_type == "raw":
+ discrete = rearrange((image * 255).to(torch.int64), "b c h w -> b (c h w)")
+ elif vae_type == "maskgit":
+ discrete = vae.get_code(image)
+ elif vae_type == "taming":
+ _, discrete = vae.encode(image)
+ elif vae_type == "video_vqvae":
+ vae.temporal_dim_length = image.shape[2]
+ vae.spatial_dim_length = image.shape[3]
+ discrete = vae.encode(image.to(device)) # [B, C, T, H, W]
+ discrete = rearrange(discrete, "b t h w -> b (t h w)")
+ elif vae_type == "VQ-8" or vae_type == "VQ-16":
+ image = (image * 2) - 1
+ latent, _, [_, _, discrete] = vae.encode(image)
+ discrete = rearrange(discrete, "(b h w) -> b (h w)", h=latent_dim, w=latent_dim)
+ elif vae_type == "lfq_128" or vae_type == "lfq_256":
+ _, _, _, indices = vae(image, return_indices=True)
+ discrete = rearrange(indices, "(b n) -> b n", b=batch_size)
+ elif vae_type == "bsq_18":
+ image = (image * 2) - 1
+ quant, loss, info = vae.encode(image, skip_quantize=False)
+ discrete = info["indices"]
+ breakpoint()
+ elif vae_type == "cosmos":
+ discrete, _ = vae.encode((image.to(device) * 2) - 1)
+ discrete = rearrange(discrete, "b h w -> b (h w)")
+ elif "titok" in vae_type:
+ discrete = vae.encode(image.to(device))[1]["min_encoding_indices"].squeeze(1)
+ elif vae_type == "chameleon":
+ image = (image * 2) - 1
+ discrete = vae.get_image_tokens(image)
+ elif vae_type == "lumina":
+ breakpoint()
+ elif vae_type == "magvit":
+ discrete = vae.get_code(image)
+ else:
+ raise ValueError(f"Unknown VAE type: {vae_type}")
+ return discrete
+
+
+@torch.no_grad()
+def vae_decode_image(config, vae, discrete, use_cond: bool = False):
+ if discrete is None or (not isinstance(discrete, list) and discrete.shape[1] == 0):
+ return torch.zeros(1, 3, config.data.resolution, config.data.resolution)
+
+ def get_attr(attr_name):
+ if use_cond:
+ return getattr(config.model, f"cond_{attr_name}", None)
+ return getattr(config.model, attr_name, None)
+
+ with torch.autocast(device_type="cuda", enabled=False):
+ vae_type = get_attr("vae_type")
+ latent_dim = config.data.resolution // get_attr("downscale_ratio")
+ if vae_type == "stable_diffusion":
+ # input - (B, N, C)
+ original_height = config.data.resolution // config.model.downscale_ratio
+ discrete = rearrange(discrete, "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", h=original_height, p1=config.model.patching_downscale, p2=config.model.patching_downscale)
+ discrete = 1 / vae.config.scaling_factor * discrete
+ image = vae.decode(discrete.to(torch.float16), return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # image = image.detach().cpu().permute(0, 2, 3, 1).float()#.numpy()
+ return image
+
+ if not isinstance(discrete, list):
+ discrete = discrete.to(dtype=torch.int64)
+ if config.trainer.add_label and discrete.shape[-1] % 2 != 0:
+ discrete = discrete[:, 1:]
+
+ if vae_type != "chameleon" and vae_type != "lumina" and "titok" not in vae_type and discrete.shape[-1] != latent_dim ** 2:
+ for test_res in (128, 256, 512, 1024):
+ if (test_res // get_attr("downscale_ratio")) ** 2 == discrete.shape[-1]:
+ latent_dim = test_res // get_attr("downscale_ratio")
+ break
+ else:
+ raise ValueError(f"Unknown latent dimension: {latent_dim}")
+
+ if vae_type == "diffusers":
+ image = vae.decode(
+ discrete, force_not_quantize=True, shape=(discrete.shape[0], latent_dim, latent_dim, vae.config.latent_channels)
+ ).sample
+ if "CompVis/ldm-celebahq-256" in get_attr("use_custom_vae_ckpt"):
+ image = (image + 1) / 2
+ elif vae_type == "raw":
+ image = discrete / 255
+ latent_dim = int(sqrt(discrete.shape[1] // 3))
+ image = rearrange(image, "b (c h w) -> b c h w", c=3, h=latent_dim, w=latent_dim)
+ elif vae_type == "maskgit" or vae_type == "taming":
+ if not 0 <= discrete.min() and discrete.max() < get_attr("image_vocab_size"):
+ raise ValueError(f"Discrete values out of bounds: {discrete.min()}, {discrete.max()}")
+ assert 0 <= discrete.min() and discrete.max() < get_attr("image_vocab_size")
+ image = vae.decode_code(discrete)
+ elif vae_type == "video_vqvae":
+ image = vae.decode(
+ rearrange(discrete, "b (t h w) -> b t h w", t=vae.temporal_dim_length, h=vae.spatial_dim_length, w=vae.spatial_dim_length)
+ ) # [B, T // 4, H, W]
+ elif vae_type == "VQ-8" or vae_type == "VQ-16":
+ image = vae.decode_code(discrete, shape=(discrete.shape[0], vae.config.codebook_embed_dim, latent_dim, latent_dim))
+ image = (image + 1) / 2
+ elif vae_type == "lfq_128" or vae_type == "lfq_256":
+ x = discrete
+ # From taming/modules/vqvae/lookup_free_quantize.py. Index -> -1/1 float
+ mask = 2 ** torch.arange(vae.quantize.codebook_dim - 1, -1, -1, device=x.device, dtype=torch.long)
+ x = (x.unsqueeze(-1) & mask) != 0
+ x = (x * 2.0) - 1.0
+ x = rearrange(x, "b (h w) c -> b c h w", h=latent_dim, w=latent_dim)
+ image = vae.decode(x)
+ image = torch.clamp(image, 0.0, 1.0)
+ elif vae_type == "bsq_18":
+ quant = vae.quantize.get_codebook_entry(discrete)
+ image = vae.decode(quant)
+ image = torch.clamp(image, 0.0, 1.0)
+ elif vae_type == "cosmos":
+ image = vae.decode(discrete.reshape(discrete.shape[0], 16, 16))
+ image = (image / 2) + 0.5
+ image = torch.clamp(image, 0.0, 1.0)
+ elif "titok" in vae_type:
+ image = vae.decode_tokens(discrete.unsqueeze(1))
+ image = torch.clamp(image, 0.0, 1.0)
+ elif vae_type == "chameleon":
+ image = vae.decode_image_tokens(discrete)
+ image = (image + 1) / 2
+ image = torch.clamp(image, 0.0, 1.0)
+ elif vae_type == "lumina":
+ # We always expect either [B, N] or [[B, N], ...]
+ if not isinstance(discrete, list):
+ discrete = [discrete]
+
+ images = []
+ for i in range(len(discrete)):
+ for j in range(discrete[i].shape[0]):
+ images.append(torch.from_numpy(np.array(vae.decode_image(discrete[i][j].cpu().tolist()))))
+
+ image = torch.stack(images, dim=0).permute(0, 3, 1, 2) / 255
+ elif vae_type == "magvit":
+ image = vae.decode_code(discrete)
+ image = torch.clamp(image, 0.0, 1.0)
+ else:
+ raise ValueError(f"Unknown VAE type: {vae_type}")
+
+ image.clamp_(0, 1)
+ return image
+
+
+def auto_batch(config, fn, data):
+ split_size = 32 if getattr(config.eval, "force_empty_cache", False) else 128
+ if getattr(config.eval, "force_empty_cache", False):
+ from model_utils import empty_device_cache
+ empty_device_cache()
+
+ if data.shape[0] > split_size:
+ return torch.cat([fn(chunk) for chunk in torch.split(data, split_size, dim=0)], dim=0)
+ else:
+ return fn(data)
+
+def get_image_batch(config, vae, batch, device, use_cond: bool = False):
+ def get_attr(attr_name):
+ if use_cond:
+ return getattr(config.model, f"cond_{attr_name}", None)
+ return getattr(config.model, attr_name, None)
+
+ vae_type = get_attr("vae_type")
+ if "img" in batch:
+ return auto_batch(config, lambda img: vae_encode_image(config, vae, img, device, vae_type, use_cond), batch["img"])
+ elif "video" in batch:
+ if vae_type == "video_vqvae":
+ return vae_encode_image(config, vae, batch["video"], device, vae_type, use_cond)
+ else:
+ return torch.cat(
+ [
+ vae_encode_image(config, vae, batch["video"][:, :, frame_idx], device, vae_type, use_cond)
+ for frame_idx in range(batch["video"].shape[2])
+ ],
+ dim=-1,
+ )
+ else:
+ raise ValueError(f"Unknown batch type: {batch}")
+
+def decode_latents(config, vae, sample, use_cond: bool = False, batched: bool = True):
+ if getattr(config.model, "video_model", False):
+ num_frames = config.data.num_frames
+ frames = torch.split(sample, sample.shape[-1] // num_frames, dim=-1)
+ return torch.cat([vae_decode_image(config, vae, frame, use_cond) for frame in frames], dim=-2)
+ else:
+ if batched:
+ return auto_batch(config, lambda s: vae_decode_image(config, vae, s, use_cond), sample)
+ else:
+ return np.stack([vae_decode_image(config, vae, s.unsqueeze(0), use_cond).squeeze(0) for s in sample])
\ No newline at end of file
diff --git a/unidisc/tokenizers/laion_aesthetic_v2.py b/unidisc/tokenizers/laion_aesthetic_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff20ad0244d1f88eccd1125085f6ae6934bf83e1
--- /dev/null
+++ b/unidisc/tokenizers/laion_aesthetic_v2.py
@@ -0,0 +1,97 @@
+import torch
+import torch.nn as nn
+import numpy as np
+import clip
+import os
+import math
+from constants import UNIDISC_DIR
+from functools import partial
+
+aesthetic_path = str(UNIDISC_DIR / "ckpts" / "ava+logos-l14-linearMSE.pth")
+
+class AestheticPredictor(nn.Module):
+ def __init__(self, input_size):
+ super().__init__()
+ self.input_size = input_size
+ self.layers = nn.Sequential(
+ nn.Linear(self.input_size, 1024),
+ nn.Dropout(0.2),
+ nn.Linear(1024, 128),
+ nn.Dropout(0.2),
+ nn.Linear(128, 64),
+ nn.Dropout(0.1),
+ nn.Linear(64, 16),
+ nn.Linear(16, 1)
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+def get_image_features(image, device, model, preprocess):
+ image = preprocess(image)
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ image = image.to(device)
+ with torch.no_grad():
+ image_features = model.encode_image(image)
+ image_features /= image_features.norm(dim=-1, keepdim=True) # l2 normalize
+
+ image_features = image_features.cpu().detach().numpy()
+ return image_features
+
+def sigmoid(x):
+ return 1 / (1 + np.exp(-x))
+
+def orig_score(image=None, predictor=None, clip_model=None, clip_preprocess=None, device=None, prompt="", reverse=False):
+ image_features = get_image_features(image, device, clip_model, clip_preprocess)
+ score_origin = predictor(torch.from_numpy(image_features).to(device).float()).item() - 5.6
+ if reverse:
+ score_origin = score_origin*-1
+ _score = sigmoid(score_origin)
+ return _score
+
+def score(image=None, predictor=None, clip_model=None, clip_preprocess=None, device=None, prompt="", reverse=False):
+ image_features = get_image_features(image, device, clip_model, clip_preprocess)
+ score_origin = predictor(torch.from_numpy(image_features).to(device).float()) - 5.6
+ score_origin = score_origin.detach().cpu().numpy()
+ if reverse:
+ score_origin = score_origin*-1
+ _score = sigmoid(score_origin)
+ return _score
+
+@torch.no_grad()
+def get_predictor_func(device, accept_pillow=False):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ pt_state = torch.load(aesthetic_path, map_location=torch.device('cpu'))
+
+ # CLIP embedding dim is 768 for CLIP ViT L 14
+ predictor = AestheticPredictor(768)
+ predictor.load_state_dict(pt_state)
+ predictor.to(device)
+ predictor.eval()
+ clip_model, clip_preprocess = clip.load("ViT-L/14", device=device)
+ if not accept_pillow:
+ clip_preprocess.transforms = [clip_preprocess.transforms[0], clip_preprocess.transforms[1], clip_preprocess.transforms[4]]
+ get_reward = partial(score, predictor=predictor, clip_model=clip_model, clip_preprocess=clip_preprocess, device=device)
+ else:
+ get_reward = partial(orig_score, predictor=predictor, clip_model=clip_model, clip_preprocess=clip_preprocess, device=device)
+ return get_reward
+
+
+if __name__ == "__main__":
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ get_reward = get_predictor_func(device, accept_pillow=False)
+ from image_utils import Im
+ rand_img = torch.rand(5, 3, 224, 224)
+ rand_img[1] = Im.random().resize(224, 224).torch
+ rand_img[2] = Im.random().resize(224, 224).torch
+ rand_img[3] = Im.random().resize(224, 224).torch
+ rand_img[4] = Im("https://img.freepik.com/premium-photo/majestic-3d-lion-illustration-retro-aesthetic-artwork_971394-242.jpg").resize(224, 224).torch
+ torch_rewards = get_reward(image=rand_img)
+ print(torch_rewards)
+ from PIL import Image
+ pil_images = [Image.fromarray((img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)) for img in rand_img]
+ get_reward_pil = get_predictor_func(device, accept_pillow=True)
+ pil_rewards = [get_reward_pil(image=pil_img) for pil_img in pil_images]
+ print(pil_rewards)
diff --git a/unidisc/tokenizers/text_reward_model.py b/unidisc/tokenizers/text_reward_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..401e35f69eded900156d87d0d12cbecc7aa61c36
--- /dev/null
+++ b/unidisc/tokenizers/text_reward_model.py
@@ -0,0 +1,7 @@
+from transformers import AutoModelForSequenceClassification, AutoTokenizer
+reward_name = "OpenAssistant/reward-model-deberta-v3-large-v2"
+rank_model, tokenizer = AutoModelForSequenceClassification.from_pretrained(reward_name), AutoTokenizer.from_pretrained(reward_name)
+question, answer = "", ""
+inputs = tokenizer(question, answer, return_tensors='pt')
+score = rank_model(**inputs).logits[0].cpu().detach()
+print(score)
\ No newline at end of file
diff --git a/unidisc/tokenizers/text_tokenizers.py b/unidisc/tokenizers/text_tokenizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9cda1f74ce827c27efd3d6fb157d130980fa9c9
--- /dev/null
+++ b/unidisc/tokenizers/text_tokenizers.py
@@ -0,0 +1,104 @@
+
+import einops
+import hydra
+import hydra.utils
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+
+from constants import LIB_DIR, UNIDISC_DIR
+from decoupled_utils import (Profiler, barrier, dprint, get_rank,
+ get_slurm_job_id, get_world_size, gprint,
+ is_main_process, print_memory, rank_zero_fn,
+ rprint, save_memory_profile, try_except)
+
+
+class VggFaceTokenizer:
+ def __init__(self, mask_token_id, v2=False):
+ self.mask_token_id = mask_token_id
+ self.v2 = v2
+ self.idx_to_attr = [
+ "Male",
+ "Young",
+ "Middle_Aged",
+ "Senior",
+ "Asian",
+ "White",
+ "Black",
+ "Rosy_Cheeks",
+ "Shiny_Skin",
+ "Bald",
+ "Wavy_Hair",
+ "Receding_Hairline",
+ "Bangs",
+ "Sideburns",
+ "Black_Hair",
+ "Blond_Hair",
+ "Brown_Hair",
+ "Gray_Hair",
+ "No_Beard",
+ "Mustache",
+ "5_o_Clock_Shadow",
+ "Goatee",
+ "Oval_Face",
+ "Square_Face",
+ "Round_Face",
+ "Double_Chin",
+ "High_Cheekbones",
+ "Chubby",
+ "Obstructed_Forehead",
+ "Fully_Visible_Forehead",
+ "Brown_Eyes",
+ "Bags_Under_Eyes",
+ "Bushy_Eyebrows",
+ "Arched_Eyebrows",
+ "Mouth_Closed",
+ "Smiling",
+ "Big_Lips",
+ "Big_Nose",
+ "Pointy_Nose",
+ "Heavy_Makeup",
+ "Wearing_Hat",
+ "Wearing_Earrings",
+ "Wearing_Necktie",
+ "Wearing_Lipstick",
+ "No_Eyewear",
+ "Eyeglasses",
+ "Attractive",
+ ]
+ if self.v2:
+ self.idx_to_attr.insert(0, "Female")
+
+ def batch_decode(self, tokens_list, show_mask_token=False):
+ decoded_strs = []
+ for tokens in tokens_list:
+ if self.v2:
+ assert len(tokens) == 48, f"Expected 49 tokens, got {len(tokens)}"
+ else:
+ assert len(tokens) == 47, f"Expected 48 tokens, got {len(tokens)}"
+
+ example_str = []
+ for attr in tokens:
+ if 2 <= attr <= (49 if self.v2 else 48):
+ example_str.append(self.idx_to_attr[attr - 2])
+ elif attr == self.mask_token_id:
+ if show_mask_token:
+ example_str.append("[MASK]")
+ elif 0 <= attr < 2:
+ pass
+ else:
+ gprint(f"Unknown attribute id: {attr}")
+
+ decoded_strs.append("The person is " + ", and ".join(example_str))
+ return decoded_strs
+
+ @property
+ def eos_token(self):
+ return "END OF SENTENCE"
+
+ @property
+ def eos_token_id(self):
+ return 999999
\ No newline at end of file
diff --git a/unidisc/tokenizers/tokenize_interleaved.py b/unidisc/tokenizers/tokenize_interleaved.py
new file mode 100644
index 0000000000000000000000000000000000000000..60ad2654c1ef725307b78a08266f7d454d67b466
--- /dev/null
+++ b/unidisc/tokenizers/tokenize_interleaved.py
@@ -0,0 +1,293 @@
+from typing import Dict, Sequence
+import transformers
+import copy
+from unidisc.tokenizers import conversation as conversation_lib
+import json
+import torch
+from torch.utils.data import Dataset, DataLoader
+import glob
+import bisect
+import os
+import subprocess
+from pathlib import Path
+import socket
+
+# Model Constants
+DEFAULT_IMAGE_TOKEN = ""
+
+def tokenizer_image_token(prompt, tokenizer, return_tensors=None, image_ids=None, start_idx=None):
+ prompt_chunks = [tokenizer(chunk, add_special_tokens=False).input_ids for chunk in prompt.split('')]
+
+ input_ids = []
+ attention_mask = []
+ modality = []
+
+ start_idx = 0
+ for i, chunk in enumerate(prompt_chunks):
+ input_ids.extend(chunk)
+ attention_mask.extend([True] * len(chunk))
+ modality.extend([False] * len(chunk))
+
+ if i < len(prompt_chunks) - 1:
+ if image_ids is not None and start_idx < len(image_ids):
+ input_ids.extend([tokenizer.additional_special_tokens_ids[tokenizer.additional_special_tokens.index("")]])
+ attention_mask.append(True)
+ modality.append(False)
+
+ input_ids.extend(image_ids[start_idx].tolist())
+ attention_mask.extend([True] * len(image_ids[start_idx]))
+ modality.extend([True] * len(image_ids[start_idx]))
+ start_idx += 1
+
+ if not input_ids[0] == tokenizer.bos_token_id:
+ input_ids = [tokenizer.bos_token_id] + input_ids
+ attention_mask = [True] + attention_mask
+ modality = [False] + modality
+
+ if not input_ids[-1] == tokenizer.eos_token_id:
+ input_ids = input_ids + [tokenizer.eos_token_id]
+ attention_mask = attention_mask + [True]
+ modality = modality + [False]
+
+ if return_tensors is not None:
+ if return_tensors == 'pt':
+ return (torch.tensor(input_ids, dtype=torch.long), torch.tensor(attention_mask, dtype=torch.bool), torch.tensor(modality, dtype=torch.bool)), start_idx
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
+
+ return (input_ids, attention_mask, modality), start_idx
+
+
+def preprocess_llama_2(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False,
+ image_ids = None,
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+
+ # if not role == conv.roles[j % 2]:
+ # print(f"Role mismatch at {i}, {j}: {role} vs {conv.roles[j % 2]}")
+
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ print(f"After pre-processing: {conversations}")
+
+ # Tokenize conversations
+ if has_image:
+ data = []
+ start_idx = 0
+ for i, prompt in enumerate(conversations):
+ new_data, start_idx = tokenizer_image_token(prompt, tokenizer, return_tensors='pt', image_ids=image_ids, start_idx=start_idx)
+ data.append(new_data)
+
+ if not start_idx == len(image_ids):
+ breakpoint()
+
+ assert start_idx == len(image_ids), f"start_idx: {start_idx}, len(image_ids): {len(image_ids)}"
+
+ input_ids = torch.stack([x[0] for x in data], dim=0)
+ attention_mask = torch.stack([x[1] for x in data], dim=0)
+ modality = torch.stack([x[2] for x in data], dim=0)
+ else:
+ data = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=8192,
+ truncation=True,
+ )
+ attention_mask = data["attention_mask"]
+ input_ids = data["input_ids"]
+ modality = torch.zeros_like(attention_mask, dtype=torch.bool)
+
+ return dict(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ modality=modality
+ )
+
+def preprocess(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False,
+ image_ids = None
+) -> Dict:
+
+ """
+ Given a list of sources, each is a conversation list. This transform:
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
+ 2. Concatenate conversations together;
+ 3. Tokenize the concatenated conversation;
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
+ """
+ return preprocess_llama_2(sources, tokenizer, has_image=has_image, image_ids=image_ids)
+
+
+def _has_image(sample: dict) -> bool:
+ return "image" in sample and not str(sample['image']) in ['', 'None', 'none', 'nan']
+
+def preprocess_multimodal(
+ sources: Sequence[str],
+) -> Dict:
+ DEFAULT_IMAGE_PATCH_TOKEN = ""
+ DEFAULT_IM_START_TOKEN = ""
+ DEFAULT_IM_END_TOKEN = ""
+ is_multimodal = True
+ mm_use_im_start_end = False
+ if not is_multimodal:
+ return sources
+
+ for source in sources:
+ for sentence in source:
+ if DEFAULT_IMAGE_TOKEN in sentence['value']:
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
+ sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
+ sentence['value'] = sentence['value'].strip()
+ if "mmtag" in conversation_lib.default_conversation.version:
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '')
+ replace_token = DEFAULT_IMAGE_TOKEN
+ if mm_use_im_start_end:
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
+
+ return sources
+
+class JsonlDataset(Dataset):
+ def __init__(self, glob_pattern, max_cache_size_gb=50, max_files=5):
+ self.files = glob.glob(glob_pattern)
+ self.files.sort()
+ self.line_counts = []
+ self.cumulative_sizes = []
+ total = 0
+
+ userhome = Path.home()
+ print(f"Using {userhome} as cache directory")
+ cache_dir = userhome / ".cache" / "unidisc"
+ cache_dir.mkdir(parents=True, exist_ok=True)
+ cache_file = cache_dir / "file_metadata.json"
+
+ # Try to load cached metadata
+ cached_metadata = {}
+ try:
+ if cache_file.exists():
+ with open(cache_file, 'r') as f:
+ cached_metadata = json.load(f)
+ except Exception as e:
+ print(f"Error loading cached metadata: {e}")
+
+ print(f"Loaded cached metadata with {len(cached_metadata)} files")
+
+ self.file_sizes = {}
+ needs_update = False
+ for filename in self.files:
+ if filename in cached_metadata:
+ count = cached_metadata[filename]['line_count']
+ size = cached_metadata[filename]['file_size']
+ else:
+ print(f"Processing {filename}")
+ result = subprocess.run(['wc', '-l', filename], capture_output=True, text=True)
+ count = int(result.stdout.split()[0])
+ size = os.path.getsize(filename)
+ cached_metadata[filename] = {
+ 'line_count': count,
+ 'file_size': size
+ }
+ needs_update = True
+
+ self.line_counts.append(count)
+ total += count
+ self.cumulative_sizes.append(total)
+ self.file_sizes[filename] = size
+
+ if needs_update and os.environ.get("SLURM_ARRAY_TASK_ID", None) is None:
+ print(f"Writing cached metadata to {cache_file}")
+ with open(cache_file, 'w') as f:
+ json.dump(cached_metadata, f)
+
+ self._cache = {}
+ self.current_cache_size = 0
+ self.max_cache_size = max_cache_size_gb * 1024 * 1024 * 1024 # Convert GB to bytes
+ self.max_files = max_files
+
+ if len(self) == 0:
+ print("No files to process")
+ print(len(cached_metadata))
+ print(cached_metadata)
+ exit()
+ else:
+ print(f"Total len: {len(self)}")
+
+ def __len__(self):
+ return self.cumulative_sizes[-1]
+
+ def __getitem__(self, idx):
+ file_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ # print(f"File idx: {file_idx}")
+ if file_idx == 0:
+ line_idx = idx
+ else:
+ line_idx = idx - self.cumulative_sizes[file_idx - 1]
+
+ filename = self.files[file_idx]
+ if filename in self._cache:
+ lines = self._cache[filename]
+ else:
+ with open(filename, 'r') as f:
+ lines = f.readlines()
+
+ # print(f"Adding {filename} to cache")
+ # Remove files from cache until we have enough space and are under max files
+ while ((self.current_cache_size + self.file_sizes[filename] > self.max_cache_size or
+ len(self._cache) >= self.max_files) and self._cache):
+ removed_file = next(iter(self._cache))
+ self.current_cache_size -= self.file_sizes[removed_file]
+ self._cache.pop(removed_file)
+
+ if (self.file_sizes[filename] <= self.max_cache_size and
+ len(self._cache) < self.max_files):
+ self._cache[filename] = lines
+ self.current_cache_size += self.file_sizes[filename]
+
+ data = json.loads(lines[line_idx])
+ data["idx"] = idx
+ return data
+
+
+if __name__ == "__main__":
+ tokenizer = transformers.AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
+ tokenizer.model_max_length = 100000
+ tokenizer.padding_side = 'right'
+ tokenizer.add_eos_token = True
+
+ i = 0
+ dataset = JsonlDataset(glob_pattern="/scratch/aswerdlo/cambrian/jsons/gpt4v_77k.jsonl")
+ dataloader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=lambda x: x)
+
+ for batch in dataloader:
+ image_paths = []
+ for i in range(len(batch)):
+ if "image" in batch[i]:
+ image_paths.append(batch[i]["image"])
+
+ image_ids = torch.zeros((len(image_paths), 256), dtype=torch.int64)
+ for i, sources in enumerate(batch):
+ has_image = _has_image(sources)
+ sources = copy.deepcopy([e["conversations"] for e in [sources]])
+ if has_image:
+ sources = preprocess_multimodal(sources)
+
+ data_dict = preprocess(sources, tokenizer, has_image=has_image, image_ids=image_ids[[i]])
+ breakpoint()
\ No newline at end of file
diff --git a/unidisc/tokenizers/viz_tokenizers.py b/unidisc/tokenizers/viz_tokenizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea2308a08639fcd9c176750958985af6a3f0c3d4
--- /dev/null
+++ b/unidisc/tokenizers/viz_tokenizers.py
@@ -0,0 +1,88 @@
+from pathlib import Path
+from PIL import Image, ImageDraw, ImageFont
+import textwrap
+
+def custom_sort_key(name):
+ if "GT_" in name:
+ return (0, name)
+ elif "seq256" in name:
+ return (1, name)
+ elif "seq1024" in name:
+ return (2, name)
+ elif "seq4096" in name:
+ return (3, name)
+ else:
+ return (4, name)
+
+def visualize_datasets(root_dir, img_resolution=256, text_img_width=100, text_wrap_width=10, selected_datasets=None):
+ root_path = Path(root_dir)
+ for folder_path in root_path.iterdir():
+ if folder_path.is_dir():
+ datasets = {}
+ for dataset_path in folder_path.iterdir():
+ if dataset_path.is_dir():
+ if dataset_path.name == "output": continue
+ if selected_datasets and not any(x in dataset_path.name for x in selected_datasets):
+ continue
+ images = []
+ image_paths = [p for p in dataset_path.iterdir() if p.stem.isdigit()]
+ for image_path in sorted(image_paths, key=lambda x: int(x.stem)):
+ images.append(Image.open(image_path))
+ datasets[dataset_path.name] = images
+
+ num_images = len(images)
+
+ viz_per_image = False
+ if viz_per_image:
+ for index in range(num_images):
+ widths = [img.width for img in images]
+ max_width = max(widths)
+
+ # Create per_index image
+ per_index_heights = [images[index].resize((img_resolution, img_resolution), Image.LANCZOS).height for images in datasets.values() if len(images) > index]
+ per_index_total_height = sum(per_index_heights)
+ per_index_image = Image.new('RGB', (img_resolution + text_img_width, per_index_total_height)) # Set width to img_resolution + space for text
+ y_offset = 0
+ for dataset_name, images in sorted(datasets.items(), key=lambda x: custom_sort_key(x[0])):
+ if len(images) > index:
+ img = images[index].resize((img_resolution, img_resolution), Image.LANCZOS) # Resize image to img_resolution x img_resolution
+ text_img = Image.new('RGB', (text_img_width, img_resolution), (255, 255, 255)) # Create a white image for text
+ draw = ImageDraw.Draw(text_img)
+ font = ImageFont.load_default()
+ wrapped_text = textwrap.fill(dataset_name, width=text_wrap_width) # Wrap text to fit within the image
+ draw.text((10, 10), wrapped_text, fill=(0, 0, 0), font=font)
+ combined_img = Image.new('RGB', (img_resolution + text_img_width, img_resolution)) # Combined width of text and image
+ combined_img.paste(text_img, (0, 0))
+ combined_img.paste(img, (text_img_width, 0))
+ per_index_image.paste(combined_img, (0, y_offset))
+ y_offset += img.height
+
+ (folder_path / "output").mkdir(parents=True, exist_ok=True)
+ per_index_image.save(folder_path / "output" / f'{index}_per_index_viz.png')
+
+ # Create combined image for the entire dataset
+ num_datasets = len(datasets)
+ combined_image_width = (img_resolution + text_img_width) * num_images # Each column is an index + space for text
+ combined_image_height = img_resolution * num_datasets # Each row is a dataset
+ combined_image = Image.new('RGB', (combined_image_width, combined_image_height))
+
+ for row_index, (dataset_name, images) in enumerate(sorted(datasets.items(), key=lambda x: custom_sort_key(x[0]))):
+ for col_index, img in enumerate(images):
+ resized_img = img.resize((img_resolution, img_resolution), Image.LANCZOS)
+ text_img = Image.new('RGB', (text_img_width, img_resolution), (255, 255, 255)) # Create a white image for text
+ draw = ImageDraw.Draw(text_img)
+ font = ImageFont.load_default()
+ wrapped_text = textwrap.fill(dataset_name, width=text_wrap_width) # Wrap text to fit within the image
+ draw.text((10, 10), wrapped_text, fill=(0, 0, 0), font=font)
+ combined_img = Image.new('RGB', (img_resolution + text_img_width, img_resolution)) # Combined width of text and image
+ combined_img.paste(text_img, (0, 0))
+ combined_img.paste(resized_img, (text_img_width, 0))
+ x_offset = col_index * (img_resolution + text_img_width)
+ y_offset = row_index * img_resolution
+ combined_image.paste(combined_img, (x_offset, y_offset))
+
+
+ (folder_path / "output").mkdir(parents=True, exist_ok=True)
+ combined_image.save(folder_path / "output" / f'combined_viz_{img_resolution}.png')
+
+visualize_datasets('output', img_resolution=256, text_img_width=100, text_wrap_width=10, selected_datasets=["GT_256", 'titok128', 'titok256', 'cosmos'])
\ No newline at end of file
diff --git a/unidisc/utils/cuda_utils.py b/unidisc/utils/cuda_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d716444a63a2738ea17ef90917744d8359aef34f
--- /dev/null
+++ b/unidisc/utils/cuda_utils.py
@@ -0,0 +1,60 @@
+from torchtnt.utils.distributed import all_gather_tensors, get_global_rank
+import torch.distributed as dist
+from decoupled_utils import is_torch_xla_available, use_dist, rprint
+import torch
+
+def _get_min_max_indices(input_list):
+ min_index = -1
+ max_index = -1
+ min_value = float("inf")
+ max_value = float("-inf")
+ for rank, curr_value in enumerate(input_list):
+ if curr_value < min_value:
+ min_value = curr_value
+ min_index = rank
+ if curr_value > max_value:
+ max_value = curr_value
+ max_index = rank
+
+ return min_index, max_index
+
+
+def sync_times(device):
+ if not use_dist():
+ return
+ # Use torch.cuda.Event to measure time across multiple nodes
+ start_event = torch.cuda.Event(enable_timing=True)
+ end_event = torch.cuda.Event(enable_timing=True)
+
+ # Record the start time
+ start_event.record()
+
+ # Perform a synchronization to ensure all nodes start timing at roughly the same point
+ dist.barrier()
+
+ # Record the end time
+ end_event.record()
+
+ # Wait for the end event to be completed
+ end_event.synchronize()
+
+ # Calculate the elapsed time on this node
+ elapsed_time = start_event.elapsed_time(end_event)
+
+ # Gather elapsed times from all nodes
+ elapsed_time_tensor = torch.tensor([elapsed_time], device=device)
+ all_elapsed_times = all_gather_tensors(elapsed_time_tensor)
+
+ # Convert tensor list to a list of times
+ elapsed_times_list = [tensor.item() for tensor in all_elapsed_times]
+
+ # Determine the fastest and slowest ranks
+ fastest_rank, slowest_rank = _get_min_max_indices(elapsed_times_list)
+ time_on_fastest_rank = elapsed_times_list[fastest_rank]
+ time_on_slowest_rank = elapsed_times_list[slowest_rank]
+ time_difference = time_on_slowest_rank - time_on_fastest_rank
+
+ # Print the time difference
+ rprint(
+ f"Time difference between fastest rank ({fastest_rank}: {time_on_fastest_rank} ms) and slowest rank ({slowest_rank}: {time_on_slowest_rank} ms) is {time_difference} milliseconds."
+ )
\ No newline at end of file
diff --git a/unidisc/utils/logging_utils.py b/unidisc/utils/logging_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b69ee34581a68ad4270644118ceeb5fbcd36d0d
--- /dev/null
+++ b/unidisc/utils/logging_utils.py
@@ -0,0 +1,181 @@
+import logging
+from pathlib import Path
+from typing import Optional
+
+import os
+import time
+import atexit
+
+class DebugLogger:
+ def __init__(self, log_name=None, file_name=None, log_dir="/dev/shm", buffer_size=100, flush_interval=1.0, add_user_prefix=True):
+ """
+ Initializes the logger.
+
+ :param identifier: Optional; Unique identifier for the log file (e.g., PID or custom string).
+ If None, uses the current process ID.
+ :param log_dir: Directory where log files are stored.
+ :param buffer_size: Number of messages to buffer before writing to file.
+ :param flush_interval: Time interval (in seconds) to flush logs if buffer isn't full.
+ """
+
+ if file_name is None:
+ file_name = f"pid_{os.getpid()}"
+
+ self.log_name = log_name
+ self.log_dir = Path(log_dir)
+ if add_user_prefix:
+ self.log_dir = self.log_dir / os.getenv("USER")
+
+ self.log_dir = self.log_dir / "logs"
+ self.log_dir.mkdir(parents=True, exist_ok=True)
+ prefix_timestamp = time.strftime("%Y%m%d_%H%M")
+ suffix_timestamp = f"{int(time.time())}_{int(time.time_ns() % 1_000_000_000)}"
+ self.log_file_path = os.path.join(self.log_dir, f"{prefix_timestamp}_{file_name}_{suffix_timestamp}.out")
+ self.buffer = []
+ self.buffer_size = buffer_size
+ self.flush_interval = flush_interval
+ self.last_flush_time = time.time()
+ self.file = open(self.log_file_path, 'a', buffering=1) # Line-buffered
+ atexit.register(self.close)
+
+ def log(self, *args, sep=' ', end='\n', flush=False, **kwargs):
+ """
+ Adds a log message to the buffer.
+
+ Supports arbitrary positional and keyword arguments like the built-in print() function.
+
+ :param args: Arbitrary positional arguments to be logged.
+ :param sep: String inserted between values, default a space.
+ :param end: String appended after the last value, default a newline.
+ :param kwargs: Additional keyword arguments (ignored but accepted for compatibility).
+ """
+ message = sep.join(str(arg) for arg in args) + end
+ timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
+ timestamp = f"{timestamp}, {self.log_name}" if self.log_name is not None else timestamp
+ formatted_message = f"[{timestamp}] {message}".rstrip('\n')
+ self.buffer.append(formatted_message)
+
+ if len(self.buffer) >= self.buffer_size:
+ self.flush()
+ elif (time.time() - self.last_flush_time) >= self.flush_interval:
+ self.flush()
+
+ def flush(self):
+ """
+ Writes buffered log messages to the file and clears the buffer.
+ """
+ if len(self.buffer) > 0:
+ try:
+ self.file.write('\n'.join(self.buffer) + '\n')
+ self.file.flush()
+ os.fsync(self.file.fileno()) # Ensure it's written to disk
+ self.buffer.clear()
+ self.last_flush_time = time.time()
+ except Exception as e:
+ print(f"Failed to write logs: {e}")
+
+ def close(self):
+ """
+ Flushes any remaining logs and closes the file.
+ """
+ self.flush()
+ if not self.file.closed:
+ self.file.close()
+
+logger: Optional[logging.Logger] = None
+file_only_logger: Optional[logging.Logger] = None # New global variable
+memory_logger: Optional[DebugLogger] = None
+
+def set_logger(name: str, log_file_path: Optional[str] = None):
+ global logger, file_only_logger, memory_logger
+ if logger is not None and logger.hasHandlers():
+ logger.handlers.clear()
+
+ logger = logging.getLogger(name)
+ logger.handlers = []
+ logger.setLevel(logging.DEBUG)
+ logger.propagate = False
+
+ log_format = "[%(asctime)s.%(msecs)03d][%(name)s][%(levelname)s] - %(message)s"
+ date_format = "%Y-%m-%d %H:%M:%S"
+ formatter = logging.Formatter(fmt=log_format, datefmt=date_format)
+
+ console_handler = logging.StreamHandler()
+ console_handler.setLevel(logging.INFO)
+ console_handler.setFormatter(formatter)
+ logger.addHandler(console_handler)
+
+ if log_file_path:
+ Path(log_file_path).parent.mkdir(parents=True, exist_ok=True)
+ file_handler = logging.FileHandler(log_file_path)
+ file_handler.setLevel(logging.DEBUG)
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+
+ file_only_logger = logging.getLogger(name + f"_")
+ file_only_logger.handlers = []
+ file_only_logger.setLevel(logging.DEBUG)
+ file_only_logger.addHandler(file_handler)
+ file_only_logger.propagate = False
+
+ if memory_logger is not None:
+ memory_logger.close()
+ memory_logger = DebugLogger(log_name=name, file_name=Path(log_file_path).stem if log_file_path is not None else None)
+
+def get_logger():
+ return logger
+
+class Dummy:
+ def __getattr__(self, name):
+ def method(*args, **kwargs):
+ pass
+ return method
+
+def get_logger_(main_process_only: bool) -> logging.Logger:
+ global logger
+ from decoupled_utils import get_rank, is_main_process
+ if is_main_process() or not main_process_only:
+ if logger is not None:
+ return logger
+ else:
+ return set_logger(__name__ + f"_rank_{get_rank()}")
+ else:
+ return Dummy()
+
+def combine_args(*args):
+ return " ".join((str(arg) for arg in args))
+
+def _always_debug_log(*args, **kwargs) -> logging.Logger:
+ from decoupled_utils import is_main_process
+ if not is_main_process() and file_only_logger is not None:
+ file_only_logger.debug(combine_args(*args), **kwargs)
+
+ log_memory(*args, **kwargs)
+
+def log_debug(*args, main_process_only: bool = True, **kwargs):
+ kwargs.pop("end", None)
+ if main_process_only: _always_debug_log(combine_args(*args), **kwargs)
+ get_logger_(main_process_only=main_process_only).debug(combine_args(*args), **kwargs)
+
+
+def log_info(*args, main_process_only: bool = True, **kwargs):
+ kwargs.pop("end", None)
+ if main_process_only: _always_debug_log(combine_args(*args), **kwargs)
+ get_logger_(main_process_only=main_process_only).info(combine_args(*args), **kwargs)
+
+
+def log_error(*args, main_process_only: bool = True, **kwargs):
+ kwargs.pop("end", None)
+ if main_process_only: _always_debug_log(combine_args(*args), **kwargs)
+ get_logger_(main_process_only=main_process_only).error(combine_args(*args), **kwargs)
+
+
+def log_warn(*args, main_process_only: bool = True, **kwargs):
+ kwargs.pop("end", None)
+ if main_process_only: _always_debug_log(combine_args(*args), **kwargs)
+ get_logger_(main_process_only=main_process_only).warning(combine_args(*args), **kwargs)
+
+def log_memory(*args, **kwargs):
+ kwargs.pop("end", None)
+ if memory_logger is not None:
+ memory_logger.log(combine_args(*args), **kwargs)
\ No newline at end of file
diff --git a/unidisc/utils/parallel_loader.py b/unidisc/utils/parallel_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..eec99056bcd0116971ae2a0dbfe5dc00c4a97b22
--- /dev/null
+++ b/unidisc/utils/parallel_loader.py
@@ -0,0 +1,255 @@
+import itertools
+import threading
+import torch
+import torch_xla
+import torch_xla.debug.profiler as xp
+import torch_xla.utils.keyd_queue as kq
+import torch_xla.utils.utils as xu
+import torch_xla.core.xla_model as xm
+from decoupled_utils import gprint
+
+
+class PerDeviceQueue(object):
+
+ def __init__(self, device, loader_prefetch_size, device_prefetch_size):
+ self.device = device
+ self.loader_queue = kq.Queue(maxsize=loader_prefetch_size)
+ self.queue = kq.Queue(maxsize=device_prefetch_size)
+ self.close_queue_count = itertools.count()
+ gprint("PerDeviceQueue initialized")
+
+
+class PerDeviceLoader(object):
+
+ def __init__(self, loader, device):
+ self._loader = loader
+ self._device = device
+ self._mark_step_batch_count = loader.batches_per_execution - 1
+ self._batches_yielded = 0
+ gprint("PerDeviceLoader initialized")
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ return self.next()
+
+ def __len__(self):
+ return self._loader.per_device_samples()
+
+ def next(self):
+ gprint("Getting next item")
+ if xp.get_tracer_marked_step():
+ gprint("Marking step traced")
+ xp.set_tracer_marked_step(False)
+ self._batches_yielded += 1
+ else:
+ if self._mark_step_batch_count <= self._batches_yielded:
+ gprint(f"before Marking step, {self._batches_yielded}, {self._mark_step_batch_count}")
+ self._batches_yielded = 0
+ xm.mark_step()
+ gprint("Marking step")
+ else:
+ self._batches_yielded += 1
+ gprint("Not marking step, batches yielded: ", self._batches_yielded)
+
+ gprint("Getting next item")
+ item = self._loader.next_item(self._device)
+ gprint("Item retrieved")
+ if item is None:
+ gprint("Item is None, marking step", item)
+ xm.mark_step()
+ gprint("Marked step, exiting since item is None")
+ raise StopIteration
+ return item
+
+
+class ParallelLoader(object):
+ """Wraps an existing PyTorch DataLoader with background data upload.
+
+ Args:
+ loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
+ wrapped.
+ devices (`torch.device`...): The list of devices where the data has to be
+ sent. The i-th sample returned by the `loader` will be sent to `devices[i
+ % len(devices)]`.
+ batchdim (int, optional): The dimension which is holding the batch size.
+ Default: 0
+ loader_prefetch_size (int, optional): The max capacity of the queue used by
+ the thread which is reading samples from the `loader`, to be processed by
+ the worker threads which upload data to the devices.
+ Default: 8
+ device_prefetch_size (int, optional): The max size of the per-device queues,
+ where the worker threads deposit tensors which have already been sent to
+ devices.
+ Default: 4
+ host_to_device_transfer_threads (int, optional): The number of threads that
+ work in parallel to transfer data from loader queue to device queue.
+ Default: 1
+ input_sharding (ShardingSpec, optional): Sharding spec to apply to
+ compatible input tensors after loading.
+ Default: None
+ """
+
+ def __init__(self,
+ loader,
+ devices,
+ batchdim=0,
+ batches_per_execution=1,
+ loader_prefetch_size=12,
+ device_prefetch_size=4,
+ host_to_device_transfer_threads=4,
+ input_sharding=None):
+ self._loader = loader
+ self._devices = [torch.device(x) for x in devices]
+ self._batchdim = batchdim
+ self._batches_per_execution = batches_per_execution
+ self._done = False
+ self._queues = dict()
+ self._input_sharding = input_sharding
+ for device in self._devices:
+ self._queues[device] = PerDeviceQueue(device, loader_prefetch_size,
+ device_prefetch_size)
+ thread = threading.Thread(target=self._loader_worker)
+ thread.daemon = True
+ thread.start()
+ for dqueue in self._queues.values():
+ for i in range(host_to_device_transfer_threads):
+ thread = threading.Thread(
+ target=self._worker,
+ args=(
+ dqueue,
+ host_to_device_transfer_threads,
+ ))
+ thread.daemon = True
+ thread.start()
+
+ gprint("ParallelLoader finished")
+
+ def per_device_loader(self, device):
+ """Retrieves the loader iterator object for the given device.
+
+ Args:
+ device (`torch.device`): The device whole loader is being requested.
+
+ Returns:
+ The loader iterator object for the `device`. This is not a
+ `torch.utils.data.DataLoader` interface, but a Python iterator which
+ returns the same tensor data structure as returned by the wrapped
+ `torch.utils.data.DataLoader`, but residing on XLA devices.
+ """
+ return PerDeviceLoader(self, torch.device(device))
+
+ def per_device_samples(self):
+ return len(self._loader) // len(self._devices)
+
+ def next_item(self, device):
+ dqueue = self._queues[device]
+ gprint("Getting item from queue")
+ return dqueue.queue.get()
+
+ def close(self):
+ self._done = True
+ for dqueue in self._queues.values():
+ dqueue.queue.close()
+ dqueue.loader_queue.close()
+
+ @property
+ def batches_per_execution(self):
+ return self._batches_per_execution
+
+ def _loader_worker(self):
+ queues = list(self._queues.values())
+ data_iter = enumerate(self._loader)
+ batch = []
+ while not self._done:
+ try:
+ gprint("Getting next item")
+ _, data = next(data_iter)
+ gprint("Item retrieved inside loader worker")
+ except StopIteration:
+ gprint("StopIteration")
+ break
+
+ gprint("Appending item to batch, type: ", type(data))
+ batch.append(data)
+ if len(batch) == len(self._devices):
+ gprint("Batch full, sending to queues")
+ for queue_no, device_batch in enumerate(batch):
+ queues[queue_no].loader_queue.put(device_batch)
+ batch = []
+
+ gprint(f"Current batch length: {len(batch)}")
+ gprint("Loader worker done")
+ for dqueue in queues:
+ dqueue.loader_queue.close_write()
+ gprint("Loader worker closed")
+
+ def _get_batch(self, dqueue):
+ batch = []
+ while dqueue.queue.max_size() > len(batch):
+ gprint("Getting item from loader queue")
+ item = dqueue.loader_queue.get()
+ gprint(f"Item retrieved")
+ if item is None:
+ gprint("Item is None, breaking", item)
+ break
+ batch.append(item)
+ gprint(f"Batch retrieved: length {len(batch)}")
+ return batch
+
+ def _worker(self, dqueue, host_to_device_transfer_threads):
+ device = torch.device(dqueue.device)
+ gprint("Worker initialized")
+ while True:
+ gprint("Getting batch")
+ batch = self._get_batch(dqueue)
+ gprint("Batch retrieved")
+ if not batch:
+ gprint("Batch empty, breaking, ", batch)
+ break
+
+ gprint("Sending batch to device")
+ batch = xm.send_cpu_data_to_device(batch, device, self._input_sharding)
+ gprint("Batch sent to device")
+ for data in batch:
+ gprint("Putting data in queue")
+ if data is None:
+ print("Data is None! ", data)
+ dqueue.queue.put(data)
+ gprint("Data put in queue")
+ gprint("Closing queue")
+ close_queue_count = next(dqueue.close_queue_count)
+ gprint(f"Close queue count: {close_queue_count}")
+ if close_queue_count == host_to_device_transfer_threads - 1:
+ gprint("Closing queue")
+ dqueue.queue.close_write()
+ gprint("Queue closed")
+
+ gprint("Worker done!!")
+
+
+class MpDeviceLoader(object):
+ """Wraps an existing PyTorch DataLoader with background data upload.
+
+ This class should only be using with multi-processing data parallelism.
+
+ Args:
+ loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
+ wrapped.
+ device (`torch.device`...): The device where the data has to be sent.
+ kwargs: Named arguments for the `ParallelLoader` constructor.
+ """
+
+ def __init__(self, loader, device, **kwargs):
+ self._loader = loader
+ self._device = device
+ self._parallel_loader_kwargs = kwargs
+
+ def __iter__(self):
+ parallel_loader = ParallelLoader(self._loader, [self._device], **self._parallel_loader_kwargs)
+ gprint("ParallelLoader initialized")
+ return parallel_loader.per_device_loader(self._device)
+
+ def __len__(self):
+ return len(self._loader)
diff --git a/unidisc/utils/simple_llm.py b/unidisc/utils/simple_llm.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fc99198a12bf0eadff0e3b2c25de38a417282aa
--- /dev/null
+++ b/unidisc/utils/simple_llm.py
@@ -0,0 +1,150 @@
+import functools
+import os
+import random
+import subprocess
+import time
+from contextlib import ExitStack
+from decoupled_utils import rprint
+
+OPENROUTER_BASE = "https://openrouter.ai"
+OPENROUTER_API_BASE = f"{OPENROUTER_BASE}/api/v1"
+OPENROUTER_REFERRER = "https://github.com/alexanderatallah/openrouter-streamlit"
+
+def get_groq_llama(model="llama-3.2-90b-text-preview"):
+ from langchain_groq import ChatGroq
+ groq_llm = ChatGroq(
+ temperature=0.8,
+ model=model,
+ max_retries=0,
+ request_timeout=30,
+ )
+ return groq_llm
+
+def get_openai_azure():
+ from langchain_openai import AzureChatOpenAI
+ # Need to also set AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT
+ # Key only works for gpt-4o
+ os.environ["AZURE_OPENAI_API_VERSION"] = '2024-06-01'
+ os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"] = "gpt-4o"
+ llm = AzureChatOpenAI(
+ openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
+ azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"],
+ )
+ return llm
+
+def get_openai_openrouter(model="gpt-4o-mini"):
+ from langchain_openai import ChatOpenAI
+ llm = ChatOpenAI(
+ temperature=0.8,
+ model=model,
+ openai_api_key=os.environ["OPENROUTER_API_KEY"],
+ openai_api_base=OPENROUTER_API_BASE,
+ timeout=15,
+ )
+ return llm
+
+def get_llm(llm_model_type, **kwargs):
+ from langchain_core.output_parsers import JsonOutputParser
+ output_parser = JsonOutputParser()
+
+ from langchain_core.prompts import ChatPromptTemplate
+ prompt = ChatPromptTemplate.from_messages([
+ ('system', 'You are a helpful assistant.'),
+ ('user', """
+ I am generating a set of incorrect captions for an image. Given the following prompt from a human user that corresponds to a real image, please generate a set of 12 incorrect prompts that modify the original prompt but maintains some of the original meaning or context. For example, you may add or remove an object, change the desired styling, the sentence structure, or reference a different proper noun. You might change the subject, time period, time of day, location, culture, camera angle, and other attributes. Make the prompts very simple and do not use very exotic or rare objects or words. For half of the captions, make them broken, have improper grammar or just be nonsensical. Do not generate NSFW prompts. Do not preface the output with any numbers or text. {format_instructions}. The output should have keys as indices and values as the prompts, and should be valid, parseable JSON. Make sure to escape quotes.
+
+ Original prompt: {prompt}
+ """)
+ ])
+
+
+ openai_llm = get_groq_llama("llama-3.2-90b-text-preview")
+ llm = openai_llm.with_fallbacks([
+ get_groq_llama("llama-3.2-11b-text-preview"),
+ get_groq_llama("gemma-7b-it"),
+ get_groq_llama("llama-3.2-3b-preview"),
+ get_groq_llama("llama-3.2-1b-preview"),
+ get_groq_llama("llama-3.2-11b-vision-preview"),
+ get_groq_llama("llama-3.2-90b-vision-preview"),
+ ])
+
+ chain = prompt | llm
+
+ return functools.partial(forward_llm, chain=chain, output_parser=output_parser, llm_model_type=llm_model_type)
+
+def forward_llm(prompt, chain, output_parser, llm_model_type, fake_openai_failure=False):
+ with ExitStack() as stack:
+ if "gpt" in llm_model_type:
+ from langchain_community.callbacks import get_openai_callback
+ cb = stack.enter_context(get_openai_callback())
+
+ if fake_openai_failure:
+ from unittest.mock import patch
+
+ import httpx
+ from openai import RateLimitError
+ request = httpx.Request("GET", "/")
+ response = httpx.Response(200, request=request)
+ error = RateLimitError("rate limit", response=response, body="")
+ stack.enter_context(patch("openai.resources.chat.completions.Completions.create", side_effect=error))
+
+ for i in range(10):
+ try:
+ start_time = time.time()
+ rprint(f"Calling LLM...")
+ output_message = chain.invoke({
+ "prompt": prompt,
+ "format_instructions": output_parser.get_format_instructions()
+ })
+
+ output = output_parser.invoke(output_message)
+ output = list(output.values())
+
+ if len([x for x in output if x is not None]) == 0:
+ raise ValueError("No output from LLM")
+
+ end_time = time.time()
+ rprint(f"LLM Time taken: {end_time - start_time:.2f} seconds")
+ break
+ except Exception as e:
+ rprint(f"Error, retrying: {i}, {e}")
+ if i == 9:
+ raise e
+ continue
+
+ try:
+ model_name = output_message.response_metadata['model_name']
+ rprint(f"Used model name: {model_name}")
+ except:
+ model_name = "Unknown"
+
+ if "gpt" in llm_model_type and i == 0:
+ rprint(cb)
+
+ output = [prompt for prompt in output if prompt is not None]
+
+ if len(output) == 0:
+ rprint("No output from LLM")
+ rprint(f"Raw: {output_message}")
+ output = []
+ else:
+ if any(x in output[0].lower() for x in [" here", "diverse"]):
+ rprint("Removing the first element.")
+ rprint(output[0])
+ output.pop(0)
+
+ output = [prompt.strip() for prompt in output]
+ output = [prompt for prompt in output if prompt != ""]
+
+ return output, model_name
+
+import json
+import random
+from pathlib import Path
+
+if __name__ == "__main__":
+ llm_func = get_llm(llm_model_type="")
+ res = llm_func("A red sailboat on a blue ocean with a yellow sun", fake_openai_failure=False)
+ breakpoint()
+
+
\ No newline at end of file
diff --git a/unidisc/utils/slurm_requeue.py b/unidisc/utils/slurm_requeue.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c2b54fca8f359769adc0d9bb6bb7c78b3261641
--- /dev/null
+++ b/unidisc/utils/slurm_requeue.py
@@ -0,0 +1,70 @@
+import subprocess
+import socket
+import os
+REQUEUE_COMMAND = "scontrol requeue JobId={job_id}; scontrol update JobId={job_id} ExcNodeList={exclude_list}"
+
+def get_hostname() -> str:
+ return socket.gethostname()
+
+def is_cuda_available() -> bool:
+ try: # test if CUDA is available
+ import torch
+
+ test_tensor = torch.rand(5, 3, device="cuda")
+ test_tensor.requires_grad = True
+ (test_tensor**2).sum().backward()
+ return True
+
+ except RuntimeError as e:
+ print(f"Cuda test failed: {e}")
+
+ return False
+
+
+def get_current_exclude_list(job_id: str) -> list:
+ try: # get the current exclude list
+ exclude_list = []
+ output = subprocess.run(["scontrol", "show", "job", job_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout
+
+ for line in output.splitlines():
+ if "ExcNodeList=" in line:
+ for host in line.split("ExcNodeList=")[1].split()[0].split(","):
+
+ if host != "(null)":
+ exclude_list.append(host)
+
+ except subprocess.CalledProcessError as e:
+ print(f"Failed to retrieve exclude list for job {job_id}: {e}")
+
+ return exclude_list
+
+
+def requeue_job_excluding_node(job_id: str, bad_node: str) -> bool:
+ try: # stop and requeue the job excluding the bad node
+ exclude_list = get_current_exclude_list(job_id)
+ exclude_list.append(bad_node)
+ exclude_list = ",".join(exclude_list)
+ subprocess.run(REQUEUE_COMMAND.format(job_id=job_id, exclude_list=exclude_list), shell=True, check=True)
+ print(f"Requeued job {job_id}, excluding {bad_node}.")
+ return True
+
+ except subprocess.CalledProcessError as e:
+ print(f"Failed to requeue job {job_id}: {e}")
+
+ return False
+
+def check_requeue() -> bool:
+ job_id = os.getenv("SLURM_JOB_ID", None)
+ if not job_id:
+ return False
+
+ if not (cuda_available := is_cuda_available()):
+ print("Attempting to requeue job.")
+ requeue_job_excluding_node(job_id, get_hostname())
+ else:
+ print("CUDA is available. Proceeding with the job.")
+
+ return cuda_available
+
+if __name__ == "__main__":
+ check_requeue()
diff --git a/unidisc/utils/standalone_metrics.py b/unidisc/utils/standalone_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb34c40b9fe559a19686520bdf48ee062d1b3067
--- /dev/null
+++ b/unidisc/utils/standalone_metrics.py
@@ -0,0 +1,1971 @@
+# Copyright The Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import builtins
+import functools
+import inspect
+from abc import ABC, abstractmethod
+from collections import OrderedDict
+from contextlib import contextmanager
+from copy import deepcopy
+from typing import (Any, Callable, ClassVar, Dict, Generator, Hashable,
+ Iterable, Iterator, List, Optional, Sequence, Tuple, Union)
+
+import torch
+from lightning_utilities import apply_to_collection
+from torch import Tensor
+from torch.nn import Module, ModuleDict
+from torchmetrics.metric import Metric
+from torchmetrics.utilities import rank_zero_warn
+from torchmetrics.utilities.data import (_flatten, _flatten_dict,
+ _squeeze_if_scalar, allclose,
+ dim_zero_cat, dim_zero_max,
+ dim_zero_mean, dim_zero_min,
+ dim_zero_sum)
+from torchmetrics.utilities.distributed import gather_all_tensors
+from torchmetrics.utilities.exceptions import TorchMetricsUserError
+from torchmetrics.utilities.imports import (_MATPLOTLIB_AVAILABLE,
+ _TORCH_GREATER_EQUAL_2_1)
+from torchmetrics.utilities.plot import (_AX_TYPE, _PLOT_OUT_TYPE,
+ plot_single_or_multi_val)
+from torchmetrics.utilities.prints import rank_zero_warn
+from typing_extensions import Literal
+from decoupled_utils import is_torch_xla_available
+
+def jit_distributed_available() -> bool:
+ """Determine if distributed mode is initialized."""
+ return not is_torch_xla_available()
+
+class Metric(Module, ABC):
+ """Base class for all metrics present in the Metrics API.
+
+ This class is inherited by all metrics and implements the following functionality:
+ 1. Handles the transfer of metric states to correct device
+ 2. Handles the synchronization of metric states across processes
+
+ The three core methods of the base class are
+ * ``add_state()``
+ * ``forward()``
+ * ``reset()``
+
+ which should almost never be overwritten by child classes. Instead, the following methods should be overwritten
+ * ``update()``
+ * ``compute()``
+
+
+ Args:
+ kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.
+
+ - compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states.
+ - dist_sync_on_step: If metric state should synchronize on ``forward()``. Default is ``False``
+ - process_group: The process group on which the synchronization is called. Default is the world.
+ - dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom
+ implementation that calls ``torch.distributed.all_gather`` internally.
+ - distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a
+ check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
+ - sync_on_compute: If metric state should synchronize when ``compute`` is called. Default is ``True``
+ - compute_with_cache: If results from ``compute`` should be cached. Default is ``True``
+
+ """
+
+ __jit_ignored_attributes__: ClassVar[List[str]] = ["device"]
+ __jit_unused_properties__: ClassVar[List[str]] = [
+ "is_differentiable",
+ "higher_is_better",
+ "plot_lower_bound",
+ "plot_upper_bound",
+ "plot_legend_name",
+ "metric_state",
+ "_update_called",
+ ]
+ is_differentiable: Optional[bool] = None
+ higher_is_better: Optional[bool] = None
+ full_state_update: Optional[bool] = None
+
+ plot_lower_bound: Optional[float] = None
+ plot_upper_bound: Optional[float] = None
+ plot_legend_name: Optional[str] = None
+
+ def __init__(
+ self,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__()
+
+ # see (https://github.com/pytorch/pytorch/blob/3e6bb5233f9ca2c5aa55d9cda22a7ee85439aa6e/
+ # torch/nn/modules/module.py#L227)
+ torch._C._log_api_usage_once(f"torchmetrics.metric.{self.__class__.__name__}")
+ # magic patch for `RuntimeError: DataLoader worker (pid(s) 104) exited unexpectedly`
+ self._TORCH_GREATER_EQUAL_2_1 = bool(_TORCH_GREATER_EQUAL_2_1)
+ self._device = torch.device("cpu")
+ self._dtype = torch.get_default_dtype()
+
+ self.compute_on_cpu = kwargs.pop("compute_on_cpu", False)
+ if not isinstance(self.compute_on_cpu, bool):
+ raise ValueError(
+ f"Expected keyword argument `compute_on_cpu` to be an `bool` but got {self.compute_on_cpu}"
+ )
+
+ self.dist_sync_on_step = kwargs.pop("dist_sync_on_step", False)
+ if not isinstance(self.dist_sync_on_step, bool):
+ raise ValueError(
+ f"Expected keyword argument `dist_sync_on_step` to be an `bool` but got {self.dist_sync_on_step}"
+ )
+
+ self.process_group = kwargs.pop("process_group", None)
+
+ self.dist_sync_fn = kwargs.pop("dist_sync_fn", None)
+ if self.dist_sync_fn is not None and not callable(self.dist_sync_fn):
+ raise ValueError(
+ f"Expected keyword argument `dist_sync_fn` to be an callable function but got {self.dist_sync_fn}"
+ )
+
+ self.distributed_available_fn = kwargs.pop("distributed_available_fn", None) or jit_distributed_available
+
+ self.sync_on_compute = kwargs.pop("sync_on_compute", True)
+ if not isinstance(self.sync_on_compute, bool):
+ raise ValueError(
+ f"Expected keyword argument `sync_on_compute` to be a `bool` but got {self.sync_on_compute}"
+ )
+ self.compute_with_cache = kwargs.pop("compute_with_cache", True)
+ if not isinstance(self.compute_with_cache, bool):
+ raise ValueError(
+ f"Expected keyword argument `compute_with_cache` to be a `bool` but got {self.compute_with_cache}"
+ )
+
+ if kwargs:
+ kwargs_ = [f"`{a}`" for a in sorted(kwargs)]
+ raise ValueError(f"Unexpected keyword arguments: {', '.join(kwargs_)}")
+
+ # initialize
+ self._update_signature = inspect.signature(self.update)
+ self.update: Callable = self._wrap_update(self.update) # type: ignore[method-assign]
+ self.compute: Callable = self._wrap_compute(self.compute) # type: ignore[method-assign]
+ self._computed = None
+ self._forward_cache = None
+ self._update_count = 0
+ self._to_sync = self.sync_on_compute
+ self._should_unsync = True
+ self._enable_grad = False
+ self._dtype_convert = False
+
+ # initialize state
+ self._defaults: Dict[str, Union[List, Tensor]] = {}
+ self._persistent: Dict[str, bool] = {}
+ self._reductions: Dict[str, Union[str, Callable[..., Any], None]] = {}
+
+ # state management
+ self._is_synced = False
+ self._cache: Optional[Dict[str, Union[List[Tensor], Tensor]]] = None
+
+ @property
+ def _update_called(self) -> bool:
+ rank_zero_warn(
+ "This property will be removed in 2.0.0. Use `Metric.updated_called` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return self.update_called
+
+ @property
+ def update_called(self) -> bool:
+ """Returns `True` if `update` or `forward` has been called initialization or last `reset`."""
+ return self._update_count > 0
+
+ @property
+ def update_count(self) -> int:
+ """Get the number of times `update` and/or `forward` has been called since initialization or last `reset`."""
+ return self._update_count
+
+ @property
+ def metric_state(self) -> Dict[str, Union[List[Tensor], Tensor]]:
+ """Get the current state of the metric."""
+ return {attr: getattr(self, attr) for attr in self._defaults}
+
+ def add_state(
+ self,
+ name: str,
+ default: Union[list, Tensor],
+ dist_reduce_fx: Optional[Union[str, Callable]] = None,
+ persistent: bool = False,
+ ) -> None:
+ """Add metric state variable. Only used by subclasses.
+
+ Metric state variables are either `:class:`~torch.Tensor` or an empty list, which can be appended to by the
+ metric. Each state variable must have a unique name associated with it. State variables are accessible as
+ attributes of the metric i.e, if ``name`` is ``"my_state"`` then its value can be accessed from an instance
+ ``metric`` as ``metric.my_state``. Metric states behave like buffers and parameters of :class:`~torch.nn.Module`
+ as they are also updated when ``.to()`` is called. Unlike parameters and buffers, metric states are not by
+ default saved in the modules :attr:`~torch.nn.Module.state_dict`.
+
+ Args:
+ name: The name of the state variable. The variable will then be accessible at ``self.name``.
+ default: Default value of the state; can either be a :class:`~torch.Tensor` or an empty list.
+ The state will be reset to this value when ``self.reset()`` is called.
+ dist_reduce_fx (Optional): Function to reduce state across multiple processes in distributed mode.
+ If value is ``"sum"``, ``"mean"``, ``"cat"``, ``"min"`` or ``"max"`` we will use ``torch.sum``,
+ ``torch.mean``, ``torch.cat``, ``torch.min`` and ``torch.max``` respectively, each with argument
+ ``dim=0``. Note that the ``"cat"`` reduction only makes sense if the state is a list, and not
+ a tensor. The user can also pass a custom function in this parameter.
+ persistent (Optional): whether the state will be saved as part of the modules ``state_dict``.
+ Default is ``False``.
+
+ Note:
+ Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes.
+ However, there won't be any reduction function applied to the synchronized metric state.
+
+ The metric states would be synced as follows
+
+ - If the metric state is :class:`~torch.Tensor`, the synced value will be a stacked :class:`~torch.Tensor`
+ across the process dimension if the metric state was a :class:`~torch.Tensor`. The original
+ :class:`~torch.Tensor` metric state retains dimension and hence the synchronized output will be of shape
+ ``(num_process, ...)``.
+
+ - If the metric state is a ``list``, the synced value will be a ``list`` containing the
+ combined elements from all processes.
+
+ Note:
+ When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow
+ the format discussed in the above note.
+
+ Note:
+ The values inserted into a list state are deleted whenever :meth:`~Metric.reset` is called. This allows
+ device memory to be automatically reallocated, but may produce unexpected effects when referencing list
+ states. To retain such values after :meth:`~Metric.reset` is called, you must first copy them to another
+ object.
+
+ Raises:
+ ValueError:
+ If ``default`` is not a ``tensor`` or an ``empty list``.
+ ValueError:
+ If ``dist_reduce_fx`` is not callable or one of ``"mean"``, ``"sum"``, ``"cat"``, ``"min"``,
+ ``"max"`` or ``None``.
+
+ """
+ if not isinstance(default, (Tensor, list)) or (isinstance(default, list) and default):
+ raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)")
+
+ if dist_reduce_fx == "sum":
+ dist_reduce_fx = dim_zero_sum
+ elif dist_reduce_fx == "mean":
+ dist_reduce_fx = dim_zero_mean
+ elif dist_reduce_fx == "max":
+ dist_reduce_fx = dim_zero_max
+ elif dist_reduce_fx == "min":
+ dist_reduce_fx = dim_zero_min
+ elif dist_reduce_fx == "cat":
+ dist_reduce_fx = dim_zero_cat
+ elif dist_reduce_fx is not None and not callable(dist_reduce_fx):
+ raise ValueError("`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', 'min', 'max', None]")
+
+ if isinstance(default, Tensor):
+ default = default.contiguous()
+
+ setattr(self, name, default)
+
+ self._defaults[name] = deepcopy(default)
+ self._persistent[name] = persistent
+ self._reductions[name] = dist_reduce_fx
+
+ @torch.jit.unused
+ def forward(self, *args: Any, **kwargs: Any) -> Any:
+ """Aggregate and evaluate batch input directly.
+
+ Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch
+ statistics to the overall accumululating metric state. Input arguments are the exact same as corresponding
+ ``update`` method. The returned output is the exact same as the output of ``compute``.
+
+ Args:
+ args: Any arguments as required by the metric ``update`` method.
+ kwargs: Any keyword arguments as required by the metric ``update`` method.
+
+ Returns:
+ The output of the ``compute`` method evaluated on the current batch.
+
+ Raises:
+ TorchMetricsUserError:
+ If the metric is already synced and ``forward`` is called again.
+
+ """
+ # check if states are already synced
+ if self._is_synced:
+ raise TorchMetricsUserError(
+ "The Metric shouldn't be synced when performing ``forward``. "
+ "HINT: Did you forget to call ``unsync`` ?."
+ )
+
+ if self.full_state_update or self.full_state_update is None or self.dist_sync_on_step:
+ self._forward_cache = self._forward_full_state_update(*args, **kwargs)
+ else:
+ self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
+
+ return self._forward_cache
+
+ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any:
+ """Forward computation using two calls to `update`.
+
+ Doing this secures that metrics that need access to the full metric state during `update` works as expected.
+ This is the most safe method to use for any metric but also the slower version of the two forward
+ implementations.
+
+ """
+ # global accumulation
+ self.update(*args, **kwargs)
+ _update_count = self._update_count
+
+ self._to_sync = self.dist_sync_on_step
+ # skip restore cache operation from compute as cache is stored below.
+ self._should_unsync = False
+ # skip computing on cpu for the batch
+ _temp_compute_on_cpu = self.compute_on_cpu
+ self.compute_on_cpu = False
+
+ # save context before switch
+ cache = self._copy_state_dict()
+
+ # call reset, update, compute, on single batch
+ self._enable_grad = True # allow grads for batch computation
+ self.reset()
+ self.update(*args, **kwargs)
+ batch_val = self.compute()
+
+ # restore context
+ for attr, val in cache.items():
+ setattr(self, attr, val)
+ self._update_count = _update_count
+
+ # restore context
+ self._is_synced = False
+ self._should_unsync = True
+ self._to_sync = self.sync_on_compute
+ self._computed = None
+ self._enable_grad = False
+ self.compute_on_cpu = _temp_compute_on_cpu
+ if self.compute_on_cpu:
+ self._move_list_states_to_cpu()
+
+ return batch_val
+
+ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:
+ """Forward computation using single call to `update`.
+
+ This can be done when the global metric state is a sinple reduction of batch states. This can be unsafe for
+ certain metric cases but is also the fastest way to both accumulate globally and compute locally.
+
+ """
+ # store global state and reset to default
+ global_state = self._copy_state_dict()
+ _update_count = self._update_count
+ self.reset()
+
+ # local synchronization settings
+ self._to_sync = self.dist_sync_on_step
+ self._should_unsync = False
+ _temp_compute_on_cpu = self.compute_on_cpu
+ self.compute_on_cpu = False
+ self._enable_grad = True # allow grads for batch computation
+
+ # calculate batch state and compute batch value
+ self.update(*args, **kwargs)
+ batch_val = self.compute()
+
+ # reduce batch and global state
+ self._update_count = _update_count + 1
+ with torch.no_grad():
+ self._reduce_states(global_state)
+
+ # restore context
+ self._is_synced = False
+ self._should_unsync = True
+ self._to_sync = self.sync_on_compute
+ self._computed = None
+ self._enable_grad = False
+ self.compute_on_cpu = _temp_compute_on_cpu
+ if self.compute_on_cpu:
+ self._move_list_states_to_cpu()
+
+ return batch_val
+
+ def _reduce_states(self, incoming_state: Dict[str, Any]) -> None:
+ """Add an incoming metric state to the current state of the metric.
+
+ Args:
+ incoming_state: a dict containing a metric state similar metric itself
+
+ """
+ for attr in self._defaults:
+ local_state = getattr(self, attr)
+ global_state = incoming_state[attr]
+ reduce_fn = self._reductions[attr]
+ if reduce_fn == dim_zero_sum:
+ reduced = global_state + local_state
+ elif reduce_fn == dim_zero_mean:
+ reduced = ((self._update_count - 1) * global_state + local_state).float() / self._update_count
+ elif reduce_fn == dim_zero_max:
+ reduced = torch.max(global_state, local_state)
+ elif reduce_fn == dim_zero_min:
+ reduced = torch.min(global_state, local_state)
+ elif reduce_fn == dim_zero_cat:
+ if isinstance(global_state, Tensor):
+ reduced = torch.cat([global_state, local_state])
+ else:
+ reduced = global_state + local_state
+ elif reduce_fn is None and isinstance(global_state, Tensor):
+ reduced = torch.stack([global_state, local_state])
+ elif reduce_fn is None and isinstance(global_state, list):
+ reduced = _flatten([global_state, local_state])
+ elif reduce_fn and callable(reduce_fn):
+ reduced = reduce_fn(torch.stack([global_state, local_state]))
+ else:
+ raise TypeError(f"Unsupported reduce_fn: {reduce_fn}")
+ setattr(self, attr, reduced)
+
+ def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None) -> None:
+ input_dict = {attr: getattr(self, attr) for attr in self._reductions}
+
+ for attr, reduction_fn in self._reductions.items():
+ # pre-concatenate metric states that are lists to reduce number of all_gather operations
+ if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) > 1:
+ input_dict[attr] = [dim_zero_cat(input_dict[attr])]
+
+ # cornor case in distributed settings where a rank have not received any data, create empty to concatenate
+ if (
+ self._TORCH_GREATER_EQUAL_2_1
+ and reduction_fn == dim_zero_cat
+ and isinstance(input_dict[attr], list)
+ and len(input_dict[attr]) == 0
+ ):
+ input_dict[attr] = [torch.tensor([], device=self.device, dtype=self.dtype)]
+
+ output_dict = apply_to_collection(
+ input_dict,
+ Tensor,
+ dist_sync_fn,
+ group=process_group or self.process_group,
+ )
+
+ for attr, reduction_fn in self._reductions.items():
+ # pre-processing ops (stack or flatten for inputs)
+
+ if isinstance(output_dict[attr], list) and len(output_dict[attr]) == 0:
+ setattr(self, attr, [])
+ continue
+
+ if isinstance(output_dict[attr][0], Tensor):
+ output_dict[attr] = torch.stack(output_dict[attr])
+ elif isinstance(output_dict[attr][0], list):
+ output_dict[attr] = _flatten(output_dict[attr])
+
+ if not (callable(reduction_fn) or reduction_fn is None):
+ raise TypeError("reduction_fn must be callable or None")
+ reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr]
+ setattr(self, attr, reduced)
+
+ def _wrap_update(self, update: Callable) -> Callable:
+ @functools.wraps(update)
+ def wrapped_func(*args: Any, **kwargs: Any) -> None:
+ self._computed = None
+ self._update_count += 1
+ with torch.set_grad_enabled(self._enable_grad):
+ try:
+ update(*args, **kwargs)
+ except RuntimeError as err:
+ if "Expected all tensors to be on" in str(err):
+ raise RuntimeError(
+ "Encountered different devices in metric calculation (see stacktrace for details)."
+ " This could be due to the metric class not being on the same device as input."
+ f" Instead of `metric={self.__class__.__name__}(...)` try to do"
+ f" `metric={self.__class__.__name__}(...).to(device)` where"
+ " device corresponds to the device of the input."
+ ) from err
+ raise err
+
+ if self.compute_on_cpu:
+ self._move_list_states_to_cpu()
+
+ return wrapped_func
+
+ def _move_list_states_to_cpu(self) -> None:
+ """Move list states to cpu to save GPU memory."""
+ for key in self._defaults:
+ current_val = getattr(self, key)
+ if isinstance(current_val, Sequence):
+ setattr(self, key, [cur_v.to("cpu") for cur_v in current_val])
+
+ def sync(
+ self,
+ dist_sync_fn: Optional[Callable] = None,
+ process_group: Optional[Any] = None,
+ should_sync: bool = True,
+ distributed_available: Optional[Callable] = None,
+ ) -> None:
+ """Sync function for manually controlling when metrics states should be synced across processes.
+
+ Args:
+ dist_sync_fn: Function to be used to perform states synchronization
+ process_group:
+ Specify the process group on which synchronization is called.
+ default: `None` (which selects the entire world)
+ should_sync: Whether to apply to state synchronization. This will have an impact
+ only when running in a distributed setting.
+ distributed_available: Function to determine if we are running inside a distributed setting
+
+ Raises:
+ TorchMetricsUserError:
+ If the metric is already synced and ``sync`` is called again.
+
+ """
+ if self._is_synced and should_sync:
+ raise TorchMetricsUserError("The Metric has already been synced.")
+
+ if distributed_available is None and self.distributed_available_fn is not None:
+ distributed_available = self.distributed_available_fn
+
+ is_distributed = distributed_available() if callable(distributed_available) else None
+
+ if not should_sync or not is_distributed:
+ return
+
+ if dist_sync_fn is None:
+ dist_sync_fn = gather_all_tensors
+
+ # cache prior to syncing
+ self._cache = self._copy_state_dict()
+
+ # sync
+ self._sync_dist(dist_sync_fn, process_group=process_group)
+ self._is_synced = True
+
+ def unsync(self, should_unsync: bool = True) -> None:
+ """Unsync function for manually controlling when metrics states should be reverted back to their local states.
+
+ Args:
+ should_unsync: Whether to perform unsync
+
+ """
+ if not should_unsync:
+ return
+
+ if not self._is_synced:
+ raise TorchMetricsUserError("The Metric has already been un-synced.")
+
+ if self._cache is None:
+ raise TorchMetricsUserError("The internal cache should exist to unsync the Metric.")
+
+ # if we synced, restore to cache so that we can continue to accumulate un-synced state
+ for attr, val in self._cache.items():
+ setattr(self, attr, val)
+ self._is_synced = False
+ self._cache = None
+
+ @contextmanager
+ def sync_context(
+ self,
+ dist_sync_fn: Optional[Callable] = None,
+ process_group: Optional[Any] = None,
+ should_sync: bool = True,
+ should_unsync: bool = True,
+ distributed_available: Optional[Callable] = None,
+ ) -> Generator:
+ """Context manager to synchronize states.
+
+ This context manager is used in distributed setting and makes sure that the local cache states are restored
+ after yielding the synchronized state.
+
+ Args:
+ dist_sync_fn: Function to be used to perform states synchronization
+ process_group:
+ Specify the process group on which synchronization is called.
+ default: `None` (which selects the entire world)
+ should_sync: Whether to apply to state synchronization. This will have an impact
+ only when running in a distributed setting.
+ should_unsync: Whether to restore the cache state so that the metrics can
+ continue to be accumulated.
+ distributed_available: Function to determine if we are running inside a distributed setting
+
+ """
+ self.sync(
+ dist_sync_fn=dist_sync_fn,
+ process_group=process_group,
+ should_sync=should_sync,
+ distributed_available=distributed_available,
+ )
+
+ yield
+
+ self.unsync(should_unsync=self._is_synced and should_unsync)
+
+ def _wrap_compute(self, compute: Callable) -> Callable:
+ @functools.wraps(compute)
+ def wrapped_func(*args: Any, **kwargs: Any) -> Any:
+ if not self.update_called:
+ rank_zero_warn(
+ f"The ``compute`` method of metric {self.__class__.__name__}"
+ " was called before the ``update`` method which may lead to errors,"
+ " as metric states have not yet been updated.",
+ UserWarning,
+ )
+
+ # return cached value
+ if self._computed is not None:
+ return self._computed
+
+ # compute relies on the sync context manager to gather the states across processes and apply reduction
+ # if synchronization happened, the current rank accumulated states will be restored to keep
+ # accumulation going if ``should_unsync=True``,
+ with self.sync_context(
+ dist_sync_fn=self.dist_sync_fn,
+ should_sync=self._to_sync,
+ should_unsync=self._should_unsync,
+ ):
+ value = _squeeze_if_scalar(compute(*args, **kwargs))
+ # clone tensor to avoid in-place operations after compute, altering already computed results
+ value = apply_to_collection(value, Tensor, lambda x: x.clone())
+
+ if self.compute_with_cache:
+ self._computed = value
+
+ return value
+
+ return wrapped_func
+
+ @abstractmethod
+ def update(self, *_: Any, **__: Any) -> None:
+ """Override this method to update the state variables of your metric class."""
+
+ @abstractmethod
+ def compute(self) -> Any:
+ """Override this method to compute the final metric value.
+
+ This method will automatically synchronize state variables when running in distributed backend.
+
+ """
+
+ def plot(self, *_: Any, **__: Any) -> Any:
+ """Override this method plot the metric value."""
+ raise NotImplementedError
+
+ def _plot(
+ self,
+ val: Optional[Union[Tensor, Sequence[Tensor], Dict[str, Tensor], Sequence[Dict[str, Tensor]]]] = None,
+ ax: Optional[_AX_TYPE] = None,
+ ) -> _PLOT_OUT_TYPE:
+ """Plot a single or multiple values from the metric.
+
+ Args:
+ val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
+ If no value is provided, will automatically call `metric.compute` and plot that result.
+ ax: An matplotlib axis object. If provided will add plot to that axis
+
+ Returns:
+ Figure and Axes object
+
+ Raises:
+ ModuleNotFoundError:
+ If `matplotlib` is not installed
+
+ """
+ val = val if val is not None else self.compute()
+ fig, ax = plot_single_or_multi_val(
+ val,
+ ax=ax,
+ higher_is_better=self.higher_is_better,
+ name=self.__class__.__name__,
+ lower_bound=self.plot_lower_bound,
+ upper_bound=self.plot_upper_bound,
+ legend_name=self.plot_legend_name,
+ )
+ return fig, ax
+
+ def reset(self) -> None:
+ """Reset metric state variables to their default value."""
+ self._update_count = 0
+ self._forward_cache = None
+ self._computed = None
+
+ for attr, default in self._defaults.items():
+ current_val = getattr(self, attr)
+ if isinstance(default, Tensor):
+ setattr(self, attr, default.detach().clone().to(current_val.device))
+ else:
+ getattr(self, attr).clear() # delete/free list items
+
+ # reset internal states
+ self._cache = None
+ self._is_synced = False
+
+ def clone(self) -> "Metric":
+ """Make a copy of the metric."""
+ return deepcopy(self)
+
+ def __getstate__(self) -> Dict[str, Any]:
+ """Get the current state, including all metric states, for the metric.
+
+ Used for loading and saving a metric.
+
+ """
+ # ignore update and compute functions for pickling
+ return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute", "_update_signature"]}
+
+ def __setstate__(self, state: Dict[str, Any]) -> None:
+ """Set the state of the metric, based on a input state.
+
+ Used for loading and saving a metric.
+
+ """
+ # manually restore update and compute functions for pickling
+ self.__dict__.update(state)
+ self._update_signature = inspect.signature(self.update)
+ self.update: Callable = self._wrap_update(self.update) # type: ignore[method-assign]
+ self.compute: Callable = self._wrap_compute(self.compute) # type: ignore[method-assign]
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ """Overwrite default method to prevent specific attributes from being set by user."""
+ if name in (
+ "higher_is_better",
+ "is_differentiable",
+ "full_state_update",
+ "plot_lower_bound",
+ "plot_upper_bound",
+ "plot_legend_name",
+ ):
+ raise RuntimeError(f"Can't change const `{name}`.")
+ super().__setattr__(name, value)
+
+ @property
+ def device(self) -> "torch.device":
+ """Return the device of the metric."""
+ return self._device
+
+ @property
+ def dtype(self) -> "torch.dtype":
+ """Return the default dtype of the metric."""
+ return self._dtype
+
+ def type(self, dst_type: Union[str, torch.dtype]) -> "Metric":
+ """Override default and prevent dtype casting.
+
+ Please use :meth:`Metric.set_dtype` instead.
+
+ """
+ return self
+
+ def float(self) -> "Metric":
+ """Override default and prevent dtype casting.
+
+ Please use :meth:`Metric.set_dtype` instead.
+
+ """
+ return self
+
+ def double(self) -> "Metric":
+ """Override default and prevent dtype casting.
+
+ Please use :meth:`Metric.set_dtype` instead.
+
+ """
+ return self
+
+ def half(self) -> "Metric":
+ """Override default and prevent dtype casting.
+
+ Please use :meth:`Metric.set_dtype` instead.
+
+ """
+ return self
+
+ def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "Metric":
+ """Transfer all metric state to specific dtype. Special version of standard `type` method.
+
+ Arguments:
+ dst_type: the desired type as string or dtype object
+
+ """
+ self._dtype_convert = True
+ out = super().type(dst_type)
+ out._dtype_convert = False
+ return out
+
+ def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> Module:
+ """Overwrite `_apply` function such that we can also move metric states to the correct device.
+
+ This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods
+ are called. Dtype conversion is garded and will only happen through the special `set_dtype` method.
+
+ Args:
+ fn: the function to apply
+ exclude_state: list of state variables to exclude from applying the function, that then needs to be handled
+ by the metric class itself.
+
+ """
+ this = super()._apply(fn)
+ fs = str(fn)
+ cond = any(f in fs for f in ["Module.type", "Module.half", "Module.float", "Module.double", "Module.bfloat16"])
+ if not self._dtype_convert and cond:
+ return this
+
+ # Also apply fn to metric states and defaults
+ for key, value in this._defaults.items():
+ if key in exclude_state:
+ continue
+
+ if isinstance(value, Tensor):
+ this._defaults[key] = fn(value)
+ elif isinstance(value, Sequence):
+ this._defaults[key] = [fn(v) for v in value]
+
+ current_val = getattr(this, key)
+ if isinstance(current_val, Tensor):
+ setattr(this, key, fn(current_val))
+ elif isinstance(current_val, Sequence):
+ setattr(this, key, [fn(cur_v) for cur_v in current_val])
+ else:
+ raise TypeError(
+ f"Expected metric state to be either a Tensor or a list of Tensor, but encountered {current_val}"
+ )
+
+ # make sure to update the device attribute
+ # if the dummy tensor moves device by fn function we should also update the attribute
+ _dummy_tensor = fn(torch.zeros(1, device=self.device))
+ self._device = _dummy_tensor.device
+ self._dtype = _dummy_tensor.dtype
+
+ # Additional apply to forward cache and computed attributes (may be nested)
+ if this._computed is not None:
+ this._computed = apply_to_collection(this._computed, Tensor, fn)
+ if this._forward_cache is not None:
+ this._forward_cache = apply_to_collection(this._forward_cache, Tensor, fn)
+
+ return this
+
+ def persistent(self, mode: bool = False) -> None:
+ """Change post-init if metric states should be saved to its state_dict."""
+ for key in self._persistent:
+ self._persistent[key] = mode
+
+ def state_dict( # type: ignore[override] # todo
+ self,
+ destination: Optional[Dict[str, Any]] = None,
+ prefix: str = "",
+ keep_vars: bool = False,
+ ) -> Dict[str, Any]:
+ """Get the current state of metric as an dictionary.
+
+ Args:
+ destination: Optional dictionary, that if provided, the state of module will be updated into the dict and
+ the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned.
+ prefix: optional string, a prefix added to parameter and buffer names to compose the keys in state_dict.
+ keep_vars: by default the :class:`~torch.Tensor` returned in the state dict are detached from autograd.
+ If set to ``True``, detaching will not be performed.
+
+ """
+ destination: Dict[str, Union[torch.Tensor, List, Any]] = super().state_dict(
+ destination=destination, # type: ignore[arg-type]
+ prefix=prefix,
+ keep_vars=keep_vars,
+ )
+ # Register metric states to be part of the state_dict
+ for key in self._defaults:
+ if not self._persistent[key]:
+ continue
+ current_val = getattr(self, key)
+ if not keep_vars:
+ if isinstance(current_val, Tensor):
+ current_val = current_val.detach()
+ elif isinstance(current_val, list):
+ current_val = [cur_v.detach() if isinstance(cur_v, Tensor) else cur_v for cur_v in current_val]
+ destination[prefix + key] = deepcopy(current_val)
+ return destination
+
+ def _copy_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]:
+ """Copy the current state values."""
+ cache: Dict[str, Union[Tensor, List[Any]]] = {}
+ for attr in self._defaults:
+ current_value = getattr(self, attr)
+
+ if isinstance(current_value, Tensor):
+ cache[attr] = current_value.detach().clone().to(current_value.device)
+ else:
+ cache[attr] = [ # safely copy (non-graph leaf) Tensor elements
+ _.detach().clone().to(_.device) if isinstance(_, Tensor) else deepcopy(_) for _ in current_value
+ ]
+
+ return cache
+
+ def _load_from_state_dict(
+ self,
+ state_dict: dict,
+ prefix: str,
+ local_metadata: dict,
+ strict: bool,
+ missing_keys: List[str],
+ unexpected_keys: List[str],
+ error_msgs: List[str],
+ ) -> None:
+ """Load metric states from state_dict."""
+ for key in self._defaults:
+ name = prefix + key
+ if name in state_dict:
+ setattr(self, key, state_dict.pop(name))
+ super()._load_from_state_dict(
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
+ )
+
+ def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
+ """Filter kwargs such that they match the update signature of the metric."""
+ # filter all parameters based on update signature except those of
+ # types `VAR_POSITIONAL` for `* args` and `VAR_KEYWORD` for `** kwargs`
+ _params = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
+ _sign_params = self._update_signature.parameters
+ filtered_kwargs = {
+ k: v for k, v in kwargs.items() if (k in _sign_params and _sign_params[k].kind not in _params)
+ }
+
+ exists_var_keyword = any(v.kind == inspect.Parameter.VAR_KEYWORD for v in _sign_params.values())
+ # if no kwargs filtered, return all kwargs as default
+ if not filtered_kwargs and not exists_var_keyword:
+ # no kwargs in update signature -> don't return any kwargs
+ return {}
+ if exists_var_keyword:
+ # kwargs found in update signature -> return all kwargs to be sure to not omit any.
+ # filtering logic is likely implemented within the update call.
+ return kwargs
+ return filtered_kwargs
+
+ def __hash__(self) -> int:
+ """Return an unique hash of the metric.
+
+ The hash depends on both the class itself but also the current metric state, which therefore enforces that two
+ instances of the same metrics never have the same hash even if they have been updated on the same data.
+
+ """
+ # we need to add the id here, since PyTorch requires a module hash to be unique.
+ # Internally, PyTorch nn.Module relies on that for children discovery
+ # (see https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/modules/module.py#L1544)
+ # For metrics that include tensors it is not a problem,
+ # since their hash is unique based on the memory location but we cannot rely on that for every metric.
+ hash_vals = [self.__class__.__name__, id(self)]
+
+ for key in self._defaults:
+ val = getattr(self, key)
+ # Special case: allow list values, so long
+ # as their elements are hashable
+ if hasattr(val, "__iter__") and not isinstance(val, Tensor):
+ hash_vals.extend(val)
+ else:
+ hash_vals.append(val)
+
+ return hash(tuple(hash_vals))
+
+ def __add__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the addition operator."""
+ return CompositionalMetric(torch.add, self, other)
+
+ def __and__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the logical and operator."""
+ return CompositionalMetric(torch.bitwise_and, self, other)
+
+ def __eq__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[override]
+ """Construct compositional metric using the equal operator."""
+ return CompositionalMetric(torch.eq, self, other)
+
+ def __floordiv__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the floor division operator."""
+ return CompositionalMetric(torch.floor_divide, self, other)
+
+ def __ge__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the greater than or equal operator."""
+ return CompositionalMetric(torch.ge, self, other)
+
+ def __gt__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the greater than operator."""
+ return CompositionalMetric(torch.gt, self, other)
+
+ def __le__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the less than or equal operator."""
+ return CompositionalMetric(torch.le, self, other)
+
+ def __lt__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the less than operator."""
+ return CompositionalMetric(torch.lt, self, other)
+
+ def __matmul__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the matrix multiplication operator."""
+ return CompositionalMetric(torch.matmul, self, other)
+
+ def __mod__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the remainder operator."""
+ return CompositionalMetric(torch.fmod, self, other)
+
+ def __mul__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the multiplication operator."""
+ return CompositionalMetric(torch.mul, self, other)
+
+ def __ne__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[override]
+ """Construct compositional metric using the not equal operator."""
+ return CompositionalMetric(torch.ne, self, other)
+
+ def __or__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the logical or operator."""
+ return CompositionalMetric(torch.bitwise_or, self, other)
+
+ def __pow__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the exponential/power operator."""
+ return CompositionalMetric(torch.pow, self, other)
+
+ def __radd__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the addition operator."""
+ return CompositionalMetric(torch.add, other, self)
+
+ def __rand__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the logical and operator."""
+ # swap them since bitwise_and only supports that way and it's commutative
+ return CompositionalMetric(torch.bitwise_and, self, other)
+
+ def __rfloordiv__(self, other: "CompositionalMetric") -> "Metric":
+ """Construct compositional metric using the floor division operator."""
+ return CompositionalMetric(torch.floor_divide, other, self)
+
+ def __rmatmul__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the matrix multiplication operator."""
+ return CompositionalMetric(torch.matmul, other, self)
+
+ def __rmod__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the remainder operator."""
+ return CompositionalMetric(torch.fmod, other, self)
+
+ def __rmul__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the multiplication operator."""
+ return CompositionalMetric(torch.mul, other, self)
+
+ def __ror__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the logical or operator."""
+ return CompositionalMetric(torch.bitwise_or, other, self)
+
+ def __rpow__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the exponential/power operator."""
+ return CompositionalMetric(torch.pow, other, self)
+
+ def __rsub__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the subtraction operator."""
+ return CompositionalMetric(torch.sub, other, self)
+
+ def __rtruediv__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the true divide operator."""
+ return CompositionalMetric(torch.true_divide, other, self)
+
+ def __rxor__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the logical xor operator."""
+ return CompositionalMetric(torch.bitwise_xor, other, self)
+
+ def __sub__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the subtraction operator."""
+ return CompositionalMetric(torch.sub, self, other)
+
+ def __truediv__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the true divide operator."""
+ return CompositionalMetric(torch.true_divide, self, other)
+
+ def __xor__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric":
+ """Construct compositional metric using the logical xor operator."""
+ return CompositionalMetric(torch.bitwise_xor, self, other)
+
+ def __abs__(self) -> "CompositionalMetric":
+ """Construct compositional metric using the absolute operator."""
+ return CompositionalMetric(torch.abs, self, None)
+
+ def __inv__(self) -> "CompositionalMetric":
+ """Construct compositional metric using the not operator."""
+ return CompositionalMetric(torch.bitwise_not, self, None)
+
+ def __invert__(self) -> "CompositionalMetric":
+ """Construct compositional metric using the not operator."""
+ return self.__inv__()
+
+ def __neg__(self) -> "CompositionalMetric":
+ """Construct compositional metric using absolute negative operator."""
+ return CompositionalMetric(_neg, self, None)
+
+ def __pos__(self) -> "CompositionalMetric":
+ """Construct compositional metric using absolute operator."""
+ return CompositionalMetric(torch.abs, self, None)
+
+ def __getitem__(self, idx: int) -> "CompositionalMetric":
+ """Construct compositional metric using the get item operator."""
+ return CompositionalMetric(lambda x: x[idx], self, None)
+
+ def __getnewargs__(self) -> Tuple:
+ """Needed method for construction of new metrics __new__ method."""
+ return tuple(
+ Metric.__str__(self),
+ )
+
+ __iter__ = None
+
+
+def _neg(x: Tensor) -> Tensor:
+ return -torch.abs(x)
+
+
+class BaseAggregator(Metric):
+ """Base class for aggregation metrics.
+
+ Args:
+ fn: string specifying the reduction function
+ default_value: default tensor value to use for the metric state
+ nan_strategy: options:
+ - ``'error'``: if any `nan` values are encountered will give a RuntimeError
+ - ``'warn'``: if any `nan` values are encountered will give a warning and continue
+ - ``'ignore'``: all `nan` values are silently removed
+ - a float: if a float is provided will impute any `nan` values with this value
+
+ state_name: name of the metric state
+ kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
+
+ Raises:
+ ValueError:
+ If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
+
+ """
+
+ is_differentiable = None
+ higher_is_better = None
+ full_state_update: bool = False
+
+ def __init__(
+ self,
+ fn: Union[Callable, str],
+ default_value: Union[Tensor, List],
+ nan_strategy: Union[str, float] = "error",
+ state_name: str = "value",
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ allowed_nan_strategy = ("error", "warn", "ignore")
+ if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float):
+ raise ValueError(
+ f"Arg `nan_strategy` should either be a float or one of {allowed_nan_strategy}"
+ f" but got {nan_strategy}."
+ )
+
+ self.nan_strategy = nan_strategy
+ self.add_state(state_name, default=default_value, dist_reduce_fx=fn)
+ self.state_name = state_name
+
+ def _cast_and_nan_check_input(
+ self, x: Union[float, Tensor], weight: Optional[Union[float, Tensor]] = None
+ ) -> Tuple[Tensor, Tensor]:
+ """Convert input ``x`` to a tensor and check for Nans."""
+ if not isinstance(x, Tensor):
+ x = torch.as_tensor(x, dtype=self.dtype, device=self.device)
+ if weight is not None and not isinstance(weight, Tensor):
+ weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device)
+
+ nans = torch.isnan(x)
+ if weight is not None:
+ nans_weight = torch.isnan(weight)
+ else:
+ nans_weight = torch.zeros_like(nans).bool()
+ weight = torch.ones_like(x)
+ if nans.any() or nans_weight.any():
+ if self.nan_strategy == "error":
+ raise RuntimeError("Encountered `nan` values in tensor")
+ if self.nan_strategy in ("ignore", "warn"):
+ if self.nan_strategy == "warn":
+ rank_zero_warn("Encountered `nan` values in tensor. Will be removed.", UserWarning)
+ x = x[~(nans | nans_weight)]
+ weight = weight[~(nans | nans_weight)]
+ else:
+ if not isinstance(self.nan_strategy, float):
+ raise ValueError(f"`nan_strategy` shall be float but you pass {self.nan_strategy}")
+ x[nans | nans_weight] = self.nan_strategy
+ weight[nans | nans_weight] = self.nan_strategy
+
+ return x.to(self.dtype), weight.to(self.dtype)
+
+ def update(self, value: Union[float, Tensor]) -> None:
+ """Overwrite in child class."""
+
+ def compute(self) -> Tensor:
+ """Compute the aggregated value."""
+ return getattr(self, self.state_name)
+
+class MeanMetric(BaseAggregator):
+ """Aggregate a stream of value into their mean value.
+
+ As input to ``forward`` and ``update`` the metric accepts the following input
+
+ - ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with
+ arbitrary shape ``(...,)``.
+ - ``weight`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float value with
+ arbitrary shape ``(...,)``. Needs to be broadcastable with the shape of ``value`` tensor.
+
+ As output of `forward` and `compute` the metric returns the following output
+
+ - ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated (weighted) mean over all inputs received
+
+ Args:
+ nan_strategy: options:
+ - ``'error'``: if any `nan` values are encountered will give a RuntimeError
+ - ``'warn'``: if any `nan` values are encountered will give a warning and continue
+ - ``'ignore'``: all `nan` values are silently removed
+ - a float: if a float is provided will impute any `nan` values with this value
+
+ kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
+
+ Raises:
+ ValueError:
+ If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
+
+ Example:
+ >>> from torchmetrics.aggregation import MeanMetric
+ >>> metric = MeanMetric()
+ >>> metric.update(1)
+ >>> metric.update(torch.tensor([2, 3]))
+ >>> metric.compute()
+ tensor(2.)
+
+ """
+
+ mean_value: Tensor
+
+ def __init__(
+ self,
+ nan_strategy: Union[str, float] = "warn",
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ "sum",
+ torch.tensor(0.0, dtype=torch.get_default_dtype()),
+ nan_strategy,
+ state_name="mean_value",
+ **kwargs,
+ )
+ self.add_state("weight", default=torch.tensor(0.0, dtype=torch.get_default_dtype()), dist_reduce_fx="sum")
+
+ def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0) -> None:
+ """Update state with data.
+
+ Args:
+ value: Either a float or tensor containing data. Additional tensor
+ dimensions will be flattened
+ weight: Either a float or tensor containing weights for calculating
+ the average. Shape of weight should be able to broadcast with
+ the shape of `value`. Default to `1.0` corresponding to simple
+ harmonic average.
+
+ """
+ # broadcast weight to value shape
+ if not isinstance(value, Tensor):
+ value = torch.as_tensor(value, dtype=self.dtype, device=self.device)
+ if weight is not None and not isinstance(weight, Tensor):
+ weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device)
+ weight = torch.broadcast_to(weight, value.shape)
+
+ # OLD:
+ # value, weight = self._cast_and_nan_check_input(value, weight)
+
+ # NEW:
+ value, weight = value.to(self.dtype), weight.to(self.dtype)
+ # value, weight = torch.where(torch.isnan(value), torch.tensor(0.0, dtype=self.dtype, device=self.device), value), torch.where(torch.isnan(weight), torch.tensor(0.0, dtype=self.dtype, device=self.device), weight)
+
+ self.mean_value += (value * weight).sum()
+ self.weight += weight.sum()
+
+ def compute(self) -> Tensor:
+ """Compute the aggregated value."""
+ return self.mean_value / self.weight
+
+ def plot(
+ self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
+ ) -> _PLOT_OUT_TYPE:
+ """Plot a single or multiple values from the metric.
+
+ Args:
+ val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
+ If no value is provided, will automatically call `metric.compute` and plot that result.
+ ax: An matplotlib axis object. If provided will add plot to that axis
+
+ Returns:
+ Figure and Axes object
+
+ Raises:
+ ModuleNotFoundError:
+ If `matplotlib` is not installed
+
+ .. plot::
+ :scale: 75
+
+ >>> # Example plotting a single value
+ >>> from torchmetrics.aggregation import MeanMetric
+ >>> metric = MeanMetric()
+ >>> metric.update([1, 2, 3])
+ >>> fig_, ax_ = metric.plot()
+
+ .. plot::
+ :scale: 75
+
+ >>> # Example plotting multiple values
+ >>> from torchmetrics.aggregation import MeanMetric
+ >>> metric = MeanMetric()
+ >>> values = [ ]
+ >>> for i in range(10):
+ ... values.append(metric([i, i+1]))
+ >>> fig_, ax_ = metric.plot(values)
+
+ """
+ return self._plot(val, ax)
+
+
+class MetricCollection(ModuleDict):
+ """MetricCollection class can be used to chain metrics that have the same call pattern into one single class.
+
+ Args:
+ metrics: One of the following
+
+ * list or tuple (sequence): if metrics are passed in as a list or tuple, will use the metrics class name
+ as key for output dict. Therefore, two metrics of the same class cannot be chained this way.
+
+ * arguments: similar to passing in as a list, metrics passed in as arguments will use their metric
+ class name as key for the output dict.
+
+ * dict: if metrics are passed in as a dict, will use each key in the dict as key for output dict.
+ Use this format if you want to chain together multiple of the same metric with different parameters.
+ Note that the keys in the output dict will be sorted alphabetically.
+
+ prefix: a string to append in front of the keys of the output dict
+
+ postfix: a string to append after the keys of the output dict
+
+ compute_groups:
+ By default the MetricCollection will try to reduce the computations needed for the metrics in the collection
+ by checking if they belong to the same **compute group**. All metrics in a compute group share the same
+ metric state and are therefore only different in their compute step e.g. accuracy, precision and recall
+ can all be computed from the true positives/negatives and false positives/negatives. By default,
+ this argument is ``True`` which enables this feature. Set this argument to `False` for disabling
+ this behaviour. Can also be set to a list of lists of metrics for setting the compute groups yourself.
+
+ .. note::
+ The compute groups feature can significantly speedup the calculation of metrics under the right conditions.
+ First, the feature is only available when calling the ``update`` method and not when calling ``forward`` method
+ due to the internal logic of ``forward`` preventing this. Secondly, since we compute groups share metric
+ states by reference, calling ``.items()``, ``.values()`` etc. on the metric collection will break this
+ reference and a copy of states are instead returned in this case (reference will be reestablished on the next
+ call to ``update``).
+
+ .. note::
+ Metric collections can be nested at initialization (see last example) but the output of the collection will
+ still be a single flatten dictionary combining the prefix and postfix arguments from the nested collection.
+
+ Raises:
+ ValueError:
+ If one of the elements of ``metrics`` is not an instance of ``pl.metrics.Metric``.
+ ValueError:
+ If two elements in ``metrics`` have the same ``name``.
+ ValueError:
+ If ``metrics`` is not a ``list``, ``tuple`` or a ``dict``.
+ ValueError:
+ If ``metrics`` is ``dict`` and additional_metrics are passed in.
+ ValueError:
+ If ``prefix`` is set and it is not a string.
+ ValueError:
+ If ``postfix`` is set and it is not a string.
+
+ Example::
+ In the most basic case, the metrics can be passed in as a list or tuple. The keys of the output dict will be
+ the same as the class name of the metric:
+
+ >>> from torch import tensor
+ >>> from pprint import pprint
+ >>> from torchmetrics import MetricCollection
+ >>> from torchmetrics.regression import MeanSquaredError
+ >>> from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall
+ >>> target = tensor([0, 2, 0, 2, 0, 1, 0, 2])
+ >>> preds = tensor([2, 1, 2, 0, 1, 2, 2, 2])
+ >>> metrics = MetricCollection([MulticlassAccuracy(num_classes=3, average='micro'),
+ ... MulticlassPrecision(num_classes=3, average='macro'),
+ ... MulticlassRecall(num_classes=3, average='macro')])
+ >>> metrics(preds, target) # doctest: +NORMALIZE_WHITESPACE
+ {'MulticlassAccuracy': tensor(0.1250),
+ 'MulticlassPrecision': tensor(0.0667),
+ 'MulticlassRecall': tensor(0.1111)}
+
+ Example::
+ Alternatively, metrics can be passed in as arguments. The keys of the output dict will be the same as the
+ class name of the metric:
+
+ >>> metrics = MetricCollection(MulticlassAccuracy(num_classes=3, average='micro'),
+ ... MulticlassPrecision(num_classes=3, average='macro'),
+ ... MulticlassRecall(num_classes=3, average='macro'))
+ >>> metrics(preds, target) # doctest: +NORMALIZE_WHITESPACE
+ {'MulticlassAccuracy': tensor(0.1250),
+ 'MulticlassPrecision': tensor(0.0667),
+ 'MulticlassRecall': tensor(0.1111)}
+
+ Example::
+ If multiple of the same metric class (with different parameters) should be chained together, metrics can be
+ passed in as a dict and the output dict will have the same keys as the input dict:
+
+ >>> metrics = MetricCollection({'micro_recall': MulticlassRecall(num_classes=3, average='micro'),
+ ... 'macro_recall': MulticlassRecall(num_classes=3, average='macro')})
+ >>> same_metric = metrics.clone()
+ >>> pprint(metrics(preds, target))
+ {'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)}
+ >>> pprint(same_metric(preds, target))
+ {'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)}
+
+ Example::
+ Metric collections can also be nested up to a single time. The output of the collection will still be a single
+ dict with the prefix and postfix arguments from the nested collection:
+
+ >>> metrics = MetricCollection([
+ ... MetricCollection([
+ ... MulticlassAccuracy(num_classes=3, average='macro'),
+ ... MulticlassPrecision(num_classes=3, average='macro')
+ ... ], postfix='_macro'),
+ ... MetricCollection([
+ ... MulticlassAccuracy(num_classes=3, average='micro'),
+ ... MulticlassPrecision(num_classes=3, average='micro')
+ ... ], postfix='_micro'),
+ ... ], prefix='valmetrics/')
+ >>> pprint(metrics(preds, target)) # doctest: +NORMALIZE_WHITESPACE
+ {'valmetrics/MulticlassAccuracy_macro': tensor(0.1111),
+ 'valmetrics/MulticlassAccuracy_micro': tensor(0.1250),
+ 'valmetrics/MulticlassPrecision_macro': tensor(0.0667),
+ 'valmetrics/MulticlassPrecision_micro': tensor(0.1250)}
+
+ Example::
+ The `compute_groups` argument allow you to specify which metrics should share metric state. By default, this
+ will automatically be derived but can also be set manually.
+
+ >>> metrics = MetricCollection(
+ ... MulticlassRecall(num_classes=3, average='macro'),
+ ... MulticlassPrecision(num_classes=3, average='macro'),
+ ... MeanSquaredError(),
+ ... compute_groups=[['MulticlassRecall', 'MulticlassPrecision'], ['MeanSquaredError']]
+ ... )
+ >>> metrics.update(preds, target)
+ >>> pprint(metrics.compute())
+ {'MeanSquaredError': tensor(2.3750), 'MulticlassPrecision': tensor(0.0667), 'MulticlassRecall': tensor(0.1111)}
+ >>> pprint(metrics.compute_groups)
+ {0: ['MulticlassRecall', 'MulticlassPrecision'], 1: ['MeanSquaredError']}
+
+ """
+
+ _modules: Dict[str, Metric] # type: ignore[assignment]
+ _groups: Dict[int, List[str]]
+
+ def __init__(
+ self,
+ metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]],
+ *additional_metrics: Metric,
+ prefix: Optional[str] = None,
+ postfix: Optional[str] = None,
+ compute_groups: Union[bool, List[List[str]]] = True,
+ ) -> None:
+ super().__init__()
+
+ self.prefix = self._check_arg(prefix, "prefix")
+ self.postfix = self._check_arg(postfix, "postfix")
+ print(f"Metrics compute_groups: {compute_groups}")
+ self._enable_compute_groups = compute_groups
+ self._groups_checked: bool = False
+ self._state_is_copy: bool = False
+
+ self.add_metrics(metrics, *additional_metrics)
+
+ @property
+ def metric_state(self) -> Dict[str, Dict[str, Any]]:
+ """Get the current state of the metric."""
+ return {k: m.metric_state for k, m in self.items(keep_base=False, copy_state=False)}
+
+ @torch.jit.unused
+ def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
+ """Call forward for each metric sequentially.
+
+ Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs)
+ will be filtered based on the signature of the individual metric.
+
+ """
+ return self._compute_and_reduce("forward", *args, **kwargs)
+
+ def update(self, *args: Any, **kwargs: Any) -> None:
+ """Call update for each metric sequentially.
+
+ Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs)
+ will be filtered based on the signature of the individual metric.
+
+ """
+ # Use compute groups if already initialized and checked
+ if self._groups_checked:
+ # Delete the cache of all metrics to invalidate the cache and therefore recent compute calls, forcing new
+ # compute calls to recompute
+ for k in self.keys(keep_base=True):
+ mi = getattr(self, str(k))
+ mi._computed = None
+ for cg in self._groups.values():
+ # only update the first member
+ m0 = getattr(self, cg[0])
+ m0.update(*args, **m0._filter_kwargs(**kwargs))
+ if self._state_is_copy:
+ # If we have deep copied state in between updates, reestablish link
+ self._compute_groups_create_state_ref()
+ self._state_is_copy = False
+ else: # the first update always do per metric to form compute groups
+ for m in self.values(copy_state=False):
+ m_kwargs = m._filter_kwargs(**kwargs)
+ m.update(*args, **m_kwargs)
+
+ if self._enable_compute_groups:
+ self._merge_compute_groups()
+ # create reference between states
+ self._compute_groups_create_state_ref()
+ self._groups_checked = True
+
+ def _merge_compute_groups(self) -> None:
+ """Iterate over the collection of metrics, checking if the state of each metric matches another.
+
+ If so, their compute groups will be merged into one. The complexity of the method is approximately
+ ``O(number_of_metrics_in_collection ** 2)``, as all metrics need to be compared to all other metrics.
+
+ """
+ num_groups = len(self._groups)
+ while True:
+ for cg_idx1, cg_members1 in deepcopy(self._groups).items():
+ for cg_idx2, cg_members2 in deepcopy(self._groups).items():
+ if cg_idx1 == cg_idx2:
+ continue
+
+ metric1 = getattr(self, cg_members1[0])
+ metric2 = getattr(self, cg_members2[0])
+
+ if self._equal_metric_states(metric1, metric2):
+ self._groups[cg_idx1].extend(self._groups.pop(cg_idx2))
+ break
+
+ # Start over if we merged groups
+ if len(self._groups) != num_groups:
+ break
+
+ # Stop when we iterate over everything and do not merge any groups
+ if len(self._groups) == num_groups:
+ break
+ num_groups = len(self._groups)
+
+ # Re-index groups
+ temp = deepcopy(self._groups)
+ self._groups = {}
+ for idx, values in enumerate(temp.values()):
+ self._groups[idx] = values
+
+ @staticmethod
+ def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool:
+ """Check if the metric state of two metrics are the same."""
+ # empty state
+ if len(metric1._defaults) == 0 or len(metric2._defaults) == 0:
+ return False
+
+ if metric1._defaults.keys() != metric2._defaults.keys():
+ return False
+
+ for key in metric1._defaults:
+ state1 = getattr(metric1, key)
+ state2 = getattr(metric2, key)
+
+ if type(state1) != type(state2):
+ return False
+
+ if isinstance(state1, Tensor) and isinstance(state2, Tensor):
+ return state1.shape == state2.shape and allclose(state1, state2)
+
+ if isinstance(state1, list) and isinstance(state2, list):
+ return all(s1.shape == s2.shape and allclose(s1, s2) for s1, s2 in zip(state1, state2))
+
+ return True
+
+ def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
+ """Create reference between metrics in the same compute group.
+
+ Args:
+ copy: If `True` the metric state will between members will be copied instead
+ of just passed by reference
+
+ """
+ if not self._state_is_copy:
+ for cg in self._groups.values():
+ m0 = getattr(self, cg[0])
+ for i in range(1, len(cg)):
+ mi = getattr(self, cg[i])
+ for state in m0._defaults:
+ m0_state = getattr(m0, state)
+ # Determine if we just should set a reference or a full copy
+ setattr(mi, state, deepcopy(m0_state) if copy else m0_state)
+ mi._update_count = deepcopy(m0._update_count) if copy else m0._update_count
+ self._state_is_copy = copy
+
+ def compute(self) -> Dict[str, Any]:
+ """Compute the result for each metric in the collection."""
+ return self._compute_and_reduce("compute")
+
+ def _compute_and_reduce(
+ self, method_name: Literal["compute", "forward"], *args: Any, **kwargs: Any
+ ) -> Dict[str, Any]:
+ """Compute result from collection and reduce into a single dictionary.
+
+ Args:
+ method_name: The method to call on each metric in the collection.
+ Should be either `compute` or `forward`.
+ args: Positional arguments to pass to each metric (if method_name is `forward`)
+ kwargs: Keyword arguments to pass to each metric (if method_name is `forward`)
+
+ Raises:
+ ValueError:
+ If method_name is not `compute` or `forward`.
+
+ """
+ result = {}
+ for k, m in self.items(keep_base=True, copy_state=False):
+ if method_name == "compute":
+ res = m.compute()
+ elif method_name == "forward":
+ res = m(*args, **m._filter_kwargs(**kwargs))
+ else:
+ raise ValueError(f"method_name should be either 'compute' or 'forward', but got {method_name}")
+ result[k] = res
+
+ _, duplicates = _flatten_dict(result)
+
+ flattened_results = {}
+ for k, m in self.items(keep_base=True, copy_state=False):
+ res = result[k]
+ if isinstance(res, dict):
+ for key, v in res.items():
+ # if duplicates of keys we need to add unique prefix to each key
+ if duplicates:
+ stripped_k = k.replace(getattr(m, "prefix", ""), "")
+ stripped_k = stripped_k.replace(getattr(m, "postfix", ""), "")
+ key = f"{stripped_k}_{key}"
+ if getattr(m, "_from_collection", None) and m.prefix is not None:
+ key = f"{m.prefix}{key}"
+ if getattr(m, "_from_collection", None) and m.postfix is not None:
+ key = f"{key}{m.postfix}"
+ flattened_results[key] = v
+ else:
+ flattened_results[k] = res
+ return {self._set_name(k): v for k, v in flattened_results.items()}
+
+ def reset(self) -> None:
+ """Call reset for each metric sequentially."""
+ for m in self.values(copy_state=False):
+ m.reset()
+ if self._enable_compute_groups and self._groups_checked:
+ # reset state reference
+ self._compute_groups_create_state_ref()
+
+ def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> "MetricCollection":
+ """Make a copy of the metric collection.
+
+ Args:
+ prefix: a string to append in front of the metric keys
+ postfix: a string to append after the keys of the output dict.
+
+ """
+ mc = deepcopy(self)
+ if prefix:
+ mc.prefix = self._check_arg(prefix, "prefix")
+ if postfix:
+ mc.postfix = self._check_arg(postfix, "postfix")
+ return mc
+
+ def persistent(self, mode: bool = True) -> None:
+ """Change if metric states should be saved to its state_dict after initialization."""
+ for m in self.values(copy_state=False):
+ m.persistent(mode)
+
+ def add_metrics(
+ self, metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], *additional_metrics: Metric
+ ) -> None:
+ """Add new metrics to Metric Collection."""
+ if isinstance(metrics, Metric):
+ # set compatible with original type expectations
+ metrics = [metrics]
+ if isinstance(metrics, Sequence):
+ # prepare for optional additions
+ metrics = list(metrics)
+ remain: list = []
+ for m in additional_metrics:
+ sel = metrics if isinstance(m, Metric) else remain
+ sel.append(m)
+
+ if remain:
+ rank_zero_warn(
+ f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored."
+ )
+ elif additional_metrics:
+ raise ValueError(
+ f"You have passes extra arguments {additional_metrics} which are not compatible"
+ f" with first passed dictionary {metrics} so they will be ignored."
+ )
+
+ if isinstance(metrics, dict):
+ # Check all values are metrics
+ # Make sure that metrics are added in deterministic order
+ for name in sorted(metrics.keys()):
+ metric = metrics[name]
+ if not isinstance(metric, (Metric, MetricCollection)):
+ raise ValueError(
+ f"Value {metric} belonging to key {name} is not an instance of"
+ " `torchmetrics.Metric` or `torchmetrics.MetricCollection`"
+ )
+ if isinstance(metric, Metric):
+ self[name] = metric
+ else:
+ for k, v in metric.items(keep_base=False):
+ v.postfix = metric.postfix
+ v.prefix = metric.prefix
+ v._from_collection = True
+ self[f"{name}_{k}"] = v
+ elif isinstance(metrics, Sequence):
+ for metric in metrics:
+ if not isinstance(metric, (Metric, MetricCollection)):
+ raise ValueError(
+ f"Input {metric} to `MetricCollection` is not a instance of"
+ " `torchmetrics.Metric` or `torchmetrics.MetricCollection`"
+ )
+ if isinstance(metric, Metric):
+ name = metric.__class__.__name__
+ if name in self:
+ raise ValueError(f"Encountered two metrics both named {name}")
+ self[name] = metric
+ else:
+ for k, v in metric.items(keep_base=False):
+ v.postfix = metric.postfix
+ v.prefix = metric.prefix
+ v._from_collection = True
+ self[k] = v
+ else:
+ raise ValueError(
+ "Unknown input to MetricCollection. Expected, `Metric`, `MetricCollection` or `dict`/`sequence` of the"
+ f" previous, but got {metrics}"
+ )
+
+ self._groups_checked = False
+ if self._enable_compute_groups:
+ self._init_compute_groups()
+ else:
+ self._groups = {}
+
+ def _init_compute_groups(self) -> None:
+ """Initialize compute groups.
+
+ If user provided a list, we check that all metrics in the list are also in the collection. If set to `True` we
+ simply initialize each metric in the collection as its own group
+
+ """
+ if isinstance(self._enable_compute_groups, list):
+ self._groups = dict(enumerate(self._enable_compute_groups))
+ for v in self._groups.values():
+ for metric in v:
+ if metric not in self:
+ raise ValueError(
+ f"Input {metric} in `compute_groups` argument does not match a metric in the collection."
+ f" Please make sure that {self._enable_compute_groups} matches {self.keys(keep_base=True)}"
+ )
+ self._groups_checked = True
+ else:
+ # Initialize all metrics as their own compute group
+ self._groups = {i: [str(k)] for i, k in enumerate(self.keys(keep_base=True))}
+
+ @property
+ def compute_groups(self) -> Dict[int, List[str]]:
+ """Return a dict with the current compute groups in the collection."""
+ return self._groups
+
+ def _set_name(self, base: str) -> str:
+ """Adjust name of metric with both prefix and postfix."""
+ name = base if self.prefix is None else self.prefix + base
+ return name if self.postfix is None else name + self.postfix
+
+ def _to_renamed_ordered_dict(self) -> OrderedDict:
+ od = OrderedDict()
+ for k, v in self._modules.items():
+ od[self._set_name(k)] = v
+ return od
+
+ def __iter__(self) -> Iterator[Hashable]:
+ """Return an iterator over the keys of the MetricDict."""
+ return iter(self.keys())
+
+ # TODO: redefine this as native python dict
+ def keys(self, keep_base: bool = False) -> Iterable[Hashable]:
+ r"""Return an iterable of the ModuleDict key.
+
+ Args:
+ keep_base: Whether to add prefix/postfix on the items collection.
+
+ """
+ if keep_base:
+ return self._modules.keys()
+ return self._to_renamed_ordered_dict().keys()
+
+ def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tuple[str, Metric]]:
+ r"""Return an iterable of the ModuleDict key/value pairs.
+
+ Args:
+ keep_base: Whether to add prefix/postfix on the collection.
+ copy_state:
+ If metric states should be copied between metrics in the same compute group or just passed by reference
+
+ """
+ self._compute_groups_create_state_ref(copy_state)
+ if keep_base:
+ return self._modules.items()
+ return self._to_renamed_ordered_dict().items()
+
+ def values(self, copy_state: bool = True) -> Iterable[Metric]:
+ """Return an iterable of the ModuleDict values.
+
+ Args:
+ copy_state:
+ If metric states should be copied between metrics in the same compute group or just passed by reference
+
+ """
+ self._compute_groups_create_state_ref(copy_state)
+ return self._modules.values()
+
+ def __getitem__(self, key: str, copy_state: bool = True) -> Metric:
+ """Retrieve a single metric from the collection.
+
+ Args:
+ key: name of metric to retrieve
+ copy_state:
+ If metric states should be copied between metrics in the same compute group or just passed by reference
+
+ """
+ self._compute_groups_create_state_ref(copy_state)
+ if self.prefix:
+ key = key.removeprefix(self.prefix)
+ if self.postfix:
+ key = key.removesuffix(self.postfix)
+ return self._modules[key]
+
+ @staticmethod
+ def _check_arg(arg: Optional[str], name: str) -> Optional[str]:
+ if arg is None or isinstance(arg, str):
+ return arg
+ raise ValueError(f"Expected input `{name}` to be a string, but got {type(arg)}")
+
+ def __repr__(self) -> str:
+ """Return the representation of the metric collection including all metrics in the collection."""
+ repr_str = super().__repr__()[:-2]
+ if self.prefix:
+ repr_str += f",\n prefix={self.prefix}{',' if self.postfix else ''}"
+ if self.postfix:
+ repr_str += f"{',' if not self.prefix else ''}\n postfix={self.postfix}"
+ return repr_str + "\n)"
+
+ def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "MetricCollection":
+ """Transfer all metric state to specific dtype. Special version of standard `type` method.
+
+ Arguments:
+ dst_type: the desired type as ``torch.dtype`` or string.
+
+ """
+ for m in self.values(copy_state=False):
+ m.set_dtype(dst_type)
+ return self
+
+ def plot(
+ self,
+ val: Optional[Union[Dict, Sequence[Dict]]] = None,
+ ax: Optional[Union[_AX_TYPE, Sequence[_AX_TYPE]]] = None,
+ together: bool = False,
+ ) -> Sequence[_PLOT_OUT_TYPE]:
+ """Plot a single or multiple values from the metric.
+
+ The plot method has two modes of operation. If argument `together` is set to `False` (default), the `.plot`
+ method of each metric will be called individually and the result will be list of figures. If `together` is set
+ to `True`, the values of all metrics will instead be plotted in the same figure.
+
+ Args:
+ val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
+ If no value is provided, will automatically call `metric.compute` and plot that result.
+ ax: Either a single instance of matplotlib axis object or an sequence of matplotlib axis objects. If
+ provided, will add the plots to the provided axis objects. If not provided, will create a new. If
+ argument `together` is set to `True`, a single object is expected. If `together` is set to `False`,
+ the number of axis objects needs to be the same length as the number of metrics in the collection.
+ together: If `True`, will plot all metrics in the same axis. If `False`, will plot each metric in a separate
+
+ Returns:
+ Either install tuple of Figure and Axes object or an sequence of tuples with Figure and Axes object for each
+ metric in the collection.
+
+ Raises:
+ ModuleNotFoundError:
+ If `matplotlib` is not installed
+ ValueError:
+ If `together` is not an bool
+ ValueError:
+ If `ax` is not an instance of matplotlib axis object or a sequence of matplotlib axis objects
+
+ .. plot::
+ :scale: 75
+
+ >>> # Example plotting a single value
+ >>> import torch
+ >>> from torchmetrics import MetricCollection
+ >>> from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall
+ >>> metrics = MetricCollection([BinaryAccuracy(), BinaryPrecision(), BinaryRecall()])
+ >>> metrics.update(torch.rand(10), torch.randint(2, (10,)))
+ >>> fig_ax_ = metrics.plot()
+
+ .. plot::
+ :scale: 75
+
+ >>> # Example plotting multiple values
+ >>> import torch
+ >>> from torchmetrics import MetricCollection
+ >>> from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall
+ >>> metrics = MetricCollection([BinaryAccuracy(), BinaryPrecision(), BinaryRecall()])
+ >>> values = []
+ >>> for _ in range(10):
+ ... values.append(metrics(torch.rand(10), torch.randint(2, (10,))))
+ >>> fig_, ax_ = metrics.plot(values, together=True)
+
+ """
+ if not isinstance(together, bool):
+ raise ValueError(f"Expected argument `together` to be a boolean, but got {type(together)}")
+ if ax is not None:
+ if together and not isinstance(ax, _AX_TYPE):
+ raise ValueError(
+ f"Expected argument `ax` to be a matplotlib axis object, but got {type(ax)} when `together=True`"
+ )
+ if not together and not (
+ isinstance(ax, Sequence) and all(isinstance(a, _AX_TYPE) for a in ax) and len(ax) == len(self)
+ ):
+ raise ValueError(
+ f"Expected argument `ax` to be a sequence of matplotlib axis objects with the same length as the "
+ f"number of metrics in the collection, but got {type(ax)} with len {len(ax)} when `together=False`"
+ )
+ val = val or self.compute()
+ if together:
+ return plot_single_or_multi_val(val, ax=ax)
+ fig_axs = []
+ for i, (k, m) in enumerate(self.items(keep_base=False, copy_state=False)):
+ if isinstance(val, dict):
+ f, a = m.plot(val[k], ax=ax[i] if ax is not None else ax)
+ elif isinstance(val, Sequence):
+ f, a = m.plot([v[k] for v in val], ax=ax[i] if ax is not None else ax)
+ fig_axs.append((f, a))
+ return fig_axs
diff --git a/unidisc/utils/tensor_utils.py b/unidisc/utils/tensor_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ec67aad96f95fb81ee5047d36c7000de90a96bf
--- /dev/null
+++ b/unidisc/utils/tensor_utils.py
@@ -0,0 +1,106 @@
+import torch
+import math
+
+def get_interleaved_indices(modality):
+ modality_mask = modality.bool()
+
+ # Pad input_mask with zeros at both ends along the sequence dimension
+ pad_input_mask = torch.nn.functional.pad(modality_mask, (1, 1), mode='constant', value=0) # Shape: [B, N+2]
+
+ # Compute the difference along the sequence dimension to find transitions
+ diff = pad_input_mask[:, 1:].float() - pad_input_mask[:, :-1].float() # Shape: [B, N+1]
+
+ # Find start/end positions
+ starts = (diff == 1).nonzero(as_tuple=False) # Shape: [num_blocks, 2], columns: [batch_idx, position]
+ ends = (diff == -1).nonzero(as_tuple=False) # Shape: [num_blocks, 2], columns: [batch_idx, position]
+
+ # Extract batch indices and positions
+ batch_indices = starts[:, 0] # Batch indices
+ start_positions = starts[:, 1] # Start positions in [0, N+1]
+ end_positions = ends[:, 1] # End positions in [0, N+1]
+
+ return batch_indices, start_positions, end_positions
+
+def get_contiguous_blocks(sample_ids):
+ # modality: [B, N], integer tensor
+ # Compute where the value changes along the sequence dimension
+ diff = sample_ids[:, 1:] != sample_ids[:, :-1] # Shape: [B, N-1]
+ diff = torch.nn.functional.pad(diff, (1, 0), mode='constant', value=True) # Pad at the beginning
+
+ # Find start positions where the value changes (including the first position)
+ starts = diff.nonzero(as_tuple=False) # Shape: [num_blocks, 2], columns: [batch_idx, position]
+
+ # Compute end positions by shifting diff to the left and padding at the end
+ diff_end = torch.nn.functional.pad(diff[:, 1:], (0, 1), mode='constant', value=True) # Shape: [B, N]
+ ends = diff_end.nonzero(as_tuple=False) # Shape: [num_blocks, 2], columns: [batch_idx, position]
+
+ # Extract batch indices and positions
+ batch_indices = starts[:, 0] # Batch indices
+ start_positions = starts[:, 1] # Start positions in [0, N)
+ end_positions = ends[:, 1] + 1 # End positions in [0, N)
+
+ valid_mask = sample_ids[batch_indices, start_positions] >= 0
+
+ return batch_indices[valid_mask], start_positions[valid_mask], end_positions[valid_mask]
+
+def get_contiguous_blocks_per_sample(modality, sample_ids):
+ # modality: [B, N], integer tensor
+ # Compute where the value changes along the sequence dimension
+ # Detect changes in either modality or sample_ids
+ diff_modality = modality[:, 1:] != modality[:, :-1] # Shape: [B, N-1]
+ diff_sample_ids = sample_ids[:, 1:] != sample_ids[:, :-1] # Shape: [B, N-1]
+ diff = diff_modality | diff_sample_ids # Changes in either signal count as transitions
+ diff = torch.nn.functional.pad(diff, (1, 0), mode='constant', value=True) # Pad at the beginning
+
+ # Find start positions where either value changes (including the first position)
+ starts = diff.nonzero(as_tuple=False) # Shape: [num_blocks, 2], columns: [batch_idx, position]
+
+ # Compute end positions by shifting diff to the left and padding at the end
+ diff_end = torch.nn.functional.pad(diff[:, 1:], (0, 1), mode='constant', value=True) # Shape: [B, N]
+ ends = diff_end.nonzero(as_tuple=False) # Shape: [num_blocks, 2], columns: [batch_idx, position]
+
+ # Extract batch indices and positions
+ batch_indices = starts[:, 0] # Batch indices
+ start_positions = starts[:, 1] # Start positions in [0, N)
+ end_positions = ends[:, 1] + 1 # End positions in [0, N)
+
+ valid_mask = sample_ids[batch_indices, start_positions] >= 0
+
+ return batch_indices[valid_mask], start_positions[valid_mask], end_positions[valid_mask]
+
+
+def tensor_dim_slice(tensor, dim, dim_slice):
+ return tensor[(dim if dim >= 0 else dim + tensor.dim()) * (slice(None), ) + (dim_slice, )]
+
+def packshape(shape, dim : int = -1, mask : int = 0b00000001, dtype = torch.uint8, pack = True):
+ dim = dim if dim >= 0 else dim + len(shape)
+ bits, nibble = (8 if dtype is torch.uint8 else 16 if dtype is torch.int16 else 32 if dtype is torch.int32 else 64 if dtype is torch.int64 else 0), (1 if mask == 0b00000001 else 2 if mask == 0b00000011 else 4 if mask == 0b00001111 else 8 if mask == 0b11111111 else 0)
+ # bits = torch.iinfo(dtype).bits # does not JIT compile
+ assert nibble <= bits and bits % nibble == 0
+ nibbles = bits // nibble
+ shape = (shape[:dim] + (int(math.ceil(shape[dim] / nibbles)), ) + shape[1 + dim:]) if pack else (shape[:dim] + (shape[dim] * nibbles, ) + shape[1 + dim:])
+ return shape, nibbles, nibble
+
+def packbits(tensor, dim : int = -1, mask : int = 0b00000001, out = None, dtype = torch.uint8):
+ dim = dim if dim >= 0 else dim + tensor.dim()
+ shape, nibbles, nibble = packshape(tensor.shape, dim = dim, mask = mask, dtype = dtype, pack = True)
+ out = out if out is not None else torch.empty(shape, device = tensor.device, dtype = dtype)
+ assert out.shape == shape
+
+ assert tensor.shape[dim] % nibbles == 0
+ shift = torch.arange((nibbles - 1) * nibble, -1, -nibble, dtype = torch.uint8, device = tensor.device)
+ shift = shift.view(nibbles, *((1, ) * (tensor.dim() - dim - 1)))
+ torch.sum(tensor.view(*tensor.shape[:dim], -1, nibbles, *tensor.shape[1 + dim:]) << shift, dim = 1 + dim, out = out)
+ return out
+
+def unpackbits(tensor, dim : int = -1, mask : int = 0b00000001, shape = None, out = None, dtype = torch.uint8):
+ dim = dim if dim >= 0 else dim + tensor.dim()
+ shape_, nibbles, nibble = packshape(tensor.shape, dim = dim, mask = mask, dtype = tensor.dtype, pack = False)
+ shape = shape if shape is not None else shape_
+ out = out if out is not None else torch.empty(shape, device = tensor.device, dtype = dtype)
+ assert out.shape == shape
+
+ assert shape[dim] % nibbles == 0
+ shift = torch.arange((nibbles - 1) * nibble, -1, -nibble, dtype = torch.uint8, device = tensor.device)
+ shift = shift.view(nibbles, *((1, ) * (tensor.dim() - dim - 1)))
+ return torch.bitwise_and((tensor.unsqueeze(1 + dim) >> shift).view_as(out), mask, out = out)
\ No newline at end of file
diff --git a/unidisc/utils/throughput_monitor.py b/unidisc/utils/throughput_monitor.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cb052dc0f5ebcf9f3718594e14e7e3b99253f00
--- /dev/null
+++ b/unidisc/utils/throughput_monitor.py
@@ -0,0 +1,736 @@
+import time
+from collections import deque
+from typing import (TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional,
+ TypeVar, Union)
+
+import torch
+import wandb
+from torchtnt.framework.callback import Callback
+from torchtnt.framework.state import State
+from torchtnt.framework.unit import TTrainUnit
+from decoupled_utils import is_main_process, rank_zero_fn, try_except
+from typing_extensions import override
+
+_THROUGHPUT_METRICS = Dict[str, Union[int, float]]
+
+
+# The API design of this class follows `torchmetrics.Metric` but it doesn't need to be an actual Metric because there's
+# no need for synchronization or reduction as it doesn't use Tensors at all.
+class Throughput:
+ """Computes throughput.
+
+ +------------------------+-------------------------------------------------------------------------------------+
+ | Key | Value |
+ +========================+=====================================================================================+
+ | batches_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of batches |
+ | | processed per second |
+ +--------------------------+-----------------------------------------------------------------------------------+
+ | samples_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of samples |
+ | | processed per second |
+ +--------------------------+-----------------------------------------------------------------------------------+
+ | items_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of items |
+ | | processed per second |
+ +--------------------------+-----------------------------------------------------------------------------------+
+ | flpps_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of flops |
+ | | processed per second |
+ +--------------------------+-----------------------------------------------------------------------------------+
+ | device/batches_per_sec | batches_per_sec divided by world size |
+ +--------------------------+-----------------------------------------------------------------------------------+
+ | device/samples_per_sec | samples_per_sec divided by world size |
+ +--------------------------+-----------------------------------------------------------------------------------+
+ | device/items_per_sec | items_per_sec divided by world size. This may include padding depending on the data |
+ +--------------------------+-----------------------------------------------------------------------------------+
+ | device/flops_per_sec | flops_per_sec divided by world size. |
+ +--------------------------+-----------------------------------------------------------------------------------+
+ | device/mfu | device/flops_per_sec divided by world size. |
+ +--------------------------+-----------------------------------------------------------------------------------+
+ | time | Total elapsed time |
+ +--------------------------+-----------------------------------------------------------------------------------+
+ | batches | Total batches seen |
+ +--------------------------+-----------------------------------------------------------------------------------+
+ | samples | Total samples seen |
+ +--------------------------+-----------------------------------------------------------------------------------+
+ | lengths | Total items seen |
+ +--------------------------+-----------------------------------------------------------------------------------+
+
+ Example::
+
+ throughput = Throughput()
+ t0 = time()
+ for i in range(1000):
+ do_work()
+ if torch.cuda.is_available(): torch.cuda.synchronize() # required or else time() won't be correct
+ throughput.update(time=time() - t0, samples=i)
+ if i % 10 == 0:
+ print(throughput.compute())
+
+ Notes:
+ - The implementation assumes that devices FLOPs are all the same as it normalizes by the world size and only
+ takes a single ``available_flops`` value.
+ - items_per_sec, flops_per_sec and MFU do not account for padding if present. We suggest using
+ samples_per_sec or batches_per_sec to measure throughput under this circumstance.
+
+ Args:
+ available_flops: Number of theoretical flops available for a single device.
+ world_size: Number of devices available across hosts. Global metrics are not included if the world size is 1.
+ window_size: Number of batches to use for a rolling average.
+ separator: Key separator to use when creating per-device and global metrics.
+
+ """
+
+ def __init__(
+ self, world_size: int = 1, window_size: int = 100, separator: str = "/", available_flops=None
+ ) -> None:
+ self.separator = separator
+ assert world_size > 0
+ self.world_size = world_size
+ self.available_flops = available_flops
+
+ # throughput is computed over a window of values. at least 2 is enforced since it looks at the difference
+ # between the first and last elements
+ assert window_size > 1
+ # custom class instead of `deque(maxlen=)` because it's easy for users to mess up their timer/counters and log
+ # values that do not increase monotonically. this class will raise an error if that happens.
+ self._time: _MonotonicWindow[float] = _MonotonicWindow(maxlen=window_size)
+ self._batches: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size)
+ self._samples: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size)
+ self._lengths: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size)
+ self._steps: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size)
+ if available_flops is not None:
+ self._flops: Deque[int] = deque(maxlen=window_size)
+
+ def update(
+ self,
+ *,
+ time: float,
+ steps: int,
+ batches: int,
+ samples: int,
+ lengths: Optional[int] = None,
+ flops: Optional[int] = None,
+ ) -> None:
+ """Update throughput metrics.
+
+ Args:
+ time: Total elapsed time in seconds. It should monotonically increase by the iteration time with each
+ call.
+ batches: Total batches seen per device. It should monotonically increase with each call.
+ samples: Total samples seen per device. It should monotonically increase by the batch size with each call.
+ lengths: Total length of the samples seen. It should monotonically increase by the lengths of a batch with
+ each call.
+ increment_step: Flag to increment the step counter.
+
+ """
+ self._time.append(time)
+ self._steps.append(steps)
+ if samples < batches:
+ raise ValueError(f"Expected samples ({samples}) to be greater or equal than batches ({batches})")
+ self._batches.append(batches)
+ self._samples.append(samples)
+ self._lengths.append(lengths)
+ if self.available_flops is not None:
+ self._flops.append(flops)
+
+ def compute(self) -> _THROUGHPUT_METRICS:
+ """Compute throughput metrics."""
+ metrics = {
+ "time": self._time[-1],
+ "steps": self._steps[-1],
+ "batches": self._batches[-1] * self.world_size,
+ "samples": self._samples[-1] * self.world_size,
+
+ }
+ if self._lengths:
+ metrics["lengths"] = self._lengths[-1]
+
+ # a different but valid design choice would be to still compute all these metrics even if the window of values
+ # has not been filled
+ if len(self._time) == self._time.maxlen:
+ elapsed_time = self._time[-1] - self._time[0]
+ elapsed_batches = self._batches[-1] - self._batches[0]
+ elapsed_samples = self._samples[-1] - self._samples[0]
+ elapsed_steps = self._steps[-1] - self._steps[0]
+ # we are safe from ZeroDivisionError thanks to `_MonotonicWindow`
+ dev_samples_per_sec = elapsed_samples / elapsed_time
+ dev_batches_per_sec = elapsed_batches / elapsed_time
+ dev_steps_per_sec = elapsed_steps / elapsed_time
+ metrics.update({
+ f"device{self.separator}batches_per_sec": dev_batches_per_sec,
+ f"device{self.separator}samples_per_sec": dev_samples_per_sec,
+ f"device{self.separator}steps_per_sec": dev_steps_per_sec,
+ })
+ metrics.update({
+ "batches_per_sec": dev_batches_per_sec * self.world_size,
+ "samples_per_sec": dev_samples_per_sec * self.world_size,
+ "steps_per_sec": dev_steps_per_sec * self.world_size,
+ })
+
+ if len(self._lengths) == self._lengths.maxlen:
+ elapsed_lengths = self._lengths[-1] - self._lengths[0]
+ dev_items_per_sec = elapsed_lengths / elapsed_time
+ metrics[f"device{self.separator}items_per_sec"] = dev_items_per_sec
+ metrics["items_per_sec"] = dev_items_per_sec * self.world_size
+
+ if self.available_flops is not None:
+ elapsed_flops = sum(self._flops) - self._flops[0]
+ elapsed_time = self._time[-1] - self._time[0]
+ dev_flops_per_sec = (elapsed_flops / elapsed_time) if elapsed_time > 0 else 0
+ flops_per_sec = dev_flops_per_sec * self.world_size
+ metrics["flops_per_sec"] = flops_per_sec
+ metrics[f"device{self.separator}flops_per_sec"] = dev_flops_per_sec
+ metrics[f"device{self.separator}mfu"] = dev_flops_per_sec / self.available_flops
+
+ return metrics
+
+ def reset(self) -> None:
+ self._time.clear()
+ self._batches.clear()
+ self._samples.clear()
+ self._lengths.clear()
+ self._steps.clear()
+
+
+
+T = TypeVar("T", bound=float)
+
+
+class _MonotonicWindow(List[T]):
+ """Custom fixed size list that only supports right-append and ensures that all values increase monotonically."""
+
+ def __init__(self, maxlen: int) -> None:
+ super().__init__()
+ self.maxlen = maxlen
+
+ @property
+ def last(self) -> Optional[T]:
+ if len(self) > 0:
+ return self[-1]
+ return None
+
+ @override
+ def append(self, x: T) -> None:
+ last = self.last
+ if last is not None and last >= x:
+ pass
+ # print(f"Expected the value to increase, last: {last}, current: {x}")
+ list.append(self, x)
+ # truncate excess
+ if len(self) > self.maxlen:
+ del self[0]
+
+ @override
+ def __setitem__(self, key: Any, value: Any) -> None:
+ # assigning is not implemented since we don't use it. it could be by checking all previous values
+ raise NotImplementedError("__setitem__ is not supported")
+
+
+class ThroughputMonitor(Callback):
+ def __init__(
+ self, batch_size_fn: Callable[[Any], int], length_fn: Optional[Callable[[Any], int]] = None, world_size: int = 1, log_every_n_steps=10, flops_per_sample=None, device = None, dtype = None, **kwargs: Any
+ ) -> None:
+ super().__init__()
+ self.kwargs = kwargs
+ self.batch_size_fn = batch_size_fn
+ self.length_fn = length_fn
+ self._throughputs: dict = {}
+ self._t0s: dict = {}
+ self.inference_max_batch_size = 0
+ self.stage = 'train'
+ self.log_every_n_steps = log_every_n_steps
+ self.flops_per_sample = flops_per_sample
+ self.last_samples = 0
+
+ if flops_per_sample is not None:
+ self.available_flops = get_available_flops(device, dtype)
+
+ throughput = Throughput(world_size=world_size, available_flops=self.available_flops, **self.kwargs)
+ self._throughputs[self.stage] = throughput
+
+ self._throughputs[self.stage].reset()
+ self._t0s[self.stage] = time.perf_counter()
+
+ @rank_zero_fn
+ @try_except(write_error_to_file=True)
+ @torch.inference_mode()
+ def on_train_step_end(
+ self, state: State, unit: TTrainUnit
+ ) -> None:
+
+ global_step = unit.global_step
+ if global_step % self.log_every_n_steps != 0:
+ return
+
+ if is_torch_cuda_available():
+ torch.cuda.synchronize()
+
+ stage = self.stage
+ throughput = self._throughputs[stage]
+ elapsed = time.perf_counter() - self._t0s[stage]
+ tokens_per_sample = unit.num_tokens_per_sample
+
+ if self.batch_size_fn is not None:
+ batch_size = self.batch_size_fn(state.batch) * unit.gradient_accumulation_steps
+ else:
+ batch_size = unit.step_batch_size
+
+ if self.length_fn is not None:
+ batch = state.batch
+ _length = self.length_fn(batch)
+ else:
+ _length = tokens_per_sample * batch_size
+
+ if self.available_flops is not None:
+ _flops = self.flops_per_sample * ((global_step * batch_size) - self.last_samples)
+ self.last_samples = global_step * batch_size
+ else:
+ _flops = None
+
+ throughput.update(
+ time=elapsed,
+ steps=global_step,
+ batches=global_step * unit.gradient_accumulation_steps,
+ samples=global_step * batch_size,
+ lengths=global_step * _length,
+ flops=_flops
+ )
+
+ throughput = self._throughputs[stage]
+ metrics = throughput.compute()
+ metrics = {f"{stage}_metrics/{k}": v for k, v in metrics.items()}
+ if is_main_process():
+ wandb.log(dict(**metrics, **{"trainer/global_step": global_step}))
+
+
+_CUDA_FLOPS: Dict[str, Dict[Union[str, torch.dtype], float]] = {
+ # Hopper
+ # source: https://resources.nvidia.com/en-us-tensor-core
+ "h100 nvl": {
+ torch.float64: 67e12,
+ torch.float32: 133.8e12,
+ "tfloat32": 989.4e12,
+ torch.bfloat16: 1978.8e12,
+ torch.float16: 1978.8e12,
+ torch.int8: 3957.8e12,
+ },
+ "h100 sxm": {
+ torch.float64: 33.5e12,
+ torch.float32: 66.9e12,
+ "tfloat32": 494.7e12,
+ torch.bfloat16: 989.4e12,
+ torch.float16: 989.4e12,
+ torch.int8: 1978.9e12,
+ },
+ "h100 pcie": {
+ torch.float64: 25.6e12,
+ torch.float32: 51.2e12,
+ "tfloat32": 378e12,
+ torch.bfloat16: 756e12,
+ torch.float16: 756e12,
+ torch.int8: 1513e12,
+ },
+ # Ada
+ # source: https://images.nvidia.com/aem-dam/Solutions/Data-Center/l4/nvidia-ada-gpu-architecture-whitepaper-v2.1.pdf
+ "rtx 4090": {
+ torch.float32: 82.6e12,
+ "tfloat32": 82.6e12,
+ torch.bfloat16: 82.6e12,
+ torch.float16: 82.6e12,
+ torch.int8: 660.6e12,
+ "int4": 1321.2e12,
+ },
+ "rtx 4080": {
+ torch.float32: 48.7e12,
+ "tfloat32": 48.7e12,
+ torch.bfloat16: 48.7e12,
+ torch.float16: 48.7e12,
+ torch.int8: 389.9e12,
+ "int4": 779.8e12,
+ },
+ "l4": {
+ torch.float32: 30.3e12,
+ "tfloat32": 60e12,
+ torch.bfloat16: 121e12,
+ torch.float16: 121e12,
+ torch.int8: 242e12,
+ "int4": 484e12,
+ },
+ "l40": {
+ torch.float32: 90.5e12,
+ "tfloat32": 90.5e12,
+ torch.bfloat16: 181e12,
+ torch.float16: 181e12,
+ torch.int8: 362e12,
+ "int4": 724e12,
+ },
+ # Ampere
+ # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
+ # sxm and pcie have same flop counts
+ "a100": {
+ torch.float64: 9.7e12,
+ torch.float32: 19.5e12,
+ "tfloat32": 156e12,
+ torch.bfloat16: 312e12,
+ torch.float16: 312e12,
+ torch.int8: 624e12,
+ },
+ "a6000": {
+ torch.float32: 38.7e12,
+ "tfloat32": 77.4e12,
+ torch.bfloat16: 38.7e12,
+ torch.float16: 38.7e12,
+ torch.int8: 309.7e12,
+ "int4": 619.3e12,
+ },
+ "6000ada": {
+ torch.float32: 91.1e12,
+ "tfloat32": 182.1e12,
+ torch.bfloat16: 91.1e12,
+ torch.float16: 91.1e12,
+ torch.int8: 728.5e12,
+ "int4": 1457.0e12,
+ },
+ "a5000": {
+ torch.float32: 27.8e12,
+ torch.bfloat16: 27.8e12,
+ torch.float16: 27.8e12,
+ },
+ "a5500": {
+ torch.float32: 34.1e12,
+ torch.bfloat16: 34.1e12,
+ torch.float16: 34.1e12,
+ },
+ "a40": {
+ torch.float32: 37.4e12,
+ "tfloat32": 74.8e12,
+ torch.bfloat16: 37.4e12,
+ torch.float16: 37.4e12,
+ torch.int8: 299.3e12,
+ "int4": 598.7e12,
+ },
+ # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/a10-datasheet.pdf
+ "a10g": {
+ torch.float32: 31.2e12,
+ "tfloat32": 62.5e12,
+ torch.bfloat16: 125e12,
+ torch.float16: 125e12,
+ torch.int8: 250e12,
+ "int4": 500e12,
+ },
+ "rtx 3090 ti": {
+ torch.float32: 40e12,
+ "tfloat32": 40e12,
+ torch.bfloat16: 40e12,
+ torch.float16: 40e12,
+ torch.int8: 320e12,
+ "int4": 640e12,
+ },
+ "rtx 3090": {
+ torch.float32: 35.6e12,
+ "tfloat32": 35.6e12,
+ torch.bfloat16: 35.6e12,
+ torch.float16: 35.6e12,
+ torch.int8: 284e12,
+ "int4": 568e12,
+ },
+ "rtx 3080 ti": {
+ torch.float32: 34.1e12,
+ "tfloat32": 34.1e12,
+ torch.bfloat16: 34.1e12,
+ torch.float16: 34.1e12,
+ torch.int8: 272.8e12,
+ "int4": 546.6e12,
+ },
+ "rtx 3080": {
+ torch.float32: 29.8e12,
+ "tfloat32": 29.8e12,
+ torch.bfloat16: 29.8e12,
+ torch.float16: 29.8e12,
+ torch.int8: 238e12,
+ "int4": 476e12,
+ },
+ "rtx 3070": {
+ torch.float32: 20.3e12,
+ "tfloat32": 20.3e12,
+ torch.bfloat16: 20.3e12,
+ torch.float16: 20.3e12,
+ torch.int8: 162.6e12,
+ "int4": 325.2e12,
+ },
+ # Turing
+ # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf
+ # sxm and pcie have same flop counts
+ "t4": {
+ torch.float32: 8.1e12,
+ torch.float16: 65e12,
+ torch.int8: 130e12,
+ "int4": 260e12,
+ },
+ # https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/quadro-rtx-5000-data-sheet-us-nvidia-704120-r4-web.pdf
+ "quadro rtx 5000": {
+ torch.float32: 11.2e12,
+ torch.float16: 89.2e12,
+ },
+ "rtx 2080 super": {
+ torch.float32: 11.2e12,
+ torch.float16: 22.3e12,
+ torch.int8: 178.4e12,
+ "int4": 356.8e12,
+ },
+ "rtx 2080 ti": {
+ torch.float32: 14.2e12,
+ torch.float16: 28.5e12,
+ torch.int8: 227.7e12,
+ "int4": 455.4e12,
+ },
+ "rtx 2080": {
+ torch.float32: 10.6e12,
+ torch.float16: 21.2e12,
+ torch.int8: 169.6e12,
+ "int4": 339.1e12,
+ },
+ # https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.pdf
+ "rtx 2070 super": {
+ torch.float32: 9.1e12,
+ torch.float16: 18.1e12,
+ torch.int8: 145e12,
+ "int4": 290e12,
+ },
+ "titan rtx": {
+ torch.float32: 16.3e12,
+ torch.float16: 32.6e12,
+ torch.int8: 261e12,
+ "int4": 522e12,
+ },
+ # Volta
+ # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf
+ "v100 sxm": {
+ torch.float64: 7.8e12,
+ torch.float32: 15.7e12,
+ torch.float16: 125e12,
+ },
+ "v100 pcie": {
+ torch.float64: 7e12,
+ torch.float32: 14e12,
+ torch.float16: 112e12,
+ },
+ "v100s pcie": {
+ torch.float64: 8.2e12,
+ torch.float32: 16.4e12,
+ torch.float16: 130e12,
+ },
+ "l40s": {
+ torch.float32: 91.6e12,
+ torch.bfloat16: 362.05e12,
+ torch.float16: 362.05e12,
+ },
+}
+
+_TPU_FLOPS = {
+ # flop count for each TPU generation is the same for all precisions
+ # since bfloat16 precision is always used for performing matrix operations
+ # for more info: https://cloud.google.com/tpu/docs/bfloat16#choosing_bfloat16
+ # source: https://arxiv.org/pdf/1907.10701.pdf
+ "v2": 45e12,
+ # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v3
+ "v3": 123e12,
+ # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4
+ "v4": 275e12,
+ # source: https://cloud.google.com/tpu/docs/v5e-training
+ "v5litepod": 197e12,
+}
+
+from decoupled_utils import is_torch_cuda_available, rprint, synchronize_device
+
+def _is_ampere_or_later(device: Optional[torch.device] = None) -> bool:
+ major, _ = torch.cuda.get_device_capability(device)
+ return major >= 8 # Ampere and later leverage tensor cores, where this setting becomes useful
+
+
+def get_available_flops(device: torch.device, dtype: Union[torch.dtype, str]) -> Optional[int]:
+ """Returns the available theoretical FLOPs.
+
+ This is an optimistic upper limit that could only be achievable if only thick matmuls were run in a benchmark
+ environment.
+
+ """
+ if device.type == "cuda":
+ device_name = torch.cuda.get_device_name(device)
+ chip = device_name.lower()
+ if "h100" in chip:
+ if "hbm3" in chip:
+ chip = "h100 sxm"
+ elif "nvl" in chip:
+ chip = "h100 nvl"
+ elif "pcie" in chip or "hbm2e" in chip:
+ chip = "h100 pcie"
+ elif "l40s" in chip:
+ chip = "l40s"
+ elif "l4" in chip:
+ chip = "l40" if "tesla" in chip else "l4"
+ elif "geforce rtx" in chip:
+ number = chip.split(" ")[3]
+ extra = ""
+ if "super" in chip:
+ extra = " super"
+ elif "ti" in chip:
+ extra = " ti"
+ chip = f"rtx {number}{extra}"
+ elif "a6000" in chip:
+ chip = "a6000"
+ elif "6000 ada" in chip:
+ chip = "6000ada"
+ elif "a5500" in chip:
+ chip = "a5500"
+ elif "a5000" in chip:
+ chip = "a5000"
+ elif "a100" in chip:
+ chip = "a100"
+ elif "a40" in chip:
+ chip = "a40"
+ elif "a10g" in chip:
+ chip = "a10g"
+ elif "t4" in chip:
+ chip = "t4"
+ elif "quadro rtx 5000" in chip:
+ chip = "quadro rtx 5000"
+ elif "titan rtx" in chip:
+ chip = "titan rtx"
+ elif "v100-sxm" in chip:
+ chip = "v100 sxm"
+ elif "v100-pcie" in chip:
+ chip = "v100 pcie"
+ elif "v100s-pcie" in chip:
+ chip = "v100s pcie"
+ else:
+ # the flops list is not exhaustive, return with a warning
+ rprint(f"FLOPs not found for {device_name!r}")
+ return 0
+ if chip not in _CUDA_FLOPS:
+ # parsing is implemented but we don't have the stats
+ rprint(f"FLOPs not found for {device_name!r}, chip is {chip!r}")
+ return 0
+ dtype_to_flops = _CUDA_FLOPS[chip]
+ if dtype is torch.float32:
+ if _is_ampere_or_later() and torch.get_float32_matmul_precision() != "highest":
+ dtype = "tfloat32"
+ if dtype not in dtype_to_flops:
+ # for example, T4 doesn't support bfloat16. it might also be that we are missing this dtype from the list
+ rprint(f"{device_name!r} does not support {dtype}")
+ return 0
+ return int(dtype_to_flops[dtype])
+
+ if device.type == "xla":
+ from torch_xla._internal import tpu
+ tpu_env = tpu.get_tpu_env()
+ # not all TPU generations define the "TYPE" envar. example: TYPE="V4", ACCELERATOR_TYPE="v4-8"
+ device_name = tpu_env.get("TYPE") or tpu_env["ACCELERATOR_TYPE"].split("-")[0]
+ chip = device_name.lower()
+ assert isinstance(device_name, str)
+ if chip not in _TPU_FLOPS:
+ rprint(f"FLOPs not found for TPU {device_name!r} with {dtype}")
+ return 0
+ return int(_TPU_FLOPS[chip])
+
+
+# https://github.com/karpathy/llm.c/blob/master/llmc/mfu.h
+class PerfData:
+ def __init__(self, TF_32, BF_16_32, FP_16_32, FP_16_16, FP_8_32, FP_8_16, CLOCK, CORES):
+ self.TF_32 = TF_32
+ self.BF_16_32 = BF_16_32
+ self.FP_16_32 = FP_16_32
+ self.FP_16_16 = FP_16_16
+ self.FP_8_32 = FP_8_32
+ self.FP_8_16 = FP_8_16
+ self.CLOCK = CLOCK
+ self.CORES = CORES
+
+class GPUEntry:
+ def __init__(self, name, perf_data, new_cores, new_mhz):
+ self.name = name
+ self.perf_data = perf_data
+ self.new_cores = new_cores
+ self.new_mhz = new_mhz
+
+VOLTA = PerfData(125.0, -1.0, 125.0, -1.0, -1.0, -1.0, 1530.0, 640.0)
+AMPERE_DATACENTER = PerfData(156.0, 312.0, 312.0, 312.0, -1.0, -1.0, 1410.0, 432.0)
+AMPERE_CONSUMER = PerfData(40.0, 80.0, 80.0, 160.0, -1.0, -1.0, 1860.0, 336.0)
+HOPPER = PerfData(378.0, 756.0, 756.0, 756.0, 1513.0, 1513.0, 1620.0, 456.0)
+ADA = PerfData(82.6, 165.2, 165.2, 330.3, 330.3, 660.6, 2520.0, 512.0)
+
+gpu_db = [
+ GPUEntry("Tesla V100-SXM2-16GB", VOLTA, 640, 1530),
+ GPUEntry("Tesla V100-PCIE-32GB", VOLTA, 640, 1530),
+ GPUEntry("NVIDIA A100-PCIE-40GB", AMPERE_DATACENTER, 432, 1410),
+ GPUEntry("NVIDIA A100-PCIE-80GB", AMPERE_DATACENTER, 432, 1410),
+ GPUEntry("NVIDIA A100-SXM4-40GB", AMPERE_DATACENTER, 432, 1410),
+ GPUEntry("NVIDIA A100-SXM4-80GB", AMPERE_DATACENTER, 432, 1410),
+ GPUEntry("NVIDIA RTX A2000", AMPERE_CONSUMER, 104, 1200),
+ GPUEntry("NVIDIA RTX A4000", AMPERE_CONSUMER, 192, 1560),
+ GPUEntry("NVIDIA RTX A4500", AMPERE_CONSUMER, 224, 1650),
+ GPUEntry("NVIDIA RTX A5000", AMPERE_CONSUMER, 256, 1695),
+ GPUEntry("NVIDIA RTX A5500", AMPERE_CONSUMER, 320, 1770),
+ GPUEntry("NVIDIA RTX A6000", AMPERE_CONSUMER, 336, 1800),
+ GPUEntry("NVIDIA GeForce RTX 3090 Ti", AMPERE_CONSUMER, 336, 1860),
+ GPUEntry("NVIDIA GeForce RTX 3090", AMPERE_CONSUMER, 328, 1695),
+ GPUEntry("NVIDIA GeForce RTX 3080 Ti", AMPERE_CONSUMER, 320, 1665),
+ GPUEntry("NVIDIA GeForce RTX 3080", AMPERE_CONSUMER, 272, 1710),
+ GPUEntry("NVIDIA GeForce RTX 3070 Ti", AMPERE_CONSUMER, 192, 1770),
+ GPUEntry("NVIDIA GeForce RTX 3070", AMPERE_CONSUMER, 184, 1725),
+ GPUEntry("NVIDIA GeForce RTX 3060 Ti", AMPERE_CONSUMER, 152, 1665),
+ GPUEntry("NVIDIA GeForce RTX 3060", AMPERE_CONSUMER, 112, 1777),
+ GPUEntry("NVIDIA RTX A2000 ADA", ADA, 88, 2130),
+ GPUEntry("NVIDIA RTX A4000 ADA", ADA, 192, 2175),
+ GPUEntry("NVIDIA RTX A4500 ADA", ADA, 224, 2580),
+ GPUEntry("NVIDIA RTX A5000 ADA", ADA, 400, 2550),
+ GPUEntry("NVIDIA RTX A5880 ADA", ADA, 440, 2460),
+ GPUEntry("NVIDIA RTX A6000 ADA", ADA, 568, 2505),
+ GPUEntry("NVIDIA GeForce RTX 4090", ADA, 512, 2520),
+ GPUEntry("NVIDIA GeForce RTX 4080 SUPER", ADA, 320, 2550),
+ GPUEntry("NVIDIA GeForce RTX 4080", ADA, 304, 2505),
+ GPUEntry("NVIDIA GeForce RTX 4070 Ti SUPER", ADA, 264, 2610),
+ GPUEntry("NVIDIA GeForce RTX 4070 Ti", ADA, 240, 2610),
+ GPUEntry("NVIDIA GeForce RTX 4070 SUPER", ADA, 224, 2475),
+ GPUEntry("NVIDIA GeForce RTX 4070", ADA, 184, 2475),
+ GPUEntry("NVIDIA GeForce RTX 4060 Ti", ADA, 136, 2535),
+ GPUEntry("NVIDIA GeForce RTX 4060", ADA, 96, 2460),
+ GPUEntry("NVIDIA H100 PCIe", HOPPER, 456, 1620),
+ GPUEntry("NVIDIA H100 80GB HBM3", HOPPER, 528, 1830)
+]
+
+MFUH_PRECISION_FP32 = 0
+MFUH_PRECISION_FP16 = 1
+MFUH_PRECISION_BF16 = 2
+
+def get_flops_promised(device, precision_mode):
+ """
+ This function is used to estimate the Model Flops Utilization (MFU)
+ basically we have to figure out how many flops the GPU can do per second.
+ Note that this is not a simple endeavor and may well go wrong! The details are tricky.
+ The returned value is in units of 1e12.
+ """
+ if precision_mode not in [MFUH_PRECISION_FP32, MFUH_PRECISION_FP16, MFUH_PRECISION_BF16]:
+ print(f"Invalid precision mode: {precision_mode}")
+ return -1.0
+
+ for entry in gpu_db:
+ if entry.name == device:
+ perf_data = entry.perf_data
+
+ value = -1.0
+ if precision_mode == MFUH_PRECISION_BF16:
+ value = perf_data.BF_16_32
+ if precision_mode == MFUH_PRECISION_FP32:
+ value = perf_data.TF_32
+ if precision_mode == MFUH_PRECISION_FP16:
+ value = perf_data.FP_16_32
+
+ if value < 0.0:
+ print(f"No data for GPU {device} and precision mode {precision_mode}")
+ return -1.0
+
+ new_cores = entry.new_cores
+ new_mhz = entry.new_mhz
+ adjusted = value * (new_cores / perf_data.CORES) * (new_mhz / perf_data.CLOCK)
+ return adjusted
+
+ return -1.0
\ No newline at end of file
diff --git a/unidisc/utils/trainer_utils.py b/unidisc/utils/trainer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1fbfa0ae199672030a2293e8a25ee8073fb9d90
--- /dev/null
+++ b/unidisc/utils/trainer_utils.py
@@ -0,0 +1,252 @@
+from __future__ import annotations
+
+import os
+import shutil
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from functools import wraps
+from pathlib import Path
+from typing import TYPE_CHECKING, Optional, Any
+
+import torch
+import torch.nn as nn
+from accelerate import Accelerator
+from accelerate.state import PartialState
+from accelerate.utils import extract_model_from_parallel
+from image_utils import Im
+
+from decoupled_utils import is_main_process, gprint, rprint
+
+log_info = gprint
+log_error = gprint
+
+
+def load_from_ckpt(cfg, accelerator: Optional[Accelerator], model: nn.Module, load_model: bool, load_accelerator_state: bool = False) -> int:
+ """
+ Loads the model [or just returns the checkpoint global step]
+ """
+ if cfg.trainer.ckpt == "latest":
+ # Get the most recent checkpoint
+ dirs = os.listdir(cfg.checkpoint_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("_")[-1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+ else:
+ path = Path(cfg.trainer.ckpt)
+
+ if path.is_dir() and any(child.is_dir() and child.name == "state" for child in path.iterdir()):
+ path = path / "state"
+
+ if path is None:
+ log_error(f"Checkpoint '{cfg.trainer.ckpt}' does not exist. Exiting.")
+ raise FileNotFoundError
+ else:
+ log_info(f"Resuming from checkpoint {path}")
+
+ # TODO: @Tsung-Wei Ke tested this and found that it doesn't work, at least in some of the cases we used.
+ # We should see if we can still load the optimizer states.
+
+ # from accelerate.utils.modeling import load_checkpoint_in_model
+ # if path.is_file() or cfg.trainer.load_weights_only_no_state:
+ # load_checkpoint_in_model(model, str(path))
+ # else:
+
+ if load_model:
+ if accelerator is not None and path.parent.stem == "state" and load_accelerator_state:
+ log_info("Loading accelerator state!")
+ accelerator.load_state(path.parent)
+
+ state_dict = torch.load(path, map_location='cpu')
+ if cfg.trainer.ignore_clip_weights:
+ state_dict = {k:v for k,v in state_dict.items() if 'clip' not in k and 'mapper.position_embedding' not in k and 'up_proj' not in k}
+ if cfg.trainer.ignore_pos_emb_weights:
+ state_dict = {k:v for k,v in state_dict.items() if 'cross_attn_pos_emb' not in k}
+ model.load_state_dict(state_dict, strict=cfg.trainer.strict_load)
+ try:
+ if path.is_file():
+ global_step = int(path.parent.parent.name.split("_")[-1])
+ else:
+ global_step = int(path.name.split("_")[-1] if "_" in path.name else path.parent.name.split("_")[-1])
+ except:
+ log_error(f"Could not parse global step from checkpoint path {path}. Setting to 0.")
+ global_step = 0
+
+ # first_epoch = global_step // num_update_steps_per_epoch
+ first_epoch = 0
+ log_info(f"Continuing from epoch {first_epoch} and global step {global_step}")
+ return global_step
+
+
+def handle_checkpointing_dirs(cfg, prefix: str):
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if cfg.checkpointing.checkpoints_total_limit is not None:
+ if not os.path.exists(cfg.checkpointing.save_dir):
+ os.makedirs(cfg.checkpointing.save_dir, exist_ok=True)
+
+ checkpoints = os.listdir(cfg.checkpointing.save_dir)
+ checkpoints = [
+ d for d in checkpoints
+ if d.startswith(f"{prefix}_")
+ and len(os.listdir(os.path.join(cfg.checkpointing.save_dir, d))) >= 1
+ and sum(f.stat().st_size for f in Path(os.path.join(cfg.checkpointing.save_dir, d)).rglob('*')) > 10 * 1024 * 1024
+ ]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("_")[-1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= cfg.checkpointing.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - cfg.checkpointing.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ log_info(f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints")
+ log_info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(cfg.checkpointing.save_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+
+@dataclass
+class TrainingState:
+ epoch_step: int # Step in the current epoch. Resets every epoch.
+ num_epoch_steps: int # Total number of steps in the current epoch. [E.g., dataloader size on a single GPU]
+ global_step: int # Current number of steps which does not reset.
+ true_step: int
+ epoch: int
+ split: Optional[Any] = None
+ current_run_global_step: Optional[int] = None
+ batch: Optional[dict] = None
+
+
+class Trainable(nn.Module, ABC):
+ @abstractmethod
+ def forward(self, batch: dict, state: TrainingState) -> dict:
+ ...
+
+ @abstractmethod
+ def set_training_mode(self):
+ ...
+
+ @abstractmethod
+ def set_inference_mode(self):
+ ...
+
+ @abstractmethod
+ def checkpoint(self, accelerator: Accelerator, state: TrainingState, path: Path):
+ ...
+
+ def run_inference(self, batch: dict, state: TrainingState, accelerator: Optional[Accelerator] = None) -> dict[str, Im]:
+ ...
+
+ def on_sync_gradients(self):
+ pass
+
+ def get_param_groups(self) -> Optional[dict[str, Any]]:
+ return None
+
+ def process_input(self, batch: dict) -> Any:
+ return batch
+
+
+def check_every_n_steps(
+ state: TrainingState,
+ n: Optional[int],
+ run_first: bool = False,
+ all_processes: bool = False,
+ decay_steps: bool = False,
+ max_eval_interval: Optional[int] = None,
+ decrease_n_runs: Optional[int] = None,
+):
+ if n is None or n <= 0: return False
+ if decay_steps:
+ max_eval_interval = max_eval_interval or n * 2
+ decrease_n_runs = decrease_n_runs or 5
+ n = min(n * ((state.global_step // (decrease_n_runs * n)) + 1), max_eval_interval)
+
+ return ((state.global_step % n == 0 or (state.current_run_global_step is not None and state.current_run_global_step == 0)) and (run_first or (state.global_step > 0 and state.current_run_global_step > 0))) and (is_main_process() or all_processes)
+
+
+def check_every_n_epochs(state: TrainingState, n: Optional[int], run_first: bool = False, all_processes: bool = False):
+ # Check if the current step is the last one in the epoch. We always want to run on the last step of the epoch. If we have n=5, then we run at the end of epochs 0 [if except_first == False], 5, 10, 15, etc.
+ return (
+ n is not None
+ and (state.epoch_step == state.num_epoch_steps - 1)
+ and ((state.epoch + 1) % n == 0 or (state.epoch == 0 and run_first))
+ and (is_main_process() or all_processes)
+ )
+
+
+def every_n_steps(func, *wrapper_args, **wrapper_kwargs):
+ @wraps(func)
+ def wrapper(state: TrainingState, *args, **kwargs):
+ if check_every_n_steps(state, *wrapper_args, **wrapper_kwargs):
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+def every_n_epochs(func, *wrapper_args, **wrapper_kwargs):
+ @wraps(func)
+ def wrapper(state: TrainingState, *args, **kwargs):
+ if check_every_n_epochs(state, *wrapper_args, **wrapper_kwargs):
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+def unwrap(model):
+ """
+ In DDP/torch.compile and some other situations, our nn.Module is wrapped so to access class attributes we often need to unwrap it.
+ """
+ # equiv to. unwrap
+ if PartialState._shared_state == {}:
+ # Accelerate is initialized
+ return extract_model_from_parallel(model)
+ else:
+ from torch.nn.parallel import DistributedDataParallel
+
+ if isinstance(model, DistributedDataParallel):
+ return model.module
+ else:
+ return model
+
+def linear_warmup(current_step: int, warmup_steps: int, final_value: float, initial_value: float = 0.0, start_step: int = 0):
+ current_step = max(0, current_step - start_step)
+ if current_step < warmup_steps:
+ return initial_value + (final_value - initial_value) * (current_step / max(1, warmup_steps))
+ else:
+ return final_value
+
+def get_parameters(module):
+ params = []
+
+ def recursive_collect(module):
+ for name, child in module.named_children():
+ if isinstance(child, nn.Embedding) or child.__class__.__name__ == "EmbeddingLayer":
+ continue
+ else:
+ params.extend(list(child.parameters(recurse=False)))
+ recursive_collect(child)
+
+ recursive_collect(module)
+ return params
+
+def count_parameters(module):
+ return sum(p.numel() for p in get_parameters(module) if p.requires_grad)
+
+def incremental_dict_update(data, new_data):
+ data.update(new_data)
+ return data
+ for key, value in new_data.items():
+ if key in data and isinstance(value, torch.Tensor):
+ if value.numel() == 1:
+ data[key] = (data[key] + value) / 2
+ else:
+ data[key] = torch.cat([data[key], value])
+ else:
+ data[key] = value
+ return data
+
+if __name__ == "__main__":
+ for i in range(50000):
+ if check_every_n_steps(TrainingState(epoch_step=i, num_epoch_steps=10, global_step=i, epoch=0, true_step=i), 500, decay_steps=True):
+ print(i)
\ No newline at end of file
diff --git a/unidisc/utils/viz_utils.py b/unidisc/utils/viz_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e497cc037e412cb255022f6989f22f75900eba4
--- /dev/null
+++ b/unidisc/utils/viz_utils.py
@@ -0,0 +1,317 @@
+from PIL import Image, ImageDraw, ImageFont
+import textwrap
+
+
+def create_text_image(
+ text,
+ desired_width,
+ font_path="/usr/share/fonts/dejavu/DejaVuSansMono.ttf",
+ font_size=20,
+ text_color=(0, 0, 0),
+ bg_color=(255, 255, 255),
+ line_spacing=5
+):
+ """
+ Creates a Pillow Image with the given text wrapped to fit the desired width.
+
+ Parameters:
+ - text (str): The text to render on the image.
+ - desired_width (int): The width of the image in pixels.
+ - font_path (str): Path to the .ttf font file.
+ - font_size (int): Size of the font.
+ - text_color (tuple): RGB color tuple for the text.
+ - bg_color (tuple): RGB color tuple for the background.
+ - line_spacing (int): Space between lines in pixels.
+
+ Returns:
+ - Image: A Pillow Image object with the rendered text.
+ """
+ # Load the font
+ try:
+ font = ImageFont.truetype(font_path, font_size)
+ except IOError:
+ raise IOError(f"Font file not found at path: {font_path}")
+
+ # Create a dummy image to get drawing context
+ dummy_img = Image.new('RGB', (desired_width, 1))
+ draw = ImageDraw.Draw(dummy_img)
+
+ # Wrap text based on pixel width
+ lines = []
+ words = text.split()
+ if not words:
+ lines = ['']
+ else:
+ line = words[0]
+ for word in words[1:]:
+ test_line = f"{line} {word}"
+ # Use font.getlength instead of font.getsize
+ try:
+ line_width = font.getlength(test_line)
+ except AttributeError:
+ # Fallback for older Pillow versions
+ bbox = font.getbbox(test_line)
+ line_width = bbox[2] - bbox[0]
+
+ if line_width <= desired_width:
+ line = test_line
+ else:
+ lines.append(line)
+ line = word
+ lines.append(line)
+
+ # Calculate the height required for the text
+ ascent, descent = font.getmetrics()
+ line_height = ascent + descent + line_spacing
+ img_height = line_height * len(lines) + line_spacing
+
+ # Create the final image
+ img = Image.new('RGB', (desired_width, img_height), color=bg_color)
+ draw = ImageDraw.Draw(img)
+
+ # Draw each line of text
+ y_text = line_spacing
+ for line in lines:
+ draw.text((0, y_text), line, font=font, fill=text_color)
+ y_text += line_height
+
+ return img
+
+
+
+
+"""
+import os
+import random
+from PIL import Image, ImageEnhance
+import numpy as np
+import fiftyone as fo
+import fiftyone.zoo as foz
+from glob import glob
+
+coco_categories = [
+ "car", "motorcycle", "airplane", "bus", "train", "truck",
+ "boat", "fire hydrant", "cat", "dog", "horse",
+ "elephant", "bear", "zebra", "bowl", "banana",
+ "pizza", "couch", "bed",
+]
+
+dataset = foz.load_zoo_dataset(
+ "coco-2017",
+ split="validation", # Options: "train", "validation", "test"
+ label_types=["detections", "segmentations"],
+ classes=coco_categories,
+ max_samples=100 # Adjust as needed
+)
+"""
+
+def extract_objects_coco(dataset, output_dir, mask_ratio_threshold=0.7):
+ # extracted_objects_dir = "./objects"
+ # os.makedirs(extracted_objects_dir, exist_ok=True)
+ # extract_objects_coco(dataset, extracted_objects_dir)
+ """
+ Extracts objects from the COCO dataset based on segmentation masks.
+
+ Parameters:
+ - dataset: The FiftyOne dataset to process.
+ - output_dir: Directory where extracted objects will be saved.
+ - mask_ratio_threshold: Minimum ratio of mask area to bounding box area to include the object.
+ """
+ for sample in dataset:
+ # Load the original image
+ try:
+ image = Image.open(sample.filepath).convert("RGBA")
+ except Exception as e:
+ print(f"Error loading image {sample.filepath}: {e}")
+ continue
+
+ width, height = image.size
+
+ # Access segmentations
+ if not hasattr(sample, 'segmentations') or sample.segmentations is None:
+ continue # Skip samples without segmentations
+
+ segmentations = sample.segmentations.detections
+
+ for idx, segmentation in enumerate(segmentations):
+ # Get the mask as a boolean numpy array
+ mask_array = segmentation.mask # Shape: (mask_height, mask_width), dtype: bool
+
+ if mask_array is None:
+ continue # Skip if mask is not available
+
+ # Get the bounding box in absolute pixel coordinates
+ bbox = segmentation.bounding_box # [x_min, y_min, width, height] in relative coords
+ x_min = int(bbox[0] * width)
+ y_min = int(bbox[1] * height)
+ bbox_width = int(bbox[2] * width)
+ bbox_height = int(bbox[3] * height)
+
+ # Calculate mask area and bounding box area
+ mask_area = np.sum(mask_array)
+ bbox_area = bbox_width * bbox_height
+
+ if bbox_area == 0:
+ print(f"Bounding box has zero area for sample {sample.id}, detection {idx}. Skipping.")
+ continue
+
+ mask_ratio = mask_area / bbox_area
+
+ if mask_ratio < mask_ratio_threshold:
+ print(f"Mask ratio {mask_ratio:.2f} below threshold for sample {sample.id}, detection {idx}. Skipping.")
+ continue # Skip masks that don't meet the area ratio threshold
+
+ # Ensure the mask size matches the bounding box size
+ mask_height, mask_width = mask_array.shape
+ if (mask_width, mask_height) != (bbox_width, bbox_height):
+ print(f"Mask size {mask_array.shape} does not match bounding box size {(bbox_height, bbox_width)} for sample {sample.id}. Resizing mask.")
+ # Resize the mask to match the bounding box dimensions
+ mask_image = Image.fromarray(mask_array.astype(np.uint8) * 255, mode='L')
+ mask_image = mask_image.resize((bbox_width, bbox_height), Image.NEAREST)
+ mask_array = np.array(mask_image) > 0
+ else:
+ # Convert boolean mask to uint8
+ mask_uint8 = (mask_array * 255).astype(np.uint8)
+ # Create a PIL Image from the mask
+ mask_image = Image.fromarray(mask_uint8, mode='L')
+
+ # Create a full-sized mask and paste the object mask into it
+ full_mask = Image.new("L", (width, height))
+ try:
+ full_mask.paste(mask_image, (x_min, y_min))
+ except ValueError as ve:
+ print(f"Error pasting mask for sample {sample.id}, detection {idx}: {ve}")
+ continue
+
+ # Create an RGBA image for the object with transparency
+ object_image = Image.new("RGBA", (width, height))
+ try:
+ object_image.paste(image, mask=full_mask)
+ except ValueError as ve:
+ print(f"Error pasting image with mask for sample {sample.id}, detection {idx}: {ve}")
+ continue
+
+ # Calculate absolute bounding box coordinates
+ x_max = x_min + bbox_width
+ y_max = y_min + bbox_height
+
+ # Ensure coordinates are within image boundaries
+ x_min = max(x_min, 0)
+ y_min = max(y_min, 0)
+ x_max = min(x_max, width)
+ y_max = min(y_max, height)
+
+ diff_y = y_max - y_min
+ diff_x = x_max - x_min
+ if diff_x < 128 or diff_y < 128:
+ continue # Skip if the bounding box is too small
+
+ # Crop the object to its bounding box
+ object_crop = object_image.crop((x_min, y_min, x_max, y_max))
+
+ # Optional: Further crop using the mask to tightly bound the object
+ cropped_mask = np.array(full_mask)[y_min:y_max, x_min:x_max]
+ if not np.any(cropped_mask):
+ print(f"Empty mask after cropping for sample {sample.id}, detection {idx}. Skipping.")
+ continue # Skip if the mask is empty after cropping
+
+ # Find the bounding box of the non-zero regions in the cropped mask
+ ys, xs = np.where(cropped_mask)
+ if len(xs) == 0 or len(ys) == 0:
+ print(f"No non-zero pixels found in mask for sample {sample.id}, detection {idx}. Skipping.")
+ continue # Skip if no non-zero pixels
+
+ tight_x_min = xs.min()
+ tight_y_min = ys.min()
+ tight_x_max = xs.max()
+ tight_y_max = ys.max()
+
+ # Further crop the image
+ object_crop = object_crop.crop((tight_x_min, tight_y_min, tight_x_max + 1, tight_y_max + 1))
+
+ # Save the object image
+ object_class = segmentation.label.replace(" ", "_")
+ object_filename = f"{object_class}_{sample.id}_{idx}.png"
+ object_filepath = os.path.join(output_dir, object_filename)
+ try:
+ object_crop.save(object_filepath)
+ print(f"Saved object {object_filepath}")
+ except Exception as e:
+ print(f"Error saving object {object_filepath}: {e}")
+ continue
+
+def augment_image_with_random_object_coco(original_image, extracted_objects_dir):
+ import os
+ import random
+ from PIL import Image, ImageEnhance
+ import numpy as np
+ from glob import glob
+
+ try:
+ if isinstance(original_image, str):
+ original_image = Image.open(original_image).convert('RGBA')
+ elif isinstance(original_image, Image.Image):
+ original_image = original_image.convert('RGBA')
+ else:
+ raise ValueError(f"Unsupported type for original_image: {type(original_image)}")
+ except Exception as e:
+ print(f"Error loading original image {original_image}: {e}")
+ return
+
+ width, height = original_image.size
+
+ # Get a list of extracted object images
+ object_image_paths = glob(os.path.join(extracted_objects_dir, '*.png'))
+
+ if not object_image_paths:
+ print("No extracted object images found.")
+ return
+
+ # Choose a random object image
+ object_image_path = random.choice(object_image_paths)
+ try:
+ object_image = Image.open(object_image_path).convert('RGBA')
+ except Exception as e:
+ print(f"Error loading object image {object_image_path}: {e}")
+ return
+
+ # Resize object image
+ obj_width, obj_height = object_image.size
+
+ # Modify the scaling logic to ensure object fits
+ max_scale = min(width / obj_width, height / obj_height, 0.8) # Never scale larger than 70%
+ min_scale = min(0.3, max_scale) # Use either 0.2 or max_scale, whichever is smaller
+
+ if max_scale < 0.3: # If even 20% is too big, try another object
+ print("Selected object too large for image, choosing another...")
+ return None # Return None to indicate we need to try again
+
+ scaling_factor = random.uniform(min_scale, max_scale)
+ new_obj_width = int(obj_width * scaling_factor)
+ new_obj_height = int(obj_height * scaling_factor)
+ object_image = object_image.resize((new_obj_width, new_obj_height), Image.LANCZOS)
+
+ # Optional: Adjust brightness for better blending
+ brightness_factor = random.uniform(0.8, 1.2) # Slightly darken or brighten
+ enhancer = ImageEnhance.Brightness(object_image)
+ object_image = enhancer.enhance(brightness_factor)
+
+ # Randomize position
+ max_x = width - new_obj_width
+ max_y = height - new_obj_height
+ position = (random.randint(0, max_x), random.randint(0, max_y))
+
+ # Overlay object image
+ composite_image = original_image.copy()
+ try:
+ composite_image.paste(object_image, position, object_image)
+ except ValueError as ve:
+ print(f"Error pasting object onto original image: {ve}")
+ return
+
+ return composite_image.convert('RGB')
+
+# # Example usage
+# original_image_path = '00e36460e7a9adde.jpg' # Replace with your image path
+# output_image_path = 'fixed.png'
+# augment_image_with_random_object_coco(original_image_path, extracted_objects_dir, output_image_path)
diff --git a/unidisc/utils/xla_utils.py b/unidisc/utils/xla_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ddd777bdf7554514951fa89bc244a72954acbe8
--- /dev/null
+++ b/unidisc/utils/xla_utils.py
@@ -0,0 +1,293 @@
+from decoupled_utils import is_torch_xla_available, rprint, gprint
+import torch
+from functools import partial
+
+# Wrap the base model with an outer FSDP wrapper
+def shard_output(output, mesh):
+ import torch_xla.distributed.spmd as xs
+ from transformers.modeling_outputs import CausalLMOutputWithPast
+
+ real_output = None
+ if isinstance(output, torch.Tensor):
+ real_output = output
+ elif isinstance(output, tuple):
+ real_output = output[0]
+ elif isinstance(output, CausalLMOutputWithPast):
+ real_output = output.logits
+
+ xs.mark_sharding(real_output, mesh, (('dcn', 'fsdp'), None, None))
+
+
+from typing import Set, Type
+import torch.nn as nn
+
+def _module_wrap_policy(
+ module: nn.Module,
+ recurse: bool,
+ nonwrapped_numel: int,
+ module_classes: Set[Type[nn.Module]],
+ min_num_params: int,
+) -> bool:
+ """
+ This auto wrap policy wraps every module that is an instance of any type in
+ ``module_classes`` as its own FSDP instance. The root module given by
+ ``module`` is always wrapped as an FSDP instance regardless. Since the
+ wrapping proceeds bottom up, each FSDP instance manages the parameters in
+ its subtree excluding any already managed by a child FSDP instance.
+
+ Args:
+ module (nn.Module): Current module being considered.
+ recurse (bool): If ``False``, then this function must decide whether
+ ``module`` should be wrapped as an FSDP instance or not. If
+ ``True``, then the function is still recursing down the module
+ tree as a part of the DFS.
+ nonwrapped_numel (int): Parameter numel not yet wrapped.
+ module_classes (Set[Type[nn.Module]]): Set of module classes that are
+ wrapped as FSDP instances.
+
+ Returns:
+ ``True`` if ``recurse=True``, and whether ``module`` should be wrapped
+ if ``recurse=False``.
+ """
+
+ print(f"Found {module.__class__.__name__} with {nonwrapped_numel} parameters; we have min_num_params={min_num_params}")
+
+ if recurse and nonwrapped_numel >= min_num_params:
+ print(f"Recursing down {module.__class__.__name__}")
+ return True # always recurse
+
+ if nonwrapped_numel >= min_num_params and isinstance(module, tuple(module_classes)):
+ print(f"Wrapping {module.__class__.__name__}")
+ return isinstance(module, tuple(module_classes)) and nonwrapped_numel >= min_num_params
+
+def transformer_auto_wrap_policy(
+ module: nn.Module,
+ recurse: bool,
+ unwrapped_params: int,
+ transformer_layer_cls: Set[Type[nn.Module]],
+ min_num_params: int = int(1e6)
+) -> bool:
+ """
+ See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the
+ same as ``module_classes``. Note that shared parameters must be wrapped in
+ the same FSDP instance, so this auto wrap policy can help wrap shared
+ embeddings into the same FSDP instance for transformer models.
+ """
+ return _module_wrap_policy(module, recurse, unwrapped_params, transformer_layer_cls, min_num_params)
+
+# Taken from HF Transformers
+def wrap_xla_fsdp(config, model):
+ import torch_xla.distributed.spmd as spmd
+ import torch_xla.core.xla_model as xm
+ import torch_xla.debug.metrics as met
+ import torch_xla.distributed.spmd as xs
+ import torch_xla.runtime as xr
+ import torch_xla
+
+ is_fsdp_xla_v2_enabled = True
+ try:
+ from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
+ from torch_xla.distributed.fsdp import checkpoint_module
+ from torch_xla.distributed.fsdp.wrap import (
+ size_based_auto_wrap_policy,
+ transformer_auto_wrap_policy
+ )
+
+ if is_fsdp_xla_v2_enabled:
+ from torch_xla.experimental.spmd_fully_sharded_data_parallel import (
+ SpmdFullyShardedDataParallel as FSDPv2,
+ )
+ except ImportError:
+ raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
+
+ # See: https://github.com/pytorch-tpu/transformers/blob/alanwaketan/flash_attention/examples/pytorch/language-modeling/run_clm.py
+ auto_wrap_policy = None
+ auto_wrapper_callable = None
+ fsdp_transformer_layer_cls_to_wrap = ["DDiTBlock", "ChameleonDecoderLayer", "OpenELMDecoderLayer"]
+
+ if getattr(config.trainer, "fsdp_size_based_auto_wrap", False):
+ auto_wrap_policy = partial(
+ size_based_auto_wrap_policy,
+ min_num_params=int(1e6),
+ )
+ elif fsdp_transformer_layer_cls_to_wrap is not None:
+ from transformers.trainer_pt_utils import get_module_class_from_name
+ transformer_cls_to_wrap = set()
+ found_valid_layer = False
+ for layer_class in fsdp_transformer_layer_cls_to_wrap:
+ transformer_cls = get_module_class_from_name(model, layer_class)
+ if transformer_cls is not None:
+ transformer_cls_to_wrap.add(transformer_cls)
+ rprint(f"Found valid layer: {layer_class}")
+ found_valid_layer = True
+
+ if not found_valid_layer:
+ raise Exception("Could not find the transformer layer class to wrap in the model.")
+
+ auto_wrap_policy = partial(
+ transformer_auto_wrap_policy,
+ transformer_layer_cls=transformer_cls_to_wrap,
+ )
+
+ gradient_checkpointing = config.trainer.use_gradient_checkpointing
+ if gradient_checkpointing:
+ def auto_wrapper_callable(m, *args, **kwargs):
+ target_cls = FSDP if not is_fsdp_xla_v2_enabled else FSDPv2
+ return target_cls(checkpoint_module(m), *args, **kwargs)
+
+ patch_xla_linear = getattr(config.trainer, "patch_xla_linear", False)
+ if patch_xla_linear:
+ rprint("WARNING!!!! Patching XLA Linear")
+ from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
+ model = apply_xla_patch_to_nn_linear(model, xs.xla_patched_nn_linear_forward)
+
+ # Patch `xm.optimizer_step` should not reduce gradients in this case,
+ # as FSDP does not need gradient reduction over sharded parameters.
+ patch_xla_optimizer_step = getattr(config.trainer, "patch_xla_optimizer_step", True)
+ if patch_xla_optimizer_step:
+ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
+ loss = optimizer.step(**optimizer_args)
+ if barrier:
+ xm.mark_step()
+ return loss
+ xm.optimizer_step = patched_optimizer_step
+
+ custom_spmd_wrap = getattr(config.trainer, "custom_spmd_wrap", False)
+ custom_chameleon_spmd_wrap = getattr(config.trainer, "custom_chameleon_spmd_wrap", False)
+ auto_spmd = getattr(config.trainer, "auto_spmd", False)
+
+ if auto_spmd:
+ pass
+ elif custom_spmd_wrap:
+ # Replace the meta tensor parameter with the initialized XLA tensor
+ # Shard each parameter in the model based on the sharding strategy provided.
+ gprint("Custom SPMD wrap enabled")
+ spmd_mesh = xs.get_global_mesh()
+ spmd_fsdp_sharding = True
+ spmd_2d_sharding = 0
+ for name, param in model.named_parameters():
+ if spmd_fsdp_sharding:
+ print('> [FSDP] Sharding tensor', name, param.shape, param.dtype)
+ # We don't care about layernorm's weights, and
+ # LLaMA doesn't use biases.
+ if len(param.shape) == 1:
+ continue
+ assert len(param.shape) == 2
+
+ # Shard the largest dimension
+ if param.shape[0] > param.shape[1]:
+ partition_spec = ('fsdp', None)
+ else:
+ partition_spec = (None, 'fsdp')
+ xs.mark_sharding(param, spmd_mesh, partition_spec)
+ elif spmd_2d_sharding > 0:
+ # Apply 2D sharding:
+ print('> [2D] Sharding tensor', name, param.shape)
+
+ # We don't care about layernorm's weights, and
+ # LLaMA doesn't use biases.
+ if len(param.shape) == 1:
+ continue
+
+ if 'embed_tokens' in name:
+ xs.mark_sharding(param, spmd_mesh, ('model', 'fsdp'))
+ elif 'q_proj' in name or 'k_proj' in name or 'v_proj' in name:
+ xs.mark_sharding(param, spmd_mesh, ('fsdp', 'model'))
+ elif 'o_proj' in name:
+ xs.mark_sharding(param, spmd_mesh, ('model', 'fsdp'))
+ elif 'gate_proj' in name or 'up_proj' in name:
+ xs.mark_sharding(param, spmd_mesh, ('model', 'fsdp'))
+ elif 'down_proj' in name:
+ xs.mark_sharding(param, spmd_mesh, ('fsdp', 'model'))
+ elif 'lm_head' in name:
+ xs.mark_sharding(param, spmd_mesh, ('model', 'fsdp'))
+
+ print(f'{name} {torch_xla._XLAC._get_xla_sharding_spec(param)}')
+
+ for i in range(len(model.model.blocks)):
+ spmd.xla_sharding.apply_backward_optimization_barrier(model.blocks[i])
+
+ elif custom_chameleon_spmd_wrap:
+ # Replace the meta tensor parameter with the initialized XLA tensor
+ # Shard each parameter in the model based on the sharding strategy provided.
+ gprint("Custom Chameleon SPMD wrap enabled")
+ spmd_mesh = xs.get_global_mesh()
+ spmd_fsdp_sharding = True
+ spmd_2d_sharding = 0
+ for name, param in model.named_parameters():
+ # We don't care about layernorm's weights, and
+ # LLaMA doesn't use biases.
+ if len(param.shape) == 1:
+ gprint(f"Skipping shard of {name} with {param.numel()} elements because it is 1D, shape: {param.shape}")
+ continue
+
+ if param.requires_grad is False and getattr(config.trainer, "no_shard_grad_false", False):
+ gprint(f"Skipping shard of {name} with {param.numel()} elements because requires_grad is False, shape: {param.shape}")
+ continue
+
+ if param.numel() < int(1e6) and getattr(config.trainer, "no_shard_small", False):
+ gprint(f"Skipping shard of {name} with {param.numel()} elements, shape: {param.shape}")
+ continue
+
+ assert len(param.shape) == 2
+
+ print('> [FSDP] Sharding tensor', name, param.shape, param.dtype)
+
+ # Shard the largest dimension
+ if param.shape[0] > param.shape[1]:
+ partition_spec = ('fsdp', None)
+ else:
+ partition_spec = (None, 'fsdp')
+
+ xs.mark_sharding(param, spmd_mesh, partition_spec)
+ print(f'{name} {torch_xla._XLAC._get_xla_sharding_spec(param)}')
+
+ gprint(f"Model: {model.__class__.__name__}")
+ for block in (model.base_model.model.model.layers if config.model.use_lora else model.model.layers):
+ gprint(f"Applying barrier to {block.__class__.__name__}")
+ spmd.xla_sharding.apply_backward_optimization_barrier(block)
+
+ else:
+ gprint(f"Using FSDPv2, {xs.get_global_mesh()}")
+ model = FSDPv2(
+ model,
+ shard_output=shard_output,
+ auto_wrap_policy=auto_wrap_policy,
+ auto_wrapper_callable=auto_wrapper_callable,
+ )
+
+ for name, param in model.named_parameters():
+ if param.requires_grad is False or param.numel() < int(1e6):
+ xs.clear_sharding(param)
+ xs.mark_sharding(param, xs.get_global_mesh(), tuple([None] * len(param.shape)))
+
+ if torch_xla._XLAC._get_xla_sharding_spec(param) != "":
+ gprint(f'Sharding {name} {param.shape} requires_grad={param.requires_grad} numel={param.numel()} {torch_xla._XLAC._get_xla_sharding_spec(param)}')
+
+ return model
+
+
+def tpu_spmd_dataloader(dataloader, device):
+ if is_torch_xla_available():
+ import torch_xla.distributed.spmd as xs
+ import torch_xla
+ from torch_xla.distributed.parallel_loader import MpDeviceLoader
+
+ sharding_spec = xs.ShardingSpec(xs.get_global_mesh(), (('dcn', 'fsdp'), None))
+ if isinstance(dataloader, MpDeviceLoader):
+ rprint("Modifying existing MpDeviceLoader")
+ dataloader._parallel_loader_kwargs["input_sharding"] = sharding_spec
+ else:
+ rprint("Creating MpDeviceLoader")
+ rprint(f"Drop Last: {dataloader.drop_last}")
+ loader = MpDeviceLoader(
+ dataloader,
+ device=torch_xla.device(),
+ input_sharding=sharding_spec,
+ )
+ loader.dataset = dataloader.dataset
+ dataloader = loader
+ return dataloader
+ else:
+ return dataloader
+
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..571c8390ba62a49c55a852992612c03bcfd09e6e
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,527 @@
+"""Console logger utilities.
+
+Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
+Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
+"""
+
+import math
+import os
+from pathlib import Path
+from typing import List, Optional
+
+import fsspec
+import torch
+from timm.scheduler import CosineLRScheduler
+import omegaconf
+import rich
+import rich.syntax
+import rich.tree
+
+from decoupled_utils import rank_zero_fn, rprint
+from decoupled_utils import (get_hostname, get_num_devices, get_tpu_devices, gprint,
+ is_torch_cuda_available, is_torch_xla_available, rprint)
+
+
+def print_trainable_parameters(model):
+ trainable_params = 0
+ all_param = 0
+ for _, param in model.named_parameters():
+ all_param += param.numel()
+ if param.requires_grad:
+ trainable_params += param.numel()
+ print(
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
+ )
+
+def fsspec_exists(filename):
+ """Check if a file exists using fsspec."""
+ fs, _ = fsspec.core.url_to_fs(filename)
+ return fs.exists(filename)
+
+
+def fsspec_listdir(dirname):
+ """Listdir in manner compatible with fsspec."""
+ fs, _ = fsspec.core.url_to_fs(dirname)
+ return fs.ls(dirname)
+
+
+def fsspec_mkdirs(dirname, exist_ok=True):
+ """Mkdirs in manner compatible with fsspec."""
+ fs, _ = fsspec.core.url_to_fs(dirname)
+ fs.makedirs(dirname, exist_ok=exist_ok)
+
+
+def print_nans(tensor, name):
+ if torch.isnan(tensor).any():
+ gprint(f"{name} has nans: {tensor}")
+
+
+class CosineDecayWarmupLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler):
+ """Wrap timm.scheduler.CosineLRScheduler
+ Enables calling scheduler.step() without passing in epoch.
+ Supports resuming as well.
+ Adapted from:
+ https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._last_epoch = -1
+ self.step(epoch=0)
+
+ def step(self, epoch=None):
+ if epoch is None:
+ self._last_epoch += 1
+ else:
+ self._last_epoch = epoch
+ # We call either step or step_update, depending on
+ # whether we're using the scheduler every epoch or every
+ # step.
+ # Otherwise, lightning will always call step (i.e.,
+ # meant for each epoch), and if we set scheduler
+ # interval to "step", then the learning rate update will
+ # be wrong.
+ if self.t_in_epochs:
+ super().step(epoch=self._last_epoch)
+ else:
+ super().step_update(num_updates=self._last_epoch)
+
+
+class Sampler:
+ def __init__(self, shape):
+ self.shape = shape
+
+ def _sampling_noise(self):
+ pass
+
+ def _hard_sample(self, logits):
+ pass
+
+ def _soft_sample(self, logits):
+ return 0
+
+ def sample(self, logits):
+ noise = self._sampling_noise()
+ noise = noise[: logits.shape[0], :]
+ logits = logits + noise.to(dtype=logits.dtype, device=logits.device)
+ hard_sample = self._hard_sample(logits)
+ soft_sample = self._soft_sample(logits)
+ return soft_sample + (hard_sample - soft_sample).detach()
+
+
+class TopKSampler(Sampler):
+ def __init__(self, k, shape, gamma_tau=1.0):
+ super().__init__(shape)
+ self.k = k
+ self.gamma_tau = gamma_tau
+ self.num_betas = 10
+ self.sampler = torch.distributions.gamma.Gamma(1 / k * torch.ones(self.num_betas, *self.shape), 1.0)
+
+ def _sampling_noise(self):
+ noise = self.sampler.sample()
+ beta = self.k / torch.arange(1, self.num_betas + 1, 1, dtype=torch.float32)
+ beta = beta[:, None, None]
+ assert beta.ndim == noise.ndim
+ s = noise / beta
+ s = torch.sum(s, axis=0)
+ s = s - math.log(10.0)
+ s = self.gamma_tau * (s / self.k)
+ return s
+
+ def _hard_sample(self, logits):
+ assert logits.ndim == 2
+ thresholds, _ = torch.sort(logits, dim=-1)
+ thresholds = thresholds[:, -self.k][:, None]
+ return (logits >= thresholds).type(logits.dtype)
+
+ def _soft_sample(self, logits):
+ soft_top_k = logits - torch.mean(logits, dim=-1, keepdim=True)
+ return soft_top_k / torch.norm(soft_top_k, dim=-1, keepdim=True)
+
+
+class DeterministicTopK(TopKSampler):
+ def __init__(self, k):
+ super().__init__(k, shape=(1, 1))
+
+ def _sampling_noise(self):
+ return 0
+
+ def discreize(self, x):
+ hard_sample = self._hard_sample(x)
+ soft_sample = self._soft_sample(x)
+ return soft_sample + (hard_sample - soft_sample).detach()
+
+
+class GumbelSampler(Sampler):
+
+ def __init__(self, shape, temperature=1.0):
+ super().__init__(shape)
+ self.temperature = temperature
+
+ def _sampling_noise(self):
+ return -(1e-10 - (torch.rand(*self.shape) + 1e-10).log()).log()
+
+ def _hard_sample(self, logits):
+ assert logits.ndim == 2
+ indices = torch.argmax(logits, dim=-1)
+ zeros = logits * 0
+ ones = torch.ones_like(logits[:, :, :1])
+ return torch.scatter(zeros, -1, indices[:, :, None], ones)
+
+ def _soft_sample(self, logits):
+ return torch.nn.functional.softmax(logits / self.temperature, dim=-1)
+
+
+class BinarySampler(GumbelSampler):
+
+ def sample(self, probs):
+ # TODO(subhamsahoo): use the temperature parameter.
+ pos_noise = self._sampling_noise().to(dtype=probs.dtype, device=probs.device)
+ neg_noise = self._sampling_noise().to(dtype=probs.dtype, device=probs.device)
+ del_noise_exp = (neg_noise - pos_noise).exp()
+ hard_sample = (probs * (1 + del_noise_exp) > 1).to(probs.dtype)
+ soft_sample = probs / (probs + (1 - probs) * del_noise_exp)
+ return soft_sample + (hard_sample - soft_sample).detach()
+
+
+class GaussianSampler:
+ def __init__(self):
+ self.softplus = torch.nn.Softplus()
+
+ def sample(self, x):
+ assert x.ndim == 2
+ n = x.shape[-1] // 2
+ mu = x[:, :n]
+ sigma = self.softplus(x[:, n:]).sqrt()
+ return mu + sigma * torch.randn_like(mu)
+
+
+def is_global_rank_zero():
+ """Helper function to determine if the current process is global_rank 0 (the main process)"""
+ # Try to get the pytorch RANK env var
+ # RANK is set by torch.distributed.launch
+ rank = os.environ.get("RANK", None)
+ if rank is not None:
+ return rank == 0
+
+ # Try to get the SLURM global rank env var
+ # SLURM_PROCID is set by SLURM
+ slurm_rank = os.environ.get("SLURM_PROCID", None)
+ if slurm_rank is not None:
+ return slurm_rank == 0
+
+ # Try to get the MPI global rank env var
+ mpi_rank = os.environ.get("OMPI_COMM_WORLD_RANK", None)
+ if mpi_rank is not None:
+ return mpi_rank == 0
+
+ # if neither pytorch, SLURM nor MPI env vars are set
+ # check NODE_RANK/GROUP_RANK and LOCAL_RANK env vars
+ # assume global_rank is zero if undefined
+ node_rank = os.environ.get("NODE_RANK", os.environ.get("GROUP_RANK", 0))
+ local_rank = os.environ.get("LOCAL_RANK", 0)
+ return node_rank == 0 and local_rank == 0
+
+
+def get_rank():
+ """Helper function that returns torch.distributed.get_rank() if DDP has been initialized otherwise it returns 0."""
+
+ if is_global_rank_zero():
+ return 0
+ else:
+ return torch.distributed.get_rank()
+
+
+def set_numa_affinity(gpu_index, verbose=False):
+ import pynvml as nvml
+
+ nvml.nvmlInit()
+ """This util will assign to the current process the cpu cores set that resides on the same NUMA
+ node as the GPU. Typically if you have 8 GPUs, then the first 4 are on the first NUMA node and
+ the remaining 4 are on the second.
+
+ `gpu_index` is typically the same as `LOCAL_RANK` in the distributed training, but beware that
+ `CUDA_VISIBLE_DEVICES` could impact that. e.g. `CUDA_VISIBLE_DEVICES=0,7` won't do the right
+ thing - then you will probably want to remap the ids with something like:
+
+ ```
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
+ ids = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")))
+ gpu_index = ids[gpu_index] # remap
+ ```
+
+ """
+
+ num_elements = math.ceil(os.cpu_count() / 64)
+ handle = nvml.nvmlDeviceGetHandleByIndex(gpu_index)
+ affinity_string = ""
+ for j in nvml.nvmlDeviceGetCpuAffinity(handle, num_elements):
+ # assume nvml returns list of 64 bit ints
+ affinity_string = f"{j:064b}{affinity_string}"
+ affinity_list = [int(x) for x in affinity_string]
+ affinity_list.reverse() # so core 0 is the 0th element
+ affinity_to_set = [i for i, e in enumerate(affinity_list) if e != 0]
+
+ if verbose:
+ cores = os.sched_getaffinity(0)
+ gprint(f"before: {len(cores)} visible cpu cores: {cores}")
+
+ try:
+ os.sched_setaffinity(0, affinity_to_set)
+ except Exception as e:
+ gprint(f"Failed to set affinity: {e}")
+
+ if verbose:
+ cores = os.sched_getaffinity(0)
+ gprint(f"after: {len(cores)} visible cpu cores: {cores}")
+
+
+from typing import Dict, Union
+
+import torch
+from torch.nn import Module
+
+
+def grad_norm(module: Module, norm_type: Union[float, int, str], group_separator: str = "/") -> Dict[str, float]:
+ """Compute each parameter's gradient's norm and their overall norm.
+
+ The overall norm is computed over all gradients together, as if they
+ were concatenated into a single vector.
+
+ Args:
+ module: :class:`torch.nn.Module` to inspect.
+ norm_type: The type of the used p-norm, cast to float if necessary.
+ Can be ``'inf'`` for infinity norm.
+ group_separator: The separator string used by the logger to group
+ the gradients norms in their own subfolder instead of the logs one.
+
+ Return:
+ norms: The dictionary of p-norms of each parameter's gradient and
+ a special entry for the total p-norm of the gradients viewed
+ as a single vector.
+
+ """
+ norm_type = float(norm_type)
+ if norm_type <= 0:
+ raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}")
+
+ norms = {f"{group_separator}{name}": p.grad.data.norm(norm_type) for name, p in module.named_parameters() if p.grad is not None}
+ total_norm = torch.tensor(list(norms.values())).norm(norm_type)
+ return norms, total_norm
+
+has_set_omega_conf_resolvers = False
+
+def set_omega_conf_resolvers():
+ global has_set_omega_conf_resolvers
+ if has_set_omega_conf_resolvers:
+ return
+ has_set_omega_conf_resolvers = True
+ import omegaconf
+ from omegaconf import OmegaConf
+
+ def get_dir_name(_root_):
+ if str(_root_.mode) == "eval":
+ return "eval"
+ elif _root_.debug:
+ return "debug"
+ else:
+ return _root_.data.train
+
+ def getpythoncmd(_root_):
+ return _root_.python_orig + "--multi_gpu \\\n" if (_root_.trainer.devices * _root_.trainer.num_nodes > 1) else _root_.python_orig
+
+ def custom_batch_size():
+ if is_torch_cuda_available() and torch.cuda.get_device_properties(0).total_memory >= 23 * 1024 * 1024 * 1024:
+ return 64
+ elif is_torch_cuda_available() and torch.cuda.get_device_properties(0).total_memory >= 10 * 1024 * 1024 * 1024:
+ return 32
+ else:
+ return 28
+
+ def get_slurm_name(_root_):
+ return _root_.slurm_name if hasattr(_root_, "slurm_name") and _root_.slurm_name is not None else _root_.wandb.project
+
+ partition_time_limit_min = {
+ "partition_name": 60 * 6,
+ }
+
+ gpu_constraints = {
+ "cluster_name": "gpu_constraints", # e.g. "A5000|A6000"
+ }
+
+ partitions = {
+ "cluster_name": "partition_name",
+ }
+
+
+ babel_exclude_nodes = set()
+ if os.environ.get("BAD_NODES", None) is not None:
+ babel_exclude_nodes.update(os.environ.get("BAD_NODES").split(","))
+
+ exclude_nodes = {
+ "cluster_name": "nodes_to_exclude",
+ }
+
+ def get_hostname_split():
+ return get_hostname().split("-")[0].split(".")[0]
+
+ omegaconf.OmegaConf.register_new_resolver("getpythoncmd", getpythoncmd)
+ omegaconf.OmegaConf.register_new_resolver("get_dir_name", get_dir_name)
+ omegaconf.OmegaConf.register_new_resolver("cwd", os.getcwd)
+ omegaconf.OmegaConf.register_new_resolver("device_count", get_num_devices)
+ omegaconf.OmegaConf.register_new_resolver("eval", eval)
+ omegaconf.OmegaConf.register_new_resolver("div_up", lambda x, y: (x + y - 1) // y)
+ omegaconf.OmegaConf.register_new_resolver("find_grad_accum", lambda x, y: round(x / y))
+ omegaconf.OmegaConf.register_new_resolver("find_partition", lambda: partitions[get_hostname_split()] if get_hostname_split() in partitions else "all")
+ omegaconf.OmegaConf.register_new_resolver("find_constraint", lambda: gpu_constraints[get_hostname_split()] if get_hostname_split() in gpu_constraints else "")
+ omegaconf.OmegaConf.register_new_resolver("is_ar", lambda parameterization: parameterization == "ar")
+ omegaconf.OmegaConf.register_new_resolver("kv_cache_batch_size", lambda eval_batch_size, cfg: eval_batch_size * 2 if cfg is not None else eval_batch_size)
+ omegaconf.OmegaConf.register_new_resolver("exclude_nodes", lambda: exclude_nodes[get_hostname_split()] if get_hostname_split() in exclude_nodes else "")
+ omegaconf.OmegaConf.register_new_resolver("get_slurm_name", get_slurm_name)
+
+
+ def adjust_n_blocks(_root_):
+ return (
+ (_root_.model.base_n_blocks - 1 if _root_.model.base_n_blocks < 24 else _root_.model.base_n_blocks - 2)
+ if str(_root_.backbone) == "maskdit"
+ else _root_.model.base_n_blocks
+ )
+
+ omegaconf.OmegaConf.register_new_resolver("adjust_n_blocks", adjust_n_blocks)
+ omegaconf.OmegaConf.register_new_resolver("partition_limit", lambda x: partition_time_limit_min[x] if x in partition_time_limit_min else 60 * 6)
+ omegaconf.OmegaConf.register_new_resolver("custom_batch_size", custom_batch_size)
+ omegaconf.OmegaConf.register_new_resolver("get_repo_dir", lambda: os.getenv("UNIDISC_DIR", str(Path(__file__).parent)))
+
+
+@rank_zero_fn
+def _print_config(config, resolve: bool = True, save_cfg: bool = True) -> None:
+ """Prints content of DictConfig using Rich library and its tree structure.
+
+ Args:
+ config (DictConfig): Configuration composed by Hydra.
+ resolve (bool): Whether to resolve reference fields of DictConfig.
+ save_cfg (bool): Whether to save the configuration tree to a file.
+ """
+
+ style = "dim"
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
+
+ fields = config.keys()
+ for field in fields:
+ branch = tree.add(field, style=style, guide_style=style)
+
+ config_section = config.get(field)
+ branch_content = str(config_section)
+ if isinstance(config_section, omegaconf.DictConfig):
+ branch_content = omegaconf.OmegaConf.to_yaml(config_section, resolve=resolve)
+
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
+
+ rich.print(tree)
+ if save_cfg:
+ with fsspec.open("config_tree.txt", "w") as fp:
+ rich.print(tree, file=fp)
+
+def set_torch_defaults(benchmark=True):
+ torch.set_float32_matmul_precision("medium")
+ if is_torch_cuda_available():
+ rprint(f"Setting torch defaults")
+ exec("import torch.backends.cuda as cuda")
+ exec("import torch.backends.cudnn as cudnn")
+ exec("cudnn.enabled = True")
+ if benchmark:
+ exec("cudnn.benchmark = True")
+ else:
+ rprint(f"Warning: Not benchmarking")
+ exec("cudnn.allow_tf32 = True")
+ exec("cuda.matmul.allow_tf32 = True")
+ exec("cudnn.deterministic = False")
+ else:
+ rprint(f"Warning: CUDA not available. Not setting defaults.")
+
+from torch.distributed.elastic.multiprocessing.errors import (ChildFailedError,
+ record)
+from torch.distributed.elastic.multiprocessing.errors.handlers import \
+ get_error_handler
+
+_NOT_AVAILABLE = ""
+class ErrorHandler:
+ def __init__(self, error_handler=None):
+ self.error_handler = error_handler or get_error_handler()
+
+ def __enter__(self):
+ assert self.error_handler is not None
+ self.error_handler.initialize()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ if exc_type is not None:
+ if issubclass(exc_type, SystemExit) and exc_value.code == 0:
+ return True # Prevents SystemExit with code 0 from stopping the program
+ elif issubclass(exc_type, ChildFailedError):
+ rank, failure = exc_value.get_first_failure()
+ if failure.error_file != _NOT_AVAILABLE:
+ self.error_handler.dump_error_file(failure.error_file, failure.exitcode)
+ else:
+ rprint(
+ "local_rank %s FAILED with no error file. "
+ "Decorate your entrypoint fn with @record for traceback info. "
+ "See: https://pytorch.org/docs/stable/elastic/errors.html",
+ rank
+ )
+ return False # Re-raises the exception
+ self.error_handler.record_exception(exc_value)
+ return False # Any other exceptions will be re-raised
+
+
+def convert_state_dict_keys(state_dict):
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if "attn_out" in k:
+ new_key = k.replace("attn_out", "attention.attn_out")
+ elif "attn_qkv" in k:
+ new_key = k.replace("attn_qkv", "attention.attn_qkv")
+ else:
+ new_key = k
+ new_state_dict[new_key] = v
+ return new_state_dict
+
+from accelerate.utils import extract_model_from_parallel
+def apply_compile(model, **compile_kwargs):
+ """
+ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
+ repeated structure. Alternatively one can compile the whole model (after applying DP).
+ """
+ for layer_id, transformer_block in extract_model_from_parallel(model).blocks.named_children():
+ transformer_block = torch.compile(transformer_block, **compile_kwargs)
+ extract_model_from_parallel(model).blocks.register_module(layer_id, transformer_block)
+
+ output_layer = torch.compile(extract_model_from_parallel(model).output_layer, **compile_kwargs)
+ extract_model_from_parallel(model).register_module("output_layer", output_layer)
+
+def compile_model(config, model):
+ compile_kwargs = dict()
+
+ if config.backbone == "maskdit":
+ compile_kwargs["dynamic"] = True
+
+ compile_kwargs["mode"] = config.trainer.compile_mode
+ rprint(f"Using compile mode: {config.trainer.compile_mode}")
+
+ if getattr(config.trainer, "sd3_compile_config", True):
+ torch._inductor.config.conv_1x1_as_mm = True
+ torch._inductor.config.coordinate_descent_tuning = True
+ torch._inductor.config.epilogue_fusion = False
+ torch._inductor.config.coordinate_descent_check_all_directions = True
+ rprint(f"Using SD3 compile config")
+
+ if config.trainer.compile_fullgraph:
+ compile_kwargs["fullgraph"] = True
+ rprint(f"Using fullgraph compile")
+
+ if getattr(config.trainer, "compile_per_layer", False):
+ apply_compile(model, **compile_kwargs)
+ else:
+ model = torch.compile(model, **compile_kwargs)
+
+ return model
diff --git a/uv.lock b/uv.lock
new file mode 100644
index 0000000000000000000000000000000000000000..c0151ab2271ff400b53972809f70440801375fc6
--- /dev/null
+++ b/uv.lock
@@ -0,0 +1,3915 @@
+version = 1
+revision = 1
+requires-python = ">=3.10, <3.13"
+resolution-markers = [
+ "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
+ "python_full_version >= '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux'",
+ "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
+ "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux'",
+ "python_full_version >= '3.12.4' and sys_platform == 'win32'",
+ "python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform == 'win32'",
+ "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
+ "python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux'",
+ "python_full_version == '3.11.*' and sys_platform == 'win32'",
+ "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'",
+ "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux'",
+ "python_full_version < '3.11' and sys_platform == 'win32'",
+ "python_full_version >= '3.12.4' and sys_platform == 'darwin'",
+ "python_full_version >= '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'",
+ "python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform == 'darwin'",
+ "python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'",
+ "python_full_version == '3.11.*' and sys_platform == 'darwin'",
+ "python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'",
+ "python_full_version < '3.11' and sys_platform == 'darwin'",
+ "python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'",
+]
+
+[[package]]
+name = "absl-py"
+version = "2.1.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/7a/8f/fc001b92ecc467cc32ab38398bd0bfb45df46e7523bf33c2ad22a505f06e/absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff", size = 118055 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/a2/ad/e0d3c824784ff121c03cc031f944bc7e139a8f1870ffd2845cc2dd76f6c4/absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308", size = 133706 },
+]
+
+[[package]]
+name = "accelerate"
+version = "1.5.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "huggingface-hub" },
+ { name = "numpy" },
+ { name = "packaging" },
+ { name = "psutil" },
+ { name = "pyyaml" },
+ { name = "safetensors" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/a9/4c/a61132924da12cef62a88c04b5825246ab83dcc1bae6291d098cfcb0b72d/accelerate-1.5.2.tar.gz", hash = "sha256:a1cf39473edc0e42772a9d9a18c9eb1ce8ffd9e1719dc0ab80670f5c1fd4dc43", size = 352341 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/70/83/167d4b638bb758a966828eb8d23c5e7047825edfdf768ff5f4fb01440063/accelerate-1.5.2-py3-none-any.whl", hash = "sha256:68a3b272f6a6ffebb457bdc138581a2bf52efad6a5e0214dc46675f3edd98792", size = 345146 },
+]
+
+[[package]]
+name = "aiohappyeyeballs"
+version = "2.4.4"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/7f/55/e4373e888fdacb15563ef6fa9fa8c8252476ea071e96fb46defac9f18bf2/aiohappyeyeballs-2.4.4.tar.gz", hash = "sha256:5fdd7d87889c63183afc18ce9271f9b0a7d32c2303e394468dd45d514a757745", size = 21977 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b9/74/fbb6559de3607b3300b9be3cc64e97548d55678e44623db17820dbd20002/aiohappyeyeballs-2.4.4-py3-none-any.whl", hash = "sha256:a980909d50efcd44795c4afeca523296716d50cd756ddca6af8c65b996e27de8", size = 14756 },
+]
+
+[[package]]
+name = "aiohttp"
+version = "3.11.12"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "aiohappyeyeballs" },
+ { name = "aiosignal" },
+ { name = "async-timeout", marker = "python_full_version < '3.11'" },
+ { name = "attrs" },
+ { name = "frozenlist" },
+ { name = "multidict" },
+ { name = "propcache" },
+ { name = "yarl" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/37/4b/952d49c73084fb790cb5c6ead50848c8e96b4980ad806cf4d2ad341eaa03/aiohttp-3.11.12.tar.gz", hash = "sha256:7603ca26d75b1b86160ce1bbe2787a0b706e592af5b2504e12caa88a217767b0", size = 7673175 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/65/42/3880e133590820aa7bc6d068eb7d8e0ad9fdce9b4663f92b821d3f6b5601/aiohttp-3.11.12-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:aa8a8caca81c0a3e765f19c6953416c58e2f4cc1b84829af01dd1c771bb2f91f", size = 708721 },
+ { url = "https://files.pythonhosted.org/packages/d8/8c/04869803bed108b25afad75f94c651b287851843caacbec6677d8f2d572b/aiohttp-3.11.12-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:84ede78acde96ca57f6cf8ccb8a13fbaf569f6011b9a52f870c662d4dc8cd854", size = 468596 },
+ { url = "https://files.pythonhosted.org/packages/4f/f4/9074011f0d1335b161c953fb32545b6667cf24465e1932b9767874995c7e/aiohttp-3.11.12-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:584096938a001378484aa4ee54e05dc79c7b9dd933e271c744a97b3b6f644957", size = 455758 },
+ { url = "https://files.pythonhosted.org/packages/fd/68/06298c57ef8f534065930b805e6dbd83613f0534447922782fb9920fce28/aiohttp-3.11.12-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:392432a2dde22b86f70dd4a0e9671a349446c93965f261dbaecfaf28813e5c42", size = 1584797 },
+ { url = "https://files.pythonhosted.org/packages/bd/1e/cee6b51fcb3b1c4185a7dc62b3113bc136fae07f39386c88c90b7f79f199/aiohttp-3.11.12-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:88d385b8e7f3a870146bf5ea31786ef7463e99eb59e31db56e2315535d811f55", size = 1632535 },
+ { url = "https://files.pythonhosted.org/packages/71/1f/42424462b7a09da362e1711090db9f8d68a37a33f0aab51307335517c599/aiohttp-3.11.12-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b10a47e5390c4b30a0d58ee12581003be52eedd506862ab7f97da7a66805befb", size = 1668484 },
+ { url = "https://files.pythonhosted.org/packages/f6/79/0e25542bbe3c2bfd7a12c7a49c7bce73b09a836f65079e4b77bc2bafc89e/aiohttp-3.11.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b5263dcede17b6b0c41ef0c3ccce847d82a7da98709e75cf7efde3e9e3b5cae", size = 1589708 },
+ { url = "https://files.pythonhosted.org/packages/d1/13/93ae26b75e23f7d3a613872e472fae836ca100dc5bde5936ebc93ada8890/aiohttp-3.11.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50c5c7b8aa5443304c55c262c5693b108c35a3b61ef961f1e782dd52a2f559c7", size = 1544752 },
+ { url = "https://files.pythonhosted.org/packages/cf/5e/48847fad1b014ef92ef18ea1339a3b58eb81d3bc717b94c3627f5d2a42c5/aiohttp-3.11.12-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d1c031a7572f62f66f1257db37ddab4cb98bfaf9b9434a3b4840bf3560f5e788", size = 1529417 },
+ { url = "https://files.pythonhosted.org/packages/ae/56/fbd4ea019303f4877f0e0b8c9de92e9db24338e7545570d3f275f3c74c53/aiohttp-3.11.12-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:7e44eba534381dd2687be50cbd5f2daded21575242ecfdaf86bbeecbc38dae8e", size = 1557808 },
+ { url = "https://files.pythonhosted.org/packages/f1/43/112189cf6b3c482ecdd6819b420eaa0c2033426f28d741bb7f19db5dd2bb/aiohttp-3.11.12-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:145a73850926018ec1681e734cedcf2716d6a8697d90da11284043b745c286d5", size = 1536765 },
+ { url = "https://files.pythonhosted.org/packages/30/12/59986547de8306e06c7b30e547ccda02d29636e152366caba2dd8627bfe1/aiohttp-3.11.12-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:2c311e2f63e42c1bf86361d11e2c4a59f25d9e7aabdbdf53dc38b885c5435cdb", size = 1607621 },
+ { url = "https://files.pythonhosted.org/packages/aa/9b/af3b323b20df3318ed20d701d8242e523d59c842ca93f23134b05c9d5054/aiohttp-3.11.12-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:ea756b5a7bac046d202a9a3889b9a92219f885481d78cd318db85b15cc0b7bcf", size = 1628977 },
+ { url = "https://files.pythonhosted.org/packages/36/62/adf5a331a7bda475cc326dde393fa2bc5849060b1b37ac3d1bee1953f2cd/aiohttp-3.11.12-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:526c900397f3bbc2db9cb360ce9c35134c908961cdd0ac25b1ae6ffcaa2507ff", size = 1564455 },
+ { url = "https://files.pythonhosted.org/packages/90/c4/4a24291f22f111a854dfdb54dc94d4e0a5229ccbb7bc7f0bed972aa50410/aiohttp-3.11.12-cp310-cp310-win32.whl", hash = "sha256:b8d3bb96c147b39c02d3db086899679f31958c5d81c494ef0fc9ef5bb1359b3d", size = 416768 },
+ { url = "https://files.pythonhosted.org/packages/51/69/5221c8006acb7bb10d9e8e2238fb216571bddc2e00a8d95bcfbe2f579c57/aiohttp-3.11.12-cp310-cp310-win_amd64.whl", hash = "sha256:7fe3d65279bfbee8de0fb4f8c17fc4e893eed2dba21b2f680e930cc2b09075c5", size = 442170 },
+ { url = "https://files.pythonhosted.org/packages/9c/38/35311e70196b6a63cfa033a7f741f800aa8a93f57442991cbe51da2394e7/aiohttp-3.11.12-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:87a2e00bf17da098d90d4145375f1d985a81605267e7f9377ff94e55c5d769eb", size = 708797 },
+ { url = "https://files.pythonhosted.org/packages/44/3e/46c656e68cbfc4f3fc7cb5d2ba4da6e91607fe83428208028156688f6201/aiohttp-3.11.12-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b34508f1cd928ce915ed09682d11307ba4b37d0708d1f28e5774c07a7674cac9", size = 468669 },
+ { url = "https://files.pythonhosted.org/packages/a0/d6/2088fb4fd1e3ac2bfb24bc172223babaa7cdbb2784d33c75ec09e66f62f8/aiohttp-3.11.12-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:936d8a4f0f7081327014742cd51d320296b56aa6d324461a13724ab05f4b2933", size = 455739 },
+ { url = "https://files.pythonhosted.org/packages/e7/dc/c443a6954a56f4a58b5efbfdf23cc6f3f0235e3424faf5a0c56264d5c7bb/aiohttp-3.11.12-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de1378f72def7dfb5dbd73d86c19eda0ea7b0a6873910cc37d57e80f10d64e1", size = 1685858 },
+ { url = "https://files.pythonhosted.org/packages/25/67/2d5b3aaade1d5d01c3b109aa76e3aa9630531252cda10aa02fb99b0b11a1/aiohttp-3.11.12-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b9d45dbb3aaec05cf01525ee1a7ac72de46a8c425cb75c003acd29f76b1ffe94", size = 1743829 },
+ { url = "https://files.pythonhosted.org/packages/90/9b/9728fe9a3e1b8521198455d027b0b4035522be18f504b24c5d38d59e7278/aiohttp-3.11.12-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:930ffa1925393381e1e0a9b82137fa7b34c92a019b521cf9f41263976666a0d6", size = 1785587 },
+ { url = "https://files.pythonhosted.org/packages/ce/cf/28fbb43d4ebc1b4458374a3c7b6db3b556a90e358e9bbcfe6d9339c1e2b6/aiohttp-3.11.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8340def6737118f5429a5df4e88f440746b791f8f1c4ce4ad8a595f42c980bd5", size = 1675319 },
+ { url = "https://files.pythonhosted.org/packages/e5/d2/006c459c11218cabaa7bca401f965c9cc828efbdea7e1615d4644eaf23f7/aiohttp-3.11.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4016e383f91f2814e48ed61e6bda7d24c4d7f2402c75dd28f7e1027ae44ea204", size = 1619982 },
+ { url = "https://files.pythonhosted.org/packages/9d/83/ca425891ebd37bee5d837110f7fddc4d808a7c6c126a7d1b5c3ad72fc6ba/aiohttp-3.11.12-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3c0600bcc1adfaaac321422d615939ef300df81e165f6522ad096b73439c0f58", size = 1654176 },
+ { url = "https://files.pythonhosted.org/packages/25/df/047b1ce88514a1b4915d252513640184b63624e7914e41d846668b8edbda/aiohttp-3.11.12-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:0450ada317a65383b7cce9576096150fdb97396dcfe559109b403c7242faffef", size = 1660198 },
+ { url = "https://files.pythonhosted.org/packages/d3/cc/6ecb8e343f0902528620b9dbd567028a936d5489bebd7dbb0dd0914f4fdb/aiohttp-3.11.12-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:850ff6155371fd802a280f8d369d4e15d69434651b844bde566ce97ee2277420", size = 1650186 },
+ { url = "https://files.pythonhosted.org/packages/f8/f8/453df6dd69256ca8c06c53fc8803c9056e2b0b16509b070f9a3b4bdefd6c/aiohttp-3.11.12-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:8fd12d0f989c6099e7b0f30dc6e0d1e05499f3337461f0b2b0dadea6c64b89df", size = 1733063 },
+ { url = "https://files.pythonhosted.org/packages/55/f8/540160787ff3000391de0e5d0d1d33be4c7972f933c21991e2ea105b2d5e/aiohttp-3.11.12-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:76719dd521c20a58a6c256d058547b3a9595d1d885b830013366e27011ffe804", size = 1755306 },
+ { url = "https://files.pythonhosted.org/packages/30/7d/49f3bfdfefd741576157f8f91caa9ff61a6f3d620ca6339268327518221b/aiohttp-3.11.12-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:97fe431f2ed646a3b56142fc81d238abcbaff08548d6912acb0b19a0cadc146b", size = 1692909 },
+ { url = "https://files.pythonhosted.org/packages/40/9c/8ce00afd6f6112ce9a2309dc490fea376ae824708b94b7b5ea9cba979d1d/aiohttp-3.11.12-cp311-cp311-win32.whl", hash = "sha256:e10c440d142fa8b32cfdb194caf60ceeceb3e49807072e0dc3a8887ea80e8c16", size = 416584 },
+ { url = "https://files.pythonhosted.org/packages/35/97/4d3c5f562f15830de472eb10a7a222655d750839943e0e6d915ef7e26114/aiohttp-3.11.12-cp311-cp311-win_amd64.whl", hash = "sha256:246067ba0cf5560cf42e775069c5d80a8989d14a7ded21af529a4e10e3e0f0e6", size = 442674 },
+ { url = "https://files.pythonhosted.org/packages/4d/d0/94346961acb476569fca9a644cc6f9a02f97ef75961a6b8d2b35279b8d1f/aiohttp-3.11.12-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:e392804a38353900c3fd8b7cacbea5132888f7129f8e241915e90b85f00e3250", size = 704837 },
+ { url = "https://files.pythonhosted.org/packages/a9/af/05c503f1cc8f97621f199ef4b8db65fb88b8bc74a26ab2adb74789507ad3/aiohttp-3.11.12-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8fa1510b96c08aaad49303ab11f8803787c99222288f310a62f493faf883ede1", size = 464218 },
+ { url = "https://files.pythonhosted.org/packages/f2/48/b9949eb645b9bd699153a2ec48751b985e352ab3fed9d98c8115de305508/aiohttp-3.11.12-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dc065a4285307607df3f3686363e7f8bdd0d8ab35f12226362a847731516e42c", size = 456166 },
+ { url = "https://files.pythonhosted.org/packages/14/fb/980981807baecb6f54bdd38beb1bd271d9a3a786e19a978871584d026dcf/aiohttp-3.11.12-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cddb31f8474695cd61fc9455c644fc1606c164b93bff2490390d90464b4655df", size = 1682528 },
+ { url = "https://files.pythonhosted.org/packages/90/cb/77b1445e0a716914e6197b0698b7a3640590da6c692437920c586764d05b/aiohttp-3.11.12-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9dec0000d2d8621d8015c293e24589d46fa218637d820894cb7356c77eca3259", size = 1737154 },
+ { url = "https://files.pythonhosted.org/packages/ff/24/d6fb1f4cede9ccbe98e4def6f3ed1e1efcb658871bbf29f4863ec646bf38/aiohttp-3.11.12-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e3552fe98e90fdf5918c04769f338a87fa4f00f3b28830ea9b78b1bdc6140e0d", size = 1793435 },
+ { url = "https://files.pythonhosted.org/packages/17/e2/9f744cee0861af673dc271a3351f59ebd5415928e20080ab85be25641471/aiohttp-3.11.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dfe7f984f28a8ae94ff3a7953cd9678550dbd2a1f9bda5dd9c5ae627744c78e", size = 1692010 },
+ { url = "https://files.pythonhosted.org/packages/90/c4/4a1235c1df544223eb57ba553ce03bc706bdd065e53918767f7fa1ff99e0/aiohttp-3.11.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a481a574af914b6e84624412666cbfbe531a05667ca197804ecc19c97b8ab1b0", size = 1619481 },
+ { url = "https://files.pythonhosted.org/packages/60/70/cf12d402a94a33abda86dd136eb749b14c8eb9fec1e16adc310e25b20033/aiohttp-3.11.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1987770fb4887560363b0e1a9b75aa303e447433c41284d3af2840a2f226d6e0", size = 1641578 },
+ { url = "https://files.pythonhosted.org/packages/1b/25/7211973fda1f5e833fcfd98ccb7f9ce4fbfc0074e3e70c0157a751d00db8/aiohttp-3.11.12-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:a4ac6a0f0f6402854adca4e3259a623f5c82ec3f0c049374133bcb243132baf9", size = 1684463 },
+ { url = "https://files.pythonhosted.org/packages/93/60/b5905b4d0693f6018b26afa9f2221fefc0dcbd3773fe2dff1a20fb5727f1/aiohttp-3.11.12-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c96a43822f1f9f69cc5c3706af33239489a6294be486a0447fb71380070d4d5f", size = 1646691 },
+ { url = "https://files.pythonhosted.org/packages/b4/fc/ba1b14d6fdcd38df0b7c04640794b3683e949ea10937c8a58c14d697e93f/aiohttp-3.11.12-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:a5e69046f83c0d3cb8f0d5bd9b8838271b1bc898e01562a04398e160953e8eb9", size = 1702269 },
+ { url = "https://files.pythonhosted.org/packages/5e/39/18c13c6f658b2ba9cc1e0c6fb2d02f98fd653ad2addcdf938193d51a9c53/aiohttp-3.11.12-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:68d54234c8d76d8ef74744f9f9fc6324f1508129e23da8883771cdbb5818cbef", size = 1734782 },
+ { url = "https://files.pythonhosted.org/packages/9f/d2/ccc190023020e342419b265861877cd8ffb75bec37b7ddd8521dd2c6deb8/aiohttp-3.11.12-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c9fd9dcf9c91affe71654ef77426f5cf8489305e1c66ed4816f5a21874b094b9", size = 1694740 },
+ { url = "https://files.pythonhosted.org/packages/3f/54/186805bcada64ea90ea909311ffedcd74369bfc6e880d39d2473314daa36/aiohttp-3.11.12-cp312-cp312-win32.whl", hash = "sha256:0ed49efcd0dc1611378beadbd97beb5d9ca8fe48579fc04a6ed0844072261b6a", size = 411530 },
+ { url = "https://files.pythonhosted.org/packages/3d/63/5eca549d34d141bcd9de50d4e59b913f3641559460c739d5e215693cb54a/aiohttp-3.11.12-cp312-cp312-win_amd64.whl", hash = "sha256:54775858c7f2f214476773ce785a19ee81d1294a6bedc5cc17225355aab74802", size = 437860 },
+]
+
+[[package]]
+name = "aiosignal"
+version = "1.3.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "frozenlist" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/ba/b5/6d55e80f6d8a08ce22b982eafa278d823b541c925f11ee774b0b9c43473d/aiosignal-1.3.2.tar.gz", hash = "sha256:a8c255c66fafb1e499c9351d0bf32ff2d8a0321595ebac3b93713656d2436f54", size = 19424 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ec/6a/bc7e17a3e87a2985d3e8f4da4cd0f481060eb78fb08596c42be62c90a4d9/aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5", size = 7597 },
+]
+
+[[package]]
+name = "annotated-types"
+version = "0.7.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 },
+]
+
+[[package]]
+name = "antlr4-python3-runtime"
+version = "4.9.3"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b", size = 117034 }
+
+[[package]]
+name = "anyio"
+version = "4.8.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "exceptiongroup", marker = "python_full_version < '3.11'" },
+ { name = "idna" },
+ { name = "sniffio" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/a3/73/199a98fc2dae33535d6b8e8e6ec01f8c1d76c9adb096c6b7d64823038cde/anyio-4.8.0.tar.gz", hash = "sha256:1d9fe889df5212298c0c0723fa20479d1b94883a2df44bd3897aa91083316f7a", size = 181126 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/46/eb/e7f063ad1fec6b3178a3cd82d1a3c4de82cccf283fc42746168188e1cdd5/anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a", size = 96041 },
+]
+
+[[package]]
+name = "apsw"
+version = "3.48.0.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/f6/db/06485deedcac623addcfd13dceee54904d5477e66b406b0bf3f0418701c2/apsw-3.48.0.0.tar.gz", hash = "sha256:7c4492a55bd5c9f63821edd0162d6177f383b4733cfe421bd3bde5151e80c49b", size = 1037214 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/43/1e/5b06b5127148c1e98b3dbb9515e9b776edc6e797e9b3be5dd7426905f9d8/apsw-3.48.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e35ec9c5a91b1c683c8ca8deeffc30d0dfa3bd5e2d50113e2dbe7198fe4b7b1b", size = 1832003 },
+ { url = "https://files.pythonhosted.org/packages/be/cd/0d4447706138302d96de2797aaa6bf7926087fe465dad23ec0ed62c4e4bf/apsw-3.48.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:236e7da6a1d6262e4cade801de98bb11704e61bd13714daa842bf3f9cff25ae2", size = 1771501 },
+ { url = "https://files.pythonhosted.org/packages/b7/8a/30c615c0f31fc70e4f010999a46cbc06547bff47f24156ba273170e1cf8a/apsw-3.48.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:90d2a0eb609ba25253260685855513b304db522a19f9852ba6dc76a53c9a19a0", size = 6318082 },
+ { url = "https://files.pythonhosted.org/packages/03/9b/d035ca2bd2b4ce92a61ec941923ec4b519215dbb72a3b5f1224758a0f187/apsw-3.48.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df9921d81b69179df8fea18d0d2e07d4ca34803e083f74b3204878c96dcd3ba6", size = 6244569 },
+ { url = "https://files.pythonhosted.org/packages/6e/d6/a671e05c571f50e2352027bc542e03d0fdf91445673c85f21fc0cb2c296e/apsw-3.48.0.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bbd80f11b0a733b9e6660a3818a7e5faf8c940ce91137cd0dcda4d96563afc8", size = 6210184 },
+ { url = "https://files.pythonhosted.org/packages/cd/80/1feeebe84aa4cabfc719f3de2df7af02ecc5caef74737093bcb666ff16b5/apsw-3.48.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:96ed9190196596516b5872b5b1ef250fcbfe016b4563abe11ec255a7f2a1a9d7", size = 6370041 },
+ { url = "https://files.pythonhosted.org/packages/a5/1c/0713c6f833b12731e1084f5a2278c7ee5cc0ff03913e57560103f61c46ce/apsw-3.48.0.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:664058b1a45e62ecdc5ea8ba79e002c38b66b721fe540cb349afad5cc02bdc48", size = 6335684 },
+ { url = "https://files.pythonhosted.org/packages/83/a4/9a71eaf3b70de3c7b60fadd4de735563eb20fefc50d74e2241316a25ba39/apsw-3.48.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3e378cefc3e38e0fb0f0df3248286ca5f18a5e9f9c48377279e9495f23775d62", size = 6307331 },
+ { url = "https://files.pythonhosted.org/packages/18/f1/8aa76f85dd554e5074623505e15ad2b3a369d31accff2b7c53abbc33198d/apsw-3.48.0.0-cp310-cp310-win32.whl", hash = "sha256:07e3f59e3abd6414d4e50f79b553c43fe6b6a86acf51612a083fd3a6c2b5e99b", size = 1500885 },
+ { url = "https://files.pythonhosted.org/packages/b8/a1/8072b014398c94209d24a855e0193ac1e7f000837c6f43a2a9fe99d179dd/apsw-3.48.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:19846928f14af69cfb96ffc4b2325becda0e42ac8858d8551ce5688824333453", size = 1657578 },
+ { url = "https://files.pythonhosted.org/packages/b0/73/c142fc1e5be62476e00f30b64558c688c721380d10330e53644a676cab44/apsw-3.48.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:584ab5feeab24349f3931fb0062de1f7612725f5ae08e334bafc73dcccb638dc", size = 1834231 },
+ { url = "https://files.pythonhosted.org/packages/6d/7e/a672d31e5d6b5c75aae3482fce7b29895e31a1966a60a9d6a6591eee5bba/apsw-3.48.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e28d98f680fbfb65a8a2a6c26344115d2d95567d56737ee9f45d4b1153cdf075", size = 1772952 },
+ { url = "https://files.pythonhosted.org/packages/ae/e5/40124c46e7eee3b5dc62cd1654a5de5b8437c3dcd76e34569a2cb51df9cf/apsw-3.48.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c06fa4f53b3f59c3ff859e7fe425c543cf0cc3b2b300a84241ad9de0b0ff4a19", size = 6527609 },
+ { url = "https://files.pythonhosted.org/packages/0c/55/10dc21be42fb50660ad90b4ae5f706b621938b234663b9207d9a93c18164/apsw-3.48.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:647f5b4c38d7fd43b8bdba8f0e384e44a7426ee4062689e120e7f1e7a423af75", size = 6442666 },
+ { url = "https://files.pythonhosted.org/packages/de/ce/83fb818c31122374ac29bc8679d81f69e55d1581e55ed93f7d4b6cf56141/apsw-3.48.0.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:87a383db31053a3ccaaa724e93dd6af8ee0a39df942edf4cfb24b7ebaef8afb7", size = 6399865 },
+ { url = "https://files.pythonhosted.org/packages/a5/ec/005279c1770abb6896c876333a5fa87b66e68d10449d6aebec6d1101bf73/apsw-3.48.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:e733b1422873acb115a5d6acd4515039e7977721fe9821c4c644571ad2fa412e", size = 6543281 },
+ { url = "https://files.pythonhosted.org/packages/1c/6b/9eaf1babd0f498913d7aef84f57361db5ba202bc124010d90813c2387730/apsw-3.48.0.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:6aa02fbd9d1cbc299923824ecebbf7146fb9f349dbe24dbb1268d85427134bae", size = 6459805 },
+ { url = "https://files.pythonhosted.org/packages/92/e7/f3854606e6288c41b35b3b21361b65e94d062af461a5f8f21e14c0dd321d/apsw-3.48.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:58f3ffaa7430e030bb33da0186673ffd162b05f38fa2ff9f5088aa0c9520b846", size = 6419859 },
+ { url = "https://files.pythonhosted.org/packages/b2/a9/4355a022cfff3010644aeb9d5653711ccd14182d6005b37f6de7e1d7f3c1/apsw-3.48.0.0-cp311-cp311-win32.whl", hash = "sha256:3de4e939d2b4ec536ee8ec0b263aa2a6047976e081de22ba477b16d465fed473", size = 1494748 },
+ { url = "https://files.pythonhosted.org/packages/c8/cc/62e42c4461397268394a9e7a1f8a9549158833246d2a41e14a792f1ab9ed/apsw-3.48.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:8fc888fe2382aa09a659aa7b67ff4afa59cfd5b3ab0adec32b8b9a7151df8846", size = 1654600 },
+ { url = "https://files.pythonhosted.org/packages/d0/e4/3925f809d560b1bd2d0eb3d86b2444322077a0ef2a2c5574f71862e4c853/apsw-3.48.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f9e0010654a5c777105528394078d77128126657ca250af29cc6418fe4fbdccd", size = 1835648 },
+ { url = "https://files.pythonhosted.org/packages/f0/64/b9c88b89936d4fdfcc0266981d4af92080f43f5809a06b39fa205b3c9295/apsw-3.48.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1045520cb9c113cc3c1328896c58522b45f622dd66bce81eb8eaf8d671eb16d7", size = 1773007 },
+ { url = "https://files.pythonhosted.org/packages/c6/6d/3a6cd9c2167ce99c3a64285802007a15140437d34e4f848b5839d43484cc/apsw-3.48.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0bb77cb8a62089cba2101d84fe6052f6ece6e5f28cc8c724d9b332d4f2e1993", size = 6518416 },
+ { url = "https://files.pythonhosted.org/packages/8e/1c/02caa68f1213ff765b5bfb3a8af365236720436939f564a49b04336ebdc8/apsw-3.48.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80b1590b66a419984898796e1ba0251f81fbdf97ab4f29d7a427648dd0bb89a7", size = 6435935 },
+ { url = "https://files.pythonhosted.org/packages/d2/b4/d86e1ac821540fe2cd4010a610e121e11bd04373466cd4c854c260729dcb/apsw-3.48.0.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a5389511068e4977179a4978d92df0dd943a28635ba5e289b895387adc436184", size = 6387775 },
+ { url = "https://files.pythonhosted.org/packages/ab/91/c45d24e413428b45d089bb8736e7e8c941d64ffd32f5d7b7d2d493fc9970/apsw-3.48.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:45670822d98d3545f63203b1db2efde5eb08380989d8c61506e43207badf1be6", size = 6533327 },
+ { url = "https://files.pythonhosted.org/packages/4d/fe/f6a5e890b99c42e040ffa54e7d8f4934c621f534a7cd67763c7df1728ab1/apsw-3.48.0.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4b3b402ea4dd4e8943129ca98f7a4167341fd748525c74f6cd6310d247e3e3ca", size = 6436618 },
+ { url = "https://files.pythonhosted.org/packages/1f/ed/53eee29c9c702419ca104ed67f4f7d8439e262d5340a081f5102597844fa/apsw-3.48.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1bdcf0e772adbd69eeafa7f6dab5635623cf171282608b6e9354eb076cfb4afe", size = 6417649 },
+ { url = "https://files.pythonhosted.org/packages/20/b4/ba348c577e335326fadd2752e21eb5f71e007f428cfcfbd98ae431d5623f/apsw-3.48.0.0-cp312-cp312-win32.whl", hash = "sha256:3bfc2ee0f930d994848993c10bb5973b50b8542d07d0e479ab6665ea609a06fb", size = 1494385 },
+ { url = "https://files.pythonhosted.org/packages/3c/01/99deccb56796a31cfcc2bf480e764d3a165faa992d5da4880bec2a896d82/apsw-3.48.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:b0fa43d5624bfcec017ee87d2d296d0bcead9396afcfff2314dfa516b82d8274", size = 1653735 },
+]
+
+[[package]]
+name = "apswutils"
+version = "0.0.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "apsw" },
+ { name = "fastcore" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/ab/55/b18acc76ecb52b9703a89a17a8646ea6aa28355e1358420b07edace93884/apswutils-0.0.2.tar.gz", hash = "sha256:146b3d1f18d08551d2a0eb8f0b7325c2904978e42105284434c667428f98356c", size = 50854 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/7d/e5/2c36e0e7b2c79fdc6ae40dcad3790b420a3bd2a88ae5bccbba3e67cb904d/apswutils-0.0.2-py3-none-any.whl", hash = "sha256:8f98661f7110868fe509ebc5241ec01a9ea33dbce22c284717cded5402c4b864", size = 80452 },
+]
+
+[[package]]
+name = "args"
+version = "0.1.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/e5/1c/b701b3f4bd8d3667df8342f311b3efaeab86078a840fb826bd204118cc6b/args-0.1.0.tar.gz", hash = "sha256:a785b8d837625e9b61c39108532d95b85274acd679693b71ebb5156848fcf814", size = 3048 }
+
+[[package]]
+name = "asttokens"
+version = "3.0.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/4a/e7/82da0a03e7ba5141f05cce0d302e6eed121ae055e0456ca228bf693984bc/asttokens-3.0.0.tar.gz", hash = "sha256:0dcd8baa8d62b0c1d118b399b2ddba3c4aff271d0d7a9e0d4c1681c79035bbc7", size = 61978 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918 },
+]
+
+[[package]]
+name = "async-timeout"
+version = "5.0.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/a5/ae/136395dfbfe00dfc94da3f3e136d0b13f394cba8f4841120e34226265780/async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3", size = 9274 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233 },
+]
+
+[[package]]
+name = "attrs"
+version = "25.1.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/49/7c/fdf464bcc51d23881d110abd74b512a42b3d5d376a55a831b44c603ae17f/attrs-25.1.0.tar.gz", hash = "sha256:1c97078a80c814273a76b2a298a932eb681c87415c11dee0a6921de7f1b02c3e", size = 810562 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/fc/30/d4986a882011f9df997a55e6becd864812ccfcd821d64aac8570ee39f719/attrs-25.1.0-py3-none-any.whl", hash = "sha256:c75a69e28a550a7e93789579c22aa26b0f5b83b75dc4e08fe092980051e1090a", size = 63152 },
+]
+
+[[package]]
+name = "beautifulsoup4"
+version = "4.13.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "soupsieve" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/f0/3c/adaf39ce1fb4afdd21b611e3d530b183bb7759c9b673d60db0e347fd4439/beautifulsoup4-4.13.3.tar.gz", hash = "sha256:1bd32405dacc920b42b83ba01644747ed77456a65760e285fbc47633ceddaf8b", size = 619516 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/f9/49/6abb616eb3cbab6a7cca303dc02fdf3836de2e0b834bf966a7f5271a34d8/beautifulsoup4-4.13.3-py3-none-any.whl", hash = "sha256:99045d7d3f08f91f0d656bc9b7efbae189426cd913d830294a15eefa0ea4df16", size = 186015 },
+]
+
+[[package]]
+name = "bitsandbytes"
+version = "0.45.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "numpy" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/db/9d/9382259196d7ad7f3550702390081224e673a705e75b5660ee377b592fc0/bitsandbytes-0.45.2-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:ba3a720187f518b172ebce4081049c682ae3fd8284947e22499b256ff99a2bc3", size = 69680042 },
+ { url = "https://files.pythonhosted.org/packages/cb/33/550bcfe84f08ee20f3bcc1b129dcadaf7f2d1a065ce9812476fc7be2874a/bitsandbytes-0.45.2-py3-none-win_amd64.whl", hash = "sha256:e1893170455422924156760bae9e210ae26e8bd2ce7b21065d24b47decbe6963", size = 69124143 },
+]
+
+[[package]]
+name = "blinker"
+version = "1.9.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/21/28/9b3f50ce0e048515135495f198351908d99540d69bfdc8c1d15b73dc55ce/blinker-1.9.0.tar.gz", hash = "sha256:b4ce2265a7abece45e7cc896e98dbebe6cead56bcf805a3d23136d145f5445bf", size = 22460 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458 },
+]
+
+[[package]]
+name = "braceexpand"
+version = "0.1.7"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/54/93/badd4f5ccf25209f3fef2573073da9fe4a45a3da99fca2f800f942130c0f/braceexpand-0.1.7.tar.gz", hash = "sha256:e6e539bd20eaea53547472ff94f4fb5c3d3bf9d0a89388c4b56663aba765f705", size = 7777 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/fa/93/e8c04e80e82391a6e51f218ca49720f64236bc824e92152a2633b74cf7ab/braceexpand-0.1.7-py2.py3-none-any.whl", hash = "sha256:91332d53de7828103dcae5773fb43bc34950b0c8160e35e0f44c4427a3b85014", size = 5923 },
+]
+
+[[package]]
+name = "certifi"
+version = "2025.1.31"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/1c/ab/c9f1e32b7b1bf505bf26f0ef697775960db7932abeb7b516de930ba2705f/certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651", size = 167577 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/38/fc/bce832fd4fd99766c04d1ee0eead6b0ec6486fb100ae5e74c1d91292b982/certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe", size = 166393 },
+]
+
+[[package]]
+name = "charset-normalizer"
+version = "3.4.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/16/b0/572805e227f01586461c80e0fd25d65a2115599cc9dad142fee4b747c357/charset_normalizer-3.4.1.tar.gz", hash = "sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3", size = 123188 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/0d/58/5580c1716040bc89206c77d8f74418caf82ce519aae06450393ca73475d1/charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de", size = 198013 },
+ { url = "https://files.pythonhosted.org/packages/d0/11/00341177ae71c6f5159a08168bcb98c6e6d196d372c94511f9f6c9afe0c6/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176", size = 141285 },
+ { url = "https://files.pythonhosted.org/packages/01/09/11d684ea5819e5a8f5100fb0b38cf8d02b514746607934134d31233e02c8/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e218488cd232553829be0664c2292d3af2eeeb94b32bea483cf79ac6a694e037", size = 151449 },
+ { url = "https://files.pythonhosted.org/packages/08/06/9f5a12939db324d905dc1f70591ae7d7898d030d7662f0d426e2286f68c9/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80ed5e856eb7f30115aaf94e4a08114ccc8813e6ed1b5efa74f9f82e8509858f", size = 143892 },
+ { url = "https://files.pythonhosted.org/packages/93/62/5e89cdfe04584cb7f4d36003ffa2936681b03ecc0754f8e969c2becb7e24/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b010a7a4fd316c3c484d482922d13044979e78d1861f0e0650423144c616a46a", size = 146123 },
+ { url = "https://files.pythonhosted.org/packages/a9/ac/ab729a15c516da2ab70a05f8722ecfccc3f04ed7a18e45c75bbbaa347d61/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4532bff1b8421fd0a320463030c7520f56a79c9024a4e88f01c537316019005a", size = 147943 },
+ { url = "https://files.pythonhosted.org/packages/03/d2/3f392f23f042615689456e9a274640c1d2e5dd1d52de36ab8f7955f8f050/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d973f03c0cb71c5ed99037b870f2be986c3c05e63622c017ea9816881d2dd247", size = 142063 },
+ { url = "https://files.pythonhosted.org/packages/f2/e3/e20aae5e1039a2cd9b08d9205f52142329f887f8cf70da3650326670bddf/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3a3bd0dcd373514dcec91c411ddb9632c0d7d92aed7093b8c3bbb6d69ca74408", size = 150578 },
+ { url = "https://files.pythonhosted.org/packages/8d/af/779ad72a4da0aed925e1139d458adc486e61076d7ecdcc09e610ea8678db/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:d9c3cdf5390dcd29aa8056d13e8e99526cda0305acc038b96b30352aff5ff2bb", size = 153629 },
+ { url = "https://files.pythonhosted.org/packages/c2/b6/7aa450b278e7aa92cf7732140bfd8be21f5f29d5bf334ae987c945276639/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:2bdfe3ac2e1bbe5b59a1a63721eb3b95fc9b6817ae4a46debbb4e11f6232428d", size = 150778 },
+ { url = "https://files.pythonhosted.org/packages/39/f4/d9f4f712d0951dcbfd42920d3db81b00dd23b6ab520419626f4023334056/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:eab677309cdb30d047996b36d34caeda1dc91149e4fdca0b1a039b3f79d9a807", size = 146453 },
+ { url = "https://files.pythonhosted.org/packages/49/2b/999d0314e4ee0cff3cb83e6bc9aeddd397eeed693edb4facb901eb8fbb69/charset_normalizer-3.4.1-cp310-cp310-win32.whl", hash = "sha256:c0429126cf75e16c4f0ad00ee0eae4242dc652290f940152ca8c75c3a4b6ee8f", size = 95479 },
+ { url = "https://files.pythonhosted.org/packages/2d/ce/3cbed41cff67e455a386fb5e5dd8906cdda2ed92fbc6297921f2e4419309/charset_normalizer-3.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:9f0b8b1c6d84c8034a44893aba5e767bf9c7a211e313a9605d9c617d7083829f", size = 102790 },
+ { url = "https://files.pythonhosted.org/packages/72/80/41ef5d5a7935d2d3a773e3eaebf0a9350542f2cab4eac59a7a4741fbbbbe/charset_normalizer-3.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8bfa33f4f2672964266e940dd22a195989ba31669bd84629f05fab3ef4e2d125", size = 194995 },
+ { url = "https://files.pythonhosted.org/packages/7a/28/0b9fefa7b8b080ec492110af6d88aa3dea91c464b17d53474b6e9ba5d2c5/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28bf57629c75e810b6ae989f03c0828d64d6b26a5e205535585f96093e405ed1", size = 139471 },
+ { url = "https://files.pythonhosted.org/packages/71/64/d24ab1a997efb06402e3fc07317e94da358e2585165930d9d59ad45fcae2/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f08ff5e948271dc7e18a35641d2f11a4cd8dfd5634f55228b691e62b37125eb3", size = 149831 },
+ { url = "https://files.pythonhosted.org/packages/37/ed/be39e5258e198655240db5e19e0b11379163ad7070962d6b0c87ed2c4d39/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:234ac59ea147c59ee4da87a0c0f098e9c8d169f4dc2a159ef720f1a61bbe27cd", size = 142335 },
+ { url = "https://files.pythonhosted.org/packages/88/83/489e9504711fa05d8dde1574996408026bdbdbd938f23be67deebb5eca92/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd4ec41f914fa74ad1b8304bbc634b3de73d2a0889bd32076342a573e0779e00", size = 143862 },
+ { url = "https://files.pythonhosted.org/packages/c6/c7/32da20821cf387b759ad24627a9aca289d2822de929b8a41b6241767b461/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eea6ee1db730b3483adf394ea72f808b6e18cf3cb6454b4d86e04fa8c4327a12", size = 145673 },
+ { url = "https://files.pythonhosted.org/packages/68/85/f4288e96039abdd5aeb5c546fa20a37b50da71b5cf01e75e87f16cd43304/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c96836c97b1238e9c9e3fe90844c947d5afbf4f4c92762679acfe19927d81d77", size = 140211 },
+ { url = "https://files.pythonhosted.org/packages/28/a3/a42e70d03cbdabc18997baf4f0227c73591a08041c149e710045c281f97b/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4d86f7aff21ee58f26dcf5ae81a9addbd914115cdebcbb2217e4f0ed8982e146", size = 148039 },
+ { url = "https://files.pythonhosted.org/packages/85/e4/65699e8ab3014ecbe6f5c71d1a55d810fb716bbfd74f6283d5c2aa87febf/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:09b5e6733cbd160dcc09589227187e242a30a49ca5cefa5a7edd3f9d19ed53fd", size = 151939 },
+ { url = "https://files.pythonhosted.org/packages/b1/82/8e9fe624cc5374193de6860aba3ea8070f584c8565ee77c168ec13274bd2/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:5777ee0881f9499ed0f71cc82cf873d9a0ca8af166dfa0af8ec4e675b7df48e6", size = 149075 },
+ { url = "https://files.pythonhosted.org/packages/3d/7b/82865ba54c765560c8433f65e8acb9217cb839a9e32b42af4aa8e945870f/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:237bdbe6159cff53b4f24f397d43c6336c6b0b42affbe857970cefbb620911c8", size = 144340 },
+ { url = "https://files.pythonhosted.org/packages/b5/b6/9674a4b7d4d99a0d2df9b215da766ee682718f88055751e1e5e753c82db0/charset_normalizer-3.4.1-cp311-cp311-win32.whl", hash = "sha256:8417cb1f36cc0bc7eaba8ccb0e04d55f0ee52df06df3ad55259b9a323555fc8b", size = 95205 },
+ { url = "https://files.pythonhosted.org/packages/1e/ab/45b180e175de4402dcf7547e4fb617283bae54ce35c27930a6f35b6bef15/charset_normalizer-3.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:d7f50a1f8c450f3925cb367d011448c39239bb3eb4117c36a6d354794de4ce76", size = 102441 },
+ { url = "https://files.pythonhosted.org/packages/0a/9a/dd1e1cdceb841925b7798369a09279bd1cf183cef0f9ddf15a3a6502ee45/charset_normalizer-3.4.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:73d94b58ec7fecbc7366247d3b0b10a21681004153238750bb67bd9012414545", size = 196105 },
+ { url = "https://files.pythonhosted.org/packages/d3/8c/90bfabf8c4809ecb648f39794cf2a84ff2e7d2a6cf159fe68d9a26160467/charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dad3e487649f498dd991eeb901125411559b22e8d7ab25d3aeb1af367df5efd7", size = 140404 },
+ { url = "https://files.pythonhosted.org/packages/ad/8f/e410d57c721945ea3b4f1a04b74f70ce8fa800d393d72899f0a40526401f/charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c30197aa96e8eed02200a83fba2657b4c3acd0f0aa4bdc9f6c1af8e8962e0757", size = 150423 },
+ { url = "https://files.pythonhosted.org/packages/f0/b8/e6825e25deb691ff98cf5c9072ee0605dc2acfca98af70c2d1b1bc75190d/charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2369eea1ee4a7610a860d88f268eb39b95cb588acd7235e02fd5a5601773d4fa", size = 143184 },
+ { url = "https://files.pythonhosted.org/packages/3e/a2/513f6cbe752421f16d969e32f3583762bfd583848b763913ddab8d9bfd4f/charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc2722592d8998c870fa4e290c2eec2c1569b87fe58618e67d38b4665dfa680d", size = 145268 },
+ { url = "https://files.pythonhosted.org/packages/74/94/8a5277664f27c3c438546f3eb53b33f5b19568eb7424736bdc440a88a31f/charset_normalizer-3.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffc9202a29ab3920fa812879e95a9e78b2465fd10be7fcbd042899695d75e616", size = 147601 },
+ { url = "https://files.pythonhosted.org/packages/7c/5f/6d352c51ee763623a98e31194823518e09bfa48be2a7e8383cf691bbb3d0/charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:804a4d582ba6e5b747c625bf1255e6b1507465494a40a2130978bda7b932c90b", size = 141098 },
+ { url = "https://files.pythonhosted.org/packages/78/d4/f5704cb629ba5ab16d1d3d741396aec6dc3ca2b67757c45b0599bb010478/charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0f55e69f030f7163dffe9fd0752b32f070566451afe180f99dbeeb81f511ad8d", size = 149520 },
+ { url = "https://files.pythonhosted.org/packages/c5/96/64120b1d02b81785f222b976c0fb79a35875457fa9bb40827678e54d1bc8/charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c4c3e6da02df6fa1410a7680bd3f63d4f710232d3139089536310d027950696a", size = 152852 },
+ { url = "https://files.pythonhosted.org/packages/84/c9/98e3732278a99f47d487fd3468bc60b882920cef29d1fa6ca460a1fdf4e6/charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:5df196eb874dae23dcfb968c83d4f8fdccb333330fe1fc278ac5ceeb101003a9", size = 150488 },
+ { url = "https://files.pythonhosted.org/packages/13/0e/9c8d4cb99c98c1007cc11eda969ebfe837bbbd0acdb4736d228ccaabcd22/charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e358e64305fe12299a08e08978f51fc21fac060dcfcddd95453eabe5b93ed0e1", size = 146192 },
+ { url = "https://files.pythonhosted.org/packages/b2/21/2b6b5b860781a0b49427309cb8670785aa543fb2178de875b87b9cc97746/charset_normalizer-3.4.1-cp312-cp312-win32.whl", hash = "sha256:9b23ca7ef998bc739bf6ffc077c2116917eabcc901f88da1b9856b210ef63f35", size = 95550 },
+ { url = "https://files.pythonhosted.org/packages/21/5b/1b390b03b1d16c7e382b561c5329f83cc06623916aab983e8ab9239c7d5c/charset_normalizer-3.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:6ff8a4a60c227ad87030d76e99cd1698345d4491638dfa6673027c48b3cd395f", size = 102785 },
+ { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 },
+]
+
+[[package]]
+name = "clean-fid"
+version = "0.1.35"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "numpy" },
+ { name = "pillow" },
+ { name = "requests" },
+ { name = "scipy" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "torchvision", version = "0.21.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torchvision", version = "0.21.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "tqdm" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/70/56/0dfc83e0fe455cfe6272b23a65039b4101c63a4e7446801e26178b675fbf/clean_fid-0.1.35-py3-none-any.whl", hash = "sha256:757cf49d75debe9b07ab14955fe59c845a296deaf0616153b40c5e75b3cf87fb", size = 26008 },
+]
+
+[[package]]
+name = "click"
+version = "8.1.8"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "colorama", marker = "sys_platform == 'win32'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 },
+]
+
+[[package]]
+name = "clint"
+version = "0.5.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "args" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/3d/b4/41ecb1516f1ba728f39ee7062b9dac1352d39823f513bb6f9e8aeb86e26d/clint-0.5.1.tar.gz", hash = "sha256:05224c32b1075563d0b16d0015faaf9da43aa214e4a2140e51f08789e7a4c5aa", size = 29355 }
+
+[[package]]
+name = "clip"
+version = "1.0"
+source = { git = "ssh://git@github.com/openai/CLIP.git?rev=dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1#dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1" }
+dependencies = [
+ { name = "ftfy" },
+ { name = "packaging" },
+ { name = "regex" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "torchvision", version = "0.21.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torchvision", version = "0.21.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "tqdm" },
+]
+
+[[package]]
+name = "cloudpickle"
+version = "3.1.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/52/39/069100b84d7418bc358d81669d5748efb14b9cceacd2f9c75f550424132f/cloudpickle-3.1.1.tar.gz", hash = "sha256:b216fa8ae4019d5482a8ac3c95d8f6346115d8835911fd4aefd1a445e4242c64", size = 22113 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/7e/e8/64c37fadfc2816a7701fa8a6ed8d87327c7d54eacfbfb6edab14a2f2be75/cloudpickle-3.1.1-py3-none-any.whl", hash = "sha256:c8c5a44295039331ee9dad40ba100a9c7297b6f988e50e87ccdf3765a668350e", size = 20992 },
+]
+
+[[package]]
+name = "colorama"
+version = "0.4.6"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 },
+]
+
+[[package]]
+name = "contourpy"
+version = "1.3.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "numpy" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/25/c2/fc7193cc5383637ff390a712e88e4ded0452c9fbcf84abe3de5ea3df1866/contourpy-1.3.1.tar.gz", hash = "sha256:dfd97abd83335045a913e3bcc4a09c0ceadbe66580cf573fe961f4a825efa699", size = 13465753 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b2/a3/80937fe3efe0edacf67c9a20b955139a1a622730042c1ea991956f2704ad/contourpy-1.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a045f341a77b77e1c5de31e74e966537bba9f3c4099b35bf4c2e3939dd54cdab", size = 268466 },
+ { url = "https://files.pythonhosted.org/packages/82/1d/e3eaebb4aa2d7311528c048350ca8e99cdacfafd99da87bc0a5f8d81f2c2/contourpy-1.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:500360b77259914f7805af7462e41f9cb7ca92ad38e9f94d6c8641b089338124", size = 253314 },
+ { url = "https://files.pythonhosted.org/packages/de/f3/d796b22d1a2b587acc8100ba8c07fb7b5e17fde265a7bb05ab967f4c935a/contourpy-1.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2f926efda994cdf3c8d3fdb40b9962f86edbc4457e739277b961eced3d0b4c1", size = 312003 },
+ { url = "https://files.pythonhosted.org/packages/bf/f5/0e67902bc4394daee8daa39c81d4f00b50e063ee1a46cb3938cc65585d36/contourpy-1.3.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:adce39d67c0edf383647a3a007de0a45fd1b08dedaa5318404f1a73059c2512b", size = 351896 },
+ { url = "https://files.pythonhosted.org/packages/1f/d6/e766395723f6256d45d6e67c13bb638dd1fa9dc10ef912dc7dd3dcfc19de/contourpy-1.3.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:abbb49fb7dac584e5abc6636b7b2a7227111c4f771005853e7d25176daaf8453", size = 320814 },
+ { url = "https://files.pythonhosted.org/packages/a9/57/86c500d63b3e26e5b73a28b8291a67c5608d4aa87ebd17bd15bb33c178bc/contourpy-1.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0cffcbede75c059f535725c1680dfb17b6ba8753f0c74b14e6a9c68c29d7ea3", size = 324969 },
+ { url = "https://files.pythonhosted.org/packages/b8/62/bb146d1289d6b3450bccc4642e7f4413b92ebffd9bf2e91b0404323704a7/contourpy-1.3.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ab29962927945d89d9b293eabd0d59aea28d887d4f3be6c22deaefbb938a7277", size = 1265162 },
+ { url = "https://files.pythonhosted.org/packages/18/04/9f7d132ce49a212c8e767042cc80ae390f728060d2eea47058f55b9eff1c/contourpy-1.3.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:974d8145f8ca354498005b5b981165b74a195abfae9a8129df3e56771961d595", size = 1324328 },
+ { url = "https://files.pythonhosted.org/packages/46/23/196813901be3f97c83ababdab1382e13e0edc0bb4e7b49a7bff15fcf754e/contourpy-1.3.1-cp310-cp310-win32.whl", hash = "sha256:ac4578ac281983f63b400f7fe6c101bedc10651650eef012be1ccffcbacf3697", size = 173861 },
+ { url = "https://files.pythonhosted.org/packages/e0/82/c372be3fc000a3b2005061ca623a0d1ecd2eaafb10d9e883a2fc8566e951/contourpy-1.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:174e758c66bbc1c8576992cec9599ce8b6672b741b5d336b5c74e35ac382b18e", size = 218566 },
+ { url = "https://files.pythonhosted.org/packages/12/bb/11250d2906ee2e8b466b5f93e6b19d525f3e0254ac8b445b56e618527718/contourpy-1.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3e8b974d8db2c5610fb4e76307e265de0edb655ae8169e8b21f41807ccbeec4b", size = 269555 },
+ { url = "https://files.pythonhosted.org/packages/67/71/1e6e95aee21a500415f5d2dbf037bf4567529b6a4e986594d7026ec5ae90/contourpy-1.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:20914c8c973f41456337652a6eeca26d2148aa96dd7ac323b74516988bea89fc", size = 254549 },
+ { url = "https://files.pythonhosted.org/packages/31/2c/b88986e8d79ac45efe9d8801ae341525f38e087449b6c2f2e6050468a42c/contourpy-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19d40d37c1c3a4961b4619dd9d77b12124a453cc3d02bb31a07d58ef684d3d86", size = 313000 },
+ { url = "https://files.pythonhosted.org/packages/c4/18/65280989b151fcf33a8352f992eff71e61b968bef7432fbfde3a364f0730/contourpy-1.3.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:113231fe3825ebf6f15eaa8bc1f5b0ddc19d42b733345eae0934cb291beb88b6", size = 352925 },
+ { url = "https://files.pythonhosted.org/packages/f5/c7/5fd0146c93220dbfe1a2e0f98969293b86ca9bc041d6c90c0e065f4619ad/contourpy-1.3.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4dbbc03a40f916a8420e420d63e96a1258d3d1b58cbdfd8d1f07b49fcbd38e85", size = 323693 },
+ { url = "https://files.pythonhosted.org/packages/85/fc/7fa5d17daf77306840a4e84668a48ddff09e6bc09ba4e37e85ffc8e4faa3/contourpy-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a04ecd68acbd77fa2d39723ceca4c3197cb2969633836ced1bea14e219d077c", size = 326184 },
+ { url = "https://files.pythonhosted.org/packages/ef/e7/104065c8270c7397c9571620d3ab880558957216f2b5ebb7e040f85eeb22/contourpy-1.3.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c414fc1ed8ee1dbd5da626cf3710c6013d3d27456651d156711fa24f24bd1291", size = 1268031 },
+ { url = "https://files.pythonhosted.org/packages/e2/4a/c788d0bdbf32c8113c2354493ed291f924d4793c4a2e85b69e737a21a658/contourpy-1.3.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:31c1b55c1f34f80557d3830d3dd93ba722ce7e33a0b472cba0ec3b6535684d8f", size = 1325995 },
+ { url = "https://files.pythonhosted.org/packages/a6/e6/a2f351a90d955f8b0564caf1ebe4b1451a3f01f83e5e3a414055a5b8bccb/contourpy-1.3.1-cp311-cp311-win32.whl", hash = "sha256:f611e628ef06670df83fce17805c344710ca5cde01edfdc72751311da8585375", size = 174396 },
+ { url = "https://files.pythonhosted.org/packages/a8/7e/cd93cab453720a5d6cb75588cc17dcdc08fc3484b9de98b885924ff61900/contourpy-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:b2bdca22a27e35f16794cf585832e542123296b4687f9fd96822db6bae17bfc9", size = 219787 },
+ { url = "https://files.pythonhosted.org/packages/37/6b/175f60227d3e7f5f1549fcb374592be311293132207e451c3d7c654c25fb/contourpy-1.3.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:0ffa84be8e0bd33410b17189f7164c3589c229ce5db85798076a3fa136d0e509", size = 271494 },
+ { url = "https://files.pythonhosted.org/packages/6b/6a/7833cfae2c1e63d1d8875a50fd23371394f540ce809d7383550681a1fa64/contourpy-1.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:805617228ba7e2cbbfb6c503858e626ab528ac2a32a04a2fe88ffaf6b02c32bc", size = 255444 },
+ { url = "https://files.pythonhosted.org/packages/7f/b3/7859efce66eaca5c14ba7619791b084ed02d868d76b928ff56890d2d059d/contourpy-1.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ade08d343436a94e633db932e7e8407fe7de8083967962b46bdfc1b0ced39454", size = 307628 },
+ { url = "https://files.pythonhosted.org/packages/48/b2/011415f5e3f0a50b1e285a0bf78eb5d92a4df000553570f0851b6e309076/contourpy-1.3.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:47734d7073fb4590b4a40122b35917cd77be5722d80683b249dac1de266aac80", size = 347271 },
+ { url = "https://files.pythonhosted.org/packages/84/7d/ef19b1db0f45b151ac78c65127235239a8cf21a59d1ce8507ce03e89a30b/contourpy-1.3.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2ba94a401342fc0f8b948e57d977557fbf4d515f03c67682dd5c6191cb2d16ec", size = 318906 },
+ { url = "https://files.pythonhosted.org/packages/ba/99/6794142b90b853a9155316c8f470d2e4821fe6f086b03e372aca848227dd/contourpy-1.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efa874e87e4a647fd2e4f514d5e91c7d493697127beb95e77d2f7561f6905bd9", size = 323622 },
+ { url = "https://files.pythonhosted.org/packages/3c/0f/37d2c84a900cd8eb54e105f4fa9aebd275e14e266736778bb5dccbf3bbbb/contourpy-1.3.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1bf98051f1045b15c87868dbaea84f92408337d4f81d0e449ee41920ea121d3b", size = 1266699 },
+ { url = "https://files.pythonhosted.org/packages/3a/8a/deb5e11dc7d9cc8f0f9c8b29d4f062203f3af230ba83c30a6b161a6effc9/contourpy-1.3.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:61332c87493b00091423e747ea78200659dc09bdf7fd69edd5e98cef5d3e9a8d", size = 1326395 },
+ { url = "https://files.pythonhosted.org/packages/1a/35/7e267ae7c13aaf12322ccc493531f1e7f2eb8fba2927b9d7a05ff615df7a/contourpy-1.3.1-cp312-cp312-win32.whl", hash = "sha256:e914a8cb05ce5c809dd0fe350cfbb4e881bde5e2a38dc04e3afe1b3e58bd158e", size = 175354 },
+ { url = "https://files.pythonhosted.org/packages/a1/35/c2de8823211d07e8a79ab018ef03960716c5dff6f4d5bff5af87fd682992/contourpy-1.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:08d9d449a61cf53033612cb368f3a1b26cd7835d9b8cd326647efe43bca7568d", size = 220971 },
+ { url = "https://files.pythonhosted.org/packages/3e/4f/e56862e64b52b55b5ddcff4090085521fc228ceb09a88390a2b103dccd1b/contourpy-1.3.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:b457d6430833cee8e4b8e9b6f07aa1c161e5e0d52e118dc102c8f9bd7dd060d6", size = 265605 },
+ { url = "https://files.pythonhosted.org/packages/b0/2e/52bfeeaa4541889f23d8eadc6386b442ee2470bd3cff9baa67deb2dd5c57/contourpy-1.3.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb76c1a154b83991a3cbbf0dfeb26ec2833ad56f95540b442c73950af2013750", size = 315040 },
+ { url = "https://files.pythonhosted.org/packages/52/94/86bfae441707205634d80392e873295652fc313dfd93c233c52c4dc07874/contourpy-1.3.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:44a29502ca9c7b5ba389e620d44f2fbe792b1fb5734e8b931ad307071ec58c53", size = 218221 },
+]
+
+[[package]]
+name = "cycler"
+version = "0.12.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/a9/95/a3dbbb5028f35eafb79008e7522a75244477d2838f38cbb722248dabc2a8/cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c", size = 7615 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30", size = 8321 },
+]
+
+[[package]]
+name = "datasets"
+version = "3.2.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "aiohttp" },
+ { name = "dill" },
+ { name = "filelock" },
+ { name = "fsspec", extra = ["http"] },
+ { name = "huggingface-hub" },
+ { name = "multiprocess" },
+ { name = "numpy" },
+ { name = "packaging" },
+ { name = "pandas" },
+ { name = "pyarrow" },
+ { name = "pyyaml" },
+ { name = "requests" },
+ { name = "tqdm" },
+ { name = "xxhash" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/fc/48/744286c044e2b942d4fa67f92816126522ad1f0675def0ea3264e6242005/datasets-3.2.0.tar.gz", hash = "sha256:9a6e1a356052866b5dbdd9c9eedb000bf3fc43d986e3584d9b028f4976937229", size = 558366 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d7/84/0df6c5981f5fc722381662ff8cfbdf8aad64bec875f75d80b55bfef394ce/datasets-3.2.0-py3-none-any.whl", hash = "sha256:f3d2ba2698b7284a4518019658596a6a8bc79f31e51516524249d6c59cf0fe2a", size = 480647 },
+]
+
+[[package]]
+name = "decorator"
+version = "5.1.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/66/0c/8d907af351aa16b42caae42f9d6aa37b900c67308052d10fdce809f8d952/decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330", size = 35016 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d5/50/83c593b07763e1161326b3b8c6686f0f4b0f24d5526546bee538c89837d6/decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186", size = 9073 },
+]
+
+[[package]]
+name = "deepspeed"
+version = "0.16.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "einops" },
+ { name = "hjson" },
+ { name = "msgpack" },
+ { name = "ninja" },
+ { name = "numpy" },
+ { name = "nvidia-ml-py" },
+ { name = "packaging" },
+ { name = "psutil" },
+ { name = "py-cpuinfo" },
+ { name = "pydantic" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "tqdm" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/69/c5/9617171e726b11a7a753315906ae0ee8944925c9b6f975f22845c382eda7/deepspeed-0.16.3.tar.gz", hash = "sha256:3877b4d825fc940e3752f4e905f911a45623aa6f9f81b072a0d5deb87a2b785a", size = 1428532 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/52/23/cfe2166f06f1817b73b8695addd2fded55f4189322bcf181cfb7feaba550/deepspeed-0.16.3-cp310-cp310-win_amd64.whl", hash = "sha256:3c083645561b6d2d7246c30d63c900a1edf07fc302e27744fe264aca5fdaa456", size = 42055395 },
+ { url = "https://files.pythonhosted.org/packages/ce/92/853de2a32b3f9c0945e0c9e145eddac98e570c1faebcd58aa85e6dcd89ba/deepspeed-0.16.3-cp311-cp311-win_amd64.whl", hash = "sha256:719b1a417c2863f32f482d9948cb575f7fa08cc6d548250e6557057101c9aeab", size = 42071961 },
+ { url = "https://files.pythonhosted.org/packages/55/8f/57ee424e114478a667c604b2f3f05bbc1addc0afc089b9ab6a64b233063d/deepspeed-0.16.3-cp312-cp312-win_amd64.whl", hash = "sha256:8b1016f8d89f1617d9e92c3f03a93dfc32bbfd173aecbb06d209fe6d593ed766", size = 42072463 },
+]
+
+[[package]]
+name = "diffusers"
+version = "0.32.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "filelock" },
+ { name = "huggingface-hub" },
+ { name = "importlib-metadata" },
+ { name = "numpy" },
+ { name = "pillow" },
+ { name = "regex" },
+ { name = "requests" },
+ { name = "safetensors" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/ee/72/d3f715e71a77f14a48f1ac081cea5bb6c6e6cd41c55b1291f401b1504679/diffusers-0.32.2.tar.gz", hash = "sha256:eb1e36b326aabb0675729af7c626caf7a76ce7ced3a126e879331790b1eaa230", size = 2614622 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/88/c9/2b2e822d871b06605363e52555be5d7ff0996f4eeaf0f7e5fda15adedfbe/diffusers-0.32.2-py3-none-any.whl", hash = "sha256:d7f182b49c7f428737ee3bf6397d463ec03b85f4f3b2c9470bd1d73292b609ff", size = 3226075 },
+]
+
+[[package]]
+name = "dill"
+version = "0.3.8"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/17/4d/ac7ffa80c69ea1df30a8aa11b3578692a5118e7cd1aa157e3ef73b092d15/dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca", size = 184847 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7", size = 116252 },
+]
+
+[[package]]
+name = "docker-pycreds"
+version = "0.4.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "six" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/c5/e6/d1f6c00b7221e2d7c4b470132c931325c8b22c51ca62417e300f5ce16009/docker-pycreds-0.4.0.tar.gz", hash = "sha256:6ce3270bcaf404cc4c3e27e4b6c70d3521deae82fb508767870fdbf772d584d4", size = 8754 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/f5/e8/f6bd1eee09314e7e6dee49cbe2c5e22314ccdb38db16c9fc72d2fa80d054/docker_pycreds-0.4.0-py2.py3-none-any.whl", hash = "sha256:7266112468627868005106ec19cd0d722702d2b7d5912a28e19b826c3d37af49", size = 8982 },
+]
+
+[[package]]
+name = "einops"
+version = "0.8.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/79/ca/9f5dcb8bead39959454c3912266bedc4c315839cee0e0ca9f4328f4588c1/einops-0.8.0.tar.gz", hash = "sha256:63486517fed345712a8385c100cb279108d9d47e6ae59099b07657e983deae85", size = 58861 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/44/5a/f0b9ad6c0a9017e62d4735daaeb11ba3b6c009d69a26141b258cd37b5588/einops-0.8.0-py3-none-any.whl", hash = "sha256:9572fb63046264a862693b0a87088af3bdc8c068fde03de63453cbbde245465f", size = 43223 },
+]
+
+[[package]]
+name = "evaluate"
+version = "0.4.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "datasets" },
+ { name = "dill" },
+ { name = "fsspec", extra = ["http"] },
+ { name = "huggingface-hub" },
+ { name = "multiprocess" },
+ { name = "numpy" },
+ { name = "packaging" },
+ { name = "pandas" },
+ { name = "requests" },
+ { name = "tqdm" },
+ { name = "xxhash" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/5a/a0/10a56e0939ece94c54276e81459cb4101f46f0e9a6f54fc31a35f64e8854/evaluate-0.4.3.tar.gz", hash = "sha256:3a5700cf83aabee9549264e1e5666f116367c61dbd4d38352015e859a5e2098d", size = 65679 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/a2/e7/cbca9e2d2590eb9b5aa8f7ebabe1beb1498f9462d2ecede5c9fd9735faaf/evaluate-0.4.3-py3-none-any.whl", hash = "sha256:47d8770bdea76e2c2ed0d40189273027d1a41ccea861bcc7ba12d30ec5d1e517", size = 84010 },
+]
+
+[[package]]
+name = "exceptiongroup"
+version = "1.2.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/09/35/2495c4ac46b980e4ca1f6ad6db102322ef3ad2410b79fdde159a4b0f3b92/exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc", size = 28883 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453 },
+]
+
+[[package]]
+name = "executing"
+version = "2.2.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/91/50/a9d80c47ff289c611ff12e63f7c5d13942c65d68125160cefd768c73e6e4/executing-2.2.0.tar.gz", hash = "sha256:5d108c028108fe2551d1a7b2e8b713341e2cb4fc0aa7dcf966fa4327a5226755", size = 978693 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 },
+]
+
+[[package]]
+name = "faiss-cpu"
+version = "1.10.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "numpy" },
+ { name = "packaging" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/91/1b/6fe5dbe5be0240cfd82b52bd7c186655c578d935c0ce2e713c100e6f8cce/faiss_cpu-1.10.0.tar.gz", hash = "sha256:5bdca555f24bc036f4d67f8a5a4d6cc91b8d2126d4e78de496ca23ccd46e479d", size = 69159 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/8b/56/87eb506d8634f08fc7c63d1ca5631aeec7d6b9afbfabedf2cb7a2a804b13/faiss_cpu-1.10.0-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:6693474be296a7142ade1051ea18e7d85cedbfdee4b7eac9c52f83fed0467855", size = 7693034 },
+ { url = "https://files.pythonhosted.org/packages/51/46/f4d9de34ed1b06300b1a75b824d4857963216f5826de33f291af78088e39/faiss_cpu-1.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:70ebe60a560414dc8dd6cfe8fed105c8f002c0d11f765f5adfe8d63d42c0467f", size = 3234656 },
+ { url = "https://files.pythonhosted.org/packages/74/3a/e146861019d9290e0198b3470b8d13a658c3b5f228abefc3658ce0afd63d/faiss_cpu-1.10.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:74c5712d4890f15c661ab7b1b75867812e9596e1469759956fad900999bedbb5", size = 3663789 },
+ { url = "https://files.pythonhosted.org/packages/aa/40/624f0002bb777e37aac1aadfadec1eb4391be6ad05b7fcfbf66049b99a48/faiss_cpu-1.10.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:473d158fbd638d6ad5fb64469ba79a9f09d3494b5f4e8dfb4f40ce2fc335dca4", size = 30673545 },
+ { url = "https://files.pythonhosted.org/packages/d6/39/298ffcbefd899e84a43e63df217a6dc800d52bca37ebe0d1155ff367886a/faiss_cpu-1.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:dcd0cb2ec84698cbe3df9ed247d2392f09bda041ad34b92d38fa916cd019ad4b", size = 13684176 },
+ { url = "https://files.pythonhosted.org/packages/78/93/81800f41cb2c719c199d3eb534fcc154853123261d841e37482e8e468619/faiss_cpu-1.10.0-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:8ff6924b0f00df278afe70940ae86302066466580724c2f3238860039e9946f1", size = 7693037 },
+ { url = "https://files.pythonhosted.org/packages/8d/83/fc9028f6d6aec2c2f219f53a5d4a2b279434715643242e59a2e9755b1ce0/faiss_cpu-1.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cb80b530a9ded44a7d4031a7355a237aaa0ff1f150c1176df050e0254ea5f6f6", size = 3234657 },
+ { url = "https://files.pythonhosted.org/packages/af/45/588a02e60daa73f6052611334fbbdffcedf37122320f1c91cb90f3e69b96/faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:7a9fef4039ed877d40e41d5563417b154c7f8cd57621487dad13c4eb4f32515f", size = 3663710 },
+ { url = "https://files.pythonhosted.org/packages/cb/cf/9caa08ca4e21ab935f82be0713e5d60566140414c3fff7932d9427c8fd72/faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:49b6647aa9e159a2c4603cbff2e1b313becd98ad6e851737ab325c74fe8e0278", size = 30673629 },
+ { url = "https://files.pythonhosted.org/packages/2c/2d/d2a4171a9cca9a7c04cd9d6f9441a37f1e0558724b90bf7fc7db08553601/faiss_cpu-1.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:6f8c0ef8b615c12c7bf612bd1fc51cffa49c1ddaa6207c6981f01ab6782e6b3b", size = 13683966 },
+ { url = "https://files.pythonhosted.org/packages/bd/cc/f6aa1288dbb40b2a4f101d16900885e056541f37d8d08ec70462e92cf277/faiss_cpu-1.10.0-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:2aca486fe2d680ea64a18d356206c91ff85db99fd34c19a757298c67c23262b1", size = 7720242 },
+ { url = "https://files.pythonhosted.org/packages/be/56/40901306324a17fbc1eee8a6e86ba67bd99a67e768ce9908f271e648e9e0/faiss_cpu-1.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c1108a4059c66c37c403183e566ca1ed0974a6af7557c92d49207639aab661bc", size = 3239223 },
+ { url = "https://files.pythonhosted.org/packages/2e/34/5b1463c450c9a6de3109caf8f38fbf0c329ef940ed1973fcf8c8ec7fa27e/faiss_cpu-1.10.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:449f3eb778d6d937e01a16a3170de4bb8aabfe87c7cb479b458fb790276310c5", size = 3671461 },
+ { url = "https://files.pythonhosted.org/packages/78/d9/0b78c474289f23b31283d8fb64c8e6a522a7fa47b131a3c6c141c8e6639d/faiss_cpu-1.10.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:9899c340f92bd94071d6faf4bef0ccb5362843daea42144d4ba857a2a1f67511", size = 30663859 },
+ { url = "https://files.pythonhosted.org/packages/17/f0/194727b9e6e282e2877bc001ba886228f6af52e9a6730bbdb223e38591c3/faiss_cpu-1.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:345a52dbfa980d24b93c94410eadf82d1eef359c6a42e5e0768cca96539f1c3c", size = 13687087 },
+]
+
+[[package]]
+name = "fastapi"
+version = "0.115.8"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "pydantic" },
+ { name = "starlette" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/a2/b2/5a5dc4affdb6661dea100324e19a7721d5dc524b464fe8e366c093fd7d87/fastapi-0.115.8.tar.gz", hash = "sha256:0ce9111231720190473e222cdf0f07f7206ad7e53ea02beb1d2dc36e2f0741e9", size = 295403 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/8f/7d/2d6ce181d7a5f51dedb8c06206cbf0ec026a99bf145edd309f9e17c3282f/fastapi-0.115.8-py3-none-any.whl", hash = "sha256:753a96dd7e036b34eeef8babdfcfe3f28ff79648f86551eb36bfc1b0bf4a8cbf", size = 94814 },
+]
+
+[[package]]
+name = "fastcore"
+version = "1.7.29"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "packaging" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/a1/a6/f457241a8a5c42b80ef50b96e7cc515dd93bdb9ea273133004bbc8a6aa96/fastcore-1.7.29.tar.gz", hash = "sha256:e7e734cbe58805a22c205341c6671de562a8abba54b13eeb24cdb4486d066e31", size = 80514 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d7/3a/a0b1c764426622287c9b6547d4ea637c406bc884141814df4a5ebab3ab9b/fastcore-1.7.29-py3-none-any.whl", hash = "sha256:76fd4815eabbed704faca3abfea4b7e1f98b6351ba6c869a2d405f37bc4b0074", size = 84208 },
+]
+
+[[package]]
+name = "fastlite"
+version = "0.1.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "apswutils" },
+ { name = "fastcore" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/15/d6/0a6dc989095fc973e9d97f61c2d637042a59b53e6ae4bab23c68dfbfbcd8/fastlite-0.1.1.tar.gz", hash = "sha256:cbbbc70b3a58189416627a5eaa8f3f88c8c93fa2262e151c1be5705d657177c8", size = 20888 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/1d/cc/e0997edb370cadd4bb2e211da368ed6c4fa5cfe9d983cf38becc3e637eb9/fastlite-0.1.1-py3-none-any.whl", hash = "sha256:4e1988d9dc720a97f9717999b67a6ff45c0e3e323cf53af48e45cb9ac91b7e5a", size = 16644 },
+]
+
+[[package]]
+name = "filelock"
+version = "3.17.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/dc/9c/0b15fb47b464e1b663b1acd1253a062aa5feecb07d4e597daea542ebd2b5/filelock-3.17.0.tar.gz", hash = "sha256:ee4e77401ef576ebb38cd7f13b9b28893194acc20a8e68e18730ba9c0e54660e", size = 18027 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/89/ec/00d68c4ddfedfe64159999e5f8a98fb8442729a63e2077eb9dcd89623d27/filelock-3.17.0-py3-none-any.whl", hash = "sha256:533dc2f7ba78dc2f0f531fc6c4940addf7b70a481e269a5a3b93be94ffbe8338", size = 16164 },
+]
+
+[[package]]
+name = "flash-attn"
+version = "2.7.4.post1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "einops" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/11/34/9bf60e736ed7bbe15055ac2dab48ec67d9dbd088d2b4ae318fd77190ab4e/flash_attn-2.7.4.post1.tar.gz", hash = "sha256:f03485c9a49a4d68d0733acdcad80ab0e72afa025a777fdc2966ceccf9d51765", size = 5986610 }
+
+[[package]]
+name = "flask"
+version = "3.1.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "blinker" },
+ { name = "click" },
+ { name = "itsdangerous" },
+ { name = "jinja2" },
+ { name = "werkzeug" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/89/50/dff6380f1c7f84135484e176e0cac8690af72fa90e932ad2a0a60e28c69b/flask-3.1.0.tar.gz", hash = "sha256:5f873c5184c897c8d9d1b05df1e3d01b14910ce69607a117bd3277098a5836ac", size = 680824 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/af/47/93213ee66ef8fae3b93b3e29206f6b251e65c97bd91d8e1c5596ef15af0a/flask-3.1.0-py3-none-any.whl", hash = "sha256:d667207822eb83f1c4b50949b1623c8fc8d51f2341d65f72e1a1815397551136", size = 102979 },
+]
+
+[[package]]
+name = "fonttools"
+version = "4.55.8"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/f1/24/de7e40adc99be2aa5adc6321bbdf3cf58dbe751b87343da658dd3fc7d946/fonttools-4.55.8.tar.gz", hash = "sha256:54d481d456dcd59af25d4a9c56b2c4c3f20e9620b261b84144e5950f33e8df17", size = 3458915 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/54/b8/82b3444cb081798eabb8397452ddf73680e623d7fdf9c575594a2240b8a2/fonttools-4.55.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d11600f5343092697d7434f3bf77a393c7ae74be206fe30e577b9a195fd53165", size = 2752288 },
+ { url = "https://files.pythonhosted.org/packages/86/8f/9c5f2172e9f6dcf52bb6477bcd5a023d056114787c8184b683c34996f5a1/fonttools-4.55.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c96f2506ce1a0beeaa9595f9a8b7446477eb133f40c0e41fc078744c28149f80", size = 2280718 },
+ { url = "https://files.pythonhosted.org/packages/c6/a6/b7cd7b54412bb7a27e282ee54459cae24524ad0eab6f81ead2a91d435287/fonttools-4.55.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b5f05ef72e846e9f49ccdd74b9da4309901a4248434c63c1ee9321adcb51d65", size = 4562177 },
+ { url = "https://files.pythonhosted.org/packages/0e/16/eff3be24cecb9336639148c40507f949c193642d8369352af480597633fb/fonttools-4.55.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba45b637da80a262b55b7657aec68da2ac54b8ae7891cd977a5dbe5fd26db429", size = 4604843 },
+ { url = "https://files.pythonhosted.org/packages/b5/95/737574364439cbcc5e6d4f3e000f15432141680ca8cb5c216b619a3d1cab/fonttools-4.55.8-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:edcffaeadba9a334c1c3866e275d7dd495465e7dbd296f688901bdbd71758113", size = 4559127 },
+ { url = "https://files.pythonhosted.org/packages/5f/07/ea90834742f9b3e51a05f0f15f7c817eb7aab3d6ebf4f06c4626825ccb89/fonttools-4.55.8-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b9f9fce3c9b2196e162182ec5db8af8eb3acd0d76c2eafe9fdba5f370044e556", size = 4728575 },
+ { url = "https://files.pythonhosted.org/packages/93/74/0c816d83cd2945a25aed592b0cb3c9ba32e8b259781bf41dc112204129d9/fonttools-4.55.8-cp310-cp310-win32.whl", hash = "sha256:f089e8da0990cfe2d67e81d9cf581ff372b48dc5acf2782701844211cd1f0eb3", size = 2155662 },
+ { url = "https://files.pythonhosted.org/packages/78/bc/f5a24229edd8cdd7494f2099e1c62fca288dad4c8637ee62df04459db27e/fonttools-4.55.8-cp310-cp310-win_amd64.whl", hash = "sha256:01ea3901b0802fc5f9e854f5aeb5bc27770dd9dd24c28df8f74ba90f8b3f5915", size = 2200126 },
+ { url = "https://files.pythonhosted.org/packages/0a/e3/834e0919b34b40a6a2895f533323231bba3b8f5ae22c19ab725b84cf84c0/fonttools-4.55.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:95f5a1d4432b3cea6571f5ce4f4e9b25bf36efbd61c32f4f90130a690925d6ee", size = 2753424 },
+ { url = "https://files.pythonhosted.org/packages/b6/f9/9cf7fc04da85d37cfa1c287f0a25c274d6940dad259dbaa9fd796b87bd3c/fonttools-4.55.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3d20f152de7625a0008ba1513f126daaaa0de3b4b9030aa72dd5c27294992260", size = 2281635 },
+ { url = "https://files.pythonhosted.org/packages/35/1f/25330293a5bb6bd50825725270c587c2b25c2694020a82d2c424d2fd5469/fonttools-4.55.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5a3ff5bb95fd5a3962b2754f8435e6d930c84fc9e9921c51e802dddf40acd56", size = 4869363 },
+ { url = "https://files.pythonhosted.org/packages/f2/e0/e58b10ef50830145ba94dbeb64b70773af61cfccea663d485c7fae2aab65/fonttools-4.55.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b99d4fd2b6d0a00c7336c8363fccc7a11eccef4b17393af75ca6e77cf93ff413", size = 4898604 },
+ { url = "https://files.pythonhosted.org/packages/e0/66/b59025011dbae1ea10dcb60f713a10e54d17cde5c8dc48db75af79dc2088/fonttools-4.55.8-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d637e4d33e46619c79d1a6c725f74d71b574cd15fb5bbb9b6f3eba8f28363573", size = 4877804 },
+ { url = "https://files.pythonhosted.org/packages/67/76/abbbae972af55d54f83fcaeb90e26aaac937c8711b5a32d7c63768c37891/fonttools-4.55.8-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0f38bfb6b7a39c4162c3eb0820a0bdf8e3bdd125cd54e10ba242397d15e32439", size = 5045913 },
+ { url = "https://files.pythonhosted.org/packages/8b/f2/5eb68b5202731b008ccfd4ad6d82af9a8abdec411609e76fdd6c43881f2c/fonttools-4.55.8-cp311-cp311-win32.whl", hash = "sha256:acfec948de41cd5e640d5c15d0200e8b8e7c5c6bb82afe1ca095cbc4af1188ee", size = 2154525 },
+ { url = "https://files.pythonhosted.org/packages/42/d6/96dc2462006ffa16c8d475244e372abdc47d03a7bd38be0f29e7ae552af4/fonttools-4.55.8-cp311-cp311-win_amd64.whl", hash = "sha256:604c805b41241b4880e2dc86cf2d4754c06777371c8299799ac88d836cb18c3b", size = 2201043 },
+ { url = "https://files.pythonhosted.org/packages/e9/ce/8358af1c353d890d4c6cbcc3d64242631f91a93f8384b76bc49db800f1de/fonttools-4.55.8-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:63403ee0f2fa4e1de28e539f8c24f2bdca1d8ecb503fa9ea2d231d9f1e729809", size = 2747851 },
+ { url = "https://files.pythonhosted.org/packages/1b/3d/7a906f58f80c1ed37bbdf7b3f9b6792906156cb9143b067bf54c38405134/fonttools-4.55.8-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:302e1003a760b222f711d5ba6d1ad7fd5f7f713eb872cd6a3eb44390bc9770af", size = 2279102 },
+ { url = "https://files.pythonhosted.org/packages/0a/0a/91a923a9de012e0f751ef8e13e1a5ea10f3a1b8416ae9afd5db1ad351b20/fonttools-4.55.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e72a7816ff8a759be9ca36ca46934f8ccf4383711ef597d9240306fe1878cb8d", size = 4784092 },
+ { url = "https://files.pythonhosted.org/packages/e8/07/4b8a5c8a746cc8c8103c6462d057d8806bd925347ac3905055686dd40e94/fonttools-4.55.8-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03c2b50b54e6e8b3564b232e57e8f58be217cf441cf0155745d9e44a76f9c30f", size = 4855206 },
+ { url = "https://files.pythonhosted.org/packages/37/df/09bf09ff8eae1e74bf16f9df514fd60af9f3d994e3edb0339f7d0bbc59e2/fonttools-4.55.8-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a7230f7590f9570d26ee903b6a4540274494e200fae978df0d9325b7b9144529", size = 4762599 },
+ { url = "https://files.pythonhosted.org/packages/84/58/a80d97818a3bede7e4b58318302e89e749b9639c890ecbc972a6e533201f/fonttools-4.55.8-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:466a78984f0572305c3c48377f4e3f7f4e909f1209f45ef8e7041d5c8a744a56", size = 4990188 },
+ { url = "https://files.pythonhosted.org/packages/a8/e3/1f1b1a70527ab9a1b9bfe1829a783a042c108ab3357af626e8e69a21f0e2/fonttools-4.55.8-cp312-cp312-win32.whl", hash = "sha256:243cbfc0b7cb1c307af40e321f8343a48d0a080bc1f9466cf2b5468f776ef108", size = 2142995 },
+ { url = "https://files.pythonhosted.org/packages/61/cf/08c4954c944799458690eb0e498209fb6a2e79e20a869189f56d18e909b6/fonttools-4.55.8-cp312-cp312-win_amd64.whl", hash = "sha256:a19059aa892676822c1f05cb5a67296ecdfeb267fe7c47d4758f3e8e942c2b2a", size = 2189833 },
+ { url = "https://files.pythonhosted.org/packages/cc/e6/efdcd5d6858b951c29d56de31a19355579d826712bf390d964a21b076ddb/fonttools-4.55.8-py3-none-any.whl", hash = "sha256:07636dae94f7fe88561f9da7a46b13d8e3f529f87fdb221b11d85f91eabceeb7", size = 1089900 },
+]
+
+[[package]]
+name = "frozenlist"
+version = "1.5.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/8f/ed/0f4cec13a93c02c47ec32d81d11c0c1efbadf4a471e3f3ce7cad366cbbd3/frozenlist-1.5.0.tar.gz", hash = "sha256:81d5af29e61b9c8348e876d442253723928dce6433e0e76cd925cd83f1b4b817", size = 39930 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/54/79/29d44c4af36b2b240725dce566b20f63f9b36ef267aaaa64ee7466f4f2f8/frozenlist-1.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5b6a66c18b5b9dd261ca98dffcb826a525334b2f29e7caa54e182255c5f6a65a", size = 94451 },
+ { url = "https://files.pythonhosted.org/packages/47/47/0c999aeace6ead8a44441b4f4173e2261b18219e4ad1fe9a479871ca02fc/frozenlist-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d1b3eb7b05ea246510b43a7e53ed1653e55c2121019a97e60cad7efb881a97bb", size = 54301 },
+ { url = "https://files.pythonhosted.org/packages/8d/60/107a38c1e54176d12e06e9d4b5d755b677d71d1219217cee063911b1384f/frozenlist-1.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:15538c0cbf0e4fa11d1e3a71f823524b0c46299aed6e10ebb4c2089abd8c3bec", size = 52213 },
+ { url = "https://files.pythonhosted.org/packages/17/62/594a6829ac5679c25755362a9dc93486a8a45241394564309641425d3ff6/frozenlist-1.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e79225373c317ff1e35f210dd5f1344ff31066ba8067c307ab60254cd3a78ad5", size = 240946 },
+ { url = "https://files.pythonhosted.org/packages/7e/75/6c8419d8f92c80dd0ee3f63bdde2702ce6398b0ac8410ff459f9b6f2f9cb/frozenlist-1.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9272fa73ca71266702c4c3e2d4a28553ea03418e591e377a03b8e3659d94fa76", size = 264608 },
+ { url = "https://files.pythonhosted.org/packages/88/3e/82a6f0b84bc6fb7e0be240e52863c6d4ab6098cd62e4f5b972cd31e002e8/frozenlist-1.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:498524025a5b8ba81695761d78c8dd7382ac0b052f34e66939c42df860b8ff17", size = 261361 },
+ { url = "https://files.pythonhosted.org/packages/fd/85/14e5f9ccac1b64ff2f10c927b3ffdf88772aea875882406f9ba0cec8ad84/frozenlist-1.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:92b5278ed9d50fe610185ecd23c55d8b307d75ca18e94c0e7de328089ac5dcba", size = 231649 },
+ { url = "https://files.pythonhosted.org/packages/ee/59/928322800306f6529d1852323014ee9008551e9bb027cc38d276cbc0b0e7/frozenlist-1.5.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f3c8c1dacd037df16e85227bac13cca58c30da836c6f936ba1df0c05d046d8d", size = 241853 },
+ { url = "https://files.pythonhosted.org/packages/7d/bd/e01fa4f146a6f6c18c5d34cab8abdc4013774a26c4ff851128cd1bd3008e/frozenlist-1.5.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f2ac49a9bedb996086057b75bf93538240538c6d9b38e57c82d51f75a73409d2", size = 243652 },
+ { url = "https://files.pythonhosted.org/packages/a5/bd/e4771fd18a8ec6757033f0fa903e447aecc3fbba54e3630397b61596acf0/frozenlist-1.5.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e66cc454f97053b79c2ab09c17fbe3c825ea6b4de20baf1be28919460dd7877f", size = 241734 },
+ { url = "https://files.pythonhosted.org/packages/21/13/c83821fa5544af4f60c5d3a65d054af3213c26b14d3f5f48e43e5fb48556/frozenlist-1.5.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:5a3ba5f9a0dfed20337d3e966dc359784c9f96503674c2faf015f7fe8e96798c", size = 260959 },
+ { url = "https://files.pythonhosted.org/packages/71/f3/1f91c9a9bf7ed0e8edcf52698d23f3c211d8d00291a53c9f115ceb977ab1/frozenlist-1.5.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:6321899477db90bdeb9299ac3627a6a53c7399c8cd58d25da094007402b039ab", size = 262706 },
+ { url = "https://files.pythonhosted.org/packages/4c/22/4a256fdf5d9bcb3ae32622c796ee5ff9451b3a13a68cfe3f68e2c95588ce/frozenlist-1.5.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:76e4753701248476e6286f2ef492af900ea67d9706a0155335a40ea21bf3b2f5", size = 250401 },
+ { url = "https://files.pythonhosted.org/packages/af/89/c48ebe1f7991bd2be6d5f4ed202d94960c01b3017a03d6954dd5fa9ea1e8/frozenlist-1.5.0-cp310-cp310-win32.whl", hash = "sha256:977701c081c0241d0955c9586ffdd9ce44f7a7795df39b9151cd9a6fd0ce4cfb", size = 45498 },
+ { url = "https://files.pythonhosted.org/packages/28/2f/cc27d5f43e023d21fe5c19538e08894db3d7e081cbf582ad5ed366c24446/frozenlist-1.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:189f03b53e64144f90990d29a27ec4f7997d91ed3d01b51fa39d2dbe77540fd4", size = 51622 },
+ { url = "https://files.pythonhosted.org/packages/79/43/0bed28bf5eb1c9e4301003b74453b8e7aa85fb293b31dde352aac528dafc/frozenlist-1.5.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:fd74520371c3c4175142d02a976aee0b4cb4a7cc912a60586ffd8d5929979b30", size = 94987 },
+ { url = "https://files.pythonhosted.org/packages/bb/bf/b74e38f09a246e8abbe1e90eb65787ed745ccab6eaa58b9c9308e052323d/frozenlist-1.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2f3f7a0fbc219fb4455264cae4d9f01ad41ae6ee8524500f381de64ffaa077d5", size = 54584 },
+ { url = "https://files.pythonhosted.org/packages/2c/31/ab01375682f14f7613a1ade30149f684c84f9b8823a4391ed950c8285656/frozenlist-1.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f47c9c9028f55a04ac254346e92977bf0f166c483c74b4232bee19a6697e4778", size = 52499 },
+ { url = "https://files.pythonhosted.org/packages/98/a8/d0ac0b9276e1404f58fec3ab6e90a4f76b778a49373ccaf6a563f100dfbc/frozenlist-1.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0996c66760924da6e88922756d99b47512a71cfd45215f3570bf1e0b694c206a", size = 276357 },
+ { url = "https://files.pythonhosted.org/packages/ad/c9/c7761084fa822f07dac38ac29f841d4587570dd211e2262544aa0b791d21/frozenlist-1.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a2fe128eb4edeabe11896cb6af88fca5346059f6c8d807e3b910069f39157869", size = 287516 },
+ { url = "https://files.pythonhosted.org/packages/a1/ff/cd7479e703c39df7bdab431798cef89dc75010d8aa0ca2514c5b9321db27/frozenlist-1.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1a8ea951bbb6cacd492e3948b8da8c502a3f814f5d20935aae74b5df2b19cf3d", size = 283131 },
+ { url = "https://files.pythonhosted.org/packages/59/a0/370941beb47d237eca4fbf27e4e91389fd68699e6f4b0ebcc95da463835b/frozenlist-1.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de537c11e4aa01d37db0d403b57bd6f0546e71a82347a97c6a9f0dcc532b3a45", size = 261320 },
+ { url = "https://files.pythonhosted.org/packages/b8/5f/c10123e8d64867bc9b4f2f510a32042a306ff5fcd7e2e09e5ae5100ee333/frozenlist-1.5.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c2623347b933fcb9095841f1cc5d4ff0b278addd743e0e966cb3d460278840d", size = 274877 },
+ { url = "https://files.pythonhosted.org/packages/fa/79/38c505601ae29d4348f21706c5d89755ceded02a745016ba2f58bd5f1ea6/frozenlist-1.5.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:cee6798eaf8b1416ef6909b06f7dc04b60755206bddc599f52232606e18179d3", size = 269592 },
+ { url = "https://files.pythonhosted.org/packages/19/e2/39f3a53191b8204ba9f0bb574b926b73dd2efba2a2b9d2d730517e8f7622/frozenlist-1.5.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:f5f9da7f5dbc00a604fe74aa02ae7c98bcede8a3b8b9666f9f86fc13993bc71a", size = 265934 },
+ { url = "https://files.pythonhosted.org/packages/d5/c9/3075eb7f7f3a91f1a6b00284af4de0a65a9ae47084930916f5528144c9dd/frozenlist-1.5.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:90646abbc7a5d5c7c19461d2e3eeb76eb0b204919e6ece342feb6032c9325ae9", size = 283859 },
+ { url = "https://files.pythonhosted.org/packages/05/f5/549f44d314c29408b962fa2b0e69a1a67c59379fb143b92a0a065ffd1f0f/frozenlist-1.5.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:bdac3c7d9b705d253b2ce370fde941836a5f8b3c5c2b8fd70940a3ea3af7f4f2", size = 287560 },
+ { url = "https://files.pythonhosted.org/packages/9d/f8/cb09b3c24a3eac02c4c07a9558e11e9e244fb02bf62c85ac2106d1eb0c0b/frozenlist-1.5.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03d33c2ddbc1816237a67f66336616416e2bbb6beb306e5f890f2eb22b959cdf", size = 277150 },
+ { url = "https://files.pythonhosted.org/packages/37/48/38c2db3f54d1501e692d6fe058f45b6ad1b358d82cd19436efab80cfc965/frozenlist-1.5.0-cp311-cp311-win32.whl", hash = "sha256:237f6b23ee0f44066219dae14c70ae38a63f0440ce6750f868ee08775073f942", size = 45244 },
+ { url = "https://files.pythonhosted.org/packages/ca/8c/2ddffeb8b60a4bce3b196c32fcc30d8830d4615e7b492ec2071da801b8ad/frozenlist-1.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:0cc974cc93d32c42e7b0f6cf242a6bd941c57c61b618e78b6c0a96cb72788c1d", size = 51634 },
+ { url = "https://files.pythonhosted.org/packages/79/73/fa6d1a96ab7fd6e6d1c3500700963eab46813847f01ef0ccbaa726181dd5/frozenlist-1.5.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:31115ba75889723431aa9a4e77d5f398f5cf976eea3bdf61749731f62d4a4a21", size = 94026 },
+ { url = "https://files.pythonhosted.org/packages/ab/04/ea8bf62c8868b8eada363f20ff1b647cf2e93377a7b284d36062d21d81d1/frozenlist-1.5.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7437601c4d89d070eac8323f121fcf25f88674627505334654fd027b091db09d", size = 54150 },
+ { url = "https://files.pythonhosted.org/packages/d0/9a/8e479b482a6f2070b26bda572c5e6889bb3ba48977e81beea35b5ae13ece/frozenlist-1.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7948140d9f8ece1745be806f2bfdf390127cf1a763b925c4a805c603df5e697e", size = 51927 },
+ { url = "https://files.pythonhosted.org/packages/e3/12/2aad87deb08a4e7ccfb33600871bbe8f0e08cb6d8224371387f3303654d7/frozenlist-1.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:feeb64bc9bcc6b45c6311c9e9b99406660a9c05ca8a5b30d14a78555088b0b3a", size = 282647 },
+ { url = "https://files.pythonhosted.org/packages/77/f2/07f06b05d8a427ea0060a9cef6e63405ea9e0d761846b95ef3fb3be57111/frozenlist-1.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:683173d371daad49cffb8309779e886e59c2f369430ad28fe715f66d08d4ab1a", size = 289052 },
+ { url = "https://files.pythonhosted.org/packages/bd/9f/8bf45a2f1cd4aa401acd271b077989c9267ae8463e7c8b1eb0d3f561b65e/frozenlist-1.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7d57d8f702221405a9d9b40f9da8ac2e4a1a8b5285aac6100f3393675f0a85ee", size = 291719 },
+ { url = "https://files.pythonhosted.org/packages/41/d1/1f20fd05a6c42d3868709b7604c9f15538a29e4f734c694c6bcfc3d3b935/frozenlist-1.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30c72000fbcc35b129cb09956836c7d7abf78ab5416595e4857d1cae8d6251a6", size = 267433 },
+ { url = "https://files.pythonhosted.org/packages/af/f2/64b73a9bb86f5a89fb55450e97cd5c1f84a862d4ff90d9fd1a73ab0f64a5/frozenlist-1.5.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:000a77d6034fbad9b6bb880f7ec073027908f1b40254b5d6f26210d2dab1240e", size = 283591 },
+ { url = "https://files.pythonhosted.org/packages/29/e2/ffbb1fae55a791fd6c2938dd9ea779509c977435ba3940b9f2e8dc9d5316/frozenlist-1.5.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5d7f5a50342475962eb18b740f3beecc685a15b52c91f7d975257e13e029eca9", size = 273249 },
+ { url = "https://files.pythonhosted.org/packages/2e/6e/008136a30798bb63618a114b9321b5971172a5abddff44a100c7edc5ad4f/frozenlist-1.5.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:87f724d055eb4785d9be84e9ebf0f24e392ddfad00b3fe036e43f489fafc9039", size = 271075 },
+ { url = "https://files.pythonhosted.org/packages/ae/f0/4e71e54a026b06724cec9b6c54f0b13a4e9e298cc8db0f82ec70e151f5ce/frozenlist-1.5.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:6e9080bb2fb195a046e5177f10d9d82b8a204c0736a97a153c2466127de87784", size = 285398 },
+ { url = "https://files.pythonhosted.org/packages/4d/36/70ec246851478b1c0b59f11ef8ade9c482ff447c1363c2bd5fad45098b12/frozenlist-1.5.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:9b93d7aaa36c966fa42efcaf716e6b3900438632a626fb09c049f6a2f09fc631", size = 294445 },
+ { url = "https://files.pythonhosted.org/packages/37/e0/47f87544055b3349b633a03c4d94b405956cf2437f4ab46d0928b74b7526/frozenlist-1.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:52ef692a4bc60a6dd57f507429636c2af8b6046db8b31b18dac02cbc8f507f7f", size = 280569 },
+ { url = "https://files.pythonhosted.org/packages/f9/7c/490133c160fb6b84ed374c266f42800e33b50c3bbab1652764e6e1fc498a/frozenlist-1.5.0-cp312-cp312-win32.whl", hash = "sha256:29d94c256679247b33a3dc96cce0f93cbc69c23bf75ff715919332fdbb6a32b8", size = 44721 },
+ { url = "https://files.pythonhosted.org/packages/b1/56/4e45136ffc6bdbfa68c29ca56ef53783ef4c2fd395f7cbf99a2624aa9aaa/frozenlist-1.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:8969190d709e7c48ea386db202d708eb94bdb29207a1f269bab1196ce0dcca1f", size = 51329 },
+ { url = "https://files.pythonhosted.org/packages/c6/c8/a5be5b7550c10858fcf9b0ea054baccab474da77d37f1e828ce043a3a5d4/frozenlist-1.5.0-py3-none-any.whl", hash = "sha256:d994863bba198a4a518b467bb971c56e1db3f180a25c6cf7bb1949c267f748c3", size = 11901 },
+]
+
+[[package]]
+name = "fsspec"
+version = "2024.9.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/62/7c/12b0943011daaaa9c35c2a2e22e5eb929ac90002f08f1259d69aedad84de/fsspec-2024.9.0.tar.gz", hash = "sha256:4b0afb90c2f21832df142f292649035d80b421f60a9e1c027802e5a0da2b04e8", size = 286206 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/1d/a0/6aaea0c2fbea2f89bfd5db25fb1e3481896a423002ebe4e55288907a97a3/fsspec-2024.9.0-py3-none-any.whl", hash = "sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b", size = 179253 },
+]
+
+[package.optional-dependencies]
+http = [
+ { name = "aiohttp" },
+]
+
+[[package]]
+name = "ftfy"
+version = "6.3.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "wcwidth" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/a5/d3/8650919bc3c7c6e90ee3fa7fd618bf373cbbe55dff043bd67353dbb20cd8/ftfy-6.3.1.tar.gz", hash = "sha256:9b3c3d90f84fb267fe64d375a07b7f8912d817cf86009ae134aa03e1819506ec", size = 308927 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ab/6e/81d47999aebc1b155f81eca4477a616a70f238a2549848c38983f3c22a82/ftfy-6.3.1-py3-none-any.whl", hash = "sha256:7c70eb532015cd2f9adb53f101fb6c7945988d023a085d127d1573dc49dd0083", size = 44821 },
+]
+
+[[package]]
+name = "gitdb"
+version = "4.0.12"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "smmap" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794 },
+]
+
+[[package]]
+name = "gitpython"
+version = "3.1.44"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "gitdb" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/c0/89/37df0b71473153574a5cdef8f242de422a0f5d26d7a9e231e6f169b4ad14/gitpython-3.1.44.tar.gz", hash = "sha256:c87e30b26253bf5418b01b0660f818967f3c503193838337fe5e573331249269", size = 214196 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/1d/9a/4114a9057db2f1462d5c8f8390ab7383925fe1ac012eaa42402ad65c2963/GitPython-3.1.44-py3-none-any.whl", hash = "sha256:9e0e10cda9bed1ee64bc9a6de50e7e38a9c9943241cd7f585f6df3ed28011110", size = 207599 },
+]
+
+[[package]]
+name = "glob2"
+version = "0.7"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/d7/a5/bbbc3b74a94fbdbd7915e7ad030f16539bfdc1362f7e9003b594f0537950/glob2-0.7.tar.gz", hash = "sha256:85c3dbd07c8aa26d63d7aacee34fa86e9a91a3873bc30bf62ec46e531f92ab8c", size = 10697 }
+
+[[package]]
+name = "grpcio"
+version = "1.70.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/69/e1/4b21b5017c33f3600dcc32b802bb48fe44a4d36d6c066f52650c7c2690fa/grpcio-1.70.0.tar.gz", hash = "sha256:8d1584a68d5922330025881e63a6c1b54cc8117291d382e4fa69339b6d914c56", size = 12788932 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/10/e9/f72408bac1f7b05b25e4df569b02d6b200c8e7857193aa9f1df7a3744add/grpcio-1.70.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:95469d1977429f45fe7df441f586521361e235982a0b39e33841549143ae2851", size = 5229736 },
+ { url = "https://files.pythonhosted.org/packages/b3/17/e65139ea76dac7bcd8a3f17cbd37e3d1a070c44db3098d0be5e14c5bd6a1/grpcio-1.70.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:ed9718f17fbdb472e33b869c77a16d0b55e166b100ec57b016dc7de9c8d236bf", size = 11432751 },
+ { url = "https://files.pythonhosted.org/packages/a0/12/42de6082b4ab14a59d30b2fc7786882fdaa75813a4a4f3d4a8c4acd6ed59/grpcio-1.70.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:374d014f29f9dfdb40510b041792e0e2828a1389281eb590df066e1cc2b404e5", size = 5711439 },
+ { url = "https://files.pythonhosted.org/packages/34/f8/b5a19524d273cbd119274a387bb72d6fbb74578e13927a473bc34369f079/grpcio-1.70.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2af68a6f5c8f78d56c145161544ad0febbd7479524a59c16b3e25053f39c87f", size = 6330777 },
+ { url = "https://files.pythonhosted.org/packages/1a/67/3d6c0ad786238aac7fa93b79246fc452978fbfe9e5f86f70da8e8a2797d0/grpcio-1.70.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce7df14b2dcd1102a2ec32f621cc9fab6695effef516efbc6b063ad749867295", size = 5944639 },
+ { url = "https://files.pythonhosted.org/packages/76/0d/d9f7cbc41c2743cf18236a29b6a582f41bd65572a7144d92b80bc1e68479/grpcio-1.70.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c78b339869f4dbf89881e0b6fbf376313e4f845a42840a7bdf42ee6caed4b11f", size = 6643543 },
+ { url = "https://files.pythonhosted.org/packages/fc/24/bdd7e606b3400c14330e33a4698fa3a49e38a28c9e0a831441adbd3380d2/grpcio-1.70.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:58ad9ba575b39edef71f4798fdb5c7b6d02ad36d47949cd381d4392a5c9cbcd3", size = 6199897 },
+ { url = "https://files.pythonhosted.org/packages/d1/33/8132eb370087960c82d01b89faeb28f3e58f5619ffe19889f57c58a19c18/grpcio-1.70.0-cp310-cp310-win32.whl", hash = "sha256:2b0d02e4b25a5c1f9b6c7745d4fa06efc9fd6a611af0fb38d3ba956786b95199", size = 3617513 },
+ { url = "https://files.pythonhosted.org/packages/99/bc/0fce5cfc0ca969df66f5dca6cf8d2258abb88146bf9ab89d8cf48e970137/grpcio-1.70.0-cp310-cp310-win_amd64.whl", hash = "sha256:0de706c0a5bb9d841e353f6343a9defc9fc35ec61d6eb6111802f3aa9fef29e1", size = 4303342 },
+ { url = "https://files.pythonhosted.org/packages/65/c4/1f67d23d6bcadd2fd61fb460e5969c52b3390b4a4e254b5e04a6d1009e5e/grpcio-1.70.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:17325b0be0c068f35770f944124e8839ea3185d6d54862800fc28cc2ffad205a", size = 5229017 },
+ { url = "https://files.pythonhosted.org/packages/e4/bd/cc36811c582d663a740fb45edf9f99ddbd99a10b6ba38267dc925e1e193a/grpcio-1.70.0-cp311-cp311-macosx_10_14_universal2.whl", hash = "sha256:dbe41ad140df911e796d4463168e33ef80a24f5d21ef4d1e310553fcd2c4a386", size = 11472027 },
+ { url = "https://files.pythonhosted.org/packages/7e/32/8538bb2ace5cd72da7126d1c9804bf80b4fe3be70e53e2d55675c24961a8/grpcio-1.70.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:5ea67c72101d687d44d9c56068328da39c9ccba634cabb336075fae2eab0d04b", size = 5707785 },
+ { url = "https://files.pythonhosted.org/packages/ce/5c/a45f85f2a0dfe4a6429dee98717e0e8bd7bd3f604315493c39d9679ca065/grpcio-1.70.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb5277db254ab7586769e490b7b22f4ddab3876c490da0a1a9d7c695ccf0bf77", size = 6331599 },
+ { url = "https://files.pythonhosted.org/packages/9f/e5/5316b239380b8b2ad30373eb5bb25d9fd36c0375e94a98a0a60ea357d254/grpcio-1.70.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7831a0fc1beeeb7759f737f5acd9fdcda520e955049512d68fda03d91186eea", size = 5940834 },
+ { url = "https://files.pythonhosted.org/packages/05/33/dbf035bc6d167068b4a9f2929dfe0b03fb763f0f861ecb3bb1709a14cb65/grpcio-1.70.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:27cc75e22c5dba1fbaf5a66c778e36ca9b8ce850bf58a9db887754593080d839", size = 6641191 },
+ { url = "https://files.pythonhosted.org/packages/4c/c4/684d877517e5bfd6232d79107e5a1151b835e9f99051faef51fed3359ec4/grpcio-1.70.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d63764963412e22f0491d0d32833d71087288f4e24cbcddbae82476bfa1d81fd", size = 6198744 },
+ { url = "https://files.pythonhosted.org/packages/e9/43/92fe5eeaf340650a7020cfb037402c7b9209e7a0f3011ea1626402219034/grpcio-1.70.0-cp311-cp311-win32.whl", hash = "sha256:bb491125103c800ec209d84c9b51f1c60ea456038e4734688004f377cfacc113", size = 3617111 },
+ { url = "https://files.pythonhosted.org/packages/55/15/b6cf2c9515c028aff9da6984761a3ab484a472b0dc6435fcd07ced42127d/grpcio-1.70.0-cp311-cp311-win_amd64.whl", hash = "sha256:d24035d49e026353eb042bf7b058fb831db3e06d52bee75c5f2f3ab453e71aca", size = 4304604 },
+ { url = "https://files.pythonhosted.org/packages/4c/a4/ddbda79dd176211b518f0f3795af78b38727a31ad32bc149d6a7b910a731/grpcio-1.70.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:ef4c14508299b1406c32bdbb9fb7b47612ab979b04cf2b27686ea31882387cff", size = 5198135 },
+ { url = "https://files.pythonhosted.org/packages/30/5c/60eb8a063ea4cb8d7670af8fac3f2033230fc4b75f62669d67c66ac4e4b0/grpcio-1.70.0-cp312-cp312-macosx_10_14_universal2.whl", hash = "sha256:aa47688a65643afd8b166928a1da6247d3f46a2784d301e48ca1cc394d2ffb40", size = 11447529 },
+ { url = "https://files.pythonhosted.org/packages/fb/b9/1bf8ab66729f13b44e8f42c9de56417d3ee6ab2929591cfee78dce749b57/grpcio-1.70.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:880bfb43b1bb8905701b926274eafce5c70a105bc6b99e25f62e98ad59cb278e", size = 5664484 },
+ { url = "https://files.pythonhosted.org/packages/d1/06/2f377d6906289bee066d96e9bdb91e5e96d605d173df9bb9856095cccb57/grpcio-1.70.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e654c4b17d07eab259d392e12b149c3a134ec52b11ecdc6a515b39aceeec898", size = 6303739 },
+ { url = "https://files.pythonhosted.org/packages/ae/50/64c94cfc4db8d9ed07da71427a936b5a2bd2b27c66269b42fbda82c7c7a4/grpcio-1.70.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2394e3381071045a706ee2eeb6e08962dd87e8999b90ac15c55f56fa5a8c9597", size = 5910417 },
+ { url = "https://files.pythonhosted.org/packages/53/89/8795dfc3db4389c15554eb1765e14cba8b4c88cc80ff828d02f5572965af/grpcio-1.70.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:b3c76701428d2df01964bc6479422f20e62fcbc0a37d82ebd58050b86926ef8c", size = 6626797 },
+ { url = "https://files.pythonhosted.org/packages/9c/b2/6a97ac91042a2c59d18244c479ee3894e7fb6f8c3a90619bb5a7757fa30c/grpcio-1.70.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ac073fe1c4cd856ebcf49e9ed6240f4f84d7a4e6ee95baa5d66ea05d3dd0df7f", size = 6190055 },
+ { url = "https://files.pythonhosted.org/packages/86/2b/28db55c8c4d156053a8c6f4683e559cd0a6636f55a860f87afba1ac49a51/grpcio-1.70.0-cp312-cp312-win32.whl", hash = "sha256:cd24d2d9d380fbbee7a5ac86afe9787813f285e684b0271599f95a51bce33528", size = 3600214 },
+ { url = "https://files.pythonhosted.org/packages/17/c3/a7a225645a965029ed432e5b5e9ed959a574e62100afab553eef58be0e37/grpcio-1.70.0-cp312-cp312-win_amd64.whl", hash = "sha256:0495c86a55a04a874c7627fd33e5beaee771917d92c0e6d9d797628ac40e7655", size = 4292538 },
+]
+
+[[package]]
+name = "h11"
+version = "0.14.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/f5/38/3af3d3633a34a3316095b39c8e8fb4853a28a536e55d347bd8d8e9a14b03/h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d", size = 100418 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259 },
+]
+
+[[package]]
+name = "h5py"
+version = "3.12.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "numpy" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/cc/0c/5c2b0a88158682aeafb10c1c2b735df5bc31f165bfe192f2ee9f2a23b5f1/h5py-3.12.1.tar.gz", hash = "sha256:326d70b53d31baa61f00b8aa5f95c2fcb9621a3ee8365d770c551a13dbbcbfdf", size = 411457 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/df/7d/b21045fbb004ad8bb6fb3be4e6ca903841722706f7130b9bba31ef2f88e3/h5py-3.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2f0f1a382cbf494679c07b4371f90c70391dedb027d517ac94fa2c05299dacda", size = 3402133 },
+ { url = "https://files.pythonhosted.org/packages/29/a7/3c2a33fba1da64a0846744726fd067a92fb8abb887875a0dd8e3bac8b45d/h5py-3.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cb65f619dfbdd15e662423e8d257780f9a66677eae5b4b3fc9dca70b5fd2d2a3", size = 2866436 },
+ { url = "https://files.pythonhosted.org/packages/1e/d0/4bf67c3937a2437c20844165766ddd1a1817ae6b9544c3743050d8e0f403/h5py-3.12.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b15d8dbd912c97541312c0e07438864d27dbca857c5ad634de68110c6beb1c2", size = 5168596 },
+ { url = "https://files.pythonhosted.org/packages/85/bc/e76f4b2096e0859225f5441d1b7f5e2041fffa19fc2c16756c67078417aa/h5py-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59685fe40d8c1fbbee088c88cd4da415a2f8bee5c270337dc5a1c4aa634e3307", size = 5341537 },
+ { url = "https://files.pythonhosted.org/packages/99/bd/fb8ed45308bb97e04c02bd7aed324ba11e6a4bf9ed73967ca2a168e9cf92/h5py-3.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:577d618d6b6dea3da07d13cc903ef9634cde5596b13e832476dd861aaf651f3e", size = 2990575 },
+ { url = "https://files.pythonhosted.org/packages/33/61/c463dc5fc02fbe019566d067a9d18746cd3c664f29c9b8b3c3f9ed025365/h5py-3.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ccd9006d92232727d23f784795191bfd02294a4f2ba68708825cb1da39511a93", size = 3410828 },
+ { url = "https://files.pythonhosted.org/packages/95/9d/eb91a9076aa998bb2179d6b1788055ea09cdf9d6619cd967f1d3321ed056/h5py-3.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ad8a76557880aed5234cfe7279805f4ab5ce16b17954606cca90d578d3e713ef", size = 2872586 },
+ { url = "https://files.pythonhosted.org/packages/b0/62/e2b1f9723ff713e3bd3c16dfeceec7017eadc21ef063d8b7080c0fcdc58a/h5py-3.12.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1473348139b885393125126258ae2d70753ef7e9cec8e7848434f385ae72069e", size = 5273038 },
+ { url = "https://files.pythonhosted.org/packages/e1/89/118c3255d6ff2db33b062ec996a762d99ae50c21f54a8a6047ae8eda1b9f/h5py-3.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:018a4597f35092ae3fb28ee851fdc756d2b88c96336b8480e124ce1ac6fb9166", size = 5452688 },
+ { url = "https://files.pythonhosted.org/packages/1d/4d/cbd3014eb78d1e449b29beba1f3293a841aa8086c6f7968c383c2c7ff076/h5py-3.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:3fdf95092d60e8130ba6ae0ef7a9bd4ade8edbe3569c13ebbaf39baefffc5ba4", size = 3006095 },
+ { url = "https://files.pythonhosted.org/packages/d4/e1/ea9bfe18a3075cdc873f0588ff26ce394726047653557876d7101bf0c74e/h5py-3.12.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:06a903a4e4e9e3ebbc8b548959c3c2552ca2d70dac14fcfa650d9261c66939ed", size = 3372538 },
+ { url = "https://files.pythonhosted.org/packages/0d/74/1009b663387c025e8fa5f3ee3cf3cd0d99b1ad5c72eeb70e75366b1ce878/h5py-3.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7b3b8f3b48717e46c6a790e3128d39c61ab595ae0a7237f06dfad6a3b51d5351", size = 2868104 },
+ { url = "https://files.pythonhosted.org/packages/af/52/c604adc06280c15a29037d4aa79a24fe54d8d0b51085e81ed24b2fa995f7/h5py-3.12.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:050a4f2c9126054515169c49cb900949814987f0c7ae74c341b0c9f9b5056834", size = 5194606 },
+ { url = "https://files.pythonhosted.org/packages/fa/63/eeaacff417b393491beebabb8a3dc5342950409eb6d7b39d437289abdbae/h5py-3.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c4b41d1019322a5afc5082864dfd6359f8935ecd37c11ac0029be78c5d112c9", size = 5413256 },
+ { url = "https://files.pythonhosted.org/packages/86/f7/bb465dcb92ca3521a15cbe1031f6d18234dbf1fb52a6796a00bfaa846ebf/h5py-3.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:e4d51919110a030913201422fb07987db4338eba5ec8c5a15d6fab8e03d443fc", size = 2993055 },
+]
+
+[[package]]
+name = "hf-transfer"
+version = "0.1.9"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/1a/eb/8fc64f40388c29ce8ce3b2b180a089d4d6b25b1d0d232d016704cb852104/hf_transfer-0.1.9.tar.gz", hash = "sha256:035572865dab29d17e783fbf1e84cf1cb24f3fcf8f1b17db1cfc7fdf139f02bf", size = 25201 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/81/f5/461d2e5f307e5048289b1168d5c642ae3bb2504e88dff1a38b92ed990a21/hf_transfer-0.1.9-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:e66acf91df4a8b72f60223059df3003062a5ae111757187ed1a06750a30e911b", size = 1393046 },
+ { url = "https://files.pythonhosted.org/packages/41/ba/8d9fd9f1083525edfcb389c93738c802f3559cb749324090d7109c8bf4c2/hf_transfer-0.1.9-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:8669dbcc7a3e2e8d61d42cd24da9c50d57770bd74b445c65123291ca842a7e7a", size = 1348126 },
+ { url = "https://files.pythonhosted.org/packages/8e/a2/cd7885bc9959421065a6fae0fe67b6c55becdeda4e69b873e52976f9a9f0/hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8fd0167c4407a3bc4cdd0307e65ada2294ec04f1813d8a69a5243e379b22e9d8", size = 3728604 },
+ { url = "https://files.pythonhosted.org/packages/f6/2e/a072cf196edfeda3310c9a5ade0a0fdd785e6154b3ce24fc738c818da2a7/hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ee8b10afedcb75f71091bcc197c526a6ebf5c58bbbadb34fdeee6160f55f619f", size = 3064995 },
+ { url = "https://files.pythonhosted.org/packages/c2/84/aec9ef4c0fab93c1ea2b1badff38c78b4b2f86f0555b26d2051dbc920cde/hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5828057e313de59300dd1abb489444bc452efe3f479d3c55b31a8f680936ba42", size = 3580908 },
+ { url = "https://files.pythonhosted.org/packages/29/63/b560d39651a56603d64f1a0212d0472a44cbd965db2fa62b99d99cb981bf/hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc6bd19e1cc177c66bdef15ef8636ad3bde79d5a4f608c158021153b4573509d", size = 3400839 },
+ { url = "https://files.pythonhosted.org/packages/d6/d8/f87ea6f42456254b48915970ed98e993110521e9263472840174d32c880d/hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdca9bfb89e6f8f281890cc61a8aff2d3cecaff7e1a4d275574d96ca70098557", size = 3552664 },
+ { url = "https://files.pythonhosted.org/packages/d6/56/1267c39b65fc8f4e2113b36297320f102718bf5799b544a6cbe22013aa1d/hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:89a23f58b7b7effbc047b8ca286f131b17728c99a9f972723323003ffd1bb916", size = 4073732 },
+ { url = "https://files.pythonhosted.org/packages/82/1a/9c748befbe3decf7cb415e34f8a0c3789a0a9c55910dea73d581e48c0ce5/hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:dc7fff1345980d6c0ebb92c811d24afa4b98b3e07ed070c8e38cc91fd80478c5", size = 3390096 },
+ { url = "https://files.pythonhosted.org/packages/72/85/4c03da147b6b4b7cb12e074d3d44eee28604a387ed0eaf7eaaead5069c57/hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:1a6bd16c667ebe89a069ca163060127a794fa3a3525292c900b8c8cc47985b0d", size = 3664743 },
+ { url = "https://files.pythonhosted.org/packages/e7/6e/e597b04f753f1b09e6893075d53a82a30c13855cbaa791402695b01e369f/hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:d2fde99d502093ade3ab1b53f80da18480e9902aa960dab7f74fb1b9e5bc5746", size = 3695243 },
+ { url = "https://files.pythonhosted.org/packages/09/89/d4e234727a26b2546c8fb70a276cd924260d60135f2165bf8b9ed67bb9a4/hf_transfer-0.1.9-cp38-abi3-win32.whl", hash = "sha256:435cc3cdc8524ce57b074032b8fd76eed70a4224d2091232fa6a8cef8fd6803e", size = 1086605 },
+ { url = "https://files.pythonhosted.org/packages/a1/14/f1e15b851d1c2af5b0b1a82bf8eb10bda2da62d98180220ba6fd8879bb5b/hf_transfer-0.1.9-cp38-abi3-win_amd64.whl", hash = "sha256:16f208fc678911c37e11aa7b586bc66a37d02e636208f18b6bc53d29b5df40ad", size = 1160240 },
+]
+
+[[package]]
+name = "hjson"
+version = "3.1.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/82/e5/0b56d723a76ca67abadbf7fb71609fb0ea7e6926e94fcca6c65a85b36a0e/hjson-3.1.0.tar.gz", hash = "sha256:55af475a27cf83a7969c808399d7bccdec8fb836a07ddbd574587593b9cdcf75", size = 40541 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/1f/7f/13cd798d180af4bf4c0ceddeefba2b864a63c71645abc0308b768d67bb81/hjson-3.1.0-py3-none-any.whl", hash = "sha256:65713cdcf13214fb554eb8b4ef803419733f4f5e551047c9b711098ab7186b89", size = 54018 },
+]
+
+[[package]]
+name = "hpsv2x"
+version = "1.2.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "braceexpand" },
+ { name = "clint" },
+ { name = "einops" },
+ { name = "fsspec" },
+ { name = "ftfy" },
+ { name = "huggingface-hub" },
+ { name = "pandas" },
+ { name = "protobuf" },
+ { name = "pyarrow" },
+ { name = "pytest" },
+ { name = "pytest-split" },
+ { name = "regex" },
+ { name = "requests" },
+ { name = "sentencepiece" },
+ { name = "timm" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "torchvision", version = "0.21.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torchvision", version = "0.21.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "tqdm" },
+ { name = "transformers" },
+ { name = "webdataset" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/4e/87/7ff2e74350a04827802239e2146c38a27067ffcd8f73399131fa2fb5466e/hpsv2x-1.2.0.tar.gz", hash = "sha256:66220d040a34183e6cb53b22023a034278126f052506e54a9d89608d9e766f39", size = 9005708 }
+
+[[package]]
+name = "httpcore"
+version = "1.0.7"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "certifi" },
+ { name = "h11" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/6a/41/d7d0a89eb493922c37d343b607bc1b5da7f5be7e383740b4753ad8943e90/httpcore-1.0.7.tar.gz", hash = "sha256:8551cb62a169ec7162ac7be8d4817d561f60e08eaa485234898414bb5a8a0b4c", size = 85196 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/87/f5/72347bc88306acb359581ac4d52f23c0ef445b57157adedb9aee0cd689d2/httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd", size = 78551 },
+]
+
+[[package]]
+name = "httptools"
+version = "0.6.4"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/a7/9a/ce5e1f7e131522e6d3426e8e7a490b3a01f39a6696602e1c4f33f9e94277/httptools-0.6.4.tar.gz", hash = "sha256:4e93eee4add6493b59a5c514da98c939b244fce4a0d8879cd3f466562f4b7d5c", size = 240639 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/3b/6f/972f8eb0ea7d98a1c6be436e2142d51ad2a64ee18e02b0e7ff1f62171ab1/httptools-0.6.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3c73ce323711a6ffb0d247dcd5a550b8babf0f757e86a52558fe5b86d6fefcc0", size = 198780 },
+ { url = "https://files.pythonhosted.org/packages/6a/b0/17c672b4bc5c7ba7f201eada4e96c71d0a59fbc185e60e42580093a86f21/httptools-0.6.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:345c288418f0944a6fe67be8e6afa9262b18c7626c3ef3c28adc5eabc06a68da", size = 103297 },
+ { url = "https://files.pythonhosted.org/packages/92/5e/b4a826fe91971a0b68e8c2bd4e7db3e7519882f5a8ccdb1194be2b3ab98f/httptools-0.6.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:deee0e3343f98ee8047e9f4c5bc7cedbf69f5734454a94c38ee829fb2d5fa3c1", size = 443130 },
+ { url = "https://files.pythonhosted.org/packages/b0/51/ce61e531e40289a681a463e1258fa1e05e0be54540e40d91d065a264cd8f/httptools-0.6.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca80b7485c76f768a3bc83ea58373f8db7b015551117375e4918e2aa77ea9b50", size = 442148 },
+ { url = "https://files.pythonhosted.org/packages/ea/9e/270b7d767849b0c96f275c695d27ca76c30671f8eb8cc1bab6ced5c5e1d0/httptools-0.6.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:90d96a385fa941283ebd231464045187a31ad932ebfa541be8edf5b3c2328959", size = 415949 },
+ { url = "https://files.pythonhosted.org/packages/81/86/ced96e3179c48c6f656354e106934e65c8963d48b69be78f355797f0e1b3/httptools-0.6.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:59e724f8b332319e2875efd360e61ac07f33b492889284a3e05e6d13746876f4", size = 417591 },
+ { url = "https://files.pythonhosted.org/packages/75/73/187a3f620ed3175364ddb56847d7a608a6fc42d551e133197098c0143eca/httptools-0.6.4-cp310-cp310-win_amd64.whl", hash = "sha256:c26f313951f6e26147833fc923f78f95604bbec812a43e5ee37f26dc9e5a686c", size = 88344 },
+ { url = "https://files.pythonhosted.org/packages/7b/26/bb526d4d14c2774fe07113ca1db7255737ffbb119315839af2065abfdac3/httptools-0.6.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f47f8ed67cc0ff862b84a1189831d1d33c963fb3ce1ee0c65d3b0cbe7b711069", size = 199029 },
+ { url = "https://files.pythonhosted.org/packages/a6/17/3e0d3e9b901c732987a45f4f94d4e2c62b89a041d93db89eafb262afd8d5/httptools-0.6.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0614154d5454c21b6410fdf5262b4a3ddb0f53f1e1721cfd59d55f32138c578a", size = 103492 },
+ { url = "https://files.pythonhosted.org/packages/b7/24/0fe235d7b69c42423c7698d086d4db96475f9b50b6ad26a718ef27a0bce6/httptools-0.6.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8787367fbdfccae38e35abf7641dafc5310310a5987b689f4c32cc8cc3ee975", size = 462891 },
+ { url = "https://files.pythonhosted.org/packages/b1/2f/205d1f2a190b72da6ffb5f41a3736c26d6fa7871101212b15e9b5cd8f61d/httptools-0.6.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40b0f7fe4fd38e6a507bdb751db0379df1e99120c65fbdc8ee6c1d044897a636", size = 459788 },
+ { url = "https://files.pythonhosted.org/packages/6e/4c/d09ce0eff09057a206a74575ae8f1e1e2f0364d20e2442224f9e6612c8b9/httptools-0.6.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:40a5ec98d3f49904b9fe36827dcf1aadfef3b89e2bd05b0e35e94f97c2b14721", size = 433214 },
+ { url = "https://files.pythonhosted.org/packages/3e/d2/84c9e23edbccc4a4c6f96a1b8d99dfd2350289e94f00e9ccc7aadde26fb5/httptools-0.6.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:dacdd3d10ea1b4ca9df97a0a303cbacafc04b5cd375fa98732678151643d4988", size = 434120 },
+ { url = "https://files.pythonhosted.org/packages/d0/46/4d8e7ba9581416de1c425b8264e2cadd201eb709ec1584c381f3e98f51c1/httptools-0.6.4-cp311-cp311-win_amd64.whl", hash = "sha256:288cd628406cc53f9a541cfaf06041b4c71d751856bab45e3702191f931ccd17", size = 88565 },
+ { url = "https://files.pythonhosted.org/packages/bb/0e/d0b71465c66b9185f90a091ab36389a7352985fe857e352801c39d6127c8/httptools-0.6.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:df017d6c780287d5c80601dafa31f17bddb170232d85c066604d8558683711a2", size = 200683 },
+ { url = "https://files.pythonhosted.org/packages/e2/b8/412a9bb28d0a8988de3296e01efa0bd62068b33856cdda47fe1b5e890954/httptools-0.6.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:85071a1e8c2d051b507161f6c3e26155b5c790e4e28d7f236422dbacc2a9cc44", size = 104337 },
+ { url = "https://files.pythonhosted.org/packages/9b/01/6fb20be3196ffdc8eeec4e653bc2a275eca7f36634c86302242c4fbb2760/httptools-0.6.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69422b7f458c5af875922cdb5bd586cc1f1033295aa9ff63ee196a87519ac8e1", size = 508796 },
+ { url = "https://files.pythonhosted.org/packages/f7/d8/b644c44acc1368938317d76ac991c9bba1166311880bcc0ac297cb9d6bd7/httptools-0.6.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16e603a3bff50db08cd578d54f07032ca1631450ceb972c2f834c2b860c28ea2", size = 510837 },
+ { url = "https://files.pythonhosted.org/packages/52/d8/254d16a31d543073a0e57f1c329ca7378d8924e7e292eda72d0064987486/httptools-0.6.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ec4f178901fa1834d4a060320d2f3abc5c9e39766953d038f1458cb885f47e81", size = 485289 },
+ { url = "https://files.pythonhosted.org/packages/5f/3c/4aee161b4b7a971660b8be71a92c24d6c64372c1ab3ae7f366b3680df20f/httptools-0.6.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f9eb89ecf8b290f2e293325c646a211ff1c2493222798bb80a530c5e7502494f", size = 489779 },
+ { url = "https://files.pythonhosted.org/packages/12/b7/5cae71a8868e555f3f67a50ee7f673ce36eac970f029c0c5e9d584352961/httptools-0.6.4-cp312-cp312-win_amd64.whl", hash = "sha256:db78cb9ca56b59b016e64b6031eda5653be0589dba2b1b43453f6e8b405a0970", size = 88634 },
+]
+
+[[package]]
+name = "httpx"
+version = "0.28.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "anyio" },
+ { name = "certifi" },
+ { name = "httpcore" },
+ { name = "idna" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 },
+]
+
+[[package]]
+name = "huggingface-hub"
+version = "0.28.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "filelock" },
+ { name = "fsspec" },
+ { name = "packaging" },
+ { name = "pyyaml" },
+ { name = "requests" },
+ { name = "tqdm" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/e7/ce/a734204aaae6c35a22f9956ebcd8d8708ae5b842e15d6f42bd6f49e634a4/huggingface_hub-0.28.1.tar.gz", hash = "sha256:893471090c98e3b6efbdfdacafe4052b20b84d59866fb6f54c33d9af18c303ae", size = 387074 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ea/da/6c2bea5327b640920267d3bf2c9fc114cfbd0a5de234d81cda80cc9e33c8/huggingface_hub-0.28.1-py3-none-any.whl", hash = "sha256:aa6b9a3ffdae939b72c464dbb0d7f99f56e649b55c3d52406f49e0a5a620c0a7", size = 464068 },
+]
+
+[[package]]
+name = "hydra-core"
+version = "1.3.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "antlr4-python3-runtime" },
+ { name = "omegaconf" },
+ { name = "packaging" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/6d/8e/07e42bc434a847154083b315779b0a81d567154504624e181caf2c71cd98/hydra-core-1.3.2.tar.gz", hash = "sha256:8a878ed67216997c3e9d88a8e72e7b4767e81af37afb4ea3334b269a4390a824", size = 3263494 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-py3-none-any.whl", hash = "sha256:fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b", size = 154547 },
+]
+
+[[package]]
+name = "idna"
+version = "3.10"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 },
+]
+
+[[package]]
+name = "image-utilities"
+version = "0.0.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "einops" },
+ { name = "jaxtyping" },
+ { name = "numpy" },
+ { name = "pillow" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/e5/42/bf3373a34d7b799daae784434008f805a56c8e3f834b5d6f2a5cc4fd5a13/image_utilities-0.0.3.tar.gz", hash = "sha256:b6b72cdd0c5bfea3cfb68ca3846b844bde344e0abbdbe9635963286d5851a283", size = 151133 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/81/7c/8916c5cd10171ca0af46c2ce4ed2f017b2386142c49e059711917602f87d/image_utilities-0.0.3-py3-none-any.whl", hash = "sha256:cd3aa77cbb1d3f071395cf76d256223e4c8c4fa71b7da0d418fa51198bac08da", size = 25709 },
+]
+
+[[package]]
+name = "importlib-metadata"
+version = "8.6.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "zipp" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/33/08/c1395a292bb23fd03bdf572a1357c5a733d3eecbab877641ceacab23db6e/importlib_metadata-8.6.1.tar.gz", hash = "sha256:310b41d755445d74569f993ccfc22838295d9fe005425094fad953d7f15c8580", size = 55767 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/79/9d/0fb148dc4d6fa4a7dd1d8378168d9b4cd8d4560a6fbf6f0121c5fc34eb68/importlib_metadata-8.6.1-py3-none-any.whl", hash = "sha256:02a89390c1e15fdfdc0d7c6b25cb3e62650d0494005c97d6f148bf5b9787525e", size = 26971 },
+]
+
+[[package]]
+name = "iniconfig"
+version = "2.0.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 },
+]
+
+[[package]]
+name = "ipdb"
+version = "0.13.13"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "decorator" },
+ { name = "ipython" },
+ { name = "tomli", marker = "python_full_version < '3.11'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/3d/1b/7e07e7b752017f7693a0f4d41c13e5ca29ce8cbcfdcc1fd6c4ad8c0a27a0/ipdb-0.13.13.tar.gz", hash = "sha256:e3ac6018ef05126d442af680aad863006ec19d02290561ac88b8b1c0b0cfc726", size = 17042 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/0c/4c/b075da0092003d9a55cf2ecc1cae9384a1ca4f650d51b00fc59875fe76f6/ipdb-0.13.13-py3-none-any.whl", hash = "sha256:45529994741c4ab6d2388bfa5d7b725c2cf7fe9deffabdb8a6113aa5ed449ed4", size = 12130 },
+]
+
+[[package]]
+name = "ipython"
+version = "8.32.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "colorama", marker = "sys_platform == 'win32'" },
+ { name = "decorator" },
+ { name = "exceptiongroup", marker = "python_full_version < '3.11'" },
+ { name = "jedi" },
+ { name = "matplotlib-inline" },
+ { name = "pexpect", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" },
+ { name = "prompt-toolkit" },
+ { name = "pygments" },
+ { name = "stack-data" },
+ { name = "traitlets" },
+ { name = "typing-extensions", marker = "python_full_version < '3.12'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/36/80/4d2a072e0db7d250f134bc11676517299264ebe16d62a8619d49a78ced73/ipython-8.32.0.tar.gz", hash = "sha256:be2c91895b0b9ea7ba49d33b23e2040c352b33eb6a519cca7ce6e0c743444251", size = 5507441 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e7/e1/f4474a7ecdb7745a820f6f6039dc43c66add40f1bcc66485607d93571af6/ipython-8.32.0-py3-none-any.whl", hash = "sha256:cae85b0c61eff1fc48b0a8002de5958b6528fa9c8defb1894da63f42613708aa", size = 825524 },
+]
+
+[[package]]
+name = "itsdangerous"
+version = "2.2.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/9c/cb/8ac0172223afbccb63986cc25049b154ecfb5e85932587206f42317be31d/itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173", size = 54410 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/04/96/92447566d16df59b2a776c0fb82dbc4d9e07cd95062562af01e408583fc4/itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef", size = 16234 },
+]
+
+[[package]]
+name = "jaxtyping"
+version = "0.2.37"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "wadler-lindig" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/56/d1/acd3abf3587487ae131a022328c44068c5ca96e766ed2841c292c2a70ffc/jaxtyping-0.2.37.tar.gz", hash = "sha256:ae8c124abbea61b8c56455fb8b42f1183ab63353572aedfb1de355bb5b40e951", size = 45615 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/89/7b/623d15216d70e6c1cf149163ee4559e6da23ea0118f8cd7c72fdfb951298/jaxtyping-0.2.37-py3-none-any.whl", hash = "sha256:519f694ce569ba5000a1f8823e742df2d86e6da3956d2218769b6bdc5ad2d23a", size = 56263 },
+]
+
+[[package]]
+name = "jedi"
+version = "0.19.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "parso" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278 },
+]
+
+[[package]]
+name = "jinja2"
+version = "3.1.5"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "markupsafe" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/af/92/b3130cbbf5591acf9ade8708c365f3238046ac7cb8ccba6e81abccb0ccff/jinja2-3.1.5.tar.gz", hash = "sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb", size = 244674 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/bd/0f/2ba5fbcd631e3e88689309dbe978c5769e883e4b84ebfe7da30b43275c5a/jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb", size = 134596 },
+]
+
+[[package]]
+name = "joblib"
+version = "1.4.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/64/33/60135848598c076ce4b231e1b1895170f45fbcaeaa2c9d5e38b04db70c35/joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e", size = 2116621 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 },
+]
+
+[[package]]
+name = "kiwisolver"
+version = "1.4.8"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/82/59/7c91426a8ac292e1cdd53a63b6d9439abd573c875c3f92c146767dd33faf/kiwisolver-1.4.8.tar.gz", hash = "sha256:23d5f023bdc8c7e54eb65f03ca5d5bb25b601eac4d7f1a042888a1f45237987e", size = 97538 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/47/5f/4d8e9e852d98ecd26cdf8eaf7ed8bc33174033bba5e07001b289f07308fd/kiwisolver-1.4.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:88c6f252f6816a73b1f8c904f7bbe02fd67c09a69f7cb8a0eecdbf5ce78e63db", size = 124623 },
+ { url = "https://files.pythonhosted.org/packages/1d/70/7f5af2a18a76fe92ea14675f8bd88ce53ee79e37900fa5f1a1d8e0b42998/kiwisolver-1.4.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c72941acb7b67138f35b879bbe85be0f6c6a70cab78fe3ef6db9c024d9223e5b", size = 66720 },
+ { url = "https://files.pythonhosted.org/packages/c6/13/e15f804a142353aefd089fadc8f1d985561a15358c97aca27b0979cb0785/kiwisolver-1.4.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ce2cf1e5688edcb727fdf7cd1bbd0b6416758996826a8be1d958f91880d0809d", size = 65413 },
+ { url = "https://files.pythonhosted.org/packages/ce/6d/67d36c4d2054e83fb875c6b59d0809d5c530de8148846b1370475eeeece9/kiwisolver-1.4.8-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:c8bf637892dc6e6aad2bc6d4d69d08764166e5e3f69d469e55427b6ac001b19d", size = 1650826 },
+ { url = "https://files.pythonhosted.org/packages/de/c6/7b9bb8044e150d4d1558423a1568e4f227193662a02231064e3824f37e0a/kiwisolver-1.4.8-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:034d2c891f76bd3edbdb3ea11140d8510dca675443da7304205a2eaa45d8334c", size = 1628231 },
+ { url = "https://files.pythonhosted.org/packages/b6/38/ad10d437563063eaaedbe2c3540a71101fc7fb07a7e71f855e93ea4de605/kiwisolver-1.4.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d47b28d1dfe0793d5e96bce90835e17edf9a499b53969b03c6c47ea5985844c3", size = 1408938 },
+ { url = "https://files.pythonhosted.org/packages/52/ce/c0106b3bd7f9e665c5f5bc1e07cc95b5dabd4e08e3dad42dbe2faad467e7/kiwisolver-1.4.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb158fe28ca0c29f2260cca8c43005329ad58452c36f0edf298204de32a9a3ed", size = 1422799 },
+ { url = "https://files.pythonhosted.org/packages/d0/87/efb704b1d75dc9758087ba374c0f23d3254505edaedd09cf9d247f7878b9/kiwisolver-1.4.8-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5536185fce131780ebd809f8e623bf4030ce1b161353166c49a3c74c287897f", size = 1354362 },
+ { url = "https://files.pythonhosted.org/packages/eb/b3/fd760dc214ec9a8f208b99e42e8f0130ff4b384eca8b29dd0efc62052176/kiwisolver-1.4.8-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:369b75d40abedc1da2c1f4de13f3482cb99e3237b38726710f4a793432b1c5ff", size = 2222695 },
+ { url = "https://files.pythonhosted.org/packages/a2/09/a27fb36cca3fc01700687cc45dae7a6a5f8eeb5f657b9f710f788748e10d/kiwisolver-1.4.8-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:641f2ddf9358c80faa22e22eb4c9f54bd3f0e442e038728f500e3b978d00aa7d", size = 2370802 },
+ { url = "https://files.pythonhosted.org/packages/3d/c3/ba0a0346db35fe4dc1f2f2cf8b99362fbb922d7562e5f911f7ce7a7b60fa/kiwisolver-1.4.8-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:d561d2d8883e0819445cfe58d7ddd673e4015c3c57261d7bdcd3710d0d14005c", size = 2334646 },
+ { url = "https://files.pythonhosted.org/packages/41/52/942cf69e562f5ed253ac67d5c92a693745f0bed3c81f49fc0cbebe4d6b00/kiwisolver-1.4.8-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:1732e065704b47c9afca7ffa272f845300a4eb959276bf6970dc07265e73b605", size = 2467260 },
+ { url = "https://files.pythonhosted.org/packages/32/26/2d9668f30d8a494b0411d4d7d4ea1345ba12deb6a75274d58dd6ea01e951/kiwisolver-1.4.8-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bcb1ebc3547619c3b58a39e2448af089ea2ef44b37988caf432447374941574e", size = 2288633 },
+ { url = "https://files.pythonhosted.org/packages/98/99/0dd05071654aa44fe5d5e350729961e7bb535372935a45ac89a8924316e6/kiwisolver-1.4.8-cp310-cp310-win_amd64.whl", hash = "sha256:89c107041f7b27844179ea9c85d6da275aa55ecf28413e87624d033cf1f6b751", size = 71885 },
+ { url = "https://files.pythonhosted.org/packages/6c/fc/822e532262a97442989335394d441cd1d0448c2e46d26d3e04efca84df22/kiwisolver-1.4.8-cp310-cp310-win_arm64.whl", hash = "sha256:b5773efa2be9eb9fcf5415ea3ab70fc785d598729fd6057bea38d539ead28271", size = 65175 },
+ { url = "https://files.pythonhosted.org/packages/da/ed/c913ee28936c371418cb167b128066ffb20bbf37771eecc2c97edf8a6e4c/kiwisolver-1.4.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a4d3601908c560bdf880f07d94f31d734afd1bb71e96585cace0e38ef44c6d84", size = 124635 },
+ { url = "https://files.pythonhosted.org/packages/4c/45/4a7f896f7467aaf5f56ef093d1f329346f3b594e77c6a3c327b2d415f521/kiwisolver-1.4.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:856b269c4d28a5c0d5e6c1955ec36ebfd1651ac00e1ce0afa3e28da95293b561", size = 66717 },
+ { url = "https://files.pythonhosted.org/packages/5f/b4/c12b3ac0852a3a68f94598d4c8d569f55361beef6159dce4e7b624160da2/kiwisolver-1.4.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c2b9a96e0f326205af81a15718a9073328df1173a2619a68553decb7097fd5d7", size = 65413 },
+ { url = "https://files.pythonhosted.org/packages/a9/98/1df4089b1ed23d83d410adfdc5947245c753bddfbe06541c4aae330e9e70/kiwisolver-1.4.8-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c5020c83e8553f770cb3b5fc13faac40f17e0b205bd237aebd21d53d733adb03", size = 1343994 },
+ { url = "https://files.pythonhosted.org/packages/8d/bf/b4b169b050c8421a7c53ea1ea74e4ef9c335ee9013216c558a047f162d20/kiwisolver-1.4.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dace81d28c787956bfbfbbfd72fdcef014f37d9b48830829e488fdb32b49d954", size = 1434804 },
+ { url = "https://files.pythonhosted.org/packages/66/5a/e13bd341fbcf73325ea60fdc8af752addf75c5079867af2e04cc41f34434/kiwisolver-1.4.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11e1022b524bd48ae56c9b4f9296bce77e15a2e42a502cceba602f804b32bb79", size = 1450690 },
+ { url = "https://files.pythonhosted.org/packages/9b/4f/5955dcb376ba4a830384cc6fab7d7547bd6759fe75a09564910e9e3bb8ea/kiwisolver-1.4.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b9b4d2892fefc886f30301cdd80debd8bb01ecdf165a449eb6e78f79f0fabd6", size = 1376839 },
+ { url = "https://files.pythonhosted.org/packages/3a/97/5edbed69a9d0caa2e4aa616ae7df8127e10f6586940aa683a496c2c280b9/kiwisolver-1.4.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a96c0e790ee875d65e340ab383700e2b4891677b7fcd30a699146f9384a2bb0", size = 1435109 },
+ { url = "https://files.pythonhosted.org/packages/13/fc/e756382cb64e556af6c1809a1bbb22c141bbc2445049f2da06b420fe52bf/kiwisolver-1.4.8-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:23454ff084b07ac54ca8be535f4174170c1094a4cff78fbae4f73a4bcc0d4dab", size = 2245269 },
+ { url = "https://files.pythonhosted.org/packages/76/15/e59e45829d7f41c776d138245cabae6515cb4eb44b418f6d4109c478b481/kiwisolver-1.4.8-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:87b287251ad6488e95b4f0b4a79a6d04d3ea35fde6340eb38fbd1ca9cd35bbbc", size = 2393468 },
+ { url = "https://files.pythonhosted.org/packages/e9/39/483558c2a913ab8384d6e4b66a932406f87c95a6080112433da5ed668559/kiwisolver-1.4.8-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:b21dbe165081142b1232a240fc6383fd32cdd877ca6cc89eab93e5f5883e1c25", size = 2355394 },
+ { url = "https://files.pythonhosted.org/packages/01/aa/efad1fbca6570a161d29224f14b082960c7e08268a133fe5dc0f6906820e/kiwisolver-1.4.8-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:768cade2c2df13db52475bd28d3a3fac8c9eff04b0e9e2fda0f3760f20b3f7fc", size = 2490901 },
+ { url = "https://files.pythonhosted.org/packages/c9/4f/15988966ba46bcd5ab9d0c8296914436720dd67fca689ae1a75b4ec1c72f/kiwisolver-1.4.8-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d47cfb2650f0e103d4bf68b0b5804c68da97272c84bb12850d877a95c056bd67", size = 2312306 },
+ { url = "https://files.pythonhosted.org/packages/2d/27/bdf1c769c83f74d98cbc34483a972f221440703054894a37d174fba8aa68/kiwisolver-1.4.8-cp311-cp311-win_amd64.whl", hash = "sha256:ed33ca2002a779a2e20eeb06aea7721b6e47f2d4b8a8ece979d8ba9e2a167e34", size = 71966 },
+ { url = "https://files.pythonhosted.org/packages/4a/c9/9642ea855604aeb2968a8e145fc662edf61db7632ad2e4fb92424be6b6c0/kiwisolver-1.4.8-cp311-cp311-win_arm64.whl", hash = "sha256:16523b40aab60426ffdebe33ac374457cf62863e330a90a0383639ce14bf44b2", size = 65311 },
+ { url = "https://files.pythonhosted.org/packages/fc/aa/cea685c4ab647f349c3bc92d2daf7ae34c8e8cf405a6dcd3a497f58a2ac3/kiwisolver-1.4.8-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:d6af5e8815fd02997cb6ad9bbed0ee1e60014438ee1a5c2444c96f87b8843502", size = 124152 },
+ { url = "https://files.pythonhosted.org/packages/c5/0b/8db6d2e2452d60d5ebc4ce4b204feeb16176a851fd42462f66ade6808084/kiwisolver-1.4.8-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:bade438f86e21d91e0cf5dd7c0ed00cda0f77c8c1616bd83f9fc157fa6760d31", size = 66555 },
+ { url = "https://files.pythonhosted.org/packages/60/26/d6a0db6785dd35d3ba5bf2b2df0aedc5af089962c6eb2cbf67a15b81369e/kiwisolver-1.4.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b83dc6769ddbc57613280118fb4ce3cd08899cc3369f7d0e0fab518a7cf37fdb", size = 65067 },
+ { url = "https://files.pythonhosted.org/packages/c9/ed/1d97f7e3561e09757a196231edccc1bcf59d55ddccefa2afc9c615abd8e0/kiwisolver-1.4.8-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:111793b232842991be367ed828076b03d96202c19221b5ebab421ce8bcad016f", size = 1378443 },
+ { url = "https://files.pythonhosted.org/packages/29/61/39d30b99954e6b46f760e6289c12fede2ab96a254c443639052d1b573fbc/kiwisolver-1.4.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:257af1622860e51b1a9d0ce387bf5c2c4f36a90594cb9514f55b074bcc787cfc", size = 1472728 },
+ { url = "https://files.pythonhosted.org/packages/0c/3e/804163b932f7603ef256e4a715e5843a9600802bb23a68b4e08c8c0ff61d/kiwisolver-1.4.8-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:69b5637c3f316cab1ec1c9a12b8c5f4750a4c4b71af9157645bf32830e39c03a", size = 1478388 },
+ { url = "https://files.pythonhosted.org/packages/8a/9e/60eaa75169a154700be74f875a4d9961b11ba048bef315fbe89cb6999056/kiwisolver-1.4.8-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:782bb86f245ec18009890e7cb8d13a5ef54dcf2ebe18ed65f795e635a96a1c6a", size = 1413849 },
+ { url = "https://files.pythonhosted.org/packages/bc/b3/9458adb9472e61a998c8c4d95cfdfec91c73c53a375b30b1428310f923e4/kiwisolver-1.4.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc978a80a0db3a66d25767b03688f1147a69e6237175c0f4ffffaaedf744055a", size = 1475533 },
+ { url = "https://files.pythonhosted.org/packages/e4/7a/0a42d9571e35798de80aef4bb43a9b672aa7f8e58643d7bd1950398ffb0a/kiwisolver-1.4.8-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:36dbbfd34838500a31f52c9786990d00150860e46cd5041386f217101350f0d3", size = 2268898 },
+ { url = "https://files.pythonhosted.org/packages/d9/07/1255dc8d80271400126ed8db35a1795b1a2c098ac3a72645075d06fe5c5d/kiwisolver-1.4.8-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:eaa973f1e05131de5ff3569bbba7f5fd07ea0595d3870ed4a526d486fe57fa1b", size = 2425605 },
+ { url = "https://files.pythonhosted.org/packages/84/df/5a3b4cf13780ef6f6942df67b138b03b7e79e9f1f08f57c49957d5867f6e/kiwisolver-1.4.8-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:a66f60f8d0c87ab7f59b6fb80e642ebb29fec354a4dfad687ca4092ae69d04f4", size = 2375801 },
+ { url = "https://files.pythonhosted.org/packages/8f/10/2348d068e8b0f635c8c86892788dac7a6b5c0cb12356620ab575775aad89/kiwisolver-1.4.8-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:858416b7fb777a53f0c59ca08190ce24e9abbd3cffa18886a5781b8e3e26f65d", size = 2520077 },
+ { url = "https://files.pythonhosted.org/packages/32/d8/014b89fee5d4dce157d814303b0fce4d31385a2af4c41fed194b173b81ac/kiwisolver-1.4.8-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:085940635c62697391baafaaeabdf3dd7a6c3643577dde337f4d66eba021b2b8", size = 2338410 },
+ { url = "https://files.pythonhosted.org/packages/bd/72/dfff0cc97f2a0776e1c9eb5bef1ddfd45f46246c6533b0191887a427bca5/kiwisolver-1.4.8-cp312-cp312-win_amd64.whl", hash = "sha256:01c3d31902c7db5fb6182832713d3b4122ad9317c2c5877d0539227d96bb2e50", size = 71853 },
+ { url = "https://files.pythonhosted.org/packages/dc/85/220d13d914485c0948a00f0b9eb419efaf6da81b7d72e88ce2391f7aed8d/kiwisolver-1.4.8-cp312-cp312-win_arm64.whl", hash = "sha256:a3c44cb68861de93f0c4a8175fbaa691f0aa22550c331fefef02b618a9dcb476", size = 65424 },
+ { url = "https://files.pythonhosted.org/packages/1f/f9/ae81c47a43e33b93b0a9819cac6723257f5da2a5a60daf46aa5c7226ea85/kiwisolver-1.4.8-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:e7a019419b7b510f0f7c9dceff8c5eae2392037eae483a7f9162625233802b0a", size = 60403 },
+ { url = "https://files.pythonhosted.org/packages/58/ca/f92b5cb6f4ce0c1ebfcfe3e2e42b96917e16f7090e45b21102941924f18f/kiwisolver-1.4.8-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:286b18e86682fd2217a48fc6be6b0f20c1d0ed10958d8dc53453ad58d7be0bf8", size = 58657 },
+ { url = "https://files.pythonhosted.org/packages/80/28/ae0240f732f0484d3a4dc885d055653c47144bdf59b670aae0ec3c65a7c8/kiwisolver-1.4.8-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4191ee8dfd0be1c3666ccbac178c5a05d5f8d689bbe3fc92f3c4abec817f8fe0", size = 84948 },
+ { url = "https://files.pythonhosted.org/packages/5d/eb/78d50346c51db22c7203c1611f9b513075f35c4e0e4877c5dde378d66043/kiwisolver-1.4.8-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cd2785b9391f2873ad46088ed7599a6a71e762e1ea33e87514b1a441ed1da1c", size = 81186 },
+ { url = "https://files.pythonhosted.org/packages/43/f8/7259f18c77adca88d5f64f9a522792e178b2691f3748817a8750c2d216ef/kiwisolver-1.4.8-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c07b29089b7ba090b6f1a669f1411f27221c3662b3a1b7010e67b59bb5a6f10b", size = 80279 },
+ { url = "https://files.pythonhosted.org/packages/3a/1d/50ad811d1c5dae091e4cf046beba925bcae0a610e79ae4c538f996f63ed5/kiwisolver-1.4.8-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:65ea09a5a3faadd59c2ce96dc7bf0f364986a315949dc6374f04396b0d60e09b", size = 71762 },
+]
+
+[[package]]
+name = "lightning-utilities"
+version = "0.12.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "packaging" },
+ { name = "setuptools" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/d5/26/a449b858a6beaaf779d56775a5c675d636af11e32004e4420506a48eb7f4/lightning_utilities-0.12.0.tar.gz", hash = "sha256:95b5f22a0b69eb27ca0929c6c1d510592a70080e1733a055bf154903c0343b60", size = 29677 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/50/8d/da77ceb92ed674da93959184a2777d08ccbd872559fb52aba16b91686b7e/lightning_utilities-0.12.0-py3-none-any.whl", hash = "sha256:b827f5768607e81ccc7b2ada1f50628168d1cc9f839509c7e87c04b59079e66c", size = 28487 },
+]
+
+[[package]]
+name = "lovely-numpy"
+version = "0.2.13"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "fastcore" },
+ { name = "ipython" },
+ { name = "matplotlib" },
+ { name = "numpy" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/f7/46/ba51024fbf97d8370afa8a764fe18bf5c663e64badf129dfb2949a7ccbdd/lovely_numpy-0.2.13.tar.gz", hash = "sha256:0ed56660986731db3d3d7ff130e85e8d162e96c96042a0bd9a3992d32a3b34e2", size = 23920 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/92/f7/a74742c9f9823a6a1fb18eb11fe77da61ea9426770ede6406f2f97a9622a/lovely_numpy-0.2.13-py3-none-any.whl", hash = "sha256:2e696d145d301264390f790cd602560ca5e525534d9059e0fcd3f313b63b5786", size = 24357 },
+]
+
+[[package]]
+name = "lovely-tensors"
+version = "0.1.18"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "lovely-numpy" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/20/c4/b109b6c912929a471cf03c6a2f02d8c5ab2889ec57177692d7b971c4ec70/lovely_tensors-0.1.18.tar.gz", hash = "sha256:afb07d52de9ec6e560e77d9b3e01758bbc921a6f25c661b1c51ebbb3f75ee0c5", size = 21962 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/71/fb/688604e72694c597cf236e5ccd9a4bcc34bacc5f8e5dbb2cf7d7ddccb25b/lovely_tensors-0.1.18-py3-none-any.whl", hash = "sha256:91dc30f0d6224364851e6f14497e1677076cd3a59bae4d92c78bbe8684f6f22b", size = 19303 },
+]
+
+[[package]]
+name = "lpips"
+version = "0.1.4"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "numpy" },
+ { name = "scipy" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "torchvision", version = "0.21.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torchvision", version = "0.21.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "tqdm" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/e8/2d/4b8148d32f5bd461eb7d5daa54fcc998f86eaa709a57f4ef6aa4c62f024f/lpips-0.1.4.tar.gz", hash = "sha256:3846331df6c69688aec3d300a5eeef6c529435bc8460bd58201c3d62e56188fa", size = 18029 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/9b/13/1df50c7925d9d2746702719f40e864f51ed66f307b20ad32392f1ad2bb87/lpips-0.1.4-py3-none-any.whl", hash = "sha256:fd537af5828b69d2e6ffc0a397bd506dbc28ca183543617690844c08e102ec5e", size = 53763 },
+]
+
+[[package]]
+name = "lxml"
+version = "5.3.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/e7/6b/20c3a4b24751377aaa6307eb230b66701024012c29dd374999cc92983269/lxml-5.3.0.tar.gz", hash = "sha256:4e109ca30d1edec1ac60cdbe341905dc3b8f55b16855e03a54aaf59e51ec8c6f", size = 3679318 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/a1/ce/2789e39eddf2b13fac29878bfa465f0910eb6b0096e29090e5176bc8cf43/lxml-5.3.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:dd36439be765e2dde7660212b5275641edbc813e7b24668831a5c8ac91180656", size = 8124570 },
+ { url = "https://files.pythonhosted.org/packages/24/a8/f4010166a25d41715527129af2675981a50d3bbf7df09c5d9ab8ca24fbf9/lxml-5.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ae5fe5c4b525aa82b8076c1a59d642c17b6e8739ecf852522c6321852178119d", size = 4413042 },
+ { url = "https://files.pythonhosted.org/packages/41/a4/7e45756cecdd7577ddf67a68b69c1db0f5ddbf0c9f65021ee769165ffc5a/lxml-5.3.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:501d0d7e26b4d261fca8132854d845e4988097611ba2531408ec91cf3fd9d20a", size = 5139213 },
+ { url = "https://files.pythonhosted.org/packages/02/e2/ecf845b12323c92748077e1818b64e8b4dba509a4cb12920b3762ebe7552/lxml-5.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb66442c2546446944437df74379e9cf9e9db353e61301d1a0e26482f43f0dd8", size = 4838814 },
+ { url = "https://files.pythonhosted.org/packages/12/91/619f9fb72cf75e9ceb8700706f7276f23995f6ad757e6d400fbe35ca4990/lxml-5.3.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9e41506fec7a7f9405b14aa2d5c8abbb4dbbd09d88f9496958b6d00cb4d45330", size = 5425084 },
+ { url = "https://files.pythonhosted.org/packages/25/3b/162a85a8f0fd2a3032ec3f936636911c6e9523a8e263fffcfd581ce98b54/lxml-5.3.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f7d4a670107d75dfe5ad080bed6c341d18c4442f9378c9f58e5851e86eb79965", size = 4875993 },
+ { url = "https://files.pythonhosted.org/packages/43/af/dd3f58cc7d946da6ae42909629a2b1d5dd2d1b583334d4af9396697d6863/lxml-5.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41ce1f1e2c7755abfc7e759dc34d7d05fd221723ff822947132dc934d122fe22", size = 5012462 },
+ { url = "https://files.pythonhosted.org/packages/69/c1/5ea46b2d4c98f5bf5c83fffab8a0ad293c9bc74df9ecfbafef10f77f7201/lxml-5.3.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:44264ecae91b30e5633013fb66f6ddd05c006d3e0e884f75ce0b4755b3e3847b", size = 4815288 },
+ { url = "https://files.pythonhosted.org/packages/1d/51/a0acca077ad35da458f4d3f729ef98effd2b90f003440d35fc36323f8ae6/lxml-5.3.0-cp310-cp310-manylinux_2_28_ppc64le.whl", hash = "sha256:3c174dc350d3ec52deb77f2faf05c439331d6ed5e702fc247ccb4e6b62d884b7", size = 5472435 },
+ { url = "https://files.pythonhosted.org/packages/4d/6b/0989c9368986961a6b0f55b46c80404c4b758417acdb6d87bfc3bd5f4967/lxml-5.3.0-cp310-cp310-manylinux_2_28_s390x.whl", hash = "sha256:2dfab5fa6a28a0b60a20638dc48e6343c02ea9933e3279ccb132f555a62323d8", size = 4976354 },
+ { url = "https://files.pythonhosted.org/packages/05/9e/87492d03ff604fbf656ed2bf3e2e8d28f5d58ea1f00ff27ac27b06509079/lxml-5.3.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b1c8c20847b9f34e98080da785bb2336ea982e7f913eed5809e5a3c872900f32", size = 5029973 },
+ { url = "https://files.pythonhosted.org/packages/f9/cc/9ae1baf5472af88e19e2c454b3710c1be9ecafb20eb474eeabcd88a055d2/lxml-5.3.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2c86bf781b12ba417f64f3422cfc302523ac9cd1d8ae8c0f92a1c66e56ef2e86", size = 4888837 },
+ { url = "https://files.pythonhosted.org/packages/d2/10/5594ffaec8c120d75b17e3ad23439b740a51549a9b5fd7484b2179adfe8f/lxml-5.3.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:c162b216070f280fa7da844531169be0baf9ccb17263cf5a8bf876fcd3117fa5", size = 5530555 },
+ { url = "https://files.pythonhosted.org/packages/ea/9b/de17f05377c8833343b629905571fb06cff2028f15a6f58ae2267662e341/lxml-5.3.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:36aef61a1678cb778097b4a6eeae96a69875d51d1e8f4d4b491ab3cfb54b5a03", size = 5405314 },
+ { url = "https://files.pythonhosted.org/packages/8a/b4/227be0f1f3cca8255925985164c3838b8b36e441ff0cc10c1d3c6bdba031/lxml-5.3.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f65e5120863c2b266dbcc927b306c5b78e502c71edf3295dfcb9501ec96e5fc7", size = 5079303 },
+ { url = "https://files.pythonhosted.org/packages/5c/ee/19abcebb7fc40319bb71cd6adefa1ad94d09b5660228715854d6cc420713/lxml-5.3.0-cp310-cp310-win32.whl", hash = "sha256:ef0c1fe22171dd7c7c27147f2e9c3e86f8bdf473fed75f16b0c2e84a5030ce80", size = 3475126 },
+ { url = "https://files.pythonhosted.org/packages/a1/35/183d32551447e280032b2331738cd850da435a42f850b71ebeaab42c1313/lxml-5.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:052d99051e77a4f3e8482c65014cf6372e61b0a6f4fe9edb98503bb5364cfee3", size = 3805065 },
+ { url = "https://files.pythonhosted.org/packages/5c/a8/449faa2a3cbe6a99f8d38dcd51a3ee8844c17862841a6f769ea7c2a9cd0f/lxml-5.3.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:74bcb423462233bc5d6066e4e98b0264e7c1bed7541fff2f4e34fe6b21563c8b", size = 8141056 },
+ { url = "https://files.pythonhosted.org/packages/ac/8a/ae6325e994e2052de92f894363b038351c50ee38749d30cc6b6d96aaf90f/lxml-5.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a3d819eb6f9b8677f57f9664265d0a10dd6551d227afb4af2b9cd7bdc2ccbf18", size = 4425238 },
+ { url = "https://files.pythonhosted.org/packages/f8/fb/128dddb7f9086236bce0eeae2bfb316d138b49b159f50bc681d56c1bdd19/lxml-5.3.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b8f5db71b28b8c404956ddf79575ea77aa8b1538e8b2ef9ec877945b3f46442", size = 5095197 },
+ { url = "https://files.pythonhosted.org/packages/b4/f9/a181a8ef106e41e3086629c8bdb2d21a942f14c84a0e77452c22d6b22091/lxml-5.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c3406b63232fc7e9b8783ab0b765d7c59e7c59ff96759d8ef9632fca27c7ee4", size = 4809809 },
+ { url = "https://files.pythonhosted.org/packages/25/2f/b20565e808f7f6868aacea48ddcdd7e9e9fb4c799287f21f1a6c7c2e8b71/lxml-5.3.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2ecdd78ab768f844c7a1d4a03595038c166b609f6395e25af9b0f3f26ae1230f", size = 5407593 },
+ { url = "https://files.pythonhosted.org/packages/23/0e/caac672ec246d3189a16c4d364ed4f7d6bf856c080215382c06764058c08/lxml-5.3.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:168f2dfcfdedf611eb285efac1516c8454c8c99caf271dccda8943576b67552e", size = 4866657 },
+ { url = "https://files.pythonhosted.org/packages/67/a4/1f5fbd3f58d4069000522196b0b776a014f3feec1796da03e495cf23532d/lxml-5.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa617107a410245b8660028a7483b68e7914304a6d4882b5ff3d2d3eb5948d8c", size = 4967017 },
+ { url = "https://files.pythonhosted.org/packages/ee/73/623ecea6ca3c530dd0a4ed0d00d9702e0e85cd5624e2d5b93b005fe00abd/lxml-5.3.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:69959bd3167b993e6e710b99051265654133a98f20cec1d9b493b931942e9c16", size = 4810730 },
+ { url = "https://files.pythonhosted.org/packages/1d/ce/fb84fb8e3c298f3a245ae3ea6221c2426f1bbaa82d10a88787412a498145/lxml-5.3.0-cp311-cp311-manylinux_2_28_ppc64le.whl", hash = "sha256:bd96517ef76c8654446fc3db9242d019a1bb5fe8b751ba414765d59f99210b79", size = 5455154 },
+ { url = "https://files.pythonhosted.org/packages/b1/72/4d1ad363748a72c7c0411c28be2b0dc7150d91e823eadad3b91a4514cbea/lxml-5.3.0-cp311-cp311-manylinux_2_28_s390x.whl", hash = "sha256:ab6dd83b970dc97c2d10bc71aa925b84788c7c05de30241b9e96f9b6d9ea3080", size = 4969416 },
+ { url = "https://files.pythonhosted.org/packages/42/07/b29571a58a3a80681722ea8ed0ba569211d9bb8531ad49b5cacf6d409185/lxml-5.3.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:eec1bb8cdbba2925bedc887bc0609a80e599c75b12d87ae42ac23fd199445654", size = 5013672 },
+ { url = "https://files.pythonhosted.org/packages/b9/93/bde740d5a58cf04cbd38e3dd93ad1e36c2f95553bbf7d57807bc6815d926/lxml-5.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6a7095eeec6f89111d03dabfe5883a1fd54da319c94e0fb104ee8f23616b572d", size = 4878644 },
+ { url = "https://files.pythonhosted.org/packages/56/b5/645c8c02721d49927c93181de4017164ec0e141413577687c3df8ff0800f/lxml-5.3.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:6f651ebd0b21ec65dfca93aa629610a0dbc13dbc13554f19b0113da2e61a4763", size = 5511531 },
+ { url = "https://files.pythonhosted.org/packages/85/3f/6a99a12d9438316f4fc86ef88c5d4c8fb674247b17f3173ecadd8346b671/lxml-5.3.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:f422a209d2455c56849442ae42f25dbaaba1c6c3f501d58761c619c7836642ec", size = 5402065 },
+ { url = "https://files.pythonhosted.org/packages/80/8a/df47bff6ad5ac57335bf552babfb2408f9eb680c074ec1ba412a1a6af2c5/lxml-5.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:62f7fdb0d1ed2065451f086519865b4c90aa19aed51081979ecd05a21eb4d1be", size = 5069775 },
+ { url = "https://files.pythonhosted.org/packages/08/ae/e7ad0f0fbe4b6368c5ee1e3ef0c3365098d806d42379c46c1ba2802a52f7/lxml-5.3.0-cp311-cp311-win32.whl", hash = "sha256:c6379f35350b655fd817cd0d6cbeef7f265f3ae5fedb1caae2eb442bbeae9ab9", size = 3474226 },
+ { url = "https://files.pythonhosted.org/packages/c3/b5/91c2249bfac02ee514ab135e9304b89d55967be7e53e94a879b74eec7a5c/lxml-5.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:9c52100e2c2dbb0649b90467935c4b0de5528833c76a35ea1a2691ec9f1ee7a1", size = 3814971 },
+ { url = "https://files.pythonhosted.org/packages/eb/6d/d1f1c5e40c64bf62afd7a3f9b34ce18a586a1cccbf71e783cd0a6d8e8971/lxml-5.3.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:e99f5507401436fdcc85036a2e7dc2e28d962550afe1cbfc07c40e454256a859", size = 8171753 },
+ { url = "https://files.pythonhosted.org/packages/bd/83/26b1864921869784355459f374896dcf8b44d4af3b15d7697e9156cb2de9/lxml-5.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:384aacddf2e5813a36495233b64cb96b1949da72bef933918ba5c84e06af8f0e", size = 4441955 },
+ { url = "https://files.pythonhosted.org/packages/e0/d2/e9bff9fb359226c25cda3538f664f54f2804f4b37b0d7c944639e1a51f69/lxml-5.3.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:874a216bf6afaf97c263b56371434e47e2c652d215788396f60477540298218f", size = 5050778 },
+ { url = "https://files.pythonhosted.org/packages/88/69/6972bfafa8cd3ddc8562b126dd607011e218e17be313a8b1b9cc5a0ee876/lxml-5.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:65ab5685d56914b9a2a34d67dd5488b83213d680b0c5d10b47f81da5a16b0b0e", size = 4748628 },
+ { url = "https://files.pythonhosted.org/packages/5d/ea/a6523c7c7f6dc755a6eed3d2f6d6646617cad4d3d6d8ce4ed71bfd2362c8/lxml-5.3.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aac0bbd3e8dd2d9c45ceb82249e8bdd3ac99131a32b4d35c8af3cc9db1657179", size = 5322215 },
+ { url = "https://files.pythonhosted.org/packages/99/37/396fbd24a70f62b31d988e4500f2068c7f3fd399d2fd45257d13eab51a6f/lxml-5.3.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b369d3db3c22ed14c75ccd5af429086f166a19627e84a8fdade3f8f31426e52a", size = 4813963 },
+ { url = "https://files.pythonhosted.org/packages/09/91/e6136f17459a11ce1757df864b213efbeab7adcb2efa63efb1b846ab6723/lxml-5.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c24037349665434f375645fa9d1f5304800cec574d0310f618490c871fd902b3", size = 4923353 },
+ { url = "https://files.pythonhosted.org/packages/1d/7c/2eeecf87c9a1fca4f84f991067c693e67340f2b7127fc3eca8fa29d75ee3/lxml-5.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:62d172f358f33a26d6b41b28c170c63886742f5b6772a42b59b4f0fa10526cb1", size = 4740541 },
+ { url = "https://files.pythonhosted.org/packages/3b/ed/4c38ba58defca84f5f0d0ac2480fdcd99fc7ae4b28fc417c93640a6949ae/lxml-5.3.0-cp312-cp312-manylinux_2_28_ppc64le.whl", hash = "sha256:c1f794c02903c2824fccce5b20c339a1a14b114e83b306ff11b597c5f71a1c8d", size = 5346504 },
+ { url = "https://files.pythonhosted.org/packages/a5/22/bbd3995437e5745cb4c2b5d89088d70ab19d4feabf8a27a24cecb9745464/lxml-5.3.0-cp312-cp312-manylinux_2_28_s390x.whl", hash = "sha256:5d6a6972b93c426ace71e0be9a6f4b2cfae9b1baed2eed2006076a746692288c", size = 4898077 },
+ { url = "https://files.pythonhosted.org/packages/0a/6e/94537acfb5b8f18235d13186d247bca478fea5e87d224644e0fe907df976/lxml-5.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:3879cc6ce938ff4eb4900d901ed63555c778731a96365e53fadb36437a131a99", size = 4946543 },
+ { url = "https://files.pythonhosted.org/packages/8d/e8/4b15df533fe8e8d53363b23a41df9be907330e1fa28c7ca36893fad338ee/lxml-5.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:74068c601baff6ff021c70f0935b0c7bc528baa8ea210c202e03757c68c5a4ff", size = 4816841 },
+ { url = "https://files.pythonhosted.org/packages/1a/e7/03f390ea37d1acda50bc538feb5b2bda6745b25731e4e76ab48fae7106bf/lxml-5.3.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ecd4ad8453ac17bc7ba3868371bffb46f628161ad0eefbd0a855d2c8c32dd81a", size = 5417341 },
+ { url = "https://files.pythonhosted.org/packages/ea/99/d1133ab4c250da85a883c3b60249d3d3e7c64f24faff494cf0fd23f91e80/lxml-5.3.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:7e2f58095acc211eb9d8b5771bf04df9ff37d6b87618d1cbf85f92399c98dae8", size = 5327539 },
+ { url = "https://files.pythonhosted.org/packages/7d/ed/e6276c8d9668028213df01f598f385b05b55a4e1b4662ee12ef05dab35aa/lxml-5.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e63601ad5cd8f860aa99d109889b5ac34de571c7ee902d6812d5d9ddcc77fa7d", size = 5012542 },
+ { url = "https://files.pythonhosted.org/packages/36/88/684d4e800f5aa28df2a991a6a622783fb73cf0e46235cfa690f9776f032e/lxml-5.3.0-cp312-cp312-win32.whl", hash = "sha256:17e8d968d04a37c50ad9c456a286b525d78c4a1c15dd53aa46c1d8e06bf6fa30", size = 3486454 },
+ { url = "https://files.pythonhosted.org/packages/fc/82/ace5a5676051e60355bd8fb945df7b1ba4f4fb8447f2010fb816bfd57724/lxml-5.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:c1a69e58a6bb2de65902051d57fde951febad631a20a64572677a1052690482f", size = 3816857 },
+ { url = "https://files.pythonhosted.org/packages/99/f7/b73a431c8500565aa500e99e60b448d305eaf7c0b4c893c7c5a8a69cc595/lxml-5.3.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7b1cd427cb0d5f7393c31b7496419da594fe600e6fdc4b105a54f82405e6626c", size = 3925431 },
+ { url = "https://files.pythonhosted.org/packages/db/48/4a206623c0d093d0e3b15f415ffb4345b0bdf661a3d0b15a112948c033c7/lxml-5.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51806cfe0279e06ed8500ce19479d757db42a30fd509940b1701be9c86a5ff9a", size = 4216683 },
+ { url = "https://files.pythonhosted.org/packages/54/47/577820c45dd954523ae8453b632d91e76da94ca6d9ee40d8c98dd86f916b/lxml-5.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee70d08fd60c9565ba8190f41a46a54096afa0eeb8f76bd66f2c25d3b1b83005", size = 4326732 },
+ { url = "https://files.pythonhosted.org/packages/68/de/96cb6d3269bc994b4f5ede8ca7bf0840f5de0a278bc6e50cb317ff71cafa/lxml-5.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:8dc2c0395bea8254d8daebc76dcf8eb3a95ec2a46fa6fae5eaccee366bfe02ce", size = 4218377 },
+ { url = "https://files.pythonhosted.org/packages/a5/43/19b1ef6cbffa4244a217f95cc5f41a6cb4720fed33510a49670b03c5f1a0/lxml-5.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:6ba0d3dcac281aad8a0e5b14c7ed6f9fa89c8612b47939fc94f80b16e2e9bc83", size = 4351237 },
+ { url = "https://files.pythonhosted.org/packages/ba/b2/6a22fb5c0885da3b00e116aee81f0b829ec9ac8f736cd414b4a09413fc7d/lxml-5.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:6e91cf736959057f7aac7adfc83481e03615a8e8dd5758aa1d95ea69e8931dba", size = 3487557 },
+]
+
+[[package]]
+name = "markdown"
+version = "3.7"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/54/28/3af612670f82f4c056911fbbbb42760255801b3068c48de792d354ff4472/markdown-3.7.tar.gz", hash = "sha256:2ae2471477cfd02dbbf038d5d9bc226d40def84b4fe2986e49b59b6b472bbed2", size = 357086 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/3f/08/83871f3c50fc983b88547c196d11cf8c3340e37c32d2e9d6152abe2c61f7/Markdown-3.7-py3-none-any.whl", hash = "sha256:7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803", size = 106349 },
+]
+
+[[package]]
+name = "markdown-it-py"
+version = "3.0.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "mdurl" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/38/71/3b932df36c1a044d397a1f92d1cf91ee0a503d91e470cbd670aa66b07ed0/markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb", size = 74596 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", size = 87528 },
+]
+
+[[package]]
+name = "markupsafe"
+version = "3.0.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/b2/97/5d42485e71dfc078108a86d6de8fa46db44a1a9295e89c5d6d4a06e23a62/markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0", size = 20537 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/04/90/d08277ce111dd22f77149fd1a5d4653eeb3b3eaacbdfcbae5afb2600eebd/MarkupSafe-3.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8", size = 14357 },
+ { url = "https://files.pythonhosted.org/packages/04/e1/6e2194baeae0bca1fae6629dc0cbbb968d4d941469cbab11a3872edff374/MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158", size = 12393 },
+ { url = "https://files.pythonhosted.org/packages/1d/69/35fa85a8ece0a437493dc61ce0bb6d459dcba482c34197e3efc829aa357f/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38a9ef736c01fccdd6600705b09dc574584b89bea478200c5fbf112a6b0d5579", size = 21732 },
+ { url = "https://files.pythonhosted.org/packages/22/35/137da042dfb4720b638d2937c38a9c2df83fe32d20e8c8f3185dbfef05f7/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbcb445fa71794da8f178f0f6d66789a28d7319071af7a496d4d507ed566270d", size = 20866 },
+ { url = "https://files.pythonhosted.org/packages/29/28/6d029a903727a1b62edb51863232152fd335d602def598dade38996887f0/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57cb5a3cf367aeb1d316576250f65edec5bb3be939e9247ae594b4bcbc317dfb", size = 20964 },
+ { url = "https://files.pythonhosted.org/packages/cc/cd/07438f95f83e8bc028279909d9c9bd39e24149b0d60053a97b2bc4f8aa51/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:3809ede931876f5b2ec92eef964286840ed3540dadf803dd570c3b7e13141a3b", size = 21977 },
+ { url = "https://files.pythonhosted.org/packages/29/01/84b57395b4cc062f9c4c55ce0df7d3108ca32397299d9df00fedd9117d3d/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e07c3764494e3776c602c1e78e298937c3315ccc9043ead7e685b7f2b8d47b3c", size = 21366 },
+ { url = "https://files.pythonhosted.org/packages/bd/6e/61ebf08d8940553afff20d1fb1ba7294b6f8d279df9fd0c0db911b4bbcfd/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b424c77b206d63d500bcb69fa55ed8d0e6a3774056bdc4839fc9298a7edca171", size = 21091 },
+ { url = "https://files.pythonhosted.org/packages/11/23/ffbf53694e8c94ebd1e7e491de185124277964344733c45481f32ede2499/MarkupSafe-3.0.2-cp310-cp310-win32.whl", hash = "sha256:fcabf5ff6eea076f859677f5f0b6b5c1a51e70a376b0579e0eadef8db48c6b50", size = 15065 },
+ { url = "https://files.pythonhosted.org/packages/44/06/e7175d06dd6e9172d4a69a72592cb3f7a996a9c396eee29082826449bbc3/MarkupSafe-3.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:6af100e168aa82a50e186c82875a5893c5597a0c1ccdb0d8b40240b1f28b969a", size = 15514 },
+ { url = "https://files.pythonhosted.org/packages/6b/28/bbf83e3f76936960b850435576dd5e67034e200469571be53f69174a2dfd/MarkupSafe-3.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9025b4018f3a1314059769c7bf15441064b2207cb3f065e6ea1e7359cb46db9d", size = 14353 },
+ { url = "https://files.pythonhosted.org/packages/6c/30/316d194b093cde57d448a4c3209f22e3046c5bb2fb0820b118292b334be7/MarkupSafe-3.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:93335ca3812df2f366e80509ae119189886b0f3c2b81325d39efdb84a1e2ae93", size = 12392 },
+ { url = "https://files.pythonhosted.org/packages/f2/96/9cdafba8445d3a53cae530aaf83c38ec64c4d5427d975c974084af5bc5d2/MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cb8438c3cbb25e220c2ab33bb226559e7afb3baec11c4f218ffa7308603c832", size = 23984 },
+ { url = "https://files.pythonhosted.org/packages/f1/a4/aefb044a2cd8d7334c8a47d3fb2c9f328ac48cb349468cc31c20b539305f/MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a123e330ef0853c6e822384873bef7507557d8e4a082961e1defa947aa59ba84", size = 23120 },
+ { url = "https://files.pythonhosted.org/packages/8d/21/5e4851379f88f3fad1de30361db501300d4f07bcad047d3cb0449fc51f8c/MarkupSafe-3.0.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e084f686b92e5b83186b07e8a17fc09e38fff551f3602b249881fec658d3eca", size = 23032 },
+ { url = "https://files.pythonhosted.org/packages/00/7b/e92c64e079b2d0d7ddf69899c98842f3f9a60a1ae72657c89ce2655c999d/MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8213e09c917a951de9d09ecee036d5c7d36cb6cb7dbaece4c71a60d79fb9798", size = 24057 },
+ { url = "https://files.pythonhosted.org/packages/f9/ac/46f960ca323037caa0a10662ef97d0a4728e890334fc156b9f9e52bcc4ca/MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:5b02fb34468b6aaa40dfc198d813a641e3a63b98c2b05a16b9f80b7ec314185e", size = 23359 },
+ { url = "https://files.pythonhosted.org/packages/69/84/83439e16197337b8b14b6a5b9c2105fff81d42c2a7c5b58ac7b62ee2c3b1/MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4", size = 23306 },
+ { url = "https://files.pythonhosted.org/packages/9a/34/a15aa69f01e2181ed8d2b685c0d2f6655d5cca2c4db0ddea775e631918cd/MarkupSafe-3.0.2-cp311-cp311-win32.whl", hash = "sha256:6c89876f41da747c8d3677a2b540fb32ef5715f97b66eeb0c6b66f5e3ef6f59d", size = 15094 },
+ { url = "https://files.pythonhosted.org/packages/da/b8/3a3bd761922d416f3dc5d00bfbed11f66b1ab89a0c2b6e887240a30b0f6b/MarkupSafe-3.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:70a87b411535ccad5ef2f1df5136506a10775d267e197e4cf531ced10537bd6b", size = 15521 },
+ { url = "https://files.pythonhosted.org/packages/22/09/d1f21434c97fc42f09d290cbb6350d44eb12f09cc62c9476effdb33a18aa/MarkupSafe-3.0.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9778bd8ab0a994ebf6f84c2b949e65736d5575320a17ae8984a77fab08db94cf", size = 14274 },
+ { url = "https://files.pythonhosted.org/packages/6b/b0/18f76bba336fa5aecf79d45dcd6c806c280ec44538b3c13671d49099fdd0/MarkupSafe-3.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:846ade7b71e3536c4e56b386c2a47adf5741d2d8b94ec9dc3e92e5e1ee1e2225", size = 12348 },
+ { url = "https://files.pythonhosted.org/packages/e0/25/dd5c0f6ac1311e9b40f4af06c78efde0f3b5cbf02502f8ef9501294c425b/MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c99d261bd2d5f6b59325c92c73df481e05e57f19837bdca8413b9eac4bd8028", size = 24149 },
+ { url = "https://files.pythonhosted.org/packages/f3/f0/89e7aadfb3749d0f52234a0c8c7867877876e0a20b60e2188e9850794c17/MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17c96c14e19278594aa4841ec148115f9c7615a47382ecb6b82bd8fea3ab0c8", size = 23118 },
+ { url = "https://files.pythonhosted.org/packages/d5/da/f2eeb64c723f5e3777bc081da884b414671982008c47dcc1873d81f625b6/MarkupSafe-3.0.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:88416bd1e65dcea10bc7569faacb2c20ce071dd1f87539ca2ab364bf6231393c", size = 22993 },
+ { url = "https://files.pythonhosted.org/packages/da/0e/1f32af846df486dce7c227fe0f2398dc7e2e51d4a370508281f3c1c5cddc/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2181e67807fc2fa785d0592dc2d6206c019b9502410671cc905d132a92866557", size = 24178 },
+ { url = "https://files.pythonhosted.org/packages/c4/f6/bb3ca0532de8086cbff5f06d137064c8410d10779c4c127e0e47d17c0b71/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:52305740fe773d09cffb16f8ed0427942901f00adedac82ec8b67752f58a1b22", size = 23319 },
+ { url = "https://files.pythonhosted.org/packages/a2/82/8be4c96ffee03c5b4a034e60a31294daf481e12c7c43ab8e34a1453ee48b/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48", size = 23352 },
+ { url = "https://files.pythonhosted.org/packages/51/ae/97827349d3fcffee7e184bdf7f41cd6b88d9919c80f0263ba7acd1bbcb18/MarkupSafe-3.0.2-cp312-cp312-win32.whl", hash = "sha256:0f4ca02bea9a23221c0182836703cbf8930c5e9454bacce27e767509fa286a30", size = 15097 },
+ { url = "https://files.pythonhosted.org/packages/c1/80/a61f99dc3a936413c3ee4e1eecac96c0da5ed07ad56fd975f1a9da5bc630/MarkupSafe-3.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:8e06879fc22a25ca47312fbe7c8264eb0b662f6db27cb2d3bbbc74b1df4b9b87", size = 15601 },
+]
+
+[[package]]
+name = "matplotlib"
+version = "3.10.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "contourpy" },
+ { name = "cycler" },
+ { name = "fonttools" },
+ { name = "kiwisolver" },
+ { name = "numpy" },
+ { name = "packaging" },
+ { name = "pillow" },
+ { name = "pyparsing" },
+ { name = "python-dateutil" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/68/dd/fa2e1a45fce2d09f4aea3cee169760e672c8262325aa5796c49d543dc7e6/matplotlib-3.10.0.tar.gz", hash = "sha256:b886d02a581b96704c9d1ffe55709e49b4d2d52709ccebc4be42db856e511278", size = 36686418 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/09/ec/3cdff7b5239adaaacefcc4f77c316dfbbdf853c4ed2beec467e0fec31b9f/matplotlib-3.10.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2c5829a5a1dd5a71f0e31e6e8bb449bc0ee9dbfb05ad28fc0c6b55101b3a4be6", size = 8160551 },
+ { url = "https://files.pythonhosted.org/packages/41/f2/b518f2c7f29895c9b167bf79f8529c63383ae94eaf49a247a4528e9a148d/matplotlib-3.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2a43cbefe22d653ab34bb55d42384ed30f611bcbdea1f8d7f431011a2e1c62e", size = 8034853 },
+ { url = "https://files.pythonhosted.org/packages/ed/8d/45754b4affdb8f0d1a44e4e2bcd932cdf35b256b60d5eda9f455bb293ed0/matplotlib-3.10.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:607b16c8a73943df110f99ee2e940b8a1cbf9714b65307c040d422558397dac5", size = 8446724 },
+ { url = "https://files.pythonhosted.org/packages/09/5a/a113495110ae3e3395c72d82d7bc4802902e46dc797f6b041e572f195c56/matplotlib-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01d2b19f13aeec2e759414d3bfe19ddfb16b13a1250add08d46d5ff6f9be83c6", size = 8583905 },
+ { url = "https://files.pythonhosted.org/packages/12/b1/8b1655b4c9ed4600c817c419f7eaaf70082630efd7556a5b2e77a8a3cdaf/matplotlib-3.10.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5e6c6461e1fc63df30bf6f80f0b93f5b6784299f721bc28530477acd51bfc3d1", size = 9395223 },
+ { url = "https://files.pythonhosted.org/packages/5a/85/b9a54d64585a6b8737a78a61897450403c30f39e0bd3214270bb0b96f002/matplotlib-3.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:994c07b9d9fe8d25951e3202a68c17900679274dadfc1248738dcfa1bd40d7f3", size = 8025355 },
+ { url = "https://files.pythonhosted.org/packages/0c/f1/e37f6c84d252867d7ddc418fff70fc661cfd363179263b08e52e8b748e30/matplotlib-3.10.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:fd44fc75522f58612ec4a33958a7e5552562b7705b42ef1b4f8c0818e304a363", size = 8171677 },
+ { url = "https://files.pythonhosted.org/packages/c7/8b/92e9da1f28310a1f6572b5c55097b0c0ceb5e27486d85fb73b54f5a9b939/matplotlib-3.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c58a9622d5dbeb668f407f35f4e6bfac34bb9ecdcc81680c04d0258169747997", size = 8044945 },
+ { url = "https://files.pythonhosted.org/packages/c5/cb/49e83f0fd066937a5bd3bc5c5d63093703f3637b2824df8d856e0558beef/matplotlib-3.10.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:845d96568ec873be63f25fa80e9e7fae4be854a66a7e2f0c8ccc99e94a8bd4ef", size = 8458269 },
+ { url = "https://files.pythonhosted.org/packages/b2/7d/2d873209536b9ee17340754118a2a17988bc18981b5b56e6715ee07373ac/matplotlib-3.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5439f4c5a3e2e8eab18e2f8c3ef929772fd5641876db71f08127eed95ab64683", size = 8599369 },
+ { url = "https://files.pythonhosted.org/packages/b8/03/57d6cbbe85c61fe4cbb7c94b54dce443d68c21961830833a1f34d056e5ea/matplotlib-3.10.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4673ff67a36152c48ddeaf1135e74ce0d4bce1bbf836ae40ed39c29edf7e2765", size = 9405992 },
+ { url = "https://files.pythonhosted.org/packages/14/cf/e382598f98be11bf51dd0bc60eca44a517f6793e3dc8b9d53634a144620c/matplotlib-3.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:7e8632baebb058555ac0cde75db885c61f1212e47723d63921879806b40bec6a", size = 8034580 },
+ { url = "https://files.pythonhosted.org/packages/44/c7/6b2d8cb7cc251d53c976799cacd3200add56351c175ba89ab9cbd7c1e68a/matplotlib-3.10.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4659665bc7c9b58f8c00317c3c2a299f7f258eeae5a5d56b4c64226fca2f7c59", size = 8172465 },
+ { url = "https://files.pythonhosted.org/packages/42/2a/6d66d0fba41e13e9ca6512a0a51170f43e7e7ed3a8dfa036324100775612/matplotlib-3.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d44cb942af1693cced2604c33a9abcef6205601c445f6d0dc531d813af8a2f5a", size = 8043300 },
+ { url = "https://files.pythonhosted.org/packages/90/60/2a60342b27b90a16bada939a85e29589902b41073f59668b904b15ea666c/matplotlib-3.10.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a994f29e968ca002b50982b27168addfd65f0105610b6be7fa515ca4b5307c95", size = 8448936 },
+ { url = "https://files.pythonhosted.org/packages/a7/b2/d872fc3d753516870d520595ddd8ce4dd44fa797a240999f125f58521ad7/matplotlib-3.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b0558bae37f154fffda54d779a592bc97ca8b4701f1c710055b609a3bac44c8", size = 8594151 },
+ { url = "https://files.pythonhosted.org/packages/f4/bd/b2f60cf7f57d014ab33e4f74602a2b5bdc657976db8196bbc022185f6f9c/matplotlib-3.10.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:503feb23bd8c8acc75541548a1d709c059b7184cde26314896e10a9f14df5f12", size = 9400347 },
+ { url = "https://files.pythonhosted.org/packages/9f/6e/264673e64001b99d747aff5a288eca82826c024437a3694e19aed1decf46/matplotlib-3.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:c40ba2eb08b3f5de88152c2333c58cee7edcead0a2a0d60fcafa116b17117adc", size = 8039144 },
+ { url = "https://files.pythonhosted.org/packages/32/5f/29def7ce4e815ab939b56280976ee35afffb3bbdb43f332caee74cb8c951/matplotlib-3.10.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:81713dd0d103b379de4516b861d964b1d789a144103277769238c732229d7f03", size = 8155500 },
+ { url = "https://files.pythonhosted.org/packages/de/6d/d570383c9f7ca799d0a54161446f9ce7b17d6c50f2994b653514bcaa108f/matplotlib-3.10.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:359f87baedb1f836ce307f0e850d12bb5f1936f70d035561f90d41d305fdacea", size = 8032398 },
+ { url = "https://files.pythonhosted.org/packages/c9/b4/680aa700d99b48e8c4393fa08e9ab8c49c0555ee6f4c9c0a5e8ea8dfde5d/matplotlib-3.10.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ae80dc3a4add4665cf2faa90138384a7ffe2a4e37c58d83e115b54287c4f06ef", size = 8587361 },
+]
+
+[[package]]
+name = "matplotlib-inline"
+version = "0.1.7"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "traitlets" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/99/5b/a36a337438a14116b16480db471ad061c36c3694df7c2084a0da7ba538b7/matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90", size = 8159 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899 },
+]
+
+[[package]]
+name = "mauve-text"
+version = "0.4.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "faiss-cpu" },
+ { name = "numpy" },
+ { name = "requests" },
+ { name = "scikit-learn" },
+ { name = "tqdm" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/53/f1/790ce5858951689e17cf2d767a951fdd40dd22b33b8ae01aecee182d2ad2/mauve-text-0.4.0.tar.gz", hash = "sha256:a9cd29587d1acdfeb006274839c44ac65aec378fb89cceb368094f4a264fd94f", size = 22281 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/1f/0a/7fa7d797479b762e800b57064c5fe861743fe12722292c36de220fc964a2/mauve_text-0.4.0-py3-none-any.whl", hash = "sha256:ceafe978fd66adf4a2a8bad8c47be417e7306f43a4a4ab9121de01f81fc7e47b", size = 21534 },
+]
+
+[[package]]
+name = "mdurl"
+version = "0.1.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 },
+]
+
+[[package]]
+name = "mistletoe"
+version = "1.4.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/11/96/ea46a376a7c4cd56955ecdfff0ea68de43996a4e6d1aee4599729453bd11/mistletoe-1.4.0.tar.gz", hash = "sha256:1630f906e5e4bbe66fdeb4d29d277e2ea515d642bb18a9b49b136361a9818c9d", size = 107203 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/2a/0f/b5e545f0c7962be90366af3418989b12cf441d9da1e5d89d88f2f3e5cf8f/mistletoe-1.4.0-py3-none-any.whl", hash = "sha256:44a477803861de1237ba22e375c6b617690a31d2902b47279d1f8f7ed498a794", size = 51304 },
+]
+
+[[package]]
+name = "ml-collections"
+version = "1.0.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "absl-py" },
+ { name = "pyyaml" },
+ { name = "six" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/31/f9/74689ff3e3ff6e4ec8616887cb00c9c66bca7e6243fd328358ea3665d547/ml_collections-1.0.0.tar.gz", hash = "sha256:00b11a1a339dd6c2d9b7f0daab47ab17e10e29ca1b2a656058605e2b7210897f", size = 61151 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/5b/3c/2663b8b41a6f7dae1f1058cc75d9b1d09cf58e6482cb562976d4babe483c/ml_collections-1.0.0-py3-none-any.whl", hash = "sha256:17dbca4d83aba64f56b4b96e59637026d99d9e922569118b8a7f2e0ca6d203a6", size = 76451 },
+]
+
+[[package]]
+name = "monsterui"
+version = "0.0.34"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "fastcore" },
+ { name = "lxml" },
+ { name = "mistletoe" },
+ { name = "python-fasthtml" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/5b/c2/ffcc66649dc3088f91002bd7f0a4ca522377ae47577be8017ceba15f22fc/monsterui-0.0.34.tar.gz", hash = "sha256:66b98d4bfec7367d49dcf50fafe27f168d4438997b801fc4e40034f754695626", size = 29941 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/1c/f7/f269b3ee4eb99ff28941be3cd4c01bce4d653c0be5d4de77c824fe861687/MonsterUI-0.0.34-py3-none-any.whl", hash = "sha256:679876f7c5613459c8f67edfe74e66bd98564da0becd7a6176554466b0c5626d", size = 29812 },
+]
+
+[[package]]
+name = "mpmath"
+version = "1.3.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/e0/47/dd32fa426cc72114383ac549964eecb20ecfd886d1e5ccf5340b55b02f57/mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", size = 508106 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198 },
+]
+
+[[package]]
+name = "msgpack"
+version = "1.1.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/cb/d0/7555686ae7ff5731205df1012ede15dd9d927f6227ea151e901c7406af4f/msgpack-1.1.0.tar.gz", hash = "sha256:dd432ccc2c72b914e4cb77afce64aab761c1137cc698be3984eee260bcb2896e", size = 167260 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/4b/f9/a892a6038c861fa849b11a2bb0502c07bc698ab6ea53359e5771397d883b/msgpack-1.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7ad442d527a7e358a469faf43fda45aaf4ac3249c8310a82f0ccff9164e5dccd", size = 150428 },
+ { url = "https://files.pythonhosted.org/packages/df/7a/d174cc6a3b6bb85556e6a046d3193294a92f9a8e583cdbd46dc8a1d7e7f4/msgpack-1.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:74bed8f63f8f14d75eec75cf3d04ad581da6b914001b474a5d3cd3372c8cc27d", size = 84131 },
+ { url = "https://files.pythonhosted.org/packages/08/52/bf4fbf72f897a23a56b822997a72c16de07d8d56d7bf273242f884055682/msgpack-1.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:914571a2a5b4e7606997e169f64ce53a8b1e06f2cf2c3a7273aa106236d43dd5", size = 81215 },
+ { url = "https://files.pythonhosted.org/packages/02/95/dc0044b439b518236aaf012da4677c1b8183ce388411ad1b1e63c32d8979/msgpack-1.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c921af52214dcbb75e6bdf6a661b23c3e6417f00c603dd2070bccb5c3ef499f5", size = 371229 },
+ { url = "https://files.pythonhosted.org/packages/ff/75/09081792db60470bef19d9c2be89f024d366b1e1973c197bb59e6aabc647/msgpack-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8ce0b22b890be5d252de90d0e0d119f363012027cf256185fc3d474c44b1b9e", size = 378034 },
+ { url = "https://files.pythonhosted.org/packages/32/d3/c152e0c55fead87dd948d4b29879b0f14feeeec92ef1fd2ec21b107c3f49/msgpack-1.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:73322a6cc57fcee3c0c57c4463d828e9428275fb85a27aa2aa1a92fdc42afd7b", size = 363070 },
+ { url = "https://files.pythonhosted.org/packages/d9/2c/82e73506dd55f9e43ac8aa007c9dd088c6f0de2aa19e8f7330e6a65879fc/msgpack-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e1f3c3d21f7cf67bcf2da8e494d30a75e4cf60041d98b3f79875afb5b96f3a3f", size = 359863 },
+ { url = "https://files.pythonhosted.org/packages/cb/a0/3d093b248837094220e1edc9ec4337de3443b1cfeeb6e0896af8ccc4cc7a/msgpack-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:64fc9068d701233effd61b19efb1485587560b66fe57b3e50d29c5d78e7fef68", size = 368166 },
+ { url = "https://files.pythonhosted.org/packages/e4/13/7646f14f06838b406cf5a6ddbb7e8dc78b4996d891ab3b93c33d1ccc8678/msgpack-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:42f754515e0f683f9c79210a5d1cad631ec3d06cea5172214d2176a42e67e19b", size = 370105 },
+ { url = "https://files.pythonhosted.org/packages/67/fa/dbbd2443e4578e165192dabbc6a22c0812cda2649261b1264ff515f19f15/msgpack-1.1.0-cp310-cp310-win32.whl", hash = "sha256:3df7e6b05571b3814361e8464f9304c42d2196808e0119f55d0d3e62cd5ea044", size = 68513 },
+ { url = "https://files.pythonhosted.org/packages/24/ce/c2c8fbf0ded750cb63cbcbb61bc1f2dfd69e16dca30a8af8ba80ec182dcd/msgpack-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:685ec345eefc757a7c8af44a3032734a739f8c45d1b0ac45efc5d8977aa4720f", size = 74687 },
+ { url = "https://files.pythonhosted.org/packages/b7/5e/a4c7154ba65d93be91f2f1e55f90e76c5f91ccadc7efc4341e6f04c8647f/msgpack-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3d364a55082fb2a7416f6c63ae383fbd903adb5a6cf78c5b96cc6316dc1cedc7", size = 150803 },
+ { url = "https://files.pythonhosted.org/packages/60/c2/687684164698f1d51c41778c838d854965dd284a4b9d3a44beba9265c931/msgpack-1.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:79ec007767b9b56860e0372085f8504db5d06bd6a327a335449508bbee9648fa", size = 84343 },
+ { url = "https://files.pythonhosted.org/packages/42/ae/d3adea9bb4a1342763556078b5765e666f8fdf242e00f3f6657380920972/msgpack-1.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6ad622bf7756d5a497d5b6836e7fc3752e2dd6f4c648e24b1803f6048596f701", size = 81408 },
+ { url = "https://files.pythonhosted.org/packages/dc/17/6313325a6ff40ce9c3207293aee3ba50104aed6c2c1559d20d09e5c1ff54/msgpack-1.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e59bca908d9ca0de3dc8684f21ebf9a690fe47b6be93236eb40b99af28b6ea6", size = 396096 },
+ { url = "https://files.pythonhosted.org/packages/a8/a1/ad7b84b91ab5a324e707f4c9761633e357820b011a01e34ce658c1dda7cc/msgpack-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e1da8f11a3dd397f0a32c76165cf0c4eb95b31013a94f6ecc0b280c05c91b59", size = 403671 },
+ { url = "https://files.pythonhosted.org/packages/bb/0b/fd5b7c0b308bbf1831df0ca04ec76fe2f5bf6319833646b0a4bd5e9dc76d/msgpack-1.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:452aff037287acb1d70a804ffd022b21fa2bb7c46bee884dbc864cc9024128a0", size = 387414 },
+ { url = "https://files.pythonhosted.org/packages/f0/03/ff8233b7c6e9929a1f5da3c7860eccd847e2523ca2de0d8ef4878d354cfa/msgpack-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8da4bf6d54ceed70e8861f833f83ce0814a2b72102e890cbdfe4b34764cdd66e", size = 383759 },
+ { url = "https://files.pythonhosted.org/packages/1f/1b/eb82e1fed5a16dddd9bc75f0854b6e2fe86c0259c4353666d7fab37d39f4/msgpack-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:41c991beebf175faf352fb940bf2af9ad1fb77fd25f38d9142053914947cdbf6", size = 394405 },
+ { url = "https://files.pythonhosted.org/packages/90/2e/962c6004e373d54ecf33d695fb1402f99b51832631e37c49273cc564ffc5/msgpack-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a52a1f3a5af7ba1c9ace055b659189f6c669cf3657095b50f9602af3a3ba0fe5", size = 396041 },
+ { url = "https://files.pythonhosted.org/packages/f8/20/6e03342f629474414860c48aeffcc2f7f50ddaf351d95f20c3f1c67399a8/msgpack-1.1.0-cp311-cp311-win32.whl", hash = "sha256:58638690ebd0a06427c5fe1a227bb6b8b9fdc2bd07701bec13c2335c82131a88", size = 68538 },
+ { url = "https://files.pythonhosted.org/packages/aa/c4/5a582fc9a87991a3e6f6800e9bb2f3c82972912235eb9539954f3e9997c7/msgpack-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fd2906780f25c8ed5d7b323379f6138524ba793428db5d0e9d226d3fa6aa1788", size = 74871 },
+ { url = "https://files.pythonhosted.org/packages/e1/d6/716b7ca1dbde63290d2973d22bbef1b5032ca634c3ff4384a958ec3f093a/msgpack-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d46cf9e3705ea9485687aa4001a76e44748b609d260af21c4ceea7f2212a501d", size = 152421 },
+ { url = "https://files.pythonhosted.org/packages/70/da/5312b067f6773429cec2f8f08b021c06af416bba340c912c2ec778539ed6/msgpack-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5dbad74103df937e1325cc4bfeaf57713be0b4f15e1c2da43ccdd836393e2ea2", size = 85277 },
+ { url = "https://files.pythonhosted.org/packages/28/51/da7f3ae4462e8bb98af0d5bdf2707f1b8c65a0d4f496e46b6afb06cbc286/msgpack-1.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:58dfc47f8b102da61e8949708b3eafc3504509a5728f8b4ddef84bd9e16ad420", size = 82222 },
+ { url = "https://files.pythonhosted.org/packages/33/af/dc95c4b2a49cff17ce47611ca9ba218198806cad7796c0b01d1e332c86bb/msgpack-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4676e5be1b472909b2ee6356ff425ebedf5142427842aa06b4dfd5117d1ca8a2", size = 392971 },
+ { url = "https://files.pythonhosted.org/packages/f1/54/65af8de681fa8255402c80eda2a501ba467921d5a7a028c9c22a2c2eedb5/msgpack-1.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:17fb65dd0bec285907f68b15734a993ad3fc94332b5bb21b0435846228de1f39", size = 401403 },
+ { url = "https://files.pythonhosted.org/packages/97/8c/e333690777bd33919ab7024269dc3c41c76ef5137b211d776fbb404bfead/msgpack-1.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a51abd48c6d8ac89e0cfd4fe177c61481aca2d5e7ba42044fd218cfd8ea9899f", size = 385356 },
+ { url = "https://files.pythonhosted.org/packages/57/52/406795ba478dc1c890559dd4e89280fa86506608a28ccf3a72fbf45df9f5/msgpack-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2137773500afa5494a61b1208619e3871f75f27b03bcfca7b3a7023284140247", size = 383028 },
+ { url = "https://files.pythonhosted.org/packages/e7/69/053b6549bf90a3acadcd8232eae03e2fefc87f066a5b9fbb37e2e608859f/msgpack-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:398b713459fea610861c8a7b62a6fec1882759f308ae0795b5413ff6a160cf3c", size = 391100 },
+ { url = "https://files.pythonhosted.org/packages/23/f0/d4101d4da054f04274995ddc4086c2715d9b93111eb9ed49686c0f7ccc8a/msgpack-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:06f5fd2f6bb2a7914922d935d3b8bb4a7fff3a9a91cfce6d06c13bc42bec975b", size = 394254 },
+ { url = "https://files.pythonhosted.org/packages/1c/12/cf07458f35d0d775ff3a2dc5559fa2e1fcd06c46f1ef510e594ebefdca01/msgpack-1.1.0-cp312-cp312-win32.whl", hash = "sha256:ad33e8400e4ec17ba782f7b9cf868977d867ed784a1f5f2ab46e7ba53b6e1e1b", size = 69085 },
+ { url = "https://files.pythonhosted.org/packages/73/80/2708a4641f7d553a63bc934a3eb7214806b5b39d200133ca7f7afb0a53e8/msgpack-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:115a7af8ee9e8cddc10f87636767857e7e3717b7a2e97379dc2054712693e90f", size = 75347 },
+]
+
+[[package]]
+name = "multidict"
+version = "6.1.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "typing-extensions", marker = "python_full_version < '3.11'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/d6/be/504b89a5e9ca731cd47487e91c469064f8ae5af93b7259758dcfc2b9c848/multidict-6.1.0.tar.gz", hash = "sha256:22ae2ebf9b0c69d206c003e2f6a914ea33f0a932d4aa16f236afc049d9958f4a", size = 64002 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/29/68/259dee7fd14cf56a17c554125e534f6274c2860159692a414d0b402b9a6d/multidict-6.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3380252550e372e8511d49481bd836264c009adb826b23fefcc5dd3c69692f60", size = 48628 },
+ { url = "https://files.pythonhosted.org/packages/50/79/53ba256069fe5386a4a9e80d4e12857ced9de295baf3e20c68cdda746e04/multidict-6.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:99f826cbf970077383d7de805c0681799491cb939c25450b9b5b3ced03ca99f1", size = 29327 },
+ { url = "https://files.pythonhosted.org/packages/ff/10/71f1379b05b196dae749b5ac062e87273e3f11634f447ebac12a571d90ae/multidict-6.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a114d03b938376557927ab23f1e950827c3b893ccb94b62fd95d430fd0e5cf53", size = 29689 },
+ { url = "https://files.pythonhosted.org/packages/71/45/70bac4f87438ded36ad4793793c0095de6572d433d98575a5752629ef549/multidict-6.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1c416351ee6271b2f49b56ad7f308072f6f44b37118d69c2cad94f3fa8a40d5", size = 126639 },
+ { url = "https://files.pythonhosted.org/packages/80/cf/17f35b3b9509b4959303c05379c4bfb0d7dd05c3306039fc79cf035bbac0/multidict-6.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6b5d83030255983181005e6cfbac1617ce9746b219bc2aad52201ad121226581", size = 134315 },
+ { url = "https://files.pythonhosted.org/packages/ef/1f/652d70ab5effb33c031510a3503d4d6efc5ec93153562f1ee0acdc895a57/multidict-6.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3e97b5e938051226dc025ec80980c285b053ffb1e25a3db2a3aa3bc046bf7f56", size = 129471 },
+ { url = "https://files.pythonhosted.org/packages/a6/64/2dd6c4c681688c0165dea3975a6a4eab4944ea30f35000f8b8af1df3148c/multidict-6.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d618649d4e70ac6efcbba75be98b26ef5078faad23592f9b51ca492953012429", size = 124585 },
+ { url = "https://files.pythonhosted.org/packages/87/56/e6ee5459894c7e554b57ba88f7257dc3c3d2d379cb15baaa1e265b8c6165/multidict-6.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:10524ebd769727ac77ef2278390fb0068d83f3acb7773792a5080f2b0abf7748", size = 116957 },
+ { url = "https://files.pythonhosted.org/packages/36/9e/616ce5e8d375c24b84f14fc263c7ef1d8d5e8ef529dbc0f1df8ce71bb5b8/multidict-6.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ff3827aef427c89a25cc96ded1759271a93603aba9fb977a6d264648ebf989db", size = 128609 },
+ { url = "https://files.pythonhosted.org/packages/8c/4f/4783e48a38495d000f2124020dc96bacc806a4340345211b1ab6175a6cb4/multidict-6.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:06809f4f0f7ab7ea2cabf9caca7d79c22c0758b58a71f9d32943ae13c7ace056", size = 123016 },
+ { url = "https://files.pythonhosted.org/packages/3e/b3/4950551ab8fc39862ba5e9907dc821f896aa829b4524b4deefd3e12945ab/multidict-6.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:f179dee3b863ab1c59580ff60f9d99f632f34ccb38bf67a33ec6b3ecadd0fd76", size = 133542 },
+ { url = "https://files.pythonhosted.org/packages/96/4d/f0ce6ac9914168a2a71df117935bb1f1781916acdecbb43285e225b484b8/multidict-6.1.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:aaed8b0562be4a0876ee3b6946f6869b7bcdb571a5d1496683505944e268b160", size = 130163 },
+ { url = "https://files.pythonhosted.org/packages/be/72/17c9f67e7542a49dd252c5ae50248607dfb780bcc03035907dafefb067e3/multidict-6.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3c8b88a2ccf5493b6c8da9076fb151ba106960a2df90c2633f342f120751a9e7", size = 126832 },
+ { url = "https://files.pythonhosted.org/packages/71/9f/72d719e248cbd755c8736c6d14780533a1606ffb3fbb0fbd77da9f0372da/multidict-6.1.0-cp310-cp310-win32.whl", hash = "sha256:4a9cb68166a34117d6646c0023c7b759bf197bee5ad4272f420a0141d7eb03a0", size = 26402 },
+ { url = "https://files.pythonhosted.org/packages/04/5a/d88cd5d00a184e1ddffc82aa2e6e915164a6d2641ed3606e766b5d2f275a/multidict-6.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:20b9b5fbe0b88d0bdef2012ef7dee867f874b72528cf1d08f1d59b0e3850129d", size = 28800 },
+ { url = "https://files.pythonhosted.org/packages/93/13/df3505a46d0cd08428e4c8169a196131d1b0c4b515c3649829258843dde6/multidict-6.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3efe2c2cb5763f2f1b275ad2bf7a287d3f7ebbef35648a9726e3b69284a4f3d6", size = 48570 },
+ { url = "https://files.pythonhosted.org/packages/f0/e1/a215908bfae1343cdb72f805366592bdd60487b4232d039c437fe8f5013d/multidict-6.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7053d3b0353a8b9de430a4f4b4268ac9a4fb3481af37dfe49825bf45ca24156", size = 29316 },
+ { url = "https://files.pythonhosted.org/packages/70/0f/6dc70ddf5d442702ed74f298d69977f904960b82368532c88e854b79f72b/multidict-6.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:27e5fc84ccef8dfaabb09d82b7d179c7cf1a3fbc8a966f8274fcb4ab2eb4cadb", size = 29640 },
+ { url = "https://files.pythonhosted.org/packages/d8/6d/9c87b73a13d1cdea30b321ef4b3824449866bd7f7127eceed066ccb9b9ff/multidict-6.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e2b90b43e696f25c62656389d32236e049568b39320e2735d51f08fd362761b", size = 131067 },
+ { url = "https://files.pythonhosted.org/packages/cc/1e/1b34154fef373371fd6c65125b3d42ff5f56c7ccc6bfff91b9b3c60ae9e0/multidict-6.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d83a047959d38a7ff552ff94be767b7fd79b831ad1cd9920662db05fec24fe72", size = 138507 },
+ { url = "https://files.pythonhosted.org/packages/fb/e0/0bc6b2bac6e461822b5f575eae85da6aae76d0e2a79b6665d6206b8e2e48/multidict-6.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d1a9dd711d0877a1ece3d2e4fea11a8e75741ca21954c919406b44e7cf971304", size = 133905 },
+ { url = "https://files.pythonhosted.org/packages/ba/af/73d13b918071ff9b2205fcf773d316e0f8fefb4ec65354bbcf0b10908cc6/multidict-6.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec2abea24d98246b94913b76a125e855eb5c434f7c46546046372fe60f666351", size = 129004 },
+ { url = "https://files.pythonhosted.org/packages/74/21/23960627b00ed39643302d81bcda44c9444ebcdc04ee5bedd0757513f259/multidict-6.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4867cafcbc6585e4b678876c489b9273b13e9fff9f6d6d66add5e15d11d926cb", size = 121308 },
+ { url = "https://files.pythonhosted.org/packages/8b/5c/cf282263ffce4a596ed0bb2aa1a1dddfe1996d6a62d08842a8d4b33dca13/multidict-6.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5b48204e8d955c47c55b72779802b219a39acc3ee3d0116d5080c388970b76e3", size = 132608 },
+ { url = "https://files.pythonhosted.org/packages/d7/3e/97e778c041c72063f42b290888daff008d3ab1427f5b09b714f5a8eff294/multidict-6.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:d8fff389528cad1618fb4b26b95550327495462cd745d879a8c7c2115248e399", size = 127029 },
+ { url = "https://files.pythonhosted.org/packages/47/ac/3efb7bfe2f3aefcf8d103e9a7162572f01936155ab2f7ebcc7c255a23212/multidict-6.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:a7a9541cd308eed5e30318430a9c74d2132e9a8cb46b901326272d780bf2d423", size = 137594 },
+ { url = "https://files.pythonhosted.org/packages/42/9b/6c6e9e8dc4f915fc90a9b7798c44a30773dea2995fdcb619870e705afe2b/multidict-6.1.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:da1758c76f50c39a2efd5e9859ce7d776317eb1dd34317c8152ac9251fc574a3", size = 134556 },
+ { url = "https://files.pythonhosted.org/packages/1d/10/8e881743b26aaf718379a14ac58572a240e8293a1c9d68e1418fb11c0f90/multidict-6.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c943a53e9186688b45b323602298ab727d8865d8c9ee0b17f8d62d14b56f0753", size = 130993 },
+ { url = "https://files.pythonhosted.org/packages/45/84/3eb91b4b557442802d058a7579e864b329968c8d0ea57d907e7023c677f2/multidict-6.1.0-cp311-cp311-win32.whl", hash = "sha256:90f8717cb649eea3504091e640a1b8568faad18bd4b9fcd692853a04475a4b80", size = 26405 },
+ { url = "https://files.pythonhosted.org/packages/9f/0b/ad879847ecbf6d27e90a6eabb7eff6b62c129eefe617ea45eae7c1f0aead/multidict-6.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:82176036e65644a6cc5bd619f65f6f19781e8ec2e5330f51aa9ada7504cc1926", size = 28795 },
+ { url = "https://files.pythonhosted.org/packages/fd/16/92057c74ba3b96d5e211b553895cd6dc7cc4d1e43d9ab8fafc727681ef71/multidict-6.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b04772ed465fa3cc947db808fa306d79b43e896beb677a56fb2347ca1a49c1fa", size = 48713 },
+ { url = "https://files.pythonhosted.org/packages/94/3d/37d1b8893ae79716179540b89fc6a0ee56b4a65fcc0d63535c6f5d96f217/multidict-6.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6180c0ae073bddeb5a97a38c03f30c233e0a4d39cd86166251617d1bbd0af436", size = 29516 },
+ { url = "https://files.pythonhosted.org/packages/a2/12/adb6b3200c363062f805275b4c1e656be2b3681aada66c80129932ff0bae/multidict-6.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:071120490b47aa997cca00666923a83f02c7fbb44f71cf7f136df753f7fa8761", size = 29557 },
+ { url = "https://files.pythonhosted.org/packages/47/e9/604bb05e6e5bce1e6a5cf80a474e0f072e80d8ac105f1b994a53e0b28c42/multidict-6.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50b3a2710631848991d0bf7de077502e8994c804bb805aeb2925a981de58ec2e", size = 130170 },
+ { url = "https://files.pythonhosted.org/packages/7e/13/9efa50801785eccbf7086b3c83b71a4fb501a4d43549c2f2f80b8787d69f/multidict-6.1.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b58c621844d55e71c1b7f7c498ce5aa6985d743a1a59034c57a905b3f153c1ef", size = 134836 },
+ { url = "https://files.pythonhosted.org/packages/bf/0f/93808b765192780d117814a6dfcc2e75de6dcc610009ad408b8814dca3ba/multidict-6.1.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55b6d90641869892caa9ca42ff913f7ff1c5ece06474fbd32fb2cf6834726c95", size = 133475 },
+ { url = "https://files.pythonhosted.org/packages/d3/c8/529101d7176fe7dfe1d99604e48d69c5dfdcadb4f06561f465c8ef12b4df/multidict-6.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b820514bfc0b98a30e3d85462084779900347e4d49267f747ff54060cc33925", size = 131049 },
+ { url = "https://files.pythonhosted.org/packages/ca/0c/fc85b439014d5a58063e19c3a158a889deec399d47b5269a0f3b6a2e28bc/multidict-6.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:10a9b09aba0c5b48c53761b7c720aaaf7cf236d5fe394cd399c7ba662d5f9966", size = 120370 },
+ { url = "https://files.pythonhosted.org/packages/db/46/d4416eb20176492d2258fbd47b4abe729ff3b6e9c829ea4236f93c865089/multidict-6.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1e16bf3e5fc9f44632affb159d30a437bfe286ce9e02754759be5536b169b305", size = 125178 },
+ { url = "https://files.pythonhosted.org/packages/5b/46/73697ad7ec521df7de5531a32780bbfd908ded0643cbe457f981a701457c/multidict-6.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:76f364861c3bfc98cbbcbd402d83454ed9e01a5224bb3a28bf70002a230f73e2", size = 119567 },
+ { url = "https://files.pythonhosted.org/packages/cd/ed/51f060e2cb0e7635329fa6ff930aa5cffa17f4c7f5c6c3ddc3500708e2f2/multidict-6.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:820c661588bd01a0aa62a1283f20d2be4281b086f80dad9e955e690c75fb54a2", size = 129822 },
+ { url = "https://files.pythonhosted.org/packages/df/9e/ee7d1954b1331da3eddea0c4e08d9142da5f14b1321c7301f5014f49d492/multidict-6.1.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:0e5f362e895bc5b9e67fe6e4ded2492d8124bdf817827f33c5b46c2fe3ffaca6", size = 128656 },
+ { url = "https://files.pythonhosted.org/packages/77/00/8538f11e3356b5d95fa4b024aa566cde7a38aa7a5f08f4912b32a037c5dc/multidict-6.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3ec660d19bbc671e3a6443325f07263be452c453ac9e512f5eb935e7d4ac28b3", size = 125360 },
+ { url = "https://files.pythonhosted.org/packages/be/05/5d334c1f2462d43fec2363cd00b1c44c93a78c3925d952e9a71caf662e96/multidict-6.1.0-cp312-cp312-win32.whl", hash = "sha256:58130ecf8f7b8112cdb841486404f1282b9c86ccb30d3519faf301b2e5659133", size = 26382 },
+ { url = "https://files.pythonhosted.org/packages/a3/bf/f332a13486b1ed0496d624bcc7e8357bb8053823e8cd4b9a18edc1d97e73/multidict-6.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:188215fc0aafb8e03341995e7c4797860181562380f81ed0a87ff455b70bf1f1", size = 28529 },
+ { url = "https://files.pythonhosted.org/packages/99/b7/b9e70fde2c0f0c9af4cc5277782a89b66d35948ea3369ec9f598358c3ac5/multidict-6.1.0-py3-none-any.whl", hash = "sha256:48e171e52d1c4d33888e529b999e5900356b9ae588c2f09a52dcefb158b27506", size = 10051 },
+]
+
+[[package]]
+name = "multiprocess"
+version = "0.70.16"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "dill" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/b5/ae/04f39c5d0d0def03247c2893d6f2b83c136bf3320a2154d7b8858f2ba72d/multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1", size = 1772603 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ef/76/6e712a2623d146d314f17598df5de7224c85c0060ef63fd95cc15a25b3fa/multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee", size = 134980 },
+ { url = "https://files.pythonhosted.org/packages/0f/ab/1e6e8009e380e22254ff539ebe117861e5bdb3bff1fc977920972237c6c7/multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec", size = 134982 },
+ { url = "https://files.pythonhosted.org/packages/bc/f7/7ec7fddc92e50714ea3745631f79bd9c96424cb2702632521028e57d3a36/multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02", size = 134824 },
+ { url = "https://files.pythonhosted.org/packages/50/15/b56e50e8debaf439f44befec5b2af11db85f6e0f344c3113ae0be0593a91/multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a", size = 143519 },
+ { url = "https://files.pythonhosted.org/packages/0a/7d/a988f258104dcd2ccf1ed40fdc97e26c4ac351eeaf81d76e266c52d84e2f/multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e", size = 146741 },
+ { url = "https://files.pythonhosted.org/packages/ea/89/38df130f2c799090c978b366cfdf5b96d08de5b29a4a293df7f7429fa50b/multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435", size = 132628 },
+ { url = "https://files.pythonhosted.org/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3", size = 133351 },
+]
+
+[[package]]
+name = "mypy-extensions"
+version = "1.0.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/98/a4/1ab47638b92648243faf97a5aeb6ea83059cc3624972ab6b8d2316078d3f/mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782", size = 4433 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695 },
+]
+
+[[package]]
+name = "networkx"
+version = "3.4.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263 },
+]
+
+[[package]]
+name = "ninja"
+version = "1.11.1.3"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/bd/8f/21a2701f95b7d0d5137736561b3427ece0c4a1e085d4a223b92d16ab7d8b/ninja-1.11.1.3.tar.gz", hash = "sha256:edfa0d2e9d7ead1635b03e40a32ad56cc8f56798b6e2e9848d8300b174897076", size = 129532 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ea/ba/0069cd4a83d68f7b0308be70e219b15d675e50c8ea28763a3f0373c45bfc/ninja-1.11.1.3-py3-none-macosx_10_9_universal2.whl", hash = "sha256:2b4879ea3f1169f3d855182c57dcc84d1b5048628c8b7be0d702b81882a37237", size = 279132 },
+ { url = "https://files.pythonhosted.org/packages/72/6b/3805be87df8417a0c7b21078c8045f2a1e59b34f371bfe4cb4fb0d6df7f2/ninja-1.11.1.3-py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:bc3ebc8b2e47716149f3541742b5cd8e0b08f51013b825c05baca3e34854370d", size = 472101 },
+ { url = "https://files.pythonhosted.org/packages/6b/35/a8e38d54768e67324e365e2a41162be298f51ec93e6bd4b18d237d7250d8/ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a27e78ca71316c8654965ee94b286a98c83877bfebe2607db96897bbfe458af0", size = 422884 },
+ { url = "https://files.pythonhosted.org/packages/2f/99/7996457319e139c02697fb2aa28e42fe32bb0752cef492edc69d56a3552e/ninja-1.11.1.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2883ea46b3c5079074f56820f9989c6261fcc6fd873d914ee49010ecf283c3b2", size = 157046 },
+ { url = "https://files.pythonhosted.org/packages/6d/8b/93f38e5cddf76ccfdab70946515b554f25d2b4c95ef9b2f9cfbc43fa7cc1/ninja-1.11.1.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c4bdb9fd2d0c06501ae15abfd23407660e95659e384acd36e013b6dd7d8a8e4", size = 180014 },
+ { url = "https://files.pythonhosted.org/packages/7d/1d/713884d0fa3c972164f69d552e0701d30e2bf25eba9ef160bfb3dc69926a/ninja-1.11.1.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:114ed5c61c8474df6a69ab89097a20749b769e2c219a452cb2fadc49b0d581b0", size = 157098 },
+ { url = "https://files.pythonhosted.org/packages/c7/22/ecb0f70e77c9e22ee250aa717a608a142756833a34d43943d7d658ee0e56/ninja-1.11.1.3-py3-none-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7fa2247fce98f683bc712562d82b22b8a0a5c000738a13147ca2d1b68c122298", size = 130089 },
+ { url = "https://files.pythonhosted.org/packages/ec/a6/3ee846c20ab6ad95b90c5c8703c76cb1f39cc8ce2d1ae468956e3b1b2581/ninja-1.11.1.3-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:a38c6c6c8032bed68b70c3b065d944c35e9f903342875d3a3218c1607987077c", size = 372508 },
+ { url = "https://files.pythonhosted.org/packages/95/0d/aa44abe4141f29148ce671ac8c92045878906b18691c6f87a29711c2ff1c/ninja-1.11.1.3-py3-none-musllinux_1_1_i686.whl", hash = "sha256:56ada5d33b8741d298836644042faddebc83ee669782d661e21563034beb5aba", size = 419369 },
+ { url = "https://files.pythonhosted.org/packages/f7/ec/48bf5105568ac9bd2016b701777bdd5000cc09a14ac837fef9f15e8d634e/ninja-1.11.1.3-py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:53409151da081f3c198bb0bfc220a7f4e821e022c5b7d29719adda892ddb31bb", size = 420304 },
+ { url = "https://files.pythonhosted.org/packages/18/e5/69df63976cf971a03379899f8520a036c9dbab26330b37197512aed5b3df/ninja-1.11.1.3-py3-none-musllinux_1_1_s390x.whl", hash = "sha256:1ad2112c2b0159ed7c4ae3731595191b1546ba62316fc40808edecd0306fefa3", size = 416056 },
+ { url = "https://files.pythonhosted.org/packages/6f/4f/bdb401af7ed0e24a3fef058e13a149f2de1ce4b176699076993615d55610/ninja-1.11.1.3-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:28aea3c1c280cba95b8608d50797169f3a34280e3e9a6379b6e340f0c9eaeeb0", size = 379725 },
+ { url = "https://files.pythonhosted.org/packages/bd/68/05e7863bf13128c61652eeb3ec7096c3d3a602f32f31752dbfb034e3fa07/ninja-1.11.1.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b6966f83064a88a51693073eea3decd47e08c3965241e09578ef7aa3a7738329", size = 434881 },
+ { url = "https://files.pythonhosted.org/packages/bd/ad/edc0d1efe77f29f45bbca2e1dab07ef597f61a88de6e4bccffc0aec2256c/ninja-1.11.1.3-py3-none-win32.whl", hash = "sha256:a4a3b71490557e18c010cbb26bd1ea9a0c32ee67e8f105e9731515b6e0af792e", size = 255988 },
+ { url = "https://files.pythonhosted.org/packages/03/93/09a9f7672b4f97438aca6217ac54212a63273f1cd3b46b731d0bb22c53e7/ninja-1.11.1.3-py3-none-win_amd64.whl", hash = "sha256:04d48d14ea7ba11951c156599ab526bdda575450797ff57c6fdf99b2554d09c7", size = 296502 },
+ { url = "https://files.pythonhosted.org/packages/d9/9d/0cc1e82849070ff3cbee69f326cb48a839407bcd15d8844443c30a5e7509/ninja-1.11.1.3-py3-none-win_arm64.whl", hash = "sha256:17978ad611d8ead578d83637f5ae80c2261b033db0b493a7ce94f88623f29e1b", size = 270571 },
+]
+
+[[package]]
+name = "numpy"
+version = "2.2.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/ec/d0/c12ddfd3a02274be06ffc71f3efc6d0e457b0409c4481596881e748cb264/numpy-2.2.2.tar.gz", hash = "sha256:ed6906f61834d687738d25988ae117683705636936cc605be0bb208b23df4d8f", size = 20233295 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/70/2a/69033dc22d981ad21325314f8357438078f5c28310a6d89fb3833030ec8a/numpy-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7079129b64cb78bdc8d611d1fd7e8002c0a2565da6a47c4df8062349fee90e3e", size = 21215825 },
+ { url = "https://files.pythonhosted.org/packages/31/2c/39f91e00bbd3d5639b027ac48c55dc5f2992bd2b305412d26be4c830862a/numpy-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2ec6c689c61df613b783aeb21f945c4cbe6c51c28cb70aae8430577ab39f163e", size = 14354996 },
+ { url = "https://files.pythonhosted.org/packages/0a/2c/d468ebd253851af10de5b3e8f3418ebabfaab5f0337a75299fbeb8b8c17a/numpy-2.2.2-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:40c7ff5da22cd391944a28c6a9c638a5eef77fcf71d6e3a79e1d9d9e82752715", size = 5393621 },
+ { url = "https://files.pythonhosted.org/packages/7f/f4/3d8a5a0da297034106c5de92be881aca7079cde6058934215a1de91334f6/numpy-2.2.2-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:995f9e8181723852ca458e22de5d9b7d3ba4da3f11cc1cb113f093b271d7965a", size = 6928931 },
+ { url = "https://files.pythonhosted.org/packages/47/a7/029354ab56edd43dd3f5efbfad292b8844f98b93174f322f82353fa46efa/numpy-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b78ea78450fd96a498f50ee096f69c75379af5138f7881a51355ab0e11286c97", size = 14333157 },
+ { url = "https://files.pythonhosted.org/packages/e3/d7/11fc594838d35c43519763310c316d4fd56f8600d3fc80a8e13e325b5c5c/numpy-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3fbe72d347fbc59f94124125e73fc4976a06927ebc503ec5afbfb35f193cd957", size = 16381794 },
+ { url = "https://files.pythonhosted.org/packages/af/d4/dd9b19cd4aff9c79d3f54d17f8be815407520d3116004bc574948336981b/numpy-2.2.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8e6da5cffbbe571f93588f562ed130ea63ee206d12851b60819512dd3e1ba50d", size = 15543990 },
+ { url = "https://files.pythonhosted.org/packages/30/97/ab96b7650f27f684a9b1e46757a7294ecc50cab27701d05f146e9f779627/numpy-2.2.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:09d6a2032faf25e8d0cadde7fd6145118ac55d2740132c1d845f98721b5ebcfd", size = 18170896 },
+ { url = "https://files.pythonhosted.org/packages/81/9b/bae9618cab20db67a2ca9d711795cad29b2ca4b73034dd3b5d05b962070a/numpy-2.2.2-cp310-cp310-win32.whl", hash = "sha256:159ff6ee4c4a36a23fe01b7c3d07bd8c14cc433d9720f977fcd52c13c0098160", size = 6573458 },
+ { url = "https://files.pythonhosted.org/packages/92/9b/95678092febd14070cfb7906ea7932e71e9dd5a6ab3ee948f9ed975e905d/numpy-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:64bd6e1762cd7f0986a740fee4dff927b9ec2c5e4d9a28d056eb17d332158014", size = 12915812 },
+ { url = "https://files.pythonhosted.org/packages/21/67/32c68756eed84df181c06528ff57e09138f893c4653448c4967311e0f992/numpy-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:642199e98af1bd2b6aeb8ecf726972d238c9877b0f6e8221ee5ab945ec8a2189", size = 21220002 },
+ { url = "https://files.pythonhosted.org/packages/3b/89/f43bcad18f2b2e5814457b1c7f7b0e671d0db12c8c0e43397ab8cb1831ed/numpy-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6d9fc9d812c81e6168b6d405bf00b8d6739a7f72ef22a9214c4241e0dc70b323", size = 14391215 },
+ { url = "https://files.pythonhosted.org/packages/9c/e6/efb8cd6122bf25e86e3dd89d9dbfec9e6861c50e8810eed77d4be59b51c6/numpy-2.2.2-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:c7d1fd447e33ee20c1f33f2c8e6634211124a9aabde3c617687d8b739aa69eac", size = 5391918 },
+ { url = "https://files.pythonhosted.org/packages/47/e2/fccf89d64d9b47ffb242823d4e851fc9d36fa751908c9aac2807924d9b4e/numpy-2.2.2-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:451e854cfae0febe723077bd0cf0a4302a5d84ff25f0bfece8f29206c7bed02e", size = 6933133 },
+ { url = "https://files.pythonhosted.org/packages/34/22/5ece749c0e5420a9380eef6fbf83d16a50010bd18fef77b9193d80a6760e/numpy-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd249bc894af67cbd8bad2c22e7cbcd46cf87ddfca1f1289d1e7e54868cc785c", size = 14338187 },
+ { url = "https://files.pythonhosted.org/packages/5b/86/caec78829311f62afa6fa334c8dfcd79cffb4d24bcf96ee02ae4840d462b/numpy-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02935e2c3c0c6cbe9c7955a8efa8908dd4221d7755644c59d1bba28b94fd334f", size = 16393429 },
+ { url = "https://files.pythonhosted.org/packages/c8/4e/0c25f74c88239a37924577d6ad780f3212a50f4b4b5f54f5e8c918d726bd/numpy-2.2.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:a972cec723e0563aa0823ee2ab1df0cb196ed0778f173b381c871a03719d4826", size = 15559103 },
+ { url = "https://files.pythonhosted.org/packages/d4/bd/d557f10fa50dc4d5871fb9606af563249b66af2fc6f99041a10e8757c6f1/numpy-2.2.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d6d6a0910c3b4368d89dde073e630882cdb266755565155bc33520283b2d9df8", size = 18182967 },
+ { url = "https://files.pythonhosted.org/packages/30/e9/66cc0f66386d78ed89e45a56e2a1d051e177b6e04477c4a41cd590ef4017/numpy-2.2.2-cp311-cp311-win32.whl", hash = "sha256:860fd59990c37c3ef913c3ae390b3929d005243acca1a86facb0773e2d8d9e50", size = 6571499 },
+ { url = "https://files.pythonhosted.org/packages/66/a3/4139296b481ae7304a43581046b8f0a20da6a0dfe0ee47a044cade796603/numpy-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:da1eeb460ecce8d5b8608826595c777728cdf28ce7b5a5a8c8ac8d949beadcf2", size = 12919805 },
+ { url = "https://files.pythonhosted.org/packages/0c/e6/847d15770ab7a01e807bdfcd4ead5bdae57c0092b7dc83878171b6af97bb/numpy-2.2.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ac9bea18d6d58a995fac1b2cb4488e17eceeac413af014b1dd26170b766d8467", size = 20912636 },
+ { url = "https://files.pythonhosted.org/packages/d1/af/f83580891577b13bd7e261416120e036d0d8fb508c8a43a73e38928b794b/numpy-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:23ae9f0c2d889b7b2d88a3791f6c09e2ef827c2446f1c4a3e3e76328ee4afd9a", size = 14098403 },
+ { url = "https://files.pythonhosted.org/packages/2b/86/d019fb60a9d0f1d4cf04b014fe88a9135090adfadcc31c1fadbb071d7fa7/numpy-2.2.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:3074634ea4d6df66be04f6728ee1d173cfded75d002c75fac79503a880bf3825", size = 5128938 },
+ { url = "https://files.pythonhosted.org/packages/7a/1b/50985edb6f1ec495a1c36452e860476f5b7ecdc3fc59ea89ccad3c4926c5/numpy-2.2.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:8ec0636d3f7d68520afc6ac2dc4b8341ddb725039de042faf0e311599f54eb37", size = 6661937 },
+ { url = "https://files.pythonhosted.org/packages/f4/1b/17efd94cad1b9d605c3f8907fb06bcffc4ce4d1d14d46b95316cccccf2b9/numpy-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ffbb1acd69fdf8e89dd60ef6182ca90a743620957afb7066385a7bbe88dc748", size = 14049518 },
+ { url = "https://files.pythonhosted.org/packages/5b/73/65d2f0b698df1731e851e3295eb29a5ab8aa06f763f7e4188647a809578d/numpy-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0349b025e15ea9d05c3d63f9657707a4e1d471128a3b1d876c095f328f8ff7f0", size = 16099146 },
+ { url = "https://files.pythonhosted.org/packages/d5/69/308f55c0e19d4b5057b5df286c5433822e3c8039ede06d4051d96f1c2c4e/numpy-2.2.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:463247edcee4a5537841d5350bc87fe8e92d7dd0e8c71c995d2c6eecb8208278", size = 15246336 },
+ { url = "https://files.pythonhosted.org/packages/f0/d8/d8d333ad0d8518d077a21aeea7b7c826eff766a2b1ce1194dea95ca0bacf/numpy-2.2.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:9dd47ff0cb2a656ad69c38da850df3454da88ee9a6fde0ba79acceee0e79daba", size = 17863507 },
+ { url = "https://files.pythonhosted.org/packages/82/6e/0b84ad3103ffc16d6673e63b5acbe7901b2af96c2837174c6318c98e27ab/numpy-2.2.2-cp312-cp312-win32.whl", hash = "sha256:4525b88c11906d5ab1b0ec1f290996c0020dd318af8b49acaa46f198b1ffc283", size = 6276491 },
+ { url = "https://files.pythonhosted.org/packages/fc/84/7f801a42a67b9772a883223a0a1e12069a14626c81a732bd70aac57aebc1/numpy-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:5acea83b801e98541619af398cc0109ff48016955cc0818f478ee9ef1c5c3dcb", size = 12616372 },
+ { url = "https://files.pythonhosted.org/packages/96/7e/1dd770ee68916ed358991ab62c2cc353ffd98d0b75b901d52183ca28e8bb/numpy-2.2.2-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:b0531f0b0e07643eb089df4c509d30d72c9ef40defa53e41363eca8a8cc61495", size = 21047291 },
+ { url = "https://files.pythonhosted.org/packages/d1/3c/ccd08578dc532a8e6927952339d4a02682b776d5e85be49ed0760308433e/numpy-2.2.2-pp310-pypy310_pp73-macosx_14_0_x86_64.whl", hash = "sha256:e9e82dcb3f2ebbc8cb5ce1102d5f1c5ed236bf8a11730fb45ba82e2841ec21df", size = 6792494 },
+ { url = "https://files.pythonhosted.org/packages/7c/28/8754b9aee4f97199f9a047f73bb644b5a2014994a6d7b061ba67134a42de/numpy-2.2.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0d4142eb40ca6f94539e4db929410f2a46052a0fe7a2c1c59f6179c39938d2a", size = 16197312 },
+ { url = "https://files.pythonhosted.org/packages/26/96/deb93f871f401045a684ca08a009382b247d14996d7a94fea6aa43c67b94/numpy-2.2.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:356ca982c188acbfa6af0d694284d8cf20e95b1c3d0aefa8929376fea9146f60", size = 12822674 },
+]
+
+[[package]]
+name = "nvidia-cublas-cu12"
+version = "12.4.5.8"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805 },
+]
+
+[[package]]
+name = "nvidia-cuda-cupti-cu12"
+version = "12.4.127"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957 },
+]
+
+[[package]]
+name = "nvidia-cuda-nvrtc-cu12"
+version = "12.4.127"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306 },
+]
+
+[[package]]
+name = "nvidia-cuda-runtime-cu12"
+version = "12.4.127"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737 },
+]
+
+[[package]]
+name = "nvidia-cudnn-cu12"
+version = "9.1.0.70"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
+]
+
+[[package]]
+name = "nvidia-cufft-cu12"
+version = "11.2.1.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 },
+]
+
+[[package]]
+name = "nvidia-curand-cu12"
+version = "10.3.5.147"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206 },
+]
+
+[[package]]
+name = "nvidia-cusolver-cu12"
+version = "11.6.1.9"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
+ { name = "nvidia-cusparse-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
+ { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 },
+]
+
+[[package]]
+name = "nvidia-cusparse-cu12"
+version = "12.3.1.170"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 },
+]
+
+[[package]]
+name = "nvidia-cusparselt-cu12"
+version = "0.6.2"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/78/a8/bcbb63b53a4b1234feeafb65544ee55495e1bb37ec31b999b963cbccfd1d/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9", size = 150057751 },
+]
+
+[[package]]
+name = "nvidia-ml-py"
+version = "12.570.86"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/ad/6e/7b0c9b88c7d520fb8639024a1a3b6dd1db03bf2c17ae85040c8758d2eb6f/nvidia_ml_py-12.570.86.tar.gz", hash = "sha256:0508d4a0c7b6d015cf574530b95a62ed4fc89da3b8b47e1aefe6777db170ec8b", size = 43147 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d8/a8/ec37169be4e2b7063b9076ed3fe0661e87335fbca665eed3f48c415cb234/nvidia_ml_py-12.570.86-py3-none-any.whl", hash = "sha256:58907de35a845abd13dcb227f18298f3b5dd94a72d04c9e594e77711e95c0b51", size = 44442 },
+]
+
+[[package]]
+name = "nvidia-nccl-cu12"
+version = "2.21.5"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", size = 188654414 },
+]
+
+[[package]]
+name = "nvidia-nvjitlink-cu12"
+version = "12.4.127"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810 },
+]
+
+[[package]]
+name = "nvidia-nvtx-cu12"
+version = "12.4.127"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 },
+]
+
+[[package]]
+name = "oauthlib"
+version = "3.2.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/6d/fa/fbf4001037904031639e6bfbfc02badfc7e12f137a8afa254df6c4c8a670/oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918", size = 177352 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/7e/80/cab10959dc1faead58dc8384a781dfbf93cb4d33d50988f7a69f1b7c9bbe/oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca", size = 151688 },
+]
+
+[[package]]
+name = "omegaconf"
+version = "2.3.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "antlr4-python3-runtime" },
+ { name = "pyyaml" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/09/48/6388f1bb9da707110532cb70ec4d2822858ddfb44f1cdf1233c20a80ea4b/omegaconf-2.3.0.tar.gz", hash = "sha256:d5d4b6d29955cc50ad50c46dc269bcd92c6e00f5f90d23ab5fee7bfca4ba4cc7", size = 3298120 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl", hash = "sha256:7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b", size = 79500 },
+]
+
+[[package]]
+name = "open-clip-torch"
+version = "2.30.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "ftfy" },
+ { name = "huggingface-hub" },
+ { name = "regex" },
+ { name = "safetensors" },
+ { name = "timm" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "torchvision", version = "0.21.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torchvision", version = "0.21.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "tqdm" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/28/71/133f3eb549d61a937e488805046baaee9eda4acfa8f8cbf01f43f64d2654/open_clip_torch-2.30.0.tar.gz", hash = "sha256:9a635e542a4fb83b268ec8ba2585698e2d5badcb1a517d26dcb49dff1a64c49f", size = 1485046 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/be/86/6ba3921b9fc0c83fd1838b1fb197973245994258586887876625eda732f8/open_clip_torch-2.30.0-py3-none-any.whl", hash = "sha256:68343092181a03a6a0b3ba8a3529856e40299d4c06bc83082ce73e0ba438187a", size = 1514664 },
+]
+
+[[package]]
+name = "opencv-python"
+version = "4.11.0.86"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "numpy" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/17/06/68c27a523103dad5837dc5b87e71285280c4f098c60e4fe8a8db6486ab09/opencv-python-4.11.0.86.tar.gz", hash = "sha256:03d60ccae62304860d232272e4a4fda93c39d595780cb40b161b310244b736a4", size = 95171956 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/05/4d/53b30a2a3ac1f75f65a59eb29cf2ee7207ce64867db47036ad61743d5a23/opencv_python-4.11.0.86-cp37-abi3-macosx_13_0_arm64.whl", hash = "sha256:432f67c223f1dc2824f5e73cdfcd9db0efc8710647d4e813012195dc9122a52a", size = 37326322 },
+ { url = "https://files.pythonhosted.org/packages/3b/84/0a67490741867eacdfa37bc18df96e08a9d579583b419010d7f3da8ff503/opencv_python-4.11.0.86-cp37-abi3-macosx_13_0_x86_64.whl", hash = "sha256:9d05ef13d23fe97f575153558653e2d6e87103995d54e6a35db3f282fe1f9c66", size = 56723197 },
+ { url = "https://files.pythonhosted.org/packages/f3/bd/29c126788da65c1fb2b5fb621b7fed0ed5f9122aa22a0868c5e2c15c6d23/opencv_python-4.11.0.86-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b92ae2c8852208817e6776ba1ea0d6b1e0a1b5431e971a2a0ddd2a8cc398202", size = 42230439 },
+ { url = "https://files.pythonhosted.org/packages/2c/8b/90eb44a40476fa0e71e05a0283947cfd74a5d36121a11d926ad6f3193cc4/opencv_python-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b02611523803495003bd87362db3e1d2a0454a6a63025dc6658a9830570aa0d", size = 62986597 },
+ { url = "https://files.pythonhosted.org/packages/fb/d7/1d5941a9dde095468b288d989ff6539dd69cd429dbf1b9e839013d21b6f0/opencv_python-4.11.0.86-cp37-abi3-win32.whl", hash = "sha256:810549cb2a4aedaa84ad9a1c92fbfdfc14090e2749cedf2c1589ad8359aa169b", size = 29384337 },
+ { url = "https://files.pythonhosted.org/packages/a4/7d/f1c30a92854540bf789e9cd5dde7ef49bbe63f855b85a2e6b3db8135c591/opencv_python-4.11.0.86-cp37-abi3-win_amd64.whl", hash = "sha256:085ad9b77c18853ea66283e98affefe2de8cc4c1f43eda4c100cf9b2721142ec", size = 39488044 },
+]
+
+[[package]]
+name = "orjson"
+version = "3.10.15"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/ae/f9/5dea21763eeff8c1590076918a446ea3d6140743e0e36f58f369928ed0f4/orjson-3.10.15.tar.gz", hash = "sha256:05ca7fe452a2e9d8d9d706a2984c95b9c2ebc5db417ce0b7a49b91d50642a23e", size = 5282482 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/52/09/e5ff18ad009e6f97eb7edc5f67ef98b3ce0c189da9c3eaca1f9587cd4c61/orjson-3.10.15-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:552c883d03ad185f720d0c09583ebde257e41b9521b74ff40e08b7dec4559c04", size = 249532 },
+ { url = "https://files.pythonhosted.org/packages/bd/b8/a75883301fe332bd433d9b0ded7d2bb706ccac679602c3516984f8814fb5/orjson-3.10.15-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:616e3e8d438d02e4854f70bfdc03a6bcdb697358dbaa6bcd19cbe24d24ece1f8", size = 125229 },
+ { url = "https://files.pythonhosted.org/packages/83/4b/22f053e7a364cc9c685be203b1e40fc5f2b3f164a9b2284547504eec682e/orjson-3.10.15-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7c2c79fa308e6edb0ffab0a31fd75a7841bf2a79a20ef08a3c6e3b26814c8ca8", size = 150148 },
+ { url = "https://files.pythonhosted.org/packages/63/64/1b54fc75ca328b57dd810541a4035fe48c12a161d466e3cf5b11a8c25649/orjson-3.10.15-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:73cb85490aa6bf98abd20607ab5c8324c0acb48d6da7863a51be48505646c814", size = 139748 },
+ { url = "https://files.pythonhosted.org/packages/5e/ff/ff0c5da781807bb0a5acd789d9a7fbcb57f7b0c6e1916595da1f5ce69f3c/orjson-3.10.15-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:763dadac05e4e9d2bc14938a45a2d0560549561287d41c465d3c58aec818b164", size = 154559 },
+ { url = "https://files.pythonhosted.org/packages/4e/9a/11e2974383384ace8495810d4a2ebef5f55aacfc97b333b65e789c9d362d/orjson-3.10.15-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a330b9b4734f09a623f74a7490db713695e13b67c959713b78369f26b3dee6bf", size = 130349 },
+ { url = "https://files.pythonhosted.org/packages/2d/c4/dd9583aea6aefee1b64d3aed13f51d2aadb014028bc929fe52936ec5091f/orjson-3.10.15-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a61a4622b7ff861f019974f73d8165be1bd9a0855e1cad18ee167acacabeb061", size = 138514 },
+ { url = "https://files.pythonhosted.org/packages/53/3e/dcf1729230654f5c5594fc752de1f43dcf67e055ac0d300c8cdb1309269a/orjson-3.10.15-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:acd271247691574416b3228db667b84775c497b245fa275c6ab90dc1ffbbd2b3", size = 130940 },
+ { url = "https://files.pythonhosted.org/packages/e8/2b/b9759fe704789937705c8a56a03f6c03e50dff7df87d65cba9a20fec5282/orjson-3.10.15-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:e4759b109c37f635aa5c5cc93a1b26927bfde24b254bcc0e1149a9fada253d2d", size = 414713 },
+ { url = "https://files.pythonhosted.org/packages/a7/6b/b9dfdbd4b6e20a59238319eb203ae07c3f6abf07eef909169b7a37ae3bba/orjson-3.10.15-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:9e992fd5cfb8b9f00bfad2fd7a05a4299db2bbe92e6440d9dd2fab27655b3182", size = 141028 },
+ { url = "https://files.pythonhosted.org/packages/7c/b5/40f5bbea619c7caf75eb4d652a9821875a8ed04acc45fe3d3ef054ca69fb/orjson-3.10.15-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f95fb363d79366af56c3f26b71df40b9a583b07bbaaf5b317407c4d58497852e", size = 129715 },
+ { url = "https://files.pythonhosted.org/packages/38/60/2272514061cbdf4d672edbca6e59c7e01cd1c706e881427d88f3c3e79761/orjson-3.10.15-cp310-cp310-win32.whl", hash = "sha256:f9875f5fea7492da8ec2444839dcc439b0ef298978f311103d0b7dfd775898ab", size = 142473 },
+ { url = "https://files.pythonhosted.org/packages/11/5d/be1490ff7eafe7fef890eb4527cf5bcd8cfd6117f3efe42a3249ec847b60/orjson-3.10.15-cp310-cp310-win_amd64.whl", hash = "sha256:17085a6aa91e1cd70ca8533989a18b5433e15d29c574582f76f821737c8d5806", size = 133564 },
+ { url = "https://files.pythonhosted.org/packages/7a/a2/21b25ce4a2c71dbb90948ee81bd7a42b4fbfc63162e57faf83157d5540ae/orjson-3.10.15-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:c4cc83960ab79a4031f3119cc4b1a1c627a3dc09df125b27c4201dff2af7eaa6", size = 249533 },
+ { url = "https://files.pythonhosted.org/packages/b2/85/2076fc12d8225698a51278009726750c9c65c846eda741e77e1761cfef33/orjson-3.10.15-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ddbeef2481d895ab8be5185f2432c334d6dec1f5d1933a9c83014d188e102cef", size = 125230 },
+ { url = "https://files.pythonhosted.org/packages/06/df/a85a7955f11274191eccf559e8481b2be74a7c6d43075d0a9506aa80284d/orjson-3.10.15-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9e590a0477b23ecd5b0ac865b1b907b01b3c5535f5e8a8f6ab0e503efb896334", size = 150148 },
+ { url = "https://files.pythonhosted.org/packages/37/b3/94c55625a29b8767c0eed194cb000b3787e3c23b4cdd13be17bae6ccbb4b/orjson-3.10.15-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a6be38bd103d2fd9bdfa31c2720b23b5d47c6796bcb1d1b598e3924441b4298d", size = 139749 },
+ { url = "https://files.pythonhosted.org/packages/53/ba/c608b1e719971e8ddac2379f290404c2e914cf8e976369bae3cad88768b1/orjson-3.10.15-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ff4f6edb1578960ed628a3b998fa54d78d9bb3e2eb2cfc5c2a09732431c678d0", size = 154558 },
+ { url = "https://files.pythonhosted.org/packages/b2/c4/c1fb835bb23ad788a39aa9ebb8821d51b1c03588d9a9e4ca7de5b354fdd5/orjson-3.10.15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b0482b21d0462eddd67e7fce10b89e0b6ac56570424662b685a0d6fccf581e13", size = 130349 },
+ { url = "https://files.pythonhosted.org/packages/78/14/bb2b48b26ab3c570b284eb2157d98c1ef331a8397f6c8bd983b270467f5c/orjson-3.10.15-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bb5cc3527036ae3d98b65e37b7986a918955f85332c1ee07f9d3f82f3a6899b5", size = 138513 },
+ { url = "https://files.pythonhosted.org/packages/4a/97/d5b353a5fe532e92c46467aa37e637f81af8468aa894cd77d2ec8a12f99e/orjson-3.10.15-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d569c1c462912acdd119ccbf719cf7102ea2c67dd03b99edcb1a3048651ac96b", size = 130942 },
+ { url = "https://files.pythonhosted.org/packages/b5/5d/a067bec55293cca48fea8b9928cfa84c623be0cce8141d47690e64a6ca12/orjson-3.10.15-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:1e6d33efab6b71d67f22bf2962895d3dc6f82a6273a965fab762e64fa90dc399", size = 414717 },
+ { url = "https://files.pythonhosted.org/packages/6f/9a/1485b8b05c6b4c4db172c438cf5db5dcfd10e72a9bc23c151a1137e763e0/orjson-3.10.15-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:c33be3795e299f565681d69852ac8c1bc5c84863c0b0030b2b3468843be90388", size = 141033 },
+ { url = "https://files.pythonhosted.org/packages/f8/d2/fc67523656e43a0c7eaeae9007c8b02e86076b15d591e9be11554d3d3138/orjson-3.10.15-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:eea80037b9fae5339b214f59308ef0589fc06dc870578b7cce6d71eb2096764c", size = 129720 },
+ { url = "https://files.pythonhosted.org/packages/79/42/f58c7bd4e5b54da2ce2ef0331a39ccbbaa7699b7f70206fbf06737c9ed7d/orjson-3.10.15-cp311-cp311-win32.whl", hash = "sha256:d5ac11b659fd798228a7adba3e37c010e0152b78b1982897020a8e019a94882e", size = 142473 },
+ { url = "https://files.pythonhosted.org/packages/00/f8/bb60a4644287a544ec81df1699d5b965776bc9848d9029d9f9b3402ac8bb/orjson-3.10.15-cp311-cp311-win_amd64.whl", hash = "sha256:cf45e0214c593660339ef63e875f32ddd5aa3b4adc15e662cdb80dc49e194f8e", size = 133570 },
+ { url = "https://files.pythonhosted.org/packages/66/85/22fe737188905a71afcc4bf7cc4c79cd7f5bbe9ed1fe0aac4ce4c33edc30/orjson-3.10.15-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:9d11c0714fc85bfcf36ada1179400862da3288fc785c30e8297844c867d7505a", size = 249504 },
+ { url = "https://files.pythonhosted.org/packages/48/b7/2622b29f3afebe938a0a9037e184660379797d5fd5234e5998345d7a5b43/orjson-3.10.15-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dba5a1e85d554e3897fa9fe6fbcff2ed32d55008973ec9a2b992bd9a65d2352d", size = 125080 },
+ { url = "https://files.pythonhosted.org/packages/ce/8f/0b72a48f4403d0b88b2a41450c535b3e8989e8a2d7800659a967efc7c115/orjson-3.10.15-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7723ad949a0ea502df656948ddd8b392780a5beaa4c3b5f97e525191b102fff0", size = 150121 },
+ { url = "https://files.pythonhosted.org/packages/06/ec/acb1a20cd49edb2000be5a0404cd43e3c8aad219f376ac8c60b870518c03/orjson-3.10.15-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6fd9bc64421e9fe9bd88039e7ce8e58d4fead67ca88e3a4014b143cec7684fd4", size = 139796 },
+ { url = "https://files.pythonhosted.org/packages/33/e1/f7840a2ea852114b23a52a1c0b2bea0a1ea22236efbcdb876402d799c423/orjson-3.10.15-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dadba0e7b6594216c214ef7894c4bd5f08d7c0135f4dd0145600be4fbcc16767", size = 154636 },
+ { url = "https://files.pythonhosted.org/packages/fa/da/31543337febd043b8fa80a3b67de627669b88c7b128d9ad4cc2ece005b7a/orjson-3.10.15-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b48f59114fe318f33bbaee8ebeda696d8ccc94c9e90bc27dbe72153094e26f41", size = 130621 },
+ { url = "https://files.pythonhosted.org/packages/ed/78/66115dc9afbc22496530d2139f2f4455698be444c7c2475cb48f657cefc9/orjson-3.10.15-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:035fb83585e0f15e076759b6fedaf0abb460d1765b6a36f48018a52858443514", size = 138516 },
+ { url = "https://files.pythonhosted.org/packages/22/84/cd4f5fb5427ffcf823140957a47503076184cb1ce15bcc1165125c26c46c/orjson-3.10.15-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d13b7fe322d75bf84464b075eafd8e7dd9eae05649aa2a5354cfa32f43c59f17", size = 130762 },
+ { url = "https://files.pythonhosted.org/packages/93/1f/67596b711ba9f56dd75d73b60089c5c92057f1130bb3a25a0f53fb9a583b/orjson-3.10.15-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:7066b74f9f259849629e0d04db6609db4cf5b973248f455ba5d3bd58a4daaa5b", size = 414700 },
+ { url = "https://files.pythonhosted.org/packages/7c/0c/6a3b3271b46443d90efb713c3e4fe83fa8cd71cda0d11a0f69a03f437c6e/orjson-3.10.15-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:88dc3f65a026bd3175eb157fea994fca6ac7c4c8579fc5a86fc2114ad05705b7", size = 141077 },
+ { url = "https://files.pythonhosted.org/packages/3b/9b/33c58e0bfc788995eccd0d525ecd6b84b40d7ed182dd0751cd4c1322ac62/orjson-3.10.15-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b342567e5465bd99faa559507fe45e33fc76b9fb868a63f1642c6bc0735ad02a", size = 129898 },
+ { url = "https://files.pythonhosted.org/packages/01/c1/d577ecd2e9fa393366a1ea0a9267f6510d86e6c4bb1cdfb9877104cac44c/orjson-3.10.15-cp312-cp312-win32.whl", hash = "sha256:0a4f27ea5617828e6b58922fdbec67b0aa4bb844e2d363b9244c47fa2180e665", size = 142566 },
+ { url = "https://files.pythonhosted.org/packages/ed/eb/a85317ee1732d1034b92d56f89f1de4d7bf7904f5c8fb9dcdd5b1c83917f/orjson-3.10.15-cp312-cp312-win_amd64.whl", hash = "sha256:ef5b87e7aa9545ddadd2309efe6824bd3dd64ac101c15dae0f2f597911d46eaa", size = 133732 },
+]
+
+[[package]]
+name = "packaging"
+version = "24.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451 },
+]
+
+[[package]]
+name = "pandas"
+version = "2.2.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "numpy" },
+ { name = "python-dateutil" },
+ { name = "pytz" },
+ { name = "tzdata" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/9c/d6/9f8431bacc2e19dca897724cd097b1bb224a6ad5433784a44b587c7c13af/pandas-2.2.3.tar.gz", hash = "sha256:4f18ba62b61d7e192368b84517265a99b4d7ee8912f8708660fb4a366cc82667", size = 4399213 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/aa/70/c853aec59839bceed032d52010ff5f1b8d87dc3114b762e4ba2727661a3b/pandas-2.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1948ddde24197a0f7add2bdc4ca83bf2b1ef84a1bc8ccffd95eda17fd836ecb5", size = 12580827 },
+ { url = "https://files.pythonhosted.org/packages/99/f2/c4527768739ffa4469b2b4fff05aa3768a478aed89a2f271a79a40eee984/pandas-2.2.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:381175499d3802cde0eabbaf6324cce0c4f5d52ca6f8c377c29ad442f50f6348", size = 11303897 },
+ { url = "https://files.pythonhosted.org/packages/ed/12/86c1747ea27989d7a4064f806ce2bae2c6d575b950be087837bdfcabacc9/pandas-2.2.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d9c45366def9a3dd85a6454c0e7908f2b3b8e9c138f5dc38fed7ce720d8453ed", size = 66480908 },
+ { url = "https://files.pythonhosted.org/packages/44/50/7db2cd5e6373ae796f0ddad3675268c8d59fb6076e66f0c339d61cea886b/pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86976a1c5b25ae3f8ccae3a5306e443569ee3c3faf444dfd0f41cda24667ad57", size = 13064210 },
+ { url = "https://files.pythonhosted.org/packages/61/61/a89015a6d5536cb0d6c3ba02cebed51a95538cf83472975275e28ebf7d0c/pandas-2.2.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b8661b0238a69d7aafe156b7fa86c44b881387509653fdf857bebc5e4008ad42", size = 16754292 },
+ { url = "https://files.pythonhosted.org/packages/ce/0d/4cc7b69ce37fac07645a94e1d4b0880b15999494372c1523508511b09e40/pandas-2.2.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:37e0aced3e8f539eccf2e099f65cdb9c8aa85109b0be6e93e2baff94264bdc6f", size = 14416379 },
+ { url = "https://files.pythonhosted.org/packages/31/9e/6ebb433de864a6cd45716af52a4d7a8c3c9aaf3a98368e61db9e69e69a9c/pandas-2.2.3-cp310-cp310-win_amd64.whl", hash = "sha256:56534ce0746a58afaf7942ba4863e0ef81c9c50d3f0ae93e9497d6a41a057645", size = 11598471 },
+ { url = "https://files.pythonhosted.org/packages/a8/44/d9502bf0ed197ba9bf1103c9867d5904ddcaf869e52329787fc54ed70cc8/pandas-2.2.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:66108071e1b935240e74525006034333f98bcdb87ea116de573a6a0dccb6c039", size = 12602222 },
+ { url = "https://files.pythonhosted.org/packages/52/11/9eac327a38834f162b8250aab32a6781339c69afe7574368fffe46387edf/pandas-2.2.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7c2875855b0ff77b2a64a0365e24455d9990730d6431b9e0ee18ad8acee13dbd", size = 11321274 },
+ { url = "https://files.pythonhosted.org/packages/45/fb/c4beeb084718598ba19aa9f5abbc8aed8b42f90930da861fcb1acdb54c3a/pandas-2.2.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd8d0c3be0515c12fed0bdbae072551c8b54b7192c7b1fda0ba56059a0179698", size = 15579836 },
+ { url = "https://files.pythonhosted.org/packages/cd/5f/4dba1d39bb9c38d574a9a22548c540177f78ea47b32f99c0ff2ec499fac5/pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c124333816c3a9b03fbeef3a9f230ba9a737e9e5bb4060aa2107a86cc0a497fc", size = 13058505 },
+ { url = "https://files.pythonhosted.org/packages/b9/57/708135b90391995361636634df1f1130d03ba456e95bcf576fada459115a/pandas-2.2.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:63cc132e40a2e084cf01adf0775b15ac515ba905d7dcca47e9a251819c575ef3", size = 16744420 },
+ { url = "https://files.pythonhosted.org/packages/86/4a/03ed6b7ee323cf30404265c284cee9c65c56a212e0a08d9ee06984ba2240/pandas-2.2.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:29401dbfa9ad77319367d36940cd8a0b3a11aba16063e39632d98b0e931ddf32", size = 14440457 },
+ { url = "https://files.pythonhosted.org/packages/ed/8c/87ddf1fcb55d11f9f847e3c69bb1c6f8e46e2f40ab1a2d2abadb2401b007/pandas-2.2.3-cp311-cp311-win_amd64.whl", hash = "sha256:3fc6873a41186404dad67245896a6e440baacc92f5b716ccd1bc9ed2995ab2c5", size = 11617166 },
+ { url = "https://files.pythonhosted.org/packages/17/a3/fb2734118db0af37ea7433f57f722c0a56687e14b14690edff0cdb4b7e58/pandas-2.2.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b1d432e8d08679a40e2a6d8b2f9770a5c21793a6f9f47fdd52c5ce1948a5a8a9", size = 12529893 },
+ { url = "https://files.pythonhosted.org/packages/e1/0c/ad295fd74bfac85358fd579e271cded3ac969de81f62dd0142c426b9da91/pandas-2.2.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a5a1595fe639f5988ba6a8e5bc9649af3baf26df3998a0abe56c02609392e0a4", size = 11363475 },
+ { url = "https://files.pythonhosted.org/packages/c6/2a/4bba3f03f7d07207481fed47f5b35f556c7441acddc368ec43d6643c5777/pandas-2.2.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5de54125a92bb4d1c051c0659e6fcb75256bf799a732a87184e5ea503965bce3", size = 15188645 },
+ { url = "https://files.pythonhosted.org/packages/38/f8/d8fddee9ed0d0c0f4a2132c1dfcf0e3e53265055da8df952a53e7eaf178c/pandas-2.2.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fffb8ae78d8af97f849404f21411c95062db1496aeb3e56f146f0355c9989319", size = 12739445 },
+ { url = "https://files.pythonhosted.org/packages/20/e8/45a05d9c39d2cea61ab175dbe6a2de1d05b679e8de2011da4ee190d7e748/pandas-2.2.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dfcb5ee8d4d50c06a51c2fffa6cff6272098ad6540aed1a76d15fb9318194d8", size = 16359235 },
+ { url = "https://files.pythonhosted.org/packages/1d/99/617d07a6a5e429ff90c90da64d428516605a1ec7d7bea494235e1c3882de/pandas-2.2.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:062309c1b9ea12a50e8ce661145c6aab431b1e99530d3cd60640e255778bd43a", size = 14056756 },
+ { url = "https://files.pythonhosted.org/packages/29/d4/1244ab8edf173a10fd601f7e13b9566c1b525c4f365d6bee918e68381889/pandas-2.2.3-cp312-cp312-win_amd64.whl", hash = "sha256:59ef3764d0fe818125a5097d2ae867ca3fa64df032331b7e0917cf5d7bf66b13", size = 11504248 },
+]
+
+[[package]]
+name = "parso"
+version = "0.8.4"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/66/94/68e2e17afaa9169cf6412ab0f28623903be73d1b32e208d9e8e541bb086d/parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d", size = 400609 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650 },
+]
+
+[[package]]
+name = "peft"
+version = "0.14.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "accelerate" },
+ { name = "huggingface-hub" },
+ { name = "numpy" },
+ { name = "packaging" },
+ { name = "psutil" },
+ { name = "pyyaml" },
+ { name = "safetensors" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "tqdm" },
+ { name = "transformers" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/21/33/fb0c31eaa8162c01e9250b21aa65d46a5339f17a818a97c68391db2ff44b/peft-0.14.0.tar.gz", hash = "sha256:546d69af7b42f5ef715a3d3261ed818bc917ae6055e5d7e187ed3f2c76ad72dc", size = 411902 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/88/05/e58e3aaa36544d30a917814e336fc65a746f708e5874945e92999bc22fa3/peft-0.14.0-py3-none-any.whl", hash = "sha256:2f04f3a870c3baf30f15e7dcaa5dd70d3e54cfdd146d3c6c187735d3ae0a0700", size = 374831 },
+]
+
+[[package]]
+name = "pexpect"
+version = "4.9.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "ptyprocess", marker = "sys_platform != 'win32'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772 },
+]
+
+[[package]]
+name = "pillow"
+version = "11.1.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/f3/af/c097e544e7bd278333db77933e535098c259609c4eb3b85381109602fb5b/pillow-11.1.0.tar.gz", hash = "sha256:368da70808b36d73b4b390a8ffac11069f8a5c85f29eff1f1b01bcf3ef5b2a20", size = 46742715 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/50/1c/2dcea34ac3d7bc96a1fd1bd0a6e06a57c67167fec2cff8d95d88229a8817/pillow-11.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:e1abe69aca89514737465752b4bcaf8016de61b3be1397a8fc260ba33321b3a8", size = 3229983 },
+ { url = "https://files.pythonhosted.org/packages/14/ca/6bec3df25e4c88432681de94a3531cc738bd85dea6c7aa6ab6f81ad8bd11/pillow-11.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c640e5a06869c75994624551f45e5506e4256562ead981cce820d5ab39ae2192", size = 3101831 },
+ { url = "https://files.pythonhosted.org/packages/d4/2c/668e18e5521e46eb9667b09e501d8e07049eb5bfe39d56be0724a43117e6/pillow-11.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a07dba04c5e22824816b2615ad7a7484432d7f540e6fa86af60d2de57b0fcee2", size = 4314074 },
+ { url = "https://files.pythonhosted.org/packages/02/80/79f99b714f0fc25f6a8499ecfd1f810df12aec170ea1e32a4f75746051ce/pillow-11.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e267b0ed063341f3e60acd25c05200df4193e15a4a5807075cd71225a2386e26", size = 4394933 },
+ { url = "https://files.pythonhosted.org/packages/81/aa/8d4ad25dc11fd10a2001d5b8a80fdc0e564ac33b293bdfe04ed387e0fd95/pillow-11.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:bd165131fd51697e22421d0e467997ad31621b74bfc0b75956608cb2906dda07", size = 4353349 },
+ { url = "https://files.pythonhosted.org/packages/84/7a/cd0c3eaf4a28cb2a74bdd19129f7726277a7f30c4f8424cd27a62987d864/pillow-11.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:abc56501c3fd148d60659aae0af6ddc149660469082859fa7b066a298bde9482", size = 4476532 },
+ { url = "https://files.pythonhosted.org/packages/8f/8b/a907fdd3ae8f01c7670dfb1499c53c28e217c338b47a813af8d815e7ce97/pillow-11.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:54ce1c9a16a9561b6d6d8cb30089ab1e5eb66918cb47d457bd996ef34182922e", size = 4279789 },
+ { url = "https://files.pythonhosted.org/packages/6f/9a/9f139d9e8cccd661c3efbf6898967a9a337eb2e9be2b454ba0a09533100d/pillow-11.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:73ddde795ee9b06257dac5ad42fcb07f3b9b813f8c1f7f870f402f4dc54b5269", size = 4413131 },
+ { url = "https://files.pythonhosted.org/packages/a8/68/0d8d461f42a3f37432203c8e6df94da10ac8081b6d35af1c203bf3111088/pillow-11.1.0-cp310-cp310-win32.whl", hash = "sha256:3a5fe20a7b66e8135d7fd617b13272626a28278d0e578c98720d9ba4b2439d49", size = 2291213 },
+ { url = "https://files.pythonhosted.org/packages/14/81/d0dff759a74ba87715509af9f6cb21fa21d93b02b3316ed43bda83664db9/pillow-11.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:b6123aa4a59d75f06e9dd3dac5bf8bc9aa383121bb3dd9a7a612e05eabc9961a", size = 2625725 },
+ { url = "https://files.pythonhosted.org/packages/ce/1f/8d50c096a1d58ef0584ddc37e6f602828515219e9d2428e14ce50f5ecad1/pillow-11.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:a76da0a31da6fcae4210aa94fd779c65c75786bc9af06289cd1c184451ef7a65", size = 2375213 },
+ { url = "https://files.pythonhosted.org/packages/dd/d6/2000bfd8d5414fb70cbbe52c8332f2283ff30ed66a9cde42716c8ecbe22c/pillow-11.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:e06695e0326d05b06833b40b7ef477e475d0b1ba3a6d27da1bb48c23209bf457", size = 3229968 },
+ { url = "https://files.pythonhosted.org/packages/d9/45/3fe487010dd9ce0a06adf9b8ff4f273cc0a44536e234b0fad3532a42c15b/pillow-11.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96f82000e12f23e4f29346e42702b6ed9a2f2fea34a740dd5ffffcc8c539eb35", size = 3101806 },
+ { url = "https://files.pythonhosted.org/packages/e3/72/776b3629c47d9d5f1c160113158a7a7ad177688d3a1159cd3b62ded5a33a/pillow-11.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3cd561ded2cf2bbae44d4605837221b987c216cff94f49dfeed63488bb228d2", size = 4322283 },
+ { url = "https://files.pythonhosted.org/packages/e4/c2/e25199e7e4e71d64eeb869f5b72c7ddec70e0a87926398785ab944d92375/pillow-11.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f189805c8be5ca5add39e6f899e6ce2ed824e65fb45f3c28cb2841911da19070", size = 4402945 },
+ { url = "https://files.pythonhosted.org/packages/c1/ed/51d6136c9d5911f78632b1b86c45241c712c5a80ed7fa7f9120a5dff1eba/pillow-11.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dd0052e9db3474df30433f83a71b9b23bd9e4ef1de13d92df21a52c0303b8ab6", size = 4361228 },
+ { url = "https://files.pythonhosted.org/packages/48/a4/fbfe9d5581d7b111b28f1d8c2762dee92e9821bb209af9fa83c940e507a0/pillow-11.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:837060a8599b8f5d402e97197d4924f05a2e0d68756998345c829c33186217b1", size = 4484021 },
+ { url = "https://files.pythonhosted.org/packages/39/db/0b3c1a5018117f3c1d4df671fb8e47d08937f27519e8614bbe86153b65a5/pillow-11.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:aa8dd43daa836b9a8128dbe7d923423e5ad86f50a7a14dc688194b7be5c0dea2", size = 4287449 },
+ { url = "https://files.pythonhosted.org/packages/d9/58/bc128da7fea8c89fc85e09f773c4901e95b5936000e6f303222490c052f3/pillow-11.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0a2f91f8a8b367e7a57c6e91cd25af510168091fb89ec5146003e424e1558a96", size = 4419972 },
+ { url = "https://files.pythonhosted.org/packages/5f/bb/58f34379bde9fe197f51841c5bbe8830c28bbb6d3801f16a83b8f2ad37df/pillow-11.1.0-cp311-cp311-win32.whl", hash = "sha256:c12fc111ef090845de2bb15009372175d76ac99969bdf31e2ce9b42e4b8cd88f", size = 2291201 },
+ { url = "https://files.pythonhosted.org/packages/3a/c6/fce9255272bcf0c39e15abd2f8fd8429a954cf344469eaceb9d0d1366913/pillow-11.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fbd43429d0d7ed6533b25fc993861b8fd512c42d04514a0dd6337fb3ccf22761", size = 2625686 },
+ { url = "https://files.pythonhosted.org/packages/c8/52/8ba066d569d932365509054859f74f2a9abee273edcef5cd75e4bc3e831e/pillow-11.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:f7955ecf5609dee9442cbface754f2c6e541d9e6eda87fad7f7a989b0bdb9d71", size = 2375194 },
+ { url = "https://files.pythonhosted.org/packages/95/20/9ce6ed62c91c073fcaa23d216e68289e19d95fb8188b9fb7a63d36771db8/pillow-11.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2062ffb1d36544d42fcaa277b069c88b01bb7298f4efa06731a7fd6cc290b81a", size = 3226818 },
+ { url = "https://files.pythonhosted.org/packages/b9/d8/f6004d98579a2596c098d1e30d10b248798cceff82d2b77aa914875bfea1/pillow-11.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a85b653980faad27e88b141348707ceeef8a1186f75ecc600c395dcac19f385b", size = 3101662 },
+ { url = "https://files.pythonhosted.org/packages/08/d9/892e705f90051c7a2574d9f24579c9e100c828700d78a63239676f960b74/pillow-11.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9409c080586d1f683df3f184f20e36fb647f2e0bc3988094d4fd8c9f4eb1b3b3", size = 4329317 },
+ { url = "https://files.pythonhosted.org/packages/8c/aa/7f29711f26680eab0bcd3ecdd6d23ed6bce180d82e3f6380fb7ae35fcf3b/pillow-11.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7fdadc077553621911f27ce206ffcbec7d3f8d7b50e0da39f10997e8e2bb7f6a", size = 4412999 },
+ { url = "https://files.pythonhosted.org/packages/c8/c4/8f0fe3b9e0f7196f6d0bbb151f9fba323d72a41da068610c4c960b16632a/pillow-11.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:93a18841d09bcdd774dcdc308e4537e1f867b3dec059c131fde0327899734aa1", size = 4368819 },
+ { url = "https://files.pythonhosted.org/packages/38/0d/84200ed6a871ce386ddc82904bfadc0c6b28b0c0ec78176871a4679e40b3/pillow-11.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:9aa9aeddeed452b2f616ff5507459e7bab436916ccb10961c4a382cd3e03f47f", size = 4496081 },
+ { url = "https://files.pythonhosted.org/packages/84/9c/9bcd66f714d7e25b64118e3952d52841a4babc6d97b6d28e2261c52045d4/pillow-11.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3cdcdb0b896e981678eee140d882b70092dac83ac1cdf6b3a60e2216a73f2b91", size = 4296513 },
+ { url = "https://files.pythonhosted.org/packages/db/61/ada2a226e22da011b45f7104c95ebda1b63dcbb0c378ad0f7c2a710f8fd2/pillow-11.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:36ba10b9cb413e7c7dfa3e189aba252deee0602c86c309799da5a74009ac7a1c", size = 4431298 },
+ { url = "https://files.pythonhosted.org/packages/e7/c4/fc6e86750523f367923522014b821c11ebc5ad402e659d8c9d09b3c9d70c/pillow-11.1.0-cp312-cp312-win32.whl", hash = "sha256:cfd5cd998c2e36a862d0e27b2df63237e67273f2fc78f47445b14e73a810e7e6", size = 2291630 },
+ { url = "https://files.pythonhosted.org/packages/08/5c/2104299949b9d504baf3f4d35f73dbd14ef31bbd1ddc2c1b66a5b7dfda44/pillow-11.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:a697cd8ba0383bba3d2d3ada02b34ed268cb548b369943cd349007730c92bddf", size = 2626369 },
+ { url = "https://files.pythonhosted.org/packages/37/f3/9b18362206b244167c958984b57c7f70a0289bfb59a530dd8af5f699b910/pillow-11.1.0-cp312-cp312-win_arm64.whl", hash = "sha256:4dd43a78897793f60766563969442020e90eb7847463eca901e41ba186a7d4a5", size = 2375240 },
+ { url = "https://files.pythonhosted.org/packages/fa/c5/389961578fb677b8b3244fcd934f720ed25a148b9a5cc81c91bdf59d8588/pillow-11.1.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8c730dc3a83e5ac137fbc92dfcfe1511ce3b2b5d7578315b63dbbb76f7f51d90", size = 3198345 },
+ { url = "https://files.pythonhosted.org/packages/c4/fa/803c0e50ffee74d4b965229e816af55276eac1d5806712de86f9371858fd/pillow-11.1.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7d33d2fae0e8b170b6a6c57400e077412240f6f5bb2a342cf1ee512a787942bb", size = 3072938 },
+ { url = "https://files.pythonhosted.org/packages/dc/67/2a3a5f8012b5d8c63fe53958ba906c1b1d0482ebed5618057ef4d22f8076/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a8d65b38173085f24bc07f8b6c505cbb7418009fa1a1fcb111b1f4961814a442", size = 3400049 },
+ { url = "https://files.pythonhosted.org/packages/e5/a0/514f0d317446c98c478d1872497eb92e7cde67003fed74f696441e647446/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:015c6e863faa4779251436db398ae75051469f7c903b043a48f078e437656f83", size = 3422431 },
+ { url = "https://files.pythonhosted.org/packages/cd/00/20f40a935514037b7d3f87adfc87d2c538430ea625b63b3af8c3f5578e72/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d44ff19eea13ae4acdaaab0179fa68c0c6f2f45d66a4d8ec1eda7d6cecbcc15f", size = 3446208 },
+ { url = "https://files.pythonhosted.org/packages/28/3c/7de681727963043e093c72e6c3348411b0185eab3263100d4490234ba2f6/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d3d8da4a631471dfaf94c10c85f5277b1f8e42ac42bade1ac67da4b4a7359b73", size = 3509746 },
+ { url = "https://files.pythonhosted.org/packages/41/67/936f9814bdd74b2dfd4822f1f7725ab5d8ff4103919a1664eb4874c58b2f/pillow-11.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:4637b88343166249fe8aa94e7c4a62a180c4b3898283bb5d3d2fd5fe10d8e4e0", size = 2626353 },
+]
+
+[[package]]
+name = "platformdirs"
+version = "4.3.6"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/13/fc/128cc9cb8f03208bdbf93d3aa862e16d376844a14f9a0ce5cf4507372de4/platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907", size = 21302 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/3c/a6/bc1012356d8ece4d66dd75c4b9fc6c1f6650ddd5991e421177d9f8f671be/platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb", size = 18439 },
+]
+
+[[package]]
+name = "pluggy"
+version = "1.5.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 },
+]
+
+[[package]]
+name = "prompt-toolkit"
+version = "3.0.50"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "wcwidth" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/a1/e1/bd15cb8ffdcfeeb2bdc215de3c3cffca11408d829e4b8416dcfe71ba8854/prompt_toolkit-3.0.50.tar.gz", hash = "sha256:544748f3860a2623ca5cd6d2795e7a14f3d0e1c3c9728359013f79877fc89bab", size = 429087 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e4/ea/d836f008d33151c7a1f62caf3d8dd782e4d15f6a43897f64480c2b8de2ad/prompt_toolkit-3.0.50-py3-none-any.whl", hash = "sha256:9b6427eb19e479d98acff65196a307c555eb567989e6d88ebbb1b509d9779198", size = 387816 },
+]
+
+[[package]]
+name = "propcache"
+version = "0.2.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/20/c8/2a13f78d82211490855b2fb303b6721348d0787fdd9a12ac46d99d3acde1/propcache-0.2.1.tar.gz", hash = "sha256:3f77ce728b19cb537714499928fe800c3dda29e8d9428778fc7c186da4c09a64", size = 41735 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/a7/a5/0ea64c9426959ef145a938e38c832fc551843481d356713ececa9a8a64e8/propcache-0.2.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6b3f39a85d671436ee3d12c017f8fdea38509e4f25b28eb25877293c98c243f6", size = 79296 },
+ { url = "https://files.pythonhosted.org/packages/76/5a/916db1aba735f55e5eca4733eea4d1973845cf77dfe67c2381a2ca3ce52d/propcache-0.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:39d51fbe4285d5db5d92a929e3e21536ea3dd43732c5b177c7ef03f918dff9f2", size = 45622 },
+ { url = "https://files.pythonhosted.org/packages/2d/62/685d3cf268b8401ec12b250b925b21d152b9d193b7bffa5fdc4815c392c2/propcache-0.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6445804cf4ec763dc70de65a3b0d9954e868609e83850a47ca4f0cb64bd79fea", size = 45133 },
+ { url = "https://files.pythonhosted.org/packages/4d/3d/31c9c29ee7192defc05aa4d01624fd85a41cf98e5922aaed206017329944/propcache-0.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9479aa06a793c5aeba49ce5c5692ffb51fcd9a7016e017d555d5e2b0045d212", size = 204809 },
+ { url = "https://files.pythonhosted.org/packages/10/a1/e4050776f4797fc86140ac9a480d5dc069fbfa9d499fe5c5d2fa1ae71f07/propcache-0.2.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d9631c5e8b5b3a0fda99cb0d29c18133bca1e18aea9effe55adb3da1adef80d3", size = 219109 },
+ { url = "https://files.pythonhosted.org/packages/c9/c0/e7ae0df76343d5e107d81e59acc085cea5fd36a48aa53ef09add7503e888/propcache-0.2.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3156628250f46a0895f1f36e1d4fbe062a1af8718ec3ebeb746f1d23f0c5dc4d", size = 217368 },
+ { url = "https://files.pythonhosted.org/packages/fc/e1/e0a2ed6394b5772508868a977d3238f4afb2eebaf9976f0b44a8d347ad63/propcache-0.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b6fb63ae352e13748289f04f37868099e69dba4c2b3e271c46061e82c745634", size = 205124 },
+ { url = "https://files.pythonhosted.org/packages/50/c1/e388c232d15ca10f233c778bbdc1034ba53ede14c207a72008de45b2db2e/propcache-0.2.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:887d9b0a65404929641a9fabb6452b07fe4572b269d901d622d8a34a4e9043b2", size = 195463 },
+ { url = "https://files.pythonhosted.org/packages/0a/fd/71b349b9def426cc73813dbd0f33e266de77305e337c8c12bfb0a2a82bfb/propcache-0.2.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a96dc1fa45bd8c407a0af03b2d5218392729e1822b0c32e62c5bf7eeb5fb3958", size = 198358 },
+ { url = "https://files.pythonhosted.org/packages/02/f2/d7c497cd148ebfc5b0ae32808e6c1af5922215fe38c7a06e4e722fe937c8/propcache-0.2.1-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:a7e65eb5c003a303b94aa2c3852ef130230ec79e349632d030e9571b87c4698c", size = 195560 },
+ { url = "https://files.pythonhosted.org/packages/bb/57/f37041bbe5e0dfed80a3f6be2612a3a75b9cfe2652abf2c99bef3455bbad/propcache-0.2.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:999779addc413181912e984b942fbcc951be1f5b3663cd80b2687758f434c583", size = 196895 },
+ { url = "https://files.pythonhosted.org/packages/83/36/ae3cc3e4f310bff2f064e3d2ed5558935cc7778d6f827dce74dcfa125304/propcache-0.2.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:19a0f89a7bb9d8048d9c4370c9c543c396e894c76be5525f5e1ad287f1750ddf", size = 207124 },
+ { url = "https://files.pythonhosted.org/packages/8c/c4/811b9f311f10ce9d31a32ff14ce58500458443627e4df4ae9c264defba7f/propcache-0.2.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:1ac2f5fe02fa75f56e1ad473f1175e11f475606ec9bd0be2e78e4734ad575034", size = 210442 },
+ { url = "https://files.pythonhosted.org/packages/18/dd/a1670d483a61ecac0d7fc4305d91caaac7a8fc1b200ea3965a01cf03bced/propcache-0.2.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:574faa3b79e8ebac7cb1d7930f51184ba1ccf69adfdec53a12f319a06030a68b", size = 203219 },
+ { url = "https://files.pythonhosted.org/packages/f9/2d/30ced5afde41b099b2dc0c6573b66b45d16d73090e85655f1a30c5a24e07/propcache-0.2.1-cp310-cp310-win32.whl", hash = "sha256:03ff9d3f665769b2a85e6157ac8b439644f2d7fd17615a82fa55739bc97863f4", size = 40313 },
+ { url = "https://files.pythonhosted.org/packages/23/84/bd9b207ac80da237af77aa6e153b08ffa83264b1c7882495984fcbfcf85c/propcache-0.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:2d3af2e79991102678f53e0dbf4c35de99b6b8b58f29a27ca0325816364caaba", size = 44428 },
+ { url = "https://files.pythonhosted.org/packages/bc/0f/2913b6791ebefb2b25b4efd4bb2299c985e09786b9f5b19184a88e5778dd/propcache-0.2.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:1ffc3cca89bb438fb9c95c13fc874012f7b9466b89328c3c8b1aa93cdcfadd16", size = 79297 },
+ { url = "https://files.pythonhosted.org/packages/cf/73/af2053aeccd40b05d6e19058419ac77674daecdd32478088b79375b9ab54/propcache-0.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f174bbd484294ed9fdf09437f889f95807e5f229d5d93588d34e92106fbf6717", size = 45611 },
+ { url = "https://files.pythonhosted.org/packages/3c/09/8386115ba7775ea3b9537730e8cf718d83bbf95bffe30757ccf37ec4e5da/propcache-0.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:70693319e0b8fd35dd863e3e29513875eb15c51945bf32519ef52927ca883bc3", size = 45146 },
+ { url = "https://files.pythonhosted.org/packages/03/7a/793aa12f0537b2e520bf09f4c6833706b63170a211ad042ca71cbf79d9cb/propcache-0.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b480c6a4e1138e1aa137c0079b9b6305ec6dcc1098a8ca5196283e8a49df95a9", size = 232136 },
+ { url = "https://files.pythonhosted.org/packages/f1/38/b921b3168d72111769f648314100558c2ea1d52eb3d1ba7ea5c4aa6f9848/propcache-0.2.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d27b84d5880f6d8aa9ae3edb253c59d9f6642ffbb2c889b78b60361eed449787", size = 239706 },
+ { url = "https://files.pythonhosted.org/packages/14/29/4636f500c69b5edea7786db3c34eb6166f3384b905665ce312a6e42c720c/propcache-0.2.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:857112b22acd417c40fa4595db2fe28ab900c8c5fe4670c7989b1c0230955465", size = 238531 },
+ { url = "https://files.pythonhosted.org/packages/85/14/01fe53580a8e1734ebb704a3482b7829a0ef4ea68d356141cf0994d9659b/propcache-0.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf6c4150f8c0e32d241436526f3c3f9cbd34429492abddbada2ffcff506c51af", size = 231063 },
+ { url = "https://files.pythonhosted.org/packages/33/5c/1d961299f3c3b8438301ccfbff0143b69afcc30c05fa28673cface692305/propcache-0.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:66d4cfda1d8ed687daa4bc0274fcfd5267873db9a5bc0418c2da19273040eeb7", size = 220134 },
+ { url = "https://files.pythonhosted.org/packages/00/d0/ed735e76db279ba67a7d3b45ba4c654e7b02bc2f8050671ec365d8665e21/propcache-0.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c2f992c07c0fca81655066705beae35fc95a2fa7366467366db627d9f2ee097f", size = 220009 },
+ { url = "https://files.pythonhosted.org/packages/75/90/ee8fab7304ad6533872fee982cfff5a53b63d095d78140827d93de22e2d4/propcache-0.2.1-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:4a571d97dbe66ef38e472703067021b1467025ec85707d57e78711c085984e54", size = 212199 },
+ { url = "https://files.pythonhosted.org/packages/eb/ec/977ffaf1664f82e90737275873461695d4c9407d52abc2f3c3e24716da13/propcache-0.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:bb6178c241278d5fe853b3de743087be7f5f4c6f7d6d22a3b524d323eecec505", size = 214827 },
+ { url = "https://files.pythonhosted.org/packages/57/48/031fb87ab6081764054821a71b71942161619549396224cbb242922525e8/propcache-0.2.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ad1af54a62ffe39cf34db1aa6ed1a1873bd548f6401db39d8e7cd060b9211f82", size = 228009 },
+ { url = "https://files.pythonhosted.org/packages/1a/06/ef1390f2524850838f2390421b23a8b298f6ce3396a7cc6d39dedd4047b0/propcache-0.2.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:e7048abd75fe40712005bcfc06bb44b9dfcd8e101dda2ecf2f5aa46115ad07ca", size = 231638 },
+ { url = "https://files.pythonhosted.org/packages/38/2a/101e6386d5a93358395da1d41642b79c1ee0f3b12e31727932b069282b1d/propcache-0.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:160291c60081f23ee43d44b08a7e5fb76681221a8e10b3139618c5a9a291b84e", size = 222788 },
+ { url = "https://files.pythonhosted.org/packages/db/81/786f687951d0979007e05ad9346cd357e50e3d0b0f1a1d6074df334b1bbb/propcache-0.2.1-cp311-cp311-win32.whl", hash = "sha256:819ce3b883b7576ca28da3861c7e1a88afd08cc8c96908e08a3f4dd64a228034", size = 40170 },
+ { url = "https://files.pythonhosted.org/packages/cf/59/7cc7037b295d5772eceb426358bb1b86e6cab4616d971bd74275395d100d/propcache-0.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:edc9fc7051e3350643ad929df55c451899bb9ae6d24998a949d2e4c87fb596d3", size = 44404 },
+ { url = "https://files.pythonhosted.org/packages/4c/28/1d205fe49be8b1b4df4c50024e62480a442b1a7b818e734308bb0d17e7fb/propcache-0.2.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:081a430aa8d5e8876c6909b67bd2d937bfd531b0382d3fdedb82612c618bc41a", size = 79588 },
+ { url = "https://files.pythonhosted.org/packages/21/ee/fc4d893f8d81cd4971affef2a6cb542b36617cd1d8ce56b406112cb80bf7/propcache-0.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d2ccec9ac47cf4e04897619c0e0c1a48c54a71bdf045117d3a26f80d38ab1fb0", size = 45825 },
+ { url = "https://files.pythonhosted.org/packages/4a/de/bbe712f94d088da1d237c35d735f675e494a816fd6f54e9db2f61ef4d03f/propcache-0.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:14d86fe14b7e04fa306e0c43cdbeebe6b2c2156a0c9ce56b815faacc193e320d", size = 45357 },
+ { url = "https://files.pythonhosted.org/packages/7f/14/7ae06a6cf2a2f1cb382586d5a99efe66b0b3d0c6f9ac2f759e6f7af9d7cf/propcache-0.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:049324ee97bb67285b49632132db351b41e77833678432be52bdd0289c0e05e4", size = 241869 },
+ { url = "https://files.pythonhosted.org/packages/cc/59/227a78be960b54a41124e639e2c39e8807ac0c751c735a900e21315f8c2b/propcache-0.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1cd9a1d071158de1cc1c71a26014dcdfa7dd3d5f4f88c298c7f90ad6f27bb46d", size = 247884 },
+ { url = "https://files.pythonhosted.org/packages/84/58/f62b4ffaedf88dc1b17f04d57d8536601e4e030feb26617228ef930c3279/propcache-0.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98110aa363f1bb4c073e8dcfaefd3a5cea0f0834c2aab23dda657e4dab2f53b5", size = 248486 },
+ { url = "https://files.pythonhosted.org/packages/1c/07/ebe102777a830bca91bbb93e3479cd34c2ca5d0361b83be9dbd93104865e/propcache-0.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:647894f5ae99c4cf6bb82a1bb3a796f6e06af3caa3d32e26d2350d0e3e3faf24", size = 243649 },
+ { url = "https://files.pythonhosted.org/packages/ed/bc/4f7aba7f08f520376c4bb6a20b9a981a581b7f2e385fa0ec9f789bb2d362/propcache-0.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bfd3223c15bebe26518d58ccf9a39b93948d3dcb3e57a20480dfdd315356baff", size = 229103 },
+ { url = "https://files.pythonhosted.org/packages/fe/d5/04ac9cd4e51a57a96f78795e03c5a0ddb8f23ec098b86f92de028d7f2a6b/propcache-0.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d71264a80f3fcf512eb4f18f59423fe82d6e346ee97b90625f283df56aee103f", size = 226607 },
+ { url = "https://files.pythonhosted.org/packages/e3/f0/24060d959ea41d7a7cc7fdbf68b31852331aabda914a0c63bdb0e22e96d6/propcache-0.2.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:e73091191e4280403bde6c9a52a6999d69cdfde498f1fdf629105247599b57ec", size = 221153 },
+ { url = "https://files.pythonhosted.org/packages/77/a7/3ac76045a077b3e4de4859a0753010765e45749bdf53bd02bc4d372da1a0/propcache-0.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3935bfa5fede35fb202c4b569bb9c042f337ca4ff7bd540a0aa5e37131659348", size = 222151 },
+ { url = "https://files.pythonhosted.org/packages/e7/af/5e29da6f80cebab3f5a4dcd2a3240e7f56f2c4abf51cbfcc99be34e17f0b/propcache-0.2.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:f508b0491767bb1f2b87fdfacaba5f7eddc2f867740ec69ece6d1946d29029a6", size = 233812 },
+ { url = "https://files.pythonhosted.org/packages/8c/89/ebe3ad52642cc5509eaa453e9f4b94b374d81bae3265c59d5c2d98efa1b4/propcache-0.2.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:1672137af7c46662a1c2be1e8dc78cb6d224319aaa40271c9257d886be4363a6", size = 238829 },
+ { url = "https://files.pythonhosted.org/packages/e9/2f/6b32f273fa02e978b7577159eae7471b3cfb88b48563b1c2578b2d7ca0bb/propcache-0.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b74c261802d3d2b85c9df2dfb2fa81b6f90deeef63c2db9f0e029a3cac50b518", size = 230704 },
+ { url = "https://files.pythonhosted.org/packages/5c/2e/f40ae6ff5624a5f77edd7b8359b208b5455ea113f68309e2b00a2e1426b6/propcache-0.2.1-cp312-cp312-win32.whl", hash = "sha256:d09c333d36c1409d56a9d29b3a1b800a42c76a57a5a8907eacdbce3f18768246", size = 40050 },
+ { url = "https://files.pythonhosted.org/packages/3b/77/a92c3ef994e47180862b9d7d11e37624fb1c00a16d61faf55115d970628b/propcache-0.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:c214999039d4f2a5b2073ac506bba279945233da8c786e490d411dfc30f855c1", size = 44117 },
+ { url = "https://files.pythonhosted.org/packages/41/b6/c5319caea262f4821995dca2107483b94a3345d4607ad797c76cb9c36bcc/propcache-0.2.1-py3-none-any.whl", hash = "sha256:52277518d6aae65536e9cea52d4e7fd2f7a66f4aa2d30ed3f2fcea620ace3c54", size = 11818 },
+]
+
+[[package]]
+name = "protobuf"
+version = "3.20.3"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/55/5b/e3d951e34f8356e5feecacd12a8e3b258a1da6d9a03ad1770f28925f29bc/protobuf-3.20.3.tar.gz", hash = "sha256:2e3427429c9cffebf259491be0af70189607f365c2f41c7c3764af6f337105f2", size = 216768 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/28/55/b80e8567ec327c060fa39b242392e25690c8899c489ecd7bb65b46b7bb55/protobuf-3.20.3-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f4bd856d702e5b0d96a00ec6b307b0f51c1982c2bf9c0052cf9019e9a544ba99", size = 918427 },
+ { url = "https://files.pythonhosted.org/packages/31/be/80a9c6f16dfa4d41be3edbe655349778ae30882407fa8275eb46b4d34854/protobuf-3.20.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9aae4406ea63d825636cc11ffb34ad3379335803216ee3a856787bcf5ccc751e", size = 1051042 },
+ { url = "https://files.pythonhosted.org/packages/db/96/948d3fcc1fa816e7ae1d27af59b9d8c5c5e582f3994fd14394f31da95b99/protobuf-3.20.3-cp310-cp310-win32.whl", hash = "sha256:28545383d61f55b57cf4df63eebd9827754fd2dc25f80c5253f9184235db242c", size = 780167 },
+ { url = "https://files.pythonhosted.org/packages/6f/5e/fc6feb366b0a9f28e0a2de3b062667c521cd9517d4ff55077b8f351ba2f3/protobuf-3.20.3-cp310-cp310-win_amd64.whl", hash = "sha256:67a3598f0a2dcbc58d02dd1928544e7d88f764b47d4a286202913f0b2801c2e7", size = 904029 },
+ { url = "https://files.pythonhosted.org/packages/8d/14/619e24a4c70df2901e1f4dbc50a6291eb63a759172558df326347dce1f0d/protobuf-3.20.3-py2.py3-none-any.whl", hash = "sha256:a7ca6d488aa8ff7f329d4c545b2dbad8ac31464f1d8b1c87ad1346717731e4db", size = 162128 },
+]
+
+[[package]]
+name = "psutil"
+version = "6.1.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/1f/5a/07871137bb752428aa4b659f910b399ba6f291156bdea939be3e96cae7cb/psutil-6.1.1.tar.gz", hash = "sha256:cf8496728c18f2d0b45198f06895be52f36611711746b7f30c464b422b50e2f5", size = 508502 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/61/99/ca79d302be46f7bdd8321089762dd4476ee725fce16fc2b2e1dbba8cac17/psutil-6.1.1-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:fc0ed7fe2231a444fc219b9c42d0376e0a9a1a72f16c5cfa0f68d19f1a0663e8", size = 247511 },
+ { url = "https://files.pythonhosted.org/packages/0b/6b/73dbde0dd38f3782905d4587049b9be64d76671042fdcaf60e2430c6796d/psutil-6.1.1-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:0bdd4eab935276290ad3cb718e9809412895ca6b5b334f5a9111ee6d9aff9377", size = 248985 },
+ { url = "https://files.pythonhosted.org/packages/17/38/c319d31a1d3f88c5b79c68b3116c129e5133f1822157dd6da34043e32ed6/psutil-6.1.1-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6e06c20c05fe95a3d7302d74e7097756d4ba1247975ad6905441ae1b5b66003", size = 284488 },
+ { url = "https://files.pythonhosted.org/packages/9c/39/0f88a830a1c8a3aba27fededc642da37613c57cbff143412e3536f89784f/psutil-6.1.1-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97f7cb9921fbec4904f522d972f0c0e1f4fabbdd4e0287813b21215074a0f160", size = 287477 },
+ { url = "https://files.pythonhosted.org/packages/47/da/99f4345d4ddf2845cb5b5bd0d93d554e84542d116934fde07a0c50bd4e9f/psutil-6.1.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33431e84fee02bc84ea36d9e2c4a6d395d479c9dd9bba2376c1f6ee8f3a4e0b3", size = 289017 },
+ { url = "https://files.pythonhosted.org/packages/38/53/bd755c2896f4461fd4f36fa6a6dcb66a88a9e4b9fd4e5b66a77cf9d4a584/psutil-6.1.1-cp37-abi3-win32.whl", hash = "sha256:eaa912e0b11848c4d9279a93d7e2783df352b082f40111e078388701fd479e53", size = 250602 },
+ { url = "https://files.pythonhosted.org/packages/7b/d7/7831438e6c3ebbfa6e01a927127a6cb42ad3ab844247f3c5b96bea25d73d/psutil-6.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649", size = 254444 },
+]
+
+[[package]]
+name = "ptyprocess"
+version = "0.7.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/20/e5/16ff212c1e452235a90aeb09066144d0c5a6a8c0834397e03f5224495c4e/ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220", size = 70762 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35", size = 13993 },
+]
+
+[[package]]
+name = "pure-eval"
+version = "0.2.3"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/cd/05/0a34433a064256a578f1783a10da6df098ceaa4a57bbeaa96a6c0352786b/pure_eval-0.2.3.tar.gz", hash = "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42", size = 19752 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842 },
+]
+
+[[package]]
+name = "py-cpuinfo"
+version = "9.0.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/37/a8/d832f7293ebb21690860d2e01d8115e5ff6f2ae8bbdc953f0eb0fa4bd2c7/py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690", size = 104716 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335 },
+]
+
+[[package]]
+name = "pyarrow"
+version = "19.0.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/7b/01/fe1fd04744c2aa038e5a11c7a4adb3d62bce09798695e54f7274b5977134/pyarrow-19.0.0.tar.gz", hash = "sha256:8d47c691765cf497aaeed4954d226568563f1b3b74ff61139f2d77876717084b", size = 1129096 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/1c/02/1ad80ffd3c558916858a49c83b6e494a9d93009bbebc603cf0cb8263bea7/pyarrow-19.0.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:c318eda14f6627966997a7d8c374a87d084a94e4e38e9abbe97395c215830e0c", size = 30686262 },
+ { url = "https://files.pythonhosted.org/packages/1b/f0/adab5f142eb8203db8bfbd3a816816e37a85423ae684567e7f3555658315/pyarrow-19.0.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:62ef8360ff256e960f57ce0299090fb86423afed5e46f18f1225f960e05aae3d", size = 32100005 },
+ { url = "https://files.pythonhosted.org/packages/94/8b/e674083610e5efc48d2f205c568d842cdfdf683d12f9ff0d546e38757722/pyarrow-19.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2795064647add0f16563e57e3d294dbfc067b723f0fd82ecd80af56dad15f503", size = 41144815 },
+ { url = "https://files.pythonhosted.org/packages/d5/fb/2726241a792b7f8a58789e5a63d1be9a5a4059206318fd0ff9485a578952/pyarrow-19.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a218670b26fb1bc74796458d97bcab072765f9b524f95b2fccad70158feb8b17", size = 42180380 },
+ { url = "https://files.pythonhosted.org/packages/7d/09/7aef12446d8e7002dfc07bb7bc71f594c1d5844ca78b364a49f07efb65b1/pyarrow-19.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:66732e39eaa2247996a6b04c8aa33e3503d351831424cdf8d2e9a0582ac54b34", size = 40515021 },
+ { url = "https://files.pythonhosted.org/packages/31/55/f05fc5608cc96060c2b24de505324d641888bd62d4eed2fa1dacd872a1e1/pyarrow-19.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:e675a3ad4732b92d72e4d24009707e923cab76b0d088e5054914f11a797ebe44", size = 42067488 },
+ { url = "https://files.pythonhosted.org/packages/f0/01/097653cec7a944c16313cb748a326771133c142034b252076bd84743b98d/pyarrow-19.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:f094742275586cdd6b1a03655ccff3b24b2610c3af76f810356c4c71d24a2a6c", size = 25276726 },
+ { url = "https://files.pythonhosted.org/packages/82/42/fba3a35bef5833bf88ed35e6a810dc1781236e1d4f808d2df824a7d21819/pyarrow-19.0.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:8e3a839bf36ec03b4315dc924d36dcde5444a50066f1c10f8290293c0427b46a", size = 30711936 },
+ { url = "https://files.pythonhosted.org/packages/88/7a/0da93a3eaaf251a30e32f3221e874263cdcd366c2cd6b7c05293aad91152/pyarrow-19.0.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:ce42275097512d9e4e4a39aade58ef2b3798a93aa3026566b7892177c266f735", size = 32133182 },
+ { url = "https://files.pythonhosted.org/packages/2f/df/fe43b1c50d3100d0de53f988344118bc20362d0de005f8a407454fa565f8/pyarrow-19.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9348a0137568c45601b031a8d118275069435f151cbb77e6a08a27e8125f59d4", size = 41145489 },
+ { url = "https://files.pythonhosted.org/packages/45/bb/6f73b41b342a0342f2516a02db4aa97a4f9569cc35482a5c288090140cd4/pyarrow-19.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a0144a712d990d60f7f42b7a31f0acaccf4c1e43e957f7b1ad58150d6f639c1", size = 42177823 },
+ { url = "https://files.pythonhosted.org/packages/23/7b/f038a96f421e453a71bd7a0f78d62b1b2ae9bcac06ed51179ca532e6a0a2/pyarrow-19.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2a1a109dfda558eb011e5f6385837daffd920d54ca00669f7a11132d0b1e6042", size = 40530609 },
+ { url = "https://files.pythonhosted.org/packages/b8/39/a2a6714b471c000e6dd6af4495dce00d7d1332351b8e3170dfb9f91dad1f/pyarrow-19.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:be686bf625aa7b9bada18defb3a3ea3981c1099697239788ff111d87f04cd263", size = 42081534 },
+ { url = "https://files.pythonhosted.org/packages/6c/a3/8396fb06ca05d807e89980c177be26617aad15211ece3184e0caa730b8a6/pyarrow-19.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:239ca66d9a05844bdf5af128861af525e14df3c9591bcc05bac25918e650d3a2", size = 25281090 },
+ { url = "https://files.pythonhosted.org/packages/bc/2e/152885f5ef421e80dae68b9c133ab261934f93a6d5e16b61d79c0ed597fb/pyarrow-19.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:a7bbe7109ab6198688b7079cbad5a8c22de4d47c4880d8e4847520a83b0d1b68", size = 30667964 },
+ { url = "https://files.pythonhosted.org/packages/80/c2/08bbee9a8610a47c9a1466845f405baf53a639ddd947c5133d8ba13544b6/pyarrow-19.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:4624c89d6f777c580e8732c27bb8e77fd1433b89707f17c04af7635dd9638351", size = 32125039 },
+ { url = "https://files.pythonhosted.org/packages/d2/56/06994df823212f5688d3c8bf4294928b12c9be36681872853655724d28c6/pyarrow-19.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b6d3ce4288793350dc2d08d1e184fd70631ea22a4ff9ea5c4ff182130249d9b", size = 41140729 },
+ { url = "https://files.pythonhosted.org/packages/94/65/38ad577c98140a9db71e9e1e594b6adb58a7478a5afec6456a8ca2df7f70/pyarrow-19.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:450a7d27e840e4d9a384b5c77199d489b401529e75a3b7a3799d4cd7957f2f9c", size = 42202267 },
+ { url = "https://files.pythonhosted.org/packages/b6/1f/966b722251a7354114ccbb71cf1a83922023e69efd8945ebf628a851ec4c/pyarrow-19.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:a08e2a8a039a3f72afb67a6668180f09fddaa38fe0d21f13212b4aba4b5d2451", size = 40505858 },
+ { url = "https://files.pythonhosted.org/packages/3b/5e/6bc81aa7fc9affc7d1c03b912fbcc984ca56c2a18513684da267715dab7b/pyarrow-19.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:f43f5aef2a13d4d56adadae5720d1fed4c1356c993eda8b59dace4b5983843c1", size = 42084973 },
+ { url = "https://files.pythonhosted.org/packages/53/c3/2f56da818b6a4758cbd514957c67bd0f078ebffa5390ee2e2bf0f9e8defc/pyarrow-19.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:2f672f5364b2d7829ef7c94be199bb88bf5661dd485e21d2d37de12ccb78a136", size = 25241976 },
+]
+
+[[package]]
+name = "pydantic"
+version = "2.10.6"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "annotated-types" },
+ { name = "pydantic-core" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/b7/ae/d5220c5c52b158b1de7ca89fc5edb72f304a70a4c540c84c8844bf4008de/pydantic-2.10.6.tar.gz", hash = "sha256:ca5daa827cce33de7a42be142548b0096bf05a7e7b365aebfa5f8eeec7128236", size = 761681 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/f4/3c/8cc1cc84deffa6e25d2d0c688ebb80635dfdbf1dbea3e30c541c8cf4d860/pydantic-2.10.6-py3-none-any.whl", hash = "sha256:427d664bf0b8a2b34ff5dd0f5a18df00591adcee7198fbd71981054cef37b584", size = 431696 },
+]
+
+[[package]]
+name = "pydantic-core"
+version = "2.27.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/fc/01/f3e5ac5e7c25833db5eb555f7b7ab24cd6f8c322d3a3ad2d67a952dc0abc/pydantic_core-2.27.2.tar.gz", hash = "sha256:eb026e5a4c1fee05726072337ff51d1efb6f59090b7da90d30ea58625b1ffb39", size = 413443 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/3a/bc/fed5f74b5d802cf9a03e83f60f18864e90e3aed7223adaca5ffb7a8d8d64/pydantic_core-2.27.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2d367ca20b2f14095a8f4fa1210f5a7b78b8a20009ecced6b12818f455b1e9fa", size = 1895938 },
+ { url = "https://files.pythonhosted.org/packages/71/2a/185aff24ce844e39abb8dd680f4e959f0006944f4a8a0ea372d9f9ae2e53/pydantic_core-2.27.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:491a2b73db93fab69731eaee494f320faa4e093dbed776be1a829c2eb222c34c", size = 1815684 },
+ { url = "https://files.pythonhosted.org/packages/c3/43/fafabd3d94d159d4f1ed62e383e264f146a17dd4d48453319fd782e7979e/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7969e133a6f183be60e9f6f56bfae753585680f3b7307a8e555a948d443cc05a", size = 1829169 },
+ { url = "https://files.pythonhosted.org/packages/a2/d1/f2dfe1a2a637ce6800b799aa086d079998959f6f1215eb4497966efd2274/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3de9961f2a346257caf0aa508a4da705467f53778e9ef6fe744c038119737ef5", size = 1867227 },
+ { url = "https://files.pythonhosted.org/packages/7d/39/e06fcbcc1c785daa3160ccf6c1c38fea31f5754b756e34b65f74e99780b5/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e2bb4d3e5873c37bb3dd58714d4cd0b0e6238cebc4177ac8fe878f8b3aa8e74c", size = 2037695 },
+ { url = "https://files.pythonhosted.org/packages/7a/67/61291ee98e07f0650eb756d44998214231f50751ba7e13f4f325d95249ab/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:280d219beebb0752699480fe8f1dc61ab6615c2046d76b7ab7ee38858de0a4e7", size = 2741662 },
+ { url = "https://files.pythonhosted.org/packages/32/90/3b15e31b88ca39e9e626630b4c4a1f5a0dfd09076366f4219429e6786076/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47956ae78b6422cbd46f772f1746799cbb862de838fd8d1fbd34a82e05b0983a", size = 1993370 },
+ { url = "https://files.pythonhosted.org/packages/ff/83/c06d333ee3a67e2e13e07794995c1535565132940715931c1c43bfc85b11/pydantic_core-2.27.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:14d4a5c49d2f009d62a2a7140d3064f686d17a5d1a268bc641954ba181880236", size = 1996813 },
+ { url = "https://files.pythonhosted.org/packages/7c/f7/89be1c8deb6e22618a74f0ca0d933fdcb8baa254753b26b25ad3acff8f74/pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:337b443af21d488716f8d0b6164de833e788aa6bd7e3a39c005febc1284f4962", size = 2005287 },
+ { url = "https://files.pythonhosted.org/packages/b7/7d/8eb3e23206c00ef7feee17b83a4ffa0a623eb1a9d382e56e4aa46fd15ff2/pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:03d0f86ea3184a12f41a2d23f7ccb79cdb5a18e06993f8a45baa8dfec746f0e9", size = 2128414 },
+ { url = "https://files.pythonhosted.org/packages/4e/99/fe80f3ff8dd71a3ea15763878d464476e6cb0a2db95ff1c5c554133b6b83/pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7041c36f5680c6e0f08d922aed302e98b3745d97fe1589db0a3eebf6624523af", size = 2155301 },
+ { url = "https://files.pythonhosted.org/packages/2b/a3/e50460b9a5789ca1451b70d4f52546fa9e2b420ba3bfa6100105c0559238/pydantic_core-2.27.2-cp310-cp310-win32.whl", hash = "sha256:50a68f3e3819077be2c98110c1f9dcb3817e93f267ba80a2c05bb4f8799e2ff4", size = 1816685 },
+ { url = "https://files.pythonhosted.org/packages/57/4c/a8838731cb0f2c2a39d3535376466de6049034d7b239c0202a64aaa05533/pydantic_core-2.27.2-cp310-cp310-win_amd64.whl", hash = "sha256:e0fd26b16394ead34a424eecf8a31a1f5137094cabe84a1bcb10fa6ba39d3d31", size = 1982876 },
+ { url = "https://files.pythonhosted.org/packages/c2/89/f3450af9d09d44eea1f2c369f49e8f181d742f28220f88cc4dfaae91ea6e/pydantic_core-2.27.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:8e10c99ef58cfdf2a66fc15d66b16c4a04f62bca39db589ae8cba08bc55331bc", size = 1893421 },
+ { url = "https://files.pythonhosted.org/packages/9e/e3/71fe85af2021f3f386da42d291412e5baf6ce7716bd7101ea49c810eda90/pydantic_core-2.27.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:26f32e0adf166a84d0cb63be85c562ca8a6fa8de28e5f0d92250c6b7e9e2aff7", size = 1814998 },
+ { url = "https://files.pythonhosted.org/packages/a6/3c/724039e0d848fd69dbf5806894e26479577316c6f0f112bacaf67aa889ac/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c19d1ea0673cd13cc2f872f6c9ab42acc4e4f492a7ca9d3795ce2b112dd7e15", size = 1826167 },
+ { url = "https://files.pythonhosted.org/packages/2b/5b/1b29e8c1fb5f3199a9a57c1452004ff39f494bbe9bdbe9a81e18172e40d3/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5e68c4446fe0810e959cdff46ab0a41ce2f2c86d227d96dc3847af0ba7def306", size = 1865071 },
+ { url = "https://files.pythonhosted.org/packages/89/6c/3985203863d76bb7d7266e36970d7e3b6385148c18a68cc8915fd8c84d57/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d9640b0059ff4f14d1f37321b94061c6db164fbe49b334b31643e0528d100d99", size = 2036244 },
+ { url = "https://files.pythonhosted.org/packages/0e/41/f15316858a246b5d723f7d7f599f79e37493b2e84bfc789e58d88c209f8a/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:40d02e7d45c9f8af700f3452f329ead92da4c5f4317ca9b896de7ce7199ea459", size = 2737470 },
+ { url = "https://files.pythonhosted.org/packages/a8/7c/b860618c25678bbd6d1d99dbdfdf0510ccb50790099b963ff78a124b754f/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c1fd185014191700554795c99b347d64f2bb637966c4cfc16998a0ca700d048", size = 1992291 },
+ { url = "https://files.pythonhosted.org/packages/bf/73/42c3742a391eccbeab39f15213ecda3104ae8682ba3c0c28069fbcb8c10d/pydantic_core-2.27.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d81d2068e1c1228a565af076598f9e7451712700b673de8f502f0334f281387d", size = 1994613 },
+ { url = "https://files.pythonhosted.org/packages/94/7a/941e89096d1175d56f59340f3a8ebaf20762fef222c298ea96d36a6328c5/pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1a4207639fb02ec2dbb76227d7c751a20b1a6b4bc52850568e52260cae64ca3b", size = 2002355 },
+ { url = "https://files.pythonhosted.org/packages/6e/95/2359937a73d49e336a5a19848713555605d4d8d6940c3ec6c6c0ca4dcf25/pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:3de3ce3c9ddc8bbd88f6e0e304dea0e66d843ec9de1b0042b0911c1663ffd474", size = 2126661 },
+ { url = "https://files.pythonhosted.org/packages/2b/4c/ca02b7bdb6012a1adef21a50625b14f43ed4d11f1fc237f9d7490aa5078c/pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:30c5f68ded0c36466acede341551106821043e9afaad516adfb6e8fa80a4e6a6", size = 2153261 },
+ { url = "https://files.pythonhosted.org/packages/72/9d/a241db83f973049a1092a079272ffe2e3e82e98561ef6214ab53fe53b1c7/pydantic_core-2.27.2-cp311-cp311-win32.whl", hash = "sha256:c70c26d2c99f78b125a3459f8afe1aed4d9687c24fd677c6a4436bc042e50d6c", size = 1812361 },
+ { url = "https://files.pythonhosted.org/packages/e8/ef/013f07248041b74abd48a385e2110aa3a9bbfef0fbd97d4e6d07d2f5b89a/pydantic_core-2.27.2-cp311-cp311-win_amd64.whl", hash = "sha256:08e125dbdc505fa69ca7d9c499639ab6407cfa909214d500897d02afb816e7cc", size = 1982484 },
+ { url = "https://files.pythonhosted.org/packages/10/1c/16b3a3e3398fd29dca77cea0a1d998d6bde3902fa2706985191e2313cc76/pydantic_core-2.27.2-cp311-cp311-win_arm64.whl", hash = "sha256:26f0d68d4b235a2bae0c3fc585c585b4ecc51382db0e3ba402a22cbc440915e4", size = 1867102 },
+ { url = "https://files.pythonhosted.org/packages/d6/74/51c8a5482ca447871c93e142d9d4a92ead74de6c8dc5e66733e22c9bba89/pydantic_core-2.27.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:9e0c8cfefa0ef83b4da9588448b6d8d2a2bf1a53c3f1ae5fca39eb3061e2f0b0", size = 1893127 },
+ { url = "https://files.pythonhosted.org/packages/d3/f3/c97e80721735868313c58b89d2de85fa80fe8dfeeed84dc51598b92a135e/pydantic_core-2.27.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:83097677b8e3bd7eaa6775720ec8e0405f1575015a463285a92bfdfe254529ef", size = 1811340 },
+ { url = "https://files.pythonhosted.org/packages/9e/91/840ec1375e686dbae1bd80a9e46c26a1e0083e1186abc610efa3d9a36180/pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:172fce187655fece0c90d90a678424b013f8fbb0ca8b036ac266749c09438cb7", size = 1822900 },
+ { url = "https://files.pythonhosted.org/packages/f6/31/4240bc96025035500c18adc149aa6ffdf1a0062a4b525c932065ceb4d868/pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:519f29f5213271eeeeb3093f662ba2fd512b91c5f188f3bb7b27bc5973816934", size = 1869177 },
+ { url = "https://files.pythonhosted.org/packages/fa/20/02fbaadb7808be578317015c462655c317a77a7c8f0ef274bc016a784c54/pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:05e3a55d124407fffba0dd6b0c0cd056d10e983ceb4e5dbd10dda135c31071d6", size = 2038046 },
+ { url = "https://files.pythonhosted.org/packages/06/86/7f306b904e6c9eccf0668248b3f272090e49c275bc488a7b88b0823444a4/pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c3ed807c7b91de05e63930188f19e921d1fe90de6b4f5cd43ee7fcc3525cb8c", size = 2685386 },
+ { url = "https://files.pythonhosted.org/packages/8d/f0/49129b27c43396581a635d8710dae54a791b17dfc50c70164866bbf865e3/pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fb4aadc0b9a0c063206846d603b92030eb6f03069151a625667f982887153e2", size = 1997060 },
+ { url = "https://files.pythonhosted.org/packages/0d/0f/943b4af7cd416c477fd40b187036c4f89b416a33d3cc0ab7b82708a667aa/pydantic_core-2.27.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:28ccb213807e037460326424ceb8b5245acb88f32f3d2777427476e1b32c48c4", size = 2004870 },
+ { url = "https://files.pythonhosted.org/packages/35/40/aea70b5b1a63911c53a4c8117c0a828d6790483f858041f47bab0b779f44/pydantic_core-2.27.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:de3cd1899e2c279b140adde9357c4495ed9d47131b4a4eaff9052f23398076b3", size = 1999822 },
+ { url = "https://files.pythonhosted.org/packages/f2/b3/807b94fd337d58effc5498fd1a7a4d9d59af4133e83e32ae39a96fddec9d/pydantic_core-2.27.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:220f892729375e2d736b97d0e51466252ad84c51857d4d15f5e9692f9ef12be4", size = 2130364 },
+ { url = "https://files.pythonhosted.org/packages/fc/df/791c827cd4ee6efd59248dca9369fb35e80a9484462c33c6649a8d02b565/pydantic_core-2.27.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a0fcd29cd6b4e74fe8ddd2c90330fd8edf2e30cb52acda47f06dd615ae72da57", size = 2158303 },
+ { url = "https://files.pythonhosted.org/packages/9b/67/4e197c300976af185b7cef4c02203e175fb127e414125916bf1128b639a9/pydantic_core-2.27.2-cp312-cp312-win32.whl", hash = "sha256:1e2cb691ed9834cd6a8be61228471d0a503731abfb42f82458ff27be7b2186fc", size = 1834064 },
+ { url = "https://files.pythonhosted.org/packages/1f/ea/cd7209a889163b8dcca139fe32b9687dd05249161a3edda62860430457a5/pydantic_core-2.27.2-cp312-cp312-win_amd64.whl", hash = "sha256:cc3f1a99a4f4f9dd1de4fe0312c114e740b5ddead65bb4102884b384c15d8bc9", size = 1989046 },
+ { url = "https://files.pythonhosted.org/packages/bc/49/c54baab2f4658c26ac633d798dab66b4c3a9bbf47cff5284e9c182f4137a/pydantic_core-2.27.2-cp312-cp312-win_arm64.whl", hash = "sha256:3911ac9284cd8a1792d3cb26a2da18f3ca26c6908cc434a18f730dc0db7bfa3b", size = 1885092 },
+ { url = "https://files.pythonhosted.org/packages/46/72/af70981a341500419e67d5cb45abe552a7c74b66326ac8877588488da1ac/pydantic_core-2.27.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:2bf14caea37e91198329b828eae1618c068dfb8ef17bb33287a7ad4b61ac314e", size = 1891159 },
+ { url = "https://files.pythonhosted.org/packages/ad/3d/c5913cccdef93e0a6a95c2d057d2c2cba347815c845cda79ddd3c0f5e17d/pydantic_core-2.27.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:b0cb791f5b45307caae8810c2023a184c74605ec3bcbb67d13846c28ff731ff8", size = 1768331 },
+ { url = "https://files.pythonhosted.org/packages/f6/f0/a3ae8fbee269e4934f14e2e0e00928f9346c5943174f2811193113e58252/pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:688d3fd9fcb71f41c4c015c023d12a79d1c4c0732ec9eb35d96e3388a120dcf3", size = 1822467 },
+ { url = "https://files.pythonhosted.org/packages/d7/7a/7bbf241a04e9f9ea24cd5874354a83526d639b02674648af3f350554276c/pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d591580c34f4d731592f0e9fe40f9cc1b430d297eecc70b962e93c5c668f15f", size = 1979797 },
+ { url = "https://files.pythonhosted.org/packages/4f/5f/4784c6107731f89e0005a92ecb8a2efeafdb55eb992b8e9d0a2be5199335/pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:82f986faf4e644ffc189a7f1aafc86e46ef70372bb153e7001e8afccc6e54133", size = 1987839 },
+ { url = "https://files.pythonhosted.org/packages/6d/a7/61246562b651dff00de86a5f01b6e4befb518df314c54dec187a78d81c84/pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:bec317a27290e2537f922639cafd54990551725fc844249e64c523301d0822fc", size = 1998861 },
+ { url = "https://files.pythonhosted.org/packages/86/aa/837821ecf0c022bbb74ca132e117c358321e72e7f9702d1b6a03758545e2/pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:0296abcb83a797db256b773f45773da397da75a08f5fcaef41f2044adec05f50", size = 2116582 },
+ { url = "https://files.pythonhosted.org/packages/81/b0/5e74656e95623cbaa0a6278d16cf15e10a51f6002e3ec126541e95c29ea3/pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:0d75070718e369e452075a6017fbf187f788e17ed67a3abd47fa934d001863d9", size = 2151985 },
+ { url = "https://files.pythonhosted.org/packages/63/37/3e32eeb2a451fddaa3898e2163746b0cffbbdbb4740d38372db0490d67f3/pydantic_core-2.27.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:7e17b560be3c98a8e3aa66ce828bdebb9e9ac6ad5466fba92eb74c4c95cb1151", size = 2004715 },
+]
+
+[[package]]
+name = "pygments"
+version = "2.19.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/7c/2d/c3338d48ea6cc0feb8446d8e6937e1408088a72a39937982cc6111d17f84/pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f", size = 4968581 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/8a/0b/9fcc47d19c48b59121088dd6da2488a49d5f72dacf8262e2790a1d2c7d15/pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c", size = 1225293 },
+]
+
+[[package]]
+name = "pynvml"
+version = "12.0.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "nvidia-ml-py" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/26/6f/6b5880ed0239e85b9a39aed103b65b2ef81425beef9f45e5c035bf008330/pynvml-12.0.0.tar.gz", hash = "sha256:299ce2451a6a17e6822d6faee750103e25b415f06f59abb8db65d30f794166f5", size = 33636 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ed/df/f7cf07a65a96dd11d71f346f9c2863accdd4784da83af7181b067d556cbc/pynvml-12.0.0-py3-none-any.whl", hash = "sha256:fdff84b62a27dbe98e08e1a647eb77342bef1aebe0878bcd15e99a83fcbecb9e", size = 26560 },
+]
+
+[[package]]
+name = "pyparsing"
+version = "3.2.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/8b/1a/3544f4f299a47911c2ab3710f534e52fea62a633c96806995da5d25be4b2/pyparsing-3.2.1.tar.gz", hash = "sha256:61980854fd66de3a90028d679a954d5f2623e83144b5afe5ee86f43d762e5f0a", size = 1067694 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/1c/a7/c8a2d361bf89c0d9577c934ebb7421b25dc84bf3a8e3ac0a40aed9acc547/pyparsing-3.2.1-py3-none-any.whl", hash = "sha256:506ff4f4386c4cec0590ec19e6302d3aedb992fdc02c761e90416f158dacf8e1", size = 107716 },
+]
+
+[[package]]
+name = "pyre-extensions"
+version = "0.0.32"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "typing-extensions" },
+ { name = "typing-inspect" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/a7/53/5bc2532536e921c48366ad1047c1344ccef6afa5e84053f0f6e20a453767/pyre_extensions-0.0.32.tar.gz", hash = "sha256:5396715f14ea56c4d5fd0a88c57ca7e44faa468f905909edd7de4ad90ed85e55", size = 10852 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/a4/7a/9812cb8be9828ab688203c5ac5f743c60652887f0c00995a6f6f19f912bd/pyre_extensions-0.0.32-py3-none-any.whl", hash = "sha256:a63ba6883ab02f4b1a9f372ed4eb4a2f4c6f3d74879aa2725186fdfcfe3e5c68", size = 12766 },
+]
+
+[[package]]
+name = "pytest"
+version = "7.2.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "attrs" },
+ { name = "colorama", marker = "sys_platform == 'win32'" },
+ { name = "exceptiongroup", marker = "python_full_version < '3.11'" },
+ { name = "iniconfig" },
+ { name = "packaging" },
+ { name = "pluggy" },
+ { name = "tomli", marker = "python_full_version < '3.11'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/0b/21/055f39bf8861580b43f845f9e8270c7786fe629b2f8562ff09007132e2e7/pytest-7.2.0.tar.gz", hash = "sha256:c4014eb40e10f11f355ad4e3c2fb2c6c6d1919c73f3b5a433de4708202cade59", size = 1300608 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/67/68/a5eb36c3a8540594b6035e6cdae40c1ef1b6a2bfacbecc3d1a544583c078/pytest-7.2.0-py3-none-any.whl", hash = "sha256:892f933d339f068883b6fd5a459f03d85bfcb355e4981e146d2c7616c21fef71", size = 316791 },
+]
+
+[[package]]
+name = "pytest-split"
+version = "0.8.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "pytest" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/04/10/c317f5e9682a6fa184a9f598c987c8cef42edbd8ba8534184cf0c1918473/pytest-split-0.8.0.tar.gz", hash = "sha256:8571a3f60ca8656c698ed86b0a3212bb9e79586ecb201daef9988c336ff0e6ff", size = 13913 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/8a/d4/dcebd4d75cc2a0faed3ed615a47bd179d6c3873959ccda31a144be68c6f4/pytest_split-0.8.0-py3-none-any.whl", hash = "sha256:2e06b8b1ab7ceb19d0b001548271abaf91d12415a8687086cf40581c555d309f", size = 11708 },
+]
+
+[[package]]
+name = "python-dateutil"
+version = "2.9.0.post0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "six" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892 },
+]
+
+[[package]]
+name = "python-dotenv"
+version = "1.0.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/bc/57/e84d88dfe0aec03b7a2d4327012c1627ab5f03652216c63d49846d7a6c58/python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca", size = 39115 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863 },
+]
+
+[[package]]
+name = "python-fasthtml"
+version = "0.12.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "beautifulsoup4" },
+ { name = "fastcore" },
+ { name = "fastlite" },
+ { name = "httpx" },
+ { name = "itsdangerous" },
+ { name = "oauthlib" },
+ { name = "python-dateutil" },
+ { name = "python-multipart" },
+ { name = "starlette" },
+ { name = "uvicorn", extra = ["standard"] },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/c5/68/ad5465c174ff934814303af17f6839a901a4a3e287857e83080807471cfd/python_fasthtml-0.12.1.tar.gz", hash = "sha256:29b54df1bed9063c32fa320f57487887ce8bf846838b441a67d6c1ebe1caadc5", size = 58697 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/62/2a/2906ab68d03d7d6bf397cfd402649d6a27b4065addbff30efc2d147e9f2c/python_fasthtml-0.12.1-py3-none-any.whl", hash = "sha256:6eabaadf826b19a07851750361a160cc8e986ad1d96c42cc64439315dd3c4c99", size = 61009 },
+]
+
+[[package]]
+name = "python-multipart"
+version = "0.0.20"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/f3/87/f44d7c9f274c7ee665a29b885ec97089ec5dc034c7f3fafa03da9e39a09e/python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13", size = 37158 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size = 24546 },
+]
+
+[[package]]
+name = "pytz"
+version = "2025.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/5f/57/df1c9157c8d5a05117e455d66fd7cf6dbc46974f832b1058ed4856785d8a/pytz-2025.1.tar.gz", hash = "sha256:c2db42be2a2518b28e65f9207c4d05e6ff547d1efa4086469ef855e4ab70178e", size = 319617 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/eb/38/ac33370d784287baa1c3d538978b5e2ea064d4c1b93ffbd12826c190dd10/pytz-2025.1-py2.py3-none-any.whl", hash = "sha256:89dd22dca55b46eac6eda23b2d72721bf1bdfef212645d81513ef5d03038de57", size = 507930 },
+]
+
+[[package]]
+name = "pyyaml"
+version = "6.0.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/54/ed/79a089b6be93607fa5cdaedf301d7dfb23af5f25c398d5ead2525b063e17/pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e", size = 130631 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/9b/95/a3fac87cb7158e231b5a6012e438c647e1a87f09f8e0d123acec8ab8bf71/PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086", size = 184199 },
+ { url = "https://files.pythonhosted.org/packages/c7/7a/68bd47624dab8fd4afbfd3c48e3b79efe09098ae941de5b58abcbadff5cb/PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf", size = 171758 },
+ { url = "https://files.pythonhosted.org/packages/49/ee/14c54df452143b9ee9f0f29074d7ca5516a36edb0b4cc40c3f280131656f/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237", size = 718463 },
+ { url = "https://files.pythonhosted.org/packages/4d/61/de363a97476e766574650d742205be468921a7b532aa2499fcd886b62530/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b", size = 719280 },
+ { url = "https://files.pythonhosted.org/packages/6b/4e/1523cb902fd98355e2e9ea5e5eb237cbc5f3ad5f3075fa65087aa0ecb669/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed", size = 751239 },
+ { url = "https://files.pythonhosted.org/packages/b7/33/5504b3a9a4464893c32f118a9cc045190a91637b119a9c881da1cf6b7a72/PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180", size = 695802 },
+ { url = "https://files.pythonhosted.org/packages/5c/20/8347dcabd41ef3a3cdc4f7b7a2aff3d06598c8779faa189cdbf878b626a4/PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68", size = 720527 },
+ { url = "https://files.pythonhosted.org/packages/be/aa/5afe99233fb360d0ff37377145a949ae258aaab831bde4792b32650a4378/PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99", size = 144052 },
+ { url = "https://files.pythonhosted.org/packages/b5/84/0fa4b06f6d6c958d207620fc60005e241ecedceee58931bb20138e1e5776/PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e", size = 161774 },
+ { url = "https://files.pythonhosted.org/packages/f8/aa/7af4e81f7acba21a4c6be026da38fd2b872ca46226673c89a758ebdc4fd2/PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774", size = 184612 },
+ { url = "https://files.pythonhosted.org/packages/8b/62/b9faa998fd185f65c1371643678e4d58254add437edb764a08c5a98fb986/PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee", size = 172040 },
+ { url = "https://files.pythonhosted.org/packages/ad/0c/c804f5f922a9a6563bab712d8dcc70251e8af811fce4524d57c2c0fd49a4/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c", size = 736829 },
+ { url = "https://files.pythonhosted.org/packages/51/16/6af8d6a6b210c8e54f1406a6b9481febf9c64a3109c541567e35a49aa2e7/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317", size = 764167 },
+ { url = "https://files.pythonhosted.org/packages/75/e4/2c27590dfc9992f73aabbeb9241ae20220bd9452df27483b6e56d3975cc5/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85", size = 762952 },
+ { url = "https://files.pythonhosted.org/packages/9b/97/ecc1abf4a823f5ac61941a9c00fe501b02ac3ab0e373c3857f7d4b83e2b6/PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4", size = 735301 },
+ { url = "https://files.pythonhosted.org/packages/45/73/0f49dacd6e82c9430e46f4a027baa4ca205e8b0a9dce1397f44edc23559d/PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e", size = 756638 },
+ { url = "https://files.pythonhosted.org/packages/22/5f/956f0f9fc65223a58fbc14459bf34b4cc48dec52e00535c79b8db361aabd/PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5", size = 143850 },
+ { url = "https://files.pythonhosted.org/packages/ed/23/8da0bbe2ab9dcdd11f4f4557ccaf95c10b9811b13ecced089d43ce59c3c8/PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44", size = 161980 },
+ { url = "https://files.pythonhosted.org/packages/86/0c/c581167fc46d6d6d7ddcfb8c843a4de25bdd27e4466938109ca68492292c/PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab", size = 183873 },
+ { url = "https://files.pythonhosted.org/packages/a8/0c/38374f5bb272c051e2a69281d71cba6fdb983413e6758b84482905e29a5d/PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725", size = 173302 },
+ { url = "https://files.pythonhosted.org/packages/c3/93/9916574aa8c00aa06bbac729972eb1071d002b8e158bd0e83a3b9a20a1f7/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5", size = 739154 },
+ { url = "https://files.pythonhosted.org/packages/95/0f/b8938f1cbd09739c6da569d172531567dbcc9789e0029aa070856f123984/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425", size = 766223 },
+ { url = "https://files.pythonhosted.org/packages/b9/2b/614b4752f2e127db5cc206abc23a8c19678e92b23c3db30fc86ab731d3bd/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476", size = 767542 },
+ { url = "https://files.pythonhosted.org/packages/d4/00/dd137d5bcc7efea1836d6264f049359861cf548469d18da90cd8216cf05f/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48", size = 731164 },
+ { url = "https://files.pythonhosted.org/packages/c9/1f/4f998c900485e5c0ef43838363ba4a9723ac0ad73a9dc42068b12aaba4e4/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b", size = 756611 },
+ { url = "https://files.pythonhosted.org/packages/df/d1/f5a275fdb252768b7a11ec63585bc38d0e87c9e05668a139fea92b80634c/PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4", size = 140591 },
+ { url = "https://files.pythonhosted.org/packages/0c/e8/4f648c598b17c3d06e8753d7d13d57542b30d56e6c2dedf9c331ae56312e/PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8", size = 156338 },
+]
+
+[[package]]
+name = "regex"
+version = "2024.11.6"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/8e/5f/bd69653fbfb76cf8604468d3b4ec4c403197144c7bfe0e6a5fc9e02a07cb/regex-2024.11.6.tar.gz", hash = "sha256:7ab159b063c52a0333c884e4679f8d7a85112ee3078fe3d9004b2dd875585519", size = 399494 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/95/3c/4651f6b130c6842a8f3df82461a8950f923925db8b6961063e82744bddcc/regex-2024.11.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ff590880083d60acc0433f9c3f713c51f7ac6ebb9adf889c79a261ecf541aa91", size = 482674 },
+ { url = "https://files.pythonhosted.org/packages/15/51/9f35d12da8434b489c7b7bffc205c474a0a9432a889457026e9bc06a297a/regex-2024.11.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:658f90550f38270639e83ce492f27d2c8d2cd63805c65a13a14d36ca126753f0", size = 287684 },
+ { url = "https://files.pythonhosted.org/packages/bd/18/b731f5510d1b8fb63c6b6d3484bfa9a59b84cc578ac8b5172970e05ae07c/regex-2024.11.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:164d8b7b3b4bcb2068b97428060b2a53be050085ef94eca7f240e7947f1b080e", size = 284589 },
+ { url = "https://files.pythonhosted.org/packages/78/a2/6dd36e16341ab95e4c6073426561b9bfdeb1a9c9b63ab1b579c2e96cb105/regex-2024.11.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3660c82f209655a06b587d55e723f0b813d3a7db2e32e5e7dc64ac2a9e86fde", size = 782511 },
+ { url = "https://files.pythonhosted.org/packages/1b/2b/323e72d5d2fd8de0d9baa443e1ed70363ed7e7b2fb526f5950c5cb99c364/regex-2024.11.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d22326fcdef5e08c154280b71163ced384b428343ae16a5ab2b3354aed12436e", size = 821149 },
+ { url = "https://files.pythonhosted.org/packages/90/30/63373b9ea468fbef8a907fd273e5c329b8c9535fee36fc8dba5fecac475d/regex-2024.11.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f1ac758ef6aebfc8943560194e9fd0fa18bcb34d89fd8bd2af18183afd8da3a2", size = 809707 },
+ { url = "https://files.pythonhosted.org/packages/f2/98/26d3830875b53071f1f0ae6d547f1d98e964dd29ad35cbf94439120bb67a/regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:997d6a487ff00807ba810e0f8332c18b4eb8d29463cfb7c820dc4b6e7562d0cf", size = 781702 },
+ { url = "https://files.pythonhosted.org/packages/87/55/eb2a068334274db86208ab9d5599ffa63631b9f0f67ed70ea7c82a69bbc8/regex-2024.11.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02a02d2bb04fec86ad61f3ea7f49c015a0681bf76abb9857f945d26159d2968c", size = 771976 },
+ { url = "https://files.pythonhosted.org/packages/74/c0/be707bcfe98254d8f9d2cff55d216e946f4ea48ad2fd8cf1428f8c5332ba/regex-2024.11.6-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f02f93b92358ee3f78660e43b4b0091229260c5d5c408d17d60bf26b6c900e86", size = 697397 },
+ { url = "https://files.pythonhosted.org/packages/49/dc/bb45572ceb49e0f6509f7596e4ba7031f6819ecb26bc7610979af5a77f45/regex-2024.11.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:06eb1be98df10e81ebaded73fcd51989dcf534e3c753466e4b60c4697a003b67", size = 768726 },
+ { url = "https://files.pythonhosted.org/packages/5a/db/f43fd75dc4c0c2d96d0881967897926942e935d700863666f3c844a72ce6/regex-2024.11.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:040df6fe1a5504eb0f04f048e6d09cd7c7110fef851d7c567a6b6e09942feb7d", size = 775098 },
+ { url = "https://files.pythonhosted.org/packages/99/d7/f94154db29ab5a89d69ff893159b19ada89e76b915c1293e98603d39838c/regex-2024.11.6-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:fdabbfc59f2c6edba2a6622c647b716e34e8e3867e0ab975412c5c2f79b82da2", size = 839325 },
+ { url = "https://files.pythonhosted.org/packages/f7/17/3cbfab1f23356fbbf07708220ab438a7efa1e0f34195bf857433f79f1788/regex-2024.11.6-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8447d2d39b5abe381419319f942de20b7ecd60ce86f16a23b0698f22e1b70008", size = 843277 },
+ { url = "https://files.pythonhosted.org/packages/7e/f2/48b393b51900456155de3ad001900f94298965e1cad1c772b87f9cfea011/regex-2024.11.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:da8f5fc57d1933de22a9e23eec290a0d8a5927a5370d24bda9a6abe50683fe62", size = 773197 },
+ { url = "https://files.pythonhosted.org/packages/45/3f/ef9589aba93e084cd3f8471fded352826dcae8489b650d0b9b27bc5bba8a/regex-2024.11.6-cp310-cp310-win32.whl", hash = "sha256:b489578720afb782f6ccf2840920f3a32e31ba28a4b162e13900c3e6bd3f930e", size = 261714 },
+ { url = "https://files.pythonhosted.org/packages/42/7e/5f1b92c8468290c465fd50c5318da64319133231415a8aa6ea5ab995a815/regex-2024.11.6-cp310-cp310-win_amd64.whl", hash = "sha256:5071b2093e793357c9d8b2929dfc13ac5f0a6c650559503bb81189d0a3814519", size = 274042 },
+ { url = "https://files.pythonhosted.org/packages/58/58/7e4d9493a66c88a7da6d205768119f51af0f684fe7be7bac8328e217a52c/regex-2024.11.6-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5478c6962ad548b54a591778e93cd7c456a7a29f8eca9c49e4f9a806dcc5d638", size = 482669 },
+ { url = "https://files.pythonhosted.org/packages/34/4c/8f8e631fcdc2ff978609eaeef1d6994bf2f028b59d9ac67640ed051f1218/regex-2024.11.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2c89a8cc122b25ce6945f0423dc1352cb9593c68abd19223eebbd4e56612c5b7", size = 287684 },
+ { url = "https://files.pythonhosted.org/packages/c5/1b/f0e4d13e6adf866ce9b069e191f303a30ab1277e037037a365c3aad5cc9c/regex-2024.11.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:94d87b689cdd831934fa3ce16cc15cd65748e6d689f5d2b8f4f4df2065c9fa20", size = 284589 },
+ { url = "https://files.pythonhosted.org/packages/25/4d/ab21047f446693887f25510887e6820b93f791992994f6498b0318904d4a/regex-2024.11.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1062b39a0a2b75a9c694f7a08e7183a80c63c0d62b301418ffd9c35f55aaa114", size = 792121 },
+ { url = "https://files.pythonhosted.org/packages/45/ee/c867e15cd894985cb32b731d89576c41a4642a57850c162490ea34b78c3b/regex-2024.11.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:167ed4852351d8a750da48712c3930b031f6efdaa0f22fa1933716bfcd6bf4a3", size = 831275 },
+ { url = "https://files.pythonhosted.org/packages/b3/12/b0f480726cf1c60f6536fa5e1c95275a77624f3ac8fdccf79e6727499e28/regex-2024.11.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2d548dafee61f06ebdb584080621f3e0c23fff312f0de1afc776e2a2ba99a74f", size = 818257 },
+ { url = "https://files.pythonhosted.org/packages/bf/ce/0d0e61429f603bac433910d99ef1a02ce45a8967ffbe3cbee48599e62d88/regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2a19f302cd1ce5dd01a9099aaa19cae6173306d1302a43b627f62e21cf18ac0", size = 792727 },
+ { url = "https://files.pythonhosted.org/packages/e4/c1/243c83c53d4a419c1556f43777ccb552bccdf79d08fda3980e4e77dd9137/regex-2024.11.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bec9931dfb61ddd8ef2ebc05646293812cb6b16b60cf7c9511a832b6f1854b55", size = 780667 },
+ { url = "https://files.pythonhosted.org/packages/c5/f4/75eb0dd4ce4b37f04928987f1d22547ddaf6c4bae697623c1b05da67a8aa/regex-2024.11.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9714398225f299aa85267fd222f7142fcb5c769e73d7733344efc46f2ef5cf89", size = 776963 },
+ { url = "https://files.pythonhosted.org/packages/16/5d/95c568574e630e141a69ff8a254c2f188b4398e813c40d49228c9bbd9875/regex-2024.11.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:202eb32e89f60fc147a41e55cb086db2a3f8cb82f9a9a88440dcfc5d37faae8d", size = 784700 },
+ { url = "https://files.pythonhosted.org/packages/8e/b5/f8495c7917f15cc6fee1e7f395e324ec3e00ab3c665a7dc9d27562fd5290/regex-2024.11.6-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:4181b814e56078e9b00427ca358ec44333765f5ca1b45597ec7446d3a1ef6e34", size = 848592 },
+ { url = "https://files.pythonhosted.org/packages/1c/80/6dd7118e8cb212c3c60b191b932dc57db93fb2e36fb9e0e92f72a5909af9/regex-2024.11.6-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:068376da5a7e4da51968ce4c122a7cd31afaaec4fccc7856c92f63876e57b51d", size = 852929 },
+ { url = "https://files.pythonhosted.org/packages/11/9b/5a05d2040297d2d254baf95eeeb6df83554e5e1df03bc1a6687fc4ba1f66/regex-2024.11.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ac10f2c4184420d881a3475fb2c6f4d95d53a8d50209a2500723d831036f7c45", size = 781213 },
+ { url = "https://files.pythonhosted.org/packages/26/b7/b14e2440156ab39e0177506c08c18accaf2b8932e39fb092074de733d868/regex-2024.11.6-cp311-cp311-win32.whl", hash = "sha256:c36f9b6f5f8649bb251a5f3f66564438977b7ef8386a52460ae77e6070d309d9", size = 261734 },
+ { url = "https://files.pythonhosted.org/packages/80/32/763a6cc01d21fb3819227a1cc3f60fd251c13c37c27a73b8ff4315433a8e/regex-2024.11.6-cp311-cp311-win_amd64.whl", hash = "sha256:02e28184be537f0e75c1f9b2f8847dc51e08e6e171c6bde130b2687e0c33cf60", size = 274052 },
+ { url = "https://files.pythonhosted.org/packages/ba/30/9a87ce8336b172cc232a0db89a3af97929d06c11ceaa19d97d84fa90a8f8/regex-2024.11.6-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:52fb28f528778f184f870b7cf8f225f5eef0a8f6e3778529bdd40c7b3920796a", size = 483781 },
+ { url = "https://files.pythonhosted.org/packages/01/e8/00008ad4ff4be8b1844786ba6636035f7ef926db5686e4c0f98093612add/regex-2024.11.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fdd6028445d2460f33136c55eeb1f601ab06d74cb3347132e1c24250187500d9", size = 288455 },
+ { url = "https://files.pythonhosted.org/packages/60/85/cebcc0aff603ea0a201667b203f13ba75d9fc8668fab917ac5b2de3967bc/regex-2024.11.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:805e6b60c54bf766b251e94526ebad60b7de0c70f70a4e6210ee2891acb70bf2", size = 284759 },
+ { url = "https://files.pythonhosted.org/packages/94/2b/701a4b0585cb05472a4da28ee28fdfe155f3638f5e1ec92306d924e5faf0/regex-2024.11.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b85c2530be953a890eaffde05485238f07029600e8f098cdf1848d414a8b45e4", size = 794976 },
+ { url = "https://files.pythonhosted.org/packages/4b/bf/fa87e563bf5fee75db8915f7352e1887b1249126a1be4813837f5dbec965/regex-2024.11.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bb26437975da7dc36b7efad18aa9dd4ea569d2357ae6b783bf1118dabd9ea577", size = 833077 },
+ { url = "https://files.pythonhosted.org/packages/a1/56/7295e6bad94b047f4d0834e4779491b81216583c00c288252ef625c01d23/regex-2024.11.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:abfa5080c374a76a251ba60683242bc17eeb2c9818d0d30117b4486be10c59d3", size = 823160 },
+ { url = "https://files.pythonhosted.org/packages/fb/13/e3b075031a738c9598c51cfbc4c7879e26729c53aa9cca59211c44235314/regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b7fa6606c2881c1db9479b0eaa11ed5dfa11c8d60a474ff0e095099f39d98e", size = 796896 },
+ { url = "https://files.pythonhosted.org/packages/24/56/0b3f1b66d592be6efec23a795b37732682520b47c53da5a32c33ed7d84e3/regex-2024.11.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c32f75920cf99fe6b6c539c399a4a128452eaf1af27f39bce8909c9a3fd8cbe", size = 783997 },
+ { url = "https://files.pythonhosted.org/packages/f9/a1/eb378dada8b91c0e4c5f08ffb56f25fcae47bf52ad18f9b2f33b83e6d498/regex-2024.11.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:982e6d21414e78e1f51cf595d7f321dcd14de1f2881c5dc6a6e23bbbbd68435e", size = 781725 },
+ { url = "https://files.pythonhosted.org/packages/83/f2/033e7dec0cfd6dda93390089864732a3409246ffe8b042e9554afa9bff4e/regex-2024.11.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:a7c2155f790e2fb448faed6dd241386719802296ec588a8b9051c1f5c481bc29", size = 789481 },
+ { url = "https://files.pythonhosted.org/packages/83/23/15d4552ea28990a74e7696780c438aadd73a20318c47e527b47a4a5a596d/regex-2024.11.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:149f5008d286636e48cd0b1dd65018548944e495b0265b45e1bffecce1ef7f39", size = 852896 },
+ { url = "https://files.pythonhosted.org/packages/e3/39/ed4416bc90deedbfdada2568b2cb0bc1fdb98efe11f5378d9892b2a88f8f/regex-2024.11.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:e5364a4502efca094731680e80009632ad6624084aff9a23ce8c8c6820de3e51", size = 860138 },
+ { url = "https://files.pythonhosted.org/packages/93/2d/dd56bb76bd8e95bbce684326302f287455b56242a4f9c61f1bc76e28360e/regex-2024.11.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0a86e7eeca091c09e021db8eb72d54751e527fa47b8d5787caf96d9831bd02ad", size = 787692 },
+ { url = "https://files.pythonhosted.org/packages/0b/55/31877a249ab7a5156758246b9c59539abbeba22461b7d8adc9e8475ff73e/regex-2024.11.6-cp312-cp312-win32.whl", hash = "sha256:32f9a4c643baad4efa81d549c2aadefaeba12249b2adc5af541759237eee1c54", size = 262135 },
+ { url = "https://files.pythonhosted.org/packages/38/ec/ad2d7de49a600cdb8dd78434a1aeffe28b9d6fc42eb36afab4a27ad23384/regex-2024.11.6-cp312-cp312-win_amd64.whl", hash = "sha256:a93c194e2df18f7d264092dc8539b8ffb86b45b899ab976aa15d48214138e81b", size = 273567 },
+]
+
+[[package]]
+name = "requests"
+version = "2.32.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "certifi" },
+ { name = "charset-normalizer" },
+ { name = "idna" },
+ { name = "urllib3" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/63/70/2bf7780ad2d390a8d301ad0b550f1581eadbd9a20f896afe06353c2a2913/requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", size = 131218 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 },
+]
+
+[[package]]
+name = "rich"
+version = "13.9.4"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "markdown-it-py" },
+ { name = "pygments" },
+ { name = "typing-extensions", marker = "python_full_version < '3.11'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/ab/3a/0316b28d0761c6734d6bc14e770d85506c986c85ffb239e688eeaab2c2bc/rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098", size = 223149 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90", size = 242424 },
+]
+
+[[package]]
+name = "safetensors"
+version = "0.5.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/f4/4f/2ef9ef1766f8c194b01b67a63a444d2e557c8fe1d82faf3ebd85f370a917/safetensors-0.5.2.tar.gz", hash = "sha256:cb4a8d98ba12fa016f4241932b1fc5e702e5143f5374bba0bbcf7ddc1c4cf2b8", size = 66957 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/96/d1/017e31e75e274492a11a456a9e7c171f8f7911fe50735b4ec6ff37221220/safetensors-0.5.2-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:45b6092997ceb8aa3801693781a71a99909ab9cc776fbc3fa9322d29b1d3bef2", size = 427067 },
+ { url = "https://files.pythonhosted.org/packages/24/84/e9d3ff57ae50dd0028f301c9ee064e5087fe8b00e55696677a0413c377a7/safetensors-0.5.2-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:6d0d6a8ee2215a440e1296b843edf44fd377b055ba350eaba74655a2fe2c4bae", size = 408856 },
+ { url = "https://files.pythonhosted.org/packages/f1/1d/fe95f5dd73db16757b11915e8a5106337663182d0381811c81993e0014a9/safetensors-0.5.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86016d40bcaa3bcc9a56cd74d97e654b5f4f4abe42b038c71e4f00a089c4526c", size = 450088 },
+ { url = "https://files.pythonhosted.org/packages/cf/21/e527961b12d5ab528c6e47b92d5f57f33563c28a972750b238b871924e49/safetensors-0.5.2-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:990833f70a5f9c7d3fc82c94507f03179930ff7d00941c287f73b6fcbf67f19e", size = 458966 },
+ { url = "https://files.pythonhosted.org/packages/a5/8b/1a037d7a57f86837c0b41905040369aea7d8ca1ec4b2a77592372b2ec380/safetensors-0.5.2-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3dfa7c2f3fe55db34eba90c29df94bcdac4821043fc391cb5d082d9922013869", size = 509915 },
+ { url = "https://files.pythonhosted.org/packages/61/3d/03dd5cfd33839df0ee3f4581a20bd09c40246d169c0e4518f20b21d5f077/safetensors-0.5.2-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:46ff2116150ae70a4e9c490d2ab6b6e1b1b93f25e520e540abe1b81b48560c3a", size = 527664 },
+ { url = "https://files.pythonhosted.org/packages/c5/dc/8952caafa9a10a3c0f40fa86bacf3190ae7f55fa5eef87415b97b29cb97f/safetensors-0.5.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ab696dfdc060caffb61dbe4066b86419107a24c804a4e373ba59be699ebd8d5", size = 461978 },
+ { url = "https://files.pythonhosted.org/packages/60/da/82de1fcf1194e3dbefd4faa92dc98b33c06bed5d67890e0962dd98e18287/safetensors-0.5.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:03c937100f38c9ff4c1507abea9928a6a9b02c9c1c9c3609ed4fb2bf413d4975", size = 491253 },
+ { url = "https://files.pythonhosted.org/packages/5a/9a/d90e273c25f90c3ba1b0196a972003786f04c39e302fbd6649325b1272bb/safetensors-0.5.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:a00e737948791b94dad83cf0eafc09a02c4d8c2171a239e8c8572fe04e25960e", size = 628644 },
+ { url = "https://files.pythonhosted.org/packages/70/3c/acb23e05aa34b4f5edd2e7f393f8e6480fbccd10601ab42cd03a57d4ab5f/safetensors-0.5.2-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:d3a06fae62418ec8e5c635b61a8086032c9e281f16c63c3af46a6efbab33156f", size = 721648 },
+ { url = "https://files.pythonhosted.org/packages/71/45/eaa3dba5253a7c6931230dc961641455710ab231f8a89cb3c4c2af70f8c8/safetensors-0.5.2-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:1506e4c2eda1431099cebe9abf6c76853e95d0b7a95addceaa74c6019c65d8cf", size = 659588 },
+ { url = "https://files.pythonhosted.org/packages/b0/71/2f9851164f821064d43b481ddbea0149c2d676c4f4e077b178e7eeaa6660/safetensors-0.5.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5c5b5d9da594f638a259fca766046f44c97244cc7ab8bef161b3e80d04becc76", size = 632533 },
+ { url = "https://files.pythonhosted.org/packages/00/f1/5680e2ef61d9c61454fad82c344f0e40b8741a9dbd1e31484f0d31a9b1c3/safetensors-0.5.2-cp38-abi3-win32.whl", hash = "sha256:fe55c039d97090d1f85277d402954dd6ad27f63034fa81985a9cc59655ac3ee2", size = 291167 },
+ { url = "https://files.pythonhosted.org/packages/86/ca/aa489392ec6fb59223ffce825461e1f811a3affd417121a2088be7a5758b/safetensors-0.5.2-cp38-abi3-win_amd64.whl", hash = "sha256:78abdddd03a406646107f973c7843276e7b64e5e32623529dc17f3d94a20f589", size = 303756 },
+]
+
+[[package]]
+name = "scikit-learn"
+version = "1.6.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "joblib" },
+ { name = "numpy" },
+ { name = "scipy" },
+ { name = "threadpoolctl" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/9e/a5/4ae3b3a0755f7b35a280ac90b28817d1f380318973cff14075ab41ef50d9/scikit_learn-1.6.1.tar.gz", hash = "sha256:b4fc2525eca2c69a59260f583c56a7557c6ccdf8deafdba6e060f94c1c59738e", size = 7068312 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/2e/3a/f4597eb41049110b21ebcbb0bcb43e4035017545daa5eedcfeb45c08b9c5/scikit_learn-1.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d056391530ccd1e501056160e3c9673b4da4805eb67eb2bdf4e983e1f9c9204e", size = 12067702 },
+ { url = "https://files.pythonhosted.org/packages/37/19/0423e5e1fd1c6ec5be2352ba05a537a473c1677f8188b9306097d684b327/scikit_learn-1.6.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0c8d036eb937dbb568c6242fa598d551d88fb4399c0344d95c001980ec1c7d36", size = 11112765 },
+ { url = "https://files.pythonhosted.org/packages/70/95/d5cb2297a835b0f5fc9a77042b0a2d029866379091ab8b3f52cc62277808/scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8634c4bd21a2a813e0a7e3900464e6d593162a29dd35d25bdf0103b3fce60ed5", size = 12643991 },
+ { url = "https://files.pythonhosted.org/packages/b7/91/ab3c697188f224d658969f678be86b0968ccc52774c8ab4a86a07be13c25/scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:775da975a471c4f6f467725dff0ced5c7ac7bda5e9316b260225b48475279a1b", size = 13497182 },
+ { url = "https://files.pythonhosted.org/packages/17/04/d5d556b6c88886c092cc989433b2bab62488e0f0dafe616a1d5c9cb0efb1/scikit_learn-1.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:8a600c31592bd7dab31e1c61b9bbd6dea1b3433e67d264d17ce1017dbdce8002", size = 11125517 },
+ { url = "https://files.pythonhosted.org/packages/6c/2a/e291c29670795406a824567d1dfc91db7b699799a002fdaa452bceea8f6e/scikit_learn-1.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:72abc587c75234935e97d09aa4913a82f7b03ee0b74111dcc2881cba3c5a7b33", size = 12102620 },
+ { url = "https://files.pythonhosted.org/packages/25/92/ee1d7a00bb6b8c55755d4984fd82608603a3cc59959245068ce32e7fb808/scikit_learn-1.6.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b3b00cdc8f1317b5f33191df1386c0befd16625f49d979fe77a8d44cae82410d", size = 11116234 },
+ { url = "https://files.pythonhosted.org/packages/30/cd/ed4399485ef364bb25f388ab438e3724e60dc218c547a407b6e90ccccaef/scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc4765af3386811c3ca21638f63b9cf5ecf66261cc4815c1db3f1e7dc7b79db2", size = 12592155 },
+ { url = "https://files.pythonhosted.org/packages/a8/f3/62fc9a5a659bb58a03cdd7e258956a5824bdc9b4bb3c5d932f55880be569/scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25fc636bdaf1cc2f4a124a116312d837148b5e10872147bdaf4887926b8c03d8", size = 13497069 },
+ { url = "https://files.pythonhosted.org/packages/a1/a6/c5b78606743a1f28eae8f11973de6613a5ee87366796583fb74c67d54939/scikit_learn-1.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:fa909b1a36e000a03c382aade0bd2063fd5680ff8b8e501660c0f59f021a6415", size = 11139809 },
+ { url = "https://files.pythonhosted.org/packages/0a/18/c797c9b8c10380d05616db3bfb48e2a3358c767affd0857d56c2eb501caa/scikit_learn-1.6.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:926f207c804104677af4857b2c609940b743d04c4c35ce0ddc8ff4f053cddc1b", size = 12104516 },
+ { url = "https://files.pythonhosted.org/packages/c4/b7/2e35f8e289ab70108f8cbb2e7a2208f0575dc704749721286519dcf35f6f/scikit_learn-1.6.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2c2cae262064e6a9b77eee1c8e768fc46aa0b8338c6a8297b9b6759720ec0ff2", size = 11167837 },
+ { url = "https://files.pythonhosted.org/packages/a4/f6/ff7beaeb644bcad72bcfd5a03ff36d32ee4e53a8b29a639f11bcb65d06cd/scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1061b7c028a8663fb9a1a1baf9317b64a257fcb036dae5c8752b2abef31d136f", size = 12253728 },
+ { url = "https://files.pythonhosted.org/packages/29/7a/8bce8968883e9465de20be15542f4c7e221952441727c4dad24d534c6d99/scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e69fab4ebfc9c9b580a7a80111b43d214ab06250f8a7ef590a4edf72464dd86", size = 13147700 },
+ { url = "https://files.pythonhosted.org/packages/62/27/585859e72e117fe861c2079bcba35591a84f801e21bc1ab85bce6ce60305/scikit_learn-1.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:70b1d7e85b1c96383f872a519b3375f92f14731e279a7b4c6cfd650cf5dffc52", size = 11110613 },
+]
+
+[[package]]
+name = "scipy"
+version = "1.15.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "numpy" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/76/c6/8eb0654ba0c7d0bb1bf67bf8fbace101a8e4f250f7722371105e8b6f68fc/scipy-1.15.1.tar.gz", hash = "sha256:033a75ddad1463970c96a88063a1df87ccfddd526437136b6ee81ff0312ebdf6", size = 59407493 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/86/53/b204ce5a4433f1864001b9d16f103b9c25f5002a602ae83585d0ea5f9c4a/scipy-1.15.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:c64ded12dcab08afff9e805a67ff4480f5e69993310e093434b10e85dc9d43e1", size = 41414518 },
+ { url = "https://files.pythonhosted.org/packages/c7/fc/54ffa7a8847f7f303197a6ba65a66104724beba2e38f328135a78f0dc480/scipy-1.15.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:5b190b935e7db569960b48840e5bef71dc513314cc4e79a1b7d14664f57fd4ff", size = 32519265 },
+ { url = "https://files.pythonhosted.org/packages/f1/77/a98b8ba03d6f371dc31a38719affd53426d4665729dcffbed4afe296784a/scipy-1.15.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:4b17d4220df99bacb63065c76b0d1126d82bbf00167d1730019d2a30d6ae01ea", size = 24792859 },
+ { url = "https://files.pythonhosted.org/packages/a7/78/70bb9f0df7444b18b108580934bfef774822e28fd34a68e5c263c7d2828a/scipy-1.15.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:63b9b6cd0333d0eb1a49de6f834e8aeaefe438df8f6372352084535ad095219e", size = 27886506 },
+ { url = "https://files.pythonhosted.org/packages/14/a7/f40f6033e06de4176ddd6cc8c3ae9f10a226c3bca5d6b4ab883bc9914a14/scipy-1.15.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f151e9fb60fbf8e52426132f473221a49362091ce7a5e72f8aa41f8e0da4f25", size = 38375041 },
+ { url = "https://files.pythonhosted.org/packages/17/03/390a1c5c61fd76b0fa4b3c5aa3bdd7e60f6c46f712924f1a9df5705ec046/scipy-1.15.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21e10b1dd56ce92fba3e786007322542361984f8463c6d37f6f25935a5a6ef52", size = 40597556 },
+ { url = "https://files.pythonhosted.org/packages/4e/70/fa95b3ae026b97eeca58204a90868802e5155ac71b9d7bdee92b68115dd3/scipy-1.15.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5dff14e75cdbcf07cdaa1c7707db6017d130f0af9ac41f6ce443a93318d6c6e0", size = 42938505 },
+ { url = "https://files.pythonhosted.org/packages/d6/07/427859116bdd71847c898180f01802691f203c3e2455a1eb496130ff07c5/scipy-1.15.1-cp310-cp310-win_amd64.whl", hash = "sha256:f82fcf4e5b377f819542fbc8541f7b5fbcf1c0017d0df0bc22c781bf60abc4d8", size = 43909663 },
+ { url = "https://files.pythonhosted.org/packages/8e/2e/7b71312da9c2dabff53e7c9a9d08231bc34d9d8fdabe88a6f1155b44591c/scipy-1.15.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:5bd8d27d44e2c13d0c1124e6a556454f52cd3f704742985f6b09e75e163d20d2", size = 41424362 },
+ { url = "https://files.pythonhosted.org/packages/81/8c/ab85f1aa1cc200c796532a385b6ebf6a81089747adc1da7482a062acc46c/scipy-1.15.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:be3deeb32844c27599347faa077b359584ba96664c5c79d71a354b80a0ad0ce0", size = 32535910 },
+ { url = "https://files.pythonhosted.org/packages/3b/9c/6f4b787058daa8d8da21ddff881b4320e28de4704a65ec147adb50cb2230/scipy-1.15.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:5eb0ca35d4b08e95da99a9f9c400dc9f6c21c424298a0ba876fdc69c7afacedf", size = 24809398 },
+ { url = "https://files.pythonhosted.org/packages/16/2b/949460a796df75fc7a1ee1becea202cf072edbe325ebe29f6d2029947aa7/scipy-1.15.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:74bb864ff7640dea310a1377d8567dc2cb7599c26a79ca852fc184cc851954ac", size = 27918045 },
+ { url = "https://files.pythonhosted.org/packages/5f/36/67fe249dd7ccfcd2a38b25a640e3af7e59d9169c802478b6035ba91dfd6d/scipy-1.15.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:667f950bf8b7c3a23b4199db24cb9bf7512e27e86d0e3813f015b74ec2c6e3df", size = 38332074 },
+ { url = "https://files.pythonhosted.org/packages/fc/da/452e1119e6f720df3feb588cce3c42c5e3d628d4bfd4aec097bd30b7de0c/scipy-1.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:395be70220d1189756068b3173853029a013d8c8dd5fd3d1361d505b2aa58fa7", size = 40588469 },
+ { url = "https://files.pythonhosted.org/packages/7f/71/5f94aceeac99a4941478af94fe9f459c6752d497035b6b0761a700f5f9ff/scipy-1.15.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ce3a000cd28b4430426db2ca44d96636f701ed12e2b3ca1f2b1dd7abdd84b39a", size = 42965214 },
+ { url = "https://files.pythonhosted.org/packages/af/25/caa430865749d504271757cafd24066d596217e83326155993980bc22f97/scipy-1.15.1-cp311-cp311-win_amd64.whl", hash = "sha256:3fe1d95944f9cf6ba77aa28b82dd6bb2a5b52f2026beb39ecf05304b8392864b", size = 43896034 },
+ { url = "https://files.pythonhosted.org/packages/d8/6e/a9c42d0d39e09ed7fd203d0ac17adfea759cba61ab457671fe66e523dbec/scipy-1.15.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c09aa9d90f3500ea4c9b393ee96f96b0ccb27f2f350d09a47f533293c78ea776", size = 41478318 },
+ { url = "https://files.pythonhosted.org/packages/04/ee/e3e535c81828618878a7433992fecc92fa4df79393f31a8fea1d05615091/scipy-1.15.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:0ac102ce99934b162914b1e4a6b94ca7da0f4058b6d6fd65b0cef330c0f3346f", size = 32596696 },
+ { url = "https://files.pythonhosted.org/packages/c4/5e/b1b0124be8e76f87115f16b8915003eec4b7060298117715baf13f51942c/scipy-1.15.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:09c52320c42d7f5c7748b69e9f0389266fd4f82cf34c38485c14ee976cb8cb04", size = 24870366 },
+ { url = "https://files.pythonhosted.org/packages/14/36/c00cb73eefda85946172c27913ab995c6ad4eee00fa4f007572e8c50cd51/scipy-1.15.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:cdde8414154054763b42b74fe8ce89d7f3d17a7ac5dd77204f0e142cdc9239e9", size = 28007461 },
+ { url = "https://files.pythonhosted.org/packages/68/94/aff5c51b3799349a9d1e67a056772a0f8a47db371e83b498d43467806557/scipy-1.15.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c9d8fc81d6a3b6844235e6fd175ee1d4c060163905a2becce8e74cb0d7554ce", size = 38068174 },
+ { url = "https://files.pythonhosted.org/packages/b0/3c/0de11ca154e24a57b579fb648151d901326d3102115bc4f9a7a86526ce54/scipy-1.15.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fb57b30f0017d4afa5fe5f5b150b8f807618819287c21cbe51130de7ccdaed2", size = 40249869 },
+ { url = "https://files.pythonhosted.org/packages/15/09/472e8d0a6b33199d1bb95e49bedcabc0976c3724edd9b0ef7602ccacf41e/scipy-1.15.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:491d57fe89927fa1aafbe260f4cfa5ffa20ab9f1435025045a5315006a91b8f5", size = 42629068 },
+ { url = "https://files.pythonhosted.org/packages/ff/ba/31c7a8131152822b3a2cdeba76398ffb404d81d640de98287d236da90c49/scipy-1.15.1-cp312-cp312-win_amd64.whl", hash = "sha256:900f3fa3db87257510f011c292a5779eb627043dd89731b9c461cd16ef76ab3d", size = 43621992 },
+]
+
+[[package]]
+name = "sentence-transformers"
+version = "3.4.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "huggingface-hub" },
+ { name = "pillow" },
+ { name = "scikit-learn" },
+ { name = "scipy" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "tqdm" },
+ { name = "transformers" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/16/74/aca6f8a2b8d62b4daf8c9a0c49d2aa573381caf47dc35cbb343389229376/sentence_transformers-3.4.1.tar.gz", hash = "sha256:68daa57504ff548340e54ff117bd86c1d2f784b21e0fb2689cf3272b8937b24b", size = 223898 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/05/89/7eb147a37b7f31d3c815543df539d8b8d0425e93296c875cc87719d65232/sentence_transformers-3.4.1-py3-none-any.whl", hash = "sha256:e026dc6d56801fd83f74ad29a30263f401b4b522165c19386d8bc10dcca805da", size = 275896 },
+]
+
+[[package]]
+name = "sentencepiece"
+version = "0.2.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/c9/d2/b9c7ca067c26d8ff085d252c89b5f69609ca93fb85a00ede95f4857865d4/sentencepiece-0.2.0.tar.gz", hash = "sha256:a52c19171daaf2e697dc6cbe67684e0fa341b1248966f6aebb541de654d15843", size = 2632106 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/f6/71/98648c3b64b23edb5403f74bcc906ad21766872a6e1ada26ea3f1eb941ab/sentencepiece-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:188779e1298a1c8b8253c7d3ad729cb0a9891e5cef5e5d07ce4592c54869e227", size = 2408979 },
+ { url = "https://files.pythonhosted.org/packages/77/9f/7efbaa6d4c0c718a9affbecc536b03ca62f99f421bdffb531c16030e2d2b/sentencepiece-0.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bed9cf85b296fa2b76fc2547b9cbb691a523864cebaee86304c43a7b4cb1b452", size = 1238845 },
+ { url = "https://files.pythonhosted.org/packages/1c/e4/c2541027a43ec6962ba9b601805d17ba3f86b38bdeae0e8ac65a2981e248/sentencepiece-0.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d7b67e724bead13f18db6e1d10b6bbdc454af574d70efbb36f27d90387be1ca3", size = 1181472 },
+ { url = "https://files.pythonhosted.org/packages/fd/46/316c1ba6c52b97de76aff7b9da678f7afbb52136afb2987c474d95630e65/sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2fde4b08cfe237be4484c6c7c2e2c75fb862cfeab6bd5449ce4caeafd97b767a", size = 1259151 },
+ { url = "https://files.pythonhosted.org/packages/aa/5a/3c48738a0835d76dd06c62b6ac48d39c923cde78dd0f587353bdcbb99851/sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c378492056202d1c48a4979650981635fd97875a00eabb1f00c6a236b013b5e", size = 1355931 },
+ { url = "https://files.pythonhosted.org/packages/a6/27/33019685023221ca8ed98e8ceb7ae5e166032686fa3662c68f1f1edf334e/sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1380ce6540a368de2ef6d7e6ba14ba8f3258df650d39ba7d833b79ee68a52040", size = 1301537 },
+ { url = "https://files.pythonhosted.org/packages/ca/e4/55f97cef14293171fef5f96e96999919ab5b4d1ce95b53547ad653d7e3bf/sentencepiece-0.2.0-cp310-cp310-win32.whl", hash = "sha256:a1151d6a6dd4b43e552394aed0edfe9292820272f0194bd56c7c1660a0c06c3d", size = 936747 },
+ { url = "https://files.pythonhosted.org/packages/85/f4/4ef1a6e0e9dbd8a60780a91df8b7452ada14cfaa0e17b3b8dfa42cecae18/sentencepiece-0.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:d490142b0521ef22bc1085f061d922a2a6666175bb6b42e588ff95c0db6819b2", size = 991525 },
+ { url = "https://files.pythonhosted.org/packages/32/43/8f8885168a47a02eba1455bd3f4f169f50ad5b8cebd2402d0f5e20854d04/sentencepiece-0.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:17982700c4f6dbb55fa3594f3d7e5dd1c8659a274af3738e33c987d2a27c9d5c", size = 2409036 },
+ { url = "https://files.pythonhosted.org/packages/0f/35/e63ba28062af0a3d688a9f128e407a1a2608544b2f480cb49bf7f4b1cbb9/sentencepiece-0.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7c867012c0e8bcd5bdad0f791609101cb5c66acb303ab3270218d6debc68a65e", size = 1238921 },
+ { url = "https://files.pythonhosted.org/packages/de/42/ae30952c4a0bd773e90c9bf2579f5533037c886dfc8ec68133d5694f4dd2/sentencepiece-0.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7fd6071249c74f779c5b27183295b9202f8dedb68034e716784364443879eaa6", size = 1181477 },
+ { url = "https://files.pythonhosted.org/packages/e3/ac/2f2ab1d60bb2d795d054eebe5e3f24b164bc21b5a9b75fba7968b3b91b5a/sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27f90c55a65013cbb8f4d7aab0599bf925cde4adc67ae43a0d323677b5a1c6cb", size = 1259182 },
+ { url = "https://files.pythonhosted.org/packages/45/fb/14633c6ecf262c468759ffcdb55c3a7ee38fe4eda6a70d75ee7c7d63c58b/sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b293734059ef656dcd65be62ff771507bea8fed0a711b6733976e1ed3add4553", size = 1355537 },
+ { url = "https://files.pythonhosted.org/packages/fb/12/2f5c8d4764b00033cf1c935b702d3bb878d10be9f0b87f0253495832d85f/sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e58b47f933aca74c6a60a79dcb21d5b9e47416256c795c2d58d55cec27f9551d", size = 1301464 },
+ { url = "https://files.pythonhosted.org/packages/4e/b1/67afc0bde24f6dcb3acdea0dd8dcdf4b8b0db240f6bacd39378bd32d09f8/sentencepiece-0.2.0-cp311-cp311-win32.whl", hash = "sha256:c581258cf346b327c62c4f1cebd32691826306f6a41d8c4bec43b010dee08e75", size = 936749 },
+ { url = "https://files.pythonhosted.org/packages/a2/f6/587c62fd21fc988555b85351f50bbde43a51524caafd63bc69240ded14fd/sentencepiece-0.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:0993dbc665f4113017892f1b87c3904a44d0640eda510abcacdfb07f74286d36", size = 991520 },
+ { url = "https://files.pythonhosted.org/packages/27/5a/141b227ed54293360a9ffbb7bf8252b4e5efc0400cdeac5809340e5d2b21/sentencepiece-0.2.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea5f536e32ea8ec96086ee00d7a4a131ce583a1b18d130711707c10e69601cb2", size = 2409370 },
+ { url = "https://files.pythonhosted.org/packages/2e/08/a4c135ad6fc2ce26798d14ab72790d66e813efc9589fd30a5316a88ca8d5/sentencepiece-0.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0cb51f53b6aae3c36bafe41e86167c71af8370a039f542c43b0cce5ef24a68c", size = 1239288 },
+ { url = "https://files.pythonhosted.org/packages/49/0a/2fe387f825ac5aad5a0bfe221904882106cac58e1b693ba7818785a882b6/sentencepiece-0.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3212121805afc58d8b00ab4e7dd1f8f76c203ddb9dc94aa4079618a31cf5da0f", size = 1181597 },
+ { url = "https://files.pythonhosted.org/packages/cc/38/e4698ee2293fe4835dc033c49796a39b3eebd8752098f6bd0aa53a14af1f/sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a3149e3066c2a75e0d68a43eb632d7ae728c7925b517f4c05c40f6f7280ce08", size = 1259220 },
+ { url = "https://files.pythonhosted.org/packages/12/24/fd7ef967c9dad2f6e6e5386d0cadaf65cda8b7be6e3861a9ab3121035139/sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:632f3594d3e7ac8b367bca204cb3fd05a01d5b21455acd097ea4c0e30e2f63d7", size = 1355962 },
+ { url = "https://files.pythonhosted.org/packages/4f/d2/18246f43ca730bb81918f87b7e886531eda32d835811ad9f4657c54eee35/sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f295105c6bdbb05bd5e1b0cafbd78ff95036f5d3641e7949455a3f4e5e7c3109", size = 1301706 },
+ { url = "https://files.pythonhosted.org/packages/8a/47/ca237b562f420044ab56ddb4c278672f7e8c866e183730a20e413b38a989/sentencepiece-0.2.0-cp312-cp312-win32.whl", hash = "sha256:fb89f811e5efd18bab141afc3fea3de141c3f69f3fe9e898f710ae7fe3aab251", size = 936941 },
+ { url = "https://files.pythonhosted.org/packages/c6/97/d159c32642306ee2b70732077632895438867b3b6df282354bd550cf2a67/sentencepiece-0.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a673a72aab81fef5ebe755c6e0cc60087d1f3a4700835d40537183c1703a45f", size = 991994 },
+]
+
+[[package]]
+name = "sentry-sdk"
+version = "2.20.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "certifi" },
+ { name = "urllib3" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/68/e8/6a366c0cd5e129dda6ecb20ff097f70b18182c248d4c27e813c21f98992a/sentry_sdk-2.20.0.tar.gz", hash = "sha256:afa82713a92facf847df3c6f63cec71eb488d826a50965def3d7722aa6f0fdab", size = 300125 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e6/0f/6f7e6cd0f4a141752caef3f79300148422fdf2b8b68b531f30b2b0c0cbda/sentry_sdk-2.20.0-py2.py3-none-any.whl", hash = "sha256:c359a1edf950eb5e80cffd7d9111f3dbeef57994cb4415df37d39fda2cf22364", size = 322576 },
+]
+
+[[package]]
+name = "setproctitle"
+version = "1.3.4"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/ae/4e/b09341b19b9ceb8b4c67298ab4a08ef7a4abdd3016c7bb152e9b6379031d/setproctitle-1.3.4.tar.gz", hash = "sha256:3b40d32a3e1f04e94231ed6dfee0da9e43b4f9c6b5450d53e6dd7754c34e0c50", size = 26456 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/52/f4/95937eb5c5370324a942ba90174c6d0fc7c5ad2f7f8ea989ccdbd6e1be5e/setproctitle-1.3.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0f6661a69c68349172ba7b4d5dd65fec2b0917abc99002425ad78c3e58cf7595", size = 16855 },
+ { url = "https://files.pythonhosted.org/packages/32/a6/d49dbb0d75d02d11db49151469e1fee740efa45de7288bffcc4d88d0c290/setproctitle-1.3.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:754bac5e470adac7f7ec2239c485cd0b75f8197ca8a5b86ffb20eb3a3676cc42", size = 11627 },
+ { url = "https://files.pythonhosted.org/packages/2e/cd/73a0fc913db50c3b736750ce67824f1108c2173e5d043a16ef9874b4f988/setproctitle-1.3.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7bc7088c15150745baf66db62a4ced4507d44419eb66207b609f91b64a682af", size = 31187 },
+ { url = "https://files.pythonhosted.org/packages/63/0f/74f9112f7f506acc01f085811c6d135751b6fa42d30207f53b25579d043a/setproctitle-1.3.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a46ef3ecf61e4840fbc1145fdd38acf158d0da7543eda7b773ed2b30f75c2830", size = 32534 },
+ { url = "https://files.pythonhosted.org/packages/3b/88/53eec2373745069d4c8a59d41ee2ef4a48949b77cccd0077c270261b238a/setproctitle-1.3.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffcb09d5c0ffa043254ec9a734a73f3791fec8bf6333592f906bb2e91ed2af1a", size = 29657 },
+ { url = "https://files.pythonhosted.org/packages/50/1c/a4d3d8c20bf3bbafd8c5038e7da09043a9d21450b6a73694ada11c01b58a/setproctitle-1.3.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06c16b7a91cdc5d700271899e4383384a61aae83a3d53d0e2e5a266376083342", size = 30695 },
+ { url = "https://files.pythonhosted.org/packages/a2/2a/9f290f0d10ea87a266d63025078eabfa040ad29ea10d815e167a5104de00/setproctitle-1.3.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:9f9732e59863eaeedd3feef94b2b216cb86d40dda4fad2d0f0aaec3b31592716", size = 30340 },
+ { url = "https://files.pythonhosted.org/packages/38/c4/5bfe02d4cdd16338973d452c7c6042abdd2827d90f7ce4e21bc003f2edb1/setproctitle-1.3.4-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e152f4ab9ea1632b5fecdd87cee354f2b2eb6e2dfc3aceb0eb36a01c1e12f94c", size = 29352 },
+ { url = "https://files.pythonhosted.org/packages/b3/41/0dd85cef0e5a5a332bdda7b55738e950c2edecea3ae45c658990553d50f8/setproctitle-1.3.4-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:020ea47a79b2bbd7bd7b94b85ca956ba7cb026e82f41b20d2e1dac4008cead25", size = 31819 },
+ { url = "https://files.pythonhosted.org/packages/d7/23/fbfacfed8805983a83324099484e47b9028d0d3c07a0fe017123eee3f580/setproctitle-1.3.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8c52b12b10e4057fc302bd09cb3e3f28bb382c30c044eb3396e805179a8260e4", size = 29745 },
+ { url = "https://files.pythonhosted.org/packages/68/37/e18c5a00bfd1c4c2c815536d5c63a470e4364b571bd5096d38d0fe277bf5/setproctitle-1.3.4-cp310-cp310-win32.whl", hash = "sha256:a65a147f545f3fac86f11acb2d0b316d3e78139a9372317b7eb50561b2817ba0", size = 11358 },
+ { url = "https://files.pythonhosted.org/packages/52/fd/1fae8c4c13af22d8d17816c44421085509a08dfa77f573d31447d6cd540c/setproctitle-1.3.4-cp310-cp310-win_amd64.whl", hash = "sha256:66821fada6426998762a3650a37fba77e814a249a95b1183011070744aff47f6", size = 12072 },
+ { url = "https://files.pythonhosted.org/packages/5d/1a/1fb7d622195bcb3ce7b04366a833e51cfa5ad632c5dafe32e0763cd3fdc9/setproctitle-1.3.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f0f749f07002c2d6fecf37cedc43207a88e6c651926a470a5f229070cf791879", size = 16851 },
+ { url = "https://files.pythonhosted.org/packages/46/54/e3aa4f46eddf795f10452ea878ff85c3496d36409636530f9a37e2de3cbe/setproctitle-1.3.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:90ea8d302a5d30b948451d146e94674a3c5b020cc0ced9a1c28f8ddb0f203a5d", size = 11620 },
+ { url = "https://files.pythonhosted.org/packages/61/47/80988221679dfd93c464248abb71c2a96338f2ca3f8e3288d0ecb7422f4d/setproctitle-1.3.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f859c88193ed466bee4eb9d45fbc29d2253e6aa3ccd9119c9a1d8d95f409a60d", size = 31519 },
+ { url = "https://files.pythonhosted.org/packages/2c/72/14984c127f708597e412f1a8cf7cac809b9bca50a267a6b01b221b094330/setproctitle-1.3.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b3afa5a0ed08a477ded239c05db14c19af585975194a00adf594d48533b23701", size = 32860 },
+ { url = "https://files.pythonhosted.org/packages/16/9d/34ea09295620fddae65cf7caeac81bbfc386a3ae6ce26a4dcadbb54c134d/setproctitle-1.3.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:10a78fce9018cc3e9a772b6537bbe3fe92380acf656c9f86db2f45e685af376e", size = 30029 },
+ { url = "https://files.pythonhosted.org/packages/44/bf/a447a51054ceed23f69d4f7370289044b4508569f11da6db2eec087bc174/setproctitle-1.3.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d758e2eed2643afac5f2881542fbb5aa97640b54be20d0a5ed0691d02f0867d", size = 31017 },
+ { url = "https://files.pythonhosted.org/packages/ec/46/adcffde6fb8d95458da0a568afdf0dabbbff6470299d94014676e1ab43c0/setproctitle-1.3.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ef133a1a2ee378d549048a12d56f4ef0e2b9113b0b25b6b77821e9af94d50634", size = 30762 },
+ { url = "https://files.pythonhosted.org/packages/a3/cd/747a67ce1f6ef8fd1fa46b0b13ba0e007b80914bd549318830b8691ab9f6/setproctitle-1.3.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1d2a154b79d5fb42d1eff06e05e22f0e8091261d877dd47b37d31352b74ecc37", size = 29753 },
+ { url = "https://files.pythonhosted.org/packages/3d/86/5939546e57238462a7839ae78399a635d1cfc5d125c7a12a28face111730/setproctitle-1.3.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:202eae632815571297833876a0f407d0d9c7ad9d843b38adbe687fe68c5192ee", size = 32161 },
+ { url = "https://files.pythonhosted.org/packages/62/83/9194a4baed06e0e90a69e2e4a77a75e5a3ff008046870c79bc36a5c45e1c/setproctitle-1.3.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:2b0080819859e80a7776ac47cf6accb4b7ad313baf55fabac89c000480dcd103", size = 30104 },
+ { url = "https://files.pythonhosted.org/packages/ac/cd/08928fec23cbf4dae2a7b245b72d86e6458d64f4e7e6956cd80a9fda8c80/setproctitle-1.3.4-cp311-cp311-win32.whl", hash = "sha256:9c9d7d1267dee8c6627963d9376efa068858cfc8f573c083b1b6a2d297a8710f", size = 11349 },
+ { url = "https://files.pythonhosted.org/packages/aa/19/240c4b99d57e045d3b2e2effa5924e810eabb18c56ef9c2336a7746dffe4/setproctitle-1.3.4-cp311-cp311-win_amd64.whl", hash = "sha256:475986ddf6df65d619acd52188336a20f616589403f5a5ceb3fc70cdc137037a", size = 12071 },
+ { url = "https://files.pythonhosted.org/packages/94/1f/02fb3c6038c819d86765316d2a911281fc56c7dd3a9355dceb3f26a5bf7b/setproctitle-1.3.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:d06990dcfcd41bb3543c18dd25c8476fbfe1f236757f42fef560f6aa03ac8dfc", size = 16842 },
+ { url = "https://files.pythonhosted.org/packages/b8/0c/d69e1f91c8f3d3aa74394e9e6ebb818f7d323e2d138ce1127e9462d09ebc/setproctitle-1.3.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:317218c9d8b17a010ab2d2f0851e8ef584077a38b1ba2b7c55c9e44e79a61e73", size = 11614 },
+ { url = "https://files.pythonhosted.org/packages/86/ed/8031871d275302054b2f1b94b7cf5e850212cc412fe968f0979e64c1b838/setproctitle-1.3.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb5fefb53b9d9f334a5d9ec518a36b92a10b936011ac8a6b6dffd60135f16459", size = 31840 },
+ { url = "https://files.pythonhosted.org/packages/45/b7/04f5d221cbdcff35d6cdf74e2a852e69dc8d8e746eb1b314be6b57b79c41/setproctitle-1.3.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0855006261635e8669646c7c304b494b6df0a194d2626683520103153ad63cc9", size = 33271 },
+ { url = "https://files.pythonhosted.org/packages/25/b2/8dff0d2a72076e5535f117f33458d520538b5a0900b90a9f59a278f0d3f6/setproctitle-1.3.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a88e466fcaee659679c1d64dcb2eddbcb4bfadffeb68ba834d9c173a25b6184", size = 30509 },
+ { url = "https://files.pythonhosted.org/packages/4b/cf/4f19cdc7fdff3eaeb3064ce6eeb27c63081dba3123fbf904ac6bf0de440c/setproctitle-1.3.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f963b6ed8ba33eda374a98d979e8a0eaf21f891b6e334701693a2c9510613c4c", size = 31543 },
+ { url = "https://files.pythonhosted.org/packages/9b/a7/5f9c3c70dc5573f660f978fb3bb4847cd26ede95a5fc294d3f1cf6779800/setproctitle-1.3.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:122c2e05697fa91f5d23f00bbe98a9da1bd457b32529192e934095fadb0853f1", size = 31268 },
+ { url = "https://files.pythonhosted.org/packages/26/ab/bbde90ea0ed6a062ef94fe1c609b68077f7eb586133a62fa62d0c8dd9f8c/setproctitle-1.3.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:1bba0a866f5895d5b769d8c36b161271c7fd407e5065862ab80ff91c29fbe554", size = 30232 },
+ { url = "https://files.pythonhosted.org/packages/36/0e/817be9934eda4cf63c96c694c3383cb0d2e5d019a2871af7dbd2202f7a58/setproctitle-1.3.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:97f1f861998e326e640708488c442519ad69046374b2c3fe9bcc9869b387f23c", size = 32739 },
+ { url = "https://files.pythonhosted.org/packages/b0/76/9b4877850c9c5f41c4bacae441285dead7c192bebf4fcbf3b3eb0e8033cc/setproctitle-1.3.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:726aee40357d4bdb70115442cb85ccc8e8bc554fc0bbbaa3a57cbe81df42287d", size = 30778 },
+ { url = "https://files.pythonhosted.org/packages/b2/fa/bbc7ab32f253b9700ac20d78ba0d5fbdc4ea5789d33e1adb236cdf20b23a/setproctitle-1.3.4-cp312-cp312-win32.whl", hash = "sha256:04d6ba8b816dbb0bfd62000b0c3e583160893e6e8c4233e1dca1a9ae4d95d924", size = 11355 },
+ { url = "https://files.pythonhosted.org/packages/44/5c/6e6665b5fd800206a9e537ab0d2630d7b9b31b4697d931ed468837cc9cf5/setproctitle-1.3.4-cp312-cp312-win_amd64.whl", hash = "sha256:9c76e43cb351ba8887371240b599925cdf3ecececc5dfb7125c71678e7722c55", size = 12069 },
+ { url = "https://files.pythonhosted.org/packages/2f/d0/775418662081d44b91da236ed4503e10e7008c9c5fd30193e13db388fbef/setproctitle-1.3.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:939d364a187b2adfbf6ae488664277e717d56c7951a4ddeb4f23b281bc50bfe5", size = 11153 },
+ { url = "https://files.pythonhosted.org/packages/fd/1f/b3b82633336cd9908bf74cbc06dd533025b3d3c202437c4e3d0bc871ca13/setproctitle-1.3.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb8a6a19be0cbf6da6fcbf3698b76c8af03fe83e4bd77c96c3922be3b88bf7da", size = 13310 },
+ { url = "https://files.pythonhosted.org/packages/f5/89/887c6872ceed5ca344d25c8cc8a3f9b99bbcb25613c4b680476b29aabe23/setproctitle-1.3.4-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:779006f9e1aade9522a40e8d9635115ab15dd82b7af8e655967162e9c01e2573", size = 12911 },
+ { url = "https://files.pythonhosted.org/packages/b0/8d/9e4a4651b1c5845a9aec0d2c08c65768ba5ca2ec76598124b45d384a5f46/setproctitle-1.3.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5519f2a7b8c535b0f1f77b30441476571373add72008230c81211ee17b423b57", size = 12105 },
+]
+
+[[package]]
+name = "setuptools"
+version = "75.8.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/92/ec/089608b791d210aec4e7f97488e67ab0d33add3efccb83a056cbafe3a2a6/setuptools-75.8.0.tar.gz", hash = "sha256:c5afc8f407c626b8313a86e10311dd3f661c6cd9c09d4bf8c15c0e11f9f2b0e6", size = 1343222 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/69/8a/b9dc7678803429e4a3bc9ba462fa3dd9066824d3c607490235c6a796be5a/setuptools-75.8.0-py3-none-any.whl", hash = "sha256:e3982f444617239225d675215d51f6ba05f845d4eec313da4418fdbb56fb27e3", size = 1228782 },
+]
+
+[[package]]
+name = "shellingham"
+version = "1.5.4"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755 },
+]
+
+[[package]]
+name = "simple-slurm"
+version = "0.3.5"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/cc/57/dc898cd6ea39407cff09e89015a2bae1e42579da98a585cd3c6b796ec8d0/simple_slurm-0.3.5.tar.gz", hash = "sha256:a85c7f8cca25ece364b6ed19c0bfb9e194374cfb2ffc157f0d5811618e957458", size = 30936 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e6/f3/f9e247faca74934168674430f6ecb871df3958f63b8dab23638056c4ac03/simple_slurm-0.3.5-py3-none-any.whl", hash = "sha256:b6635407096de727bbc2d97b85037625305a80b09300fec24a2a3f233836671a", size = 29980 },
+]
+
+[[package]]
+name = "six"
+version = "1.17.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050 },
+]
+
+[[package]]
+name = "smmap"
+version = "5.0.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/44/cd/a040c4b3119bbe532e5b0732286f805445375489fceaec1f48306068ee3b/smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5", size = 22329 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303 },
+]
+
+[[package]]
+name = "sniffio"
+version = "1.3.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 },
+]
+
+[[package]]
+name = "soupsieve"
+version = "2.6"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/d7/ce/fbaeed4f9fb8b2daa961f90591662df6a86c1abf25c548329a86920aedfb/soupsieve-2.6.tar.gz", hash = "sha256:e2e68417777af359ec65daac1057404a3c8a5455bb8abc36f1a9866ab1a51abb", size = 101569 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d1/c2/fe97d779f3ef3b15f05c94a2f1e3d21732574ed441687474db9d342a7315/soupsieve-2.6-py3-none-any.whl", hash = "sha256:e72c4ff06e4fb6e4b5a9f0f55fe6e81514581fca1515028625d0f299c602ccc9", size = 36186 },
+]
+
+[[package]]
+name = "stack-data"
+version = "0.6.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "asttokens" },
+ { name = "executing" },
+ { name = "pure-eval" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/28/e3/55dcc2cfbc3ca9c29519eb6884dd1415ecb53b0e934862d3559ddcb7e20b/stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9", size = 44707 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521 },
+]
+
+[[package]]
+name = "starlette"
+version = "0.45.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "anyio" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/ff/fb/2984a686808b89a6781526129a4b51266f678b2d2b97ab2d325e56116df8/starlette-0.45.3.tar.gz", hash = "sha256:2cbcba2a75806f8a41c722141486f37c28e30a0921c5f6fe4346cb0dcee1302f", size = 2574076 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d9/61/f2b52e107b1fc8944b33ef56bf6ac4ebbe16d91b94d2b87ce013bf63fb84/starlette-0.45.3-py3-none-any.whl", hash = "sha256:dfb6d332576f136ec740296c7e8bb8c8a7125044e7c6da30744718880cdd059d", size = 71507 },
+]
+
+[[package]]
+name = "sympy"
+version = "1.13.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "mpmath" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/ca/99/5a5b6f19ff9f083671ddf7b9632028436167cd3d33e11015754e41b249a4/sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f", size = 7533040 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8", size = 6189177 },
+]
+
+[[package]]
+name = "t2ibenchmark"
+version = "0.1.0"
+source = { git = "ssh://git@github.com/boomb0om/text2image-benchmark.git?rev=532229f679d7e97ecba61914db7276f95733e707#532229f679d7e97ecba61914db7276f95733e707" }
+dependencies = [
+ { name = "datasets" },
+ { name = "ftfy" },
+ { name = "glob2" },
+ { name = "numpy" },
+ { name = "opencv-python" },
+ { name = "pillow" },
+ { name = "regex" },
+ { name = "scipy" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "torchvision", version = "0.21.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torchvision", version = "0.21.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "tqdm" },
+]
+
+[[package]]
+name = "tabulate"
+version = "0.9.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/ec/fe/802052aecb21e3797b8f7902564ab6ea0d60ff8ca23952079064155d1ae1/tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c", size = 81090 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252 },
+]
+
+[[package]]
+name = "tensorboard"
+version = "2.18.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "absl-py" },
+ { name = "grpcio" },
+ { name = "markdown" },
+ { name = "numpy" },
+ { name = "packaging" },
+ { name = "protobuf" },
+ { name = "setuptools" },
+ { name = "six" },
+ { name = "tensorboard-data-server" },
+ { name = "werkzeug" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b1/de/021c1d407befb505791764ad2cbd56ceaaa53a746baed01d2e2143f05f18/tensorboard-2.18.0-py3-none-any.whl", hash = "sha256:107ca4821745f73e2aefa02c50ff70a9b694f39f790b11e6f682f7d326745eab", size = 5503036 },
+]
+
+[[package]]
+name = "tensorboard-data-server"
+version = "0.7.2"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/7a/13/e503968fefabd4c6b2650af21e110aa8466fe21432cd7c43a84577a89438/tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb", size = 2356 },
+ { url = "https://files.pythonhosted.org/packages/b7/85/dabeaf902892922777492e1d253bb7e1264cadce3cea932f7ff599e53fea/tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60", size = 4823598 },
+ { url = "https://files.pythonhosted.org/packages/73/c6/825dab04195756cf8ff2e12698f22513b3db2f64925bdd41671bfb33aaa5/tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530", size = 6590363 },
+]
+
+[[package]]
+name = "tensordict"
+version = "0.7.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "cloudpickle" },
+ { name = "numpy" },
+ { name = "orjson" },
+ { name = "packaging" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/c7/dd/b1ab94161882a5691181470f45491000eb5dc44b07097c52dc31721955aa/tensordict-0.7.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e651ea0288f8c1651cb0cb26618afabe2a995c39949ecede2077e0ff045165fb", size = 708764 },
+ { url = "https://files.pythonhosted.org/packages/98/97/afb16f3076459839b0dc648f04af64ae6e54afd6ec1c1fc67644dfb938e0/tensordict-0.7.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:1ced5cb1104a92df0b88a92f5828e9f395b7dddc6cbc01d839ca3f8635338e20", size = 400457 },
+ { url = "https://files.pythonhosted.org/packages/82/ce/9bf7919f0f4a7b4e3a17a928a6f0bb0fdc254076c66a922427c2a03bcbdb/tensordict-0.7.2-cp310-cp310-win_amd64.whl", hash = "sha256:2991f0659bc60666e16b3f1075f4e164f87308e930b865e79f230c597d7ba3fb", size = 390665 },
+ { url = "https://files.pythonhosted.org/packages/7b/75/fe909cf5e156b8c26a3cb13cf53967aa69750ba19cf6968c8c96320ba465/tensordict-0.7.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:14c9ad93eb0330e162e51e053ae6ffebeb06f9de63263040da37fc3c2e580236", size = 710087 },
+ { url = "https://files.pythonhosted.org/packages/4e/2f/e0609cc783a31d43d4b595a3dae7c31c88d3c1c00109758698f522f67bc9/tensordict-0.7.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:b3ee17b965bbfab6d1e1ad619cefec0bf7d20c9c933bc82f320091a8ef43c457", size = 400739 },
+ { url = "https://files.pythonhosted.org/packages/5c/ff/5f11886023f5fe2319e21c33011627b01e73bd45df5815933e92973b07cb/tensordict-0.7.2-cp311-cp311-win_amd64.whl", hash = "sha256:db287eac7943d09e6146be4ba8018b868adddeda63a9474c7483caa6b23b997f", size = 391871 },
+ { url = "https://files.pythonhosted.org/packages/e9/19/9b6f2df7a5308db28122a2d6ef1b7e43ad1f6883798aeeb22ac87db4c181/tensordict-0.7.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:13b338995b76079dc797f0d21c2d34cb3fa32201e6de41aefcdcecf9c2e09395", size = 709183 },
+ { url = "https://files.pythonhosted.org/packages/ff/5d/78978f91e155946dc406b1b973fe03b094b3479540797274b2887d52f74e/tensordict-0.7.2-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:dad90ef545aa613df0706dc4bdb8d05fd2fc15c4b1a2f461ebf12bb3ec4b3381", size = 401328 },
+ { url = "https://files.pythonhosted.org/packages/ee/b9/cc234c1f85548e4fdc49fa1bc00a6fd8aae62b474554445d04d255739e1b/tensordict-0.7.2-cp312-cp312-win_amd64.whl", hash = "sha256:bc2148a21286db6f42cbc308b59d552f90338cb3d0702b3ee4f294c0321146e3", size = 392250 },
+]
+
+[[package]]
+name = "threadpoolctl"
+version = "3.5.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/bd/55/b5148dcbf72f5cde221f8bfe3b6a540da7aa1842f6b491ad979a6c8b84af/threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107", size = 41936 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/4b/2c/ffbf7a134b9ab11a67b0cf0726453cedd9c5043a4fe7a35d1cefa9a1bcfb/threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467", size = 18414 },
+]
+
+[[package]]
+name = "timm"
+version = "1.0.15"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "huggingface-hub" },
+ { name = "pyyaml" },
+ { name = "safetensors" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "torchvision", version = "0.21.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torchvision", version = "0.21.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/bc/0c/66b0f9b4a4cb9ffdac7b52b17b37c7d3c4f75623b469e388b0c6d89b4e88/timm-1.0.15.tar.gz", hash = "sha256:756a3bc30c96565f056e608a9b559daed904617eaadb6be536f96874879b1055", size = 2230258 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/6c/d0/179abca8b984b3deefd996f362b612c39da73b60f685921e6cd58b6125b4/timm-1.0.15-py3-none-any.whl", hash = "sha256:5a3dc460c24e322ecc7fd1f3e3eb112423ddee320cb059cc1956fbc9731748ef", size = 2361373 },
+]
+
+[[package]]
+name = "tokenizers"
+version = "0.21.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "huggingface-hub" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/20/41/c2be10975ca37f6ec40d7abd7e98a5213bb04f284b869c1a24e6504fd94d/tokenizers-0.21.0.tar.gz", hash = "sha256:ee0894bf311b75b0c03079f33859ae4b2334d675d4e93f5a4132e1eae2834fe4", size = 343021 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b0/5c/8b09607b37e996dc47e70d6a7b6f4bdd4e4d5ab22fe49d7374565c7fefaf/tokenizers-0.21.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3c4c93eae637e7d2aaae3d376f06085164e1660f89304c0ab2b1d08a406636b2", size = 2647461 },
+ { url = "https://files.pythonhosted.org/packages/22/7a/88e58bb297c22633ed1c9d16029316e5b5ac5ee44012164c2edede599a5e/tokenizers-0.21.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:f53ea537c925422a2e0e92a24cce96f6bc5046bbef24a1652a5edc8ba975f62e", size = 2563639 },
+ { url = "https://files.pythonhosted.org/packages/f7/14/83429177c19364df27d22bc096d4c2e431e0ba43e56c525434f1f9b0fd00/tokenizers-0.21.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b177fb54c4702ef611de0c069d9169f0004233890e0c4c5bd5508ae05abf193", size = 2903304 },
+ { url = "https://files.pythonhosted.org/packages/7e/db/3433eab42347e0dc5452d8fcc8da03f638c9accffefe5a7c78146666964a/tokenizers-0.21.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b43779a269f4629bebb114e19c3fca0223296ae9fea8bb9a7a6c6fb0657ff8e", size = 2804378 },
+ { url = "https://files.pythonhosted.org/packages/57/8b/7da5e6f89736c2ade02816b4733983fca1c226b0c42980b1ae9dc8fcf5cc/tokenizers-0.21.0-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aeb255802be90acfd363626753fda0064a8df06031012fe7d52fd9a905eb00e", size = 3095488 },
+ { url = "https://files.pythonhosted.org/packages/4d/f6/5ed6711093dc2c04a4e03f6461798b12669bc5a17c8be7cce1240e0b5ce8/tokenizers-0.21.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d8b09dbeb7a8d73ee204a70f94fc06ea0f17dcf0844f16102b9f414f0b7463ba", size = 3121410 },
+ { url = "https://files.pythonhosted.org/packages/81/42/07600892d48950c5e80505b81411044a2d969368cdc0d929b1c847bf6697/tokenizers-0.21.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:400832c0904f77ce87c40f1a8a27493071282f785724ae62144324f171377273", size = 3388821 },
+ { url = "https://files.pythonhosted.org/packages/22/06/69d7ce374747edaf1695a4f61b83570d91cc8bbfc51ccfecf76f56ab4aac/tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84ca973b3a96894d1707e189c14a774b701596d579ffc7e69debfc036a61a04", size = 3008868 },
+ { url = "https://files.pythonhosted.org/packages/c8/69/54a0aee4d576045b49a0eb8bffdc495634309c823bf886042e6f46b80058/tokenizers-0.21.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:eb7202d231b273c34ec67767378cd04c767e967fda12d4a9e36208a34e2f137e", size = 8975831 },
+ { url = "https://files.pythonhosted.org/packages/f7/f3/b776061e4f3ebf2905ba1a25d90380aafd10c02d406437a8ba22d1724d76/tokenizers-0.21.0-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:089d56db6782a73a27fd8abf3ba21779f5b85d4a9f35e3b493c7bbcbbf0d539b", size = 8920746 },
+ { url = "https://files.pythonhosted.org/packages/d8/ee/ce83d5ec8b6844ad4c3ecfe3333d58ecc1adc61f0878b323a15355bcab24/tokenizers-0.21.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:c87ca3dc48b9b1222d984b6b7490355a6fdb411a2d810f6f05977258400ddb74", size = 9161814 },
+ { url = "https://files.pythonhosted.org/packages/18/07/3e88e65c0ed28fa93aa0c4d264988428eef3df2764c3126dc83e243cb36f/tokenizers-0.21.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:4145505a973116f91bc3ac45988a92e618a6f83eb458f49ea0790df94ee243ff", size = 9357138 },
+ { url = "https://files.pythonhosted.org/packages/15/b0/dc4572ca61555fc482ebc933f26cb407c6aceb3dc19c301c68184f8cad03/tokenizers-0.21.0-cp39-abi3-win32.whl", hash = "sha256:eb1702c2f27d25d9dd5b389cc1f2f51813e99f8ca30d9e25348db6585a97e24a", size = 2202266 },
+ { url = "https://files.pythonhosted.org/packages/44/69/d21eb253fa91622da25585d362a874fa4710be600f0ea9446d8d0217cec1/tokenizers-0.21.0-cp39-abi3-win_amd64.whl", hash = "sha256:87841da5a25a3a5f70c102de371db120f41873b854ba65e52bccd57df5a3780c", size = 2389192 },
+]
+
+[[package]]
+name = "tomli"
+version = "2.2.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/18/87/302344fed471e44a87289cf4967697d07e532f2421fdaf868a303cbae4ff/tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff", size = 17175 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/43/ca/75707e6efa2b37c77dadb324ae7d9571cb424e61ea73fad7c56c2d14527f/tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249", size = 131077 },
+ { url = "https://files.pythonhosted.org/packages/c7/16/51ae563a8615d472fdbffc43a3f3d46588c264ac4f024f63f01283becfbb/tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6", size = 123429 },
+ { url = "https://files.pythonhosted.org/packages/f1/dd/4f6cd1e7b160041db83c694abc78e100473c15d54620083dbd5aae7b990e/tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a", size = 226067 },
+ { url = "https://files.pythonhosted.org/packages/a9/6b/c54ede5dc70d648cc6361eaf429304b02f2871a345bbdd51e993d6cdf550/tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee", size = 236030 },
+ { url = "https://files.pythonhosted.org/packages/1f/47/999514fa49cfaf7a92c805a86c3c43f4215621855d151b61c602abb38091/tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e", size = 240898 },
+ { url = "https://files.pythonhosted.org/packages/73/41/0a01279a7ae09ee1573b423318e7934674ce06eb33f50936655071d81a24/tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4", size = 229894 },
+ { url = "https://files.pythonhosted.org/packages/55/18/5d8bc5b0a0362311ce4d18830a5d28943667599a60d20118074ea1b01bb7/tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106", size = 245319 },
+ { url = "https://files.pythonhosted.org/packages/92/a3/7ade0576d17f3cdf5ff44d61390d4b3febb8a9fc2b480c75c47ea048c646/tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8", size = 238273 },
+ { url = "https://files.pythonhosted.org/packages/72/6f/fa64ef058ac1446a1e51110c375339b3ec6be245af9d14c87c4a6412dd32/tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff", size = 98310 },
+ { url = "https://files.pythonhosted.org/packages/6a/1c/4a2dcde4a51b81be3530565e92eda625d94dafb46dbeb15069df4caffc34/tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b", size = 108309 },
+ { url = "https://files.pythonhosted.org/packages/52/e1/f8af4c2fcde17500422858155aeb0d7e93477a0d59a98e56cbfe75070fd0/tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea", size = 132762 },
+ { url = "https://files.pythonhosted.org/packages/03/b8/152c68bb84fc00396b83e7bbddd5ec0bd3dd409db4195e2a9b3e398ad2e3/tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8", size = 123453 },
+ { url = "https://files.pythonhosted.org/packages/c8/d6/fc9267af9166f79ac528ff7e8c55c8181ded34eb4b0e93daa767b8841573/tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192", size = 233486 },
+ { url = "https://files.pythonhosted.org/packages/5c/51/51c3f2884d7bab89af25f678447ea7d297b53b5a3b5730a7cb2ef6069f07/tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222", size = 242349 },
+ { url = "https://files.pythonhosted.org/packages/ab/df/bfa89627d13a5cc22402e441e8a931ef2108403db390ff3345c05253935e/tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77", size = 252159 },
+ { url = "https://files.pythonhosted.org/packages/9e/6e/fa2b916dced65763a5168c6ccb91066f7639bdc88b48adda990db10c8c0b/tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6", size = 237243 },
+ { url = "https://files.pythonhosted.org/packages/b4/04/885d3b1f650e1153cbb93a6a9782c58a972b94ea4483ae4ac5cedd5e4a09/tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd", size = 259645 },
+ { url = "https://files.pythonhosted.org/packages/9c/de/6b432d66e986e501586da298e28ebeefd3edc2c780f3ad73d22566034239/tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e", size = 244584 },
+ { url = "https://files.pythonhosted.org/packages/1c/9a/47c0449b98e6e7d1be6cbac02f93dd79003234ddc4aaab6ba07a9a7482e2/tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98", size = 98875 },
+ { url = "https://files.pythonhosted.org/packages/ef/60/9b9638f081c6f1261e2688bd487625cd1e660d0a85bd469e91d8db969734/tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4", size = 109418 },
+ { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 },
+]
+
+[[package]]
+name = "torch"
+version = "2.6.0"
+source = { registry = "https://pypi.org/simple" }
+resolution-markers = [
+ "python_full_version >= '3.12.4' and sys_platform == 'darwin'",
+ "python_full_version >= '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'",
+ "python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform == 'darwin'",
+ "python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'",
+ "python_full_version == '3.11.*' and sys_platform == 'darwin'",
+ "python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'",
+ "python_full_version < '3.11' and sys_platform == 'darwin'",
+ "python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'",
+]
+dependencies = [
+ { name = "filelock", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "fsspec", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "jinja2", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "networkx", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "setuptools", marker = "python_full_version >= '3.12' and sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "sympy", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "typing-extensions", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e5/16/ea1b7842413a7b8a5aaa5e99e8eaf3da3183cc3ab345ad025a07ff636301/torch-2.6.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:09e06f9949e1a0518c5b09fe95295bc9661f219d9ecb6f9893e5123e10696628", size = 66520221 },
+ { url = "https://files.pythonhosted.org/packages/0b/fa/f33a4148c6fb46ca2a3f8de39c24d473822d5774d652b66ed9b1214da5f7/torch-2.6.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:94fc63b3b4bedd327af588696559f68c264440e2503cc9e6954019473d74ae21", size = 66530713 },
+ { url = "https://files.pythonhosted.org/packages/81/b4/605ae4173aa37fb5aa14605d100ff31f4f5d49f617928c9f486bb3aaec08/torch-2.6.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989", size = 66532538 },
+]
+
+[[package]]
+name = "torch"
+version = "2.6.0+cu124"
+source = { registry = "https://download.pytorch.org/whl/cu124" }
+resolution-markers = [
+ "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
+ "python_full_version >= '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux'",
+ "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
+ "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux'",
+ "python_full_version >= '3.12.4' and sys_platform == 'win32'",
+ "python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform == 'win32'",
+ "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
+ "python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux'",
+ "python_full_version == '3.11.*' and sys_platform == 'win32'",
+ "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'",
+ "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux'",
+ "python_full_version < '3.11' and sys_platform == 'win32'",
+]
+dependencies = [
+ { name = "filelock", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "fsspec", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "jinja2", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "networkx", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform == 'win32')" },
+ { name = "sympy", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+]
+wheels = [
+ { url = "https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp310-cp310-linux_x86_64.whl", hash = "sha256:7f2ba7f7c0459320a521696f6b5bccc187f59890b23c9dfb6c49b0b87c6bfc97" },
+ { url = "https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp310-cp310-win_amd64.whl", hash = "sha256:7cc45c5b39d74875cfafe908b7f55c544147cc16b01e795feb2fe766583efe78" },
+ { url = "https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp311-cp311-linux_x86_64.whl", hash = "sha256:d4c3e9a8d31a7c0fcbb9da17c31a1917e1fac26c566a4cfbd8c9568ad7cade79" },
+ { url = "https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp311-cp311-win_amd64.whl", hash = "sha256:6a1fb2714e9323f11edb6e8abf7aad5f79e45ad25c081cde87681a18d99c29eb" },
+ { url = "https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp312-cp312-linux_x86_64.whl", hash = "sha256:a393b506844035c0dac2f30ea8478c343b8e95a429f06f3b3cadfc7f53adb597" },
+ { url = "https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp312-cp312-win_amd64.whl", hash = "sha256:3313061c1fec4c7310cf47944e84513dcd27b6173b72a349bb7ca68d0ee6e9c0" },
+]
+
+[[package]]
+name = "torchinfo"
+version = "1.8.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/53/d9/2b811d1c0812e9ef23e6cf2dbe022becbe6c5ab065e33fd80ee05c0cd996/torchinfo-1.8.0.tar.gz", hash = "sha256:72e94b0e9a3e64dc583a8e5b7940b8938a1ac0f033f795457f27e6f4e7afa2e9", size = 25880 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/72/25/973bd6128381951b23cdcd8a9870c6dcfc5606cb864df8eabd82e529f9c1/torchinfo-1.8.0-py3-none-any.whl", hash = "sha256:2e911c2918603f945c26ff21a3a838d12709223dc4ccf243407bce8b6e897b46", size = 23377 },
+]
+
+[[package]]
+name = "torchmetrics"
+version = "1.6.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "lightning-utilities" },
+ { name = "numpy" },
+ { name = "packaging" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/14/c5/8d916585d4d6eb158105c21b28cd4b0ed296d74e499bf8f104368de16619/torchmetrics-1.6.1.tar.gz", hash = "sha256:a5dc236694b392180949fdd0a0fcf2b57135c8b600e557c725e077eb41e53e64", size = 540022 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/9d/e1/84066ff60a20dfa63f4d9d8ddc280d5ed323b7f06504dbb51c523b690116/torchmetrics-1.6.1-py3-none-any.whl", hash = "sha256:c3090aa2341129e994c0a659abb6d4140ae75169a6ebf45bffc16c5cb553b38e", size = 927305 },
+]
+
+[[package]]
+name = "torchtnt"
+version = "0.2.4"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "fsspec" },
+ { name = "numpy" },
+ { name = "packaging" },
+ { name = "psutil" },
+ { name = "pyre-extensions" },
+ { name = "setuptools" },
+ { name = "tabulate" },
+ { name = "tensorboard" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "tqdm" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/3c/b9/c440c3bf39edddf1b72cbf88245a8ca3345caae0cbeabb2601ddcc679be5/torchtnt-0.2.4.tar.gz", hash = "sha256:26cf4e718965afc293e76158b47283722055bcf79e94d0e67e70b5dbf61c0c9b", size = 115176 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b1/20/d8b33d11f38515ef88a84ac32e1809de110440278cd35c30d3f1ca511397/torchtnt-0.2.4-py3-none-any.whl", hash = "sha256:c9b738232090c5e3453f202b1feb45e23d5cd2f23a04ca8b53e9b28ad3755ddb", size = 163459 },
+]
+
+[[package]]
+name = "torchvision"
+version = "0.21.0"
+source = { registry = "https://pypi.org/simple" }
+resolution-markers = [
+ "python_full_version >= '3.12.4' and sys_platform == 'darwin'",
+ "python_full_version >= '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'",
+ "python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform == 'darwin'",
+ "python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'",
+ "python_full_version == '3.11.*' and sys_platform == 'darwin'",
+ "python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'",
+ "python_full_version < '3.11' and sys_platform == 'darwin'",
+ "python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'",
+]
+dependencies = [
+ { name = "numpy", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "pillow", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/8e/0d/143bd264876fad17c82096b6c2d433f1ac9b29cdc69ee45023096976ee3d/torchvision-0.21.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:044ea420b8c6c3162a234cada8e2025b9076fa82504758cd11ec5d0f8cd9fa37", size = 1784140 },
+ { url = "https://files.pythonhosted.org/packages/29/88/00c69db213ee2443ada8886ec60789b227e06bb869d85ee324578221a7f7/torchvision-0.21.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:110d115333524d60e9e474d53c7d20f096dbd8a080232f88dddb90566f90064c", size = 1784141 },
+ { url = "https://files.pythonhosted.org/packages/6e/1b/28f527b22d5e8800184d0bc847f801ae92c7573a8c15979d92b7091c0751/torchvision-0.21.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:97a5814a93c793aaf0179cfc7f916024f4b63218929aee977b645633d074a49f", size = 1784140 },
+]
+
+[[package]]
+name = "torchvision"
+version = "0.21.0+cu124"
+source = { registry = "https://download.pytorch.org/whl/cu124" }
+resolution-markers = [
+ "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
+ "python_full_version >= '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux'",
+ "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
+ "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux'",
+ "python_full_version >= '3.12.4' and sys_platform == 'win32'",
+ "python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform == 'win32'",
+ "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
+ "python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux'",
+ "python_full_version == '3.11.*' and sys_platform == 'win32'",
+ "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'",
+ "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux'",
+ "python_full_version < '3.11' and sys_platform == 'win32'",
+]
+dependencies = [
+ { name = "numpy", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "pillow", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+]
+wheels = [
+ { url = "https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp310-cp310-linux_x86_64.whl", hash = "sha256:3d3e74018eaa7837c73e3764dad3b7792b7544401c25a42977e9744303731bd3" },
+ { url = "https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp310-cp310-win_amd64.whl", hash = "sha256:0c6aefb70ab2b312065240c804e459ac7b0e449867afd469b38d2fd47f9391a7" },
+ { url = "https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp311-cp311-linux_x86_64.whl", hash = "sha256:137376805aca5ba57bd2c7a3ecb8569df961dbe82b128aac9b3b0a7125ef9385" },
+ { url = "https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp311-cp311-win_amd64.whl", hash = "sha256:000a013584ad2304ab30496318145f284ac364622addb5ee3a5abd2769ba146f" },
+ { url = "https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp312-cp312-linux_x86_64.whl", hash = "sha256:efb53ea0af7bf09b7b53e2a18b9be6d245f7d46a90b51d5cf97f37e9b929a991" },
+ { url = "https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp312-cp312-win_amd64.whl", hash = "sha256:ec63c2ee792757492da40590e34b14f2fceda29050558c215f0c1f3b08149c0f" },
+]
+
+[[package]]
+name = "tqdm"
+version = "4.67.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "colorama", marker = "sys_platform == 'win32'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540 },
+]
+
+[[package]]
+name = "traitlets"
+version = "5.14.3"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/eb/79/72064e6a701c2183016abbbfedaba506d81e30e232a68c9f0d6f6fcd1574/traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7", size = 161621 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359 },
+]
+
+[[package]]
+name = "transformers"
+version = "4.49.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "filelock" },
+ { name = "huggingface-hub" },
+ { name = "numpy" },
+ { name = "packaging" },
+ { name = "pyyaml" },
+ { name = "regex" },
+ { name = "requests" },
+ { name = "safetensors" },
+ { name = "tokenizers" },
+ { name = "tqdm" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/79/50/46573150944f46df8ec968eda854023165a84470b42f69f67c7d475dabc5/transformers-4.49.0.tar.gz", hash = "sha256:7e40e640b5b8dc3f48743f5f5adbdce3660c82baafbd3afdfc04143cdbd2089e", size = 8610952 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/20/37/1f29af63e9c30156a3ed6ebc2754077016577c094f31de7b2631e5d379eb/transformers-4.49.0-py3-none-any.whl", hash = "sha256:6b4fded1c5fee04d384b1014495b4235a2b53c87503d7d592423c06128cbbe03", size = 9970275 },
+]
+
+[[package]]
+name = "triton"
+version = "3.2.0"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/01/65/3ffa90e158a2c82f0716eee8d26a725d241549b7d7aaf7e4f44ac03ebd89/triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3e54983cd51875855da7c68ec05c05cf8bb08df361b1d5b69e05e40b0c9bd62", size = 253090354 },
+ { url = "https://files.pythonhosted.org/packages/a7/2e/757d2280d4fefe7d33af7615124e7e298ae7b8e3bc4446cdb8e88b0f9bab/triton-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8009a1fb093ee8546495e96731336a33fb8856a38e45bb4ab6affd6dbc3ba220", size = 253157636 },
+ { url = "https://files.pythonhosted.org/packages/06/00/59500052cb1cf8cf5316be93598946bc451f14072c6ff256904428eaf03c/triton-3.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d9b215efc1c26fa7eefb9a157915c92d52e000d2bf83e5f69704047e63f125c", size = 253159365 },
+]
+
+[[package]]
+name = "typer"
+version = "0.15.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "click" },
+ { name = "rich" },
+ { name = "shellingham" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/cb/ce/dca7b219718afd37a0068f4f2530a727c2b74a8b6e8e0c0080a4c0de4fcd/typer-0.15.1.tar.gz", hash = "sha256:a0588c0a7fa68a1978a069818657778f86abe6ff5ea6abf472f940a08bfe4f0a", size = 99789 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d0/cc/0a838ba5ca64dc832aa43f727bd586309846b0ffb2ce52422543e6075e8a/typer-0.15.1-py3-none-any.whl", hash = "sha256:7994fb7b8155b64d3402518560648446072864beefd44aa2dc36972a5972e847", size = 44908 },
+]
+
+[[package]]
+name = "typing-extensions"
+version = "4.12.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/df/db/f35a00659bc03fec321ba8bce9420de607a1d37f8342eee1863174c69557/typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8", size = 85321 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d", size = 37438 },
+]
+
+[[package]]
+name = "typing-inspect"
+version = "0.9.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "mypy-extensions" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/dc/74/1789779d91f1961fa9438e9a8710cdae6bd138c80d7303996933d117264a/typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78", size = 13825 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/65/f3/107a22063bf27bdccf2024833d3445f4eea42b2e598abfbd46f6a63b6cb0/typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f", size = 8827 },
+]
+
+[[package]]
+name = "tzdata"
+version = "2025.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/43/0f/fa4723f22942480be4ca9527bbde8d43f6c3f2fe8412f00e7f5f6746bc8b/tzdata-2025.1.tar.gz", hash = "sha256:24894909e88cdb28bd1636c6887801df64cb485bd593f2fd83ef29075a81d694", size = 194950 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/0f/dd/84f10e23edd882c6f968c21c2434fe67bd4a528967067515feca9e611e5e/tzdata-2025.1-py2.py3-none-any.whl", hash = "sha256:7e127113816800496f027041c570f50bcd464a020098a3b6b199517772303639", size = 346762 },
+]
+
+[[package]]
+name = "unidisc"
+version = "0.0.1"
+source = { virtual = "." }
+dependencies = [
+ { name = "accelerate" },
+ { name = "datasets" },
+ { name = "diffusers" },
+ { name = "einops" },
+ { name = "fsspec" },
+ { name = "hf-transfer" },
+ { name = "hydra-core" },
+ { name = "image-utilities" },
+ { name = "ipdb" },
+ { name = "ipython" },
+ { name = "jaxtyping" },
+ { name = "lightning-utilities" },
+ { name = "lovely-tensors" },
+ { name = "ml-collections" },
+ { name = "numpy" },
+ { name = "omegaconf" },
+ { name = "pandas" },
+ { name = "rich" },
+ { name = "scikit-learn" },
+ { name = "sentencepiece" },
+ { name = "setuptools" },
+ { name = "tensordict" },
+ { name = "timm" },
+ { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "torchinfo" },
+ { name = "torchmetrics" },
+ { name = "torchtnt" },
+ { name = "torchvision", version = "0.21.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
+ { name = "torchvision", version = "0.21.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
+ { name = "transformers" },
+ { name = "typer" },
+ { name = "wandb" },
+]
+
+[package.dev-dependencies]
+dev = [
+ { name = "braceexpand" },
+ { name = "clean-fid" },
+ { name = "clip" },
+ { name = "evaluate" },
+ { name = "fastapi" },
+ { name = "flash-attn" },
+ { name = "h5py" },
+ { name = "hpsv2x" },
+ { name = "mauve-text" },
+ { name = "monsterui" },
+ { name = "open-clip-torch" },
+ { name = "peft" },
+ { name = "pynvml" },
+ { name = "python-fasthtml" },
+ { name = "t2ibenchmark" },
+]
+misc = [
+ { name = "bitsandbytes" },
+ { name = "deepspeed" },
+ { name = "flask" },
+ { name = "ftfy" },
+ { name = "lpips" },
+ { name = "opencv-python" },
+ { name = "requests" },
+ { name = "sentence-transformers" },
+ { name = "simple-slurm" },
+ { name = "werkzeug" },
+]
+
+[package.metadata]
+requires-dist = [
+ { name = "accelerate", specifier = "~=1.5.2" },
+ { name = "datasets", specifier = "~=3.2.0" },
+ { name = "diffusers", specifier = "~=0.32.2" },
+ { name = "einops", specifier = "~=0.8.0" },
+ { name = "fsspec" },
+ { name = "hf-transfer" },
+ { name = "hydra-core", specifier = "~=1.3.2" },
+ { name = "image-utilities", specifier = "==0.0.3.*" },
+ { name = "ipdb" },
+ { name = "ipython" },
+ { name = "jaxtyping", specifier = "~=0.2.37" },
+ { name = "lightning-utilities", specifier = "~=0.12.0" },
+ { name = "lovely-tensors" },
+ { name = "ml-collections" },
+ { name = "numpy", specifier = "~=2.2" },
+ { name = "omegaconf", specifier = "~=2.3.0" },
+ { name = "pandas" },
+ { name = "rich", specifier = "~=13.9.4" },
+ { name = "scikit-learn" },
+ { name = "sentencepiece" },
+ { name = "setuptools", specifier = ">=75.8.0" },
+ { name = "tensordict", specifier = "~=0.7.2" },
+ { name = "timm", specifier = "~=1.0.15" },
+ { name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=2.6.0" },
+ { name = "torch", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = ">=2.6.0", index = "https://download.pytorch.org/whl/cu124" },
+ { name = "torchinfo" },
+ { name = "torchmetrics", specifier = "==1.6.1" },
+ { name = "torchtnt", specifier = "~=0.2.4" },
+ { name = "torchvision", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=0.21.0" },
+ { name = "torchvision", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = ">=0.21.0", index = "https://download.pytorch.org/whl/cu124" },
+ { name = "transformers", specifier = "~=4.49.0" },
+ { name = "typer", specifier = "~=0.15.1" },
+ { name = "wandb", specifier = "~=0.19.6" },
+]
+
+[package.metadata.requires-dev]
+dev = [
+ { name = "braceexpand" },
+ { name = "clean-fid" },
+ { name = "clip", git = "ssh://git@github.com/openai/CLIP.git?rev=dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1" },
+ { name = "evaluate" },
+ { name = "fastapi", specifier = "~=0.115.8" },
+ { name = "flash-attn", specifier = "~=2.7.4" },
+ { name = "h5py" },
+ { name = "hpsv2x", specifier = "==1.2.0" },
+ { name = "mauve-text" },
+ { name = "monsterui", specifier = "~=0.0.34" },
+ { name = "open-clip-torch" },
+ { name = "peft" },
+ { name = "pynvml" },
+ { name = "python-fasthtml", specifier = "~=0.12.1" },
+ { name = "t2ibenchmark", git = "ssh://git@github.com/boomb0om/text2image-benchmark.git?rev=532229f679d7e97ecba61914db7276f95733e707" },
+]
+misc = [
+ { name = "bitsandbytes" },
+ { name = "deepspeed" },
+ { name = "flask" },
+ { name = "ftfy" },
+ { name = "lpips" },
+ { name = "opencv-python" },
+ { name = "requests" },
+ { name = "sentence-transformers" },
+ { name = "simple-slurm" },
+ { name = "werkzeug" },
+]
+
+[[package]]
+name = "urllib3"
+version = "2.3.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/aa/63/e53da845320b757bf29ef6a9062f5c669fe997973f966045cb019c3f4b66/urllib3-2.3.0.tar.gz", hash = "sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d", size = 307268 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 },
+]
+
+[[package]]
+name = "uvicorn"
+version = "0.34.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "click" },
+ { name = "h11" },
+ { name = "typing-extensions", marker = "python_full_version < '3.11'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/4b/4d/938bd85e5bf2edeec766267a5015ad969730bb91e31b44021dfe8b22df6c/uvicorn-0.34.0.tar.gz", hash = "sha256:404051050cd7e905de2c9a7e61790943440b3416f49cb409f965d9dcd0fa73e9", size = 76568 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/61/14/33a3a1352cfa71812a3a21e8c9bfb83f60b0011f5e36f2b1399d51928209/uvicorn-0.34.0-py3-none-any.whl", hash = "sha256:023dc038422502fa28a09c7a30bf2b6991512da7dcdb8fd35fe57cfc154126f4", size = 62315 },
+]
+
+[package.optional-dependencies]
+standard = [
+ { name = "colorama", marker = "sys_platform == 'win32'" },
+ { name = "httptools" },
+ { name = "python-dotenv" },
+ { name = "pyyaml" },
+ { name = "uvloop", marker = "platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'" },
+ { name = "watchfiles" },
+ { name = "websockets" },
+]
+
+[[package]]
+name = "uvloop"
+version = "0.21.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/af/c0/854216d09d33c543f12a44b393c402e89a920b1a0a7dc634c42de91b9cf6/uvloop-0.21.0.tar.gz", hash = "sha256:3bf12b0fda68447806a7ad847bfa591613177275d35b6724b1ee573faa3704e3", size = 2492741 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/3d/76/44a55515e8c9505aa1420aebacf4dd82552e5e15691654894e90d0bd051a/uvloop-0.21.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ec7e6b09a6fdded42403182ab6b832b71f4edaf7f37a9a0e371a01db5f0cb45f", size = 1442019 },
+ { url = "https://files.pythonhosted.org/packages/35/5a/62d5800358a78cc25c8a6c72ef8b10851bdb8cca22e14d9c74167b7f86da/uvloop-0.21.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:196274f2adb9689a289ad7d65700d37df0c0930fd8e4e743fa4834e850d7719d", size = 801898 },
+ { url = "https://files.pythonhosted.org/packages/f3/96/63695e0ebd7da6c741ccd4489b5947394435e198a1382349c17b1146bb97/uvloop-0.21.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f38b2e090258d051d68a5b14d1da7203a3c3677321cf32a95a6f4db4dd8b6f26", size = 3827735 },
+ { url = "https://files.pythonhosted.org/packages/61/e0/f0f8ec84979068ffae132c58c79af1de9cceeb664076beea86d941af1a30/uvloop-0.21.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87c43e0f13022b998eb9b973b5e97200c8b90823454d4bc06ab33829e09fb9bb", size = 3825126 },
+ { url = "https://files.pythonhosted.org/packages/bf/fe/5e94a977d058a54a19df95f12f7161ab6e323ad49f4dabc28822eb2df7ea/uvloop-0.21.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:10d66943def5fcb6e7b37310eb6b5639fd2ccbc38df1177262b0640c3ca68c1f", size = 3705789 },
+ { url = "https://files.pythonhosted.org/packages/26/dd/c7179618e46092a77e036650c1f056041a028a35c4d76945089fcfc38af8/uvloop-0.21.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:67dd654b8ca23aed0a8e99010b4c34aca62f4b7fce88f39d452ed7622c94845c", size = 3800523 },
+ { url = "https://files.pythonhosted.org/packages/57/a7/4cf0334105c1160dd6819f3297f8700fda7fc30ab4f61fbf3e725acbc7cc/uvloop-0.21.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c0f3fa6200b3108919f8bdabb9a7f87f20e7097ea3c543754cabc7d717d95cf8", size = 1447410 },
+ { url = "https://files.pythonhosted.org/packages/8c/7c/1517b0bbc2dbe784b563d6ab54f2ef88c890fdad77232c98ed490aa07132/uvloop-0.21.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0878c2640cf341b269b7e128b1a5fed890adc4455513ca710d77d5e93aa6d6a0", size = 805476 },
+ { url = "https://files.pythonhosted.org/packages/ee/ea/0bfae1aceb82a503f358d8d2fa126ca9dbdb2ba9c7866974faec1cb5875c/uvloop-0.21.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9fb766bb57b7388745d8bcc53a359b116b8a04c83a2288069809d2b3466c37e", size = 3960855 },
+ { url = "https://files.pythonhosted.org/packages/8a/ca/0864176a649838b838f36d44bf31c451597ab363b60dc9e09c9630619d41/uvloop-0.21.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a375441696e2eda1c43c44ccb66e04d61ceeffcd76e4929e527b7fa401b90fb", size = 3973185 },
+ { url = "https://files.pythonhosted.org/packages/30/bf/08ad29979a936d63787ba47a540de2132169f140d54aa25bc8c3df3e67f4/uvloop-0.21.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:baa0e6291d91649c6ba4ed4b2f982f9fa165b5bbd50a9e203c416a2797bab3c6", size = 3820256 },
+ { url = "https://files.pythonhosted.org/packages/da/e2/5cf6ef37e3daf2f06e651aae5ea108ad30df3cb269102678b61ebf1fdf42/uvloop-0.21.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4509360fcc4c3bd2c70d87573ad472de40c13387f5fda8cb58350a1d7475e58d", size = 3937323 },
+ { url = "https://files.pythonhosted.org/packages/8c/4c/03f93178830dc7ce8b4cdee1d36770d2f5ebb6f3d37d354e061eefc73545/uvloop-0.21.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:359ec2c888397b9e592a889c4d72ba3d6befba8b2bb01743f72fffbde663b59c", size = 1471284 },
+ { url = "https://files.pythonhosted.org/packages/43/3e/92c03f4d05e50f09251bd8b2b2b584a2a7f8fe600008bcc4523337abe676/uvloop-0.21.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f7089d2dc73179ce5ac255bdf37c236a9f914b264825fdaacaded6990a7fb4c2", size = 821349 },
+ { url = "https://files.pythonhosted.org/packages/a6/ef/a02ec5da49909dbbfb1fd205a9a1ac4e88ea92dcae885e7c961847cd51e2/uvloop-0.21.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:baa4dcdbd9ae0a372f2167a207cd98c9f9a1ea1188a8a526431eef2f8116cc8d", size = 4580089 },
+ { url = "https://files.pythonhosted.org/packages/06/a7/b4e6a19925c900be9f98bec0a75e6e8f79bb53bdeb891916609ab3958967/uvloop-0.21.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86975dca1c773a2c9864f4c52c5a55631038e387b47eaf56210f873887b6c8dc", size = 4693770 },
+ { url = "https://files.pythonhosted.org/packages/ce/0c/f07435a18a4b94ce6bd0677d8319cd3de61f3a9eeb1e5f8ab4e8b5edfcb3/uvloop-0.21.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:461d9ae6660fbbafedd07559c6a2e57cd553b34b0065b6550685f6653a98c1cb", size = 4451321 },
+ { url = "https://files.pythonhosted.org/packages/8f/eb/f7032be105877bcf924709c97b1bf3b90255b4ec251f9340cef912559f28/uvloop-0.21.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:183aef7c8730e54c9a3ee3227464daed66e37ba13040bb3f350bc2ddc040f22f", size = 4659022 },
+]
+
+[[package]]
+name = "wadler-lindig"
+version = "0.1.3"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/7a/a2/e8fc843e14aec55d18572a93c5443cc89a6b7d537b90d804f7b373301d9f/wadler_lindig-0.1.3.tar.gz", hash = "sha256:476fb7015135f714cef8f8eac7c44b164c8b993345e651a9b6f25b7b112440c9", size = 15197 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/39/3b/5b918a0da0d6920e7f7328cf0ab00df31b905d709f458596304f09096785/wadler_lindig-0.1.3-py3-none-any.whl", hash = "sha256:3018e4e6b115a7ef21c77414a41cbe7e03e83f6b5e25004958e33432a17f3c94", size = 20140 },
+]
+
+[[package]]
+name = "wandb"
+version = "0.19.6"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "click" },
+ { name = "docker-pycreds" },
+ { name = "gitpython" },
+ { name = "platformdirs" },
+ { name = "protobuf" },
+ { name = "psutil" },
+ { name = "pydantic" },
+ { name = "pyyaml" },
+ { name = "requests" },
+ { name = "sentry-sdk" },
+ { name = "setproctitle" },
+ { name = "setuptools" },
+ { name = "typing-extensions", marker = "python_full_version < '3.12'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/41/a2/63fbebc6ed670a7d834ca76552b8c6382211874b23ee8a718ba26a342a4a/wandb-0.19.6.tar.gz", hash = "sha256:4661856ee070fe8a123caece5b372d495d3cf9f58176a8f981bd716830eefc49", size = 39203528 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/bd/4f/5b77e20f10e643404df871557610a6618383e036de65e9c34b3a8354f2ac/wandb-0.19.6-py3-none-any.whl", hash = "sha256:0b174b5f190999a8238961c63c134622bf2173147a1301ea298a9ec58abbd7d4", size = 6387720 },
+ { url = "https://files.pythonhosted.org/packages/25/aa/824a171586f3fa1549f9f946d32187362c8d06ff67540d9f1be694ee9094/wandb-0.19.6-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:ad2887dd916207ead5a9f36e4aebc1b6624265f29033e4e883bb6fbd5b674080", size = 20776552 },
+ { url = "https://files.pythonhosted.org/packages/ad/3b/222e2a27ee3df3a973d8f165fa47f3e3bb25dc6d9ac1d3ec79b083c5ee09/wandb-0.19.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:ca90dd5519de1a48963536f02d6e14c150475807173b7af1d8ebe3e2f9e3afba", size = 19933524 },
+ { url = "https://files.pythonhosted.org/packages/65/76/1d69145ac3c9c6b63545e684c39b95711c3632c34d452626fd831227089d/wandb-0.19.6-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:3cb10bd1e1c0b568464a017c88eb95e0c8c3e9c1283d9ad4ee717c8977d491c1", size = 20791479 },
+ { url = "https://files.pythonhosted.org/packages/88/96/4411c4aa29cfb0bc8e310480181d79779b423231420bbcf5e61ff8c44ff7/wandb-0.19.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fe6e7bedd396b2b5f92c7fab3d364f7e0e8cb9f645d0f0c27ba7be94e720931", size = 19539263 },
+ { url = "https://files.pythonhosted.org/packages/bc/89/2e414951d35e55caf6d8ac5758a82c61c1b8330f77852fbc733c833196eb/wandb-0.19.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd9ae9a7f08e4d3972ba341c42af787e951689e0d1a76c111aa66d09bcdadafd", size = 20861187 },
+ { url = "https://files.pythonhosted.org/packages/3a/5e/7517c9fa9aa0075160c04e467f6d0e5d1b9bb6b91c4ffd6dd6fa23dd3dd0/wandb-0.19.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ff0973ca26cd06bc5451ae7ba469ad98f74024f5678dfa0d6dc78ca36eb950b6", size = 19549095 },
+ { url = "https://files.pythonhosted.org/packages/bd/be/ef3c78ab14a631558f639ab3a8379efee6f7d529e3bbf9efb0e17472495b/wandb-0.19.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:2e8dc997eb3ae5f22f5a1c3d4f3b30c28398dda45b9dbada9ff20b8d3984d3e2", size = 20938943 },
+ { url = "https://files.pythonhosted.org/packages/b6/43/2f9c71a1fe77a97e9d32b4828f1dd685ac545442f8dfbf703eac8128056f/wandb-0.19.6-py3-none-win32.whl", hash = "sha256:c0127d99e98202dc2471d44b920129c2c9242fb3a6b52a7aa8bbf9ffa35173e7", size = 20230403 },
+ { url = "https://files.pythonhosted.org/packages/fd/b2/a9ffa91c43dbe2a6687467f3aa196947b7532592879738665be5c0db17c3/wandb-0.19.6-py3-none-win_amd64.whl", hash = "sha256:8688a4f724d37a90075312e8dccffd948adbe8b6bcb82f9d2b38b764b53269fb", size = 20230407 },
+]
+
+[[package]]
+name = "watchfiles"
+version = "1.0.4"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "anyio" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/f5/26/c705fc77d0a9ecdb9b66f1e2976d95b81df3cae518967431e7dbf9b5e219/watchfiles-1.0.4.tar.gz", hash = "sha256:6ba473efd11062d73e4f00c2b730255f9c1bdd73cd5f9fe5b5da8dbd4a717205", size = 94625 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/14/02/22fcaed0396730b0d362bc8d1ffb3be2658fd473eecbb2ba84243e157f11/watchfiles-1.0.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:ba5bb3073d9db37c64520681dd2650f8bd40902d991e7b4cfaeece3e32561d08", size = 395212 },
+ { url = "https://files.pythonhosted.org/packages/e9/3d/ec5a2369a46edf3ebe092c39d9ae48e8cb6dacbde51c4b4f98936c524269/watchfiles-1.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9f25d0ba0fe2b6d2c921cf587b2bf4c451860086534f40c384329fb96e2044d1", size = 384815 },
+ { url = "https://files.pythonhosted.org/packages/df/b4/898991cececbe171e67142c31905510203649569d9817848f47c4177ee42/watchfiles-1.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:47eb32ef8c729dbc4f4273baece89398a4d4b5d21a1493efea77a17059f4df8a", size = 450680 },
+ { url = "https://files.pythonhosted.org/packages/58/f7/d4aa3000e812cfb5e5c2c6c0a3ec9d0a46a42489a8727edd160631c4e210/watchfiles-1.0.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:076f293100db3b0b634514aa0d294b941daa85fc777f9c698adb1009e5aca0b1", size = 455923 },
+ { url = "https://files.pythonhosted.org/packages/dd/95/7e2e4c6aba1b02fb5c76d2f6a450b85215921ec5f8f7ad5efd075369563f/watchfiles-1.0.4-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1eacd91daeb5158c598fe22d7ce66d60878b6294a86477a4715154990394c9b3", size = 482339 },
+ { url = "https://files.pythonhosted.org/packages/bb/67/4265b0fabcc2ef2c9e3e8802ba7908cf718a357ebfb49c72e53787156a48/watchfiles-1.0.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:13c2ce7b72026cfbca120d652f02c7750f33b4c9395d79c9790b27f014c8a5a2", size = 519908 },
+ { url = "https://files.pythonhosted.org/packages/0d/96/b57802d5f8164bdf070befb4fd3dec4edba5a364ec0670965a97eb8098ce/watchfiles-1.0.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:90192cdc15ab7254caa7765a98132a5a41471cf739513cc9bcf7d2ffcc0ec7b2", size = 501410 },
+ { url = "https://files.pythonhosted.org/packages/8b/18/6db0de4e8911ba14e31853201b40c0fa9fea5ecf3feb86b0ad58f006dfc3/watchfiles-1.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:278aaa395f405972e9f523bd786ed59dfb61e4b827856be46a42130605fd0899", size = 452876 },
+ { url = "https://files.pythonhosted.org/packages/df/df/092a961815edf723a38ba2638c49491365943919c3526cc9cf82c42786a6/watchfiles-1.0.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:a462490e75e466edbb9fc4cd679b62187153b3ba804868452ef0577ec958f5ff", size = 615353 },
+ { url = "https://files.pythonhosted.org/packages/f3/cf/b85fe645de4ff82f3f436c5e9032379fce37c303f6396a18f9726cc34519/watchfiles-1.0.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8d0d0630930f5cd5af929040e0778cf676a46775753e442a3f60511f2409f48f", size = 613187 },
+ { url = "https://files.pythonhosted.org/packages/f6/d4/a9fea27aef4dd69689bc3556718c1157a7accb72aa035ece87c1fa8483b5/watchfiles-1.0.4-cp310-cp310-win32.whl", hash = "sha256:cc27a65069bcabac4552f34fd2dce923ce3fcde0721a16e4fb1b466d63ec831f", size = 270799 },
+ { url = "https://files.pythonhosted.org/packages/df/02/dbe9d4439f15dd4ad0720b6e039bde9d66d1f830331f34c18eb70fa6608e/watchfiles-1.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:8b1f135238e75d075359cf506b27bf3f4ca12029c47d3e769d8593a2024ce161", size = 284145 },
+ { url = "https://files.pythonhosted.org/packages/0f/bb/8461adc4b1fed009546fb797fc0d5698dcfe5e289cb37e1b8f16a93cdc30/watchfiles-1.0.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:2a9f93f8439639dc244c4d2902abe35b0279102bca7bbcf119af964f51d53c19", size = 394869 },
+ { url = "https://files.pythonhosted.org/packages/55/88/9ebf36b3547176d1709c320de78c1fa3263a46be31b5b1267571d9102686/watchfiles-1.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9eea33ad8c418847dd296e61eb683cae1c63329b6d854aefcd412e12d94ee235", size = 384905 },
+ { url = "https://files.pythonhosted.org/packages/03/8a/04335ce23ef78d8c69f0913e8b20cf7d9233e3986543aeef95ef2d6e43d2/watchfiles-1.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:31f1a379c9dcbb3f09cf6be1b7e83b67c0e9faabed0471556d9438a4a4e14202", size = 449944 },
+ { url = "https://files.pythonhosted.org/packages/17/4e/c8d5dcd14fe637f4633616dabea8a4af0a10142dccf3b43e0f081ba81ab4/watchfiles-1.0.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ab594e75644421ae0a2484554832ca5895f8cab5ab62de30a1a57db460ce06c6", size = 456020 },
+ { url = "https://files.pythonhosted.org/packages/5e/74/3e91e09e1861dd7fbb1190ce7bd786700dc0fbc2ccd33bb9fff5de039229/watchfiles-1.0.4-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fc2eb5d14a8e0d5df7b36288979176fbb39672d45184fc4b1c004d7c3ce29317", size = 482983 },
+ { url = "https://files.pythonhosted.org/packages/a1/3d/e64de2d1ce4eb6a574fd78ce3a28c279da263be9ef3cfcab6f708df192f2/watchfiles-1.0.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3f68d8e9d5a321163ddacebe97091000955a1b74cd43724e346056030b0bacee", size = 520320 },
+ { url = "https://files.pythonhosted.org/packages/2c/bd/52235f7063b57240c66a991696ed27e2a18bd6fcec8a1ea5a040b70d0611/watchfiles-1.0.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f9ce064e81fe79faa925ff03b9f4c1a98b0bbb4a1b8c1b015afa93030cb21a49", size = 500988 },
+ { url = "https://files.pythonhosted.org/packages/3a/b0/ff04194141a5fe650c150400dd9e42667916bc0f52426e2e174d779b8a74/watchfiles-1.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b77d5622ac5cc91d21ae9c2b284b5d5c51085a0bdb7b518dba263d0af006132c", size = 452573 },
+ { url = "https://files.pythonhosted.org/packages/3d/9d/966164332c5a178444ae6d165082d4f351bd56afd9c3ec828eecbf190e6a/watchfiles-1.0.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1941b4e39de9b38b868a69b911df5e89dc43767feeda667b40ae032522b9b5f1", size = 615114 },
+ { url = "https://files.pythonhosted.org/packages/94/df/f569ae4c1877f96ad4086c153a8eee5a19a3b519487bf5c9454a3438c341/watchfiles-1.0.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4f8c4998506241dedf59613082d1c18b836e26ef2a4caecad0ec41e2a15e4226", size = 613076 },
+ { url = "https://files.pythonhosted.org/packages/15/ae/8ce5f29e65d5fa5790e3c80c289819c55e12be2e1b9f5b6a0e55e169b97d/watchfiles-1.0.4-cp311-cp311-win32.whl", hash = "sha256:4ebbeca9360c830766b9f0df3640b791be569d988f4be6c06d6fae41f187f105", size = 271013 },
+ { url = "https://files.pythonhosted.org/packages/a4/c6/79dc4a7c598a978e5fafa135090aaf7bbb03b8dec7bada437dfbe578e7ed/watchfiles-1.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:05d341c71f3d7098920f8551d4df47f7b57ac5b8dad56558064c3431bdfc0b74", size = 284229 },
+ { url = "https://files.pythonhosted.org/packages/37/3d/928633723211753f3500bfb138434f080363b87a1b08ca188b1ce54d1e05/watchfiles-1.0.4-cp311-cp311-win_arm64.whl", hash = "sha256:32b026a6ab64245b584acf4931fe21842374da82372d5c039cba6bf99ef722f3", size = 276824 },
+ { url = "https://files.pythonhosted.org/packages/5b/1a/8f4d9a1461709756ace48c98f07772bc6d4519b1e48b5fa24a4061216256/watchfiles-1.0.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:229e6ec880eca20e0ba2f7e2249c85bae1999d330161f45c78d160832e026ee2", size = 391345 },
+ { url = "https://files.pythonhosted.org/packages/bc/d2/6750b7b3527b1cdaa33731438432e7238a6c6c40a9924049e4cebfa40805/watchfiles-1.0.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5717021b199e8353782dce03bd8a8f64438832b84e2885c4a645f9723bf656d9", size = 381515 },
+ { url = "https://files.pythonhosted.org/packages/4e/17/80500e42363deef1e4b4818729ed939aaddc56f82f4e72b2508729dd3c6b/watchfiles-1.0.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0799ae68dfa95136dde7c472525700bd48777875a4abb2ee454e3ab18e9fc712", size = 449767 },
+ { url = "https://files.pythonhosted.org/packages/10/37/1427fa4cfa09adbe04b1e97bced19a29a3462cc64c78630787b613a23f18/watchfiles-1.0.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:43b168bba889886b62edb0397cab5b6490ffb656ee2fcb22dec8bfeb371a9e12", size = 455677 },
+ { url = "https://files.pythonhosted.org/packages/c5/7a/39e9397f3a19cb549a7d380412fd9e507d4854eddc0700bfad10ef6d4dba/watchfiles-1.0.4-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb2c46e275fbb9f0c92e7654b231543c7bbfa1df07cdc4b99fa73bedfde5c844", size = 482219 },
+ { url = "https://files.pythonhosted.org/packages/45/2d/7113931a77e2ea4436cad0c1690c09a40a7f31d366f79c6f0a5bc7a4f6d5/watchfiles-1.0.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:857f5fc3aa027ff5e57047da93f96e908a35fe602d24f5e5d8ce64bf1f2fc733", size = 518830 },
+ { url = "https://files.pythonhosted.org/packages/f9/1b/50733b1980fa81ef3c70388a546481ae5fa4c2080040100cd7bf3bf7b321/watchfiles-1.0.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55ccfd27c497b228581e2838d4386301227fc0cb47f5a12923ec2fe4f97b95af", size = 497997 },
+ { url = "https://files.pythonhosted.org/packages/2b/b4/9396cc61b948ef18943e7c85ecfa64cf940c88977d882da57147f62b34b1/watchfiles-1.0.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c11ea22304d17d4385067588123658e9f23159225a27b983f343fcffc3e796a", size = 452249 },
+ { url = "https://files.pythonhosted.org/packages/fb/69/0c65a5a29e057ad0dc691c2fa6c23b2983c7dabaa190ba553b29ac84c3cc/watchfiles-1.0.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:74cb3ca19a740be4caa18f238298b9d472c850f7b2ed89f396c00a4c97e2d9ff", size = 614412 },
+ { url = "https://files.pythonhosted.org/packages/7f/b9/319fcba6eba5fad34327d7ce16a6b163b39741016b1996f4a3c96b8dd0e1/watchfiles-1.0.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c7cce76c138a91e720d1df54014a047e680b652336e1b73b8e3ff3158e05061e", size = 611982 },
+ { url = "https://files.pythonhosted.org/packages/f1/47/143c92418e30cb9348a4387bfa149c8e0e404a7c5b0585d46d2f7031b4b9/watchfiles-1.0.4-cp312-cp312-win32.whl", hash = "sha256:b045c800d55bc7e2cadd47f45a97c7b29f70f08a7c2fa13241905010a5493f94", size = 271822 },
+ { url = "https://files.pythonhosted.org/packages/ea/94/b0165481bff99a64b29e46e07ac2e0df9f7a957ef13bec4ceab8515f44e3/watchfiles-1.0.4-cp312-cp312-win_amd64.whl", hash = "sha256:c2acfa49dd0ad0bf2a9c0bb9a985af02e89345a7189be1efc6baa085e0f72d7c", size = 285441 },
+ { url = "https://files.pythonhosted.org/packages/11/de/09fe56317d582742d7ca8c2ca7b52a85927ebb50678d9b0fa8194658f536/watchfiles-1.0.4-cp312-cp312-win_arm64.whl", hash = "sha256:22bb55a7c9e564e763ea06c7acea24fc5d2ee5dfc5dafc5cfbedfe58505e9f90", size = 277141 },
+ { url = "https://files.pythonhosted.org/packages/6f/06/175d5ac6b838fb319008c0cd981d7bf289317c510154d411d3584ca2b67b/watchfiles-1.0.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:cdcc92daeae268de1acf5b7befcd6cfffd9a047098199056c72e4623f531de18", size = 396269 },
+ { url = "https://files.pythonhosted.org/packages/86/ee/5db93b0b57dc0587abdbac4149296ee73275f615d790a82cb5598af0557f/watchfiles-1.0.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d8d3d9203705b5797f0af7e7e5baa17c8588030aaadb7f6a86107b7247303817", size = 386010 },
+ { url = "https://files.pythonhosted.org/packages/75/61/fe0dc5fedf152bfc085a53711f740701f6bdb8ab6b5c950402b681d4858b/watchfiles-1.0.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bdef5a1be32d0b07dcea3318a0be95d42c98ece24177820226b56276e06b63b0", size = 450913 },
+ { url = "https://files.pythonhosted.org/packages/9f/dd/3c7731af3baf1a9957afc643d176f94480921a690ec3237c9f9d11301c08/watchfiles-1.0.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:342622287b5604ddf0ed2d085f3a589099c9ae8b7331df3ae9845571586c4f3d", size = 453474 },
+]
+
+[[package]]
+name = "wcwidth"
+version = "0.2.13"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/6c/63/53559446a878410fc5a5974feb13d31d78d752eb18aeba59c7fef1af7598/wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5", size = 101301 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 },
+]
+
+[[package]]
+name = "webdataset"
+version = "0.2.100"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "braceexpand" },
+ { name = "numpy" },
+ { name = "pyyaml" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/df/5c/0334e75215a215d9d714f7f80b62306d3c1498c6170c0b8d9d9d980089f1/webdataset-0.2.100.tar.gz", hash = "sha256:798e30ff700277f0b963dc0395f3b9de4971a67cffc7cb6d0cb9225df7b68e42", size = 85368 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/8e/84/cf2319c375f4e061f27354685295905dc81105d2a2d2239baaf6f6e73c87/webdataset-0.2.100-py3-none-any.whl", hash = "sha256:f70a8e1f6d4f5268b364bd6f77fe8a1168ea14e7e9ed455d71f8d29585fd86af", size = 74796 },
+]
+
+[[package]]
+name = "websockets"
+version = "14.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/94/54/8359678c726243d19fae38ca14a334e740782336c9f19700858c4eb64a1e/websockets-14.2.tar.gz", hash = "sha256:5059ed9c54945efb321f097084b4c7e52c246f2c869815876a69d1efc4ad6eb5", size = 164394 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/28/fa/76607eb7dcec27b2d18d63f60a32e60e2b8629780f343bb83a4dbb9f4350/websockets-14.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e8179f95323b9ab1c11723e5d91a89403903f7b001828161b480a7810b334885", size = 163089 },
+ { url = "https://files.pythonhosted.org/packages/9e/00/ad2246b5030575b79e7af0721810fdaecaf94c4b2625842ef7a756fa06dd/websockets-14.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0d8c3e2cdb38f31d8bd7d9d28908005f6fa9def3324edb9bf336d7e4266fd397", size = 160741 },
+ { url = "https://files.pythonhosted.org/packages/72/f7/60f10924d333a28a1ff3fcdec85acf226281331bdabe9ad74947e1b7fc0a/websockets-14.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:714a9b682deb4339d39ffa674f7b674230227d981a37d5d174a4a83e3978a610", size = 160996 },
+ { url = "https://files.pythonhosted.org/packages/63/7c/c655789cf78648c01ac6ecbe2d6c18f91b75bdc263ffee4d08ce628d12f0/websockets-14.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2e53c72052f2596fb792a7acd9704cbc549bf70fcde8a99e899311455974ca3", size = 169974 },
+ { url = "https://files.pythonhosted.org/packages/fb/5b/013ed8b4611857ac92ac631079c08d9715b388bd1d88ec62e245f87a39df/websockets-14.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e3fbd68850c837e57373d95c8fe352203a512b6e49eaae4c2f4088ef8cf21980", size = 168985 },
+ { url = "https://files.pythonhosted.org/packages/cd/33/aa3e32fd0df213a5a442310754fe3f89dd87a0b8e5b4e11e0991dd3bcc50/websockets-14.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b27ece32f63150c268593d5fdb82819584831a83a3f5809b7521df0685cd5d8", size = 169297 },
+ { url = "https://files.pythonhosted.org/packages/93/17/dae0174883d6399f57853ac44abf5f228eaba86d98d160f390ffabc19b6e/websockets-14.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4daa0faea5424d8713142b33825fff03c736f781690d90652d2c8b053345b0e7", size = 169677 },
+ { url = "https://files.pythonhosted.org/packages/42/e2/0375af7ac00169b98647c804651c515054b34977b6c1354f1458e4116c1e/websockets-14.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:bc63cee8596a6ec84d9753fd0fcfa0452ee12f317afe4beae6b157f0070c6c7f", size = 169089 },
+ { url = "https://files.pythonhosted.org/packages/73/8d/80f71d2a351a44b602859af65261d3dde3a0ce4e76cf9383738a949e0cc3/websockets-14.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7a570862c325af2111343cc9b0257b7119b904823c675b22d4ac547163088d0d", size = 169026 },
+ { url = "https://files.pythonhosted.org/packages/48/97/173b1fa6052223e52bb4054a141433ad74931d94c575e04b654200b98ca4/websockets-14.2-cp310-cp310-win32.whl", hash = "sha256:75862126b3d2d505e895893e3deac0a9339ce750bd27b4ba515f008b5acf832d", size = 163967 },
+ { url = "https://files.pythonhosted.org/packages/c0/5b/2fcf60f38252a4562b28b66077e0d2b48f91fef645d5f78874cd1dec807b/websockets-14.2-cp310-cp310-win_amd64.whl", hash = "sha256:cc45afb9c9b2dc0852d5c8b5321759cf825f82a31bfaf506b65bf4668c96f8b2", size = 164413 },
+ { url = "https://files.pythonhosted.org/packages/15/b6/504695fb9a33df0ca56d157f5985660b5fc5b4bf8c78f121578d2d653392/websockets-14.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3bdc8c692c866ce5fefcaf07d2b55c91d6922ac397e031ef9b774e5b9ea42166", size = 163088 },
+ { url = "https://files.pythonhosted.org/packages/81/26/ebfb8f6abe963c795122439c6433c4ae1e061aaedfc7eff32d09394afbae/websockets-14.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c93215fac5dadc63e51bcc6dceca72e72267c11def401d6668622b47675b097f", size = 160745 },
+ { url = "https://files.pythonhosted.org/packages/a1/c6/1435ad6f6dcbff80bb95e8986704c3174da8866ddb751184046f5c139ef6/websockets-14.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1c9b6535c0e2cf8a6bf938064fb754aaceb1e6a4a51a80d884cd5db569886910", size = 160995 },
+ { url = "https://files.pythonhosted.org/packages/96/63/900c27cfe8be1a1f2433fc77cd46771cf26ba57e6bdc7cf9e63644a61863/websockets-14.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a52a6d7cf6938e04e9dceb949d35fbdf58ac14deea26e685ab6368e73744e4c", size = 170543 },
+ { url = "https://files.pythonhosted.org/packages/00/8b/bec2bdba92af0762d42d4410593c1d7d28e9bfd952c97a3729df603dc6ea/websockets-14.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9f05702e93203a6ff5226e21d9b40c037761b2cfb637187c9802c10f58e40473", size = 169546 },
+ { url = "https://files.pythonhosted.org/packages/6b/a9/37531cb5b994f12a57dec3da2200ef7aadffef82d888a4c29a0d781568e4/websockets-14.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22441c81a6748a53bfcb98951d58d1af0661ab47a536af08920d129b4d1c3473", size = 169911 },
+ { url = "https://files.pythonhosted.org/packages/60/d5/a6eadba2ed9f7e65d677fec539ab14a9b83de2b484ab5fe15d3d6d208c28/websockets-14.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:efd9b868d78b194790e6236d9cbc46d68aba4b75b22497eb4ab64fa640c3af56", size = 170183 },
+ { url = "https://files.pythonhosted.org/packages/76/57/a338ccb00d1df881c1d1ee1f2a20c9c1b5b29b51e9e0191ee515d254fea6/websockets-14.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1a5a20d5843886d34ff8c57424cc65a1deda4375729cbca4cb6b3353f3ce4142", size = 169623 },
+ { url = "https://files.pythonhosted.org/packages/64/22/e5f7c33db0cb2c1d03b79fd60d189a1da044e2661f5fd01d629451e1db89/websockets-14.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:34277a29f5303d54ec6468fb525d99c99938607bc96b8d72d675dee2b9f5bf1d", size = 169583 },
+ { url = "https://files.pythonhosted.org/packages/aa/2e/2b4662237060063a22e5fc40d46300a07142afe30302b634b4eebd717c07/websockets-14.2-cp311-cp311-win32.whl", hash = "sha256:02687db35dbc7d25fd541a602b5f8e451a238ffa033030b172ff86a93cb5dc2a", size = 163969 },
+ { url = "https://files.pythonhosted.org/packages/94/a5/0cda64e1851e73fc1ecdae6f42487babb06e55cb2f0dc8904b81d8ef6857/websockets-14.2-cp311-cp311-win_amd64.whl", hash = "sha256:862e9967b46c07d4dcd2532e9e8e3c2825e004ffbf91a5ef9dde519ee2effb0b", size = 164408 },
+ { url = "https://files.pythonhosted.org/packages/c1/81/04f7a397653dc8bec94ddc071f34833e8b99b13ef1a3804c149d59f92c18/websockets-14.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1f20522e624d7ffbdbe259c6b6a65d73c895045f76a93719aa10cd93b3de100c", size = 163096 },
+ { url = "https://files.pythonhosted.org/packages/ec/c5/de30e88557e4d70988ed4d2eabd73fd3e1e52456b9f3a4e9564d86353b6d/websockets-14.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:647b573f7d3ada919fd60e64d533409a79dcf1ea21daeb4542d1d996519ca967", size = 160758 },
+ { url = "https://files.pythonhosted.org/packages/e5/8c/d130d668781f2c77d106c007b6c6c1d9db68239107c41ba109f09e6c218a/websockets-14.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6af99a38e49f66be5a64b1e890208ad026cda49355661549c507152113049990", size = 160995 },
+ { url = "https://files.pythonhosted.org/packages/a6/bc/f6678a0ff17246df4f06765e22fc9d98d1b11a258cc50c5968b33d6742a1/websockets-14.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:091ab63dfc8cea748cc22c1db2814eadb77ccbf82829bac6b2fbe3401d548eda", size = 170815 },
+ { url = "https://files.pythonhosted.org/packages/d8/b2/8070cb970c2e4122a6ef38bc5b203415fd46460e025652e1ee3f2f43a9a3/websockets-14.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b374e8953ad477d17e4851cdc66d83fdc2db88d9e73abf755c94510ebddceb95", size = 169759 },
+ { url = "https://files.pythonhosted.org/packages/81/da/72f7caabd94652e6eb7e92ed2d3da818626e70b4f2b15a854ef60bf501ec/websockets-14.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a39d7eceeea35db85b85e1169011bb4321c32e673920ae9c1b6e0978590012a3", size = 170178 },
+ { url = "https://files.pythonhosted.org/packages/31/e0/812725b6deca8afd3a08a2e81b3c4c120c17f68c9b84522a520b816cda58/websockets-14.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0a6f3efd47ffd0d12080594f434faf1cd2549b31e54870b8470b28cc1d3817d9", size = 170453 },
+ { url = "https://files.pythonhosted.org/packages/66/d3/8275dbc231e5ba9bb0c4f93144394b4194402a7a0c8ffaca5307a58ab5e3/websockets-14.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:065ce275e7c4ffb42cb738dd6b20726ac26ac9ad0a2a48e33ca632351a737267", size = 169830 },
+ { url = "https://files.pythonhosted.org/packages/a3/ae/e7d1a56755ae15ad5a94e80dd490ad09e345365199600b2629b18ee37bc7/websockets-14.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e9d0e53530ba7b8b5e389c02282f9d2aa47581514bd6049d3a7cffe1385cf5fe", size = 169824 },
+ { url = "https://files.pythonhosted.org/packages/b6/32/88ccdd63cb261e77b882e706108d072e4f1c839ed723bf91a3e1f216bf60/websockets-14.2-cp312-cp312-win32.whl", hash = "sha256:20e6dd0984d7ca3037afcb4494e48c74ffb51e8013cac71cf607fffe11df7205", size = 163981 },
+ { url = "https://files.pythonhosted.org/packages/b3/7d/32cdb77990b3bdc34a306e0a0f73a1275221e9a66d869f6ff833c95b56ef/websockets-14.2-cp312-cp312-win_amd64.whl", hash = "sha256:44bba1a956c2c9d268bdcdf234d5e5ff4c9b6dc3e300545cbe99af59dda9dcce", size = 164421 },
+ { url = "https://files.pythonhosted.org/packages/10/3d/91d3d2bb1325cd83e8e2c02d0262c7d4426dc8fa0831ef1aa4d6bf2041af/websockets-14.2-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:d7d9cafbccba46e768be8a8ad4635fa3eae1ffac4c6e7cb4eb276ba41297ed29", size = 160773 },
+ { url = "https://files.pythonhosted.org/packages/33/7c/cdedadfef7381939577858b1b5718a4ab073adbb584e429dd9d9dc9bfe16/websockets-14.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:c76193c1c044bd1e9b3316dcc34b174bbf9664598791e6fb606d8d29000e070c", size = 161007 },
+ { url = "https://files.pythonhosted.org/packages/ca/35/7a20a3c450b27c04e50fbbfc3dfb161ed8e827b2a26ae31c4b59b018b8c6/websockets-14.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd475a974d5352390baf865309fe37dec6831aafc3014ffac1eea99e84e83fc2", size = 162264 },
+ { url = "https://files.pythonhosted.org/packages/e8/9c/e3f9600564b0c813f2448375cf28b47dc42c514344faed3a05d71fb527f9/websockets-14.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c6c0097a41968b2e2b54ed3424739aab0b762ca92af2379f152c1aef0187e1c", size = 161873 },
+ { url = "https://files.pythonhosted.org/packages/3f/37/260f189b16b2b8290d6ae80c9f96d8b34692cf1bb3475df54c38d3deb57d/websockets-14.2-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d7ff794c8b36bc402f2e07c0b2ceb4a2424147ed4785ff03e2a7af03711d60a", size = 161818 },
+ { url = "https://files.pythonhosted.org/packages/ff/1e/e47dedac8bf7140e59aa6a679e850c4df9610ae844d71b6015263ddea37b/websockets-14.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dec254fcabc7bd488dab64846f588fc5b6fe0d78f641180030f8ea27b76d72c3", size = 164465 },
+ { url = "https://files.pythonhosted.org/packages/7b/c8/d529f8a32ce40d98309f4470780631e971a5a842b60aec864833b3615786/websockets-14.2-py3-none-any.whl", hash = "sha256:7a6ceec4ea84469f15cf15807a747e9efe57e369c384fa86e022b3bea679b79b", size = 157416 },
+]
+
+[[package]]
+name = "werkzeug"
+version = "3.1.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "markupsafe" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/9f/69/83029f1f6300c5fb2471d621ab06f6ec6b3324685a2ce0f9777fd4a8b71e/werkzeug-3.1.3.tar.gz", hash = "sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746", size = 806925 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/52/24/ab44c871b0f07f491e5d2ad12c9bd7358e527510618cb1b803a88e986db1/werkzeug-3.1.3-py3-none-any.whl", hash = "sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e", size = 224498 },
+]
+
+[[package]]
+name = "xxhash"
+version = "3.5.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/00/5e/d6e5258d69df8b4ed8c83b6664f2b47d30d2dec551a29ad72a6c69eafd31/xxhash-3.5.0.tar.gz", hash = "sha256:84f2caddf951c9cbf8dc2e22a89d4ccf5d86391ac6418fe81e3c67d0cf60b45f", size = 84241 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/bb/8a/0e9feca390d512d293afd844d31670e25608c4a901e10202aa98785eab09/xxhash-3.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ece616532c499ee9afbb83078b1b952beffef121d989841f7f4b3dc5ac0fd212", size = 31970 },
+ { url = "https://files.pythonhosted.org/packages/16/e6/be5aa49580cd064a18200ab78e29b88b1127e1a8c7955eb8ecf81f2626eb/xxhash-3.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3171f693dbc2cef6477054a665dc255d996646b4023fe56cb4db80e26f4cc520", size = 30801 },
+ { url = "https://files.pythonhosted.org/packages/20/ee/b8a99ebbc6d1113b3a3f09e747fa318c3cde5b04bd9c197688fadf0eeae8/xxhash-3.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c5d3e570ef46adaf93fc81b44aca6002b5a4d8ca11bd0580c07eac537f36680", size = 220927 },
+ { url = "https://files.pythonhosted.org/packages/58/62/15d10582ef159283a5c2b47f6d799fc3303fe3911d5bb0bcc820e1ef7ff4/xxhash-3.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7cb29a034301e2982df8b1fe6328a84f4b676106a13e9135a0d7e0c3e9f806da", size = 200360 },
+ { url = "https://files.pythonhosted.org/packages/23/41/61202663ea9b1bd8e53673b8ec9e2619989353dba8cfb68e59a9cbd9ffe3/xxhash-3.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d0d307d27099bb0cbeea7260eb39ed4fdb99c5542e21e94bb6fd29e49c57a23", size = 428528 },
+ { url = "https://files.pythonhosted.org/packages/f2/07/d9a3059f702dec5b3b703737afb6dda32f304f6e9da181a229dafd052c29/xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0342aafd421795d740e514bc9858ebddfc705a75a8c5046ac56d85fe97bf196", size = 194149 },
+ { url = "https://files.pythonhosted.org/packages/eb/58/27caadf78226ecf1d62dbd0c01d152ed381c14c1ee4ad01f0d460fc40eac/xxhash-3.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3dbbd9892c5ebffeca1ed620cf0ade13eb55a0d8c84e0751a6653adc6ac40d0c", size = 207703 },
+ { url = "https://files.pythonhosted.org/packages/b1/08/32d558ce23e1e068453c39aed7b3c1cdc690c177873ec0ca3a90d5808765/xxhash-3.5.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4cc2d67fdb4d057730c75a64c5923abfa17775ae234a71b0200346bfb0a7f482", size = 216255 },
+ { url = "https://files.pythonhosted.org/packages/3f/d4/2b971e2d2b0a61045f842b622ef11e94096cf1f12cd448b6fd426e80e0e2/xxhash-3.5.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:ec28adb204b759306a3d64358a5e5c07d7b1dd0ccbce04aa76cb9377b7b70296", size = 202744 },
+ { url = "https://files.pythonhosted.org/packages/19/ae/6a6438864a8c4c39915d7b65effd85392ebe22710412902487e51769146d/xxhash-3.5.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1328f6d8cca2b86acb14104e381225a3d7b42c92c4b86ceae814e5c400dbb415", size = 210115 },
+ { url = "https://files.pythonhosted.org/packages/48/7d/b3c27c27d1fc868094d02fe4498ccce8cec9fcc591825c01d6bcb0b4fc49/xxhash-3.5.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8d47ebd9f5d9607fd039c1fbf4994e3b071ea23eff42f4ecef246ab2b7334198", size = 414247 },
+ { url = "https://files.pythonhosted.org/packages/a1/05/918f9e7d2fbbd334b829997045d341d6239b563c44e683b9a7ef8fe50f5d/xxhash-3.5.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b96d559e0fcddd3343c510a0fe2b127fbff16bf346dd76280b82292567523442", size = 191419 },
+ { url = "https://files.pythonhosted.org/packages/08/29/dfe393805b2f86bfc47c290b275f0b7c189dc2f4e136fd4754f32eb18a8d/xxhash-3.5.0-cp310-cp310-win32.whl", hash = "sha256:61c722ed8d49ac9bc26c7071eeaa1f6ff24053d553146d5df031802deffd03da", size = 30114 },
+ { url = "https://files.pythonhosted.org/packages/7b/d7/aa0b22c4ebb7c3ccb993d4c565132abc641cd11164f8952d89eb6a501909/xxhash-3.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:9bed5144c6923cc902cd14bb8963f2d5e034def4486ab0bbe1f58f03f042f9a9", size = 30003 },
+ { url = "https://files.pythonhosted.org/packages/69/12/f969b81541ee91b55f1ce469d7ab55079593c80d04fd01691b550e535000/xxhash-3.5.0-cp310-cp310-win_arm64.whl", hash = "sha256:893074d651cf25c1cc14e3bea4fceefd67f2921b1bb8e40fcfeba56820de80c6", size = 26773 },
+ { url = "https://files.pythonhosted.org/packages/b8/c7/afed0f131fbda960ff15eee7f304fa0eeb2d58770fade99897984852ef23/xxhash-3.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:02c2e816896dc6f85922ced60097bcf6f008dedfc5073dcba32f9c8dd786f3c1", size = 31969 },
+ { url = "https://files.pythonhosted.org/packages/8c/0c/7c3bc6d87e5235672fcc2fb42fd5ad79fe1033925f71bf549ee068c7d1ca/xxhash-3.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6027dcd885e21581e46d3c7f682cfb2b870942feeed58a21c29583512c3f09f8", size = 30800 },
+ { url = "https://files.pythonhosted.org/packages/04/9e/01067981d98069eec1c20201f8c145367698e9056f8bc295346e4ea32dd1/xxhash-3.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1308fa542bbdbf2fa85e9e66b1077eea3a88bef38ee8a06270b4298a7a62a166", size = 221566 },
+ { url = "https://files.pythonhosted.org/packages/d4/09/d4996de4059c3ce5342b6e1e6a77c9d6c91acce31f6ed979891872dd162b/xxhash-3.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c28b2fdcee797e1c1961cd3bcd3d545cab22ad202c846235197935e1df2f8ef7", size = 201214 },
+ { url = "https://files.pythonhosted.org/packages/62/f5/6d2dc9f8d55a7ce0f5e7bfef916e67536f01b85d32a9fbf137d4cadbee38/xxhash-3.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:924361811732ddad75ff23e90efd9ccfda4f664132feecb90895bade6a1b4623", size = 429433 },
+ { url = "https://files.pythonhosted.org/packages/d9/72/9256303f10e41ab004799a4aa74b80b3c5977d6383ae4550548b24bd1971/xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89997aa1c4b6a5b1e5b588979d1da048a3c6f15e55c11d117a56b75c84531f5a", size = 194822 },
+ { url = "https://files.pythonhosted.org/packages/34/92/1a3a29acd08248a34b0e6a94f4e0ed9b8379a4ff471f1668e4dce7bdbaa8/xxhash-3.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:685c4f4e8c59837de103344eb1c8a3851f670309eb5c361f746805c5471b8c88", size = 208538 },
+ { url = "https://files.pythonhosted.org/packages/53/ad/7fa1a109663366de42f724a1cdb8e796a260dbac45047bce153bc1e18abf/xxhash-3.5.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dbd2ecfbfee70bc1a4acb7461fa6af7748ec2ab08ac0fa298f281c51518f982c", size = 216953 },
+ { url = "https://files.pythonhosted.org/packages/35/02/137300e24203bf2b2a49b48ce898ecce6fd01789c0fcd9c686c0a002d129/xxhash-3.5.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:25b5a51dc3dfb20a10833c8eee25903fd2e14059e9afcd329c9da20609a307b2", size = 203594 },
+ { url = "https://files.pythonhosted.org/packages/23/03/aeceb273933d7eee248c4322b98b8e971f06cc3880e5f7602c94e5578af5/xxhash-3.5.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:a8fb786fb754ef6ff8c120cb96629fb518f8eb5a61a16aac3a979a9dbd40a084", size = 210971 },
+ { url = "https://files.pythonhosted.org/packages/e3/64/ed82ec09489474cbb35c716b189ddc1521d8b3de12b1b5ab41ce7f70253c/xxhash-3.5.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:a905ad00ad1e1c34fe4e9d7c1d949ab09c6fa90c919860c1534ff479f40fd12d", size = 415050 },
+ { url = "https://files.pythonhosted.org/packages/71/43/6db4c02dcb488ad4e03bc86d70506c3d40a384ee73c9b5c93338eb1f3c23/xxhash-3.5.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:963be41bcd49f53af6d795f65c0da9b4cc518c0dd9c47145c98f61cb464f4839", size = 192216 },
+ { url = "https://files.pythonhosted.org/packages/22/6d/db4abec29e7a567455344433d095fdb39c97db6955bb4a2c432e486b4d28/xxhash-3.5.0-cp311-cp311-win32.whl", hash = "sha256:109b436096d0a2dd039c355fa3414160ec4d843dfecc64a14077332a00aeb7da", size = 30120 },
+ { url = "https://files.pythonhosted.org/packages/52/1c/fa3b61c0cf03e1da4767213672efe186b1dfa4fc901a4a694fb184a513d1/xxhash-3.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:b702f806693201ad6c0a05ddbbe4c8f359626d0b3305f766077d51388a6bac58", size = 30003 },
+ { url = "https://files.pythonhosted.org/packages/6b/8e/9e6fc572acf6e1cc7ccb01973c213f895cb8668a9d4c2b58a99350da14b7/xxhash-3.5.0-cp311-cp311-win_arm64.whl", hash = "sha256:c4dcb4120d0cc3cc448624147dba64e9021b278c63e34a38789b688fd0da9bf3", size = 26777 },
+ { url = "https://files.pythonhosted.org/packages/07/0e/1bfce2502c57d7e2e787600b31c83535af83746885aa1a5f153d8c8059d6/xxhash-3.5.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:14470ace8bd3b5d51318782cd94e6f94431974f16cb3b8dc15d52f3b69df8e00", size = 31969 },
+ { url = "https://files.pythonhosted.org/packages/3f/d6/8ca450d6fe5b71ce521b4e5db69622383d039e2b253e9b2f24f93265b52c/xxhash-3.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:59aa1203de1cb96dbeab595ded0ad0c0056bb2245ae11fac11c0ceea861382b9", size = 30787 },
+ { url = "https://files.pythonhosted.org/packages/5b/84/de7c89bc6ef63d750159086a6ada6416cc4349eab23f76ab870407178b93/xxhash-3.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08424f6648526076e28fae6ea2806c0a7d504b9ef05ae61d196d571e5c879c84", size = 220959 },
+ { url = "https://files.pythonhosted.org/packages/fe/86/51258d3e8a8545ff26468c977101964c14d56a8a37f5835bc0082426c672/xxhash-3.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61a1ff00674879725b194695e17f23d3248998b843eb5e933007ca743310f793", size = 200006 },
+ { url = "https://files.pythonhosted.org/packages/02/0a/96973bd325412feccf23cf3680fd2246aebf4b789122f938d5557c54a6b2/xxhash-3.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f2f2c61bee5844d41c3eb015ac652a0229e901074951ae48581d58bfb2ba01be", size = 428326 },
+ { url = "https://files.pythonhosted.org/packages/11/a7/81dba5010f7e733de88af9555725146fc133be97ce36533867f4c7e75066/xxhash-3.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d32a592cac88d18cc09a89172e1c32d7f2a6e516c3dfde1b9adb90ab5df54a6", size = 194380 },
+ { url = "https://files.pythonhosted.org/packages/fb/7d/f29006ab398a173f4501c0e4977ba288f1c621d878ec217b4ff516810c04/xxhash-3.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:70dabf941dede727cca579e8c205e61121afc9b28516752fd65724be1355cc90", size = 207934 },
+ { url = "https://files.pythonhosted.org/packages/8a/6e/6e88b8f24612510e73d4d70d9b0c7dff62a2e78451b9f0d042a5462c8d03/xxhash-3.5.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e5d0ddaca65ecca9c10dcf01730165fd858533d0be84c75c327487c37a906a27", size = 216301 },
+ { url = "https://files.pythonhosted.org/packages/af/51/7862f4fa4b75a25c3b4163c8a873f070532fe5f2d3f9b3fc869c8337a398/xxhash-3.5.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e5b5e16c5a480fe5f59f56c30abdeba09ffd75da8d13f6b9b6fd224d0b4d0a2", size = 203351 },
+ { url = "https://files.pythonhosted.org/packages/22/61/8d6a40f288f791cf79ed5bb113159abf0c81d6efb86e734334f698eb4c59/xxhash-3.5.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:149b7914451eb154b3dfaa721315117ea1dac2cc55a01bfbd4df7c68c5dd683d", size = 210294 },
+ { url = "https://files.pythonhosted.org/packages/17/02/215c4698955762d45a8158117190261b2dbefe9ae7e5b906768c09d8bc74/xxhash-3.5.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:eade977f5c96c677035ff39c56ac74d851b1cca7d607ab3d8f23c6b859379cab", size = 414674 },
+ { url = "https://files.pythonhosted.org/packages/31/5c/b7a8db8a3237cff3d535261325d95de509f6a8ae439a5a7a4ffcff478189/xxhash-3.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fa9f547bd98f5553d03160967866a71056a60960be00356a15ecc44efb40ba8e", size = 192022 },
+ { url = "https://files.pythonhosted.org/packages/78/e3/dd76659b2811b3fd06892a8beb850e1996b63e9235af5a86ea348f053e9e/xxhash-3.5.0-cp312-cp312-win32.whl", hash = "sha256:f7b58d1fd3551b8c80a971199543379be1cee3d0d409e1f6d8b01c1a2eebf1f8", size = 30170 },
+ { url = "https://files.pythonhosted.org/packages/d9/6b/1c443fe6cfeb4ad1dcf231cdec96eb94fb43d6498b4469ed8b51f8b59a37/xxhash-3.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:fa0cafd3a2af231b4e113fba24a65d7922af91aeb23774a8b78228e6cd785e3e", size = 30040 },
+ { url = "https://files.pythonhosted.org/packages/0f/eb/04405305f290173acc0350eba6d2f1a794b57925df0398861a20fbafa415/xxhash-3.5.0-cp312-cp312-win_arm64.whl", hash = "sha256:586886c7e89cb9828bcd8a5686b12e161368e0064d040e225e72607b43858ba2", size = 26796 },
+ { url = "https://files.pythonhosted.org/packages/ab/9a/233606bada5bd6f50b2b72c45de3d9868ad551e83893d2ac86dc7bb8553a/xxhash-3.5.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:2014c5b3ff15e64feecb6b713af12093f75b7926049e26a580e94dcad3c73d8c", size = 29732 },
+ { url = "https://files.pythonhosted.org/packages/0c/67/f75276ca39e2c6604e3bee6c84e9db8a56a4973fde9bf35989787cf6e8aa/xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fab81ef75003eda96239a23eda4e4543cedc22e34c373edcaf744e721a163986", size = 36214 },
+ { url = "https://files.pythonhosted.org/packages/0f/f8/f6c61fd794229cc3848d144f73754a0c107854372d7261419dcbbd286299/xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e2febf914ace002132aa09169cc572e0d8959d0f305f93d5828c4836f9bc5a6", size = 32020 },
+ { url = "https://files.pythonhosted.org/packages/79/d3/c029c99801526f859e6b38d34ab87c08993bf3dcea34b11275775001638a/xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5d3a10609c51da2a1c0ea0293fc3968ca0a18bd73838455b5bca3069d7f8e32b", size = 40515 },
+ { url = "https://files.pythonhosted.org/packages/62/e3/bef7b82c1997579c94de9ac5ea7626d01ae5858aa22bf4fcb38bf220cb3e/xxhash-3.5.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5a74f23335b9689b66eb6dbe2a931a88fcd7a4c2cc4b1cb0edba8ce381c7a1da", size = 30064 },
+]
+
+[[package]]
+name = "yarl"
+version = "1.18.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "idna" },
+ { name = "multidict" },
+ { name = "propcache" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/b7/9d/4b94a8e6d2b51b599516a5cb88e5bc99b4d8d4583e468057eaa29d5f0918/yarl-1.18.3.tar.gz", hash = "sha256:ac1801c45cbf77b6c99242eeff4fffb5e4e73a800b5c4ad4fc0be5def634d2e1", size = 181062 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d2/98/e005bc608765a8a5569f58e650961314873c8469c333616eb40bff19ae97/yarl-1.18.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7df647e8edd71f000a5208fe6ff8c382a1de8edfbccdbbfe649d263de07d8c34", size = 141458 },
+ { url = "https://files.pythonhosted.org/packages/df/5d/f8106b263b8ae8a866b46d9be869ac01f9b3fb7f2325f3ecb3df8003f796/yarl-1.18.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c69697d3adff5aa4f874b19c0e4ed65180ceed6318ec856ebc423aa5850d84f7", size = 94365 },
+ { url = "https://files.pythonhosted.org/packages/56/3e/d8637ddb9ba69bf851f765a3ee288676f7cf64fb3be13760c18cbc9d10bd/yarl-1.18.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:602d98f2c2d929f8e697ed274fbadc09902c4025c5a9963bf4e9edfc3ab6f7ed", size = 92181 },
+ { url = "https://files.pythonhosted.org/packages/76/f9/d616a5c2daae281171de10fba41e1c0e2d8207166fc3547252f7d469b4e1/yarl-1.18.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c654d5207c78e0bd6d749f6dae1dcbbfde3403ad3a4b11f3c5544d9906969dde", size = 315349 },
+ { url = "https://files.pythonhosted.org/packages/bb/b4/3ea5e7b6f08f698b3769a06054783e434f6d59857181b5c4e145de83f59b/yarl-1.18.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5094d9206c64181d0f6e76ebd8fb2f8fe274950a63890ee9e0ebfd58bf9d787b", size = 330494 },
+ { url = "https://files.pythonhosted.org/packages/55/f1/e0fc810554877b1b67420568afff51b967baed5b53bcc983ab164eebf9c9/yarl-1.18.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:35098b24e0327fc4ebdc8ffe336cee0a87a700c24ffed13161af80124b7dc8e5", size = 326927 },
+ { url = "https://files.pythonhosted.org/packages/a9/42/b1753949b327b36f210899f2dd0a0947c0c74e42a32de3f8eb5c7d93edca/yarl-1.18.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3236da9272872443f81fedc389bace88408f64f89f75d1bdb2256069a8730ccc", size = 319703 },
+ { url = "https://files.pythonhosted.org/packages/f0/6d/e87c62dc9635daefb064b56f5c97df55a2e9cc947a2b3afd4fd2f3b841c7/yarl-1.18.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2c08cc9b16f4f4bc522771d96734c7901e7ebef70c6c5c35dd0f10845270bcd", size = 310246 },
+ { url = "https://files.pythonhosted.org/packages/e3/ef/e2e8d1785cdcbd986f7622d7f0098205f3644546da7919c24b95790ec65a/yarl-1.18.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:80316a8bd5109320d38eef8833ccf5f89608c9107d02d2a7f985f98ed6876990", size = 319730 },
+ { url = "https://files.pythonhosted.org/packages/fc/15/8723e22345bc160dfde68c4b3ae8b236e868f9963c74015f1bc8a614101c/yarl-1.18.3-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:c1e1cc06da1491e6734f0ea1e6294ce00792193c463350626571c287c9a704db", size = 321681 },
+ { url = "https://files.pythonhosted.org/packages/86/09/bf764e974f1516efa0ae2801494a5951e959f1610dd41edbfc07e5e0f978/yarl-1.18.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fea09ca13323376a2fdfb353a5fa2e59f90cd18d7ca4eaa1fd31f0a8b4f91e62", size = 324812 },
+ { url = "https://files.pythonhosted.org/packages/f6/4c/20a0187e3b903c97d857cf0272d687c1b08b03438968ae8ffc50fe78b0d6/yarl-1.18.3-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:e3b9fd71836999aad54084906f8663dffcd2a7fb5cdafd6c37713b2e72be1760", size = 337011 },
+ { url = "https://files.pythonhosted.org/packages/c9/71/6244599a6e1cc4c9f73254a627234e0dad3883ece40cc33dce6265977461/yarl-1.18.3-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:757e81cae69244257d125ff31663249b3013b5dc0a8520d73694aed497fb195b", size = 338132 },
+ { url = "https://files.pythonhosted.org/packages/af/f5/e0c3efaf74566c4b4a41cb76d27097df424052a064216beccae8d303c90f/yarl-1.18.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b1771de9944d875f1b98a745bc547e684b863abf8f8287da8466cf470ef52690", size = 331849 },
+ { url = "https://files.pythonhosted.org/packages/8a/b8/3d16209c2014c2f98a8f658850a57b716efb97930aebf1ca0d9325933731/yarl-1.18.3-cp310-cp310-win32.whl", hash = "sha256:8874027a53e3aea659a6d62751800cf6e63314c160fd607489ba5c2edd753cf6", size = 84309 },
+ { url = "https://files.pythonhosted.org/packages/fd/b7/2e9a5b18eb0fe24c3a0e8bae994e812ed9852ab4fd067c0107fadde0d5f0/yarl-1.18.3-cp310-cp310-win_amd64.whl", hash = "sha256:93b2e109287f93db79210f86deb6b9bbb81ac32fc97236b16f7433db7fc437d8", size = 90484 },
+ { url = "https://files.pythonhosted.org/packages/40/93/282b5f4898d8e8efaf0790ba6d10e2245d2c9f30e199d1a85cae9356098c/yarl-1.18.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8503ad47387b8ebd39cbbbdf0bf113e17330ffd339ba1144074da24c545f0069", size = 141555 },
+ { url = "https://files.pythonhosted.org/packages/6d/9c/0a49af78df099c283ca3444560f10718fadb8a18dc8b3edf8c7bd9fd7d89/yarl-1.18.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:02ddb6756f8f4517a2d5e99d8b2f272488e18dd0bfbc802f31c16c6c20f22193", size = 94351 },
+ { url = "https://files.pythonhosted.org/packages/5a/a1/205ab51e148fdcedad189ca8dd587794c6f119882437d04c33c01a75dece/yarl-1.18.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:67a283dd2882ac98cc6318384f565bffc751ab564605959df4752d42483ad889", size = 92286 },
+ { url = "https://files.pythonhosted.org/packages/ed/fe/88b690b30f3f59275fb674f5f93ddd4a3ae796c2b62e5bb9ece8a4914b83/yarl-1.18.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d980e0325b6eddc81331d3f4551e2a333999fb176fd153e075c6d1c2530aa8a8", size = 340649 },
+ { url = "https://files.pythonhosted.org/packages/07/eb/3b65499b568e01f36e847cebdc8d7ccb51fff716dbda1ae83c3cbb8ca1c9/yarl-1.18.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b643562c12680b01e17239be267bc306bbc6aac1f34f6444d1bded0c5ce438ca", size = 356623 },
+ { url = "https://files.pythonhosted.org/packages/33/46/f559dc184280b745fc76ec6b1954de2c55595f0ec0a7614238b9ebf69618/yarl-1.18.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c017a3b6df3a1bd45b9fa49a0f54005e53fbcad16633870104b66fa1a30a29d8", size = 354007 },
+ { url = "https://files.pythonhosted.org/packages/af/ba/1865d85212351ad160f19fb99808acf23aab9a0f8ff31c8c9f1b4d671fc9/yarl-1.18.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75674776d96d7b851b6498f17824ba17849d790a44d282929c42dbb77d4f17ae", size = 344145 },
+ { url = "https://files.pythonhosted.org/packages/94/cb/5c3e975d77755d7b3d5193e92056b19d83752ea2da7ab394e22260a7b824/yarl-1.18.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ccaa3a4b521b780a7e771cc336a2dba389a0861592bbce09a476190bb0c8b4b3", size = 336133 },
+ { url = "https://files.pythonhosted.org/packages/19/89/b77d3fd249ab52a5c40859815765d35c91425b6bb82e7427ab2f78f5ff55/yarl-1.18.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2d06d3005e668744e11ed80812e61efd77d70bb7f03e33c1598c301eea20efbb", size = 347967 },
+ { url = "https://files.pythonhosted.org/packages/35/bd/f6b7630ba2cc06c319c3235634c582a6ab014d52311e7d7c22f9518189b5/yarl-1.18.3-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:9d41beda9dc97ca9ab0b9888cb71f7539124bc05df02c0cff6e5acc5a19dcc6e", size = 346397 },
+ { url = "https://files.pythonhosted.org/packages/18/1a/0b4e367d5a72d1f095318344848e93ea70da728118221f84f1bf6c1e39e7/yarl-1.18.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ba23302c0c61a9999784e73809427c9dbedd79f66a13d84ad1b1943802eaaf59", size = 350206 },
+ { url = "https://files.pythonhosted.org/packages/b5/cf/320fff4367341fb77809a2d8d7fe75b5d323a8e1b35710aafe41fdbf327b/yarl-1.18.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:6748dbf9bfa5ba1afcc7556b71cda0d7ce5f24768043a02a58846e4a443d808d", size = 362089 },
+ { url = "https://files.pythonhosted.org/packages/57/cf/aadba261d8b920253204085268bad5e8cdd86b50162fcb1b10c10834885a/yarl-1.18.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:0b0cad37311123211dc91eadcb322ef4d4a66008d3e1bdc404808992260e1a0e", size = 366267 },
+ { url = "https://files.pythonhosted.org/packages/54/58/fb4cadd81acdee6dafe14abeb258f876e4dd410518099ae9a35c88d8097c/yarl-1.18.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0fb2171a4486bb075316ee754c6d8382ea6eb8b399d4ec62fde2b591f879778a", size = 359141 },
+ { url = "https://files.pythonhosted.org/packages/9a/7a/4c571597589da4cd5c14ed2a0b17ac56ec9ee7ee615013f74653169e702d/yarl-1.18.3-cp311-cp311-win32.whl", hash = "sha256:61b1a825a13bef4a5f10b1885245377d3cd0bf87cba068e1d9a88c2ae36880e1", size = 84402 },
+ { url = "https://files.pythonhosted.org/packages/ae/7b/8600250b3d89b625f1121d897062f629883c2f45339623b69b1747ec65fa/yarl-1.18.3-cp311-cp311-win_amd64.whl", hash = "sha256:b9d60031cf568c627d028239693fd718025719c02c9f55df0a53e587aab951b5", size = 91030 },
+ { url = "https://files.pythonhosted.org/packages/33/85/bd2e2729752ff4c77338e0102914897512e92496375e079ce0150a6dc306/yarl-1.18.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1dd4bdd05407ced96fed3d7f25dbbf88d2ffb045a0db60dbc247f5b3c5c25d50", size = 142644 },
+ { url = "https://files.pythonhosted.org/packages/ff/74/1178322cc0f10288d7eefa6e4a85d8d2e28187ccab13d5b844e8b5d7c88d/yarl-1.18.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7c33dd1931a95e5d9a772d0ac5e44cac8957eaf58e3c8da8c1414de7dd27c576", size = 94962 },
+ { url = "https://files.pythonhosted.org/packages/be/75/79c6acc0261e2c2ae8a1c41cf12265e91628c8c58ae91f5ff59e29c0787f/yarl-1.18.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:25b411eddcfd56a2f0cd6a384e9f4f7aa3efee14b188de13048c25b5e91f1640", size = 92795 },
+ { url = "https://files.pythonhosted.org/packages/6b/32/927b2d67a412c31199e83fefdce6e645247b4fb164aa1ecb35a0f9eb2058/yarl-1.18.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:436c4fc0a4d66b2badc6c5fc5ef4e47bb10e4fd9bf0c79524ac719a01f3607c2", size = 332368 },
+ { url = "https://files.pythonhosted.org/packages/19/e5/859fca07169d6eceeaa4fde1997c91d8abde4e9a7c018e371640c2da2b71/yarl-1.18.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e35ef8683211db69ffe129a25d5634319a677570ab6b2eba4afa860f54eeaf75", size = 342314 },
+ { url = "https://files.pythonhosted.org/packages/08/75/76b63ccd91c9e03ab213ef27ae6add2e3400e77e5cdddf8ed2dbc36e3f21/yarl-1.18.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:84b2deecba4a3f1a398df819151eb72d29bfeb3b69abb145a00ddc8d30094512", size = 341987 },
+ { url = "https://files.pythonhosted.org/packages/1a/e1/a097d5755d3ea8479a42856f51d97eeff7a3a7160593332d98f2709b3580/yarl-1.18.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00e5a1fea0fd4f5bfa7440a47eff01d9822a65b4488f7cff83155a0f31a2ecba", size = 336914 },
+ { url = "https://files.pythonhosted.org/packages/0b/42/e1b4d0e396b7987feceebe565286c27bc085bf07d61a59508cdaf2d45e63/yarl-1.18.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0e883008013c0e4aef84dcfe2a0b172c4d23c2669412cf5b3371003941f72bb", size = 325765 },
+ { url = "https://files.pythonhosted.org/packages/7e/18/03a5834ccc9177f97ca1bbb245b93c13e58e8225276f01eedc4cc98ab820/yarl-1.18.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5a3f356548e34a70b0172d8890006c37be92995f62d95a07b4a42e90fba54272", size = 344444 },
+ { url = "https://files.pythonhosted.org/packages/c8/03/a713633bdde0640b0472aa197b5b86e90fbc4c5bc05b727b714cd8a40e6d/yarl-1.18.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:ccd17349166b1bee6e529b4add61727d3f55edb7babbe4069b5764c9587a8cc6", size = 340760 },
+ { url = "https://files.pythonhosted.org/packages/eb/99/f6567e3f3bbad8fd101886ea0276c68ecb86a2b58be0f64077396cd4b95e/yarl-1.18.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b958ddd075ddba5b09bb0be8a6d9906d2ce933aee81100db289badbeb966f54e", size = 346484 },
+ { url = "https://files.pythonhosted.org/packages/8e/a9/84717c896b2fc6cb15bd4eecd64e34a2f0a9fd6669e69170c73a8b46795a/yarl-1.18.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c7d79f7d9aabd6011004e33b22bc13056a3e3fb54794d138af57f5ee9d9032cb", size = 359864 },
+ { url = "https://files.pythonhosted.org/packages/1e/2e/d0f5f1bef7ee93ed17e739ec8dbcb47794af891f7d165fa6014517b48169/yarl-1.18.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:4891ed92157e5430874dad17b15eb1fda57627710756c27422200c52d8a4e393", size = 364537 },
+ { url = "https://files.pythonhosted.org/packages/97/8a/568d07c5d4964da5b02621a517532adb8ec5ba181ad1687191fffeda0ab6/yarl-1.18.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ce1af883b94304f493698b00d0f006d56aea98aeb49d75ec7d98cd4a777e9285", size = 357861 },
+ { url = "https://files.pythonhosted.org/packages/7d/e3/924c3f64b6b3077889df9a1ece1ed8947e7b61b0a933f2ec93041990a677/yarl-1.18.3-cp312-cp312-win32.whl", hash = "sha256:f91c4803173928a25e1a55b943c81f55b8872f0018be83e3ad4938adffb77dd2", size = 84097 },
+ { url = "https://files.pythonhosted.org/packages/34/45/0e055320daaabfc169b21ff6174567b2c910c45617b0d79c68d7ab349b02/yarl-1.18.3-cp312-cp312-win_amd64.whl", hash = "sha256:7e2ee16578af3b52ac2f334c3b1f92262f47e02cc6193c598502bd46f5cd1477", size = 90399 },
+ { url = "https://files.pythonhosted.org/packages/f5/4b/a06e0ec3d155924f77835ed2d167ebd3b211a7b0853da1cf8d8414d784ef/yarl-1.18.3-py3-none-any.whl", hash = "sha256:b57f4f58099328dfb26c6a771d09fb20dbbae81d20cfb66141251ea063bd101b", size = 45109 },
+]
+
+[[package]]
+name = "zipp"
+version = "3.21.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/3f/50/bad581df71744867e9468ebd0bcd6505de3b275e06f202c2cb016e3ff56f/zipp-3.21.0.tar.gz", hash = "sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4", size = 24545 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b7/1a/7e4798e9339adc931158c9d69ecc34f5e6791489d469f5e50ec15e35f458/zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931", size = 9630 },
+]