giu-alb commited on
Commit
a8bf2f3
·
verified ·
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: gemma
3
+ library_name: transformers
4
+ pipeline_tag: visual-question-answering
5
+ ---
6
+ # SPEAR-1 model card
7
+
8
+ SPEAR-1 is a cutting-edge Vision-Language-Action (VLA) model capable of achieving performance __superior or on par with state-of-the-art models such as pi0-FAST and pi0.5__
9
+ on multiple embodiments while being trained __on 20x less robot data__.
10
+
11
+ This model was developed by [INSAIT](https://insait.ai/), a special unit of Sofia University St. Kliment Ohridski, in Sofia, Bulgaria.
12
+
13
+ Code and model weights for SPEAR-1 models are free to used under the Gemma license.
14
+
15
+ This repo provides model weights fine-tuned for a Franka setup with one wrist and one external camera.
16
+
17
+ ## Model description
18
+
19
+ The key to SPEAR-1's data efficiency is SPEAR-VLM, a 3D-aware VLM. SPEAR-VLM extends PaliGemma with the MoGe depth encoder and is trained on 3D VQA tasks using
20
+ primarily non-robot data sources, such as EgoExo-4D.
21
+
22
+ SPEAR-1's architecture combines SPEAR-VLM with a DiT action expert. It is first pre-trained on a mixture of robot demonstration datasets from Open X Embodiment and
23
+ then fine-tuned for specific embodiments.
24
+
25
+ ## Use with 🤗 Transformers
26
+
27
+ We provide a fully `AutoModel` compatible implementation of SPEAR-1 that can be used via transformers.
28
+
29
+ ### Environment setup
30
+
31
+ The current implementation requires the following additional dependencies: `roma`, `timm`, `flash-attn`.
32
+
33
+ Here is a snippet to set up a working environment for inference via `uv`:
34
+
35
+ ```
36
+ uv venv python 3.10.12
37
+ source .venv/bin/activate
38
+ uv pip install --torch-backend=cu126 roma==1.5.0 numpy==2.2.4 torch==2.6.0 torchvision==0.21.0 transformers==4.47.0 timm==1.0.15
39
+ uv pip install --no-build-isolation setuptools psutil flash-attn==2.7.3
40
+ ```
41
+
42
+ ### Example usage
43
+
44
+
45
+ ```python
46
+ from typing import Dict
47
+
48
+ import numpy as np
49
+ import torch
50
+ from PIL import Image
51
+ from transformers import AutoModel
52
+
53
+ model = AutoModel.from_pretrained("INSAIT-Institute/spear1-franka")
54
+ model = model.to(dtype=torch.bfloat16, device="cuda").eval()
55
+
56
+ main_image = np.asarray(Image.open("path/to/main_image.png"))
57
+ wrist_image = np.asarray(Image.open("path/to/wrist_image.png"))
58
+
59
+ ee_translation = np.array([0.36, 0.0, 0.56])
60
+ ee_rotation = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
61
+ gripper = np.array(1.0)
62
+
63
+ model_input: Dict[str, np.ndarray | str | Dict[str, np.ndarray]] = {
64
+ "images": {
65
+ "main": main_image, # (H, W, C)
66
+ "wrist": wrist_image, # (H, W, C)
67
+ },
68
+ "ee_translation": ee_translation, # (3,)
69
+ "ee_rotation": ee_rotation, # (3, 3)
70
+ "gripper": gripper, # (1,)
71
+ "language_instruction": "put the carrot on the blue plate",
72
+ "dataset_name": "droid"
73
+ }
74
+
75
+ model_output: Dict[str, np.ndarray] = model.predict_action(model_input)
76
+
77
+ ctrl_translation: np.ndarray = model_output["translation"] # (S, 3)
78
+ ctrl_rotation: np.ndarray = model_output["rotation"] # (S, 3, 3)
79
+ ctrl_gripper: np.ndarray = model_output["gripper"] # (S, 1)
80
+
81
+ ```
82
+
83
+ ## Action space
84
+
85
+ SPEAR-1 predicts action chunks of delta end-effector positions. Each step in the predicted action chunk is relative to the input state.
86
+
87
+ Given the current end-effector position `[R, t]` and a model prediction `A_rel = [[R_1, t_1], ..., [R_n, t_n]]`, absolute end effector pose commands can be computed as:
88
+ ```
89
+ A_abs = [[R * R_1, t + t_1], ..., [R * R_n, t * t_n]]
90
+ ```
91
+
92
+ ## Community Feedback
93
+
94
+ We welcome feedback from the community to help improve SPEAR-1. If you have suggestions, encounter any issues, or have ideas for improvements, please contact us.
95
+
96
+ ## Summary
97
+
98
+ - __Model type__: Vision-Language-Action with flow-matching action decoding
99
+ - __Contact__: [email protected]
100
+ - __License__: Gemma Terms of Use
common_spear.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import dataclasses
3
+ import enum
4
+ import inspect
5
+ import types
6
+ from collections.abc import Mapping as MappingABC
7
+ from functools import cached_property
8
+ from typing import (
9
+ Any,
10
+ Callable,
11
+ Dict,
12
+ Iterable,
13
+ List,
14
+ Mapping,
15
+ Optional,
16
+ Sequence,
17
+ Tuple,
18
+ Type,
19
+ Union,
20
+ )
21
+
22
+ import torch
23
+ import transformers
24
+
25
+
26
+ class StrEnum(str, enum.Enum):
27
+ """
28
+ A minimal drop-in replacement for backports.strenum.StrEnum
29
+ """
30
+
31
+ def __str__(self):
32
+ return str(self.value)
33
+
34
+ def __new__(cls, value):
35
+ # Create new instance that properly handles string initialization
36
+ if isinstance(value, str):
37
+ obj = str.__new__(cls, value)
38
+ obj._value_ = value
39
+ return obj
40
+ return super().__new__(cls, value)
41
+
42
+ @classmethod
43
+ def _missing_(cls, value):
44
+ # Enhanced lookup by string value with better error handling
45
+ if isinstance(value, str):
46
+ for member in cls:
47
+ if member.value == value:
48
+ return member
49
+ # Return None to let enum handle the KeyError
50
+ return None
51
+
52
+ def __eq__(self, other):
53
+ # Allow comparison with string values
54
+ if isinstance(other, str):
55
+ return self.value == other
56
+ return super().__eq__(other)
57
+
58
+ def __hash__(self):
59
+ # Ensure consistent hashing
60
+ return hash(self.value)
61
+
62
+
63
+ class _cached_classproperty:
64
+ def __init__(self, func):
65
+ self.func = func
66
+ self._values = {}
67
+
68
+ def __get__(self, obj, klass):
69
+ if klass not in self._values.keys():
70
+ self._values[klass] = self.func.__get__(obj, klass)()
71
+ return self._values[klass]
72
+
73
+
74
+ def cached_classproperty(func):
75
+ if not isinstance(func, (classmethod, staticmethod)):
76
+ func = classmethod(func)
77
+ return _cached_classproperty(func)
78
+
79
+
80
+ @dataclasses.dataclass
81
+ class Dataclass:
82
+ def __post_init__(self):
83
+ pass
84
+
85
+ @classmethod
86
+ def make_empty(cls) -> "Dataclass":
87
+ return cls(
88
+ **{
89
+ k: (v.make_empty() if inspect.isclass(v) and issubclass(v, Dataclass) else None)
90
+ for (k, v) in cls.types.items()
91
+ }
92
+ )
93
+
94
+ @cached_classproperty
95
+ def fields(cls) -> Tuple[dataclasses.Field, ...]:
96
+ """Returns a sorted list of the Field objects"""
97
+ return tuple(sorted(dataclasses.fields(cls), key=lambda x: x.name))
98
+
99
+ @cached_classproperty
100
+ def types(cls) -> Dict[str, type]:
101
+ return {f.name: f.type for f in cls.fields}
102
+
103
+ def as_json(self, recursive: bool = True) -> dict:
104
+ return {k: v.as_json() if isinstance(v, Dataclass) and recursive else v for (k, v) in self.items()}
105
+
106
+ @classmethod
107
+ def keys(cls) -> List[str]:
108
+ return [field.name for field in cls.fields]
109
+
110
+ def values(self):
111
+ return [getattr(self, field.name) for field in self.fields]
112
+
113
+ def items(self, recursive: bool = False):
114
+ for key, value in zip(self.keys(), self.values(), strict=True):
115
+ if recursive and isinstance(value, Dataclass):
116
+ for subkey, subvalue in value.items(recursive=True):
117
+ yield (f"{key}.{subkey}", subvalue)
118
+ else:
119
+ yield (key, value)
120
+
121
+ def replace(self, **kwargs):
122
+ """
123
+ Return a new instance of Dataclass with the kwargs overwritten.
124
+ """
125
+ kwargs = maybe_chained_keys_to_nested_dict(kwargs)
126
+ data = self.as_json(recursive=False)
127
+ for key, value in kwargs.items():
128
+ value_type = self.types.get(key, None)
129
+ if value_type is None:
130
+ raise KeyError(f"Dataclass {self.__class__} does not have a field {key}")
131
+ value_type = get_maybe_optional_type(value_type)
132
+ if inspect.isclass(value_type) and issubclass(value_type, Dataclass):
133
+ if isinstance(value, dict):
134
+ data[key] = data[key].replace(**value)
135
+ else:
136
+ data[key] = value
137
+ else:
138
+ data[key] = value
139
+ return self.__class__(**data)
140
+
141
+ def apply(self, fcn: Callable, recursive: bool = True, skip_nones: bool = False) -> "Dataclass":
142
+ def fcn_wrapper(value: Any) -> Any:
143
+ if value is None and skip_nones:
144
+ return None
145
+ if isinstance(value, dict) and recursive:
146
+ return type(value)(**{k: fcn(v) for (k, v) in value.items()})
147
+ if isinstance(value, (list, tuple)) and recursive:
148
+ return type(value)([fcn(v) for v in value])
149
+ if isinstance(value, Dataclass) and recursive:
150
+ return value.apply(fcn, recursive=True, skip_nones=skip_nones)
151
+ return fcn(value)
152
+
153
+ return self.__class__(**{key: fcn_wrapper(value) for (key, value) in self.items()})
154
+
155
+ def __getitem__(self, index) -> "Dataclass":
156
+ def extract(obj):
157
+ if obj is None:
158
+ return None
159
+ if isinstance(obj, torch.Tensor):
160
+ return obj[index]
161
+ raise ValueError(f"Cannot slice {obj.__class__.__name__} object")
162
+
163
+ return self.apply(extract)
164
+
165
+
166
+ class Config:
167
+ def __init__(self, **kwargs):
168
+ self._apply_defaults()
169
+ self._set_attributes(**kwargs)
170
+ super().__init__()
171
+ self.__post_init__()
172
+
173
+ def _apply_defaults(self):
174
+ """
175
+ Initializes all annotated fields with defaults or sensible instances.
176
+ """
177
+ annotations = getattr(self, "__annotations__", {})
178
+ for key, type_hint in annotations.items():
179
+ # Skip if already set via class-level value or __init__ kwarg
180
+ if hasattr(self, key):
181
+ continue
182
+
183
+ # Case 1: class variable has a default (declared at class level)
184
+ if key in self.__class__.__dict__:
185
+ setattr(self, key, getattr(self.__class__, key))
186
+ continue
187
+
188
+ # Case 2: if the type is another Config subclass, default-construct it
189
+ if inspect.isclass(type_hint) and issubclass(type_hint, Config):
190
+ setattr(self, key, type_hint())
191
+ continue
192
+
193
+ # Case 3: fallback None (or empty dict for mappings)
194
+ if hasattr(type_hint, "__origin__") and type_hint.__origin__ in (
195
+ dict,
196
+ Dict,
197
+ MappingABC,
198
+ ):
199
+ setattr(self, key, {})
200
+ else:
201
+ setattr(self, key, None)
202
+
203
+ def _set_attributes(self, **kwargs):
204
+ subconfig_types = self._subconfig_types
205
+ for key, value in kwargs.items():
206
+ if key in subconfig_types:
207
+ if not isinstance(value, Mapping):
208
+ raise ValueError(
209
+ f"{self.__class__.__name__}.{key} expects dict-like object for nested config, but got: {value}"
210
+ )
211
+ setattr(self, key, subconfig_types[key](**value))
212
+ else:
213
+ setattr(self, key, value)
214
+
215
+ def keys(self) -> List[str]:
216
+ """Get all annotated keys including those from parent classes."""
217
+ all_keys = {}
218
+ # Walk through MRO in reverse to respect inheritance order
219
+ for cls in reversed(self.__class__.__mro__):
220
+ if cls is object:
221
+ continue
222
+ all_keys.update(getattr(cls, "__annotations__", {}))
223
+ return list(all_keys.keys())
224
+
225
+ def items(self) -> Iterable[Tuple[str, Any]]:
226
+ for key in self.keys():
227
+ yield (key, getattr(self, key))
228
+
229
+ @cached_classproperty
230
+ def _subconfig_types(cls) -> dict[str, Type]:
231
+ keys = {
232
+ key: value
233
+ for (key, value) in cls.__annotations__.items()
234
+ if inspect.isclass(value) and issubclass(value, Config)
235
+ }
236
+ for base in cls.__bases__:
237
+ if not issubclass(base, Config):
238
+ continue
239
+ keys = {**keys, **base._subconfig_types}
240
+ return keys
241
+
242
+ def __post_init__(self):
243
+ pass
244
+
245
+ def as_json(self) -> dict:
246
+ data = {}
247
+ for key, value in self.items():
248
+ if isinstance(value, Config):
249
+ data[key] = value.as_json()
250
+ elif (
251
+ isinstance(value, collections.abc.Sequence)
252
+ and len(value) > 0
253
+ and isinstance(value[0], Config)
254
+ ):
255
+ data[key] = [v.as_json() for v in value]
256
+ elif (
257
+ isinstance(value, collections.abc.Mapping)
258
+ and len(value) > 0
259
+ and isinstance(next(iter(value.values())), Config)
260
+ ):
261
+ data[key] = {k: v.as_json() for k, v in value.items()}
262
+ else:
263
+ data[key] = value
264
+
265
+ return data
266
+
267
+
268
+ class HFConfigMixin(transformers.PretrainedConfig):
269
+ """
270
+ Bridge between your Config system and HF PretrainedConfig.
271
+
272
+ Usage:
273
+ class SPEAR1Config(HFConfigMixin, Config):
274
+ model_type = "spear1"
275
+ processor_config: PaliGemmaProcessorConfig
276
+ ...
277
+ """
278
+
279
+ def __init__(self, **kwargs):
280
+ # Let HF's machinery initialize its own attributes / defaults first.
281
+ # PretrainedConfig.__init__ will set things like `model_type`,
282
+ # `_name_or_path`, `architectures`, and keep a `kwargs`->dict of extra items.
283
+ super().__init__(**kwargs)
284
+
285
+ # Now initialize your Config behavior: set defaults and construct nested configs.
286
+ # We call Config.__init__ explicitly because HFConfigMixin inherits from PretrainedConfig,
287
+ # and the user's concrete class will use multiple-inheritance with Config.
288
+ # (This approach mirrors the earlier MRO design: class Concrete(HFConfigMixin, Config).)
289
+ # We pass kwargs again so nested configs get overridden by user kwargs.
290
+ # Note: Config.__init__ itself calls super().__init__() — but because we are calling
291
+ # Config.__init__ directly (not via super()) the MRO won't re-call PretrainedConfig.__init__ here.
292
+ # (I.e., we are deliberately calling the concrete base initializer.)
293
+ Config.__init__(self, **kwargs) # type: ignore[name-defined]
294
+
295
+ def to_dict(self) -> Dict[str, Any]:
296
+ """
297
+ Merge HF PretrainedConfig serialization and Config.as_json().
298
+
299
+ Strategy:
300
+ 1. Take HF dict (super().to_dict()) so HF metadata/defaults are present.
301
+ 2. Take our nested config dict (Config.as_json(self)).
302
+ 3. Update the HF dict with our nested config dict so annotated fields
303
+ (nested configs, lists/dicts that should be recursively serialized)
304
+ take precedence.
305
+ """
306
+ # HF's representation (contains model_type, etc.). This is trusted HF serialization.
307
+ hf = super().to_dict()
308
+
309
+ # Our nested config representation (recursively serializes Config objects).
310
+ # Do not call self.to_dict() because that would recurse back here.
311
+ cfg_json = Config.as_json(self) # type: ignore[name-defined]
312
+
313
+ # Merge: prefer cfg_json values for keys present in our config (so nested configs
314
+ # are represented as dicts rather than raw objects or omitted).
315
+ merged: Dict[str, Any] = dict(hf)
316
+ merged.update(cfg_json)
317
+ return merged
318
+
319
+ @classmethod
320
+ def from_dict(
321
+ cls: Type["HFConfigMixin"],
322
+ config_dict: Dict[str, Any],
323
+ **kwargs,
324
+ ) -> "HFConfigMixin":
325
+ """
326
+ Construct by delegating to the class constructor — that will instantiate nested configs.
327
+ This is simple and consistent with PretrainedConfig.from_dict/from_pretrained behavior.
328
+ """
329
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
330
+
331
+ instance = cls(**config_dict)
332
+
333
+ if return_unused_kwargs:
334
+ # Return tuple of (instance, unused_kwargs) if requested
335
+ # Since we consume everything in __init__, unused is typically empty
336
+ return instance, {}
337
+ return instance
338
+
339
+
340
+ class Configurable:
341
+ def __init__(self, config: Config):
342
+ self._config = config
343
+
344
+ @property
345
+ def config(self) -> Config:
346
+ return self._config
347
+
348
+
349
+ class RotationFormat(StrEnum):
350
+ """Determines how rotations will be encoded in the loaded batch"""
351
+
352
+ EULER = "euler"
353
+ QUATERNION = "quaternion"
354
+ ROTMAT = "rotmat"
355
+
356
+
357
+ class ResizeMode(StrEnum):
358
+ """
359
+ Different modes for resizing images.
360
+ """
361
+
362
+ MATCH_WIDTH = "match_width"
363
+ MATCH_HEIGHT = "match_height"
364
+ MATCH_MAX = "match_max"
365
+ NAIVE = "naive"
366
+ SMART = "smart"
367
+ PAD = "pad"
368
+ CROP = "crop"
369
+
370
+
371
+ class Normalization(StrEnum):
372
+ """Action normalization types"""
373
+
374
+ NONE = "none"
375
+ BOUNDS = "bounds"
376
+ BOUNDS_Q99 = "bounds_q99"
377
+ MEAN = "mean"
378
+
379
+
380
+ def expand_dims(tensor: torch.Tensor, ndim: int, order: Sequence[int]) -> torch.Tensor:
381
+ """
382
+ Expand the dimensions of `tensor` to `ndim` such that all new dimensions have size of 1
383
+ Args:
384
+ tensor: torch.Tensor of any shape
385
+ ndim: Number of output dimensions. Must be >= `tensor.ndim`
386
+ order: Sequence of size `tensor.ndim + 1`. Contains only values of 1 and a single value of -1,
387
+ indicating where the new `ndim - tensor.ndim` dimensions will be inserted
388
+ Returns:
389
+ torch.Tensor with dimensions `ndim`, a view of `tensor`
390
+
391
+ Ex:
392
+ expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[1, -1, 1, 1]).shape -> [2, 1, 1, 3, 4]
393
+ expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[-1, 1, 1, 1]).shape -> [1, 1, 2, 3, 4]
394
+ expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[1, 1, 1, -1]).shape -> [2, 3, 4, 1, 1]
395
+ """
396
+ assert tensor.ndim <= ndim, f"{tensor.ndim} > {ndim}; shape={tensor.shape}"
397
+ assert len(order) == tensor.ndim + 1, f"{len(order)} != {tensor.ndim + 1}; shape={tensor.shape}"
398
+ order = list(order)
399
+ assert order.count(-1) == 1, "Order must have exactly one value of -1"
400
+ assert order.count(1) == len(order) - 1, "Order must have exactly len(order) - 1 values of 1"
401
+ if tensor.ndim == ndim:
402
+ return tensor
403
+ insert_index = order.index(-1)
404
+ view = list(tensor.shape[:insert_index]) + [1] * (ndim - tensor.ndim) + list(tensor.shape[insert_index:])
405
+ tensor = tensor.view(view)
406
+ return tensor
407
+
408
+
409
+ def merge_dicts_recursive(dict_1: Dict[str, Any], dict_2: Dict[str, Any]) -> Dict[str, Any]:
410
+ """
411
+ Merges dict_1 with dict_2 recursively.
412
+ Handles clashing keys:
413
+ 1. If both values are dicts, merges them recursively
414
+ 2. If any value is not a dict, raises ValueError
415
+ """
416
+ merged = dict(dict_1)
417
+ for key, value in dict_2.items():
418
+ if key in merged:
419
+ if not type(merged[key]) is type(value) is dict:
420
+ raise ValueError(f"Multiple values provided for key {key}: {merged[key]} and {value}")
421
+ merged[key] = merge_dicts_recursive(merged[key], value)
422
+ else:
423
+ merged[key] = value
424
+ return merged
425
+
426
+
427
+ def maybe_chained_keys_to_nested_dict(data: Dict[str, Any]) -> Dict[str, Any]:
428
+ """Converts a dict with keys of the form "key1.key2.key3" to a nested dict"""
429
+ unpacked_data: Dict[str, Any] = {}
430
+ for key, value in data.items():
431
+ if "." not in key:
432
+ unpacked_data = merge_dicts_recursive(unpacked_data, {key: value})
433
+ else:
434
+ (mainkey, subkey) = key.split(".", maxsplit=1)
435
+ nested_value = maybe_chained_keys_to_nested_dict({subkey: value})
436
+ unpacked_data = merge_dicts_recursive(unpacked_data, {mainkey: nested_value})
437
+ return unpacked_data
438
+
439
+
440
+ def annotation_is_union(type_value: Type) -> bool:
441
+ return getattr(type_value, "__origin__", None) is Union or type(type_value) is types.UnionType
442
+
443
+
444
+ def annotation_is_optional(type_value: Type) -> bool:
445
+ if annotation_is_union(type_value):
446
+ union_args = set(type_value.__args__)
447
+ if len(union_args) == 2 and type(None) in union_args:
448
+ return True
449
+ return False
450
+
451
+
452
+ def get_maybe_optional_type(type_value: Type[Optional[Any]]) -> Type[Any]:
453
+ if annotation_is_optional(type_value):
454
+ type_args = type_value.__args__
455
+ if type_args[1] is type(None):
456
+ return type_args[0]
457
+ return type_args[1]
458
+ return type_value
459
+
460
+
461
+ @dataclasses.dataclass
462
+ class RoboticsTarget(Dataclass):
463
+ control_tokens_ids: Optional[torch.Tensor]
464
+ text_tokens_ids: Optional[torch.Tensor]
465
+ translation: torch.Tensor
466
+ rotation: torch.Tensor
467
+ gripper: torch.Tensor
468
+ valid_mask: torch.Tensor
469
+
470
+
471
+ @dataclasses.dataclass
472
+ class RoboticsControlPlan(Dataclass):
473
+ translation_m: torch.Tensor
474
+ rotmat: torch.Tensor
475
+ gripper_prob: torch.Tensor
476
+ valid_mask: torch.Tensor
477
+
478
+ def __post_init__(self):
479
+ super().__post_init__()
480
+ assert self.translation_m.ndim == 3, self.translation_m.shape
481
+ assert self.rotmat.ndim == 3, self.rotmat.shape
482
+ assert self.gripper_prob.ndim == 3, self.gripper_prob.shape
483
+
484
+
485
+ @dataclasses.dataclass
486
+ class RoboticsInput(Dataclass):
487
+ images: Dict[str, torch.Tensor]
488
+ input_ids: torch.Tensor
489
+ attn_mask: torch.Tensor
490
+ ee_pose_translation: torch.Tensor
491
+ ee_pose_rotation: torch.Tensor
492
+ gripper: torch.Tensor
493
+ joints: torch.Tensor
494
+ control_tokens_ids: Optional[torch.Tensor]
495
+
496
+ @property
497
+ def inputs_embeds(self) -> Optional[torch.Tensor]:
498
+ return None
499
+
500
+ @property
501
+ def past_key_values(self) -> Optional[List[torch.Tensor]]:
502
+ return None
503
+
504
+ @cached_property
505
+ def multimodal_indices(self) -> torch.Tensor:
506
+ """
507
+ Returns a torch.Tensor containing only the indices of the batch examples which are multimodal.
508
+ Return shape is [B]
509
+ """
510
+ return torch.arange(self.input_ids.shape[0], dtype=torch.int64, device=self.input_ids.device)
511
+
512
+ @cached_property
513
+ def unimodal_indices(self) -> torch.Tensor:
514
+ """
515
+ Returns a torch.Tensor containing only the indices of the batch examples which are unimodal.
516
+ Return shape is [B]
517
+ """
518
+ return torch.tensor([], dtype=torch.int64, device=self.input_ids.device)
519
+
520
+
521
+ @dataclasses.dataclass
522
+ class FlowInput(Dataclass):
523
+ timestep: torch.Tensor
524
+ translation_t: torch.Tensor
525
+ rotation_t: torch.Tensor
526
+ gripper_t: torch.Tensor
527
+ translation_t0: torch.Tensor
528
+ rotation_t0: torch.Tensor
529
+ gripper_t0: torch.Tensor
530
+
531
+
532
+ @dataclasses.dataclass
533
+ class RoboticsFlowInput(RoboticsInput):
534
+ """Input to the entire Robotics VLM"""
535
+
536
+ flow_input: FlowInput
537
+
538
+
539
+ @dataclasses.dataclass
540
+ class DiffusionInput(Dataclass):
541
+ timestep: torch.Tensor
542
+ noised_translation: torch.Tensor
543
+ noised_rotation: torch.Tensor
544
+ noised_gripper: torch.Tensor
545
+
546
+
547
+ @dataclasses.dataclass
548
+ class LLMOutput(Dataclass):
549
+ """Fork of transformers.modeling_outputs.CausalLMOutputWithPast"""
550
+
551
+ input_ids: torch.Tensor
552
+ logits: Optional[torch.Tensor]
553
+ output_ids: Optional[torch.Tensor]
554
+ loss: Optional[torch.Tensor]
555
+ past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]
556
+ hidden_states: List[torch.Tensor]
557
+ text_indices: torch.Tensor
558
+ image_indices: torch.Tensor
559
+
560
+ @classmethod
561
+ def from_transformers(
562
+ cls,
563
+ input_ids: torch.Tensor,
564
+ llm_output: transformers.modeling_outputs.CausalLMOutputWithPast,
565
+ text_indices: Optional[torch.Tensor],
566
+ image_indices: Optional[torch.Tensor],
567
+ ) -> "LLMOutput":
568
+ return LLMOutput(
569
+ input_ids=input_ids,
570
+ logits=llm_output.logits,
571
+ output_ids=None,
572
+ loss=llm_output.loss,
573
+ past_key_values=(
574
+ list(llm_output.past_key_values) if llm_output.past_key_values is not None else []
575
+ ),
576
+ hidden_states=(list(llm_output.hidden_states) if llm_output.hidden_states is not None else []),
577
+ text_indices=text_indices,
578
+ image_indices=image_indices,
579
+ )
580
+
581
+ def compress(self) -> "LLMOutput":
582
+ """
583
+ Compress the data contained in the class so it can be moved between CPU and GPU or concatenated
584
+ much faster:
585
+ - hidden_states - huge tensors; take a lot of CPU time to move across devices or concat
586
+ - past_key_values - huge tensors; take a lot of CPU time to move across devices or concat
587
+ - logits - huge last dimension; takes a lot of CPU time to move across devices or concat
588
+ """
589
+ replace: Dict[str, Any] = {
590
+ "hidden_states": [],
591
+ "past_key_values": [],
592
+ "loss": None,
593
+ "input_ids": None,
594
+ }
595
+ if self.logits is not None:
596
+ replace["logits"] = None
597
+ if self.output_ids is None or self.output_ids.shape[1] != self.text_indices.shape[0]:
598
+ replace["output_ids"] = (
599
+ torch.index_select(self.logits, dim=1, index=self.text_indices)
600
+ .argmax(dim=-1)
601
+ .to(dtype=torch.int64)
602
+ )
603
+ return self.replace(**replace)
604
+
605
+
606
+ @dataclasses.dataclass
607
+ class RoboticsOutput(Dataclass):
608
+ translation: Optional[torch.Tensor]
609
+ rotation: Optional[torch.Tensor]
610
+ gripper: Optional[torch.Tensor]
611
+ token_logits: Optional[torch.Tensor]
612
+ token_ids: Optional[torch.Tensor]
613
+ llm_output: LLMOutput
614
+
615
+ def compress(self) -> "RoboticsOutput":
616
+ """
617
+ Compress output and drop unnecessary components to speed up transfer GPU <-> CPU.
618
+ Note that LLM logits can be extremely expensive since their size is [B, S, vocab_size], which
619
+ can reach millions or billions of values for large vocab_size
620
+ """
621
+ replace: Dict[str, Any] = {
622
+ "llm_output": self.llm_output.compress(),
623
+ "token_logits": None,
624
+ }
625
+ if self.token_logits is not None and self.token_ids is None:
626
+ replace["token_ids"] = torch.argmax(self.token_logits, dim=-1)
627
+ return self.replace(**replace)
628
+
629
+
630
+ @dataclasses.dataclass
631
+ class VLMOutput(Dataclass):
632
+ llm_output: LLMOutput
633
+ vit_tokens: Optional[torch.Tensor]
634
+ attn_mask: torch.Tensor
635
+
636
+ def compress(self) -> "VLMOutput":
637
+ """
638
+ Compress output and drop unnecessary components to speed up transfer GPU <-> CPU.
639
+ Note that LLM logits can be extremely expensive since their size is [B, S, vocab_size], which
640
+ can reach millions or billions of values for large vocab_size
641
+ """
642
+ return self.replace(llm_output=self.llm_output.compress())
643
+
644
+
645
+ def is_quaternion(quaternion: torch.Tensor) -> bool:
646
+ return quaternion.shape[-1] == 4
647
+
648
+
649
+ def quaternion_half_cover(quaternion: torch.Tensor) -> torch.Tensor:
650
+ """
651
+ Flip quaternions so they cover only a half the space. If the q_w is negative, flip the quaternion.
652
+ If q_w is 0, then choose such that the first non-zero component is positive. Note that geometrically,
653
+ this doesn't correspond to a single hemisphere of the unit sphere. Follows
654
+ https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.as_quat.html#scipy.spatial.transform.Rotation.as_quat
655
+ """
656
+ assert is_quaternion(quaternion), quaternion.shape
657
+ with torch.no_grad():
658
+ is_zero = quaternion == 0
659
+ flip_condition = (
660
+ (quaternion[..., -1:] < 0)
661
+ | is_zero[..., -1:] & (quaternion[..., 0:1] < 0)
662
+ | is_zero[..., -1:] & is_zero[..., 0:1] & (quaternion[..., 1:2] < 0)
663
+ | is_zero[..., -1:] & is_zero[..., 0:1] & is_zero[..., 1:2] & (quaternion[..., 2:3] < 0)
664
+ )
665
+ quaternion = torch.where(flip_condition, -quaternion, quaternion)
666
+ return quaternion
667
+
668
+
669
+ def is_rotmat_3x3(rotmat: torch.Tensor) -> bool:
670
+ return rotmat.shape[-2:] == torch.Size([3, 3])
671
+
672
+
673
+ def is_rotmat_9(rotmat: torch.Tensor) -> bool:
674
+ return rotmat.shape[-1] == 9
675
+
676
+
677
+ def rotmat_as_9(rotmat: torch.Tensor) -> torch.Tensor:
678
+ """Convert any rotmat input to [..., 9] shape"""
679
+ if is_rotmat_9(rotmat):
680
+ return rotmat
681
+ if is_rotmat_3x3(rotmat):
682
+ return rotmat.reshape(*rotmat.shape[:-2], 9)
683
+ raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix")
684
+
685
+
686
+ def is_rotmat(rotmat: torch.Tensor) -> bool:
687
+ """
688
+ Checks if the tensor shape matches that of a rotmat. However, it's not guaranteed the data is a
689
+ valid rotmat. `is_orthonormal_rotmat` performs this additional check.
690
+ NOTE: This might incorrectly return True if the underlying data is euler angles and accidentally
691
+ `rotmat.shape[-2:] == [3, 3]`. This would happen very rarely, but use with caution
692
+ """
693
+ return is_rotmat_3x3(rotmat) or is_rotmat_9(rotmat)
694
+
695
+
696
+ def rotmat_as_3x3(rotmat: torch.Tensor) -> torch.Tensor:
697
+ """Convert any rotmat input to [..., 3, 3] shape"""
698
+ if rotmat.shape[-1] == 9:
699
+ return rotmat.reshape(*rotmat.shape[:-1], 3, 3)
700
+ if rotmat.shape[-2:] == torch.Size([3, 3]):
701
+ return rotmat
702
+ raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix")
config.json ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_auto_class": null,
3
+ "_name_or_path": "/scratch/giuliano_albanese/spear-hf",
4
+ "architectures": [
5
+ "SPEAR1"
6
+ ],
7
+ "attribute_map": {},
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_spear.SPEAR1Config",
10
+ "AutoModel": "modeling_spear.SPEAR1"
11
+ },
12
+ "autoclass": "barrel.pipes.vlams.models.vlams.vlam.VLAM",
13
+ "base_config_key": "",
14
+ "control_module_config": {
15
+ "control_decoder_config": {
16
+ "block_config": {
17
+ "activation": "GELU",
18
+ "attn_implementation": "sdpa",
19
+ "dropout": 0.0,
20
+ "feature_size": 1024,
21
+ "head_dim": 256,
22
+ "hidden_size": 4096,
23
+ "norm": "RMSNorm",
24
+ "num_heads": 8,
25
+ "num_kv_heads": 1,
26
+ "position_embed_config": {
27
+ "base": 10000,
28
+ "cached": true,
29
+ "embedding_dim": 256,
30
+ "num_embeddings": 512
31
+ }
32
+ },
33
+ "num_blocks": 18
34
+ },
35
+ "noised_control_proj_config": {
36
+ "activation": "SiLU",
37
+ "layers": [
38
+ 8,
39
+ 2048,
40
+ 1024,
41
+ 1024
42
+ ],
43
+ "norm": null,
44
+ "time_embed": {
45
+ "activation": "SiLU",
46
+ "layers": [],
47
+ "learnable_features": false,
48
+ "max_period": 10000.0,
49
+ "norm": null,
50
+ "num_features": 1024
51
+ }
52
+ },
53
+ "robot_state_proj_config": {
54
+ "activation": "SiLU",
55
+ "fourier": false,
56
+ "layers": [
57
+ 8,
58
+ 1024
59
+ ],
60
+ "mode": "ee_pose_gripper"
61
+ },
62
+ "rotation_components": 4,
63
+ "token_size": 1024
64
+ },
65
+ "is_composition": false,
66
+ "model_type": "spear1",
67
+ "processor_config": {
68
+ "control_io_config": {
69
+ "future_control_offset_sec": 0.0,
70
+ "future_controls_sequence_length": 5,
71
+ "future_controls_sequence_stride_sec": 0.2,
72
+ "future_frames_sequence_length": 1,
73
+ "future_frames_sequence_stride_sec": null,
74
+ "past_frames_sequence_length": 1,
75
+ "past_frames_stride_sec": null,
76
+ "past_scalars_sequence_length": 1,
77
+ "past_scalars_stride_sec": null,
78
+ "sequence_frames": 1,
79
+ "sequence_frames_stride_sec": null
80
+ },
81
+ "control_stats_path": "barrel/pipes/vlams/types/control_stats.yaml",
82
+ "control_tokenizer_config": {},
83
+ "delta_controls": true,
84
+ "distribution_hyperparams": {
85
+ "alpha": 1.5,
86
+ "beta": 1.0
87
+ },
88
+ "eef_control_frame": false,
89
+ "image_resize": "smart",
90
+ "joints_norm": {
91
+ "high": [
92
+ 3.141592653589793,
93
+ 3.141592653589793,
94
+ 3.141592653589793,
95
+ 3.141592653589793,
96
+ 3.141592653589793,
97
+ 3.141592653589793,
98
+ 3.141592653589793
99
+ ],
100
+ "low": [
101
+ -3.141592653589793,
102
+ -3.141592653589793,
103
+ -3.141592653589793,
104
+ -3.141592653589793,
105
+ -3.141592653589793,
106
+ -3.141592653589793,
107
+ -3.141592653589793
108
+ ]
109
+ },
110
+ "num_inference_steps": 10,
111
+ "obs_rotation_norm": "none",
112
+ "obs_translation_norm": "bounds_q99",
113
+ "observation_stats_path": "barrel/pipes/vlams/types/observation_stats.yaml",
114
+ "r0_distribution": "uniform",
115
+ "rotation_format": "quaternion",
116
+ "rotation_norm": "none",
117
+ "sig_min": 0.001,
118
+ "timestep_distribution": "beta",
119
+ "translation_norm": {
120
+ "high": [
121
+ 0.04,
122
+ 0.04,
123
+ 0.04
124
+ ],
125
+ "low": [
126
+ -0.04,
127
+ -0.04,
128
+ -0.04
129
+ ]
130
+ }
131
+ },
132
+ "sub_configs": {},
133
+ "torch_dtype": "float32",
134
+ "transformers_version": "4.47.0",
135
+ "vlm_config": {
136
+ "attn_implementation": "flash_attention_2",
137
+ "depth_tokens": 1024,
138
+ "lm_head": false,
139
+ "mean_resizing": false,
140
+ "model_id": "google/paligemma-3b-mix-224",
141
+ "paligemma_3d_config": {
142
+ "depth_config": {
143
+ "hf_filename": "moge/moge-vit-large-patch-14-backbone.pt",
144
+ "hf_hub_repo": "nikonikolov/vlams"
145
+ },
146
+ "depth_layers": 4,
147
+ "depth_only": false,
148
+ "mask_prob": 0.0,
149
+ "projection": "features_add"
150
+ },
151
+ "processor_config": {
152
+ "image_sizes": {
153
+ "main": {
154
+ "height": 210,
155
+ "width": 280
156
+ },
157
+ "wrist": {
158
+ "height": 112,
159
+ "width": 112
160
+ }
161
+ },
162
+ "image_token": "<image>",
163
+ "max_language_tokens": 75
164
+ },
165
+ "train_only_depth_tokens": false
166
+ }
167
+ }
configuration_spear.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import collections.abc
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from .common_spear import (
8
+ Config,
9
+ HFConfigMixin,
10
+ Normalization,
11
+ ResizeMode,
12
+ RotationFormat,
13
+ )
14
+
15
+
16
+ class InputSequencingConfig(Config):
17
+ """
18
+ past_frames_sequence_length: number of past images needed in a single robot state
19
+ past_scalars_sequence_length: number of past scalar state data, e.g. actions, poses, etc,
20
+ needed in a single robot state
21
+ past_frames_stride_sec: sampling rate, determines how far apart in time each point in the sequence
22
+ is. If None, ignored and takes the default data collection frequency from the dataset
23
+ past_scalars_stride_sec: similar to past_frames_stride_sec
24
+
25
+ sequence_frames: number of temporally-sequential points in a single example in the batch
26
+ sequence_frames_stride_sec: sampling rate
27
+
28
+ Understanding sequence_frames:
29
+ TODO: sequences are possibly useful in some rare cases, maybe sequence modeling problems,
30
+ but yet to be confirmed. Keeping for now, but could be removed if proved unnecessary
31
+
32
+ - past_scalars_sequence_length, past_frames_sequence_length, future_controls_sequence_length,
33
+ future_frames_sequence_length are hyperparameters refering to a SINGLE dataset example / 'state'.
34
+ It is assumed that `past_scalars_sequence_length` and `past_frames_sequence_length` are the min
35
+ number of observations that comprise a single 'state'
36
+ - sequence_frames is a hyperparameter refering to the entire learning process. It controls the size
37
+ of the sequence dimension in the batch. It's treated similarly to the batch dimension, with the
38
+ difference that points in the sequence dimensions are temporally aligned. Unlike `past_*`
39
+ attributes, in supervised learning a label is loaded for every point in the sequence dimension
40
+ and the loss usually computed over the entire sequence dimension.
41
+ """
42
+
43
+ past_scalars_sequence_length: int = 1
44
+ past_frames_sequence_length: int = 1
45
+ past_scalars_stride_sec: Optional[float] = None
46
+ past_frames_stride_sec: Optional[float] = None
47
+ sequence_frames: int = 1
48
+ sequence_frames_stride_sec: Optional[float] = None
49
+
50
+ def __post_init__(self):
51
+ super().__post_init__()
52
+ assert self.past_scalars_sequence_length >= 1, self.past_scalars_sequence_length
53
+ assert self.past_frames_sequence_length >= 1, self.past_frames_sequence_length
54
+ assert self.sequence_frames >= 1, self.sequence_frames
55
+ if self.past_frames_stride_sec is not None:
56
+ assert self.past_frames_stride_sec >= 0.0, self.past_frames_stride_sec
57
+ if self.past_scalars_stride_sec is not None:
58
+ assert self.past_scalars_stride_sec >= 0.0, self.past_scalars_stride_sec
59
+ if self.sequence_frames_stride_sec is not None:
60
+ assert self.sequence_frames_stride_sec >= 0.0, self.sequence_frames_stride_sec
61
+
62
+ def assert_same_past(self) -> None:
63
+ assert (
64
+ self.past_frames_stride_sec == self.past_scalars_stride_sec
65
+ ), f"{self.past_frames_stride_sec} != {self.past_scalars_stride_sec}"
66
+ assert (
67
+ self.past_frames_sequence_length == self.past_scalars_sequence_length
68
+ ), f"{self.past_frames_sequence_length} != {self.past_scalars_sequence_length}"
69
+
70
+
71
+ class OutputSequencingConfig(Config):
72
+ """
73
+ future_controls_sequence_length: number of control steps in the future the model predicts
74
+ future_frames_sequence_length: number of future frames the model predicts
75
+ (only relevant for neural networks that learn some sort of a world model)
76
+
77
+ future_controls_sequence_stride_sec / future_frames_sequence_stride_sec: sampling rate
78
+ that determines how far apart in time each point in the sequence is. If None,
79
+ ignored and takes the default data collection frequency from the dataset
80
+
81
+ future_control_offset_sec: time interval between the last observation and the first
82
+ point at which control is predicted. Serves as a 'causality hyperparameter', allowing
83
+ for predicting controls slightly further into the future in environments with dynamics
84
+ where the observed effects of an action appear slightly later
85
+ """
86
+
87
+ future_controls_sequence_length: int = 1
88
+ future_controls_sequence_stride_sec: Optional[float] = None
89
+ future_frames_sequence_length: int = 1
90
+ future_frames_sequence_stride_sec: Optional[float] = None
91
+ future_control_offset_sec: float = 0.0
92
+
93
+ def __post_init__(self):
94
+ super().__post_init__()
95
+ assert self.future_controls_sequence_length >= 1, self.future_controls_sequence_length
96
+ assert self.future_frames_sequence_length >= 1, self.future_frames_sequence_length
97
+ assert self.future_control_offset_sec >= 0.0, self.future_control_offset_sec
98
+ if self.future_controls_sequence_stride_sec is not None:
99
+ assert self.future_controls_sequence_stride_sec >= 0.0, self.future_controls_sequence_stride_sec
100
+ if self.future_frames_sequence_stride_sec is not None:
101
+ assert self.future_frames_sequence_stride_sec >= 0.0, self.future_frames_sequence_stride_sec
102
+
103
+
104
+ class ControlDataIOConfig(InputSequencingConfig, OutputSequencingConfig):
105
+ pass
106
+
107
+
108
+ class ControlTokenizerConfig(Config):
109
+ pass
110
+
111
+
112
+ class EmptyTokenizerConfig(ControlTokenizerConfig):
113
+ pass
114
+
115
+
116
+ class VLAMProcessorConfig(Config):
117
+ control_io_config: ControlDataIOConfig = ControlDataIOConfig()
118
+ obs_translation_norm: Normalization | Dict[str, Tuple[float, float, float]] = Normalization.NONE
119
+ obs_rotation_norm: Normalization = Normalization.NONE
120
+ translation_norm: Normalization | Dict[str, Tuple[float, float, float]] = Normalization.NONE
121
+ rotation_norm: Normalization = Normalization.NONE
122
+ joints_norm: Dict[str, Tuple[float, ...]] = {
123
+ "low": (-np.pi,) * 7,
124
+ "high": (np.pi,) * 7,
125
+ }
126
+ rotation_format: RotationFormat = RotationFormat.QUATERNION
127
+ eef_control_frame: bool = False
128
+ delta_controls: bool = False
129
+ image_resize: ResizeMode = ResizeMode.SMART
130
+ control_tokenizer_config: EmptyTokenizerConfig = EmptyTokenizerConfig()
131
+ control_stats_path: str = "barrel/pipes/vlams/types/control_stats.yaml"
132
+ observation_stats_path: str = "barrel/pipes/vlams/types/observation_stats.yaml"
133
+
134
+ def __post_init__(self):
135
+ super().__post_init__()
136
+ if isinstance(self.translation_norm, collections.abc.Mapping):
137
+ assert all((len(value) == 3 for value in self.translation_norm.values())), self.translation_norm
138
+ assert set(self.translation_norm.keys()) in (
139
+ {"low", "high"},
140
+ {"mean", "std"},
141
+ ), self.translation_norm
142
+ assert isinstance(self.joints_norm, collections.abc.Mapping), type(self.joints_norm)
143
+ assert all((len(value) == 7 for value in self.joints_norm.values())), self.joints_norm
144
+ assert set(self.joints_norm.keys()) in (
145
+ {"low", "high"},
146
+ {"mean", "std"},
147
+ ), self.joints_norm
148
+
149
+
150
+ class RegressionProcessorConfig(VLAMProcessorConfig):
151
+ pass
152
+
153
+
154
+ class PiZeroFlowProcessorConfig(RegressionProcessorConfig):
155
+ num_inference_steps: int
156
+ r0_distribution: str = "uniform"
157
+ timestep_distribution: str
158
+ distribution_hyperparams: Dict[str, Any] = {}
159
+ sig_min: float = 0.001
160
+
161
+ def __post_init__(self):
162
+ super().__post_init__()
163
+ assert self.r0_distribution in ["normal", "uniform"]
164
+
165
+
166
+ class VLMConfig(Config):
167
+ pass
168
+
169
+
170
+ class VLMProcessorConfig(Config):
171
+ pass
172
+
173
+
174
+ class ImageSizeConfig(Config):
175
+ width: int
176
+ height: int
177
+
178
+ def to_dict(self):
179
+ return {"width": self.width, "height": self.height}
180
+
181
+
182
+ class PaliGemmaProcessorConfig(Config):
183
+ image_token: str = "<image>"
184
+ image_sizes: Dict[str, ImageSizeConfig] = {"main": ImageSizeConfig(width=224, height=224)}
185
+ max_language_tokens: int = 75
186
+
187
+ def __post_init__(self):
188
+ super().__post_init__()
189
+ self.image_sizes = {
190
+ camera_name: (
191
+ ImageSizeConfig(**camera_image_size)
192
+ if not isinstance(camera_image_size, ImageSizeConfig)
193
+ else camera_image_size
194
+ )
195
+ for camera_name, camera_image_size in self.image_sizes.items()
196
+ }
197
+ for camera_name, camera_image_size in self.image_sizes.items():
198
+ assert camera_image_size.height % 14 == 0, f"{camera_name}: {camera_image_size}"
199
+ assert camera_image_size.width % 14 == 0, f"{camera_name}: {camera_image_size}"
200
+
201
+ @property
202
+ def num_image_tokens(self) -> Dict[str, int]:
203
+ return {
204
+ camera_name: camera_image_size.height // 14 * (camera_image_size.width // 14)
205
+ for (camera_name, camera_image_size) in self.image_sizes.items()
206
+ }
207
+
208
+ @property
209
+ def is_single_image_size(self) -> bool:
210
+ return (
211
+ len(self.image_sizes) == 1
212
+ or len(set(((image_size.height, image_size.width) for image_size in self.image_sizes.values())))
213
+ == 1
214
+ )
215
+
216
+ @property
217
+ def camera_names(self) -> List[str]:
218
+ return list(self.image_sizes.keys())
219
+
220
+ def to_dict(self) -> Dict[str, Any]:
221
+ base_dict = {
222
+ "image_token": self.image_token,
223
+ "max_language_tokens": self.max_language_tokens,
224
+ }
225
+ base_dict["image_sizes"] = {
226
+ camera_name: camera_image_size.to_dict()
227
+ for camera_name, camera_image_size in self.image_sizes.items()
228
+ }
229
+ return base_dict
230
+
231
+
232
+ class PaliGemmaVLMConfig(Config):
233
+ model_id: str = "google/paligemma-3b-mix-224"
234
+ attn_implementation: str = "flash_attention_2"
235
+ processor_config: PaliGemmaProcessorConfig
236
+ lm_head: bool = False
237
+ paligemma_3d_config: Dict[str, Any] = {}
238
+ depth_tokens: int = 0
239
+ train_only_depth_tokens: bool = False
240
+ mean_resizing: bool = False
241
+
242
+ def __post_init__(self):
243
+ super().__post_init__()
244
+ if self.train_only_depth_tokens:
245
+ assert self.depth_tokens > 0, self.depth_tokens
246
+ if self.paligemma_3d_config.get("mask_prob", 0.0) != 0.0:
247
+ raise NotImplementedError(
248
+ f"Masking is deprecated, but got mask_prob={self.paligemma_3d_config['mask_prob']}"
249
+ )
250
+
251
+ @property
252
+ def paligemma_3d_config_dict(self) -> Dict[str, Any]:
253
+ if len(self.paligemma_3d_config) == 0:
254
+ return {}
255
+ config = dict(self.paligemma_3d_config)
256
+ config["depth_config"] = dict(config["depth_config"])
257
+ config["depth_config"]["image_sizes"] = {
258
+ camera_name: camera_image_size.to_dict()
259
+ for camera_name, camera_image_size in self.processor_config.image_sizes.items()
260
+ }
261
+ return config
262
+
263
+ @property
264
+ def with_depth(self) -> bool:
265
+ return len(self.paligemma_3d_config) > 0
266
+
267
+
268
+ class FourierFeaturesConfig(Config):
269
+ num_features: int = 256
270
+ learnable_features: bool = False
271
+ max_period: float = 10000.0
272
+ layers: List[int] = [256, 512, 256]
273
+ activation: str = "SiLU"
274
+ norm: Optional[str] = None
275
+
276
+
277
+ class NoisedControlProjectorConfig(Config):
278
+ time_embed: FourierFeaturesConfig
279
+ layers: List[int] = []
280
+ activation: str = "SiLU"
281
+ norm: Optional[str] = None
282
+
283
+
284
+ class RobotStateProjectorConfig(Config):
285
+ layers: List[int] = []
286
+ mode: str = "none"
287
+ activation: str = "GELU"
288
+ fourier: bool = False
289
+
290
+ def __post_init__(self):
291
+ super().__post_init__()
292
+ assert self.mode in [
293
+ "ee_pose",
294
+ "ee_pose_gripper",
295
+ "ee_pose_joints",
296
+ "joints",
297
+ "all",
298
+ "none",
299
+ ], self.mode
300
+
301
+
302
+ class RotaryPositionalEncodingConfig(Config):
303
+ num_embeddings: int
304
+ embedding_dim: int
305
+ base: int = 10000
306
+ cached: bool = True
307
+
308
+
309
+ class PiZeroFlowMatchingDecoderBlockConfig(Config):
310
+ feature_size: int
311
+ head_dim: int = 128
312
+ num_heads: int = 32
313
+ num_kv_heads: int = 1
314
+ hidden_size: int
315
+ activation: str = "GELU"
316
+ norm: str = "RMSNorm"
317
+ dropout: float = 0.0
318
+ attn_implementation: str = "sdpa"
319
+ position_embed_config: RotaryPositionalEncodingConfig
320
+
321
+
322
+ class PiZeroFlowMatchingDecoderConfig(Config):
323
+ num_blocks: int
324
+ block_config: PiZeroFlowMatchingDecoderBlockConfig
325
+
326
+
327
+ class PiZeroFlowMatchingModuleConfig(Config):
328
+ token_size: int = 1024
329
+ noised_control_proj_config: NoisedControlProjectorConfig
330
+ robot_state_proj_config: RobotStateProjectorConfig
331
+ control_decoder_config: PiZeroFlowMatchingDecoderConfig
332
+ rotation_components: int = 3
333
+
334
+
335
+ class SPEAR1Config(HFConfigMixin, Config):
336
+ model_type: str = "spear1"
337
+ processor_config: PiZeroFlowProcessorConfig
338
+ vlm_config: PaliGemmaVLMConfig
339
+ control_module_config: PiZeroFlowMatchingModuleConfig
340
+
341
+ def __init__(self, **kwargs):
342
+ if "auto_map" not in kwargs:
343
+ kwargs["auto_map"] = {
344
+ "AutoConfig": "configuration_spear.SPEAR1Config",
345
+ "AutoModel": "modeling_spear.SPEAR1",
346
+ }
347
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "transformers_version": "4.47.0"
3
+ }
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0992d3b5ffdc8b896812ed19801bc9ebda65708237681ced90e642c90e0a0d2
3
+ size 4962008480
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db48d29ee9567705a81718181eac6c644d2d996f1e91c497e8c891702050c36e
3
+ size 4999821656
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c7e1d6dae46553546f53a3c9fa76a8a2d2e07664a575ce38962ae2930eb7562
3
+ size 4245980072
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_spear.py ADDED
The diff for this file is too large to render. See raw diff
 
