Commit 
							
							·
						a8bf2f3
	
verified
								·
						0
								Parent(s):
							
							
Super-squash branch 'main' using huggingface_hub
Browse files- .gitattributes +35 -0
- README.md +100 -0
- common_spear.py +702 -0
- config.json +167 -0
- configuration_spear.py +347 -0
- generation_config.json +3 -0
- model-00001-of-00003.safetensors +3 -0
- model-00002-of-00003.safetensors +3 -0
- model-00003-of-00003.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_spear.py +0 -0
- processing_spear.py +1897 -0
    	
        .gitattributes
    ADDED
    
    | @@ -0,0 +1,35 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            +
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         | 
| 6 | 
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         | 
| 7 | 
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         | 
| 8 | 
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         | 
| 9 | 
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         | 
| 10 | 
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         | 
| 11 | 
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         | 
| 12 | 
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         | 
| 13 | 
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         | 
| 14 | 
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         | 
| 15 | 
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         | 
| 16 | 
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         | 
| 17 | 
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         | 
| 18 | 
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         | 
| 19 | 
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         | 
| 20 | 
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         | 
| 21 | 
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         | 
| 22 | 
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         | 
| 23 | 
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         | 
| 24 | 
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         | 
| 25 | 
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         | 
| 26 | 
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         | 
| 27 | 
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         | 
| 28 | 
            +
            *.tar filter=lfs diff=lfs merge=lfs -text
         | 
| 29 | 
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         | 
| 30 | 
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         | 
| 31 | 
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         | 
| 32 | 
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         | 
| 33 | 
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,100 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            license: gemma
         | 
| 3 | 
            +
            library_name: transformers
         | 
| 4 | 
            +
            pipeline_tag: visual-question-answering
         | 
| 5 | 
            +
            ---
         | 
| 6 | 
            +
            # SPEAR-1 model card
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            SPEAR-1 is a cutting-edge Vision-Language-Action (VLA) model capable of achieving performance __superior or on par with state-of-the-art models such as pi0-FAST and pi0.5__ 
         | 
| 9 | 
            +
            on multiple embodiments while being trained __on 20x less robot data__. 
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            This model was developed by [INSAIT](https://insait.ai/), a special unit of Sofia University St. Kliment Ohridski, in Sofia, Bulgaria. 
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            Code and model weights for SPEAR-1 models are free to used under the Gemma license.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            This repo provides model weights fine-tuned for a Franka setup with one wrist and one external camera.
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            ## Model description
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            The key to SPEAR-1's data efficiency is SPEAR-VLM, a 3D-aware VLM. SPEAR-VLM extends PaliGemma with the MoGe depth encoder and is trained on 3D VQA tasks using
         | 
| 20 | 
            +
            primarily non-robot data sources, such as EgoExo-4D.
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            SPEAR-1's architecture combines SPEAR-VLM with a DiT action expert. It is first pre-trained on a mixture of robot demonstration datasets from Open X Embodiment and 
         | 
| 23 | 
            +
            then fine-tuned for specific embodiments.
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            ## Use with 🤗 Transformers
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            We provide a fully `AutoModel` compatible implementation of SPEAR-1 that can be used via transformers.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            ### Environment setup
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            The current implementation requires the following additional dependencies: `roma`, `timm`, `flash-attn`.
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            Here is a snippet to set up a working environment for inference via `uv`:
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            ```
         | 
| 36 | 
            +
            uv venv python 3.10.12
         | 
| 37 | 
            +
            source .venv/bin/activate
         | 
| 38 | 
            +
            uv pip install --torch-backend=cu126 roma==1.5.0 numpy==2.2.4 torch==2.6.0 torchvision==0.21.0 transformers==4.47.0 timm==1.0.15 
         | 
| 39 | 
            +
            uv pip install --no-build-isolation setuptools psutil flash-attn==2.7.3
         | 
| 40 | 
            +
            ```
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            ### Example usage
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            ```python
         | 
| 46 | 
            +
            from typing import Dict
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            import numpy as np
         | 
| 49 | 
            +
            import torch
         | 
| 50 | 
            +
            from PIL import Image
         | 
| 51 | 
            +
            from transformers import AutoModel
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            model = AutoModel.from_pretrained("INSAIT-Institute/spear1-franka")
         | 
| 54 | 
            +
            model = model.to(dtype=torch.bfloat16, device="cuda").eval()
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            main_image = np.asarray(Image.open("path/to/main_image.png"))
         | 
| 57 | 
            +
            wrist_image = np.asarray(Image.open("path/to/wrist_image.png"))
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            ee_translation = np.array([0.36, 0.0, 0.56])
         | 
| 60 | 
            +
            ee_rotation = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
         | 
| 61 | 
            +
            gripper = np.array(1.0)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            model_input: Dict[str, np.ndarray | str | Dict[str, np.ndarray]] = {
         | 
| 64 | 
            +
                "images": {
         | 
| 65 | 
            +
                  "main": main_image, # (H, W, C)
         | 
| 66 | 
            +
                  "wrist": wrist_image, # (H, W, C)
         | 
| 67 | 
            +
                },
         | 
| 68 | 
            +
                "ee_translation": ee_translation, # (3,)
         | 
| 69 | 
            +
                "ee_rotation": ee_rotation, # (3, 3)
         | 
| 70 | 
            +
                "gripper": gripper, # (1,)
         | 
| 71 | 
            +
                "language_instruction": "put the carrot on the blue plate",
         | 
| 72 | 
            +
                "dataset_name": "droid"
         | 
| 73 | 
            +
            }
         | 
| 74 | 
            +
             | 
| 75 | 
            +
            model_output: Dict[str, np.ndarray] = model.predict_action(model_input)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            ctrl_translation: np.ndarray = model_output["translation"] # (S, 3)
         | 
| 78 | 
            +
            ctrl_rotation: np.ndarray = model_output["rotation"] # (S, 3, 3)
         | 
| 79 | 
            +
            ctrl_gripper: np.ndarray = model_output["gripper"] # (S, 1)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            ```
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            ## Action space
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            SPEAR-1 predicts action chunks of delta end-effector positions. Each step in the predicted action chunk is relative to the input state. 
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            Given the current end-effector position `[R, t]` and a model prediction `A_rel = [[R_1, t_1], ..., [R_n, t_n]]`, absolute end effector pose commands can be computed as:
         | 
| 88 | 
            +
            ```
         | 
| 89 | 
            +
            A_abs = [[R * R_1, t + t_1], ..., [R * R_n, t * t_n]]
         | 
| 90 | 
            +
            ```
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            ## Community Feedback
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            We welcome feedback from the community to help improve SPEAR-1. If you have suggestions, encounter any issues, or have ideas for improvements, please contact us.
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            ## Summary
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            - __Model type__: Vision-Language-Action with flow-matching action decoding
         | 
| 99 | 
            +
            - __Contact__: [email protected]
         | 
| 100 | 
            +
            - __License__: Gemma Terms of Use
         | 
    	
        common_spear.py
    ADDED
    
    | @@ -0,0 +1,702 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import collections.abc
         | 
| 2 | 
            +
            import dataclasses
         | 
| 3 | 
            +
            import enum
         | 
| 4 | 
            +
            import inspect
         | 
| 5 | 
            +
            import types
         | 
| 6 | 
            +
            from collections.abc import Mapping as MappingABC
         | 
| 7 | 
            +
            from functools import cached_property
         | 
| 8 | 
            +
            from typing import (
         | 
| 9 | 
            +
                Any,
         | 
| 10 | 
            +
                Callable,
         | 
| 11 | 
            +
                Dict,
         | 
| 12 | 
            +
                Iterable,
         | 
| 13 | 
            +
                List,
         | 
| 14 | 
            +
                Mapping,
         | 
| 15 | 
            +
                Optional,
         | 
| 16 | 
            +
                Sequence,
         | 
| 17 | 
            +
                Tuple,
         | 
| 18 | 
            +
                Type,
         | 
| 19 | 
            +
                Union,
         | 
| 20 | 
            +
            )
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            import torch
         | 
| 23 | 
            +
            import transformers
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class StrEnum(str, enum.Enum):
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                A minimal drop-in replacement for backports.strenum.StrEnum
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def __str__(self):
         | 
| 32 | 
            +
                    return str(self.value)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def __new__(cls, value):
         | 
| 35 | 
            +
                    # Create new instance that properly handles string initialization
         | 
| 36 | 
            +
                    if isinstance(value, str):
         | 
| 37 | 
            +
                        obj = str.__new__(cls, value)
         | 
| 38 | 
            +
                        obj._value_ = value
         | 
| 39 | 
            +
                        return obj
         | 
| 40 | 
            +
                    return super().__new__(cls, value)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                @classmethod
         | 
| 43 | 
            +
                def _missing_(cls, value):
         | 
| 44 | 
            +
                    # Enhanced lookup by string value with better error handling
         | 
| 45 | 
            +
                    if isinstance(value, str):
         | 
| 46 | 
            +
                        for member in cls:
         | 
| 47 | 
            +
                            if member.value == value:
         | 
| 48 | 
            +
                                return member
         | 
| 49 | 
            +
                    # Return None to let enum handle the KeyError
         | 
| 50 | 
            +
                    return None
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def __eq__(self, other):
         | 
| 53 | 
            +
                    # Allow comparison with string values
         | 
| 54 | 
            +
                    if isinstance(other, str):
         | 
| 55 | 
            +
                        return self.value == other
         | 
| 56 | 
            +
                    return super().__eq__(other)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def __hash__(self):
         | 
| 59 | 
            +
                    # Ensure consistent hashing
         | 
| 60 | 
            +
                    return hash(self.value)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            class _cached_classproperty:
         | 
| 64 | 
            +
                def __init__(self, func):
         | 
| 65 | 
            +
                    self.func = func
         | 
| 66 | 
            +
                    self._values = {}
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def __get__(self, obj, klass):
         | 
| 69 | 
            +
                    if klass not in self._values.keys():
         | 
| 70 | 
            +
                        self._values[klass] = self.func.__get__(obj, klass)()
         | 
| 71 | 
            +
                    return self._values[klass]
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            def cached_classproperty(func):
         | 
| 75 | 
            +
                if not isinstance(func, (classmethod, staticmethod)):
         | 
| 76 | 
            +
                    func = classmethod(func)
         | 
| 77 | 
            +
                return _cached_classproperty(func)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            @dataclasses.dataclass
         | 
| 81 | 
            +
            class Dataclass:
         | 
| 82 | 
            +
                def __post_init__(self):
         | 
| 83 | 
            +
                    pass
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                @classmethod
         | 
| 86 | 
            +
                def make_empty(cls) -> "Dataclass":
         | 
| 87 | 
            +
                    return cls(
         | 
| 88 | 
            +
                        **{
         | 
| 89 | 
            +
                            k: (v.make_empty() if inspect.isclass(v) and issubclass(v, Dataclass) else None)
         | 
| 90 | 
            +
                            for (k, v) in cls.types.items()
         | 
| 91 | 
            +
                        }
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                @cached_classproperty
         | 
| 95 | 
            +
                def fields(cls) -> Tuple[dataclasses.Field, ...]:
         | 
| 96 | 
            +
                    """Returns a sorted list of the Field objects"""
         | 
| 97 | 
            +
                    return tuple(sorted(dataclasses.fields(cls), key=lambda x: x.name))
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                @cached_classproperty
         | 
| 100 | 
            +
                def types(cls) -> Dict[str, type]:
         | 
| 101 | 
            +
                    return {f.name: f.type for f in cls.fields}
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def as_json(self, recursive: bool = True) -> dict:
         | 
| 104 | 
            +
                    return {k: v.as_json() if isinstance(v, Dataclass) and recursive else v for (k, v) in self.items()}
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                @classmethod
         | 
| 107 | 
            +
                def keys(cls) -> List[str]:
         | 
| 108 | 
            +
                    return [field.name for field in cls.fields]
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                def values(self):
         | 
| 111 | 
            +
                    return [getattr(self, field.name) for field in self.fields]
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                def items(self, recursive: bool = False):
         | 
| 114 | 
            +
                    for key, value in zip(self.keys(), self.values(), strict=True):
         | 
| 115 | 
            +
                        if recursive and isinstance(value, Dataclass):
         | 
| 116 | 
            +
                            for subkey, subvalue in value.items(recursive=True):
         | 
| 117 | 
            +
                                yield (f"{key}.{subkey}", subvalue)
         | 
| 118 | 
            +
                        else:
         | 
| 119 | 
            +
                            yield (key, value)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                def replace(self, **kwargs):
         | 
| 122 | 
            +
                    """
         | 
| 123 | 
            +
                    Return a new instance of Dataclass with the kwargs overwritten.
         | 
| 124 | 
            +
                    """
         | 
| 125 | 
            +
                    kwargs = maybe_chained_keys_to_nested_dict(kwargs)
         | 
| 126 | 
            +
                    data = self.as_json(recursive=False)
         | 
| 127 | 
            +
                    for key, value in kwargs.items():
         | 
| 128 | 
            +
                        value_type = self.types.get(key, None)
         | 
| 129 | 
            +
                        if value_type is None:
         | 
| 130 | 
            +
                            raise KeyError(f"Dataclass {self.__class__} does not have a field {key}")
         | 
| 131 | 
            +
                        value_type = get_maybe_optional_type(value_type)
         | 
| 132 | 
            +
                        if inspect.isclass(value_type) and issubclass(value_type, Dataclass):
         | 
| 133 | 
            +
                            if isinstance(value, dict):
         | 
| 134 | 
            +
                                data[key] = data[key].replace(**value)
         | 
| 135 | 
            +
                            else:
         | 
| 136 | 
            +
                                data[key] = value
         | 
| 137 | 
            +
                        else:
         | 
| 138 | 
            +
                            data[key] = value
         | 
| 139 | 
            +
                    return self.__class__(**data)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                def apply(self, fcn: Callable, recursive: bool = True, skip_nones: bool = False) -> "Dataclass":
         | 
| 142 | 
            +
                    def fcn_wrapper(value: Any) -> Any:
         | 
| 143 | 
            +
                        if value is None and skip_nones:
         | 
| 144 | 
            +
                            return None
         | 
| 145 | 
            +
                        if isinstance(value, dict) and recursive:
         | 
| 146 | 
            +
                            return type(value)(**{k: fcn(v) for (k, v) in value.items()})
         | 
| 147 | 
            +
                        if isinstance(value, (list, tuple)) and recursive:
         | 
| 148 | 
            +
                            return type(value)([fcn(v) for v in value])
         | 
| 149 | 
            +
                        if isinstance(value, Dataclass) and recursive:
         | 
| 150 | 
            +
                            return value.apply(fcn, recursive=True, skip_nones=skip_nones)
         | 
| 151 | 
            +
                        return fcn(value)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    return self.__class__(**{key: fcn_wrapper(value) for (key, value) in self.items()})
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                def __getitem__(self, index) -> "Dataclass":
         | 
| 156 | 
            +
                    def extract(obj):
         | 
| 157 | 
            +
                        if obj is None:
         | 
| 158 | 
            +
                            return None
         | 
| 159 | 
            +
                        if isinstance(obj, torch.Tensor):
         | 
| 160 | 
            +
                            return obj[index]
         | 
| 161 | 
            +
                        raise ValueError(f"Cannot slice {obj.__class__.__name__} object")
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    return self.apply(extract)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            class Config:
         | 
| 167 | 
            +
                def __init__(self, **kwargs):
         | 
| 168 | 
            +
                    self._apply_defaults()
         | 
| 169 | 
            +
                    self._set_attributes(**kwargs)
         | 
| 170 | 
            +
                    super().__init__()
         | 
| 171 | 
            +
                    self.__post_init__()
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                def _apply_defaults(self):
         | 
| 174 | 
            +
                    """
         | 
| 175 | 
            +
                    Initializes all annotated fields with defaults or sensible instances.
         | 
| 176 | 
            +
                    """
         | 
| 177 | 
            +
                    annotations = getattr(self, "__annotations__", {})
         | 
| 178 | 
            +
                    for key, type_hint in annotations.items():
         | 
| 179 | 
            +
                        # Skip if already set via class-level value or __init__ kwarg
         | 
| 180 | 
            +
                        if hasattr(self, key):
         | 
| 181 | 
            +
                            continue
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                        # Case 1: class variable has a default (declared at class level)
         | 
| 184 | 
            +
                        if key in self.__class__.__dict__:
         | 
| 185 | 
            +
                            setattr(self, key, getattr(self.__class__, key))
         | 
| 186 | 
            +
                            continue
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                        # Case 2: if the type is another Config subclass, default-construct it
         | 
| 189 | 
            +
                        if inspect.isclass(type_hint) and issubclass(type_hint, Config):
         | 
| 190 | 
            +
                            setattr(self, key, type_hint())
         | 
| 191 | 
            +
                            continue
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                        # Case 3: fallback None (or empty dict for mappings)
         | 
| 194 | 
            +
                        if hasattr(type_hint, "__origin__") and type_hint.__origin__ in (
         | 
| 195 | 
            +
                            dict,
         | 
| 196 | 
            +
                            Dict,
         | 
| 197 | 
            +
                            MappingABC,
         | 
| 198 | 
            +
                        ):
         | 
| 199 | 
            +
                            setattr(self, key, {})
         | 
| 200 | 
            +
                        else:
         | 
| 201 | 
            +
                            setattr(self, key, None)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                def _set_attributes(self, **kwargs):
         | 
| 204 | 
            +
                    subconfig_types = self._subconfig_types
         | 
| 205 | 
            +
                    for key, value in kwargs.items():
         | 
| 206 | 
            +
                        if key in subconfig_types:
         | 
| 207 | 
            +
                            if not isinstance(value, Mapping):
         | 
| 208 | 
            +
                                raise ValueError(
         | 
| 209 | 
            +
                                    f"{self.__class__.__name__}.{key} expects dict-like object for nested config, but got: {value}"
         | 
| 210 | 
            +
                                )
         | 
| 211 | 
            +
                            setattr(self, key, subconfig_types[key](**value))
         | 
| 212 | 
            +
                        else:
         | 
| 213 | 
            +
                            setattr(self, key, value)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                def keys(self) -> List[str]:
         | 
| 216 | 
            +
                    """Get all annotated keys including those from parent classes."""
         | 
| 217 | 
            +
                    all_keys = {}
         | 
| 218 | 
            +
                    # Walk through MRO in reverse to respect inheritance order
         | 
| 219 | 
            +
                    for cls in reversed(self.__class__.__mro__):
         | 
| 220 | 
            +
                        if cls is object:
         | 
| 221 | 
            +
                            continue
         | 
| 222 | 
            +
                        all_keys.update(getattr(cls, "__annotations__", {}))
         | 
| 223 | 
            +
                    return list(all_keys.keys())
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                def items(self) -> Iterable[Tuple[str, Any]]:
         | 
| 226 | 
            +
                    for key in self.keys():
         | 
| 227 | 
            +
                        yield (key, getattr(self, key))
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                @cached_classproperty
         | 
| 230 | 
            +
                def _subconfig_types(cls) -> dict[str, Type]:
         | 
| 231 | 
            +
                    keys = {
         | 
| 232 | 
            +
                        key: value
         | 
| 233 | 
            +
                        for (key, value) in cls.__annotations__.items()
         | 
| 234 | 
            +
                        if inspect.isclass(value) and issubclass(value, Config)
         | 
| 235 | 
            +
                    }
         | 
| 236 | 
            +
                    for base in cls.__bases__:
         | 
| 237 | 
            +
                        if not issubclass(base, Config):
         | 
| 238 | 
            +
                            continue
         | 
| 239 | 
            +
                        keys = {**keys, **base._subconfig_types}
         | 
| 240 | 
            +
                    return keys
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                def __post_init__(self):
         | 
| 243 | 
            +
                    pass
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                def as_json(self) -> dict:
         | 
| 246 | 
            +
                    data = {}
         | 
| 247 | 
            +
                    for key, value in self.items():
         | 
| 248 | 
            +
                        if isinstance(value, Config):
         | 
| 249 | 
            +
                            data[key] = value.as_json()
         | 
| 250 | 
            +
                        elif (
         | 
| 251 | 
            +
                            isinstance(value, collections.abc.Sequence)
         | 
| 252 | 
            +
                            and len(value) > 0
         | 
| 253 | 
            +
                            and isinstance(value[0], Config)
         | 
| 254 | 
            +
                        ):
         | 
| 255 | 
            +
                            data[key] = [v.as_json() for v in value]
         | 
| 256 | 
            +
                        elif (
         | 
| 257 | 
            +
                            isinstance(value, collections.abc.Mapping)
         | 
| 258 | 
            +
                            and len(value) > 0
         | 
| 259 | 
            +
                            and isinstance(next(iter(value.values())), Config)
         | 
| 260 | 
            +
                        ):
         | 
| 261 | 
            +
                            data[key] = {k: v.as_json() for k, v in value.items()}
         | 
| 262 | 
            +
                        else:
         | 
| 263 | 
            +
                            data[key] = value
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    return data
         | 
| 266 | 
            +
             | 
| 267 | 
            +
             | 
| 268 | 
            +
            class HFConfigMixin(transformers.PretrainedConfig):
         | 
| 269 | 
            +
                """
         | 
| 270 | 
            +
                Bridge between your Config system and HF PretrainedConfig.
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                Usage:
         | 
| 273 | 
            +
                    class SPEAR1Config(HFConfigMixin, Config):
         | 
| 274 | 
            +
                        model_type = "spear1"
         | 
| 275 | 
            +
                        processor_config: PaliGemmaProcessorConfig
         | 
| 276 | 
            +
                        ...
         | 
| 277 | 
            +
                """
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                def __init__(self, **kwargs):
         | 
| 280 | 
            +
                    # Let HF's machinery initialize its own attributes / defaults first.
         | 
| 281 | 
            +
                    # PretrainedConfig.__init__ will set things like `model_type`,
         | 
| 282 | 
            +
                    # `_name_or_path`, `architectures`, and keep a `kwargs`->dict of extra items.
         | 
| 283 | 
            +
                    super().__init__(**kwargs)
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                    # Now initialize your Config behavior: set defaults and construct nested configs.
         | 
| 286 | 
            +
                    # We call Config.__init__ explicitly because HFConfigMixin inherits from PretrainedConfig,
         | 
| 287 | 
            +
                    # and the user's concrete class will use multiple-inheritance with Config.
         | 
| 288 | 
            +
                    # (This approach mirrors the earlier MRO design: class Concrete(HFConfigMixin, Config).)
         | 
| 289 | 
            +
                    # We pass kwargs again so nested configs get overridden by user kwargs.
         | 
| 290 | 
            +
                    # Note: Config.__init__ itself calls super().__init__() — but because we are calling
         | 
| 291 | 
            +
                    # Config.__init__ directly (not via super()) the MRO won't re-call PretrainedConfig.__init__ here.
         | 
| 292 | 
            +
                    # (I.e., we are deliberately calling the concrete base initializer.)
         | 
| 293 | 
            +
                    Config.__init__(self, **kwargs)  # type: ignore[name-defined]
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                def to_dict(self) -> Dict[str, Any]:
         | 
| 296 | 
            +
                    """
         | 
| 297 | 
            +
                    Merge HF PretrainedConfig serialization and Config.as_json().
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    Strategy:
         | 
| 300 | 
            +
                      1. Take HF dict (super().to_dict()) so HF metadata/defaults are present.
         | 
| 301 | 
            +
                      2. Take our nested config dict (Config.as_json(self)).
         | 
| 302 | 
            +
                      3. Update the HF dict with our nested config dict so annotated fields
         | 
| 303 | 
            +
                         (nested configs, lists/dicts that should be recursively serialized)
         | 
| 304 | 
            +
                         take precedence.
         | 
| 305 | 
            +
                    """
         | 
| 306 | 
            +
                    # HF's representation (contains model_type, etc.). This is trusted HF serialization.
         | 
| 307 | 
            +
                    hf = super().to_dict()
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    # Our nested config representation (recursively serializes Config objects).
         | 
| 310 | 
            +
                    # Do not call self.to_dict() because that would recurse back here.
         | 
| 311 | 
            +
                    cfg_json = Config.as_json(self)  # type: ignore[name-defined]
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    # Merge: prefer cfg_json values for keys present in our config (so nested configs
         | 
| 314 | 
            +
                    # are represented as dicts rather than raw objects or omitted).
         | 
| 315 | 
            +
                    merged: Dict[str, Any] = dict(hf)
         | 
| 316 | 
            +
                    merged.update(cfg_json)
         | 
| 317 | 
            +
                    return merged
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                @classmethod
         | 
| 320 | 
            +
                def from_dict(
         | 
| 321 | 
            +
                    cls: Type["HFConfigMixin"],
         | 
| 322 | 
            +
                    config_dict: Dict[str, Any],
         | 
| 323 | 
            +
                    **kwargs,
         | 
| 324 | 
            +
                ) -> "HFConfigMixin":
         | 
| 325 | 
            +
                    """
         | 
| 326 | 
            +
                    Construct by delegating to the class constructor — that will instantiate nested configs.
         | 
| 327 | 
            +
                    This is simple and consistent with PretrainedConfig.from_dict/from_pretrained behavior.
         | 
| 328 | 
            +
                    """
         | 
| 329 | 
            +
                    return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    instance = cls(**config_dict)
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                    if return_unused_kwargs:
         | 
| 334 | 
            +
                        # Return tuple of (instance, unused_kwargs) if requested
         | 
| 335 | 
            +
                        # Since we consume everything in __init__, unused is typically empty
         | 
| 336 | 
            +
                        return instance, {}
         | 
| 337 | 
            +
                    return instance
         | 
| 338 | 
            +
             | 
| 339 | 
            +
             | 
| 340 | 
            +
            class Configurable:
         | 
| 341 | 
            +
                def __init__(self, config: Config):
         | 
| 342 | 
            +
                    self._config = config
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                @property
         | 
| 345 | 
            +
                def config(self) -> Config:
         | 
| 346 | 
            +
                    return self._config
         | 
| 347 | 
            +
             | 
| 348 | 
            +
             | 
| 349 | 
            +
            class RotationFormat(StrEnum):
         | 
| 350 | 
            +
                """Determines how rotations will be encoded in the loaded batch"""
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                EULER = "euler"
         | 
| 353 | 
            +
                QUATERNION = "quaternion"
         | 
| 354 | 
            +
                ROTMAT = "rotmat"
         | 
| 355 | 
            +
             | 
| 356 | 
            +
             | 
| 357 | 
            +
            class ResizeMode(StrEnum):
         | 
| 358 | 
            +
                """
         | 
| 359 | 
            +
                Different modes for resizing images.
         | 
| 360 | 
            +
                """
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                MATCH_WIDTH = "match_width"
         | 
| 363 | 
            +
                MATCH_HEIGHT = "match_height"
         | 
| 364 | 
            +
                MATCH_MAX = "match_max"
         | 
| 365 | 
            +
                NAIVE = "naive"
         | 
| 366 | 
            +
                SMART = "smart"
         | 
| 367 | 
            +
                PAD = "pad"
         | 
| 368 | 
            +
                CROP = "crop"
         | 
| 369 | 
            +
             | 
| 370 | 
            +
             | 
| 371 | 
            +
            class Normalization(StrEnum):
         | 
| 372 | 
            +
                """Action normalization types"""
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                NONE = "none"
         | 
| 375 | 
            +
                BOUNDS = "bounds"
         | 
| 376 | 
            +
                BOUNDS_Q99 = "bounds_q99"
         | 
| 377 | 
            +
                MEAN = "mean"
         | 
| 378 | 
            +
             | 
| 379 | 
            +
             | 
| 380 | 
            +
            def expand_dims(tensor: torch.Tensor, ndim: int, order: Sequence[int]) -> torch.Tensor:
         | 
| 381 | 
            +
                """
         | 
| 382 | 
            +
                Expand the dimensions of `tensor` to `ndim` such that all new dimensions have size of 1
         | 
| 383 | 
            +
                Args:
         | 
| 384 | 
            +
                    tensor: torch.Tensor of any shape
         | 
| 385 | 
            +
                    ndim: Number of output dimensions. Must be >= `tensor.ndim`
         | 
| 386 | 
            +
                    order: Sequence of size `tensor.ndim + 1`. Contains only values of 1 and a single value of -1,
         | 
| 387 | 
            +
                        indicating where the new `ndim - tensor.ndim` dimensions will be inserted
         | 
| 388 | 
            +
                Returns:
         | 
| 389 | 
            +
                    torch.Tensor with dimensions `ndim`, a view of `tensor`
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                Ex:
         | 
| 392 | 
            +
                    expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[1, -1, 1, 1]).shape -> [2, 1, 1, 3, 4]
         | 
| 393 | 
            +
                    expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[-1, 1, 1, 1]).shape -> [1, 1, 2, 3, 4]
         | 
| 394 | 
            +
                    expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[1, 1, 1, -1]).shape -> [2, 3, 4, 1, 1]
         | 
| 395 | 
            +
                """
         | 
| 396 | 
            +
                assert tensor.ndim <= ndim, f"{tensor.ndim} > {ndim}; shape={tensor.shape}"
         | 
| 397 | 
            +
                assert len(order) == tensor.ndim + 1, f"{len(order)} != {tensor.ndim + 1}; shape={tensor.shape}"
         | 
| 398 | 
            +
                order = list(order)
         | 
| 399 | 
            +
                assert order.count(-1) == 1, "Order must have exactly one value of -1"
         | 
| 400 | 
            +
                assert order.count(1) == len(order) - 1, "Order must have exactly len(order) - 1 values of 1"
         | 
| 401 | 
            +
                if tensor.ndim == ndim:
         | 
| 402 | 
            +
                    return tensor
         | 
| 403 | 
            +
                insert_index = order.index(-1)
         | 
| 404 | 
            +
                view = list(tensor.shape[:insert_index]) + [1] * (ndim - tensor.ndim) + list(tensor.shape[insert_index:])
         | 
| 405 | 
            +
                tensor = tensor.view(view)
         | 
| 406 | 
            +
                return tensor
         | 
| 407 | 
            +
             | 
| 408 | 
            +
             | 
| 409 | 
            +
            def merge_dicts_recursive(dict_1: Dict[str, Any], dict_2: Dict[str, Any]) -> Dict[str, Any]:
         | 
| 410 | 
            +
                """
         | 
| 411 | 
            +
                Merges dict_1 with dict_2 recursively.
         | 
| 412 | 
            +
                Handles clashing keys:
         | 
| 413 | 
            +
                    1. If both values are dicts, merges them recursively
         | 
| 414 | 
            +
                    2. If any value is not a dict, raises ValueError
         | 
| 415 | 
            +
                """
         | 
| 416 | 
            +
                merged = dict(dict_1)
         | 
| 417 | 
            +
                for key, value in dict_2.items():
         | 
| 418 | 
            +
                    if key in merged:
         | 
| 419 | 
            +
                        if not type(merged[key]) is type(value) is dict:
         | 
| 420 | 
            +
                            raise ValueError(f"Multiple values provided for key {key}: {merged[key]} and {value}")
         | 
| 421 | 
            +
                        merged[key] = merge_dicts_recursive(merged[key], value)
         | 
| 422 | 
            +
                    else:
         | 
| 423 | 
            +
                        merged[key] = value
         | 
| 424 | 
            +
                return merged
         | 
| 425 | 
            +
             | 
| 426 | 
            +
             | 
| 427 | 
            +
            def maybe_chained_keys_to_nested_dict(data: Dict[str, Any]) -> Dict[str, Any]:
         | 
| 428 | 
            +
                """Converts a dict with keys of the form "key1.key2.key3" to a nested dict"""
         | 
| 429 | 
            +
                unpacked_data: Dict[str, Any] = {}
         | 
| 430 | 
            +
                for key, value in data.items():
         | 
| 431 | 
            +
                    if "." not in key:
         | 
| 432 | 
            +
                        unpacked_data = merge_dicts_recursive(unpacked_data, {key: value})
         | 
| 433 | 
            +
                    else:
         | 
| 434 | 
            +
                        (mainkey, subkey) = key.split(".", maxsplit=1)
         | 
| 435 | 
            +
                        nested_value = maybe_chained_keys_to_nested_dict({subkey: value})
         | 
| 436 | 
            +
                        unpacked_data = merge_dicts_recursive(unpacked_data, {mainkey: nested_value})
         | 
| 437 | 
            +
                return unpacked_data
         | 
| 438 | 
            +
             | 
| 439 | 
            +
             | 
| 440 | 
            +
            def annotation_is_union(type_value: Type) -> bool:
         | 
| 441 | 
            +
                return getattr(type_value, "__origin__", None) is Union or type(type_value) is types.UnionType
         | 
| 442 | 
            +
             | 
| 443 | 
            +
             | 
| 444 | 
            +
            def annotation_is_optional(type_value: Type) -> bool:
         | 
| 445 | 
            +
                if annotation_is_union(type_value):
         | 
| 446 | 
            +
                    union_args = set(type_value.__args__)
         | 
| 447 | 
            +
                    if len(union_args) == 2 and type(None) in union_args:
         | 
| 448 | 
            +
                        return True
         | 
| 449 | 
            +
                return False
         | 
| 450 | 
            +
             | 
| 451 | 
            +
             | 
| 452 | 
            +
            def get_maybe_optional_type(type_value: Type[Optional[Any]]) -> Type[Any]:
         | 
| 453 | 
            +
                if annotation_is_optional(type_value):
         | 
| 454 | 
            +
                    type_args = type_value.__args__
         | 
| 455 | 
            +
                    if type_args[1] is type(None):
         | 
| 456 | 
            +
                        return type_args[0]
         | 
| 457 | 
            +
                    return type_args[1]
         | 
| 458 | 
            +
                return type_value
         | 
| 459 | 
            +
             | 
| 460 | 
            +
             | 
| 461 | 
            +
            @dataclasses.dataclass
         | 
| 462 | 
            +
            class RoboticsTarget(Dataclass):
         | 
| 463 | 
            +
                control_tokens_ids: Optional[torch.Tensor]
         | 
| 464 | 
            +
                text_tokens_ids: Optional[torch.Tensor]
         | 
| 465 | 
            +
                translation: torch.Tensor
         | 
| 466 | 
            +
                rotation: torch.Tensor
         | 
| 467 | 
            +
                gripper: torch.Tensor
         | 
| 468 | 
            +
                valid_mask: torch.Tensor
         | 
| 469 | 
            +
             | 
| 470 | 
            +
             | 
| 471 | 
            +
            @dataclasses.dataclass
         | 
| 472 | 
            +
            class RoboticsControlPlan(Dataclass):
         | 
| 473 | 
            +
                translation_m: torch.Tensor
         | 
| 474 | 
            +
                rotmat: torch.Tensor
         | 
| 475 | 
            +
                gripper_prob: torch.Tensor
         | 
| 476 | 
            +
                valid_mask: torch.Tensor
         | 
| 477 | 
            +
             | 
| 478 | 
            +
                def __post_init__(self):
         | 
| 479 | 
            +
                    super().__post_init__()
         | 
| 480 | 
            +
                    assert self.translation_m.ndim == 3, self.translation_m.shape
         | 
| 481 | 
            +
                    assert self.rotmat.ndim == 3, self.rotmat.shape
         | 
| 482 | 
            +
                    assert self.gripper_prob.ndim == 3, self.gripper_prob.shape
         | 
| 483 | 
            +
             | 
| 484 | 
            +
             | 
| 485 | 
            +
            @dataclasses.dataclass
         | 
| 486 | 
            +
            class RoboticsInput(Dataclass):
         | 
| 487 | 
            +
                images: Dict[str, torch.Tensor]
         | 
| 488 | 
            +
                input_ids: torch.Tensor
         | 
| 489 | 
            +
                attn_mask: torch.Tensor
         | 
| 490 | 
            +
                ee_pose_translation: torch.Tensor
         | 
| 491 | 
            +
                ee_pose_rotation: torch.Tensor
         | 
| 492 | 
            +
                gripper: torch.Tensor
         | 
| 493 | 
            +
                joints: torch.Tensor
         | 
| 494 | 
            +
                control_tokens_ids: Optional[torch.Tensor]
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                @property
         | 
| 497 | 
            +
                def inputs_embeds(self) -> Optional[torch.Tensor]:
         | 
| 498 | 
            +
                    return None
         | 
| 499 | 
            +
             | 
| 500 | 
            +
                @property
         | 
| 501 | 
            +
                def past_key_values(self) -> Optional[List[torch.Tensor]]:
         | 
| 502 | 
            +
                    return None
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                @cached_property
         | 
| 505 | 
            +
                def multimodal_indices(self) -> torch.Tensor:
         | 
| 506 | 
            +
                    """
         | 
| 507 | 
            +
                    Returns a torch.Tensor containing only the indices of the batch examples which are multimodal.
         | 
| 508 | 
            +
                    Return shape is [B]
         | 
| 509 | 
            +
                    """
         | 
| 510 | 
            +
                    return torch.arange(self.input_ids.shape[0], dtype=torch.int64, device=self.input_ids.device)
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                @cached_property
         | 
| 513 | 
            +
                def unimodal_indices(self) -> torch.Tensor:
         | 
| 514 | 
            +
                    """
         | 
| 515 | 
            +
                    Returns a torch.Tensor containing only the indices of the batch examples which are unimodal.
         | 
| 516 | 
            +
                    Return shape is [B]
         | 
| 517 | 
            +
                    """
         | 
| 518 | 
            +
                    return torch.tensor([], dtype=torch.int64, device=self.input_ids.device)
         | 
| 519 | 
            +
             | 
| 520 | 
            +
             | 
| 521 | 
            +
            @dataclasses.dataclass
         | 
| 522 | 
            +
            class FlowInput(Dataclass):
         | 
| 523 | 
            +
                timestep: torch.Tensor
         | 
| 524 | 
            +
                translation_t: torch.Tensor
         | 
| 525 | 
            +
                rotation_t: torch.Tensor
         | 
| 526 | 
            +
                gripper_t: torch.Tensor
         | 
| 527 | 
            +
                translation_t0: torch.Tensor
         | 
| 528 | 
            +
                rotation_t0: torch.Tensor
         | 
| 529 | 
            +
                gripper_t0: torch.Tensor
         | 
| 530 | 
            +
             | 
| 531 | 
            +
             | 
| 532 | 
            +
            @dataclasses.dataclass
         | 
| 533 | 
            +
            class RoboticsFlowInput(RoboticsInput):
         | 
| 534 | 
            +
                """Input to the entire Robotics VLM"""
         | 
| 535 | 
            +
             | 
| 536 | 
            +
                flow_input: FlowInput
         | 
| 537 | 
            +
             | 
| 538 | 
            +
             | 
| 539 | 
            +
            @dataclasses.dataclass
         | 
| 540 | 
            +
            class DiffusionInput(Dataclass):
         | 
| 541 | 
            +
                timestep: torch.Tensor
         | 
| 542 | 
            +
                noised_translation: torch.Tensor
         | 
| 543 | 
            +
                noised_rotation: torch.Tensor
         | 
| 544 | 
            +
                noised_gripper: torch.Tensor
         | 
| 545 | 
            +
             | 
| 546 | 
            +
             | 
| 547 | 
            +
            @dataclasses.dataclass
         | 
| 548 | 
            +
            class LLMOutput(Dataclass):
         | 
| 549 | 
            +
                """Fork of transformers.modeling_outputs.CausalLMOutputWithPast"""
         | 
| 550 | 
            +
             | 
| 551 | 
            +
                input_ids: torch.Tensor
         | 
| 552 | 
            +
                logits: Optional[torch.Tensor]
         | 
| 553 | 
            +
                output_ids: Optional[torch.Tensor]
         | 
| 554 | 
            +
                loss: Optional[torch.Tensor]
         | 
| 555 | 
            +
                past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]
         | 
| 556 | 
            +
                hidden_states: List[torch.Tensor]
         | 
| 557 | 
            +
                text_indices: torch.Tensor
         | 
| 558 | 
            +
                image_indices: torch.Tensor
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                @classmethod
         | 
| 561 | 
            +
                def from_transformers(
         | 
| 562 | 
            +
                    cls,
         | 
| 563 | 
            +
                    input_ids: torch.Tensor,
         | 
| 564 | 
            +
                    llm_output: transformers.modeling_outputs.CausalLMOutputWithPast,
         | 
| 565 | 
            +
                    text_indices: Optional[torch.Tensor],
         | 
| 566 | 
            +
                    image_indices: Optional[torch.Tensor],
         | 
| 567 | 
            +
                ) -> "LLMOutput":
         | 
| 568 | 
            +
                    return LLMOutput(
         | 
| 569 | 
            +
                        input_ids=input_ids,
         | 
| 570 | 
            +
                        logits=llm_output.logits,
         | 
| 571 | 
            +
                        output_ids=None,
         | 
| 572 | 
            +
                        loss=llm_output.loss,
         | 
| 573 | 
            +
                        past_key_values=(
         | 
| 574 | 
            +
                            list(llm_output.past_key_values) if llm_output.past_key_values is not None else []
         | 
| 575 | 
            +
                        ),
         | 
| 576 | 
            +
                        hidden_states=(list(llm_output.hidden_states) if llm_output.hidden_states is not None else []),
         | 
| 577 | 
            +
                        text_indices=text_indices,
         | 
| 578 | 
            +
                        image_indices=image_indices,
         | 
| 579 | 
            +
                    )
         | 
| 580 | 
            +
             | 
| 581 | 
            +
                def compress(self) -> "LLMOutput":
         | 
| 582 | 
            +
                    """
         | 
| 583 | 
            +
                    Compress the data contained in the class so it can be moved between CPU and GPU or concatenated
         | 
| 584 | 
            +
                    much faster:
         | 
| 585 | 
            +
                        - hidden_states - huge tensors; take a lot of CPU time to move across devices or concat
         | 
| 586 | 
            +
                        - past_key_values - huge tensors; take a lot of CPU time to move across devices or concat
         | 
| 587 | 
            +
                        - logits - huge last dimension; takes a lot of CPU time to move across devices or concat
         | 
| 588 | 
            +
                    """
         | 
| 589 | 
            +
                    replace: Dict[str, Any] = {
         | 
| 590 | 
            +
                        "hidden_states": [],
         | 
| 591 | 
            +
                        "past_key_values": [],
         | 
| 592 | 
            +
                        "loss": None,
         | 
| 593 | 
            +
                        "input_ids": None,
         | 
| 594 | 
            +
                    }
         | 
| 595 | 
            +
                    if self.logits is not None:
         | 
| 596 | 
            +
                        replace["logits"] = None
         | 
| 597 | 
            +
                        if self.output_ids is None or self.output_ids.shape[1] != self.text_indices.shape[0]:
         | 
| 598 | 
            +
                            replace["output_ids"] = (
         | 
| 599 | 
            +
                                torch.index_select(self.logits, dim=1, index=self.text_indices)
         | 
| 600 | 
            +
                                .argmax(dim=-1)
         | 
| 601 | 
            +
                                .to(dtype=torch.int64)
         | 
| 602 | 
            +
                            )
         | 
| 603 | 
            +
                    return self.replace(**replace)
         | 
| 604 | 
            +
             | 
| 605 | 
            +
             | 
| 606 | 
            +
            @dataclasses.dataclass
         | 
| 607 | 
            +
            class RoboticsOutput(Dataclass):
         | 
| 608 | 
            +
                translation: Optional[torch.Tensor]
         | 
| 609 | 
            +
                rotation: Optional[torch.Tensor]
         | 
| 610 | 
            +
                gripper: Optional[torch.Tensor]
         | 
| 611 | 
            +
                token_logits: Optional[torch.Tensor]
         | 
| 612 | 
            +
                token_ids: Optional[torch.Tensor]
         | 
| 613 | 
            +
                llm_output: LLMOutput
         | 
| 614 | 
            +
             | 
| 615 | 
            +
                def compress(self) -> "RoboticsOutput":
         | 
| 616 | 
            +
                    """
         | 
| 617 | 
            +
                    Compress output and drop unnecessary components to speed up transfer GPU <-> CPU.
         | 
| 618 | 
            +
                    Note that LLM logits can be extremely expensive since their size is [B, S, vocab_size], which
         | 
| 619 | 
            +
                    can reach millions or billions of values for large vocab_size
         | 
| 620 | 
            +
                    """
         | 
| 621 | 
            +
                    replace: Dict[str, Any] = {
         | 
| 622 | 
            +
                        "llm_output": self.llm_output.compress(),
         | 
| 623 | 
            +
                        "token_logits": None,
         | 
| 624 | 
            +
                    }
         | 
| 625 | 
            +
                    if self.token_logits is not None and self.token_ids is None:
         | 
| 626 | 
            +
                        replace["token_ids"] = torch.argmax(self.token_logits, dim=-1)
         | 
| 627 | 
            +
                    return self.replace(**replace)
         | 
| 628 | 
            +
             | 
| 629 | 
            +
             | 
| 630 | 
            +
            @dataclasses.dataclass
         | 
| 631 | 
            +
            class VLMOutput(Dataclass):
         | 
| 632 | 
            +
                llm_output: LLMOutput
         | 
| 633 | 
            +
                vit_tokens: Optional[torch.Tensor]
         | 
| 634 | 
            +
                attn_mask: torch.Tensor
         | 
| 635 | 
            +
             | 
| 636 | 
            +
                def compress(self) -> "VLMOutput":
         | 
| 637 | 
            +
                    """
         | 
| 638 | 
            +
                    Compress output and drop unnecessary components to speed up transfer GPU <-> CPU.
         | 
| 639 | 
            +
                    Note that LLM logits can be extremely expensive since their size is [B, S, vocab_size], which
         | 
| 640 | 
            +
                    can reach millions or billions of values for large vocab_size
         | 
| 641 | 
            +
                    """
         | 
| 642 | 
            +
                    return self.replace(llm_output=self.llm_output.compress())
         | 
| 643 | 
            +
             | 
| 644 | 
            +
             | 
| 645 | 
            +
            def is_quaternion(quaternion: torch.Tensor) -> bool:
         | 
| 646 | 
            +
                return quaternion.shape[-1] == 4
         | 
| 647 | 
            +
             | 
| 648 | 
            +
             | 
| 649 | 
            +
            def quaternion_half_cover(quaternion: torch.Tensor) -> torch.Tensor:
         | 
| 650 | 
            +
                """
         | 
| 651 | 
            +
                Flip quaternions so they cover only a half the space. If the q_w is negative, flip the quaternion.
         | 
| 652 | 
            +
                If q_w is 0, then choose such that the first non-zero component is positive. Note that geometrically,
         | 
| 653 | 
            +
                this doesn't correspond to a single hemisphere of the unit sphere. Follows
         | 
| 654 | 
            +
                https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.as_quat.html#scipy.spatial.transform.Rotation.as_quat
         | 
| 655 | 
            +
                """
         | 
| 656 | 
            +
                assert is_quaternion(quaternion), quaternion.shape
         | 
| 657 | 
            +
                with torch.no_grad():
         | 
| 658 | 
            +
                    is_zero = quaternion == 0
         | 
| 659 | 
            +
                    flip_condition = (
         | 
| 660 | 
            +
                        (quaternion[..., -1:] < 0)
         | 
| 661 | 
            +
                        | is_zero[..., -1:] & (quaternion[..., 0:1] < 0)
         | 
| 662 | 
            +
                        | is_zero[..., -1:] & is_zero[..., 0:1] & (quaternion[..., 1:2] < 0)
         | 
| 663 | 
            +
                        | is_zero[..., -1:] & is_zero[..., 0:1] & is_zero[..., 1:2] & (quaternion[..., 2:3] < 0)
         | 
| 664 | 
            +
                    )
         | 
| 665 | 
            +
                quaternion = torch.where(flip_condition, -quaternion, quaternion)
         | 
| 666 | 
            +
                return quaternion
         | 
| 667 | 
            +
             | 
| 668 | 
            +
             | 
| 669 | 
            +
            def is_rotmat_3x3(rotmat: torch.Tensor) -> bool:
         | 
| 670 | 
            +
                return rotmat.shape[-2:] == torch.Size([3, 3])
         | 
| 671 | 
            +
             | 
| 672 | 
            +
             | 
| 673 | 
            +
            def is_rotmat_9(rotmat: torch.Tensor) -> bool:
         | 
| 674 | 
            +
                return rotmat.shape[-1] == 9
         | 
| 675 | 
            +
             | 
| 676 | 
            +
             | 
| 677 | 
            +
            def rotmat_as_9(rotmat: torch.Tensor) -> torch.Tensor:
         | 
| 678 | 
            +
                """Convert any rotmat input to [..., 9] shape"""
         | 
| 679 | 
            +
                if is_rotmat_9(rotmat):
         | 
| 680 | 
            +
                    return rotmat
         | 
| 681 | 
            +
                if is_rotmat_3x3(rotmat):
         | 
| 682 | 
            +
                    return rotmat.reshape(*rotmat.shape[:-2], 9)
         | 
| 683 | 
            +
                raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix")
         | 
| 684 | 
            +
             | 
| 685 | 
            +
             | 
| 686 | 
            +
            def is_rotmat(rotmat: torch.Tensor) -> bool:
         | 
| 687 | 
            +
                """
         | 
| 688 | 
            +
                Checks if the tensor shape matches that of a rotmat. However, it's not guaranteed the data is a
         | 
| 689 | 
            +
                valid rotmat. `is_orthonormal_rotmat` performs this additional check.
         | 
| 690 | 
            +
                NOTE: This might incorrectly return True if the underlying data is euler angles and accidentally
         | 
| 691 | 
            +
                `rotmat.shape[-2:] == [3, 3]`. This would happen very rarely, but use with caution
         | 
| 692 | 
            +
                """
         | 
| 693 | 
            +
                return is_rotmat_3x3(rotmat) or is_rotmat_9(rotmat)
         | 
| 694 | 
            +
             | 
| 695 | 
            +
             | 
| 696 | 
            +
            def rotmat_as_3x3(rotmat: torch.Tensor) -> torch.Tensor:
         | 
| 697 | 
            +
                """Convert any rotmat input to [..., 3, 3] shape"""
         | 
| 698 | 
            +
                if rotmat.shape[-1] == 9:
         | 
| 699 | 
            +
                    return rotmat.reshape(*rotmat.shape[:-1], 3, 3)
         | 
| 700 | 
            +
                if rotmat.shape[-2:] == torch.Size([3, 3]):
         | 
| 701 | 
            +
                    return rotmat
         | 
| 702 | 
            +
                raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix")
         | 
    	
        config.json
    ADDED
    
    | @@ -0,0 +1,167 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "_auto_class": null,
         | 
| 3 | 
            +
              "_name_or_path": "/scratch/giuliano_albanese/spear-hf",
         | 
| 4 | 
            +
              "architectures": [
         | 
| 5 | 
            +
                "SPEAR1"
         | 
| 6 | 
            +
              ],
         | 
| 7 | 
            +
              "attribute_map": {},
         | 
| 8 | 
            +
              "auto_map": {
         | 
| 9 | 
            +
                "AutoConfig": "configuration_spear.SPEAR1Config",
         | 
| 10 | 
            +
                "AutoModel": "modeling_spear.SPEAR1"
         | 
| 11 | 
            +
              },
         | 
| 12 | 
            +
              "autoclass": "barrel.pipes.vlams.models.vlams.vlam.VLAM",
         | 
| 13 | 
            +
              "base_config_key": "",
         | 
| 14 | 
            +
              "control_module_config": {
         | 
| 15 | 
            +
                "control_decoder_config": {
         | 
| 16 | 
            +
                  "block_config": {
         | 
| 17 | 
            +
                    "activation": "GELU",
         | 
| 18 | 
            +
                    "attn_implementation": "sdpa",
         | 
| 19 | 
            +
                    "dropout": 0.0,
         | 
| 20 | 
            +
                    "feature_size": 1024,
         | 
| 21 | 
            +
                    "head_dim": 256,
         | 
| 22 | 
            +
                    "hidden_size": 4096,
         | 
| 23 | 
            +
                    "norm": "RMSNorm",
         | 
| 24 | 
            +
                    "num_heads": 8,
         | 
| 25 | 
            +
                    "num_kv_heads": 1,
         | 
| 26 | 
            +
                    "position_embed_config": {
         | 
| 27 | 
            +
                      "base": 10000,
         | 
| 28 | 
            +
                      "cached": true,
         | 
| 29 | 
            +
                      "embedding_dim": 256,
         | 
| 30 | 
            +
                      "num_embeddings": 512
         | 
| 31 | 
            +
                    }
         | 
| 32 | 
            +
                  },
         | 
| 33 | 
            +
                  "num_blocks": 18
         | 
| 34 | 
            +
                },
         | 
| 35 | 
            +
                "noised_control_proj_config": {
         | 
| 36 | 
            +
                  "activation": "SiLU",
         | 
| 37 | 
            +
                  "layers": [
         | 
| 38 | 
            +
                    8,
         | 
| 39 | 
            +
                    2048,
         | 
| 40 | 
            +
                    1024,
         | 
| 41 | 
            +
                    1024
         | 
| 42 | 
            +
                  ],
         | 
| 43 | 
            +
                  "norm": null,
         | 
| 44 | 
            +
                  "time_embed": {
         | 
| 45 | 
            +
                    "activation": "SiLU",
         | 
| 46 | 
            +
                    "layers": [],
         | 
| 47 | 
            +
                    "learnable_features": false,
         | 
| 48 | 
            +
                    "max_period": 10000.0,
         | 
| 49 | 
            +
                    "norm": null,
         | 
| 50 | 
            +
                    "num_features": 1024
         | 
| 51 | 
            +
                  }
         | 
| 52 | 
            +
                },
         | 
| 53 | 
            +
                "robot_state_proj_config": {
         | 
| 54 | 
            +
                  "activation": "SiLU",
         | 
| 55 | 
            +
                  "fourier": false,
         | 
| 56 | 
            +
                  "layers": [
         | 
| 57 | 
            +
                    8,
         | 
| 58 | 
            +
                    1024
         | 
| 59 | 
            +
                  ],
         | 
| 60 | 
            +
                  "mode": "ee_pose_gripper"
         | 
| 61 | 
            +
                },
         | 
| 62 | 
            +
                "rotation_components": 4,
         | 
| 63 | 
            +
                "token_size": 1024
         | 
| 64 | 
            +
              },
         | 
| 65 | 
            +
              "is_composition": false,
         | 
| 66 | 
            +
              "model_type": "spear1",
         | 
| 67 | 
            +
              "processor_config": {
         | 
| 68 | 
            +
                "control_io_config": {
         | 
| 69 | 
            +
                  "future_control_offset_sec": 0.0,
         | 
| 70 | 
            +
                  "future_controls_sequence_length": 5,
         | 
| 71 | 
            +
                  "future_controls_sequence_stride_sec": 0.2,
         | 
| 72 | 
            +
                  "future_frames_sequence_length": 1,
         | 
| 73 | 
            +
                  "future_frames_sequence_stride_sec": null,
         | 
| 74 | 
            +
                  "past_frames_sequence_length": 1,
         | 
| 75 | 
            +
                  "past_frames_stride_sec": null,
         | 
| 76 | 
            +
                  "past_scalars_sequence_length": 1,
         | 
| 77 | 
            +
                  "past_scalars_stride_sec": null,
         | 
| 78 | 
            +
                  "sequence_frames": 1,
         | 
| 79 | 
            +
                  "sequence_frames_stride_sec": null
         | 
| 80 | 
            +
                },
         | 
| 81 | 
            +
                "control_stats_path": "barrel/pipes/vlams/types/control_stats.yaml",
         | 
| 82 | 
            +
                "control_tokenizer_config": {},
         | 
| 83 | 
            +
                "delta_controls": true,
         | 
| 84 | 
            +
                "distribution_hyperparams": {
         | 
| 85 | 
            +
                  "alpha": 1.5,
         | 
| 86 | 
            +
                  "beta": 1.0
         | 
| 87 | 
            +
                },
         | 
| 88 | 
            +
                "eef_control_frame": false,
         | 
| 89 | 
            +
                "image_resize": "smart",
         | 
| 90 | 
            +
                "joints_norm": {
         | 
| 91 | 
            +
                  "high": [
         | 
| 92 | 
            +
                    3.141592653589793,
         | 
| 93 | 
            +
                    3.141592653589793,
         | 
| 94 | 
            +
                    3.141592653589793,
         | 
| 95 | 
            +
                    3.141592653589793,
         | 
| 96 | 
            +
                    3.141592653589793,
         | 
| 97 | 
            +
                    3.141592653589793,
         | 
| 98 | 
            +
                    3.141592653589793
         | 
| 99 | 
            +
                  ],
         | 
| 100 | 
            +
                  "low": [
         | 
| 101 | 
            +
                    -3.141592653589793,
         | 
| 102 | 
            +
                    -3.141592653589793,
         | 
| 103 | 
            +
                    -3.141592653589793,
         | 
| 104 | 
            +
                    -3.141592653589793,
         | 
| 105 | 
            +
                    -3.141592653589793,
         | 
| 106 | 
            +
                    -3.141592653589793,
         | 
| 107 | 
            +
                    -3.141592653589793
         | 
| 108 | 
            +
                  ]
         | 
| 109 | 
            +
                },
         | 
| 110 | 
            +
                "num_inference_steps": 10,
         | 
| 111 | 
            +
                "obs_rotation_norm": "none",
         | 
| 112 | 
            +
                "obs_translation_norm": "bounds_q99",
         | 
| 113 | 
            +
                "observation_stats_path": "barrel/pipes/vlams/types/observation_stats.yaml",
         | 
| 114 | 
            +
                "r0_distribution": "uniform",
         | 
| 115 | 
            +
                "rotation_format": "quaternion",
         | 
| 116 | 
            +
                "rotation_norm": "none",
         | 
| 117 | 
            +
                "sig_min": 0.001,
         | 
| 118 | 
            +
                "timestep_distribution": "beta",
         | 
| 119 | 
            +
                "translation_norm": {
         | 
| 120 | 
            +
                  "high": [
         | 
| 121 | 
            +
                    0.04,
         | 
| 122 | 
            +
                    0.04,
         | 
| 123 | 
            +
                    0.04
         | 
| 124 | 
            +
                  ],
         | 
| 125 | 
            +
                  "low": [
         | 
| 126 | 
            +
                    -0.04,
         | 
| 127 | 
            +
                    -0.04,
         | 
| 128 | 
            +
                    -0.04
         | 
| 129 | 
            +
                  ]
         | 
| 130 | 
            +
                }
         | 
| 131 | 
            +
              },
         | 
| 132 | 
            +
              "sub_configs": {},
         | 
| 133 | 
            +
              "torch_dtype": "float32",
         | 
| 134 | 
            +
              "transformers_version": "4.47.0",
         | 
| 135 | 
            +
              "vlm_config": {
         | 
| 136 | 
            +
                "attn_implementation": "flash_attention_2",
         | 
| 137 | 
            +
                "depth_tokens": 1024,
         | 
| 138 | 
            +
                "lm_head": false,
         | 
| 139 | 
            +
                "mean_resizing": false,
         | 
| 140 | 
            +
                "model_id": "google/paligemma-3b-mix-224",
         | 
| 141 | 
            +
                "paligemma_3d_config": {
         | 
| 142 | 
            +
                  "depth_config": {
         | 
| 143 | 
            +
                    "hf_filename": "moge/moge-vit-large-patch-14-backbone.pt",
         | 
| 144 | 
            +
                    "hf_hub_repo": "nikonikolov/vlams"
         | 
| 145 | 
            +
                  },
         | 
| 146 | 
            +
                  "depth_layers": 4,
         | 
| 147 | 
            +
                  "depth_only": false,
         | 
| 148 | 
            +
                  "mask_prob": 0.0,
         | 
| 149 | 
            +
                  "projection": "features_add"
         | 
| 150 | 
            +
                },
         | 
| 151 | 
            +
                "processor_config": {
         | 
| 152 | 
            +
                  "image_sizes": {
         | 
| 153 | 
            +
                    "main": {
         | 
| 154 | 
            +
                      "height": 210,
         | 
| 155 | 
            +
                      "width": 280
         | 
| 156 | 
            +
                    },
         | 
| 157 | 
            +
                    "wrist": {
         | 
| 158 | 
            +
                      "height": 112,
         | 
| 159 | 
            +
                      "width": 112
         | 
| 160 | 
            +
                    }
         | 
| 161 | 
            +
                  },
         | 
| 162 | 
            +
                  "image_token": "<image>",
         | 
| 163 | 
            +
                  "max_language_tokens": 75
         | 
| 164 | 
            +
                },
         | 
| 165 | 
            +
                "train_only_depth_tokens": false
         | 
| 166 | 
            +
              }
         | 
| 167 | 
            +
            }
         | 
    	
        configuration_spear.py
    ADDED
    
    | @@ -0,0 +1,347 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import collections
         | 