processing_spear.py ADDED
@@ -0,0 +1,1897 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import collections.abc
3
+ import re
4
+ import warnings
5
+ from abc import abstractmethod
6
+ from functools import cached_property
7
+ from typing import Dict, List, Optional, Sequence, Tuple, TypeVar
8
+
9
+ import numpy as np
10
+ import PIL.Image
11
+ import roma
12
+ import torch
13
+ import torchvision.transforms.v2
14
+ import transformers
15
+ import yaml
16
+
17
+ from .common_spear import (
18
+ Configurable,
19
+ FlowInput,
20
+ Normalization,
21
+ ResizeMode,
22
+ RoboticsControlPlan,
23
+ RoboticsFlowInput,
24
+ RoboticsInput,
25
+ RoboticsOutput,
26
+ RoboticsTarget,
27
+ RotationFormat,
28
+ expand_dims,
29
+ is_quaternion,
30
+ is_rotmat,
31
+ is_rotmat_3x3,
32
+ is_rotmat_9,
33
+ quaternion_half_cover,
34
+ rotmat_as_3x3,
35
+ rotmat_as_9,
36
+ )
37
+ from .configuration_spear import (
38
+ ControlDataIOConfig,
39
+ ImageSizeConfig,
40
+ PaliGemmaProcessorConfig,
41
+ )
42
+
43
+
44
+ class VLMProcessor(Configurable):
45
+ @abstractmethod
46
+ def preprocess_inputs(
47
+ self, chat: List[str], images: Dict[str, List[PIL.Image.Image]]
48
+ ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: ...
49
+
50
+ @property
51
+ @abstractmethod
52
+ def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
53
+ pass
54
+
55
+ @property
56
+ @abstractmethod
57
+ def image_sizes(self) -> Dict[str, ImageSizeConfig]:
58
+ pass
59
+
60
+
61
+ class EmptyTokenizer(Configurable):
62
+ """
63
+ Takes the LLM hidden states from `llm_layer_indices` and concatenates them to produce the
64
+ desired result. Includes the hidden states for the image tokens.
65
+ """
66
+
67
+ def __init__(self, config, tokenizer: transformers.PreTrainedTokenizerBase) -> None:
68
+ super().__init__(config)
69
+ self.tokenizer = tokenizer
70
+
71
+ def __call__(self, *_) -> str:
72
+ return ""
73
+
74
+
75
+ def np_unique(
76
+ data: np.ndarray,
77
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
78
+ """
79
+ Compute unique elements in data and corresponding indices.
80
+
81
+ np.unique returns the values in a sorted order, even if the source is not sorted. Thus, if you simply
82
+ run np.unique on unsorted data, the indices you will get will be invalid.
83
+
84
+ """
85
+ (_, indices, inverse) = np.unique(data, return_index=True, return_inverse=True)
86
+ (_, indices_of_first_occurence, inverse_indices, counts) = np.unique(
87
+ indices[inverse], return_index=True, return_inverse=True, return_counts=True
88
+ )
89
+ unique_ids = data[indices_of_first_occurence]
90
+ return unique_ids, indices_of_first_occurence, inverse_indices, counts
91
+
92
+
93
+ def euler_to_rotmat(angles: torch.Tensor) -> torch.Tensor:
94
+ """
95
+ Args:
96
+ angles: Euler angles in radians in the format 'xyz', shape [..., 3]
97
+ Returns:
98
+ torch.Tensor of shape [..., 3, 3] containing rotation matrices
99
+ """
100
+ return roma.euler_to_rotmat(convention="xyz", angles=angles, degrees=False)
101
+
102
+
103
+ def euler_to_unit_quaternion(angles: torch.Tensor) -> torch.Tensor:
104
+ """
105
+ Args:
106
+ angles: Euler angles in radians in the format 'xyz', shape [..., 3]
107
+ Returns:
108
+ torch.Tensor of shape [..., 4] containing unit quaternions
109
+ """
110
+ return roma.euler_to_unitquat(convention="xyz", angles=angles, degrees=False, normalize=True)
111
+
112
+
113
+ def normalize_quaternion(quaternion: torch.Tensor, eps: float = 1e-08) -> torch.Tensor:
114
+ """
115
+ Args:
116
+ quaternion: Unnormalized quaternion, torch.Tensor of shape [..., 4]
117
+ eps: Small constant to prevent division by zero
118
+ Returns:
119
+ torch.Tensor of shape [..., 4] of unit quaternions
120
+ """
121
+ return quaternion / (quaternion.norm(dim=-1, keepdim=True).detach() + eps)
122
+
123
+
124
+ def quaternion_to_euler(quaternion: torch.Tensor) -> torch.Tensor:
125
+ """
126
+ Args:
127
+ quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized
128
+ Returns:
129
+ torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3)
130
+ """
131
+ unit_quat = normalize_quaternion(quaternion)
132
+ rotmat = roma.unitquat_to_euler(convention="xyz", quat=unit_quat, as_tuple=False, degrees=False)
133
+ return rotmat
134
+
135
+
136
+ def quaternion_to_rotmat(quaternion: torch.Tensor) -> torch.Tensor:
137
+ """
138
+ Args:
139
+ quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized
140
+ Returns:
141
+ torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3)
142
+ """
143
+ unit_quat = normalize_quaternion(quaternion)
144
+ rotmat = roma.unitquat_to_rotmat(unit_quat)
145
+ return rotmat
146
+
147
+
148
+ def rotmat_to_unit_quaternion(rotmat: torch.Tensor) -> torch.Tensor:
149
+ """
150
+ Args:
151
+ rotmat: Batch of rotation matrices, shape [..., 3, 3]
152
+ Returns:
153
+ Batch of unit quaternions, shape [..., 4]
154
+ """
155
+ rotmat = rotmat_as_3x3(rotmat)
156
+ return roma.rotmat_to_unitquat(rotmat)
157
+
158
+
159
+ def rotmat_to_euler(rotmat: torch.Tensor) -> torch.Tensor:
160
+ """
161
+ Args:
162
+ rotmat: Batch of rotation matrices, shape [..., 3, 3]
163
+ Returns:
164
+ Batch of Euler angles in radiant, shape [..., 3]
165
+ """
166
+ rotmat = rotmat_as_3x3(rotmat)
167
+ return roma.rotmat_to_euler(convention="xyz", rotmat=rotmat, as_tuple=False, degrees=False)
168
+
169
+
170
+ def symmetric_orthogonalization(x: torch.Tensor) -> torch.Tensor:
171
+ """
172
+ Maps 9D input vectors onto SO(3) via symmetric orthogonalization.
173
+ - Let SVD(M) = U \Sigma V^T
174
+ - Returned value is SVD+(M) = U diag(1, 1, det(UV^T)) V^T
175
+ - det(UV^T) ensures that det(SVD+(M)) = 1
176
+ - The return value is a rotation matrix (ortonormal) with the least-squares distance to M
177
+
178
+ Args:
179
+ x: Input matrices, not necessarily orthonormal, shape [..., 9] or [..., 3, 3]
180
+ Returns:
181
+ torch.Tensor with the same shape as x, where each inner 3x3 matrix is in SO(3)
182
+ """
183
+ with warnings.catch_warnings():
184
+ warnings.filterwarnings(
185
+ "ignore",
186
+ message="In CPU autocast, but the target dtype is not supported. Disabling autocast.",
187
+ )
188
+ with torch.autocast(device_type=x.device.type, dtype=torch.float32):
189
+ matrices = x.view(-1, 3, 3)
190
+ matrices = matrices.to(dtype=torch.float32)
191
+ (u, s, v) = torch.svd(matrices)
192
+ vt = torch.transpose(v, 1, 2)
193
+ det = torch.det(torch.matmul(u, vt)).view(-1, 1, 1)
194
+ diag_vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), dim=1)
195
+ result = torch.matmul(u, diag_vt)
196
+ result = result.view(*x.shape)
197
+ result = result.to(dtype=x.dtype)
198
+ return result
199
+
200
+
201
+ def is_rotmat_orthonormal(
202
+ rotmat: torch.Tensor, epsilon: float = 1e-06, reduction: str = "none"
203
+ ) -> torch.Tensor | bool:
204
+ """
205
+ Check if a rotation matrix is orthonormal or not.
206
+ Args:
207
+ rotmat: torch.Tensor of shape [..., 3, 3] or [..., 9]
208
+ epsilon: Tolerance for numerical comparisons. Bigger values allow for more freedom. Generally,
209
+ anything smaller than 1e-6 might incorrectly detect some otrhonormal matrices as not
210
+ reduction:
211
+ 'none' - returns torch.Tensor of bools with the same batch shape
212
+ 'all' - returns a bool, True is ALL matrices in the batch are orthonormal
213
+ Returns:
214
+ torch.Tensor with the same batch shape or bool
215
+ """
216
+ assert is_rotmat(rotmat)
217
+ rotmat = rotmat_as_3x3(rotmat.to(dtype=torch.float32))
218
+ is_orthonormal = roma.is_orthonormal_matrix(rotmat, epsilon=epsilon)
219
+ if reduction == "none":
220
+ return is_orthonormal
221
+ if reduction == "all":
222
+ return bool(torch.all(is_orthonormal).item())
223
+ raise ValueError(f"Unknown reduction mode {reduction}")
224
+
225
+
226
+ def is_orthonormal_rotmat(rotmat: torch.Tensor) -> bool:
227
+ """
228
+ Checks if the tensor shape matches that of a rotmat. If the last dimensions of shape are 3x3,
229
+ also checks if the data is a valid rotmat. This is to avoid a possible clash with euler angles
230
+ when accidentally `rotmat.shape[-2:] == [3, 3]`
231
+ """
232
+ return (
233
+ is_rotmat_9(rotmat)
234
+ or is_rotmat_3x3(rotmat)
235
+ and is_rotmat_orthonormal(rotmat, epsilon=0.01, reduction="all")
236
+ )
237
+
238
+
239
+ def is_euler(euler: torch.Tensor) -> bool:
240
+ return euler.shape[-1] == 3 and not is_orthonormal_rotmat(euler)
241
+
242
+
243
+ def normalize_rotation(rotation: torch.Tensor) -> torch.Tensor:
244
+ if is_quaternion(rotation):
245
+ return normalize_quaternion(rotation)
246
+ if is_euler(rotation):
247
+ return rotation
248
+ if is_rotmat(rotation):
249
+ is_flat = is_rotmat_9(rotation)
250
+ rotation = rotmat_as_3x3(rotation) if is_flat else rotation
251
+ rotmat = roma.special_gramschmidt(rotation)
252
+ rotmat = rotmat_as_9(rotmat) if is_flat else rotmat
253
+ return rotmat
254
+ raise ValueError(f"Unknown rotation format: {rotation.shape}")
255
+
256
+
257
+ def rotation_format_from_tensor(rotation) -> RotationFormat:
258
+ if is_quaternion(rotation):
259
+ return RotationFormat.QUATERNION
260
+ if is_orthonormal_rotmat(rotation):
261
+ return RotationFormat.ROTMAT
262
+ if is_euler(rotation):
263
+ return RotationFormat.EULER
264
+ raise ValueError(f"Tensor shape {rotation.shape} is not a valid rotation format")
265
+
266
+
267
+ def is_unit_quaternion(
268
+ quaternion: torch.Tensor, epsilon: float = 1e-08, reduction: str = "none"
269
+ ) -> torch.Tensor | bool:
270
+ """
271
+ Check if a quternion is normalized or not.
272
+ Args:
273
+ quaternion: torch.Tensor of shape [..., 4]
274
+ tolerance: Tolerance for numerical comparisons
275
+ reduction:
276
+ 'none' - returns torch.Tensor of bools with the same batch shape
277
+ 'all' - returns a bool, True if ALL quaternions in the batch are normalized
278
+ Returns:
279
+ torch.Tensor with the same batch shape or bool
280
+ """
281
+ assert is_quaternion(quaternion)
282
+ is_norm = torch.isclose(
283
+ quaternion.norm(dim=-1, keepdim=True),
284
+ torch.tensor(1.0, dtype=quaternion.dtype, device=quaternion.device),
285
+ atol=epsilon,
286
+ )
287
+ if reduction == "none":
288
+ return is_norm
289
+ if reduction == "all":
290
+ return bool(torch.all(is_norm).item())
291
+ raise ValueError(f"Unknown reduction mode {reduction}")
292
+
293
+
294
+ def convert_rotation(
295
+ rotation: torch.Tensor | np.ndarray,
296
+ output_format: RotationFormat,
297
+ autonorm: bool = True,
298
+ half_cover: bool = True,
299
+ ) -> torch.Tensor | np.ndarray:
300
+ is_np = isinstance(rotation, np.ndarray)
301
+ if is_np:
302
+ rotation = torch.from_numpy(rotation)
303
+ if is_quaternion(rotation):
304
+ if autonorm and not is_unit_quaternion(rotation, reduction="all"):
305
+ rotation = normalize_quaternion(rotation)
306
+ if output_format == RotationFormat.QUATERNION:
307
+ output = rotation
308
+ elif output_format == RotationFormat.ROTMAT:
309
+ output = rotmat_as_9(quaternion_to_rotmat(rotation))
310
+ elif output_format == RotationFormat.EULER:
311
+ output = quaternion_to_euler(rotation)
312
+ else:
313
+ raise NotImplementedError(f"Unsupported rotation format: {output_format}")
314
+ elif is_orthonormal_rotmat(rotation):
315
+ if autonorm and not is_rotmat_orthonormal(rotation, epsilon=0.01, reduction="all"):
316
+ rotation = symmetric_orthogonalization(rotation)
317
+ if output_format == RotationFormat.QUATERNION:
318
+ output = rotmat_to_unit_quaternion(rotation)
319
+ elif output_format == RotationFormat.ROTMAT:
320
+ output = rotmat_as_9(rotation)
321
+ elif output_format == RotationFormat.EULER:
322
+ output = rotmat_to_euler(rotation)
323
+ else:
324
+ raise NotImplementedError(f"Unsupported rotation format: {output_format}")
325
+ elif is_euler(rotation):
326
+ if output_format == RotationFormat.QUATERNION:
327
+ output = euler_to_unit_quaternion(rotation)
328
+ elif output_format == RotationFormat.ROTMAT:
329
+ output = rotmat_as_9(euler_to_rotmat(rotation))
330
+ elif output_format == RotationFormat.EULER:
331
+ output = rotation
332
+ else:
333
+ raise NotImplementedError(f"Unsupported rotation format: {output_format}")
334
+ else:
335
+ raise ValueError(f"Unknown rotation encoding with shape {rotation.shape}")
336
+ if output_format == RotationFormat.QUATERNION and half_cover:
337
+ output = quaternion_half_cover(output)
338
+ if is_np:
339
+ output = output.numpy()
340
+ return output
341
+
342
+
343
+ def delta_to_relative_rotations(rotation_sequence: torch.Tensor) -> torch.Tensor:
344
+ """
345
+ Transform a sequence of rotation representations encoded w.r.t. the PREVIOUS rotation frame in the
346
+ sequence to the 0-th element preceding the sequence
347
+
348
+ Ex:
349
+ `rotation_sequence` contains the rotations: R_01, R_12, R_23, R_34, where R0 is the base frame,
350
+ implicitly encoded in R_01 and R_10 converts from R0 frame to R1 frame
351
+ Output: R_01, R_02, R_03, R_04
352
+
353
+ Args:
354
+ rotation_sequence: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4], containing
355
+ either rotation matrices (R_01, R_12, R_23, R_34, ...) or quaternions
356
+ Returns:
357
+ torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] containing transformed rotations
358
+ (R_01, R_02, R_03, R_04, ...)
359
+
360
+ TODO: Can you make it work without for loop
361
+ """
362
+ assert rotation_sequence.ndim >= 3, rotation_sequence.shape
363
+ rotation_format: RotationFormat = rotation_format_from_tensor(rotation_sequence)
364
+ rotation_sequence = convert_rotation(rotation_sequence, RotationFormat.QUATERNION)
365
+ batch_dims = np.arange(rotation_sequence.ndim - 2)
366
+ delta_rotations = torch.cat(
367
+ [rotation_sequence[..., :1, :]]
368
+ + [
369
+ roma.quat_composition(rotation_sequence[..., :i, :].permute(-2, *batch_dims, -1).unsqueeze(-2))
370
+ for i in range(2, rotation_sequence.shape[-2] + 1)
371
+ ],
372
+ dim=-2,
373
+ )
374
+ delta_rotations = convert_rotation(delta_rotations, rotation_format)
375
+ return delta_rotations
376
+
377
+
378
+ def assert_np_hwc_or_hw_image(image: np.ndarray | PIL.Image.Image) -> np.ndarray:
379
+ """Make sure image is of type np.ndarray and HWC format"""
380
+ if isinstance(image, PIL.Image.Image):
381
+ image = np.asarray(image)
382
+ assert isinstance(image, np.ndarray), type(image)
383
+ assert image.ndim in [2, 3], image.shape
384
+ if image.ndim == 3:
385
+ assert image.shape[-1] <= 4, image.shape
386
+ return image
387
+
388
+
389
+ def hw_from_image(image: PIL.Image.Image | np.ndarray) -> tuple[int, int]:
390
+ if isinstance(image, np.ndarray):
391
+ (height, width) = image.shape[:2]
392
+ else:
393
+ (width, height) = image.size
394
+ return height, width
395
+
396
+
397
+ def pad_image(
398
+ image: PIL.Image.Image | np.ndarray,
399
+ target_size: dict[str, int],
400
+ pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
401
+ ) -> PIL.Image.Image | np.ndarray:
402
+ """Pad image adding a symmetric border around the height/width."""
403
+ assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image)
404
+ (height, width) = hw_from_image(image)
405
+ (target_width, target_height) = (target_size["width"], target_size["height"])
406
+ if width == target_width and height == target_height:
407
+ return image
408
+ assert target_width >= width, f"Can't pad image of width {width} to {target_width}"
409
+ assert target_height >= height, f"Can't pad image of height {height} to {target_height}"
410
+ (horizontal_pad, vertical_pad) = (
411
+ int((target_width - width) / 2),
412
+ int((target_height - height) / 2),
413
+ )
414
+ if isinstance(image, np.ndarray):
415
+ padding = ((vertical_pad, vertical_pad), (horizontal_pad, horizontal_pad)) + ((0, 0),) * (
416
+ image.ndim - 2
417
+ )
418
+ image = np.pad(image, padding, mode="constant", constant_values=pad_value)
419
+ else:
420
+ padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
421
+ image = torchvision.transforms.v2.functional.pad(
422
+ image, padding=padding, fill=pad_value, padding_mode="constant"
423
+ )
424
+ return image
425
+
426
+
427
+ def pad_image_to_ratio(
428
+ image: PIL.Image.Image | np.ndarray,
429
+ target_wh_ratio: float,
430
+ pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
431
+ ) -> PIL.Image.Image | np.ndarray:
432
+ """Pad image to a target aspect ratio."""
433
+ (height, width) = hw_from_image(image)
434
+ wh_ratio = width / height
435
+ if target_wh_ratio >= wh_ratio:
436
+ pad_size = {"width": round(height * target_wh_ratio), "height": height}
437
+ else:
438
+ pad_size = {"width": width, "height": round(width / target_wh_ratio)}
439
+ image = pad_image(image, target_size=pad_size, pad_value=pad_value)
440
+ return image
441
+
442
+
443
+ def crop_image(
444
+ image: np.ndarray | PIL.Image.Image,
445
+ start_height: int,
446
+ start_width: int,
447
+ target_height: int,
448
+ target_width: int,
449
+ ) -> np.ndarray | PIL.Image.Image:
450
+ np_image = assert_np_hwc_or_hw_image(image)
451
+ (height, width) = hw_from_image(image)
452
+ assert target_width <= width, f"Can't crop image of width {width} to {target_width}"
453
+ assert target_height <= height, f"Can't crop image of width {height} to {target_height}"
454
+ (start_height, start_width) = (round(start_height), round(start_width))
455
+ (target_height, target_width) = (round(target_height), round(target_width))
456
+ np_image = np_image[
457
+ start_height : start_height + target_height,
458
+ start_width : start_width + target_width,
459
+ ...,
460
+ ]
461
+ image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image
462
+ return image
463
+
464
+
465
+ def crop_image_center(
466
+ image: np.ndarray | PIL.Image.Image, target_size: dict[str, int]
467
+ ) -> np.ndarray | PIL.Image.Image:
468
+ np_image = assert_np_hwc_or_hw_image(image)
469
+ (height, width) = np_image.shape[:2]
470
+ (target_height, target_width) = (target_size["height"], target_size["width"])
471
+ assert target_width <= width, f"Can't crop image of width {width} to {target_width}"
472
+ assert target_height <= height, f"Can't crop image of width {height} to {target_height}"
473
+ top = (height - target_height) // 2
474
+ left = (width - target_width) // 2
475
+ np_image = crop_image(np_image, top, left, target_height, target_width)
476
+ image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image
477
+ return image
478
+
479
+
480
+ def crop_image_to_ratio(
481
+ image: PIL.Image.Image | np.ndarray, target_wh_ratio: float
482
+ ) -> PIL.Image.Image | np.ndarray:
483
+ """Pad image to a target aspect ratio."""
484
+ (height, width) = hw_from_image(image)
485
+ wh_ratio = width / height
486
+ if target_wh_ratio >= wh_ratio:
487
+ crop_size = {"width": width, "height": round(width / target_wh_ratio)}
488
+ else:
489
+ crop_size = {"width": round(height * target_wh_ratio), "height": height}
490
+ image = crop_image_center(image, target_size=crop_size)
491
+ return image
492
+
493
+
494
+ def crop_and_pad_image_to_ratio(
495
+ image: PIL.Image.Image | np.ndarray,
496
+ target_wh_ratio: float,
497
+ mode: ResizeMode | str,
498
+ pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
499
+ ) -> PIL.Image.Image | np.ndarray:
500
+ """
501
+ Crop and pad an image to a target size depending on the mode.
502
+ It's expected that the source image and target size have different aspect ratios.
503
+
504
+ Args:
505
+ image: The image to crop and pad.
506
+ target_size: The target size to crop and pad the image to.
507
+ mode: The mode to use for cropping and padding.
508
+ """
509
+ (height, width) = hw_from_image(image)
510
+ wh_ratio = width / height
511
+ if np.isclose(wh_ratio, target_wh_ratio, rtol=0.01, atol=0.0001):
512
+ return image
513
+ if mode == ResizeMode.SMART:
514
+ aspect_ratio = max(width, height) / min(width, height)
515
+ target_ratio = max(target_wh_ratio, 1 / target_wh_ratio)
516
+ if aspect_ratio == 1:
517
+ if target_ratio >= 4 / 3 - 0.01:
518
+ crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4
519
+ image = crop_image_to_ratio(image, crop_wh_ratio)
520
+ else:
521
+ pass
522
+ elif aspect_ratio <= 4 / 3 + 0.01:
523
+ if wh_ratio >= 1.0 != (target_wh_ratio >= 1.0):
524
+ image = crop_image_to_ratio(image, 1.0)
525
+ elif wh_ratio >= 1.0 != (target_wh_ratio >= 1.0):
526
+ image = crop_image_to_ratio(image, 1.0)
527
+ elif target_ratio >= 4 / 3 + 0.01:
528
+ pass
529
+ else:
530
+ crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4
531
+ image = crop_image_to_ratio(image, crop_wh_ratio)
532
+ image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value)
533
+ elif mode == ResizeMode.PAD:
534
+ image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value)
535
+ elif mode == ResizeMode.CROP:
536
+ image = crop_image_to_ratio(image, target_wh_ratio)
537
+ else:
538
+ raise ValueError(f"Mode {mode} not supported")
539
+ return image
540
+
541
+
542
+ def is_single_channel_image(image: np.ndarray | PIL.Image.Image) -> bool:
543
+ if isinstance(image, PIL.Image.Image):
544
+ return image.mode in [
545
+ "1",
546
+ "L",
547
+ "LA",
548
+ "La",
549
+ "P",
550
+ "PA",
551
+ "F",
552
+ "I",
553
+ "I;16",
554
+ "I;16L",
555
+ "I;16B",
556
+ "I;16N",
557
+ ]
558
+ if isinstance(image, np.ndarray):
559
+ return image.ndim == 2 or image.ndim == 3 and image.shape[2] == 1
560
+ raise ValueError(f"Unsupported image type: {type(image)}")
561
+
562
+
563
+ def is_binary_mask(image: np.ndarray | PIL.Image.Image) -> bool:
564
+ image = np.asarray(image)
565
+ return image.dtype in [np.uint8, np.bool_] and np.max(image) == 1
566
+
567
+
568
+ def resize_image(
569
+ image: PIL.Image.Image | np.ndarray,
570
+ target_size: dict[str, int],
571
+ mode: ResizeMode | str,
572
+ resample: PIL.Image.Resampling | str = "auto",
573
+ pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
574
+ ) -> PIL.Image.Image | np.ndarray:
575
+ (target_width, target_height) = (target_size["width"], target_size["height"])
576
+ (height, width) = hw_from_image(image)
577
+ if height == target_height and width == target_width:
578
+ return image
579
+ if resample == "auto":
580
+ if is_single_channel_image(image):
581
+ resample = PIL.Image.Resampling.BILINEAR
582
+ else:
583
+ resample = PIL.Image.Resampling.LANCZOS
584
+ else:
585
+ assert isinstance(resample, PIL.Image.Resampling), resample
586
+ if is_single_channel_image(image) and resample not in [
587
+ PIL.Image.Resampling.BILINEAR,
588
+ PIL.Image.Resampling.BICUBIC,
589
+ ]:
590
+ raise ValueError(
591
+ f"Single channel images must be resized with bilinear or bicubic, but got {resample}"
592
+ )
593
+ if is_bin_mask := is_binary_mask(image):
594
+ image = np.asarray(image).astype(np.uint8) * 255
595
+ if mode == ResizeMode.SMART:
596
+ image = crop_and_pad_image_to_ratio(
597
+ image,
598
+ target_wh_ratio=target_width / target_height,
599
+ mode=mode,
600
+ pad_value=pad_value,
601
+ )
602
+ pil_image = PIL.Image.fromarray(image) if isinstance(image, np.ndarray) else image
603
+ if mode in [ResizeMode.NAIVE, ResizeMode.SMART]:
604
+ pil_image = pil_image.resize((target_width, target_height), resample=resample)
605
+ else:
606
+ raise NotImplementedError(f"Mode {mode} not supported")
607
+ image = np.asarray(pil_image) if isinstance(image, np.ndarray) else pil_image
608
+ if is_bin_mask:
609
+ image = image.astype(np.uint8) > 127
610
+ return image
611
+
612
+
613
+ def is_global_norm(
614
+ norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list],
615
+ ) -> bool:
616
+ """Return true if norm is NONE or global for all datasets"""
617
+ return norm == Normalization.NONE or isinstance(norm, collections.abc.Mapping)
618
+
619
+
620
+ def is_mean_norm(
621
+ norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list],
622
+ ) -> bool:
623
+ """Return true if norm is based on mean and std"""
624
+ return (
625
+ norm == Normalization.MEAN
626
+ or isinstance(norm, collections.abc.Mapping)
627
+ and set(norm.keys()) == {"mean", "std"}
628
+ )
629
+
630
+
631
+ def _broadcast_shapes(
632
+ value: torch.Tensor, low: torch.Tensor, high: torch.Tensor
633
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
634
+ """
635
+ Broadcast shapes for normalization:
636
+ Args:
637
+ value: torch.Tensor of shape [..., num_components]. The entire shape might be:
638
+ - [num_components]: `value` has no batch dimension
639
+ - [num_datasets, num_components]: `value` contains entries *aligned* with the dataset bounds
640
+ contained in `low` and `high`
641
+ - [num_datasets, ..., num_components]: `value` contains entries *aligned* with the dataset bounds
642
+ contained in `low` and `high`
643
+ - [..., num_components]: `value` contains multiple dimensions. In this case, `low` and `high`
644
+ must be for a single dataset, i.e. `num_datasets = 1`
645
+
646
+ low: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `low`
647
+ contains normalization bounds for a single dataset
648
+ high: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `high`
649
+ contains normalization bounds for a single dataset
650
+ Returns:
651
+ Tuple of torch.Tensors (low, high), where `low` and `high` have the same number of dimensions as `value`
652
+ """
653
+ assert low.ndim == high.ndim == 2, f"{low.shape} != {high.shape} or ndim != 2"
654
+ assert value.shape[-1] == low.shape[-1] == high.shape[-1], f"{value.shape} != {low.shape} / {high.shape}"
655
+ if value.ndim == low.ndim == high.ndim:
656
+ return low, high
657
+ if value.ndim < low.ndim:
658
+ assert low.ndim == high.ndim == 2, f"{low.shape}, {high.shape}"
659
+ assert low.shape[0] == high.shape[0] == 1, f"{low.shape}, {high.shape}"
660
+ (low, high) = (low.view(-1), high.view(-1))
661
+ return low, high
662
+ if low.shape[0] == high.shape[0] == 1:
663
+ low = expand_dims(low.view(-1), ndim=value.ndim, order=[-1, 1])
664
+ high = expand_dims(high.view(-1), ndim=value.ndim, order=[-1, 1])
665
+ else:
666
+ assert value.shape[0] == low.shape[0] == high.shape[0], f"{value.shape} != {low.shape} / {high.shape}"
667
+ low = expand_dims(low, ndim=value.ndim, order=[1, -1, 1])
668
+ high = expand_dims(high, ndim=value.ndim, order=[1, -1, 1])
669
+ return low, high
670
+
671
+
672
+ def unnormalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
673
+ (mean, std) = _broadcast_shapes(value, mean, std)
674
+ (mean, std) = (mean.to(device=value.device), std.to(device=value.device))
675
+ return value * (std + 1e-08) + mean
676
+
677
+
678
+ def unnormalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor:
679
+ (low, high) = _broadcast_shapes(value, low, high)
680
+ (low, high) = (low.to(device=value.device), high.to(device=value.device))
681
+ return 0.5 * (value + 1) * (high - low) + low
682
+
683
+
684
+ def normalize_gripper_by_bounds(
685
+ value: torch.Tensor, low: torch.Tensor, high: torch.Tensor, binary: bool = True
686
+ ) -> torch.Tensor:
687
+ """
688
+ If binary, normalize to [0, 1], otherwise normalize to [-1, 1]
689
+ """
690
+ (low, high) = _broadcast_shapes(value, low, high)
691
+ (low, high) = (low.to(device=value.device), high.to(device=value.device))
692
+ if binary:
693
+ return torch.clamp((value - low) / torch.clamp(high - low, min=1e-08), min=0.0, max=1.0)
694
+ return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0)
695
+
696
+
697
+ def normalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
698
+ (mean, std) = _broadcast_shapes(value, mean, std)
699
+ (mean, std) = (mean.to(device=value.device), std.to(device=value.device))
700
+ return (value - mean) / (std + 1e-08)
701
+
702
+
703
+ def normalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor:
704
+ (low, high) = _broadcast_shapes(value, low, high)
705
+ (low, high) = (low.to(device=value.device), high.to(device=value.device))
706
+ return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0)
707
+
708
+
709
+ def invert_gripper(gripper: np.ndarray, low: float, high: float) -> np.ndarray:
710
+ if low < 0.0:
711
+ return np.clip(-gripper, low, high)
712
+ return high - np.clip(gripper, low, high)
713
+
714
+
715
+ GRIPPER_BOUNDS = {
716
+ "bridge": (0.0, 1.0),
717
+ "bridge_orig": (0.0, 1.0),
718
+ "droid": (0.0, 1.0),
719
+ "roboset": (0.0, 1.0),
720
+ }
721
+
722
+
723
+ def preprocess_gripper_observation(
724
+ gripper: np.ndarray, dataset_name: str | np.ndarray, binary: bool = True
725
+ ) -> np.ndarray:
726
+ """
727
+ Preprocess gripper observation depending on dataset. Input is the raw gripper observation from the dataset
728
+ or from the robot and output is normalized continuous value.
729
+ - if `binary`, output is in [0, 1], with 0 = closed and 1 = open.
730
+ - otherwise, output is in [-1, 1], with -1 = closed and 1 = open.
731
+
732
+ Dataset-specific gripper observations:
733
+ bridge: continuous; ~[0=closed; 1=open]
734
+ bridge_orig: continuous; ~[0=closed; 1=open]
735
+ droid: continuous; [0=open, 1=closed]
736
+ roboset: continuous; [0=open, 1=closed]
737
+ """
738
+ if isinstance(dataset_name, np.ndarray):
739
+ assert np.unique(dataset_name).size == 1, dataset_name
740
+ dataset_name = str(dataset_name[0])
741
+ if dataset_name in [
742
+ "droid",
743
+ "roboset",
744
+ ]:
745
+ (low, high) = GRIPPER_BOUNDS[dataset_name]
746
+ gripper = normalize_gripper_by_bounds(
747
+ torch.from_numpy(invert_gripper(gripper, low=low, high=high)),
748
+ low=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][0], dtype=torch.float32),
749
+ high=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][1], dtype=torch.float32),
750
+ binary=binary,
751
+ ).numpy()
752
+ elif dataset_name in [
753
+ "bridge",
754
+ "bridge_orig",
755
+ ]:
756
+ (low, high) = GRIPPER_BOUNDS[dataset_name]
757
+ gripper = normalize_gripper_by_bounds(
758
+ torch.from_numpy(gripper),
759
+ low=torch.full(gripper.shape, low, dtype=torch.float32),
760
+ high=torch.full(gripper.shape, high, dtype=torch.float32),
761
+ binary=binary,
762
+ ).numpy()
763
+ else:
764
+ raise NotImplementedError(f"Unknown dataset: {dataset_name}")
765
+ return gripper
766
+
767
+
768
+ def rotation_norm_bounds(
769
+ rotation_norm: Normalization,
770
+ rotation_format: RotationFormat,
771
+ stats: Dict[str, Dict[str, Dict[str, List[float]]]],
772
+ dataset_names: List[str],
773
+ ) -> Dict[str, Dict[str, torch.Tensor]]:
774
+ if rotation_format == RotationFormat.EULER and rotation_norm != Normalization.NONE:
775
+ if rotation_norm == Normalization.BOUNDS:
776
+ results = {
777
+ dataset_name: {
778
+ "low": torch.tensor(dataset_stats["euler"]["min"]),
779
+ "high": torch.tensor(dataset_stats["euler"]["max"]),
780
+ }
781
+ for (dataset_name, dataset_stats) in stats.items()
782
+ }
783
+ elif rotation_norm == Normalization.BOUNDS_Q99:
784
+ results = {
785
+ dataset_name: {
786
+ "low": torch.tensor(dataset_stats["euler"]["q01"]),
787
+ "high": torch.tensor(dataset_stats["euler"]["q99"]),
788
+ }
789
+ for (dataset_name, dataset_stats) in stats.items()
790
+ }
791
+ else:
792
+ raise NotImplementedError(f"Normalization type {rotation_norm} not yet implemented")
793
+ else:
794
+ assert rotation_norm == Normalization.NONE, rotation_norm
795
+ if rotation_format == RotationFormat.EULER:
796
+ rotation_size = 3
797
+ elif rotation_format == RotationFormat.QUATERNION:
798
+ rotation_size = 4
799
+ else:
800
+ rotation_size = 9
801
+ results = {
802
+ dataset_name: {
803
+ "low": -1 * torch.ones(rotation_size, dtype=torch.float32),
804
+ "high": 1 * torch.ones(rotation_size, dtype=torch.float32),
805
+ }
806
+ for dataset_name in dataset_names
807
+ }
808
+ return results
809
+
810
+
811
+ def translation_norm_bounds(
812
+ translation_norm: Normalization | tuple,
813
+ stats: Dict[str, Dict[str, Dict[str, List[float]]]],
814
+ dataset_names: List[str],
815
+ ) -> Dict[str, Dict[str, torch.Tensor]]:
816
+ if isinstance(translation_norm, (Normalization, str)) and translation_norm != Normalization.NONE:
817
+ if translation_norm == Normalization.BOUNDS:
818
+ results = {
819
+ dataset_name: {
820
+ "low": torch.tensor(dataset_stats["translation"]["min"]),
821
+ "high": torch.tensor(dataset_stats["translation"]["max"]),
822
+ }
823
+ for (dataset_name, dataset_stats) in stats.items()
824
+ }
825
+ elif translation_norm == Normalization.BOUNDS_Q99:
826
+ results = {
827
+ dataset_name: {
828
+ "low": torch.tensor(dataset_stats["translation"]["q01"]),
829
+ "high": torch.tensor(dataset_stats["translation"]["q99"]),
830
+ }
831
+ for (dataset_name, dataset_stats) in stats.items()
832
+ }
833
+ elif translation_norm == Normalization.MEAN:
834
+ results = {
835
+ dataset_name: {
836
+ "mean": torch.tensor(dataset_stats["translation"]["mean"]),
837
+ "std": torch.tensor(dataset_stats["translation"]["std"]),
838
+ }
839
+ for (dataset_name, dataset_stats) in stats.items()
840
+ }
841
+ else:
842
+ raise NotImplementedError(f"Normalization type {translation_norm} not yet implemented")
843
+ elif isinstance(translation_norm, Normalization) and translation_norm == Normalization.NONE:
844
+ results = {
845
+ dataset_name: {
846
+ "low": -1 * torch.ones(3, dtype=torch.float32),
847
+ "high": 1 * torch.ones(3, dtype=torch.float32),
848
+ }
849
+ for dataset_name in dataset_names
850
+ }
851
+ else:
852
+ assert isinstance(translation_norm, collections.abc.Mapping), type(translation_norm)
853
+ assert all((len(value) == 3 for value in translation_norm.values())), translation_norm
854
+ assert set(translation_norm.keys()) in (
855
+ {"low", "high"},
856
+ {"mean", "std"},
857
+ ), translation_norm
858
+ results = {
859
+ dataset_name: {
860
+ key: torch.tensor(value, dtype=torch.float32) for (key, value) in translation_norm.items()
861
+ }
862
+ for dataset_name in dataset_names
863
+ }
864
+ return results
865
+
866
+
867
+ VLAMProcessorConfigT = TypeVar("VLAMProcessorConfigT")
868
+
869
+
870
+ class VLAMProcessor(Configurable):
871
+ def __init__(self, config: VLAMProcessorConfigT, vlm_processor: VLMProcessor):
872
+ super().__init__(config)
873
+ self.vlm_processor = vlm_processor
874
+ self.control_tokenizer = EmptyTokenizer(
875
+ config=self.config.control_tokenizer_config, tokenizer=self.tokenizer
876
+ )
877
+ self.norm_bounds: Dict[str, Dict[str, Dict[str, torch.Tensor]]] = {
878
+ "obs_translation": self.obs_translation_norm_bounds,
879
+ "obs_rotation": self.obs_rotation_norm_bounds,
880
+ "translation": self.translation_norm_bounds,
881
+ "rotation": self.rotation_norm_bounds,
882
+ "joints": self.joints_norm_bounds,
883
+ }
884
+
885
+ @property
886
+ def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
887
+ return self.vlm_processor.tokenizer
888
+
889
+ @property
890
+ def image_sizes(self) -> Dict[str, ImageSizeConfig]:
891
+ return self.vlm_processor.image_sizes
892
+
893
+ @property
894
+ def camera_names(self) -> List[str]:
895
+ return list(self.vlm_processor.image_sizes.keys())
896
+
897
+ @property
898
+ def control_io_config(self) -> ControlDataIOConfig:
899
+ return self.config.control_io_config
900
+
901
+ @cached_property
902
+ def rotation_components(self) -> int:
903
+ if self.config.rotation_format == RotationFormat.EULER:
904
+ return 3
905
+ if self.config.rotation_format == RotationFormat.QUATERNION:
906
+ return 4
907
+ if self.config.rotation_format == RotationFormat.ROTMAT:
908
+ return 9
909
+ raise NotImplementedError(self.config.rotation_format)
910
+
911
+ @abstractmethod
912
+ def policy_control_plan_from_model_target(
913
+ self, target: RoboticsTarget, dataset_name: np.ndarray
914
+ ) -> RoboticsControlPlan:
915
+ pass
916
+
917
+ @abstractmethod
918
+ def policy_control_plan_from_model_output(
919
+ self,
920
+ model_output: RoboticsOutput,
921
+ dataset_name: np.ndarray,
922
+ valid_mask: torch.Tensor,
923
+ ) -> RoboticsControlPlan:
924
+ pass
925
+
926
+ def resize_image(
927
+ self, camera_name: str, image: PIL.Image.Image | np.ndarray
928
+ ) -> PIL.Image.Image | np.ndarray:
929
+ return resize_image(
930
+ image,
931
+ target_size={
932
+ "width": self.image_sizes[camera_name].width,
933
+ "height": self.image_sizes[camera_name].height,
934
+ },
935
+ mode=self.config.image_resize,
936
+ resample=PIL.Image.Resampling.LANCZOS,
937
+ )
938
+
939
+ def preprocess_inputs(
940
+ self,
941
+ chat: List[str],
942
+ images: Dict[str, PIL.Image.Image | List[PIL.Image.Image]],
943
+ ee_pose_translation: np.ndarray,
944
+ ee_pose_rotation: np.ndarray,
945
+ gripper: np.ndarray,
946
+ joints: np.ndarray,
947
+ dataset_name: np.ndarray,
948
+ inference_mode: bool,
949
+ control_target: Optional[RoboticsTarget] = None,
950
+ ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
951
+ """
952
+ Preprocess the inputs for a single example
953
+ Args:
954
+ instruction: Language instruction
955
+ images: History of input images with increasing timestamps
956
+ ee_pose_translation: np.ndarray, shape [..., num_past_scalars, 3]
957
+ ee_pose_rotation: np.ndarray, shape [..., num_past_scalars, 3 | 4 | 9]
958
+ joints: np.ndarray, shape [..., num_past_scalars, <= 7]
959
+ dataset_name: 1D np.ndarray
960
+ inference_mode: If True, prepare the input for inference (e.g. don't include target
961
+ any tokens in the input if relevant). If control_target is available, it should
962
+ still be preprocessed for test dataset comparison
963
+ control_target: RoboticsTarget, each component of shape
964
+ [..., num_control_steps, num_control_components]. Provided only when available, usually
965
+ during training and dataset test
966
+ Returns:
967
+ Dict containing torch.Tensor with inputs
968
+ """
969
+ del control_target
970
+ del inference_mode
971
+ inputs = self.vlm_processor.preprocess_inputs(chat=chat, images=images)
972
+ images: Dict[str, torch.Tensor] = inputs["images"]
973
+ input_ids: torch.Tensor = inputs["input_ids"][..., : self.tokenizer.model_max_length]
974
+ target_text_tokens_ids: torch.Tensor = inputs["target_ids"][..., : self.tokenizer.model_max_length]
975
+ attn_mask = torch.ones(input_ids.shape, dtype=torch.bool)
976
+ ee_pose_translation = torch.tensor(ee_pose_translation, dtype=torch.float32)
977
+ ee_pose_rotation = torch.tensor(ee_pose_rotation, dtype=torch.float32)
978
+ ee_pose_rotation = convert_rotation(ee_pose_rotation, self.config.rotation_format, autonorm=True)
979
+ gripper = preprocess_gripper_observation(gripper, dataset_name)
980
+ gripper = torch.tensor(gripper, dtype=torch.float32)
981
+ ee_pose_translation = self.normalize(
982
+ ee_pose_translation, dataset_name=dataset_name, key="obs_translation"
983
+ )
984
+ ee_pose_rotation = self.normalize(ee_pose_rotation, dataset_name=dataset_name, key="obs_rotation")
985
+ joints = torch.tensor(joints, dtype=torch.float32)
986
+ if joints.shape[-1] < 7:
987
+ missing_size = 7 - joints.shape[-1]
988
+ joints = torch.cat([joints, torch.zeros([*joints.shape[:-1], missing_size])], dim=-1)
989
+ joints = self.normalize(joints, dataset_name=dataset_name, key="joints")
990
+ outputs = {
991
+ "images": images,
992
+ "input_ids": input_ids,
993
+ "target_text_tokens_ids": target_text_tokens_ids,
994
+ "attn_mask": attn_mask,
995
+ "ee_pose_translation": ee_pose_translation,
996
+ "ee_pose_rotation": ee_pose_rotation,
997
+ "gripper": gripper,
998
+ "joints": joints,
999
+ "control_tokens_ids": None,
1000
+ "target_control_tokens_ids": None,
1001
+ }
1002
+ return outputs
1003
+
1004
+ def create_input(
1005
+ self,
1006
+ chat: List[str],
1007
+ images: Dict[str, List[PIL.Image.Image]],
1008
+ ee_pose_translation: np.ndarray,
1009
+ ee_pose_rotation: np.ndarray,
1010
+ gripper: np.ndarray,
1011
+ joints: np.ndarray,
1012
+ dataset_name: np.ndarray,
1013
+ inference_mode: bool,
1014
+ control_target: Optional[RoboticsTarget] = None,
1015
+ ) -> RoboticsInput:
1016
+ inputs = self.preprocess_inputs(
1017
+ chat=chat,
1018
+ images=images,
1019
+ ee_pose_translation=ee_pose_translation,
1020
+ ee_pose_rotation=ee_pose_rotation,
1021
+ gripper=gripper,
1022
+ joints=joints,
1023
+ dataset_name=dataset_name,
1024
+ inference_mode=inference_mode,
1025
+ control_target=control_target,
1026
+ )
1027
+ inputs.pop("target_text_tokens_ids")
1028
+ inputs.pop("target_control_tokens_ids")
1029
+ return RoboticsInput(**inputs)
1030
+
1031
+ def normalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor:
1032
+ if is_mean_norm(getattr(self.config, f"{key}_norm")):
1033
+ (mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
1034
+ output = normalize_by_moments(value, mean=mean, std=std)
1035
+ else:
1036
+ (low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
1037
+ output = normalize_by_bounds(value, low=low, high=high)
1038
+ return output
1039
+
1040
+ def unnormalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor:
1041
+ if is_mean_norm(getattr(self.config, f"{key}_norm")):
1042
+ (mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
1043
+ output = unnormalize_by_moments(value, mean=mean, std=std)
1044
+ else:
1045
+ (low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
1046
+ output = unnormalize_by_bounds(value, low=low, high=high)
1047
+ return output
1048
+
1049
+ def _norm_bounds_from_dataset_name(
1050
+ self, dataset_name: np.ndarray, component_key: str
1051
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1052
+ """
1053
+ Create an array of normalization bounds corresponding to dataset names
1054
+ Args:
1055
+ dataset_name: Array of shape [B] of dataset names for which to fetch the low and high
1056
+ normalization bounds. Note the values can be repeating
1057
+ component_key: str. One of 'action', 'translation', 'rotation'. Indicates for which control to
1058
+ compute the normalization bounds
1059
+ Returns:
1060
+ Tuple of low and high bounds or norm and std, each of shape [B, -1]
1061
+ """
1062
+ norm = getattr(self.config, f"{component_key}_norm")
1063
+ if is_mean_norm(norm):
1064
+ (stats_key_1, stats_key_2) = ("mean", "std")
1065
+ else:
1066
+ (stats_key_1, stats_key_2) = ("low", "high")
1067
+ if component_key == "joints":
1068
+ if not isinstance(norm, collections.abc.Mapping):
1069
+ raise NotImplementedError()
1070
+ stats = {
1071
+ key: torch.from_numpy(np.tile(np.reshape(value, [1, -1]), [len(dataset_name), 1]))
1072
+ for (key, value) in self.joints_norm_bounds["ANY"].items()
1073
+ }
1074
+ return tuple(stats.values())
1075
+ component_size = list(list(self.norm_bounds[component_key].values())[0].values())[0].shape[-1]
1076
+ if self.dataset_names == ["ANY"]:
1077
+ stats_1 = self.norm_bounds[component_key]["ANY"][stats_key_1]
1078
+ stats_2 = self.norm_bounds[component_key]["ANY"][stats_key_2]
1079
+ stats_1 = np.repeat(np.expand_dims(stats_1, axis=0), len(dataset_name), axis=0)
1080
+ stats_2 = np.repeat(np.expand_dims(stats_2, axis=0), len(dataset_name), axis=0)
1081
+ else:
1082
+ (unique_names, _, inverse_indices, _) = np_unique(dataset_name)
1083
+ stats_1 = np.zeros([len(unique_names), component_size], dtype=np.float32)
1084
+ stats_2 = np.zeros([len(unique_names), component_size], dtype=np.float32)
1085
+ for i, ds_name in enumerate(unique_names):
1086
+ stats_1[i] = self.norm_bounds[component_key][ds_name][stats_key_1].numpy()
1087
+ stats_2[i] = self.norm_bounds[component_key][ds_name][stats_key_2].numpy()
1088
+ stats_1 = stats_1[inverse_indices]
1089
+ stats_2 = stats_2[inverse_indices]
1090
+ return torch.from_numpy(stats_1), torch.from_numpy(stats_2)
1091
+
1092
+ @cached_property
1093
+ def obs_rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
1094
+ return rotation_norm_bounds(
1095
+ rotation_norm=self.config.obs_rotation_norm,
1096
+ rotation_format=self.config.rotation_format,
1097
+ stats=self._observation_stats,
1098
+ dataset_names=self.dataset_names,
1099
+ )
1100
+
1101
+ @cached_property
1102
+ def obs_translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
1103
+ return translation_norm_bounds(
1104
+ translation_norm=self.config.obs_translation_norm,
1105
+ stats=self._observation_stats,
1106
+ dataset_names=self.dataset_names,
1107
+ )
1108
+
1109
+ @cached_property
1110
+ def rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
1111
+ return rotation_norm_bounds(
1112
+ rotation_norm=self.config.rotation_norm,
1113
+ rotation_format=self.config.rotation_format,
1114
+ stats=self._control_stats,
1115
+ dataset_names=self.dataset_names,
1116
+ )
1117
+
1118
+ @cached_property
1119
+ def translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
1120
+ return translation_norm_bounds(
1121
+ translation_norm=self.config.translation_norm,
1122
+ stats=self._control_stats,
1123
+ dataset_names=self.dataset_names,
1124
+ )
1125
+
1126
+ @cached_property
1127
+ def joints_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
1128
+ """
1129
+ NOTE:
1130
+ - Joint values across all joints and all datasets vary in the range [-2pi; 2pi]
1131
+ - The effective range of a single joint is in practice one of [-2pi; 0], [-pi; pi], [0; 2pi]
1132
+ - It's possible to shift all ranges to [-pi; pi], but it requires careful handling for each joint
1133
+ """
1134
+ low = torch.tensor(self.config.joints_norm["low"], dtype=torch.float32)
1135
+ high = torch.tensor(self.config.joints_norm["high"], dtype=torch.float32)
1136
+ results = {"ANY": {"low": low, "high": high}}
1137
+ return results
1138
+
1139
+ @cached_property
1140
+ def _observation_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]:
1141
+ return {
1142
+ "bridge": {
1143
+ "euler": {
1144
+ "max": [3.141592653589793, 1.570796251296997, 3.141204357147217],
1145
+ "mean": [
1146
+ -0.25754162314671525,
1147
+ -0.12370228389510128,
1148
+ 0.1620053749182691,
1149
+ ],
1150
+ "min": [-3.141592653492551, -1.4832241535186768, -3.14153790473938],
1151
+ "q01": [-3.138795563420751, -0.56544608771801, -1.4952478170394896],
1152
+ "q99": [3.138720980629329, 0.2677614077925682, 2.0032371997833236],
1153
+ "std": [3.0257414011616577, 0.1622662085147332, 0.6404942954645315],
1154
+ },
1155
+ "gripper": {
1156
+ "max": [1.0370277166366577],
1157
+ "min": [0.04637829214334488],
1158
+ "q01": [0.05192930996417999],
1159
+ "q99": [1.0118417739868164],
1160
+ },
1161
+ "joints": {
1162
+ "max": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1163
+ "mean": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1164
+ "min": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1165
+ "q01": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1166
+ "q99": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1167
+ "std": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1168
+ },
1169
+ "translation": {
1170
+ "max": [0.5862360596656799, 0.4034728705883026, 0.3568263053894043],
1171
+ "mean": [
1172
+ 0.309032678604126,
1173
+ 0.03403777256608009,
1174
+ 0.061277542263269424,
1175
+ ],
1176
+ "min": [
1177
+ -0.04167502000927925,
1178
+ -0.2889411449432373,
1179
+ -0.13934996724128723,
1180
+ ],
1181
+ "q01": [
1182
+ 0.1711955964565277,
1183
+ -0.15639324486255646,
1184
+ -0.048255354166030884,
1185
+ ],
1186
+ "q99": [
1187
+ 0.4604376256465912,
1188
+ 0.24112474918365479,
1189
+ 0.18886254727840424,
1190
+ ],
1191
+ "std": [
1192
+ 0.0635896623134613,
1193
+ 0.09153717756271362,
1194
+ 0.049334850162267685,
1195
+ ],
1196
+ },
1197
+ },
1198
+ "bridge_orig": {
1199
+ "euler": {
1200
+ "max": [3.141592653589793, 1.570796251296997, 3.141204357147217],
1201
+ "mean": [
1202
+ -0.25754162314671525,
1203
+ -0.12370228389510128,
1204
+ 0.1620053749182691,
1205
+ ],
1206
+ "min": [-3.141592653492551, -1.4832241535186768, -3.14153790473938],
1207
+ "q01": [-3.138795563420751, -0.56544608771801, -1.4952478170394896],
1208
+ "q99": [3.138720980629329, 0.2677614077925682, 2.0032371997833236],
1209
+ "std": [3.0257414011616577, 0.1622662085147332, 0.6404942954645315],
1210
+ },
1211
+ "gripper": {
1212
+ "max": [1.0370277166366577],
1213
+ "min": [0.04637829214334488],
1214
+ "q01": [0.05192930996417999],
1215
+ "q99": [1.0118417739868164],
1216
+ },
1217
+ "joints": {
1218
+ "max": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1219
+ "mean": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1220
+ "min": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1221
+ "q01": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1222
+ "q99": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1223
+ "std": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1224
+ },
1225
+ "translation": {
1226
+ "max": [0.5862360596656799, 0.4034728705883026, 0.3568263053894043],
1227
+ "mean": [
1228
+ 0.309032678604126,
1229
+ 0.03403777256608009,
1230
+ 0.061277542263269424,
1231
+ ],
1232
+ "min": [
1233
+ -0.04167502000927925,
1234
+ -0.2889411449432373,
1235
+ -0.13934996724128723,
1236
+ ],
1237
+ "q01": [
1238
+ 0.1711955964565277,
1239
+ -0.15639324486255646,
1240
+ -0.048255354166030884,
1241
+ ],
1242
+ "q99": [
1243
+ 0.4604376256465912,
1244
+ 0.24112474918365479,
1245
+ 0.18886254727840424,
1246
+ ],
1247
+ "std": [
1248
+ 0.0635896623134613,
1249
+ 0.09153717756271362,
1250
+ 0.049334850162267685,
1251
+ ],
1252
+ },
1253
+ },
1254
+ "droid": {
1255
+ "euler": {
1256
+ "max": [3.141592502593994, 1.5705928802490234, 3.1415867805480957],
1257
+ "mean": [
1258
+ 0.3140628098409554,
1259
+ -0.09296274023036387,
1260
+ -0.07227215454779846,
1261
+ ],
1262
+ "min": [
1263
+ -3.141592502593994,
1264
+ -1.5691150426864624,
1265
+ -3.1415374279022217,
1266
+ ],
1267
+ "q01": [
1268
+ -3.1378602981567383,
1269
+ -1.2125312042236327,
1270
+ -2.1614069032669065,
1271
+ ],
1272
+ "q99": [3.137854380607605, 0.9200375998020163, 1.9367506909370364],
1273
+ "std": [2.926265757944871, 0.363273475703332, 0.7576065217938824],
1274
+ },
1275
+ "gripper": {
1276
+ "max": [1.0],
1277
+ "min": [0.0],
1278
+ "q01": [0.0],
1279
+ "q99": [0.9911894202232361],
1280
+ },
1281
+ "joints": {
1282
+ "max": [
1283
+ 2.668445110321045,
1284
+ 1.5691218376159668,
1285
+ 2.666306734085083,
1286
+ -0.3114914000034332,
1287
+ 2.6624162197113037,
1288
+ 4.28157901763916,
1289
+ 2.752457857131958,
1290
+ ],
1291
+ "mean": [
1292
+ 0.023137084334640106,
1293
+ 0.2704989977282293,
1294
+ -0.01451389357228282,
1295
+ -2.018709403792315,
1296
+ -0.042720520800030394,
1297
+ 2.350281188152209,
1298
+ 0.12424663946659845,
1299
+ ],
1300
+ "min": [
1301
+ -2.6536705493927,
1302
+ -1.547789216041565,
1303
+ -2.6781487464904785,
1304
+ -2.9409868717193604,
1305
+ -2.6705946922302246,
1306
+ 0.24893812835216522,
1307
+ -2.7615714073181152,
1308
+ ],
1309
+ "q01": [
1310
+ -0.9026106441020965,
1311
+ -0.8547340619564057,
1312
+ -0.9028875434398651,
1313
+ -2.7698556280136106,
1314
+ -1.6851656341552732,
1315
+ 1.2335169839859008,
1316
+ -1.9587260699272155,
1317
+ ],
1318
+ "q99": [
1319
+ 0.9569852340221403,
1320
+ 1.4148830294609054,
1321
+ 0.7693877756595566,
1322
+ -0.4545914208889008,
1323
+ 1.5623322343826267,
1324
+ 3.475611729621887,
1325
+ 2.263479118347167,
1326
+ ],
1327
+ "std": [
1328
+ 0.31695080251469465,
1329
+ 0.49522214687158767,
1330
+ 0.27993538230553827,
1331
+ 0.478161574676113,
1332
+ 0.4969961591445458,
1333
+ 0.45101008525403846,
1334
+ 0.7287264344068457,
1335
+ ],
1336
+ },
1337
+ "translation": {
1338
+ "max": [0.8575563430786133, 0.799155592918396, 1.0043904781341553],
1339
+ "mean": [
1340
+ 0.5283099395864883,
1341
+ 0.005363794653877434,
1342
+ 0.3120132207021294,
1343
+ ],
1344
+ "min": [
1345
+ -0.15604186058044434,
1346
+ -0.827903687953949,
1347
+ -0.2347021996974945,
1348
+ ],
1349
+ "q01": [
1350
+ 0.26669957995414734,
1351
+ -0.43774398624897004,
1352
+ -0.048167889714241026,
1353
+ ],
1354
+ "q99": [0.7774086785316463, 0.428325751423835, 0.776091011762619],
1355
+ "std": [
1356
+ 0.1148424841779685,
1357
+ 0.17489566608140428,
1358
+ 0.16541062032731538,
1359
+ ],
1360
+ },
1361
+ },
1362
+ "roboset": {
1363
+ "euler": {
1364
+ "max": [3.1415449294818236, 1.5705575529715636, 3.141527342124582],
1365
+ "mean": [
1366
+ -0.0398455755412464,
1367
+ 1.0518070390619125,
1368
+ -0.015345692503002759,
1369
+ ],
1370
+ "min": [
1371
+ -3.1415813300509536,
1372
+ -1.5222832468962035,
1373
+ -3.141575300866071,
1374
+ ],
1375
+ "q01": [
1376
+ -2.9414386317311187,
1377
+ -0.24976770655101155,
1378
+ -2.985256521212579,
1379
+ ],
1380
+ "q99": [2.9380437893235993, 1.5403010739503078, 2.9746912523985025],
1381
+ "std": [1.7866587696177456, 0.40620530263065, 1.7288511340250616],
1382
+ },
1383
+ "gripper": {
1384
+ "max": [0.83056640625],
1385
+ "min": [0.0001499652862548828],
1386
+ "q01": [0.0001499652862548828],
1387
+ "q99": [0.82666015625],
1388
+ },
1389
+ "joints": {
1390
+ "max": [
1391
+ 0.96240234375,
1392
+ 1.1162109375,
1393
+ 1.1064453125,
1394
+ -0.98095703125,
1395
+ 2.30859375,
1396
+ 1.576171875,
1397
+ 1.7412109375,
1398
+ ],
1399
+ "mean": [
1400
+ 0.005913593806326389,
1401
+ 0.1877261847257614,
1402
+ 0.04653879255056381,
1403
+ -2.0529513359069824,
1404
+ -0.011298442259430885,
1405
+ 0.6185526251792908,
1406
+ -0.01701134257018566,
1407
+ ],
1408
+ "min": [
1409
+ -0.8330078125,
1410
+ -0.74658203125,
1411
+ -0.8642578125,
1412
+ -2.892578125,
1413
+ -1.390625,
1414
+ -0.24658203125,
1415
+ -2.953125,
1416
+ ],
1417
+ "q01": [
1418
+ -0.41015625,
1419
+ -0.5302734375,
1420
+ -0.6455078125,
1421
+ -2.57421875,
1422
+ -0.76416015625,
1423
+ -0.0386962890625,
1424
+ -1.435546875,
1425
+ ],
1426
+ "q99": [
1427
+ 0.66455078125,
1428
+ 0.9501953125,
1429
+ 0.7529296875,
1430
+ -1.251953125,
1431
+ 0.75244140625,
1432
+ 1.2314453125,
1433
+ 1.384765625,
1434
+ ],
1435
+ "std": [
1436
+ 0.17915399372577667,
1437
+ 0.32234326004981995,
1438
+ 0.26069700717926025,
1439
+ 0.31767210364341736,
1440
+ 0.205329030752182,
1441
+ 0.33385637402534485,
1442
+ 0.6263682842254639,
1443
+ ],
1444
+ },
1445
+ "translation": {
1446
+ "max": [0.5747738480567932, 0.3972920775413513, 0.7443570494651794],
1447
+ "mean": [
1448
+ 0.3331542909145355,
1449
+ 0.019357483834028244,
1450
+ 0.37330344319343567,
1451
+ ],
1452
+ "min": [
1453
+ 0.09978063404560089,
1454
+ -0.29593944549560547,
1455
+ 0.10065606236457825,
1456
+ ],
1457
+ "q01": [
1458
+ 0.18437016010284424,
1459
+ -0.25699371099472046,
1460
+ 0.15134164690971375,
1461
+ ],
1462
+ "q99": [0.543661892414093, 0.29646238684654236, 0.6682320833206177],
1463
+ "std": [
1464
+ 0.07849054038524628,
1465
+ 0.12241040915250778,
1466
+ 0.1460595279932022,
1467
+ ],
1468
+ },
1469
+ },
1470
+ }
1471
+
1472
+ @cached_property
1473
+ def _control_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]:
1474
+ if is_global_norm(self.config.rotation_norm) and is_global_norm(self.config.translation_norm):
1475
+ return {}
1476
+ with open(self.config.control_stats_path, "r") as file:
1477
+ stats = yaml.safe_load(file)
1478
+ if self.config.delta_controls:
1479
+ if self.control_io_config.future_controls_sequence_stride_sec is None:
1480
+ horizon = 0.0
1481
+ else:
1482
+ horizon = self.control_io_config.future_controls_sequence_stride_sec
1483
+ elif self.control_io_config.future_controls_sequence_stride_sec is None:
1484
+ if self.control_io_config.future_controls_sequence_length == 1:
1485
+ horizon = 0.0
1486
+ else:
1487
+ raise NotImplementedError()
1488
+ else:
1489
+ horizon = (
1490
+ self.control_io_config.future_controls_sequence_length
1491
+ * self.control_io_config.future_controls_sequence_stride_sec
1492
+ )
1493
+ key = f"horizon_{round(horizon, 2)}s"
1494
+ if key in stats:
1495
+ stats = stats[key]
1496
+ else:
1497
+ raise ValueError(
1498
+ f"Missing control statistics key {key} for future_controls_sequence_length={self.config.control_io_config.future_controls_sequence_length} future_controls_sequence_stride_sec={self.config.control_io_config.future_controls_sequence_stride_sec}. Available keys: [{stats.keys()}]"
1499
+ )
1500
+ return stats
1501
+
1502
+ @cached_property
1503
+ def dataset_names(self) -> List[str]:
1504
+ if (
1505
+ is_global_norm(self.config.rotation_norm)
1506
+ and is_global_norm(self.config.obs_rotation_norm)
1507
+ and is_global_norm(self.config.translation_norm)
1508
+ and is_global_norm(self.config.obs_translation_norm)
1509
+ ):
1510
+ return ["ANY"]
1511
+ return list(set(self._control_stats.keys()) | set(self._observation_stats.keys()))
1512
+
1513
+
1514
+ def delta_to_relative_translations(translation_sequence: torch.Tensor) -> torch.Tensor:
1515
+ """
1516
+ Transform a sequence of translation vectors encoded w.r.t. PREVIOUS frame in the sequence to encoding
1517
+ w.r.t. the 0-th element preceding the sequence
1518
+ Ex:
1519
+ Sequence of points: T1, T2, T3, T4
1520
+ `translation_sequence` contains the vectors: T0T1, T1T2, T2T3, T3T4, where T0 is the base frame,
1521
+ implicitly encoded in T0T1
1522
+ Output: T0T1, T0T2, T0T3, T0T4
1523
+
1524
+ Args:
1525
+ translation_sequence: torch.Tensor of shape [..., S, 3], containing the translation vectors, where S
1526
+ corresponds to the sequence dimension
1527
+ Returns:
1528
+ torch.Tensor of the same shape as translation_sequence, containing delta translations
1529
+ """
1530
+ assert translation_sequence.ndim >= 3, translation_sequence.shape
1531
+ delta_translations = torch.cumsum(translation_sequence, dim=-2)
1532
+ return delta_translations
1533
+
1534
+
1535
+ class RegressionProcessor(VLAMProcessor):
1536
+ def policy_control_plan_from_model_target(
1537
+ self, target: RoboticsTarget, dataset_name: np.ndarray
1538
+ ) -> RoboticsControlPlan:
1539
+ translation_m = self.unnormalize(target.translation, dataset_name=dataset_name, key="translation")
1540
+ rotation = self.unnormalize(target.rotation, dataset_name=dataset_name, key="rotation")
1541
+ rotmat = convert_rotation(rotation, RotationFormat.ROTMAT)
1542
+ gripper_prob = target.gripper
1543
+ if self.config.delta_controls:
1544
+ translation_m = delta_to_relative_translations(translation_m)
1545
+ rotmat = delta_to_relative_rotations(rotmat)
1546
+ return RoboticsControlPlan(
1547
+ translation_m=translation_m,
1548
+ rotmat=rotmat,
1549
+ gripper_prob=gripper_prob,
1550
+ valid_mask=target.valid_mask,
1551
+ )
1552
+
1553
+ def policy_control_plan_from_model_output(
1554
+ self,
1555
+ model_output: RoboticsOutput,
1556
+ dataset_name: np.ndarray,
1557
+ valid_mask: torch.Tensor,
1558
+ ) -> RoboticsControlPlan:
1559
+ """Called during inference to create control plan from model output"""
1560
+ translation_m = self.unnormalize(
1561
+ model_output.translation, dataset_name=dataset_name, key="translation"
1562
+ )
1563
+ rotation = self.unnormalize(model_output.rotation, dataset_name=dataset_name, key="rotation")
1564
+ rotmat = convert_rotation(rotation, RotationFormat.ROTMAT, autonorm=True)
1565
+ gripper_prob = torch.sigmoid(model_output.gripper)
1566
+ if self.config.delta_controls:
1567
+ translation_m = delta_to_relative_translations(translation_m)
1568
+ rotmat = delta_to_relative_rotations(rotmat)
1569
+ return RoboticsControlPlan(
1570
+ translation_m=translation_m,
1571
+ rotmat=rotmat,
1572
+ gripper_prob=gripper_prob,
1573
+ valid_mask=valid_mask,
1574
+ )
1575
+
1576
+
1577
+ class PiZeroFlowMatchingProcessor(RegressionProcessor):
1578
+ def __init__(self, **kwargs):
1579
+ super().__init__(**kwargs)
1580
+ self.generator: torch.Generator = torch.Generator()
1581
+
1582
+ @cached_property
1583
+ def beta_distribution(self) -> torch.distributions.Beta:
1584
+ return torch.distributions.Beta(
1585
+ self.config.distribution_hyperparams.get("alpha", 1.5),
1586
+ self.config.distribution_hyperparams.get("beta", 1.0),
1587
+ )
1588
+
1589
+ def create_input(self, *args, **kwargs) -> RoboticsFlowInput:
1590
+ """In practice used only during inference"""
1591
+ inputs = super().create_input(*args, **kwargs)
1592
+ flow_input: FlowInput = self.sample_t0_input(batch_size=1, device=torch.device("cpu"))
1593
+ inputs = RoboticsFlowInput(**inputs.as_json(), flow_input=flow_input[0, ...])
1594
+ return inputs
1595
+
1596
+ def sample_timestep(self, batch_size: int) -> torch.Tensor:
1597
+ if self.config.timestep_distribution.lower() == "uniform":
1598
+ eps = 1e-05
1599
+ sample = (torch.rand(1, generator=self.generator) + torch.arange(batch_size) / batch_size) % (
1600
+ 1 - eps
1601
+ )
1602
+ elif self.config.timestep_distribution.lower() == "beta":
1603
+ sample = self.beta_distribution.sample([batch_size, 1, 1])
1604
+ sample = (1 - self.config.sig_min) * (1 - sample)
1605
+ else:
1606
+ raise NotImplementedError(self.config.timestep_distribution)
1607
+ sample = sample.view(batch_size, 1, 1)
1608
+ return sample
1609
+
1610
+ def _psi_t(self, timestep: torch.Tensor, x_0: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor:
1611
+ return (1 - (1 - self.config.sig_min) * timestep) * x_0 + timestep * x_1
1612
+
1613
+ def _dpsi_dt(self, x_0: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor:
1614
+ return x_1 - (1 - self.config.sig_min) * x_0
1615
+
1616
+ def sample_t0_input(self, batch_size: int, device: torch.device) -> FlowInput:
1617
+ if self.config.r0_distribution == "normal":
1618
+ controls_t0 = torch.randn(
1619
+ [
1620
+ batch_size,
1621
+ self.config.control_io_config.future_controls_sequence_length,
1622
+ 3 + self.rotation_components + 1,
1623
+ ],
1624
+ generator=self.generator,
1625
+ ).to(device=device)
1626
+ (translation_t0, rotation_t0, gripper_t0) = torch.split(
1627
+ controls_t0, [3, self.rotation_components, 1], dim=-1
1628
+ )
1629
+ rotation_t0 = normalize_rotation(rotation_t0)
1630
+ elif self.config.r0_distribution == "uniform":
1631
+ controls_t0 = torch.randn(
1632
+ [
1633
+ batch_size,
1634
+ self.config.control_io_config.future_controls_sequence_length,
1635
+ 4,
1636
+ ],
1637
+ generator=self.generator,
1638
+ ).to(device=device)
1639
+ (translation_t0, gripper_t0) = torch.split(controls_t0, [3, 1], dim=-1)
1640
+ rotation_t0 = convert_rotation(
1641
+ roma.random_unitquat(
1642
+ (
1643
+ batch_size,
1644
+ self.config.control_io_config.future_controls_sequence_length,
1645
+ ),
1646
+ device=device,
1647
+ ),
1648
+ self.config.rotation_format,
1649
+ )
1650
+ else:
1651
+ raise NotImplementedError(self.config.r0_distribution)
1652
+ if self.config.rotation_format == RotationFormat.QUATERNION:
1653
+ rotation_t0 = quaternion_half_cover(rotation_t0)
1654
+ timestep = torch.zeros([batch_size, 1, 1], device=device)
1655
+ return FlowInput(
1656
+ timestep=timestep,
1657
+ translation_t0=translation_t0,
1658
+ rotation_t0=rotation_t0,
1659
+ gripper_t0=gripper_t0,
1660
+ translation_t=None,
1661
+ rotation_t=None,
1662
+ gripper_t=None,
1663
+ )
1664
+
1665
+ def policy_control_plan_from_model_output(
1666
+ self,
1667
+ model_output: RoboticsOutput,
1668
+ dataset_name: np.ndarray,
1669
+ valid_mask: torch.Tensor,
1670
+ ) -> RoboticsControlPlan:
1671
+ if self.config.translation_norm == Normalization.NONE or is_mean_norm(self.config.translation_norm):
1672
+ model_output = model_output.replace(translation=torch.clamp(model_output.translation, -1, 1))
1673
+ if self.config.rotation_norm == Normalization.NONE or is_mean_norm(self.config.rotation_norm):
1674
+ model_output = model_output.replace(rotation=torch.clamp(model_output.rotation, -1, 1))
1675
+ control_plan = super().policy_control_plan_from_model_output(model_output, dataset_name, valid_mask)
1676
+ control_plan = control_plan.replace(gripper_prob=torch.clamp(model_output.gripper, 0, 1))
1677
+ return control_plan
1678
+
1679
+
1680
+ def make_causal_mask(shape: Sequence[int]) -> torch.Tensor:
1681
+ """
1682
+ Create a causal attention mask of shape `shape`
1683
+ Args:
1684
+ shape: Shape of the output mask, the last two dimensions correspond to [query_seq_len, kv_seq_len]
1685
+ Returns:
1686
+ torch.Tensor of dtype torch.bool. False values indicate that the row (i.e. query) can't attend
1687
+ to the corresponding column (i.e. key)
1688
+
1689
+ Example:
1690
+ shape = (3, 5) -> Mask the upper triangular part
1691
+ [
1692
+ [ 1, 0, 0, 0, 0],
1693
+ [ 1, 1, 0, 0, 0],
1694
+ [ 1, 1, 1, 0, 0]
1695
+ ]
1696
+ """
1697
+ return torch.tril(torch.ones(shape, dtype=torch.bool), diagonal=0)
1698
+
1699
+
1700
+ def enable_full_attn_blocks(attn_mask: torch.Tensor, full_attn: torch.Tensor) -> torch.Tensor:
1701
+ """
1702
+ Enable full bi-directional attention in `attn_mask` inside specific blocks
1703
+ Args:
1704
+ attn_mask: Existing attention mask of shape [..., query_seq_len, kv_seq_len] and dtype torch.bool
1705
+ where False values indicate disabled attention
1706
+ full_attn: torch.Tensor of shape [query_seq_len], dtype torch.bool. Blocks of True values indicate
1707
+ positions where full bi-directional attention should be enabled
1708
+
1709
+ Example:
1710
+ 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
1711
+ 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
1712
+ 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
1713
+ 1, 1, 1, 1, 0, 0, 0, 0, -> 1, 1, 1, 1, 0, 0, 0, 0,
1714
+ 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
1715
+ 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
1716
+ 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,
1717
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1718
+
1719
+ """
1720
+ assert full_attn.dtype == torch.bool, full_attn.dtype
1721
+ assert full_attn.ndim == 1, full_attn.shape
1722
+ assert full_attn.shape[0] == attn_mask.shape[-2], f"{full_attn.shape[0]}, {attn_mask.shape}"
1723
+ if attn_mask.shape[-1] != attn_mask.shape[-2]:
1724
+ raise NotImplementedError("Only self-attention supported right now.")
1725
+ x = full_attn.view(-1, 1) & full_attn.view(1, -1)
1726
+ x = x | make_causal_mask([full_attn.shape[0], full_attn.shape[0]])
1727
+ x = torch.cumprod(x, dim=1).to(dtype=torch.bool)
1728
+ x = x & x.permute(1, 0)
1729
+ mask_positions = torch.sum(x, dim=0) == 1 & ~full_attn
1730
+ mask_indices = torch.where(mask_positions)[0]
1731
+ x[mask_indices, mask_indices] = 0
1732
+ attn_mask = attn_mask | expand_dims(x, ndim=attn_mask.ndim, order=[-1, 1, 1])
1733
+ return attn_mask
1734
+
1735
+
1736
+ IGNORE_INDEX = -100
1737
+
1738
+
1739
+ class PaliGemmaProcessor(VLMProcessor):
1740
+ def __init__(
1741
+ self,
1742
+ config: PaliGemmaProcessorConfig,
1743
+ hf_processor: transformers.models.paligemma.processing_paligemma.PaliGemmaProcessor,
1744
+ **kwargs,
1745
+ ):
1746
+ del kwargs
1747
+ super().__init__(config)
1748
+ self.hf_processor = hf_processor
1749
+ self.hf_processor.image_processor.size = dict(self.config.image_sizes["main"].as_json())
1750
+ self.hf_processor.image_seq_length = self.config.num_image_tokens["main"]
1751
+ self.hf_processor.image_processor.image_seq_length = self.config.num_image_tokens["main"]
1752
+ self.bos_id: int = self.tokenizer.bos_token_id
1753
+ self.eos_id: int = self.tokenizer.eos_token_id
1754
+ self.sep_token = "\n"
1755
+ self.sep_id: int = self.tokenizer(
1756
+ self.sep_token,
1757
+ padding=False,
1758
+ add_special_tokens=False,
1759
+ return_attention_mask=False,
1760
+ )["input_ids"][0]
1761
+ self.image_token_id: int = self.tokenizer(
1762
+ self.config.image_token,
1763
+ padding=False,
1764
+ add_special_tokens=False,
1765
+ return_attention_mask=False,
1766
+ )["input_ids"][0]
1767
+ self.image_tokens: list[int] = [self.image_token_id] * sum(self.config.num_image_tokens.values())
1768
+ self.bbox_pattern = re.compile(
1769
+ "\\[(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+)\\]"
1770
+ )
1771
+
1772
+ def preprocess_inputs(
1773
+ self, chat: List[str], images: Dict[str, List[PIL.Image.Image]]
1774
+ ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
1775
+ """
1776
+ Based on PaliGemma paper https://arxiv.org/pdf/2407.07726 and example code at
1777
+ https://ai.google.dev/gemma/docs/paligemma/fine-tuning-paligemma#create_model_inputs
1778
+ Chat must be always made of separate messages from user and model, always starting with user
1779
+
1780
+ <image><image> ... <bos><instruction><sep><assistant><sep><instruction><sep><assistant>...<eos>
1781
+
1782
+ Args:
1783
+ chat: List[str] of even size where each entry corresponds to a different turn in the conversation
1784
+ images: Dict[str, List[PIL.Image.Image]] where different cameras correspond to different keys
1785
+ in the Dict and the List corresponds to history of images
1786
+ """
1787
+ for key, value in images.items():
1788
+ if not isinstance(value, list):
1789
+ raise TypeError(f"Camera {key} contains values of type {type(value)} instead of list")
1790
+ (input_ids, target_ids) = ([], [])
1791
+ for i, text in enumerate(chat):
1792
+ text = text.replace(self.sep_token, " ").replace("<image>", "")
1793
+ text = self.bbox_pattern.sub(self._bbox_to_loc_tokens, text)
1794
+ turn_input_ids: List[int] = self.tokenizer(
1795
+ text,
1796
+ padding=False,
1797
+ add_special_tokens=False,
1798
+ return_attention_mask=False,
1799
+ )["input_ids"]
1800
+ if i % 2 == 0:
1801
+ turn_target_ids = [IGNORE_INDEX] * len(turn_input_ids)
1802
+ else:
1803
+ turn_target_ids = turn_input_ids
1804
+ if i != len(chat) - 1:
1805
+ turn_input_ids = turn_input_ids + [self.sep_id]
1806
+ turn_target_ids = turn_target_ids + [IGNORE_INDEX]
1807
+ input_ids = input_ids + turn_input_ids
1808
+ target_ids = target_ids + turn_target_ids
1809
+ input_ids = [self.bos_id] + input_ids + [self.eos_id]
1810
+ target_ids = [IGNORE_INDEX] + target_ids + [self.eos_id]
1811
+ image_tokens = self.image_tokens
1812
+ if self.config.max_language_tokens > 0:
1813
+ input_ids = input_ids[: self.config.max_language_tokens]
1814
+ target_ids = target_ids[: self.config.max_language_tokens]
1815
+ input_ids = image_tokens + input_ids
1816
+ target_ids = [IGNORE_INDEX] * len(image_tokens) + target_ids
1817
+ input_ids = torch.tensor(input_ids, dtype=torch.int64)
1818
+ target_ids = torch.tensor(target_ids, dtype=torch.int64)
1819
+ image_tensors: Dict[str, torch.Tensor] = {
1820
+ f"{camera_name}.siglip": self.hf_processor.image_processor(
1821
+ camera_images,
1822
+ size=self.config.image_sizes[camera_name].as_json(),
1823
+ return_tensors="pt",
1824
+ )["pixel_values"]
1825
+ for (camera_name, camera_images) in images.items()
1826
+ }
1827
+ attn_mask = make_causal_mask([len(input_ids), len(input_ids)])
1828
+ attn_mask = enable_full_attn_blocks(attn_mask, full_attn=target_ids == IGNORE_INDEX)
1829
+ return {
1830
+ "input_ids": input_ids,
1831
+ "target_ids": target_ids,
1832
+ "images": image_tensors,
1833
+ "attn_mask": attn_mask,
1834
+ }
1835
+
1836
+ @property
1837
+ def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
1838
+ return self.hf_processor.tokenizer
1839
+
1840
+ @staticmethod
1841
+ def _bbox_to_loc_tokens(match: str) -> str:
1842
+ """
1843
+ https://developers.googleblog.com/en/gemma-explained-paligemma-architecture/
1844
+ """
1845
+ floats = list(map(float, match.groups()))
1846
+ transformed = [f"<loc{np.clip(round(num * 1024), 0, 1023):04d}>" for num in floats]
1847
+ return f"[{', '.join(transformed)}]"
1848
+
1849
+ @property
1850
+ def image_sizes(self) -> Dict[str, ImageSizeConfig]:
1851
+ return self.config.image_sizes
1852
+
1853
+
1854
+ class PaliGemmaDepthProcessor(PaliGemmaProcessor):
1855
+ def __init__(
1856
+ self,
1857
+ config: PaliGemmaProcessorConfig,
1858
+ hf_processor: transformers.models.paligemma.processing_paligemma.PaliGemmaProcessor,
1859
+ depth_tokens: int,
1860
+ ):
1861
+ super().__init__(config, hf_processor)
1862
+ vocab_size = len(self.tokenizer)
1863
+ self.depth_token_ids = np.arange(vocab_size - depth_tokens, vocab_size)
1864
+ self.depth_input_transforms = {
1865
+ camera_name: torchvision.transforms.v2.Compose(
1866
+ [
1867
+ torchvision.transforms.v2.Resize(
1868
+ size=(camera_image_size.height, camera_image_size.width),
1869
+ interpolation=torchvision.transforms.v2.InterpolationMode.BICUBIC,
1870
+ max_size=None,
1871
+ antialias=True,
1872
+ ),
1873
+ torchvision.transforms.v2.ToTensor(),
1874
+ torchvision.transforms.v2.Normalize(
1875
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
1876
+ ),
1877
+ ]
1878
+ )
1879
+ for (camera_name, camera_image_size) in self.config.image_sizes.items()
1880
+ }
1881
+
1882
+ def preprocess_inputs(
1883
+ self, chat: List[str], images: Dict[str, List[PIL.Image.Image]]
1884
+ ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
1885
+ inputs = super().preprocess_inputs(chat=chat, images=images)
1886
+ depth_images: Dict[str, torch.Tensor] = {
1887
+ f"{camera_name}.depth": torch.stack(
1888
+ self.depth_input_transforms[camera_name](camera_images), dim=0
1889
+ )
1890
+ for (camera_name, camera_images) in images.items()
1891
+ }
1892
+ inputs["images"] = {**inputs["images"], **depth_images}
1893
+ return inputs
1894
+
1895
+ @property
1896
+ def num_depth_tokens(self) -> int:
1897
+ return len(self.depth_token_ids)