| 2 | 
            +
            import collections.abc
         | 
| 3 | 
            +
            from typing import Any, Dict, List, Optional, Tuple
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from .common_spear import (
         | 
| 8 | 
            +
                Config,
         | 
| 9 | 
            +
                HFConfigMixin,
         | 
| 10 | 
            +
                Normalization,
         | 
| 11 | 
            +
                ResizeMode,
         | 
| 12 | 
            +
                RotationFormat,
         | 
| 13 | 
            +
            )
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class InputSequencingConfig(Config):
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
                past_frames_sequence_length: number of past images needed in a single robot state
         | 
| 19 | 
            +
                past_scalars_sequence_length: number of past scalar state data, e.g. actions, poses, etc,
         | 
| 20 | 
            +
                    needed in a single robot state
         | 
| 21 | 
            +
                past_frames_stride_sec: sampling rate, determines how far apart in time each point in the sequence
         | 
| 22 | 
            +
                    is. If None, ignored and takes the default data collection frequency from the dataset
         | 
| 23 | 
            +
                past_scalars_stride_sec: similar to past_frames_stride_sec
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                sequence_frames: number of temporally-sequential points in a single example in the batch
         | 
| 26 | 
            +
                sequence_frames_stride_sec: sampling rate
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                Understanding sequence_frames:
         | 
| 29 | 
            +
                    TODO: sequences are possibly useful in some rare cases, maybe sequence modeling problems,
         | 
| 30 | 
            +
                        but yet to be confirmed. Keeping for now, but could be removed if proved unnecessary
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    - past_scalars_sequence_length, past_frames_sequence_length, future_controls_sequence_length,
         | 
| 33 | 
            +
                        future_frames_sequence_length are hyperparameters refering to a SINGLE dataset example / 'state'.
         | 
| 34 | 
            +
                        It is assumed that `past_scalars_sequence_length` and `past_frames_sequence_length` are the min
         | 
| 35 | 
            +
                        number of observations that comprise a single 'state'
         | 
| 36 | 
            +
                    - sequence_frames is a hyperparameter refering to the entire learning process. It controls the size
         | 
| 37 | 
            +
                        of the sequence dimension in the batch. It's treated similarly to the batch dimension, with the
         | 
| 38 | 
            +
                        difference that points in the sequence dimensions are temporally aligned. Unlike `past_*`
         | 
| 39 | 
            +
                        attributes, in supervised learning a label is loaded for every point in the sequence dimension
         | 
| 40 | 
            +
                        and the loss usually computed over the entire sequence dimension.
         | 
| 41 | 
            +
                """
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                past_scalars_sequence_length: int = 1
         | 
| 44 | 
            +
                past_frames_sequence_length: int = 1
         | 
| 45 | 
            +
                past_scalars_stride_sec: Optional[float] = None
         | 
| 46 | 
            +
                past_frames_stride_sec: Optional[float] = None
         | 
| 47 | 
            +
                sequence_frames: int = 1
         | 
| 48 | 
            +
                sequence_frames_stride_sec: Optional[float] = None
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def __post_init__(self):
         | 
| 51 | 
            +
                    super().__post_init__()
         | 
| 52 | 
            +
                    assert self.past_scalars_sequence_length >= 1, self.past_scalars_sequence_length
         | 
| 53 | 
            +
                    assert self.past_frames_sequence_length >= 1, self.past_frames_sequence_length
         | 
| 54 | 
            +
                    assert self.sequence_frames >= 1, self.sequence_frames
         | 
| 55 | 
            +
                    if self.past_frames_stride_sec is not None:
         | 
| 56 | 
            +
                        assert self.past_frames_stride_sec >= 0.0, self.past_frames_stride_sec
         | 
| 57 | 
            +
                    if self.past_scalars_stride_sec is not None:
         | 
| 58 | 
            +
                        assert self.past_scalars_stride_sec >= 0.0, self.past_scalars_stride_sec
         | 
| 59 | 
            +
                    if self.sequence_frames_stride_sec is not None:
         | 
| 60 | 
            +
                        assert self.sequence_frames_stride_sec >= 0.0, self.sequence_frames_stride_sec
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def assert_same_past(self) -> None:
         | 
| 63 | 
            +
                    assert (
         | 
| 64 | 
            +
                        self.past_frames_stride_sec == self.past_scalars_stride_sec
         | 
| 65 | 
            +
                    ), f"{self.past_frames_stride_sec} != {self.past_scalars_stride_sec}"
         | 
| 66 | 
            +
                    assert (
         | 
| 67 | 
            +
                        self.past_frames_sequence_length == self.past_scalars_sequence_length
         | 
| 68 | 
            +
                    ), f"{self.past_frames_sequence_length} != {self.past_scalars_sequence_length}"
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            class OutputSequencingConfig(Config):
         | 
| 72 | 
            +
                """
         | 
| 73 | 
            +
                future_controls_sequence_length: number of control steps in the future the model predicts
         | 
| 74 | 
            +
                future_frames_sequence_length: number of future frames the model predicts
         | 
| 75 | 
            +
                    (only relevant for neural networks that learn some sort of a world model)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                future_controls_sequence_stride_sec / future_frames_sequence_stride_sec: sampling rate
         | 
| 78 | 
            +
                    that determines how far apart in time each point in the sequence is. If None,
         | 
| 79 | 
            +
                    ignored and takes the default data collection frequency from the dataset
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                future_control_offset_sec: time interval between the last observation and the first
         | 
| 82 | 
            +
                point at which control is predicted. Serves as a 'causality hyperparameter', allowing
         | 
| 83 | 
            +
                for predicting controls slightly further into the future in environments with dynamics
         | 
| 84 | 
            +
                where the observed effects of an action appear slightly later
         | 
| 85 | 
            +
                """
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                future_controls_sequence_length: int = 1
         | 
| 88 | 
            +
                future_controls_sequence_stride_sec: Optional[float] = None
         | 
| 89 | 
            +
                future_frames_sequence_length: int = 1
         | 
| 90 | 
            +
                future_frames_sequence_stride_sec: Optional[float] = None
         | 
| 91 | 
            +
                future_control_offset_sec: float = 0.0
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def __post_init__(self):
         | 
| 94 | 
            +
                    super().__post_init__()
         | 
| 95 | 
            +
                    assert self.future_controls_sequence_length >= 1, self.future_controls_sequence_length
         | 
| 96 | 
            +
                    assert self.future_frames_sequence_length >= 1, self.future_frames_sequence_length
         | 
| 97 | 
            +
                    assert self.future_control_offset_sec >= 0.0, self.future_control_offset_sec
         | 
| 98 | 
            +
                    if self.future_controls_sequence_stride_sec is not None:
         | 
| 99 | 
            +
                        assert self.future_controls_sequence_stride_sec >= 0.0, self.future_controls_sequence_stride_sec
         | 
| 100 | 
            +
                    if self.future_frames_sequence_stride_sec is not None:
         | 
| 101 | 
            +
                        assert self.future_frames_sequence_stride_sec >= 0.0, self.future_frames_sequence_stride_sec
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            class ControlDataIOConfig(InputSequencingConfig, OutputSequencingConfig):
         | 
| 105 | 
            +
                pass
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            class ControlTokenizerConfig(Config):
         | 
| 109 | 
            +
                pass
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
            class EmptyTokenizerConfig(ControlTokenizerConfig):
         | 
| 113 | 
            +
                pass
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            class VLAMProcessorConfig(Config):
         | 
| 117 | 
            +
                control_io_config: ControlDataIOConfig = ControlDataIOConfig()
         | 
| 118 | 
            +
                obs_translation_norm: Normalization | Dict[str, Tuple[float, float, float]] = Normalization.NONE
         | 
| 119 | 
            +
                obs_rotation_norm: Normalization = Normalization.NONE
         | 
| 120 | 
            +
                translation_norm: Normalization | Dict[str, Tuple[float, float, float]] = Normalization.NONE
         | 
| 121 | 
            +
                rotation_norm: Normalization = Normalization.NONE
         | 
| 122 | 
            +
                joints_norm: Dict[str, Tuple[float, ...]] = {
         | 
| 123 | 
            +
                    "low": (-np.pi,) * 7,
         | 
| 124 | 
            +
                    "high": (np.pi,) * 7,
         | 
| 125 | 
            +
                }
         | 
| 126 | 
            +
                rotation_format: RotationFormat = RotationFormat.QUATERNION
         | 
| 127 | 
            +
                eef_control_frame: bool = False
         | 
| 128 | 
            +
                delta_controls: bool = False
         | 
| 129 | 
            +
                image_resize: ResizeMode = ResizeMode.SMART
         | 
| 130 | 
            +
                control_tokenizer_config: EmptyTokenizerConfig = EmptyTokenizerConfig()
         | 
| 131 | 
            +
                control_stats_path: str = "barrel/pipes/vlams/types/control_stats.yaml"
         | 
| 132 | 
            +
                observation_stats_path: str = "barrel/pipes/vlams/types/observation_stats.yaml"
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def __post_init__(self):
         | 
| 135 | 
            +
                    super().__post_init__()
         | 
| 136 | 
            +
                    if isinstance(self.translation_norm, collections.abc.Mapping):
         | 
| 137 | 
            +
                        assert all((len(value) == 3 for value in self.translation_norm.values())), self.translation_norm
         | 
| 138 | 
            +
                        assert set(self.translation_norm.keys()) in (
         | 
| 139 | 
            +
                            {"low", "high"},
         | 
| 140 | 
            +
                            {"mean", "std"},
         | 
| 141 | 
            +
                        ), self.translation_norm
         | 
| 142 | 
            +
                    assert isinstance(self.joints_norm, collections.abc.Mapping), type(self.joints_norm)
         | 
| 143 | 
            +
                    assert all((len(value) == 7 for value in self.joints_norm.values())), self.joints_norm
         | 
| 144 | 
            +
                    assert set(self.joints_norm.keys()) in (
         | 
| 145 | 
            +
                        {"low", "high"},
         | 
| 146 | 
            +
                        {"mean", "std"},
         | 
| 147 | 
            +
                    ), self.joints_norm
         | 
| 148 | 
            +
             | 
| 149 | 
            +
             | 
| 150 | 
            +
            class RegressionProcessorConfig(VLAMProcessorConfig):
         | 
| 151 | 
            +
                pass
         | 
| 152 | 
            +
             | 
| 153 | 
            +
             | 
| 154 | 
            +
            class PiZeroFlowProcessorConfig(RegressionProcessorConfig):
         | 
| 155 | 
            +
                num_inference_steps: int
         | 
| 156 | 
            +
                r0_distribution: str = "uniform"
         | 
| 157 | 
            +
                timestep_distribution: str
         | 
| 158 | 
            +
                distribution_hyperparams: Dict[str, Any] = {}
         | 
| 159 | 
            +
                sig_min: float = 0.001
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                def __post_init__(self):
         | 
| 162 | 
            +
                    super().__post_init__()
         | 
| 163 | 
            +
                    assert self.r0_distribution in ["normal", "uniform"]
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            class VLMConfig(Config):
         | 
| 167 | 
            +
                pass
         | 
| 168 | 
            +
             | 
| 169 | 
            +
             | 
| 170 | 
            +
            class VLMProcessorConfig(Config):
         | 
| 171 | 
            +
                pass
         | 
| 172 | 
            +
             | 
| 173 | 
            +
             | 
| 174 | 
            +
            class ImageSizeConfig(Config):
         | 
| 175 | 
            +
                width: int
         | 
| 176 | 
            +
                height: int
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                def to_dict(self):
         | 
| 179 | 
            +
                    return {"width": self.width, "height": self.height}
         | 
| 180 | 
            +
             | 
| 181 | 
            +
             | 
| 182 | 
            +
            class PaliGemmaProcessorConfig(Config):
         | 
| 183 | 
            +
                image_token: str = "<image>"
         | 
| 184 | 
            +
                image_sizes: Dict[str, ImageSizeConfig] = {"main": ImageSizeConfig(width=224, height=224)}
         | 
| 185 | 
            +
                max_language_tokens: int = 75
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                def __post_init__(self):
         | 
| 188 | 
            +
                    super().__post_init__()
         | 
| 189 | 
            +
                    self.image_sizes = {
         | 
| 190 | 
            +
                        camera_name: (
         | 
| 191 | 
            +
                            ImageSizeConfig(**camera_image_size)
         | 
| 192 | 
            +
                            if not isinstance(camera_image_size, ImageSizeConfig)
         | 
| 193 | 
            +
                            else camera_image_size
         | 
| 194 | 
            +
                        )
         | 
| 195 | 
            +
                        for camera_name, camera_image_size in self.image_sizes.items()
         | 
| 196 | 
            +
                    }
         | 
| 197 | 
            +
                    for camera_name, camera_image_size in self.image_sizes.items():
         | 
| 198 | 
            +
                        assert camera_image_size.height % 14 == 0, f"{camera_name}: {camera_image_size}"
         | 
| 199 | 
            +
                        assert camera_image_size.width % 14 == 0, f"{camera_name}: {camera_image_size}"
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                @property
         | 
| 202 | 
            +
                def num_image_tokens(self) -> Dict[str, int]:
         | 
| 203 | 
            +
                    return {
         | 
| 204 | 
            +
                        camera_name: camera_image_size.height // 14 * (camera_image_size.width // 14)
         | 
| 205 | 
            +
                        for (camera_name, camera_image_size) in self.image_sizes.items()
         | 
| 206 | 
            +
                    }
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                @property
         | 
| 209 | 
            +
                def is_single_image_size(self) -> bool:
         | 
| 210 | 
            +
                    return (
         | 
| 211 | 
            +
                        len(self.image_sizes) == 1
         | 
| 212 | 
            +
                        or len(set(((image_size.height, image_size.width) for image_size in self.image_sizes.values())))
         | 
| 213 | 
            +
                        == 1
         | 
| 214 | 
            +
                    )
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                @property
         | 
| 217 | 
            +
                def camera_names(self) -> List[str]:
         | 
| 218 | 
            +
                    return list(self.image_sizes.keys())
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                def to_dict(self) -> Dict[str, Any]:
         | 
| 221 | 
            +
                    base_dict = {
         | 
| 222 | 
            +
                        "image_token": self.image_token,
         | 
| 223 | 
            +
                        "max_language_tokens": self.max_language_tokens,
         | 
| 224 | 
            +
                    }
         | 
| 225 | 
            +
                    base_dict["image_sizes"] = {
         | 
| 226 | 
            +
                        camera_name: camera_image_size.to_dict()
         | 
| 227 | 
            +
                        for camera_name, camera_image_size in self.image_sizes.items()
         | 
| 228 | 
            +
                    }
         | 
| 229 | 
            +
                    return base_dict
         | 
| 230 | 
            +
             | 
| 231 | 
            +
             | 
| 232 | 
            +
            class PaliGemmaVLMConfig(Config):
         | 
| 233 | 
            +
                model_id: str = "google/paligemma-3b-mix-224"
         | 
| 234 | 
            +
                attn_implementation: str = "flash_attention_2"
         | 
| 235 | 
            +
                processor_config: PaliGemmaProcessorConfig
         | 
| 236 | 
            +
                lm_head: bool = False
         | 
| 237 | 
            +
                paligemma_3d_config: Dict[str, Any] = {}
         | 
| 238 | 
            +
                depth_tokens: int = 0
         | 
| 239 | 
            +
                train_only_depth_tokens: bool = False
         | 
| 240 | 
            +
                mean_resizing: bool = False
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                def __post_init__(self):
         | 
| 243 | 
            +
                    super().__post_init__()
         | 
| 244 | 
            +
                    if self.train_only_depth_tokens:
         | 
| 245 | 
            +
                        assert self.depth_tokens > 0, self.depth_tokens
         | 
| 246 | 
            +
                    if self.paligemma_3d_config.get("mask_prob", 0.0) != 0.0:
         | 
| 247 | 
            +
                        raise NotImplementedError(
         | 
| 248 | 
            +
                            f"Masking is deprecated, but got mask_prob={self.paligemma_3d_config['mask_prob']}"
         | 
| 249 | 
            +
                        )
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                @property
         | 
| 252 | 
            +
                def paligemma_3d_config_dict(self) -> Dict[str, Any]:
         | 
| 253 | 
            +
                    if len(self.paligemma_3d_config) == 0:
         | 
| 254 | 
            +
                        return {}
         | 
| 255 | 
            +
                    config = dict(self.paligemma_3d_config)
         | 
| 256 | 
            +
                    config["depth_config"] = dict(config["depth_config"])
         | 
| 257 | 
            +
                    config["depth_config"]["image_sizes"] = {
         | 
| 258 | 
            +
                        camera_name: camera_image_size.to_dict()
         | 
| 259 | 
            +
                        for camera_name, camera_image_size in self.processor_config.image_sizes.items()
         | 
| 260 | 
            +
                    }
         | 
| 261 | 
            +
                    return config
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                @property
         | 
| 264 | 
            +
                def with_depth(self) -> bool:
         | 
| 265 | 
            +
                    return len(self.paligemma_3d_config) > 0
         | 
| 266 | 
            +
             | 
| 267 | 
            +
             | 
| 268 | 
            +
            class FourierFeaturesConfig(Config):
         | 
| 269 | 
            +
                num_features: int = 256
         | 
| 270 | 
            +
                learnable_features: bool = False
         | 
| 271 | 
            +
                max_period: float = 10000.0
         | 
| 272 | 
            +
                layers: List[int] = [256, 512, 256]
         | 
| 273 | 
            +
                activation: str = "SiLU"
         | 
| 274 | 
            +
                norm: Optional[str] = None
         | 
| 275 | 
            +
             | 
| 276 | 
            +
             | 
| 277 | 
            +
            class NoisedControlProjectorConfig(Config):
         | 
| 278 | 
            +
                time_embed: FourierFeaturesConfig
         | 
| 279 | 
            +
                layers: List[int] = []
         | 
| 280 | 
            +
                activation: str = "SiLU"
         | 
| 281 | 
            +
                norm: Optional[str] = None
         | 
| 282 | 
            +
             | 
| 283 | 
            +
             | 
| 284 | 
            +
            class RobotStateProjectorConfig(Config):
         | 
| 285 | 
            +
                layers: List[int] = []
         | 
| 286 | 
            +
                mode: str = "none"
         | 
| 287 | 
            +
                activation: str = "GELU"
         | 
| 288 | 
            +
                fourier: bool = False
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                def __post_init__(self):
         | 
| 291 | 
            +
                    super().__post_init__()
         | 
| 292 | 
            +
                    assert self.mode in [
         | 
| 293 | 
            +
                        "ee_pose",
         | 
| 294 | 
            +
                        "ee_pose_gripper",
         | 
| 295 | 
            +
                        "ee_pose_joints",
         | 
| 296 | 
            +
                        "joints",
         | 
| 297 | 
            +
                        "all",
         | 
| 298 | 
            +
                        "none",
         | 
| 299 | 
            +
                    ], self.mode
         | 
| 300 | 
            +
             | 
| 301 | 
            +
             | 
| 302 | 
            +
            class RotaryPositionalEncodingConfig(Config):
         | 
| 303 | 
            +
                num_embeddings: int
         | 
| 304 | 
            +
                embedding_dim: int
         | 
| 305 | 
            +
                base: int = 10000
         | 
| 306 | 
            +
                cached: bool = True
         | 
| 307 | 
            +
             | 
| 308 | 
            +
             | 
| 309 | 
            +
            class PiZeroFlowMatchingDecoderBlockConfig(Config):
         | 
| 310 | 
            +
                feature_size: int
         | 
| 311 | 
            +
                head_dim: int = 128
         | 
| 312 | 
            +
                num_heads: int = 32
         | 
| 313 | 
            +
                num_kv_heads: int = 1
         | 
| 314 | 
            +
                hidden_size: int
         | 
| 315 | 
            +
                activation: str = "GELU"
         | 
| 316 | 
            +
                norm: str = "RMSNorm"
         | 
| 317 | 
            +
                dropout: float = 0.0
         | 
| 318 | 
            +
                attn_implementation: str = "sdpa"
         | 
| 319 | 
            +
                position_embed_config: RotaryPositionalEncodingConfig
         | 
| 320 | 
            +
             | 
| 321 | 
            +
             | 
| 322 | 
            +
            class PiZeroFlowMatchingDecoderConfig(Config):
         | 
| 323 | 
            +
                num_blocks: int
         | 
| 324 | 
            +
                block_config: PiZeroFlowMatchingDecoderBlockConfig
         | 
| 325 | 
            +
             | 
| 326 | 
            +
             | 
| 327 | 
            +
            class PiZeroFlowMatchingModuleConfig(Config):
         | 
| 328 | 
            +
                token_size: int = 1024
         | 
| 329 | 
            +
                noised_control_proj_config: NoisedControlProjectorConfig
         | 
| 330 | 
            +
                robot_state_proj_config: RobotStateProjectorConfig
         | 
| 331 | 
            +
                control_decoder_config: PiZeroFlowMatchingDecoderConfig
         | 
| 332 | 
            +
                rotation_components: int = 3
         | 
| 333 | 
            +
             | 
| 334 | 
            +
             | 
| 335 | 
            +
            class SPEAR1Config(HFConfigMixin, Config):
         | 
| 336 | 
            +
                model_type: str = "spear1"
         | 
| 337 | 
            +
                processor_config: PiZeroFlowProcessorConfig
         | 
| 338 | 
            +
                vlm_config: PaliGemmaVLMConfig
         | 
| 339 | 
            +
                control_module_config: PiZeroFlowMatchingModuleConfig
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                def __init__(self, **kwargs):
         | 
| 342 | 
            +
                    if "auto_map" not in kwargs:
         | 
| 343 | 
            +
                        kwargs["auto_map"] = {
         | 
| 344 | 
            +
                            "AutoConfig": "configuration_spear.SPEAR1Config",
         | 
| 345 | 
            +
                            "AutoModel": "modeling_spear.SPEAR1",
         | 
| 346 | 
            +
                        }
         | 
| 347 | 
            +
                    super().__init__(**kwargs)
         | 
    	
        generation_config.json
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "transformers_version": "4.47.0"
         | 
| 3 | 
            +
            }
         | 
    	
        model-00001-of-00003.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:a0992d3b5ffdc8b896812ed19801bc9ebda65708237681ced90e642c90e0a0d2
         | 
| 3 | 
            +
            size 4962008480
         | 
    	
        model-00002-of-00003.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:db48d29ee9567705a81718181eac6c644d2d996f1e91c497e8c891702050c36e
         | 
| 3 | 
            +
            size 4999821656
         | 
    	
        model-00003-of-00003.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:1c7e1d6dae46553546f53a3c9fa76a8a2d2e07664a575ce38962ae2930eb7562
         | 
| 3 | 
            +
            size 4245980072
         | 
    	
        model.safetensors.index.json
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        modeling_spear.py
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        processing_spear.py
    ADDED
    
    | @@ -0,0 +1,1897 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import collections
         | 
| 2 | 
            +
            import collections.abc
         | 
| 3 | 
            +
            import re
         | 
| 4 | 
            +
            import warnings
         | 
| 5 | 
            +
            from abc import abstractmethod
         | 
| 6 | 
            +
            from functools import cached_property
         | 
| 7 | 
            +
            from typing import Dict, List, Optional, Sequence, Tuple, TypeVar
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import PIL.Image
         | 
| 11 | 
            +
            import roma
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            import torchvision.transforms.v2
         | 
| 14 | 
            +
            import transformers
         | 
| 15 | 
            +
            import yaml
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from .common_spear import (
         | 
| 18 | 
            +
                Configurable,
         | 
| 19 | 
            +
                FlowInput,
         | 
| 20 | 
            +
                Normalization,
         | 
| 21 | 
            +
                ResizeMode,
         | 
| 22 | 
            +
                RoboticsControlPlan,
         | 
| 23 | 
            +
                RoboticsFlowInput,
         | 
| 24 | 
            +
                RoboticsInput,
         | 
| 25 | 
            +
                RoboticsOutput,
         | 
| 26 | 
            +
                RoboticsTarget,
         | 
| 27 | 
            +
                RotationFormat,
         | 
| 28 | 
            +
                expand_dims,
         | 
| 29 | 
            +
                is_quaternion,
         | 
| 30 | 
            +
                is_rotmat,
         | 
| 31 | 
            +
                is_rotmat_3x3,
         | 
| 32 | 
            +
                is_rotmat_9,
         | 
| 33 | 
            +
                quaternion_half_cover,
         | 
| 34 | 
            +
                rotmat_as_3x3,
         | 
| 35 | 
            +
                rotmat_as_9,
         | 
| 36 | 
            +
            )
         | 
| 37 | 
            +
            from .configuration_spear import (
         | 
| 38 | 
            +
                ControlDataIOConfig,
         | 
| 39 | 
            +
                ImageSizeConfig,
         | 
| 40 | 
            +
                PaliGemmaProcessorConfig,
         | 
| 41 | 
            +
            )
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            class VLMProcessor(Configurable):
         | 
| 45 | 
            +
                @abstractmethod
         | 
| 46 | 
            +
                def preprocess_inputs(
         | 
| 47 | 
            +
                    self, chat: List[str], images: Dict[str, List[PIL.Image.Image]]
         | 
| 48 | 
            +
                ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: ...
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                @property
         | 
| 51 | 
            +
                @abstractmethod
         | 
| 52 | 
            +
                def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
         | 
| 53 | 
            +
                    pass
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                @property
         | 
| 56 | 
            +
                @abstractmethod
         | 
| 57 | 
            +
                def image_sizes(self) -> Dict[str, ImageSizeConfig]:
         | 
| 58 | 
            +
                    pass
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            class EmptyTokenizer(Configurable):
         | 
| 62 | 
            +
                """
         | 
| 63 | 
            +
                Takes the LLM hidden states from `llm_layer_indices` and concatenates them to produce the
         | 
| 64 | 
            +
                desired result. Includes the hidden states for the image tokens.
         | 
| 65 | 
            +
                """
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def __init__(self, config, tokenizer: transformers.PreTrainedTokenizerBase) -> None:
         | 
| 68 | 
            +
                    super().__init__(config)
         | 
| 69 | 
            +
                    self.tokenizer = tokenizer
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def __call__(self, *_) -> str:
         | 
| 72 | 
            +
                    return ""
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            def np_unique(
         | 
| 76 | 
            +
                data: np.ndarray,
         | 
| 77 | 
            +
            ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
         | 
| 78 | 
            +
                """
         | 
| 79 | 
            +
                Compute unique elements in data and corresponding indices.
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                np.unique returns the values in a sorted order, even if the source is not sorted. Thus, if you simply
         | 
| 82 | 
            +
                run np.unique on unsorted data, the indices you will get will be invalid.
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                """
         | 
| 85 | 
            +
                (_, indices, inverse) = np.unique(data, return_index=True, return_inverse=True)
         | 
| 86 | 
            +
                (_, indices_of_first_occurence, inverse_indices, counts) = np.unique(
         | 
| 87 | 
            +
                    indices[inverse], return_index=True, return_inverse=True, return_counts=True
         | 
| 88 | 
            +
                )
         | 
| 89 | 
            +
                unique_ids = data[indices_of_first_occurence]
         | 
| 90 | 
            +
                return unique_ids, indices_of_first_occurence, inverse_indices, counts
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            def euler_to_rotmat(angles: torch.Tensor) -> torch.Tensor:
         | 
| 94 | 
            +
                """
         | 
| 95 | 
            +
                Args:
         | 
| 96 | 
            +
                    angles: Euler angles in radians in the format 'xyz', shape [..., 3]
         | 
| 97 | 
            +
                Returns:
         | 
| 98 | 
            +
                    torch.Tensor of shape [..., 3, 3] containing rotation matrices
         | 
| 99 | 
            +
                """
         | 
| 100 | 
            +
                return roma.euler_to_rotmat(convention="xyz", angles=angles, degrees=False)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
             | 
| 103 | 
            +
            def euler_to_unit_quaternion(angles: torch.Tensor) -> torch.Tensor:
         | 
| 104 | 
            +
                """
         | 
| 105 | 
            +
                Args:
         | 
| 106 | 
            +
                    angles: Euler angles in radians in the format 'xyz', shape [..., 3]
         | 
| 107 | 
            +
                Returns:
         | 
| 108 | 
            +
                    torch.Tensor of shape [..., 4] containing unit quaternions
         | 
| 109 | 
            +
                """
         | 
| 110 | 
            +
                return roma.euler_to_unitquat(convention="xyz", angles=angles, degrees=False, normalize=True)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            def normalize_quaternion(quaternion: torch.Tensor, eps: float = 1e-08) -> torch.Tensor:
         | 
| 114 | 
            +
                """
         | 
| 115 | 
            +
                Args:
         | 
| 116 | 
            +
                    quaternion: Unnormalized quaternion, torch.Tensor of shape [..., 4]
         | 
| 117 | 
            +
                    eps: Small constant to prevent division by zero
         | 
| 118 | 
            +
                Returns:
         | 
| 119 | 
            +
                    torch.Tensor of shape [..., 4] of unit quaternions
         | 
| 120 | 
            +
                """
         | 
| 121 | 
            +
                return quaternion / (quaternion.norm(dim=-1, keepdim=True).detach() + eps)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
             | 
| 124 | 
            +
            def quaternion_to_euler(quaternion: torch.Tensor) -> torch.Tensor:
         | 
| 125 | 
            +
                """
         | 
| 126 | 
            +
                Args:
         | 
| 127 | 
            +
                    quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized
         | 
| 128 | 
            +
                Returns:
         | 
| 129 | 
            +
                    torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3)
         | 
| 130 | 
            +
                """
         | 
| 131 | 
            +
                unit_quat = normalize_quaternion(quaternion)
         | 
| 132 | 
            +
                rotmat = roma.unitquat_to_euler(convention="xyz", quat=unit_quat, as_tuple=False, degrees=False)
         | 
| 133 | 
            +
                return rotmat
         | 
| 134 | 
            +
             | 
| 135 | 
            +
             | 
| 136 | 
            +
            def quaternion_to_rotmat(quaternion: torch.Tensor) -> torch.Tensor:
         | 
| 137 | 
            +
                """
         | 
| 138 | 
            +
                Args:
         | 
| 139 | 
            +
                    quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized
         | 
| 140 | 
            +
                Returns:
         | 
| 141 | 
            +
                    torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3)
         | 
| 142 | 
            +
                """
         | 
| 143 | 
            +
                unit_quat = normalize_quaternion(quaternion)
         | 
| 144 | 
            +
                rotmat = roma.unitquat_to_rotmat(unit_quat)
         | 
| 145 | 
            +
                return rotmat
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            def rotmat_to_unit_quaternion(rotmat: torch.Tensor) -> torch.Tensor:
         | 
| 149 | 
            +
                """
         | 
| 150 | 
            +
                Args:
         | 
| 151 | 
            +
                    rotmat: Batch of rotation matrices, shape [..., 3, 3]
         | 
| 152 | 
            +
                Returns:
         | 
| 153 | 
            +
                    Batch of unit quaternions, shape [..., 4]
         | 
| 154 | 
            +
                """
         | 
| 155 | 
            +
                rotmat = rotmat_as_3x3(rotmat)
         | 
| 156 | 
            +
                return roma.rotmat_to_unitquat(rotmat)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
             | 
| 159 | 
            +
            def rotmat_to_euler(rotmat: torch.Tensor) -> torch.Tensor:
         | 
| 160 | 
            +
                """
         | 
| 161 | 
            +
                Args:
         | 
| 162 | 
            +
                    rotmat: Batch of rotation matrices, shape [..., 3, 3]
         | 
| 163 | 
            +
                Returns:
         | 
| 164 | 
            +
                    Batch of Euler angles in radiant, shape [..., 3]
         | 
| 165 | 
            +
                """
         | 
| 166 | 
            +
                rotmat = rotmat_as_3x3(rotmat)
         | 
| 167 | 
            +
                return roma.rotmat_to_euler(convention="xyz", rotmat=rotmat, as_tuple=False, degrees=False)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
             | 
| 170 | 
            +
            def symmetric_orthogonalization(x: torch.Tensor) -> torch.Tensor:
         | 
| 171 | 
            +
                """
         | 
| 172 | 
            +
                Maps 9D input vectors onto SO(3) via symmetric orthogonalization.
         | 
| 173 | 
            +
                    - Let SVD(M) = U \Sigma V^T
         | 
| 174 | 
            +
                    - Returned value is SVD+(M) =  U diag(1, 1, det(UV^T)) V^T
         | 
| 175 | 
            +
                    - det(UV^T) ensures that det(SVD+(M)) = 1
         | 
| 176 | 
            +
                    - The return value is a rotation matrix (ortonormal) with the least-squares distance to M
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                Args:
         | 
| 179 | 
            +
                    x: Input matrices, not necessarily orthonormal, shape [..., 9] or [..., 3, 3]
         | 
| 180 | 
            +
                Returns:
         | 
| 181 | 
            +
                    torch.Tensor with the same shape as x, where each inner 3x3 matrix is in SO(3)
         | 
| 182 | 
            +
                """
         | 
| 183 | 
            +
                with warnings.catch_warnings():
         | 
| 184 | 
            +
                    warnings.filterwarnings(
         | 
| 185 | 
            +
                        "ignore",
         | 
| 186 | 
            +
                        message="In CPU autocast, but the target dtype is not supported. Disabling autocast.",
         | 
| 187 | 
            +
                    )
         | 
| 188 | 
            +
                    with torch.autocast(device_type=x.device.type, dtype=torch.float32):
         | 
| 189 | 
            +
                        matrices = x.view(-1, 3, 3)
         | 
| 190 | 
            +
                        matrices = matrices.to(dtype=torch.float32)
         | 
| 191 | 
            +
                        (u, s, v) = torch.svd(matrices)
         | 
| 192 | 
            +
                        vt = torch.transpose(v, 1, 2)
         | 
| 193 | 
            +
                        det = torch.det(torch.matmul(u, vt)).view(-1, 1, 1)
         | 
| 194 | 
            +
                        diag_vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), dim=1)
         | 
| 195 | 
            +
                        result = torch.matmul(u, diag_vt)
         | 
| 196 | 
            +
                        result = result.view(*x.shape)
         | 
| 197 | 
            +
                result = result.to(dtype=x.dtype)
         | 
| 198 | 
            +
                return result
         | 
| 199 | 
            +
             | 
| 200 | 
            +
             | 
| 201 | 
            +
            def is_rotmat_orthonormal(
         | 
| 202 | 
            +
                rotmat: torch.Tensor, epsilon: float = 1e-06, reduction: str = "none"
         | 
| 203 | 
            +
            ) -> torch.Tensor | bool:
         | 
| 204 | 
            +
                """
         | 
| 205 | 
            +
                Check if a rotation matrix is orthonormal or not.
         | 
| 206 | 
            +
                Args:
         | 
| 207 | 
            +
                    rotmat: torch.Tensor of shape [..., 3, 3] or [..., 9]
         | 
| 208 | 
            +
                    epsilon: Tolerance for numerical comparisons. Bigger values allow for more freedom. Generally,
         | 
| 209 | 
            +
                        anything smaller than 1e-6 might incorrectly detect some otrhonormal matrices as not
         | 
| 210 | 
            +
                    reduction:
         | 
| 211 | 
            +
                        'none' - returns torch.Tensor of bools with the same batch shape
         | 
| 212 | 
            +
                        'all' - returns a bool, True is ALL matrices in the batch are orthonormal
         | 
| 213 | 
            +
                Returns:
         | 
| 214 | 
            +
                    torch.Tensor with the same batch shape or bool
         | 
| 215 | 
            +
                """
         | 
| 216 | 
            +
                assert is_rotmat(rotmat)
         | 
| 217 | 
            +
                rotmat = rotmat_as_3x3(rotmat.to(dtype=torch.float32))
         | 
| 218 | 
            +
                is_orthonormal = roma.is_orthonormal_matrix(rotmat, epsilon=epsilon)
         | 
| 219 | 
            +
                if reduction == "none":
         | 
| 220 | 
            +
                    return is_orthonormal
         | 
| 221 | 
            +
                if reduction == "all":
         | 
| 222 | 
            +
                    return bool(torch.all(is_orthonormal).item())
         | 
| 223 | 
            +
                raise ValueError(f"Unknown reduction mode {reduction}")
         | 
| 224 | 
            +
             | 
| 225 | 
            +
             | 
| 226 | 
            +
            def is_orthonormal_rotmat(rotmat: torch.Tensor) -> bool:
         | 
| 227 | 
            +
                """
         | 
| 228 | 
            +
                Checks if the tensor shape matches that of a rotmat. If the last dimensions of shape are 3x3,
         | 
| 229 | 
            +
                also checks if the data is a valid rotmat. This is to avoid a possible clash with euler angles
         | 
| 230 | 
            +
                when accidentally `rotmat.shape[-2:] == [3, 3]`
         | 
| 231 | 
            +
                """
         | 
| 232 | 
            +
                return (
         | 
| 233 | 
            +
                    is_rotmat_9(rotmat)
         | 
| 234 | 
            +
                    or is_rotmat_3x3(rotmat)
         | 
| 235 | 
            +
                    and is_rotmat_orthonormal(rotmat, epsilon=0.01, reduction="all")
         | 
| 236 | 
            +
                )
         | 
| 237 | 
            +
             | 
| 238 | 
            +
             | 
| 239 | 
            +
            def is_euler(euler: torch.Tensor) -> bool:
         | 
| 240 | 
            +
                return euler.shape[-1] == 3 and not is_orthonormal_rotmat(euler)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
             | 
| 243 | 
            +
            def normalize_rotation(rotation: torch.Tensor) -> torch.Tensor:
         | 
| 244 | 
            +
                if is_quaternion(rotation):
         | 
| 245 | 
            +
                    return normalize_quaternion(rotation)
         | 
| 246 | 
            +
                if is_euler(rotation):
         | 
| 247 | 
            +
                    return rotation
         | 
| 248 | 
            +
                if is_rotmat(rotation):
         | 
| 249 | 
            +
                    is_flat = is_rotmat_9(rotation)
         | 
| 250 | 
            +
                    rotation = rotmat_as_3x3(rotation) if is_flat else rotation
         | 
| 251 | 
            +
                    rotmat = roma.special_gramschmidt(rotation)
         | 
| 252 | 
            +
                    rotmat = rotmat_as_9(rotmat) if is_flat else rotmat
         | 
| 253 | 
            +
                    return rotmat
         | 
| 254 | 
            +
                raise ValueError(f"Unknown rotation format: {rotation.shape}")
         | 
| 255 | 
            +
             | 
| 256 | 
            +
             | 
| 257 | 
            +
            def rotation_format_from_tensor(rotation) -> RotationFormat:
         | 
| 258 | 
            +
                if is_quaternion(rotation):
         | 
| 259 | 
            +
                    return RotationFormat.QUATERNION
         | 
| 260 | 
            +
                if is_orthonormal_rotmat(rotation):
         | 
| 261 | 
            +
                    return RotationFormat.ROTMAT
         | 
| 262 | 
            +
                if is_euler(rotation):
         | 
| 263 | 
            +
                    return RotationFormat.EULER
         | 
| 264 | 
            +
                raise ValueError(f"Tensor shape {rotation.shape} is not a valid rotation format")
         | 
| 265 | 
            +
             | 
| 266 | 
            +
             | 
| 267 | 
            +
            def is_unit_quaternion(
         | 
| 268 | 
            +
                quaternion: torch.Tensor, epsilon: float = 1e-08, reduction: str = "none"
         | 
| 269 | 
            +
            ) -> torch.Tensor | bool:
         | 
| 270 | 
            +
                """
         | 
| 271 | 
            +
                Check if a quternion is normalized or not.
         | 
| 272 | 
            +
                Args:
         | 
| 273 | 
            +
                    quaternion: torch.Tensor of shape [..., 4]
         | 
| 274 | 
            +
                    tolerance: Tolerance for numerical comparisons
         | 
| 275 | 
            +
                    reduction:
         | 
| 276 | 
            +
                        'none' - returns torch.Tensor of bools with the same batch shape
         | 
| 277 | 
            +
                        'all' - returns a bool, True if ALL quaternions in the batch are normalized
         | 
| 278 | 
            +
                Returns:
         | 
| 279 | 
            +
                    torch.Tensor with the same batch shape or bool
         | 
| 280 | 
            +
                """
         | 
| 281 | 
            +
                assert is_quaternion(quaternion)
         | 
| 282 | 
            +
                is_norm = torch.isclose(
         | 
| 283 | 
            +
                    quaternion.norm(dim=-1, keepdim=True),
         | 
| 284 | 
            +
                    torch.tensor(1.0, dtype=quaternion.dtype, device=quaternion.device),
         | 
| 285 | 
            +
                    atol=epsilon,
         | 
| 286 | 
            +
                )
         | 
| 287 | 
            +
                if reduction == "none":
         | 
| 288 | 
            +
                    return is_norm
         | 
| 289 | 
            +
                if reduction == "all":
         | 
| 290 | 
            +
                    return bool(torch.all(is_norm).item())
         | 
| 291 | 
            +
                raise ValueError(f"Unknown reduction mode {reduction}")
         | 
| 292 | 
            +
             | 
| 293 | 
            +
             | 
| 294 | 
            +
            def convert_rotation(
         | 
| 295 | 
            +
                rotation: torch.Tensor | np.ndarray,
         | 
| 296 | 
            +
                output_format: RotationFormat,
         | 
| 297 | 
            +
                autonorm: bool = True,
         | 
| 298 | 
            +
                half_cover: bool = True,
         | 
| 299 | 
            +
            ) -> torch.Tensor | np.ndarray:
         | 
| 300 | 
            +
                is_np = isinstance(rotation, np.ndarray)
         | 
| 301 | 
            +
                if is_np:
         | 
| 302 | 
            +
                    rotation = torch.from_numpy(rotation)
         | 
| 303 | 
            +
                if is_quaternion(rotation):
         | 
| 304 | 
            +
                    if autonorm and not is_unit_quaternion(rotation, reduction="all"):
         | 
| 305 | 
            +
                        rotation = normalize_quaternion(rotation)
         | 
| 306 | 
            +
                    if output_format == RotationFormat.QUATERNION:
         | 
| 307 | 
            +
                        output = rotation
         | 
| 308 | 
            +
                    elif output_format == RotationFormat.ROTMAT:
         | 
| 309 | 
            +
                        output = rotmat_as_9(quaternion_to_rotmat(rotation))
         | 
| 310 | 
            +
                    elif output_format == RotationFormat.EULER:
         | 
| 311 | 
            +
                        output = quaternion_to_euler(rotation)
         | 
| 312 | 
            +
                    else:
         | 
| 313 | 
            +
                        raise NotImplementedError(f"Unsupported rotation format: {output_format}")
         | 
| 314 | 
            +
                elif is_orthonormal_rotmat(rotation):
         | 
| 315 | 
            +
                    if autonorm and not is_rotmat_orthonormal(rotation, epsilon=0.01, reduction="all"):
         | 
| 316 | 
            +
                        rotation = symmetric_orthogonalization(rotation)
         | 
| 317 | 
            +
                    if output_format == RotationFormat.QUATERNION:
         | 
| 318 | 
            +
                        output = rotmat_to_unit_quaternion(rotation)
         | 
| 319 | 
            +
                    elif output_format == RotationFormat.ROTMAT:
         | 
| 320 | 
            +
                        output = rotmat_as_9(rotation)
         | 
| 321 | 
            +
                    elif output_format == RotationFormat.EULER:
         | 
| 322 | 
            +
                        output = rotmat_to_euler(rotation)
         | 
| 323 | 
            +
                    else:
         | 
| 324 | 
            +
                        raise NotImplementedError(f"Unsupported rotation format: {output_format}")
         | 
| 325 | 
            +
                elif is_euler(rotation):
         | 
| 326 | 
            +
                    if output_format == RotationFormat.QUATERNION:
         | 
| 327 | 
            +
                        output = euler_to_unit_quaternion(rotation)
         | 
| 328 | 
            +
                    elif output_format == RotationFormat.ROTMAT:
         | 
| 329 | 
            +
                        output = rotmat_as_9(euler_to_rotmat(rotation))
         | 
| 330 | 
            +
                    elif output_format == RotationFormat.EULER:
         | 
| 331 | 
            +
                        output = rotation
         | 
| 332 | 
            +
                    else:
         | 
| 333 | 
            +
                        raise NotImplementedError(f"Unsupported rotation format: {output_format}")
         | 
| 334 | 
            +
                else:
         | 
| 335 | 
            +
                    raise ValueError(f"Unknown rotation encoding with shape {rotation.shape}")
         | 
| 336 | 
            +
                if output_format == RotationFormat.QUATERNION and half_cover:
         | 
| 337 | 
            +
                    output = quaternion_half_cover(output)
         | 
| 338 | 
            +
                if is_np:
         | 
| 339 | 
            +
                    output = output.numpy()
         | 
| 340 | 
            +
                return output
         | 
| 341 | 
            +
             | 
| 342 | 
            +
             | 
| 343 | 
            +
            def delta_to_relative_rotations(rotation_sequence: torch.Tensor) -> torch.Tensor:
         | 
| 344 | 
            +
                """
         | 
| 345 | 
            +
                Transform a sequence of rotation representations encoded w.r.t. the PREVIOUS rotation frame in the
         | 
| 346 | 
            +
                sequence to the 0-th element preceding the sequence
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                Ex:
         | 
| 349 | 
            +
                    `rotation_sequence` contains the rotations: R_01, R_12, R_23, R_34, where R0 is the base frame,
         | 
| 350 | 
            +
                        implicitly encoded in R_01 and R_10 converts from R0 frame to R1 frame
         | 
| 351 | 
            +
                    Output: R_01, R_02, R_03, R_04
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                Args:
         | 
| 354 | 
            +
                    rotation_sequence: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4], containing
         | 
| 355 | 
            +
                        either rotation matrices (R_01, R_12, R_23, R_34, ...) or quaternions
         | 
| 356 | 
            +
                Returns:
         | 
| 357 | 
            +
                    torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] containing transformed rotations
         | 
| 358 | 
            +
                        (R_01, R_02, R_03, R_04, ...)
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                TODO: Can you make it work without for loop
         | 
| 361 | 
            +
                """
         | 
| 362 | 
            +
                assert rotation_sequence.ndim >= 3, rotation_sequence.shape
         | 
| 363 | 
            +
                rotation_format: RotationFormat = rotation_format_from_tensor(rotation_sequence)
         | 
| 364 | 
            +
                rotation_sequence = convert_rotation(rotation_sequence, RotationFormat.QUATERNION)
         | 
| 365 | 
            +
                batch_dims = np.arange(rotation_sequence.ndim - 2)
         | 
| 366 | 
            +
                delta_rotations = torch.cat(
         | 
| 367 | 
            +
                    [rotation_sequence[..., :1, :]]
         | 
| 368 | 
            +
                    + [
         | 
| 369 | 
            +
                        roma.quat_composition(rotation_sequence[..., :i, :].permute(-2, *batch_dims, -1).unsqueeze(-2))
         | 
| 370 | 
            +
                        for i in range(2, rotation_sequence.shape[-2] + 1)
         | 
| 371 | 
            +
                    ],
         | 
| 372 | 
            +
                    dim=-2,
         | 
| 373 | 
            +
                )
         | 
| 374 | 
            +
                delta_rotations = convert_rotation(delta_rotations, rotation_format)
         | 
| 375 | 
            +
                return delta_rotations
         | 
| 376 | 
            +
             | 
| 377 | 
            +
             | 
| 378 | 
            +
            def assert_np_hwc_or_hw_image(image: np.ndarray | PIL.Image.Image) -> np.ndarray:
         | 
| 379 | 
            +
                """Make sure image is of type np.ndarray and HWC format"""
         | 
| 380 | 
            +
                if isinstance(image, PIL.Image.Image):
         | 
| 381 | 
            +
                    image = np.asarray(image)
         | 
| 382 | 
            +
                assert isinstance(image, np.ndarray), type(image)
         | 
| 383 | 
            +
                assert image.ndim in [2, 3], image.shape
         | 
| 384 | 
            +
                if image.ndim == 3:
         | 
| 385 | 
            +
                    assert image.shape[-1] <= 4, image.shape
         | 
| 386 | 
            +
                return image
         | 
| 387 | 
            +
             | 
| 388 | 
            +
             | 
| 389 | 
            +
            def hw_from_image(image: PIL.Image.Image | np.ndarray) -> tuple[int, int]:
         | 
| 390 | 
            +
                if isinstance(image, np.ndarray):
         | 
| 391 | 
            +
                    (height, width) = image.shape[:2]
         | 
| 392 | 
            +
                else:
         | 
| 393 | 
            +
                    (width, height) = image.size
         | 
| 394 | 
            +
                return height, width
         | 
| 395 | 
            +
             | 
| 396 | 
            +
             | 
| 397 | 
            +
            def pad_image(
         | 
| 398 | 
            +
                image: PIL.Image.Image | np.ndarray,
         | 
| 399 | 
            +
                target_size: dict[str, int],
         | 
| 400 | 
            +
                pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
         | 
| 401 | 
            +
            ) -> PIL.Image.Image | np.ndarray:
         | 
| 402 | 
            +
                """Pad image adding a symmetric border around the height/width."""
         | 
| 403 | 
            +
                assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image)
         | 
| 404 | 
            +
                (height, width) = hw_from_image(image)
         | 
| 405 | 
            +
                (target_width, target_height) = (target_size["width"], target_size["height"])
         | 
| 406 | 
            +
                if width == target_width and height == target_height:
         | 
| 407 | 
            +
                    return image
         | 
| 408 | 
            +
                assert target_width >= width, f"Can't pad image of width {width} to {target_width}"
         | 
| 409 | 
            +
                assert target_height >= height, f"Can't pad image of height {height} to {target_height}"
         | 
| 410 | 
            +
                (horizontal_pad, vertical_pad) = (
         | 
| 411 | 
            +
                    int((target_width - width) / 2),
         | 
| 412 | 
            +
                    int((target_height - height) / 2),
         | 
| 413 | 
            +
                )
         | 
| 414 | 
            +
                if isinstance(image, np.ndarray):
         | 
| 415 | 
            +
                    padding = ((vertical_pad, vertical_pad), (horizontal_pad, horizontal_pad)) + ((0, 0),) * (
         | 
| 416 | 
            +
                        image.ndim - 2
         | 
| 417 | 
            +
                    )
         | 
| 418 | 
            +
                    image = np.pad(image, padding, mode="constant", constant_values=pad_value)
         | 
| 419 | 
            +
                else:
         | 
| 420 | 
            +
                    padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
         | 
| 421 | 
            +
                    image = torchvision.transforms.v2.functional.pad(
         | 
| 422 | 
            +
                        image, padding=padding, fill=pad_value, padding_mode="constant"
         | 
| 423 | 
            +
                    )
         | 
| 424 | 
            +
                return image
         | 
| 425 | 
            +
             | 
| 426 | 
            +
             | 
| 427 | 
            +
            def pad_image_to_ratio(
         | 
| 428 | 
            +
                image: PIL.Image.Image | np.ndarray,
         | 
| 429 | 
            +
                target_wh_ratio: float,
         | 
| 430 | 
            +
                pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
         | 
| 431 | 
            +
            ) -> PIL.Image.Image | np.ndarray:
         | 
| 432 | 
            +
                """Pad image to a target aspect ratio."""
         | 
| 433 | 
            +
                (height, width) = hw_from_image(image)
         | 
| 434 | 
            +
                wh_ratio = width / height
         | 
| 435 | 
            +
                if target_wh_ratio >= wh_ratio:
         | 
| 436 | 
            +
                    pad_size = {"width": round(height * target_wh_ratio), "height": height}
         | 
| 437 | 
            +
                else:
         | 
| 438 | 
            +
                    pad_size = {"width": width, "height": round(width / target_wh_ratio)}
         | 
| 439 | 
            +
                image = pad_image(image, target_size=pad_size, pad_value=pad_value)
         | 
| 440 | 
            +
                return image
         | 
| 441 | 
            +
             | 
| 442 | 
            +
             | 
| 443 | 
            +
            def crop_image(
         | 
| 444 | 
            +
                image: np.ndarray | PIL.Image.Image,
         | 
| 445 | 
            +
                start_height: int,
         | 
| 446 | 
            +
                start_width: int,
         | 
| 447 | 
            +
                target_height: int,
         | 
| 448 | 
            +
                target_width: int,
         | 
| 449 | 
            +
            ) -> np.ndarray | PIL.Image.Image:
         | 
| 450 | 
            +
                np_image = assert_np_hwc_or_hw_image(image)
         | 
| 451 | 
            +
                (height, width) = hw_from_image(image)
         | 
| 452 | 
            +
                assert target_width <= width, f"Can't crop image of width {width} to {target_width}"
         | 
| 453 | 
            +
                assert target_height <= height, f"Can't crop image of width {height} to {target_height}"
         | 
| 454 | 
            +
                (start_height, start_width) = (round(start_height), round(start_width))
         | 
| 455 | 
            +
                (target_height, target_width) = (round(target_height), round(target_width))
         | 
| 456 | 
            +
                np_image = np_image[
         | 
| 457 | 
            +
                    start_height : start_height + target_height,
         | 
| 458 | 
            +
                    start_width : start_width + target_width,
         | 
| 459 | 
            +
                    ...,
         | 
| 460 | 
            +
                ]
         | 
| 461 | 
            +
                image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image
         | 
| 462 | 
            +
                return image
         | 
| 463 | 
            +
             | 
| 464 | 
            +
             | 
| 465 | 
            +
            def crop_image_center(
         | 
| 466 | 
            +
                image: np.ndarray | PIL.Image.Image, target_size: dict[str, int]
         | 
| 467 | 
            +
            ) -> np.ndarray | PIL.Image.Image:
         | 
| 468 | 
            +
                np_image = assert_np_hwc_or_hw_image(image)
         | 
| 469 | 
            +
                (height, width) = np_image.shape[:2]
         | 
| 470 | 
            +
                (target_height, target_width) = (target_size["height"], target_size["width"])
         | 
| 471 | 
            +
                assert target_width <= width, f"Can't crop image of width {width} to {target_width}"
         | 
| 472 | 
            +
                assert target_height <= height, f"Can't crop image of width {height} to {target_height}"
         | 
| 473 | 
            +
                top = (height - target_height) // 2
         | 
| 474 | 
            +
                left = (width - target_width) // 2
         | 
| 475 | 
            +
                np_image = crop_image(np_image, top, left, target_height, target_width)
         | 
| 476 | 
            +
                image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image
         | 
| 477 | 
            +
                return image
         | 
| 478 | 
            +
             | 
| 479 | 
            +
             | 
| 480 | 
            +
            def crop_image_to_ratio(
         | 
| 481 | 
            +
                image: PIL.Image.Image | np.ndarray, target_wh_ratio: float
         | 
| 482 | 
            +
            ) -> PIL.Image.Image | np.ndarray:
         | 
| 483 | 
            +
                """Pad image to a target aspect ratio."""
         | 
| 484 | 
            +
                (height, width) = hw_from_image(image)
         | 
| 485 | 
            +
                wh_ratio = width / height
         | 
| 486 | 
            +
                if target_wh_ratio >= wh_ratio:
         | 
| 487 | 
            +
                    crop_size = {"width": width, "height": round(width / target_wh_ratio)}
         | 
| 488 | 
            +
                else:
         | 
| 489 | 
            +
                    crop_size = {"width": round(height * target_wh_ratio), "height": height}
         | 
| 490 | 
            +
                image = crop_image_center(image, target_size=crop_size)
         | 
| 491 | 
            +
                return image
         | 
| 492 | 
            +
             | 
| 493 | 
            +
             | 
| 494 | 
            +
            def crop_and_pad_image_to_ratio(
         | 
| 495 | 
            +
                image: PIL.Image.Image | np.ndarray,
         | 
| 496 | 
            +
                target_wh_ratio: float,
         | 
| 497 | 
            +
                mode: ResizeMode | str,
         | 
| 498 | 
            +
                pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
         | 
| 499 | 
            +
            ) -> PIL.Image.Image | np.ndarray:
         | 
| 500 | 
            +
                """
         | 
| 501 | 
            +
                Crop and pad an image to a target size depending on the mode.
         | 
| 502 | 
            +
                It's expected that the source image and target size have different aspect ratios.
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                Args:
         | 
| 505 | 
            +
                    image: The image to crop and pad.
         | 
| 506 | 
            +
                    target_size: The target size to crop and pad the image to.
         | 
| 507 | 
            +
                    mode: The mode to use for cropping and padding.
         | 
| 508 | 
            +
                """
         | 
| 509 | 
            +
                (height, width) = hw_from_image(image)
         | 
| 510 | 
            +
                wh_ratio = width / height
         | 
| 511 | 
            +
                if np.isclose(wh_ratio, target_wh_ratio, rtol=0.01, atol=0.0001):
         | 
| 512 | 
            +
                    return image
         | 
| 513 | 
            +
                if mode == ResizeMode.SMART:
         | 
| 514 | 
            +
                    aspect_ratio = max(width, height) / min(width, height)
         | 
| 515 | 
            +
                    target_ratio = max(target_wh_ratio, 1 / target_wh_ratio)
         | 
| 516 | 
            +
                    if aspect_ratio == 1:
         | 
| 517 | 
            +
                        if target_ratio >= 4 / 3 - 0.01:
         | 
| 518 | 
            +
                            crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4
         | 
| 519 | 
            +
                            image = crop_image_to_ratio(image, crop_wh_ratio)
         | 
| 520 | 
            +
                        else:
         | 
| 521 | 
            +
                            pass
         | 
| 522 | 
            +
                    elif aspect_ratio <= 4 / 3 + 0.01:
         | 
| 523 | 
            +
                        if wh_ratio >= 1.0 != (target_wh_ratio >= 1.0):
         | 
| 524 | 
            +
                            image = crop_image_to_ratio(image, 1.0)
         | 
| 525 | 
            +
                    elif wh_ratio >= 1.0 != (target_wh_ratio >= 1.0):
         | 
| 526 | 
            +
                        image = crop_image_to_ratio(image, 1.0)
         | 
| 527 | 
            +
                    elif target_ratio >= 4 / 3 + 0.01:
         | 
| 528 | 
            +
                        pass
         | 
| 529 | 
            +
                    else:
         | 
| 530 | 
            +
                        crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4
         | 
| 531 | 
            +
                        image = crop_image_to_ratio(image, crop_wh_ratio)
         | 
| 532 | 
            +
                    image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value)
         | 
| 533 | 
            +
                elif mode == ResizeMode.PAD:
         | 
| 534 | 
            +
                    image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value)
         | 
| 535 | 
            +
                elif mode == ResizeMode.CROP:
         | 
| 536 | 
            +
                    image = crop_image_to_ratio(image, target_wh_ratio)
         | 
| 537 | 
            +
                else:
         | 
| 538 | 
            +
                    raise ValueError(f"Mode {mode} not supported")
         | 
| 539 | 
            +
                return image
         | 
| 540 | 
            +
             | 
| 541 | 
            +
             | 
| 542 | 
            +
            def is_single_channel_image(image: np.ndarray | PIL.Image.Image) -> bool:
         | 
| 543 | 
            +
                if isinstance(image, PIL.Image.Image):
         | 
| 544 | 
            +
                    return image.mode in [
         | 
| 545 | 
            +
                        "1",
         | 
| 546 | 
            +
                        "L",
         | 
| 547 | 
            +
                        "LA",
         | 
| 548 | 
            +
                        "La",
         | 
| 549 | 
            +
                        "P",
         | 
| 550 | 
            +
                        "PA",
         | 
| 551 | 
            +
                        "F",
         | 
| 552 | 
            +
                        "I",
         | 
| 553 | 
            +
                        "I;16",
         | 
| 554 | 
            +
                        "I;16L",
         | 
| 555 | 
            +
                        "I;16B",
         | 
| 556 | 
            +
                        "I;16N",
         | 
| 557 | 
            +
                    ]
         | 
| 558 | 
            +
                if isinstance(image, np.ndarray):
         | 
| 559 | 
            +
                    return image.ndim == 2 or image.ndim == 3 and image.shape[2] == 1
         | 
| 560 | 
            +
                raise ValueError(f"Unsupported image type: {type(image)}")
         | 
| 561 | 
            +
             | 
| 562 | 
            +
             | 
| 563 | 
            +
            def is_binary_mask(image: np.ndarray | PIL.Image.Image) -> bool:
         | 
| 564 | 
            +
                image = np.asarray(image)
         | 
| 565 | 
            +
                return image.dtype in [np.uint8, np.bool_] and np.max(image) == 1
         | 
| 566 | 
            +
             | 
| 567 | 
            +
             | 
| 568 | 
            +
            def resize_image(
         | 
| 569 | 
            +
                image: PIL.Image.Image | np.ndarray,
         | 
| 570 | 
            +
                target_size: dict[str, int],
         | 
| 571 | 
            +
                mode: ResizeMode | str,
         | 
| 572 | 
            +
                resample: PIL.Image.Resampling | str = "auto",
         | 
| 573 | 
            +
                pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
         | 
| 574 | 
            +
            ) -> PIL.Image.Image | np.ndarray:
         | 
| 575 | 
            +
                (target_width, target_height) = (target_size["width"], target_size["height"])
         | 
| 576 | 
            +
                (height, width) = hw_from_image(image)
         | 
| 577 | 
            +
                if height == target_height and width == target_width:
         | 
| 578 | 
            +
                    return image
         | 
| 579 | 
            +
                if resample == "auto":
         | 
| 580 | 
            +
                    if is_single_channel_image(image):
         | 
| 581 | 
            +
                        resample = PIL.Image.Resampling.BILINEAR
         | 
| 582 | 
            +
                    else:
         | 
| 583 | 
            +
                        resample = PIL.Image.Resampling.LANCZOS
         | 
| 584 | 
            +
                else:
         | 
| 585 | 
            +
                    assert isinstance(resample, PIL.Image.Resampling), resample
         | 
| 586 | 
            +
                    if is_single_channel_image(image) and resample not in [
         | 
| 587 | 
            +
                        PIL.Image.Resampling.BILINEAR,
         | 
| 588 | 
            +
                        PIL.Image.Resampling.BICUBIC,
         | 
| 589 | 
            +
                    ]:
         | 
| 590 | 
            +
                        raise ValueError(
         | 
| 591 | 
            +
                            f"Single channel images must be resized with bilinear or bicubic, but got {resample}"
         | 
| 592 | 
            +
                        )
         | 
| 593 | 
            +
                if is_bin_mask := is_binary_mask(image):
         | 
| 594 | 
            +
                    image = np.asarray(image).astype(np.uint8) * 255
         | 
| 595 | 
            +
                if mode == ResizeMode.SMART:
         | 
| 596 | 
            +
                    image = crop_and_pad_image_to_ratio(
         | 
| 597 | 
            +
                        image,
         | 
| 598 | 
            +
                        target_wh_ratio=target_width / target_height,
         | 
| 599 | 
            +
                        mode=mode,
         | 
| 600 | 
            +
                        pad_value=pad_value,
         | 
| 601 | 
            +
                    )
         | 
| 602 | 
            +
                pil_image = PIL.Image.fromarray(image) if isinstance(image, np.ndarray) else image
         | 
| 603 | 
            +
                if mode in [ResizeMode.NAIVE, ResizeMode.SMART]:
         | 
| 604 | 
            +
                    pil_image = pil_image.resize((target_width, target_height), resample=resample)
         | 
| 605 | 
            +
                else:
         | 
| 606 | 
            +
                    raise NotImplementedError(f"Mode {mode} not supported")
         | 
| 607 | 
            +
                image = np.asarray(pil_image) if isinstance(image, np.ndarray) else pil_image
         | 
| 608 | 
            +
                if is_bin_mask:
         | 
| 609 | 
            +
                    image = image.astype(np.uint8) > 127
         | 
| 610 | 
            +
                return image
         | 
| 611 | 
            +
             | 
| 612 | 
            +
             | 
| 613 | 
            +
            def is_global_norm(
         | 
| 614 | 
            +
                norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list],
         | 
| 615 | 
            +
            ) -> bool:
         | 
| 616 | 
            +
                """Return true if norm is NONE or global for all datasets"""
         | 
| 617 | 
            +
                return norm == Normalization.NONE or isinstance(norm, collections.abc.Mapping)
         | 
| 618 | 
            +
             | 
| 619 | 
            +
             | 
| 620 | 
            +
            def is_mean_norm(
         | 
| 621 | 
            +
                norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list],
         | 
| 622 | 
            +
            ) -> bool:
         | 
| 623 | 
            +
                """Return true if norm is based on mean and std"""
         | 
| 624 | 
            +
                return (
         | 
| 625 | 
            +
                    norm == Normalization.MEAN
         | 
| 626 | 
            +
                    or isinstance(norm, collections.abc.Mapping)
         | 
| 627 | 
            +
                    and set(norm.keys()) == {"mean", "std"}
         | 
| 628 | 
            +
                )
         | 
| 629 | 
            +
             | 
| 630 | 
            +
             | 
| 631 | 
            +
            def _broadcast_shapes(
         | 
| 632 | 
            +
                value: torch.Tensor, low: torch.Tensor, high: torch.Tensor
         | 
| 633 | 
            +
            ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 634 | 
            +
                """
         | 
| 635 | 
            +
                Broadcast shapes for normalization:
         | 
| 636 | 
            +
                Args:
         | 
| 637 | 
            +
                    value: torch.Tensor of shape [..., num_components]. The entire shape might be:
         | 
| 638 | 
            +
                        - [num_components]: `value` has no batch dimension
         | 
| 639 | 
            +
                        - [num_datasets, num_components]: `value` contains entries *aligned* with the dataset bounds
         | 
| 640 | 
            +
                            contained in `low` and `high`
         | 
| 641 | 
            +
                        - [num_datasets, ..., num_components]: `value` contains entries *aligned* with the dataset bounds
         | 
| 642 | 
            +
                            contained in `low` and `high`
         | 
| 643 | 
            +
                        - [..., num_components]: `value` contains multiple dimensions. In this case, `low` and `high`
         | 
| 644 | 
            +
                            must be for a single dataset, i.e. `num_datasets = 1`
         | 
| 645 | 
            +
             | 
| 646 | 
            +
                    low: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `low`
         | 
| 647 | 
            +
                        contains normalization bounds for a single dataset
         | 
| 648 | 
            +
                    high: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `high`
         | 
| 649 | 
            +
                        contains normalization bounds for a single dataset
         | 
| 650 | 
            +
                Returns:
         | 
| 651 | 
            +
                    Tuple of torch.Tensors (low, high), where `low` and `high` have the same number of dimensions as `value`
         | 
| 652 | 
            +
                """
         | 
| 653 | 
            +
                assert low.ndim == high.ndim == 2, f"{low.shape} != {high.shape} or ndim != 2"
         | 
| 654 | 
            +
                assert value.shape[-1] == low.shape[-1] == high.shape[-1], f"{value.shape} != {low.shape} / {high.shape}"
         | 
| 655 | 
            +
                if value.ndim == low.ndim == high.ndim:
         | 
| 656 | 
            +
                    return low, high
         | 
| 657 | 
            +
                if value.ndim < low.ndim:
         | 
| 658 | 
            +
                    assert low.ndim == high.ndim == 2, f"{low.shape}, {high.shape}"
         | 
| 659 | 
            +
                    assert low.shape[0] == high.shape[0] == 1, f"{low.shape}, {high.shape}"
         | 
| 660 | 
            +
                    (low, high) = (low.view(-1), high.view(-1))
         | 
| 661 | 
            +
                    return low, high
         | 
| 662 | 
            +
                if low.shape[0] == high.shape[0] == 1:
         | 
| 663 | 
            +
                    low = expand_dims(low.view(-1), ndim=value.ndim, order=[-1, 1])
         | 
| 664 | 
            +
                    high = expand_dims(high.view(-1), ndim=value.ndim, order=[-1, 1])
         | 
| 665 | 
            +
                else:
         | 
| 666 | 
            +
                    assert value.shape[0] == low.shape[0] == high.shape[0], f"{value.shape} != {low.shape} / {high.shape}"
         | 
| 667 | 
            +
                    low = expand_dims(low, ndim=value.ndim, order=[1, -1, 1])
         | 
| 668 | 
            +
                    high = expand_dims(high, ndim=value.ndim, order=[1, -1, 1])
         | 
| 669 | 
            +
                return low, high
         | 
| 670 | 
            +
             | 
| 671 | 
            +
             | 
| 672 | 
            +
            def unnormalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
         | 
| 673 | 
            +
                (mean, std) = _broadcast_shapes(value, mean, std)
         | 
| 674 | 
            +
                (mean, std) = (mean.to(device=value.device), std.to(device=value.device))
         | 
| 675 | 
            +
                return value * (std + 1e-08) + mean
         | 
| 676 | 
            +
             | 
| 677 | 
            +
             | 
| 678 | 
            +
            def unnormalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor:
         | 
| 679 | 
            +
                (low, high) = _broadcast_shapes(value, low, high)
         | 
| 680 | 
            +
                (low, high) = (low.to(device=value.device), high.to(device=value.device))
         | 
| 681 | 
            +
                return 0.5 * (value + 1) * (high - low) + low
         | 
| 682 | 
            +
             | 
| 683 | 
            +
             | 
| 684 | 
            +
            def normalize_gripper_by_bounds(
         | 
| 685 | 
            +
                value: torch.Tensor, low: torch.Tensor, high: torch.Tensor, binary: bool = True
         | 
| 686 | 
            +
            ) -> torch.Tensor:
         | 
| 687 | 
            +
                """
         | 
| 688 | 
            +
                If binary, normalize to [0, 1], otherwise normalize to [-1, 1]
         | 
| 689 | 
            +
                """
         | 
| 690 | 
            +
                (low, high) = _broadcast_shapes(value, low, high)
         | 
| 691 | 
            +
                (low, high) = (low.to(device=value.device), high.to(device=value.device))
         | 
| 692 | 
            +
                if binary:
         | 
| 693 | 
            +
                    return torch.clamp((value - low) / torch.clamp(high - low, min=1e-08), min=0.0, max=1.0)
         | 
| 694 | 
            +
                return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0)
         | 
| 695 | 
            +
             | 
| 696 | 
            +
             | 
| 697 | 
            +
            def normalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
         | 
| 698 | 
            +
                (mean, std) = _broadcast_shapes(value, mean, std)
         | 
| 699 | 
            +
                (mean, std) = (mean.to(device=value.device), std.to(device=value.device))
         | 
| 700 | 
            +
                return (value - mean) / (std + 1e-08)
         | 
| 701 | 
            +
             | 
| 702 | 
            +
             | 
| 703 | 
            +
            def normalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor:
         | 
| 704 | 
            +
                (low, high) = _broadcast_shapes(value, low, high)
         | 
| 705 | 
            +
                (low, high) = (low.to(device=value.device), high.to(device=value.device))
         | 
| 706 | 
            +
                return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0)
         | 
| 707 | 
            +
             | 
| 708 | 
            +
             | 
| 709 | 
            +
            def invert_gripper(gripper: np.ndarray, low: float, high: float) -> np.ndarray:
         | 
| 710 | 
            +
                if low < 0.0:
         | 
| 711 | 
            +
                    return np.clip(-gripper, low, high)
         | 
| 712 | 
            +
                return high - np.clip(gripper, low, high)
         | 
| 713 | 
            +
             | 
| 714 | 
            +
             | 
| 715 | 
            +
            GRIPPER_BOUNDS = {
         | 
| 716 | 
            +
                "bridge": (0.0, 1.0),
         | 
| 717 | 
            +
                "bridge_orig": (0.0, 1.0),
         | 
| 718 | 
            +
                "droid": (0.0, 1.0),
         | 
| 719 | 
            +
                "roboset": (0.0, 1.0),
         | 
| 720 | 
            +
            }
         | 
| 721 | 
            +
             | 
| 722 | 
            +
             | 
| 723 | 
            +
            def preprocess_gripper_observation(
         | 
| 724 | 
            +
                gripper: np.ndarray, dataset_name: str | np.ndarray, binary: bool = True
         | 
| 725 | 
            +
            ) -> np.ndarray:
         | 
| 726 | 
            +
                """
         | 
| 727 | 
            +
                Preprocess gripper observation depending on dataset. Input is the raw gripper observation from the dataset
         | 
| 728 | 
            +
                or from the robot and output is normalized continuous value.
         | 
| 729 | 
            +
                    - if `binary`, output is in [0, 1], with 0 = closed and 1 = open.
         | 
| 730 | 
            +
                    - otherwise, output is in [-1, 1], with -1 = closed and 1 = open.
         | 
| 731 | 
            +
             | 
| 732 | 
            +
                Dataset-specific gripper observations:
         | 
| 733 | 
            +
                    bridge: continuous; ~[0=closed; 1=open]
         | 
| 734 | 
            +
                    bridge_orig: continuous; ~[0=closed; 1=open]
         | 
| 735 | 
            +
                    droid: continuous; [0=open, 1=closed]
         | 
| 736 | 
            +
                    roboset: continuous; [0=open, 1=closed]
         | 
| 737 | 
            +
                """
         | 
| 738 | 
            +
                if isinstance(dataset_name, np.ndarray):
         | 
| 739 | 
            +
                    assert np.unique(dataset_name).size == 1, dataset_name
         | 
| 740 | 
            +
                    dataset_name = str(dataset_name[0])
         | 
| 741 | 
            +
                if dataset_name in [
         | 
| 742 | 
            +
                    "droid",
         | 
| 743 | 
            +
                    "roboset",
         | 
| 744 | 
            +
                ]:
         | 
| 745 | 
            +
                    (low, high) = GRIPPER_BOUNDS[dataset_name]
         | 
| 746 | 
            +
                    gripper = normalize_gripper_by_bounds(
         | 
| 747 | 
            +
                        torch.from_numpy(invert_gripper(gripper, low=low, high=high)),
         | 
| 748 | 
            +
                        low=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][0], dtype=torch.float32),
         | 
| 749 | 
            +
                        high=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][1], dtype=torch.float32),
         | 
| 750 | 
            +
                        binary=binary,
         | 
| 751 | 
            +
                    ).numpy()
         | 
| 752 | 
            +
                elif dataset_name in [
         | 
| 753 | 
            +
                    "bridge",
         | 
| 754 | 
            +
                    "bridge_orig",
         | 
| 755 | 
            +
                ]:
         | 
| 756 | 
            +
                    (low, high) = GRIPPER_BOUNDS[dataset_name]
         | 
| 757 | 
            +
                    gripper = normalize_gripper_by_bounds(
         | 
| 758 | 
            +
                        torch.from_numpy(gripper),
         | 
| 759 | 
            +
                        low=torch.full(gripper.shape, low, dtype=torch.float32),
         | 
| 760 | 
            +
                        high=torch.full(gripper.shape, high, dtype=torch.float32),
         | 
| 761 | 
            +
                        binary=binary,
         | 
| 762 | 
            +
                    ).numpy()
         | 
| 763 | 
            +
                else:
         | 
| 764 | 
            +
                    raise NotImplementedError(f"Unknown dataset: {dataset_name}")
         | 
| 765 | 
            +
                return gripper
         | 
| 766 | 
            +
             | 
| 767 | 
            +
             | 
| 768 | 
            +
            def rotation_norm_bounds(
         | 
| 769 | 
            +
                rotation_norm: Normalization,
         | 
| 770 | 
            +
                rotation_format: RotationFormat,
         | 
| 771 | 
            +
                stats: Dict[str, Dict[str, Dict[str, List[float]]]],
         | 
| 772 | 
            +
                dataset_names: List[str],
         | 
| 773 | 
            +
            ) -> Dict[str, Dict[str, torch.Tensor]]:
         | 
| 774 | 
            +
                if rotation_format == RotationFormat.EULER and rotation_norm != Normalization.NONE:
         | 
| 775 | 
            +
                    if rotation_norm == Normalization.BOUNDS:
         | 
| 776 | 
            +
                        results = {
         | 
| 777 | 
            +
                            dataset_name: {
         | 
| 778 | 
            +
                                "low": torch.tensor(dataset_stats["euler"]["min"]),
         | 
| 779 | 
            +
                                "high": torch.tensor(dataset_stats["euler"]["max"]),
         | 
| 780 | 
            +
                            }
         | 
| 781 | 
            +
                            for (dataset_name, dataset_stats) in stats.items()
         | 
| 782 | 
            +
                        }
         | 
| 783 | 
            +
                    elif rotation_norm == Normalization.BOUNDS_Q99:
         | 
| 784 | 
            +
                        results = {
         | 
| 785 | 
            +
                            dataset_name: {
         | 
| 786 | 
            +
                                "low": torch.tensor(dataset_stats["euler"]["q01"]),
         | 
| 787 | 
            +
                                "high": torch.tensor(dataset_stats["euler"]["q99"]),
         | 
| 788 | 
            +
                            }
         | 
| 789 | 
            +
                            for (dataset_name, dataset_stats) in stats.items()
         | 
| 790 | 
            +
                        }
         | 
| 791 | 
            +
                    else:
         | 
| 792 | 
            +
                        raise NotImplementedError(f"Normalization type {rotation_norm} not yet implemented")
         | 
| 793 | 
            +
                else:
         | 
| 794 | 
            +
                    assert rotation_norm == Normalization.NONE, rotation_norm
         | 
| 795 | 
            +
                    if rotation_format == RotationFormat.EULER:
         | 
| 796 | 
            +
                        rotation_size = 3
         | 
| 797 | 
            +
                    elif rotation_format == RotationFormat.QUATERNION:
         | 
| 798 | 
            +
                        rotation_size = 4
         | 
| 799 | 
            +
                    else:
         | 
| 800 | 
            +
                        rotation_size = 9
         | 
| 801 | 
            +
                    results = {
         | 
| 802 | 
            +
                        dataset_name: {
         | 
| 803 | 
            +
                            "low": -1 * torch.ones(rotation_size, dtype=torch.float32),
         | 
| 804 | 
            +
                            "high": 1 * torch.ones(rotation_size, dtype=torch.float32),
         | 
| 805 | 
            +
                        }
         | 
| 806 | 
            +
                        for dataset_name in dataset_names
         | 
| 807 | 
            +
                    }
         | 
| 808 | 
            +
                return results
         | 
| 809 | 
            +
             | 
| 810 | 
            +
             | 
| 811 | 
            +
            def translation_norm_bounds(
         | 
| 812 | 
            +
                translation_norm: Normalization | tuple,
         | 
| 813 | 
            +
                stats: Dict[str, Dict[str, Dict[str, List[float]]]],
         | 
| 814 | 
            +
                dataset_names: List[str],
         | 
| 815 | 
            +
            ) -> Dict[str, Dict[str, torch.Tensor]]:
         | 
| 816 | 
            +
                if isinstance(translation_norm, (Normalization, str)) and translation_norm != Normalization.NONE:
         | 
| 817 | 
            +
                    if translation_norm == Normalization.BOUNDS:
         | 
| 818 | 
            +
                        results = {
         | 
| 819 | 
            +
                            dataset_name: {
         | 
| 820 | 
            +
                                "low": torch.tensor(dataset_stats["translation"]["min"]),
         | 
| 821 | 
            +
                                "high": torch.tensor(dataset_stats["translation"]["max"]),
         | 
| 822 | 
            +
                            }
         | 
| 823 | 
            +
                            for (dataset_name, dataset_stats) in stats.items()
         | 
| 824 | 
            +
                        }
         | 
| 825 | 
            +
                    elif translation_norm == Normalization.BOUNDS_Q99:
         | 
| 826 | 
            +
                        results = {
         | 
| 827 | 
            +
                            dataset_name: {
         | 
| 828 | 
            +
                                "low": torch.tensor(dataset_stats["translation"]["q01"]),
         | 
| 829 | 
            +
                                "high": torch.tensor(dataset_stats["translation"]["q99"]),
         | 
| 830 | 
            +
                            }
         | 
| 831 | 
            +
                            for (dataset_name, dataset_stats) in stats.items()
         | 
| 832 | 
            +
                        }
         | 
| 833 | 
            +
                    elif translation_norm == Normalization.MEAN:
         | 
| 834 | 
            +
                        results = {
         | 
| 835 | 
            +
                            dataset_name: {
         | 
| 836 | 
            +
                                "mean": torch.tensor(dataset_stats["translation"]["mean"]),
         | 
| 837 | 
            +
                                "std": torch.tensor(dataset_stats["translation"]["std"]),
         | 
| 838 | 
            +
                            }
         | 
| 839 | 
            +
                            for (dataset_name, dataset_stats) in stats.items()
         | 
| 840 | 
            +
                        }
         | 
| 841 | 
            +
                    else:
         | 
| 842 | 
            +
                        raise NotImplementedError(f"Normalization type {translation_norm} not yet implemented")
         | 
| 843 | 
            +
                elif isinstance(translation_norm, Normalization) and translation_norm == Normalization.NONE:
         | 
| 844 | 
            +
                    results = {
         | 
| 845 | 
            +
                        dataset_name: {
         | 
| 846 | 
            +
                            "low": -1 * torch.ones(3, dtype=torch.float32),
         | 
| 847 | 
            +
                            "high": 1 * torch.ones(3, dtype=torch.float32),
         | 
| 848 | 
            +
                        }
         | 
| 849 | 
            +
                        for dataset_name in dataset_names
         | 
| 850 | 
            +
                    }
         | 
| 851 | 
            +
                else:
         | 
| 852 | 
            +
                    assert isinstance(translation_norm, collections.abc.Mapping), type(translation_norm)
         | 
| 853 | 
            +
                    assert all((len(value) == 3 for value in translation_norm.values())), translation_norm
         | 
| 854 | 
            +
                    assert set(translation_norm.keys()) in (
         | 
| 855 | 
            +
                        {"low", "high"},
         | 
| 856 | 
            +
                        {"mean", "std"},
         | 
| 857 | 
            +
                    ), translation_norm
         | 
| 858 | 
            +
                    results = {
         | 
| 859 | 
            +
                        dataset_name: {
         | 
| 860 | 
            +
                            key: torch.tensor(value, dtype=torch.float32) for (key, value) in translation_norm.items()
         | 
| 861 | 
            +
                        }
         | 
| 862 | 
            +
                        for dataset_name in dataset_names
         | 
| 863 | 
            +
                    }
         | 
| 864 | 
            +
                return results
         | 
| 865 | 
            +
             | 
| 866 | 
            +
             | 
| 867 | 
            +
            VLAMProcessorConfigT = TypeVar("VLAMProcessorConfigT")
         | 
| 868 | 
            +
             | 
| 869 | 
            +
             | 
| 870 | 
            +
            class VLAMProcessor(Configurable):
         | 
| 871 | 
            +
                def __init__(self, config: VLAMProcessorConfigT, vlm_processor: VLMProcessor):
         | 
| 872 | 
            +
                    super().__init__(config)
         | 
| 873 | 
            +
                    self.vlm_processor = vlm_processor
         | 
| 874 | 
            +
                    self.control_tokenizer = EmptyTokenizer(
         | 
| 875 | 
            +
                        config=self.config.control_tokenizer_config, tokenizer=self.tokenizer
         | 
| 876 | 
            +
                    )
         | 
| 877 | 
            +
                    self.norm_bounds: Dict[str, Dict[str, Dict[str, torch.Tensor]]] = {
         | 
| 878 | 
            +
                        "obs_translation": self.obs_translation_norm_bounds,
         | 
| 879 | 
            +
                        "obs_rotation": self.obs_rotation_norm_bounds,
         | 
| 880 | 
            +
                        "translation": self.translation_norm_bounds,
         | 
| 881 | 
            +
                        "rotation": self.rotation_norm_bounds,
         | 
| 882 | 
            +
                        "joints": self.joints_norm_bounds,
         | 
| 883 | 
            +
                    }
         | 
| 884 | 
            +
             | 
| 885 | 
            +
                @property
         | 
| 886 | 
            +
                def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
         | 
| 887 | 
            +
                    return self.vlm_processor.tokenizer
         | 
| 888 | 
            +
             | 
| 889 | 
            +
                @property
         | 
| 890 | 
            +
                def image_sizes(self) -> Dict[str, ImageSizeConfig]:
         | 
| 891 | 
            +
                    return self.vlm_processor.image_sizes
         | 
| 892 | 
            +
             | 
| 893 | 
            +
                @property
         | 
| 894 | 
            +
                def camera_names(self) -> List[str]:
         | 
| 895 | 
            +
                    return list(self.vlm_processor.image_sizes.keys())
         | 
| 896 | 
            +
             | 
| 897 | 
            +
                @property
         | 
| 898 | 
            +
                def control_io_config(self) -> ControlDataIOConfig:
         | 
| 899 | 
            +
                    return self.config.control_io_config
         | 
| 900 | 
            +
             | 
| 901 | 
            +
                @cached_property
         | 
| 902 | 
            +
                def rotation_components(self) -> int:
         | 
| 903 | 
            +
                    if self.config.rotation_format == RotationFormat.EULER:
         | 
| 904 | 
            +
                        return 3
         | 
| 905 | 
            +
                    if self.config.rotation_format == RotationFormat.QUATERNION:
         | 
| 906 | 
            +
                        return 4
         | 
| 907 | 
            +
                    if self.config.rotation_format == RotationFormat.ROTMAT:
         | 
| 908 | 
            +
                        return 9
         | 
| 909 | 
            +
                    raise NotImplementedError(self.config.rotation_format)
         | 
| 910 | 
            +
             | 
| 911 | 
            +
                @abstractmethod
         | 
| 912 | 
            +
                def policy_control_plan_from_model_target(
         | 
| 913 | 
            +
                    self, target: RoboticsTarget, dataset_name: np.ndarray
         | 
| 914 | 
            +
                ) -> RoboticsControlPlan:
         | 
| 915 | 
            +
                    pass
         | 
| 916 | 
            +
             | 
| 917 | 
            +
                @abstractmethod
         | 
| 918 | 
            +
                def policy_control_plan_from_model_output(
         | 
| 919 | 
            +
                    self,
         | 
| 920 | 
            +
                    model_output: RoboticsOutput,
         | 
| 921 | 
            +
                    dataset_name: np.ndarray,
         | 
| 922 | 
            +
                    valid_mask: torch.Tensor,
         | 
| 923 | 
            +
                ) -> RoboticsControlPlan:
         | 
| 924 | 
            +
                    pass
         | 
| 925 | 
            +
             | 
| 926 | 
            +
                def resize_image(
         | 
| 927 | 
            +
                    self, camera_name: str, image: PIL.Image.Image | np.ndarray
         | 
| 928 | 
            +
                ) -> PIL.Image.Image | np.ndarray:
         | 
| 929 | 
            +
                    return resize_image(
         | 
| 930 | 
            +
                        image,
         | 
| 931 | 
            +
                        target_size={
         | 
| 932 | 
            +
                            "width": self.image_sizes[camera_name].width,
         | 
| 933 | 
            +
                            "height": self.image_sizes[camera_name].height,
         | 
| 934 | 
            +
                        },
         | 
| 935 | 
            +
                        mode=self.config.image_resize,
         | 
| 936 | 
            +
                        resample=PIL.Image.Resampling.LANCZOS,
         | 
| 937 | 
            +
                    )
         | 
| 938 | 
            +
             | 
| 939 | 
            +
                def preprocess_inputs(
         | 
| 940 | 
            +
                    self,
         | 
| 941 | 
            +
                    chat: List[str],
         | 
| 942 | 
            +
                    images: Dict[str, PIL.Image.Image | List[PIL.Image.Image]],
         | 
| 943 | 
            +
                    ee_pose_translation: np.ndarray,
         | 
| 944 | 
            +
                    ee_pose_rotation: np.ndarray,
         | 
| 945 | 
            +
                    gripper: np.ndarray,
         | 
| 946 | 
            +
                    joints: np.ndarray,
         | 
| 947 | 
            +
                    dataset_name: np.ndarray,
         | 
| 948 | 
            +
                    inference_mode: bool,
         | 
| 949 | 
            +
                    control_target: Optional[RoboticsTarget] = None,
         | 
| 950 | 
            +
                ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
         | 
| 951 | 
            +
                    """
         | 
| 952 | 
            +
                    Preprocess the inputs for a single example
         | 
| 953 | 
            +
                    Args:
         | 
| 954 | 
            +
                        instruction: Language instruction
         | 
| 955 | 
            +
                        images: History of input images with increasing timestamps
         | 
| 956 | 
            +
                        ee_pose_translation: np.ndarray, shape [..., num_past_scalars, 3]
         | 
| 957 | 
            +
                        ee_pose_rotation: np.ndarray, shape [..., num_past_scalars, 3 | 4 | 9]
         | 
| 958 | 
            +
                        joints: np.ndarray, shape  [..., num_past_scalars, <= 7]
         | 
| 959 | 
            +
                        dataset_name: 1D np.ndarray
         | 
| 960 | 
            +
                        inference_mode: If True, prepare the input for inference (e.g. don't include target
         | 
| 961 | 
            +
                            any tokens in the input if relevant). If control_target is available, it should
         | 
| 962 | 
            +
                            still be preprocessed for test dataset comparison
         | 
| 963 | 
            +
                        control_target: RoboticsTarget, each component of shape
         | 
| 964 | 
            +
                            [..., num_control_steps, num_control_components]. Provided only when available, usually
         | 
| 965 | 
            +
                            during training and dataset test
         | 
| 966 | 
            +
                    Returns:
         | 
| 967 | 
            +
                        Dict containing torch.Tensor with inputs
         | 
| 968 | 
            +
                    """
         | 
| 969 | 
            +
                    del control_target
         | 
| 970 | 
            +
                    del inference_mode
         | 
| 971 | 
            +
                    inputs = self.vlm_processor.preprocess_inputs(chat=chat, images=images)
         | 
| 972 | 
            +
                    images: Dict[str, torch.Tensor] = inputs["images"]
         | 
| 973 | 
            +
                    input_ids: torch.Tensor = inputs["input_ids"][..., : self.tokenizer.model_max_length]
         | 
| 974 | 
            +
                    target_text_tokens_ids: torch.Tensor = inputs["target_ids"][..., : self.tokenizer.model_max_length]
         | 
| 975 | 
            +
                    attn_mask = torch.ones(input_ids.shape, dtype=torch.bool)
         | 
| 976 | 
            +
                    ee_pose_translation = torch.tensor(ee_pose_translation, dtype=torch.float32)
         | 
| 977 | 
            +
                    ee_pose_rotation = torch.tensor(ee_pose_rotation, dtype=torch.float32)
         | 
| 978 | 
            +
                    ee_pose_rotation = convert_rotation(ee_pose_rotation, self.config.rotation_format, autonorm=True)
         | 
| 979 | 
            +
                    gripper = preprocess_gripper_observation(gripper, dataset_name)
         | 
| 980 | 
            +
                    gripper = torch.tensor(gripper, dtype=torch.float32)
         | 
| 981 | 
            +
                    ee_pose_translation = self.normalize(
         | 
| 982 | 
            +
                        ee_pose_translation, dataset_name=dataset_name, key="obs_translation"
         | 
| 983 | 
            +
                    )
         | 
| 984 | 
            +
                    ee_pose_rotation = self.normalize(ee_pose_rotation, dataset_name=dataset_name, key="obs_rotation")
         | 
| 985 | 
            +
                    joints = torch.tensor(joints, dtype=torch.float32)
         | 
| 986 | 
            +
                    if joints.shape[-1] < 7:
         | 
| 987 | 
            +
                        missing_size = 7 - joints.shape[-1]
         | 
| 988 | 
            +
                        joints = torch.cat([joints, torch.zeros([*joints.shape[:-1], missing_size])], dim=-1)
         | 
| 989 | 
            +
                    joints = self.normalize(joints, dataset_name=dataset_name, key="joints")
         | 
| 990 | 
            +
                    outputs = {
         | 
| 991 | 
            +
                        "images": images,
         | 
| 992 | 
            +
                        "input_ids": input_ids,
         | 
| 993 | 
            +
                        "target_text_tokens_ids": target_text_tokens_ids,
         | 
| 994 | 
            +
                        "attn_mask": attn_mask,
         | 
| 995 | 
            +
                        "ee_pose_translation": ee_pose_translation,
         | 
| 996 | 
            +
                        "ee_pose_rotation": ee_pose_rotation,
         | 
| 997 | 
            +
                        "gripper": gripper,
         | 
| 998 | 
            +
                        "joints": joints,
         | 
| 999 | 
            +
                        "control_tokens_ids": None,
         | 
| 1000 | 
            +
                        "target_control_tokens_ids": None,
         | 
| 1001 | 
            +
                    }
         | 
| 1002 | 
            +
                    return outputs
         | 
| 1003 | 
            +
             | 
| 1004 | 
            +
                def create_input(
         | 
| 1005 | 
            +
                    self,
         | 
| 1006 | 
            +
                    chat: List[str],
         | 
| 1007 | 
            +
                    images: Dict[str, List[PIL.Image.Image]],
         | 
| 1008 | 
            +
                    ee_pose_translation: np.ndarray,
         | 
| 1009 | 
            +
                    ee_pose_rotation: np.ndarray,
         | 
| 1010 | 
            +
                    gripper: np.ndarray,
         | 
| 1011 | 
            +
                    joints: np.ndarray,
         | 
| 1012 | 
            +
                    dataset_name: np.ndarray,
         | 
| 1013 | 
            +
                    inference_mode: bool,
         | 
| 1014 | 
            +
                    control_target: Optional[RoboticsTarget] = None,
         | 
| 1015 | 
            +
                ) -> RoboticsInput:
         | 
| 1016 | 
            +
                    inputs = self.preprocess_inputs(
         | 
| 1017 | 
            +
                        chat=chat,
         | 
| 1018 | 
            +
                        images=images,
         | 
| 1019 | 
            +
                        ee_pose_translation=ee_pose_translation,
         | 
| 1020 | 
            +
                        ee_pose_rotation=ee_pose_rotation,
         | 
| 1021 | 
            +
                        gripper=gripper,
         | 
| 1022 | 
            +
                        joints=joints,
         | 
| 1023 | 
            +
                        dataset_name=dataset_name,
         | 
| 1024 | 
            +
                        inference_mode=inference_mode,
         | 
| 1025 | 
            +
                        control_target=control_target,
         | 
| 1026 | 
            +
                    )
         | 
| 1027 | 
            +
                    inputs.pop("target_text_tokens_ids")
         | 
| 1028 | 
            +
                    inputs.pop("target_control_tokens_ids")
         | 
| 1029 | 
            +
                    return RoboticsInput(**inputs)
         | 
| 1030 | 
            +
             | 
| 1031 | 
            +
                def normalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor:
         | 
| 1032 | 
            +
                    if is_mean_norm(getattr(self.config, f"{key}_norm")):
         | 
| 1033 | 
            +
                        (mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
         | 
| 1034 | 
            +
                        output = normalize_by_moments(value, mean=mean, std=std)
         | 
| 1035 | 
            +
                    else:
         | 
| 1036 | 
            +
                        (low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
         | 
| 1037 | 
            +
                        output = normalize_by_bounds(value, low=low, high=high)
         | 
| 1038 | 
            +
                    return output
         | 
| 1039 | 
            +
             | 
| 1040 | 
            +
                def unnormalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor:
         | 
| 1041 | 
            +
                    if is_mean_norm(getattr(self.config, f"{key}_norm")):
         | 
| 1042 | 
            +
                        (mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
         | 
| 1043 | 
            +
                        output = unnormalize_by_moments(value, mean=mean, std=std)
         | 
| 1044 | 
            +
                    else:
         | 
| 1045 | 
            +
                        (low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
         | 
| 1046 | 
            +
                        output = unnormalize_by_bounds(value, low=low, high=high)
         | 
| 1047 | 
            +
                    return output
         | 
| 1048 | 
            +
             | 
| 1049 | 
            +
                def _norm_bounds_from_dataset_name(
         | 
| 1050 | 
            +
                    self, dataset_name: np.ndarray, component_key: str
         | 
| 1051 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 1052 | 
            +
                    """
         | 
| 1053 | 
            +
                    Create an array of normalization bounds corresponding to dataset names
         | 
| 1054 | 
            +
                    Args:
         | 
| 1055 | 
            +
                        dataset_name: Array of shape [B] of dataset names for which to fetch the low and high
         | 
| 1056 | 
            +
                            normalization bounds. Note the values can be repeating
         | 
| 1057 | 
            +
                        component_key: str. One of 'action', 'translation', 'rotation'. Indicates for which control to
         | 
| 1058 | 
            +
                            compute the normalization bounds
         | 
| 1059 | 
            +
                    Returns:
         | 
| 1060 | 
            +
                        Tuple of low and high bounds or norm and std, each of shape [B, -1]
         | 
| 1061 | 
            +
                    """
         | 
| 1062 | 
            +
                    norm = getattr(self.config, f"{component_key}_norm")
         | 
| 1063 | 
            +
                    if is_mean_norm(norm):
         | 
| 1064 | 
            +
                        (stats_key_1, stats_key_2) = ("mean", "std")
         | 
| 1065 | 
            +
                    else:
         | 
| 1066 | 
            +
                        (stats_key_1, stats_key_2) = ("low", "high")
         | 
| 1067 | 
            +
                    if component_key == "joints":
         | 
| 1068 | 
            +
                        if not isinstance(norm, collections.abc.Mapping):
         | 
| 1069 | 
            +
                            raise NotImplementedError()
         | 
| 1070 | 
            +
                        stats = {
         | 
| 1071 | 
            +
                            key: torch.from_numpy(np.tile(np.reshape(value, [1, -1]), [len(dataset_name), 1]))
         | 
| 1072 | 
            +
                            for (key, value) in self.joints_norm_bounds["ANY"].items()
         | 
| 1073 | 
            +
                        }
         | 
| 1074 | 
            +
                        return tuple(stats.values())
         | 
| 1075 | 
            +
                    component_size = list(list(self.norm_bounds[component_key].values())[0].values())[0].shape[-1]
         | 
| 1076 | 
            +
                    if self.dataset_names == ["ANY"]:
         | 
| 1077 | 
            +
                        stats_1 = self.norm_bounds[component_key]["ANY"][stats_key_1]
         | 
| 1078 | 
            +
                        stats_2 = self.norm_bounds[component_key]["ANY"][stats_key_2]
         | 
| 1079 | 
            +
                        stats_1 = np.repeat(np.expand_dims(stats_1, axis=0), len(dataset_name), axis=0)
         | 
| 1080 | 
            +
                        stats_2 = np.repeat(np.expand_dims(stats_2, axis=0), len(dataset_name), axis=0)
         | 
| 1081 | 
            +
                    else:
         | 
| 1082 | 
            +
                        (unique_names, _, inverse_indices, _) = np_unique(dataset_name)
         | 
| 1083 | 
            +
                        stats_1 = np.zeros([len(unique_names), component_size], dtype=np.float32)
         | 
| 1084 | 
            +
                        stats_2 = np.zeros([len(unique_names), component_size], dtype=np.float32)
         | 
| 1085 | 
            +
                        for i, ds_name in enumerate(unique_names):
         | 
| 1086 | 
            +
                            stats_1[i] = self.norm_bounds[component_key][ds_name][stats_key_1].numpy()
         | 
| 1087 | 
            +
                            stats_2[i] = self.norm_bounds[component_key][ds_name][stats_key_2].numpy()
         | 
| 1088 | 
            +
                        stats_1 = stats_1[inverse_indices]
         | 
| 1089 | 
            +
                        stats_2 = stats_2[inverse_indices]
         | 
| 1090 | 
            +
                    return torch.from_numpy(stats_1), torch.from_numpy(stats_2)
         | 
| 1091 | 
            +
             | 
| 1092 | 
            +
                @cached_property
         | 
| 1093 | 
            +
                def obs_rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
         | 
| 1094 | 
            +
                    return rotation_norm_bounds(
         | 
| 1095 | 
            +
                        rotation_norm=self.config.obs_rotation_norm,
         | 
| 1096 | 
            +
                        rotation_format=self.config.rotation_format,
         | 
| 1097 | 
            +
                        stats=self._observation_stats,
         | 
| 1098 | 
            +
                        dataset_names=self.dataset_names,
         | 
| 1099 | 
            +
                    )
         | 
| 1100 | 
            +
             | 
| 1101 | 
            +
                @cached_property
         | 
| 1102 | 
            +
                def obs_translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
         | 
| 1103 | 
            +
                    return translation_norm_bounds(
         | 
| 1104 | 
            +
                        translation_norm=self.config.obs_translation_norm,
         | 
| 1105 | 
            +
                        stats=self._observation_stats,
         | 
| 1106 | 
            +
                        dataset_names=self.dataset_names,
         | 
| 1107 | 
            +
                    )
         | 
| 1108 | 
            +
             | 
| 1109 | 
            +
                @cached_property
         | 
| 1110 | 
            +
                def rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
         | 
| 1111 | 
            +
                    return rotation_norm_bounds(
         | 
| 1112 | 
            +
                        rotation_norm=self.config.rotation_norm,
         | 
| 1113 | 
            +
                        rotation_format=self.config.rotation_format,
         | 
| 1114 | 
            +
                        stats=self._control_stats,
         | 
| 1115 | 
            +
                        dataset_names=self.dataset_names,
         | 
| 1116 | 
            +
                    )
         | 
| 1117 | 
            +
             | 
| 1118 | 
            +
                @cached_property
         | 
| 1119 | 
            +
                def translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
         | 
| 1120 | 
            +
                    return translation_norm_bounds(
         | 
| 1121 | 
            +
                        translation_norm=self.config.translation_norm,
         | 
| 1122 | 
            +
                        stats=self._control_stats,
         | 
| 1123 | 
            +
                        dataset_names=self.dataset_names,
         | 
| 1124 | 
            +
                    )
         | 
| 1125 | 
            +
             | 
| 1126 | 
            +
                @cached_property
         | 
| 1127 | 
            +
                def joints_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
         | 
| 1128 | 
            +
                    """
         | 
| 1129 | 
            +
                    NOTE:
         | 
| 1130 | 
            +
                        - Joint values across all joints and all datasets vary in the range [-2pi; 2pi]
         | 
| 1131 | 
            +
                        - The effective range of a single joint is in practice one of [-2pi; 0], [-pi; pi], [0; 2pi]
         | 
| 1132 | 
            +
                        - It's possible to shift all ranges to [-pi; pi], but it requires careful handling for each joint
         | 
| 1133 | 
            +
                    """
         | 
| 1134 | 
            +
                    low = torch.tensor(self.config.joints_norm["low"], dtype=torch.float32)
         | 
| 1135 | 
            +
                    high = torch.tensor(self.config.joints_norm["high"], dtype=torch.float32)
         | 
| 1136 | 
            +
                    results = {"ANY": {"low": low, "high": high}}
         | 
| 1137 | 
            +
                    return results
         | 
| 1138 | 
            +
             | 
| 1139 | 
            +
                @cached_property
         | 
| 1140 | 
            +
                def _observation_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]:
         | 
| 1141 | 
            +
                    return {
         | 
| 1142 | 
            +
                        "bridge": {
         | 
| 1143 | 
            +
                            "euler": {
         | 
| 1144 | 
            +
                                "max": [3.141592653589793, 1.570796251296997, 3.141204357147217],
         | 
| 1145 | 
            +
                                "mean": [
         | 
| 1146 | 
            +
                                    -0.25754162314671525,
         | 
| 1147 | 
            +
                                    -0.12370228389510128,
         | 
| 1148 | 
            +
                                    0.1620053749182691,
         | 
| 1149 | 
            +
                                ],
         | 
| 1150 | 
            +
                                "min": [-3.141592653492551, -1.4832241535186768, -3.14153790473938],
         | 
| 1151 | 
            +
                                "q01": [-3.138795563420751, -0.56544608771801, -1.4952478170394896],
         | 
| 1152 | 
            +
                                "q99": [3.138720980629329, 0.2677614077925682, 2.0032371997833236],
         | 
| 1153 | 
            +
                                "std": [3.0257414011616577, 0.1622662085147332, 0.6404942954645315],
         | 
| 1154 | 
            +
                            },
         | 
| 1155 | 
            +
                            "gripper": {
         | 
| 1156 | 
            +
                                "max": [1.0370277166366577],
         | 
| 1157 | 
            +
                                "min": [0.04637829214334488],
         | 
| 1158 | 
            +
                                "q01": [0.05192930996417999],
         | 
| 1159 | 
            +
                                "q99": [1.0118417739868164],
         | 
| 1160 | 
            +
                            },
         | 
| 1161 | 
            +
                            "joints": {
         | 
| 1162 | 
            +
                                "max": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
         | 
| 1163 | 
            +
                                "mean": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
         | 
| 1164 | 
            +
                                "min": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
         | 
| 1165 | 
            +
                                "q01": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
         | 
| 1166 | 
            +
                                "q99": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
         | 
| 1167 | 
            +
                                "std": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
         | 
| 1168 | 
            +
                            },
         | 
| 1169 | 
            +
                            "translation": {
         | 
| 1170 | 
            +
                                "max": [0.5862360596656799, 0.4034728705883026, 0.3568263053894043],
         | 
| 1171 | 
            +
                                "mean": [
         | 
| 1172 | 
            +
                                    0.309032678604126,
         | 
| 1173 | 
            +
                                    0.03403777256608009,
         | 
| 1174 | 
            +
                                    0.061277542263269424,
         | 
| 1175 | 
            +
                                ],
         | 
| 1176 | 
            +
                                "min": [
         | 
| 1177 | 
            +
                                    -0.04167502000927925,
         | 
| 1178 | 
            +
                                    -0.2889411449432373,
         | 
| 1179 | 
            +
                                    -0.13934996724128723,
         | 
| 1180 | 
            +
                                ],
         | 
| 1181 | 
            +
                                "q01": [
         | 
| 1182 | 
            +
                                    0.1711955964565277,
         | 
| 1183 | 
            +
                                    -0.15639324486255646,
         | 
| 1184 | 
            +
                                    -0.048255354166030884,
         | 
| 1185 | 
            +
                                ],
         | 
| 1186 | 
            +
                                "q99": [
         | 
| 1187 | 
            +
                                    0.4604376256465912,
         | 
| 1188 | 
            +
                                    0.24112474918365479,
         | 
| 1189 | 
            +
                                    0.18886254727840424,
         | 
| 1190 | 
            +
                                ],
         | 
| 1191 | 
            +
                                "std": [
         | 
| 1192 | 
            +
                                    0.0635896623134613,
         | 
| 1193 | 
            +
                                    0.09153717756271362,
         | 
| 1194 | 
            +
                                    0.049334850162267685,
         | 
| 1195 | 
            +
                                ],
         | 
| 1196 | 
            +
                            },
         | 
| 1197 | 
            +
                        },
         | 
| 1198 | 
            +
                        "bridge_orig": {
         | 
| 1199 | 
            +
                            "euler": {
         | 
| 1200 | 
            +
                                "max": [3.141592653589793, 1.570796251296997, 3.141204357147217],
         | 
| 1201 | 
            +
                                "mean": [
         | 
| 1202 | 
            +
                                    -0.25754162314671525,
         | 
| 1203 | 
            +
                                    -0.12370228389510128,
         | 
| 1204 | 
            +
                                    0.1620053749182691,
         | 
| 1205 | 
            +
                                ],
         | 
| 1206 | 
            +
                                "min": [-3.141592653492551, -1.4832241535186768, -3.14153790473938],
         | 
| 1207 | 
            +
                                "q01": [-3.138795563420751, -0.56544608771801, -1.4952478170394896],
         | 
| 1208 | 
            +
                                "q99": [3.138720980629329, 0.2677614077925682, 2.0032371997833236],
         | 
| 1209 | 
            +
                                "std": [3.0257414011616577, 0.1622662085147332, 0.6404942954645315],
         | 
| 1210 | 
            +
                            },
         | 
| 1211 | 
            +
                            "gripper": {
         | 
| 1212 | 
            +
                                "max": [1.0370277166366577],
         | 
| 1213 | 
            +
                                "min": [0.04637829214334488],
         | 
| 1214 | 
            +
                                "q01": [0.05192930996417999],
         | 
| 1215 | 
            +
                                "q99": [1.0118417739868164],
         | 
| 1216 | 
            +
                            },
         | 
| 1217 | 
            +
                            "joints": {
         | 
| 1218 | 
            +
                                "max": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
         | 
| 1219 | 
            +
                                "mean": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
         | 
| 1220 | 
            +
                                "min": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
         | 
| 1221 | 
            +
                                "q01": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
         | 
| 1222 | 
            +
                                "q99": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
         | 
| 1223 | 
            +
                                "std": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
         | 
| 1224 | 
            +
                            },
         | 
| 1225 | 
            +
                            "translation": {
         | 
| 1226 | 
            +
                                "max": [0.5862360596656799, 0.4034728705883026, 0.3568263053894043],
         | 
| 1227 | 
            +
                                "mean": [
         | 
| 1228 | 
            +
                                    0.309032678604126,
         | 
| 1229 | 
            +
                                    0.03403777256608009,
         | 
| 1230 | 
            +
                                    0.061277542263269424,
         | 
| 1231 | 
            +
                                ],
         | 
| 1232 | 
            +
                                "min": [
         | 
| 1233 | 
            +
                                    -0.04167502000927925,
         | 
| 1234 | 
            +
                                    -0.2889411449432373,
         | 
| 1235 | 
            +
                                    -0.13934996724128723,
         | 
| 1236 | 
            +
                                ],
         | 
| 1237 | 
            +
                                "q01": [
         | 
| 1238 | 
            +
                                    0.1711955964565277,
         | 
| 1239 | 
            +
                                    -0.15639324486255646,
         | 
| 1240 | 
            +
                                    -0.048255354166030884,
         | 
| 1241 | 
            +
                                ],
         | 
| 1242 | 
            +
                                "q99": [
         | 
| 1243 | 
            +
                                    0.4604376256465912,
         | 
| 1244 | 
            +
                                    0.24112474918365479,
         | 
| 1245 | 
            +
                                    0.18886254727840424,
         | 
| 1246 | 
            +
                                ],
         | 
| 1247 | 
            +
                                "std": [
         | 
| 1248 | 
            +
                                    0.0635896623134613,
         | 
| 1249 | 
            +
                                    0.09153717756271362,
         | 
| 1250 | 
            +
                                    0.049334850162267685,
         | 
| 1251 | 
            +
                                ],
         | 
| 1252 | 
            +
                            },
         | 
| 1253 | 
            +
                        },
         | 
| 1254 | 
            +
                        "droid": {
         | 
| 1255 | 
            +
                            "euler": {
         | 
| 1256 | 
            +
                                "max": [3.141592502593994, 1.5705928802490234, 3.1415867805480957],
         | 
| 1257 | 
            +
                                "mean": [
         | 
| 1258 | 
            +
                                    0.3140628098409554,
         | 
| 1259 | 
            +
                                    -0.09296274023036387,
         | 
| 1260 | 
            +
                                    -0.07227215454779846,
         | 
| 1261 | 
            +
                                ],
         | 
| 1262 | 
            +
                                "min": [
         | 
| 1263 | 
            +
                                    -3.141592502593994,
         | 
| 1264 | 
            +
                                    -1.5691150426864624,
         | 
| 1265 | 
            +
                                    -3.1415374279022217,
         | 
| 1266 | 
            +
                                ],
         | 
| 1267 | 
            +
                                "q01": [
         | 
| 1268 | 
            +
                                    -3.1378602981567383,
         | 
| 1269 | 
            +
                                    -1.2125312042236327,
         | 
| 1270 | 
            +
                                    -2.1614069032669065,
         | 
| 1271 | 
            +
                                ],
         | 
| 1272 | 
            +
                                "q99": [3.137854380607605, 0.9200375998020163, 1.9367506909370364],
         | 
| 1273 | 
            +
                                "std": [2.926265757944871, 0.363273475703332, 0.7576065217938824],
         | 
| 1274 | 
            +
                            },
         | 
| 1275 | 
            +
                            "gripper": {
         | 
| 1276 | 
            +
                                "max": [1.0],
         | 
| 1277 | 
            +
                                "min": [0.0],
         | 
| 1278 | 
            +
                                "q01": [0.0],
         | 
| 1279 | 
            +
                                "q99": [0.9911894202232361],
         | 
| 1280 | 
            +
                            },
         | 
| 1281 | 
            +
                            "joints": {
         | 
| 1282 | 
            +
                                "max": [
         | 
| 1283 | 
            +
                                    2.668445110321045,
         | 
| 1284 | 
            +
                                    1.5691218376159668,
         | 
| 1285 | 
            +
                                    2.666306734085083,
         | 
| 1286 | 
            +
                                    -0.3114914000034332,
         | 
| 1287 | 
            +
                                    2.6624162197113037,
         | 
| 1288 | 
            +
                                    4.28157901763916,
         | 
| 1289 | 
            +
                                    2.752457857131958,
         | 
| 1290 | 
            +
                                ],
         | 
| 1291 | 
            +
                                "mean": [
         | 
| 1292 | 
            +
                                    0.023137084334640106,
         | 
| 1293 | 
            +
                                    0.2704989977282293,
         | 
| 1294 | 
            +
                                    -0.01451389357228282,
         | 
| 1295 | 
            +
                                    -2.018709403792315,
         | 
| 1296 | 
            +
                                    -0.042720520800030394,
         | 
| 1297 | 
            +
                                    2.350281188152209,
         | 
| 1298 | 
            +
                                    0.12424663946659845,
         | 
| 1299 | 
            +
                                ],
         | 
| 1300 | 
            +
                                "min": [
         | 
| 1301 | 
            +
                                    -2.6536705493927,
         | 
| 1302 | 
            +
                                    -1.547789216041565,
         | 
| 1303 | 
            +
                                    -2.6781487464904785,
         | 
| 1304 | 
            +
                                    -2.9409868717193604,
         | 
| 1305 | 
            +
                                    -2.6705946922302246,
         | 
| 1306 | 
            +
                                    0.24893812835216522,
         | 
| 1307 | 
            +
                                    -2.7615714073181152,
         | 
| 1308 | 
            +
                                ],
         | 
| 1309 | 
            +
                                "q01": [
         | 
| 1310 | 
            +
                                    -0.9026106441020965,
         | 
| 1311 | 
            +
                                    -0.8547340619564057,
         | 
| 1312 | 
            +
                                    -0.9028875434398651,
         | 
| 1313 | 
            +
                                    -2.7698556280136106,
         | 
| 1314 | 
            +
                                    -1.6851656341552732,
         | 
| 1315 | 
            +
                                    1.2335169839859008,
         | 
| 1316 | 
            +
                                    -1.9587260699272155,
         | 
| 1317 | 
            +
                                ],
         | 
| 1318 | 
            +
                                "q99": [
         | 
| 1319 | 
            +
                                    0.9569852340221403,
         | 
| 1320 | 
            +
                                    1.4148830294609054,
         | 
| 1321 | 
            +
                                    0.7693877756595566,
         | 
| 1322 | 
            +
                                    -0.4545914208889008,
         | 
| 1323 | 
            +
                                    1.5623322343826267,
         | 
| 1324 | 
            +
                                    3.475611729621887,
         | 
| 1325 | 
            +
                                    2.263479118347167,
         | 
| 1326 | 
            +
                                ],
         | 
| 1327 | 
            +
                                "std": [
         | 
| 1328 | 
            +
                                    0.31695080251469465,
         | 
| 1329 | 
            +
                                    0.49522214687158767,
         | 
| 1330 | 
            +
                                    0.27993538230553827,
         | 
| 1331 | 
            +
                                    0.478161574676113,
         | 
| 1332 | 
            +
                                    0.4969961591445458,
         | 
| 1333 | 
            +
                                    0.45101008525403846,
         | 
| 1334 | 
            +
                                    0.7287264344068457,
         | 
| 1335 | 
            +
                                ],
         | 
| 1336 | 
            +
                            },
         | 
| 1337 | 
            +
                            "translation": {
         | 
| 1338 | 
            +
                                "max": [0.8575563430786133, 0.799155592918396, 1.0043904781341553],
         | 
| 1339 | 
            +
                                "mean": [
         | 
| 1340 | 
            +
                                    0.5283099395864883,
         | 
| 1341 | 
            +
                                    0.005363794653877434,
         | 
| 1342 | 
            +
                                    0.3120132207021294,
         | 
| 1343 | 
            +
                                ],
         | 
| 1344 | 
            +
                                "min": [
         | 
| 1345 | 
            +
                                    -0.15604186058044434,
         | 
| 1346 | 
            +
                                    -0.827903687953949,
         | 
| 1347 | 
            +
                                    -0.2347021996974945,
         | 
| 1348 | 
            +
                                ],
         | 
| 1349 | 
            +
                                "q01": [
         | 
| 1350 | 
            +
                                    0.26669957995414734,
         | 
| 1351 | 
            +
                                    -0.43774398624897004,
         | 
| 1352 | 
            +
                                    -0.048167889714241026,
         | 
| 1353 | 
            +
                                ],
         | 
| 1354 | 
            +
                                "q99": [0.7774086785316463, 0.428325751423835, 0.776091011762619],
         | 
| 1355 | 
            +
                                "std": [
         | 
| 1356 | 
            +
                                    0.1148424841779685,
         | 
| 1357 | 
            +
                                    0.17489566608140428,
         | 
| 1358 | 
            +
                                    0.16541062032731538,
         | 
| 1359 | 
            +
                                ],
         | 
| 1360 | 
            +
                            },
         | 
| 1361 | 
            +
                        },
         | 
| 1362 | 
            +
                        "roboset": {
         | 
| 1363 | 
            +
                            "euler": {
         | 
| 1364 | 
            +
                                "max": [3.1415449294818236, 1.5705575529715636, 3.141527342124582],
         | 
| 1365 | 
            +
                                "mean": [
         | 
| 1366 | 
            +
                                    -0.0398455755412464,
         | 
| 1367 | 
            +
                                    1.0518070390619125,
         | 
| 1368 | 
            +
                                    -0.015345692503002759,
         | 
| 1369 | 
            +
                                ],
         | 
| 1370 | 
            +
                                "min": [
         | 
| 1371 | 
            +
                                    -3.1415813300509536,
         | 
| 1372 | 
            +
                                    -1.5222832468962035,
         | 
| 1373 | 
            +
                                    -3.141575300866071,
         | 
| 1374 | 
            +
                                ],
         | 
| 1375 | 
            +
                                "q01": [
         | 
| 1376 | 
            +
                                    -2.9414386317311187,
         | 
| 1377 | 
            +
                                    -0.24976770655101155,
         | 
| 1378 | 
            +
                                    -2.985256521212579,
         | 
| 1379 | 
            +
                                ],
         | 
| 1380 | 
            +
                                "q99": [2.9380437893235993, 1.5403010739503078, 2.9746912523985025],
         | 
| 1381 | 
            +
                                "std": [1.7866587696177456, 0.40620530263065, 1.7288511340250616],
         | 
| 1382 | 
            +
                            },
         | 
| 1383 | 
            +
                            "gripper": {
         | 
| 1384 | 
            +
                                "max": [0.83056640625],
         | 
| 1385 | 
            +
                                "min": [0.0001499652862548828],
         | 
| 1386 | 
            +
                                "q01": [0.0001499652862548828],
         | 
| 1387 | 
            +
                                "q99": [0.82666015625],
         | 
| 1388 | 
            +
                            },
         | 
| 1389 | 
            +
                            "joints": {
         | 
| 1390 | 
            +
                                "max": [
         | 
| 1391 | 
            +
                                    0.96240234375,
         | 
| 1392 | 
            +
                                    1.1162109375,
         | 
| 1393 | 
            +
                                    1.1064453125,
         | 
| 1394 | 
            +
                                    -0.98095703125,
         | 
| 1395 | 
            +
                                    2.30859375,
         | 
| 1396 | 
            +
                                    1.576171875,
         | 
| 1397 | 
            +
                                    1.7412109375,
         | 
| 1398 | 
            +
                                ],
         | 
| 1399 | 
            +
                                "mean": [
         | 
| 1400 | 
            +
                                    0.005913593806326389,
         | 
| 1401 | 
            +
                                    0.1877261847257614,
         | 
| 1402 | 
            +
                                    0.04653879255056381,
         | 
| 1403 | 
            +
                                    -2.0529513359069824,
         | 
| 1404 | 
            +
                                    -0.011298442259430885,
         | 
| 1405 | 
            +
                                    0.6185526251792908,
         | 
| 1406 | 
            +
                                    -0.01701134257018566,
         | 
| 1407 | 
            +
                                ],
         | 
| 1408 | 
            +
                                "min": [
         | 
| 1409 | 
            +
                                    -0.8330078125,
         | 
| 1410 | 
            +
                                    -0.74658203125,
         | 
| 1411 | 
            +
                                    -0.8642578125,
         | 
| 1412 | 
            +
                                    -2.892578125,
         | 
| 1413 | 
            +
                                    -1.390625,
         | 
| 1414 | 
            +
                                    -0.24658203125,
         | 
| 1415 | 
            +
                                    -2.953125,
         | 
| 1416 | 
            +
                                ],
         | 
| 1417 | 
            +
                                "q01": [
         | 
| 1418 | 
            +
                                    -0.41015625,
         | 
| 1419 | 
            +
                                    -0.5302734375,
         | 
| 1420 | 
            +
                                    -0.6455078125,
         | 
| 1421 | 
            +
                                    -2.57421875,
         | 
| 1422 | 
            +
                                    -0.76416015625,
         | 
| 1423 | 
            +
                                    -0.0386962890625,
         | 
| 1424 | 
            +
                                    -1.435546875,
         | 
| 1425 | 
            +
                                ],
         | 
| 1426 | 
            +
                                "q99": [
         | 
| 1427 | 
            +
                                    0.66455078125,
         | 
| 1428 | 
            +
                                    0.9501953125,
         | 
| 1429 | 
            +
                                    0.7529296875,
         | 
| 1430 | 
            +
                                    -1.251953125,
         | 
| 1431 | 
            +
                                    0.75244140625,
         | 
| 1432 | 
            +
                                    1.2314453125,
         | 
| 1433 | 
            +
                                    1.384765625,
         | 
| 1434 | 
            +
                                ],
         | 
| 1435 | 
            +
                                "std": [
         | 
| 1436 | 
            +
                                    0.17915399372577667,
         | 
| 1437 | 
            +
                                    0.32234326004981995,
         | 
| 1438 | 
            +
                                    0.26069700717926025,
         | 
| 1439 | 
            +
                                    0.31767210364341736,
         | 
| 1440 | 
            +
                                    0.205329030752182,
         | 
| 1441 | 
            +
                                    0.33385637402534485,
         | 
| 1442 | 
            +
                                    0.6263682842254639,
         | 
| 1443 | 
            +
                                ],
         | 
| 1444 | 
            +
                            },
         | 
| 1445 | 
            +
                            "translation": {
         | 
| 1446 | 
            +
                                "max": [0.5747738480567932, 0.3972920775413513, 0.7443570494651794],
         | 
| 1447 | 
            +
                                "mean": [
         | 
| 1448 | 
            +
                                    0.3331542909145355,
         | 
| 1449 | 
            +
                                    0.019357483834028244,
         | 
| 1450 | 
            +
                                    0.37330344319343567,
         | 
| 1451 | 
            +
                                ],
         | 
| 1452 | 
            +
                                "min": [
         | 
| 1453 | 
            +
                                    0.09978063404560089,
         | 
| 1454 | 
            +
                                    -0.29593944549560547,
         | 
| 1455 | 
            +
                                    0.10065606236457825,
         | 
| 1456 | 
            +
                                ],
         | 
| 1457 | 
            +
                                "q01": [
         | 
| 1458 | 
            +
                                    0.18437016010284424,
         | 
| 1459 | 
            +
                                    -0.25699371099472046,
         | 
| 1460 | 
            +
                                    0.15134164690971375,
         | 
| 1461 | 
            +
                                ],
         | 
| 1462 | 
            +
                                "q99": [0.543661892414093, 0.29646238684654236, 0.6682320833206177],
         | 
| 1463 | 
            +
                                "std": [
         | 
| 1464 | 
            +
                                    0.07849054038524628,
         | 
| 1465 | 
            +
                                    0.12241040915250778,
         | 
| 1466 | 
            +
                                    0.1460595279932022,
         | 
| 1467 | 
            +
                                ],
         | 
| 1468 | 
            +
                            },
         | 
| 1469 | 
            +
                        },
         | 
| 1470 | 
            +
                    }
         | 
| 1471 | 
            +
             | 
| 1472 | 
            +
                @cached_property
         | 
| 1473 | 
            +
                def _control_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]:
         | 
| 1474 | 
            +
                    if is_global_norm(self.config.rotation_norm) and is_global_norm(self.config.translation_norm):
         | 
| 1475 | 
            +
                        return {}
         | 
| 1476 | 
            +
                    with open(self.config.control_stats_path, "r") as file:
         | 
| 1477 | 
            +
                        stats = yaml.safe_load(file)
         | 
| 1478 | 
            +
                        if self.config.delta_controls:
         | 
| 1479 | 
            +
                            if self.control_io_config.future_controls_sequence_stride_sec is None:
         | 
| 1480 | 
            +
                                horizon = 0.0
         | 
| 1481 | 
            +
                            else:
         | 
| 1482 | 
            +
                                horizon = self.control_io_config.future_controls_sequence_stride_sec
         | 
| 1483 | 
            +
                        elif self.control_io_config.future_controls_sequence_stride_sec is None:
         | 
| 1484 | 
            +
                            if self.control_io_config.future_controls_sequence_length == 1:
         | 
| 1485 | 
            +
                                horizon = 0.0
         | 
| 1486 | 
            +
                            else:
         | 
| 1487 | 
            +
                                raise NotImplementedError()
         | 
| 1488 | 
            +
                        else:
         | 
| 1489 | 
            +
                            horizon = (
         | 
| 1490 | 
            +
                                self.control_io_config.future_controls_sequence_length
         | 
| 1491 | 
            +
                                * self.control_io_config.future_controls_sequence_stride_sec
         | 
| 1492 | 
            +
                            )
         | 
| 1493 | 
            +
                        key = f"horizon_{round(horizon, 2)}s"
         | 
| 1494 | 
            +
                        if key in stats:
         | 
| 1495 | 
            +
                            stats = stats[key]
         | 
| 1496 | 
            +
                        else:
         | 
| 1497 | 
            +
                            raise ValueError(
         | 
| 1498 | 
            +
                                f"Missing control statistics key {key} for future_controls_sequence_length={self.config.control_io_config.future_controls_sequence_length} future_controls_sequence_stride_sec={self.config.control_io_config.future_controls_sequence_stride_sec}. Available keys: [{stats.keys()}]"
         | 
| 1499 | 
            +
                            )
         | 
| 1500 | 
            +
                    return stats
         | 
| 1501 | 
            +
             | 
| 1502 | 
            +
                @cached_property
         | 
| 1503 | 
            +
                def dataset_names(self) -> List[str]:
         | 
| 1504 | 
            +
                    if (
         | 
| 1505 | 
            +
                        is_global_norm(self.config.rotation_norm)
         | 
| 1506 | 
            +
                        and is_global_norm(self.config.obs_rotation_norm)
         | 
| 1507 | 
            +
                        and is_global_norm(self.config.translation_norm)
         | 
| 1508 | 
            +
                        and is_global_norm(self.config.obs_translation_norm)
         | 
| 1509 | 
            +
                    ):
         | 
| 1510 | 
            +
                        return ["ANY"]
         | 
| 1511 | 
            +
                    return list(set(self._control_stats.keys()) | set(self._observation_stats.keys()))
         | 
| 1512 | 
            +
             | 
| 1513 | 
            +
             | 
| 1514 | 
            +
            def delta_to_relative_translations(translation_sequence: torch.Tensor) -> torch.Tensor:
         | 
| 1515 | 
            +
                """
         | 
| 1516 | 
            +
                Transform a sequence of translation vectors encoded w.r.t. PREVIOUS frame in the sequence to encoding
         | 
| 1517 | 
            +
                w.r.t. the 0-th element preceding the sequence
         | 
| 1518 | 
            +
                Ex:
         | 
| 1519 | 
            +
                    Sequence of points: T1, T2, T3, T4
         | 
| 1520 | 
            +
                    `translation_sequence` contains the vectors: T0T1, T1T2, T2T3, T3T4, where T0 is the base frame,
         | 
| 1521 | 
            +
                    implicitly encoded in T0T1
         | 
| 1522 | 
            +
                    Output: T0T1, T0T2, T0T3, T0T4
         | 
| 1523 | 
            +
             | 
| 1524 | 
            +
                Args:
         | 
| 1525 | 
            +
                    translation_sequence: torch.Tensor of shape [..., S, 3], containing the translation vectors, where S
         | 
| 1526 | 
            +
                        corresponds to the sequence dimension
         | 
| 1527 | 
            +
                Returns:
         | 
| 1528 | 
            +
                    torch.Tensor of the same shape as translation_sequence, containing delta translations
         | 
| 1529 | 
            +
                """
         | 
| 1530 | 
            +
                assert translation_sequence.ndim >= 3, translation_sequence.shape
         | 
| 1531 | 
            +
                delta_translations = torch.cumsum(translation_sequence, dim=-2)
         | 
| 1532 | 
            +
                return delta_translations
         | 
| 1533 | 
            +
             | 
| 1534 | 
            +
             | 
| 1535 | 
            +
            class RegressionProcessor(VLAMProcessor):
         | 
| 1536 | 
            +
                def policy_control_plan_from_model_target(
         | 
| 1537 | 
            +
                    self, target: RoboticsTarget, dataset_name: np.ndarray
         | 
| 1538 | 
            +
                ) -> RoboticsControlPlan:
         | 
| 1539 | 
            +
                    translation_m = self.unnormalize(target.translation, dataset_name=dataset_name, key="translation")
         | 
| 1540 | 
            +
                    rotation = self.unnormalize(target.rotation, dataset_name=dataset_name, key="rotation")
         | 
| 1541 | 
            +
                    rotmat = convert_rotation(rotation, RotationFormat.ROTMAT)
         | 
| 1542 | 
            +
                    gripper_prob = target.gripper
         | 
| 1543 | 
            +
                    if self.config.delta_controls:
         | 
| 1544 | 
            +
                        translation_m = delta_to_relative_translations(translation_m)
         | 
| 1545 | 
            +
                        rotmat = delta_to_relative_rotations(rotmat)
         | 
| 1546 | 
            +
                    return RoboticsControlPlan(
         | 
| 1547 | 
            +
                        translation_m=translation_m,
         | 
| 1548 | 
            +
                        rotmat=rotmat,
         | 
| 1549 | 
            +
                        gripper_prob=gripper_prob,
         | 
| 1550 | 
            +
                        valid_mask=target.valid_mask,
         | 
| 1551 | 
            +
                    )
         | 
| 1552 | 
            +
             | 
| 1553 | 
            +
                def policy_control_plan_from_model_output(
         | 
| 1554 | 
            +
                    self,
         | 
| 1555 | 
            +
                    model_output: RoboticsOutput,
         | 
| 1556 | 
            +
                    dataset_name: np.ndarray,
         | 
| 1557 | 
            +
                    valid_mask: torch.Tensor,
         | 
| 1558 | 
            +
                ) -> RoboticsControlPlan:
         | 
| 1559 | 
            +
                    """Called during inference to create control plan from model output"""
         | 
| 1560 | 
            +
                    translation_m = self.unnormalize(
         | 
| 1561 | 
            +
                        model_output.translation, dataset_name=dataset_name, key="translation"
         | 
| 1562 | 
            +
                    )
         | 
| 1563 | 
            +
                    rotation = self.unnormalize(model_output.rotation, dataset_name=dataset_name, key="rotation")
         | 
| 1564 | 
            +
                    rotmat = convert_rotation(rotation, RotationFormat.ROTMAT, autonorm=True)
         | 
| 1565 | 
            +
                    gripper_prob = torch.sigmoid(model_output.gripper)
         | 
| 1566 | 
            +
                    if self.config.delta_controls:
         | 
| 1567 | 
            +
                        translation_m = delta_to_relative_translations(translation_m)
         | 
| 1568 | 
            +
                        rotmat = delta_to_relative_rotations(rotmat)
         | 
| 1569 | 
            +
                    return RoboticsControlPlan(
         | 
| 1570 | 
            +
                        translation_m=translation_m,
         | 
| 1571 | 
            +
                        rotmat=rotmat,
         | 
| 1572 | 
            +
                        gripper_prob=gripper_prob,
         | 
| 1573 | 
            +
                        valid_mask=valid_mask,
         | 
| 1574 | 
            +
                    )
         | 
| 1575 | 
            +
             | 
| 1576 | 
            +
             | 
| 1577 | 
            +
            class PiZeroFlowMatchingProcessor(RegressionProcessor):
         | 
| 1578 | 
            +
                def __init__(self, **kwargs):
         | 
| 1579 | 
            +
                    super().__init__(**kwargs)
         | 
| 1580 | 
            +
                    self.generator: torch.Generator = torch.Generator()
         | 
| 1581 | 
            +
             | 
| 1582 | 
            +
                @cached_property
         | 
| 1583 | 
            +
                def beta_distribution(self) -> torch.distributions.Beta:
         | 
| 1584 | 
            +
                    return torch.distributions.Beta(
         | 
| 1585 | 
            +
                        self.config.distribution_hyperparams.get("alpha", 1.5),
         | 
| 1586 | 
            +
                        self.config.distribution_hyperparams.get("beta", 1.0),
         | 
| 1587 | 
            +
                    )
         | 
| 1588 | 
            +
             | 
| 1589 | 
            +
                def create_input(self, *args, **kwargs) -> RoboticsFlowInput:
         | 
| 1590 | 
            +
                    """In practice used only during inference"""
         | 
| 1591 | 
            +
                    inputs = super().create_input(*args, **kwargs)
         | 
| 1592 | 
            +
                    flow_input: FlowInput = self.sample_t0_input(batch_size=1, device=torch.device("cpu"))
         | 
| 1593 | 
            +
                    inputs = RoboticsFlowInput(**inputs.as_json(), flow_input=flow_input[0, ...])
         | 
| 1594 | 
            +
                    return inputs
         | 
| 1595 | 
            +
             | 
| 1596 | 
            +
                def sample_timestep(self, batch_size: int) -> torch.Tensor:
         | 
| 1597 | 
            +
                    if self.config.timestep_distribution.lower() == "uniform":
         | 
| 1598 | 
            +
                        eps = 1e-05
         | 
| 1599 | 
            +
                        sample = (torch.rand(1, generator=self.generator) + torch.arange(batch_size) / batch_size) % (
         | 
| 1600 | 
            +
                            1 - eps
         | 
| 1601 | 
            +
                        )
         | 
| 1602 | 
            +
                    elif self.config.timestep_distribution.lower() == "beta":
         | 
| 1603 | 
            +
                        sample = self.beta_distribution.sample([batch_size, 1, 1])
         | 
| 1604 | 
            +
                        sample = (1 - self.config.sig_min) * (1 - sample)
         | 
| 1605 | 
            +
                    else:
         | 
| 1606 | 
            +
                        raise NotImplementedError(self.config.timestep_distribution)
         | 
| 1607 | 
            +
                    sample = sample.view(batch_size, 1, 1)
         | 
| 1608 | 
            +
                    return sample
         | 
| 1609 | 
            +
             | 
| 1610 | 
            +
                def _psi_t(self, timestep: torch.Tensor, x_0: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor:
         | 
| 1611 | 
            +
                    return (1 - (1 - self.config.sig_min) * timestep) * x_0 + timestep * x_1
         | 
| 1612 | 
            +
             | 
| 1613 | 
            +
                def _dpsi_dt(self, x_0: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor:
         | 
| 1614 | 
            +
                    return x_1 - (1 - self.config.sig_min) * x_0
         | 
| 1615 | 
            +
             | 
| 1616 | 
            +
                def sample_t0_input(self, batch_size: int, device: torch.device) -> FlowInput:
         | 
| 1617 | 
            +
                    if self.config.r0_distribution == "normal":
         | 
| 1618 | 
            +
                        controls_t0 = torch.randn(
         | 
| 1619 | 
            +
                            [
         | 
| 1620 | 
            +
                                batch_size,
         | 
| 1621 | 
            +
                                self.config.control_io_config.future_controls_sequence_length,
         | 
| 1622 | 
            +
                                3 + self.rotation_components + 1,
         | 
| 1623 | 
            +
                            ],
         | 
| 1624 | 
            +
                            generator=self.generator,
         | 
| 1625 | 
            +
                        ).to(device=device)
         | 
| 1626 | 
            +
                        (translation_t0, rotation_t0, gripper_t0) = torch.split(
         | 
| 1627 | 
            +
                            controls_t0, [3, self.rotation_components, 1], dim=-1
         | 
| 1628 | 
            +
                        )
         | 
| 1629 | 
            +
                        rotation_t0 = normalize_rotation(rotation_t0)
         | 
| 1630 | 
            +
                    elif self.config.r0_distribution == "uniform":
         | 
| 1631 | 
            +
                        controls_t0 = torch.randn(
         | 
| 1632 | 
            +
                            [
         | 
| 1633 | 
            +
                                batch_size,
         | 
| 1634 | 
            +
                                self.config.control_io_config.future_controls_sequence_length,
         | 
| 1635 | 
            +
                                4,
         | 
| 1636 | 
            +
                            ],
         | 
| 1637 | 
            +
                            generator=self.generator,
         | 
| 1638 | 
            +
                        ).to(device=device)
         | 
| 1639 | 
            +
                        (translation_t0, gripper_t0) = torch.split(controls_t0, [3, 1], dim=-1)
         | 
| 1640 | 
            +
                        rotation_t0 = convert_rotation(
         | 
| 1641 | 
            +
                            roma.random_unitquat(
         | 
| 1642 | 
            +
                                (
         | 
| 1643 | 
            +
                                    batch_size,
         | 
| 1644 | 
            +
                                    self.config.control_io_config.future_controls_sequence_length,
         | 
| 1645 | 
            +
                                ),
         | 
| 1646 | 
            +
                                device=device,
         | 
| 1647 | 
            +
                            ),
         | 
| 1648 | 
            +
                            self.config.rotation_format,
         | 
| 1649 | 
            +
                        )
         | 
| 1650 | 
            +
                    else:
         | 
| 1651 | 
            +
                        raise NotImplementedError(self.config.r0_distribution)
         | 
| 1652 | 
            +
                    if self.config.rotation_format == RotationFormat.QUATERNION:
         | 
| 1653 | 
            +
                        rotation_t0 = quaternion_half_cover(rotation_t0)
         | 
| 1654 | 
            +
                    timestep = torch.zeros([batch_size, 1, 1], device=device)
         | 
| 1655 | 
            +
                    return FlowInput(
         | 
| 1656 | 
            +
                        timestep=timestep,
         | 
| 1657 | 
            +
                        translation_t0=translation_t0,
         | 
| 1658 | 
            +
                        rotation_t0=rotation_t0,
         | 
| 1659 | 
            +
                        gripper_t0=gripper_t0,
         | 
| 1660 | 
            +
                        translation_t=None,
         | 
| 1661 | 
            +
                        rotation_t=None,
         | 
| 1662 | 
            +
                        gripper_t=None,
         | 
| 1663 | 
            +
                    )
         | 
| 1664 | 
            +
             | 
| 1665 | 
            +
                def policy_control_plan_from_model_output(
         | 
| 1666 | 
            +
                    self,
         | 
| 1667 | 
            +
                    model_output: RoboticsOutput,
         | 
| 1668 | 
            +
                    dataset_name: np.ndarray,
         | 
| 1669 | 
            +
                    valid_mask: torch.Tensor,
         | 
| 1670 | 
            +
                ) -> RoboticsControlPlan:
         | 
| 1671 | 
            +
                    if self.config.translation_norm == Normalization.NONE or is_mean_norm(self.config.translation_norm):
         | 
| 1672 | 
            +
                        model_output = model_output.replace(translation=torch.clamp(model_output.translation, -1, 1))
         | 
| 1673 | 
            +
                    if self.config.rotation_norm == Normalization.NONE or is_mean_norm(self.config.rotation_norm):
         | 
| 1674 | 
            +
                        model_output = model_output.replace(rotation=torch.clamp(model_output.rotation, -1, 1))
         | 
| 1675 | 
            +
                    control_plan = super().policy_control_plan_from_model_output(model_output, dataset_name, valid_mask)
         | 
| 1676 | 
            +
                    control_plan = control_plan.replace(gripper_prob=torch.clamp(model_output.gripper, 0, 1))
         | 
| 1677 | 
            +
                    return control_plan
         | 
| 1678 | 
            +
             | 
| 1679 | 
            +
             | 
| 1680 | 
            +
            def make_causal_mask(shape: Sequence[int]) -> torch.Tensor:
         | 
| 1681 | 
            +
                """
         | 
| 1682 | 
            +
                Create a causal attention mask of shape `shape`
         | 
| 1683 | 
            +
                Args:
         | 
| 1684 | 
            +
                    shape: Shape of the output mask, the last two dimensions correspond to [query_seq_len, kv_seq_len]
         | 
| 1685 | 
            +
                Returns:
         | 
| 1686 | 
            +
                    torch.Tensor of dtype torch.bool. False values indicate that the row (i.e. query) can't attend
         | 
| 1687 | 
            +
                        to the corresponding column (i.e. key)
         | 
| 1688 | 
            +
             | 
| 1689 | 
            +
                Example:
         | 
| 1690 | 
            +
                    shape = (3, 5) -> Mask the upper triangular part
         | 
| 1691 | 
            +
                    [
         | 
| 1692 | 
            +
                        [ 1, 0, 0, 0, 0],
         | 
| 1693 | 
            +
                        [ 1, 1, 0, 0, 0],
         | 
| 1694 | 
            +
                        [ 1, 1, 1, 0, 0]
         | 
| 1695 | 
            +
                    ]
         | 
| 1696 | 
            +
                """
         | 
| 1697 | 
            +
                return torch.tril(torch.ones(shape, dtype=torch.bool), diagonal=0)
         | 
| 1698 | 
            +
             | 
| 1699 | 
            +
             | 
| 1700 | 
            +
            def enable_full_attn_blocks(attn_mask: torch.Tensor, full_attn: torch.Tensor) -> torch.Tensor:
         | 
| 1701 | 
            +
                """
         | 
| 1702 | 
            +
                Enable full bi-directional attention in `attn_mask` inside specific blocks
         | 
| 1703 | 
            +
                Args:
         | 
| 1704 | 
            +
                    attn_mask: Existing attention mask of shape [..., query_seq_len, kv_seq_len] and dtype torch.bool
         | 
| 1705 | 
            +
                        where False values indicate disabled attention
         | 
| 1706 | 
            +
                    full_attn: torch.Tensor of shape [query_seq_len], dtype torch.bool. Blocks of True values indicate
         | 
| 1707 | 
            +
                        positions where full bi-directional attention should be enabled
         | 
| 1708 | 
            +
             | 
| 1709 | 
            +
                Example:
         | 
| 1710 | 
            +
                        1, 0, 0, 0, 0, 0, 0, 0,                 1, 1, 1, 0, 0, 0, 0, 0,
         | 
| 1711 | 
            +
                        1, 1, 0, 0, 0, 0, 0, 0,                 1, 1, 1, 0, 0, 0, 0, 0,
         | 
| 1712 | 
            +
                        1, 1, 1, 0, 0, 0, 0, 0,                 1, 1, 1, 0, 0, 0, 0, 0,
         | 
| 1713 | 
            +
                        1, 1, 1, 1, 0, 0, 0, 0,      ->         1, 1, 1, 1, 0, 0, 0, 0,
         | 
| 1714 | 
            +
                        1, 1, 1, 1, 1, 0, 0, 0,                 1, 1, 1, 1, 1, 0, 0, 0,
         | 
| 1715 | 
            +
                        1, 1, 1, 1, 1, 1, 0, 0,                 1, 1, 1, 1, 1, 1, 1, 1,
         | 
| 1716 | 
            +
                        1, 1, 1, 1, 1, 1, 1, 0,                 1, 1, 1, 1, 1, 1, 1, 1,
         | 
| 1717 | 
            +
                        1, 1, 1, 1, 1, 1, 1, 1,                 1, 1, 1, 1, 1, 1, 1, 1,
         | 
| 1718 | 
            +
             | 
| 1719 | 
            +
                """
         | 
| 1720 | 
            +
                assert full_attn.dtype == torch.bool, full_attn.dtype
         | 
| 1721 | 
            +
                assert full_attn.ndim == 1, full_attn.shape
         | 
| 1722 | 
            +
                assert full_attn.shape[0] == attn_mask.shape[-2], f"{full_attn.shape[0]}, {attn_mask.shape}"
         | 
| 1723 | 
            +
                if attn_mask.shape[-1] != attn_mask.shape[-2]:
         | 
| 1724 | 
            +
                    raise NotImplementedError("Only self-attention supported right now.")
         | 
| 1725 | 
            +
                x = full_attn.view(-1, 1) & full_attn.view(1, -1)
         | 
| 1726 | 
            +
                x = x | make_causal_mask([full_attn.shape[0], full_attn.shape[0]])
         | 
| 1727 | 
            +
                x = torch.cumprod(x, dim=1).to(dtype=torch.bool)
         | 
| 1728 | 
            +
                x = x & x.permute(1, 0)
         | 
| 1729 | 
            +
                mask_positions = torch.sum(x, dim=0) == 1 & ~full_attn
         | 
| 1730 | 
            +
                mask_indices = torch.where(mask_positions)[0]
         | 
| 1731 | 
            +
                x[mask_indices, mask_indices] = 0
         | 
| 1732 | 
            +
                attn_mask = attn_mask | expand_dims(x, ndim=attn_mask.ndim, order=[-1, 1, 1])
         | 
| 1733 | 
            +
                return attn_mask
         | 
| 1734 | 
            +
             | 
| 1735 | 
            +
             | 
| 1736 | 
            +
            IGNORE_INDEX = -100
         | 
| 1737 | 
            +
             | 
| 1738 | 
            +
             | 
| 1739 | 
            +
            class PaliGemmaProcessor(VLMProcessor):
         | 
| 1740 | 
            +
                def __init__(
         | 
| 1741 | 
            +
                    self,
         | 
| 1742 | 
            +
                    config: PaliGemmaProcessorConfig,
         | 
| 1743 | 
            +
                    hf_processor: transformers.models.paligemma.processing_paligemma.PaliGemmaProcessor,
         | 
| 1744 | 
            +
                    **kwargs,
         | 
| 1745 | 
            +
                ):
         | 
| 1746 | 
            +
                    del kwargs
         | 
| 1747 | 
            +
                    super().__init__(config)
         | 
| 1748 | 
            +
                    self.hf_processor = hf_processor
         | 
| 1749 | 
            +
                    self.hf_processor.image_processor.size = dict(self.config.image_sizes["main"].as_json())
         | 
| 1750 | 
            +
                    self.hf_processor.image_seq_length = self.config.num_image_tokens["main"]
         | 
| 1751 | 
            +
                    self.hf_processor.image_processor.image_seq_length = self.config.num_image_tokens["main"]
         | 
| 1752 | 
            +
                    self.bos_id: int = self.tokenizer.bos_token_id
         | 
| 1753 | 
            +
                    self.eos_id: int = self.tokenizer.eos_token_id
         | 
| 1754 | 
            +
                    self.sep_token = "\n"
         | 
| 1755 | 
            +
                    self.sep_id: int = self.tokenizer(
         | 
| 1756 | 
            +
                        self.sep_token,
         | 
| 1757 | 
            +
                        padding=False,
         | 
| 1758 | 
            +
                        add_special_tokens=False,
         | 
| 1759 | 
            +
                        return_attention_mask=False,
         | 
| 1760 | 
            +
                    )["input_ids"][0]
         | 
| 1761 | 
            +
                    self.image_token_id: int = self.tokenizer(
         | 
| 1762 | 
            +
                        self.config.image_token,
         | 
| 1763 | 
            +
                        padding=False,
         | 
| 1764 | 
            +
                        add_special_tokens=False,
         | 
| 1765 | 
            +
                        return_attention_mask=False,
         | 
| 1766 | 
            +
                    )["input_ids"][0]
         | 
| 1767 | 
            +
                    self.image_tokens: list[int] = [self.image_token_id] * sum(self.config.num_image_tokens.values())
         | 
| 1768 | 
            +
                    self.bbox_pattern = re.compile(
         | 
| 1769 | 
            +
                        "\\[(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+)\\]"
         | 
| 1770 | 
            +
                    )
         | 
| 1771 | 
            +
             | 
| 1772 | 
            +
                def preprocess_inputs(
         | 
| 1773 | 
            +
                    self, chat: List[str], images: Dict[str, List[PIL.Image.Image]]
         | 
| 1774 | 
            +
                ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
         | 
| 1775 | 
            +
                    """
         | 
| 1776 | 
            +
                    Based on PaliGemma paper https://arxiv.org/pdf/2407.07726 and example code at
         | 
| 1777 | 
            +
                    https://ai.google.dev/gemma/docs/paligemma/fine-tuning-paligemma#create_model_inputs
         | 
| 1778 | 
            +
                    Chat must be always made of separate messages from user and model, always starting with user
         | 
| 1779 | 
            +
             | 
| 1780 | 
            +
                    <image><image> ... <bos><instruction><sep><assistant><sep><instruction><sep><assistant>...<eos>
         | 
| 1781 | 
            +
             | 
| 1782 | 
            +
                    Args:
         | 
| 1783 | 
            +
                        chat: List[str] of even size where each entry corresponds to a different turn in the conversation
         | 
| 1784 | 
            +
                        images: Dict[str, List[PIL.Image.Image]] where different cameras correspond to different keys
         | 
| 1785 | 
            +
                            in the Dict and the List corresponds to history of images
         | 
| 1786 | 
            +
                    """
         | 
| 1787 | 
            +
                    for key, value in images.items():
         | 
| 1788 | 
            +
                        if not isinstance(value, list):
         | 
| 1789 | 
            +
                            raise TypeError(f"Camera {key} contains values of type {type(value)} instead of list")
         | 
| 1790 | 
            +
                    (input_ids, target_ids) = ([], [])
         | 
| 1791 | 
            +
                    for i, text in enumerate(chat):
         | 
| 1792 | 
            +
                        text = text.replace(self.sep_token, " ").replace("<image>", "")
         | 
| 1793 | 
            +
                        text = self.bbox_pattern.sub(self._bbox_to_loc_tokens, text)
         | 
| 1794 | 
            +
                        turn_input_ids: List[int] = self.tokenizer(
         | 
| 1795 | 
            +
                            text,
         | 
| 1796 | 
            +
                            padding=False,
         | 
| 1797 | 
            +
                            add_special_tokens=False,
         | 
| 1798 | 
            +
                            return_attention_mask=False,
         | 
| 1799 | 
            +
                        )["input_ids"]
         | 
| 1800 | 
            +
                        if i % 2 == 0:
         | 
| 1801 | 
            +
                            turn_target_ids = [IGNORE_INDEX] * len(turn_input_ids)
         | 
| 1802 | 
            +
                        else:
         | 
| 1803 | 
            +
                            turn_target_ids = turn_input_ids
         | 
| 1804 | 
            +
                        if i != len(chat) - 1:
         | 
| 1805 | 
            +
                            turn_input_ids = turn_input_ids + [self.sep_id]
         | 
| 1806 | 
            +
                            turn_target_ids = turn_target_ids + [IGNORE_INDEX]
         | 
| 1807 | 
            +
                        input_ids = input_ids + turn_input_ids
         | 
| 1808 | 
            +
                        target_ids = target_ids + turn_target_ids
         | 
| 1809 | 
            +
                    input_ids = [self.bos_id] + input_ids + [self.eos_id]
         | 
| 1810 | 
            +
                    target_ids = [IGNORE_INDEX] + target_ids + [self.eos_id]
         | 
| 1811 | 
            +
                    image_tokens = self.image_tokens
         | 
| 1812 | 
            +
                    if self.config.max_language_tokens > 0:
         | 
| 1813 | 
            +
                        input_ids = input_ids[: self.config.max_language_tokens]
         | 
| 1814 | 
            +
                        target_ids = target_ids[: self.config.max_language_tokens]
         | 
| 1815 | 
            +
                    input_ids = image_tokens + input_ids
         | 
| 1816 | 
            +
                    target_ids = [IGNORE_INDEX] * len(image_tokens) + target_ids
         | 
| 1817 | 
            +
                    input_ids = torch.tensor(input_ids, dtype=torch.int64)
         | 
| 1818 | 
            +
                    target_ids = torch.tensor(target_ids, dtype=torch.int64)
         | 
| 1819 | 
            +
                    image_tensors: Dict[str, torch.Tensor] = {
         | 
| 1820 | 
            +
                        f"{camera_name}.siglip": self.hf_processor.image_processor(
         | 
| 1821 | 
            +
                            camera_images,
         | 
| 1822 | 
            +
                            size=self.config.image_sizes[camera_name].as_json(),
         | 
| 1823 | 
            +
                            return_tensors="pt",
         | 
| 1824 | 
            +
                        )["pixel_values"]
         | 
| 1825 | 
            +
                        for (camera_name, camera_images) in images.items()
         | 
| 1826 | 
            +
                    }
         | 
| 1827 | 
            +
                    attn_mask = make_causal_mask([len(input_ids), len(input_ids)])
         | 
| 1828 | 
            +
                    attn_mask = enable_full_attn_blocks(attn_mask, full_attn=target_ids == IGNORE_INDEX)
         | 
| 1829 | 
            +
                    return {
         | 
| 1830 | 
            +
                        "input_ids": input_ids,
         | 
| 1831 | 
            +
                        "target_ids": target_ids,
         | 
| 1832 | 
            +
                        "images": image_tensors,
         | 
| 1833 | 
            +
                        "attn_mask": attn_mask,
         | 
| 1834 | 
            +
                    }
         | 
| 1835 | 
            +
             | 
| 1836 | 
            +
                @property
         | 
| 1837 | 
            +
                def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
         | 
| 1838 | 
            +
                    return self.hf_processor.tokenizer
         | 
| 1839 | 
            +
             | 
| 1840 | 
            +
                @staticmethod
         | 
| 1841 | 
            +
                def _bbox_to_loc_tokens(match: str) -> str:
         | 
| 1842 | 
            +
                    """
         | 
| 1843 | 
            +
                    https://developers.googleblog.com/en/gemma-explained-paligemma-architecture/
         | 
| 1844 | 
            +
                    """
         | 
| 1845 | 
            +
                    floats = list(map(float, match.groups()))
         | 
| 1846 | 
            +
                    transformed = [f"<loc{np.clip(round(num * 1024), 0, 1023):04d}>" for num in floats]
         | 
| 1847 | 
            +
                    return f"[{', '.join(transformed)}]"
         | 
| 1848 | 
            +
             | 
| 1849 | 
            +
                @property
         | 
| 1850 | 
            +
                def image_sizes(self) -> Dict[str, ImageSizeConfig]:
         | 
| 1851 | 
            +
                    return self.config.image_sizes
         | 
| 1852 | 
            +
             | 
| 1853 | 
            +
             | 
| 1854 | 
            +
            class PaliGemmaDepthProcessor(PaliGemmaProcessor):
         | 
| 1855 | 
            +
                def __init__(
         | 
| 1856 | 
            +
                    self,
         | 
| 1857 | 
            +
                    config: PaliGemmaProcessorConfig,
         | 
| 1858 | 
            +
                    hf_processor: transformers.models.paligemma.processing_paligemma.PaliGemmaProcessor,
         | 
| 1859 | 
            +
                    depth_tokens: int,
         | 
| 1860 | 
            +
                ):
         | 
| 1861 | 
            +
                    super().__init__(config, hf_processor)
         | 
| 1862 | 
            +
                    vocab_size = len(self.tokenizer)
         | 
| 1863 | 
            +
                    self.depth_token_ids = np.arange(vocab_size - depth_tokens, vocab_size)
         | 
| 1864 | 
            +
                    self.depth_input_transforms = {
         | 
| 1865 | 
            +
                        camera_name: torchvision.transforms.v2.Compose(
         | 
| 1866 | 
            +
                            [
         | 
| 1867 | 
            +
                                torchvision.transforms.v2.Resize(
         | 
| 1868 | 
            +
                                    size=(camera_image_size.height, camera_image_size.width),
         | 
| 1869 | 
            +
                                    interpolation=torchvision.transforms.v2.InterpolationMode.BICUBIC,
         | 
| 1870 | 
            +
                                    max_size=None,
         | 
| 1871 | 
            +
                                    antialias=True,
         | 
| 1872 | 
            +
                                ),
         | 
| 1873 | 
            +
                                torchvision.transforms.v2.ToTensor(),
         | 
| 1874 | 
            +
                                torchvision.transforms.v2.Normalize(
         | 
| 1875 | 
            +
                                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
         | 
| 1876 | 
            +
                                ),
         | 
| 1877 | 
            +
                            ]
         | 
| 1878 | 
            +
                        )
         | 
| 1879 | 
            +
                        for (camera_name, camera_image_size) in self.config.image_sizes.items()
         | 
| 1880 | 
            +
                    }
         | 
| 1881 | 
            +
             | 
| 1882 | 
            +
                def preprocess_inputs(
         | 
| 1883 | 
            +
                    self, chat: List[str], images: Dict[str, List[PIL.Image.Image]]
         | 
| 1884 | 
            +
                ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
         | 
| 1885 | 
            +
                    inputs = super().preprocess_inputs(chat=chat, images=images)
         | 
| 1886 | 
            +
                    depth_images: Dict[str, torch.Tensor] = {
         | 
| 1887 | 
            +
                        f"{camera_name}.depth": torch.stack(
         | 
| 1888 | 
            +
                            self.depth_input_transforms[camera_name](camera_images), dim=0
         | 
| 1889 | 
            +
                        )
         | 
| 1890 | 
            +
                        for (camera_name, camera_images) in images.items()
         | 
| 1891 | 
            +
                    }
         | 
| 1892 | 
            +
                    inputs["images"] = {**inputs["images"], **depth_images}
         | 
| 1893 | 
            +
                    return inputs
         | 
| 1894 | 
            +
             | 
| 1895 | 
            +
                @property
         | 
| 1896 | 
            +
                def num_depth_tokens(self) -> int:
         | 
| 1897 | 
            +
                    return len(self.depth_token_ids)
         | 
