HaomingSong commited on
Commit
93c2840
·
verified ·
1 Parent(s): 23cba9a

Upload modeling_pi0.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_pi0.py +826 -0
modeling_pi0.py ADDED
@@ -0,0 +1,826 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ π0: A Vision-Language-Action Flow Model for General Robot Control
19
+
20
+ [Paper](https://www.physicalintelligence.company/download/pi0.pdf)
21
+ [Jax code](https://github.com/Physical-Intelligence/openpi)
22
+
23
+ Designed by Physical Intelligence. Ported from Jax by Hugging Face.
24
+
25
+ Install pi0 extra dependencies:
26
+ ```bash
27
+ pip install -e ".[pi0]"
28
+ ```
29
+
30
+ Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
31
+ ```bash
32
+ python lerobot/scripts/train.py \
33
+ --policy.path=lerobot/pi0 \
34
+ --dataset.repo_id=danaaubakirova/koch_test
35
+ ```
36
+
37
+ Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
38
+ pretrained with VLM default parameters before pi0 finetuning:
39
+ ```bash
40
+ python lerobot/scripts/train.py \
41
+ --policy.type=pi0 \
42
+ --dataset.repo_id=danaaubakirova/koch_test
43
+ ```
44
+
45
+ Example of using the pi0 pretrained model outside LeRobot training framework:
46
+ ```python
47
+ policy = Pi0Policy.from_pretrained("lerobot/pi0")
48
+ ```
49
+
50
+ """
51
+
52
+ import math
53
+ from collections import deque
54
+
55
+ import torch
56
+ import torch.nn.functional as F # noqa: N812
57
+ from configuration_pi0 import PI0Config
58
+ from lerobot.common.constants import ACTION, OBS_ROBOT
59
+ from lerobot.common.policies.normalize import Normalize, Unnormalize
60
+ from lerobot.common.policies.pretrained import PreTrainedPolicy
61
+ from lerobot.common.utils.utils import get_safe_dtype
62
+ from paligemma_with_expert import (
63
+ PaliGemmaWithExpertConfig,
64
+ PaliGemmaWithExpertModel,
65
+ )
66
+ from torch import Tensor, nn
67
+ from transformers import AutoTokenizer
68
+
69
+
70
+ def create_sinusoidal_pos_embedding(
71
+ time: torch.tensor,
72
+ dimension: int,
73
+ min_period: float,
74
+ max_period: float,
75
+ device="cpu",
76
+ ) -> Tensor:
77
+ """Computes sine-cosine positional embedding vectors for scalar positions."""
78
+ if dimension % 2 != 0:
79
+ raise ValueError(f"dimension ({dimension}) must be divisible by 2")
80
+
81
+ if time.ndim != 1:
82
+ raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
83
+
84
+ dtype = get_safe_dtype(torch.float64, device.type)
85
+ fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
86
+ period = min_period * (max_period / min_period) ** fraction
87
+
88
+ # Compute the outer product
89
+ scaling_factor = 1.0 / period * 2 * math.pi
90
+ sin_input = scaling_factor[None, :] * time[:, None]
91
+ pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
92
+ return pos_emb
93
+
94
+
95
+ def sample_beta(alpha, beta, bsize, device):
96
+ gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
97
+ gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
98
+ return gamma1 / (gamma1 + gamma2)
99
+
100
+
101
+ def make_att_2d_masks(pad_masks, att_masks):
102
+ """Copied from big_vision.
103
+
104
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
105
+ smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
106
+ setup several types of attention, for example:
107
+
108
+ [[1 1 1 1 1 1]]: pure causal attention.
109
+
110
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
111
+ themselves and the last 3 tokens have a causal attention. The first
112
+ entry could also be a 1 without changing behaviour.
113
+
114
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
115
+ block can attend all previous blocks and all tokens on the same block.
116
+
117
+ Args:
118
+ input_mask: bool[B, N] true if its part of the input, false if padding.
119
+ mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
120
+ it and 0 where it shares the same attention mask as the previous token.
121
+ """
122
+ if att_masks.ndim != 2:
123
+ raise ValueError(att_masks.ndim)
124
+ if pad_masks.ndim != 2:
125
+ raise ValueError(pad_masks.ndim)
126
+
127
+ cumsum = torch.cumsum(att_masks, dim=1)
128
+ att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
129
+ pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
130
+ att_2d_masks = att_2d_masks & pad_2d_masks
131
+ return att_2d_masks
132
+
133
+
134
+ def resize_with_pad(img, width, height, pad_value=-1):
135
+ # assume no-op when width height fits already
136
+ if img.ndim != 4:
137
+ raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
138
+
139
+ cur_height, cur_width = img.shape[2:]
140
+
141
+ ratio = max(cur_width / width, cur_height / height)
142
+ resized_height = int(cur_height / ratio)
143
+ resized_width = int(cur_width / ratio)
144
+ resized_img = F.interpolate(
145
+ img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
146
+ )
147
+
148
+ pad_height = max(0, int(height - resized_height))
149
+ pad_width = max(0, int(width - resized_width))
150
+
151
+ # pad on left and top of image
152
+ padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
153
+ return padded_img
154
+
155
+
156
+ def pad_vector(vector, new_dim):
157
+ """Can be (batch_size x sequence_length x features_dimension)
158
+ or (batch_size x features_dimension)
159
+ """
160
+ if vector.shape[-1] == new_dim:
161
+ return vector
162
+ shape = list(vector.shape)
163
+ current_dim = shape[-1]
164
+ shape[-1] = new_dim
165
+ new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
166
+ new_vector[..., :current_dim] = vector
167
+ return new_vector
168
+
169
+
170
+ def normalize(x, min_val, max_val):
171
+ return (x - min_val) / (max_val - min_val)
172
+
173
+
174
+ def unnormalize(x, min_val, max_val):
175
+ return x * (max_val - min_val) + min_val
176
+
177
+
178
+ def safe_arcsin(value):
179
+ # This ensures that the input stays within
180
+ # [−1,1] to avoid invalid values for arcsin
181
+ return torch.arcsin(torch.clamp(value, -1.0, 1.0))
182
+
183
+
184
+ def aloha_gripper_to_angular(value):
185
+ # Aloha transforms the gripper positions into a linear space. The following code
186
+ # reverses this transformation to be consistent with pi0 which is pretrained in
187
+ # angular space.
188
+ #
189
+ # These values are coming from the Aloha code:
190
+ # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
191
+ value = unnormalize(value, min_val=0.01844, max_val=0.05800)
192
+
193
+ # This is the inverse of the angular to linear transformation inside the Interbotix code.
194
+ def linear_to_radian(linear_position, arm_length, horn_radius):
195
+ value = (horn_radius**2 + linear_position**2 - arm_length**2) / (
196
+ 2 * horn_radius * linear_position
197
+ )
198
+ return safe_arcsin(value)
199
+
200
+ # The constants are taken from the Interbotix code.
201
+ value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
202
+
203
+ # Normalize to [0, 1].
204
+ # The values 0.4 and 1.5 were measured on an actual Trossen robot.
205
+ return normalize(value, min_val=0.4, max_val=1.5)
206
+
207
+
208
+ def aloha_gripper_from_angular(value):
209
+ # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
210
+ # Note that the units are still angular but the range is different.
211
+
212
+ # The values 0.4 and 1.5 were measured on an actual Trossen robot.
213
+ value = unnormalize(value, min_val=0.4, max_val=1.5)
214
+
215
+ # These values are coming from the Aloha code:
216
+ # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
217
+ return normalize(value, min_val=-0.6213, max_val=1.4910)
218
+
219
+
220
+ def aloha_gripper_from_angular_inv(value):
221
+ # Directly inverts the gripper_from_angular function.
222
+ value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
223
+ return normalize(value, min_val=0.4, max_val=1.5)
224
+
225
+
226
+ class PI0Policy(PreTrainedPolicy):
227
+ """Wrapper class around PI0FlowMatching model to train and run inference within LeRobot."""
228
+
229
+ config_class = PI0Config
230
+ name = "pi0"
231
+
232
+ def __init__(
233
+ self,
234
+ config: PI0Config,
235
+ dataset_stats: dict[str, dict[str, Tensor]] | None = None,
236
+ ):
237
+ """
238
+ Args:
239
+ config: Policy configuration class instance or None, in which case the default instantiation of
240
+ the configuration class is used.
241
+ dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
242
+ that they will be passed with a call to `load_state_dict` before the policy is used.
243
+ """
244
+
245
+ super().__init__(config)
246
+ config.validate_features()
247
+ self.config = config
248
+
249
+ # TODO: input / output features / normalizer for mutiple datasets
250
+ self.normalize_inputs = Normalize(
251
+ config.input_features, config.normalization_mapping, dataset_stats
252
+ )
253
+ self.normalize_targets = Normalize(
254
+ config.output_features, config.normalization_mapping, dataset_stats
255
+ )
256
+ self.unnormalize_outputs = Unnormalize(
257
+ config.output_features, config.normalization_mapping, dataset_stats
258
+ )
259
+
260
+ # self.language_tokenizer = AutoTokenizer.from_pretrained("/cpfs01/shared/optimal/vla_next/pretrained/pi0", local_files_only=True)
261
+ self.language_tokenizer = None
262
+ self.model = PI0FlowMatching(config)
263
+
264
+ self.reset()
265
+
266
+ def reset(self):
267
+ """This should be called whenever the environment is reset."""
268
+ self._action_queue = deque([], maxlen=self.config.n_action_steps)
269
+
270
+ def get_optim_params(self) -> dict:
271
+ return self.parameters()
272
+
273
+ @torch.no_grad
274
+ def select_action(
275
+ self, batch: dict[str, Tensor], noise: Tensor | None = None
276
+ ) -> Tensor:
277
+ """Select a single action given environment observations.
278
+
279
+ This method wraps `select_actions` in order to return one action at a time for execution in the
280
+ environment. It works by managing the actions in a queue and only calling `select_actions` when the
281
+ queue is empty.
282
+ """
283
+ self.eval()
284
+
285
+ if self.config.adapt_to_pi_aloha:
286
+ batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
287
+
288
+ batch = self.normalize_inputs(batch)
289
+
290
+ # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
291
+ # querying the policy.
292
+ images, img_masks = self.prepare_images(batch)
293
+ state = self.prepare_state(batch)
294
+ lang_tokens, lang_masks = self.prepare_language(batch)
295
+
296
+ actions = self.model.sample_actions(
297
+ images, img_masks, lang_tokens, lang_masks, state, noise=noise
298
+ )
299
+
300
+ # Unpad actions
301
+ original_action_dim = self.config.action_feature.shape[0]
302
+ actions = actions[:, :, :original_action_dim]
303
+
304
+ actions = self.unnormalize_outputs({"action": actions})["action"]
305
+
306
+ if self.config.adapt_to_pi_aloha:
307
+ actions = self._pi_aloha_encode_actions(actions)
308
+ return actions
309
+
310
+ def forward(
311
+ self, batch: dict[str, Tensor], noise=None, time=None
312
+ ) -> tuple[Tensor, dict[str, Tensor]]:
313
+ """Do a full training forward pass to compute the loss"""
314
+ if self.config.adapt_to_pi_aloha:
315
+ batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
316
+ batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
317
+
318
+ batch = self.normalize_inputs(batch)
319
+ batch = self.normalize_targets(batch)
320
+
321
+ images, img_masks = self.prepare_images(batch)
322
+ state = self.prepare_state(batch)
323
+ lang_tokens, lang_masks = self.prepare_language(batch)
324
+ actions = self.prepare_action(batch)
325
+ actions_is_pad = batch.get("action_is_pad")
326
+
327
+ loss_dict = {}
328
+ losses = self.model.forward(
329
+ images, img_masks, lang_tokens, lang_masks, state, actions, noise, time
330
+ )
331
+ # loss_dict["losses_after_forward"] = losses.detach().mean().item()
332
+
333
+ if actions_is_pad is not None:
334
+ in_episode_bound = ~actions_is_pad
335
+ losses = losses * in_episode_bound.unsqueeze(-1)
336
+ # loss_dict["losses_after_in_ep_bound"] = losses.detach().mean().item()
337
+
338
+ # Remove padding
339
+ losses = losses[:, :, : self.config.max_action_dim]
340
+ # loss_dict["losses_after_rm_padding"] = losses.detach().mean().item()
341
+
342
+ # For backward pass
343
+ loss = losses.mean()
344
+ # For logging
345
+ loss_dict["l2_loss"] = loss.item()
346
+
347
+ return loss, loss_dict
348
+
349
+ def prepare_images(self, batch):
350
+ """Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
351
+ convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
352
+ """
353
+ images = []
354
+ img_masks = []
355
+
356
+ present_img_keys = [key for key in self.config.image_features if key in batch]
357
+ missing_img_keys = [
358
+ key for key in self.config.image_features if key not in batch
359
+ ]
360
+
361
+ if len(present_img_keys) == 0:
362
+ raise ValueError(
363
+ f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
364
+ )
365
+
366
+ # Preprocess image features present in the batch
367
+ for key in present_img_keys:
368
+ img = batch[key]
369
+
370
+ if self.config.resize_imgs_with_padding is not None:
371
+ img = resize_with_pad(
372
+ img, *self.config.resize_imgs_with_padding, pad_value=0
373
+ )
374
+
375
+ # Normalize from range [0,1] to [-1,1] as expacted by siglip
376
+ img = img * 2.0 - 1.0
377
+
378
+ bsize = img.shape[0]
379
+ device = img.device
380
+ mask = torch.ones(bsize, dtype=torch.bool, device=device)
381
+ images.append(img)
382
+ img_masks.append(mask)
383
+
384
+ # Create image features not present in the batch
385
+ # as fully 0 padded images.
386
+ for num_empty_cameras in range(len(missing_img_keys)):
387
+ if num_empty_cameras >= self.config.empty_cameras:
388
+ break
389
+ img = torch.ones_like(img) * -1
390
+ mask = torch.zeros_like(mask)
391
+ images.append(img)
392
+ img_masks.append(mask)
393
+
394
+ return images, img_masks
395
+
396
+ def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
397
+ """Tokenize the text input"""
398
+ device = batch[OBS_ROBOT].device
399
+ tasks = batch["task"]
400
+
401
+ # PaliGemma prompt has to end with a new line
402
+ tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
403
+
404
+ tokenized_prompt = self.language_tokenizer.__call__(
405
+ tasks,
406
+ padding="max_length",
407
+ padding_side="right",
408
+ max_length=self.config.tokenizer_max_length,
409
+ return_tensors="pt",
410
+ truncation=True,
411
+ )
412
+ lang_tokens = tokenized_prompt["input_ids"].to(device=device)
413
+ lang_masks = tokenized_prompt["attention_mask"].to(
414
+ device=device, dtype=torch.bool
415
+ )
416
+
417
+ return lang_tokens, lang_masks
418
+
419
+ def _pi_aloha_decode_state(self, state):
420
+ # Flip the joints.
421
+ for motor_idx in [1, 2, 8, 9]:
422
+ state[:, motor_idx] *= -1
423
+ # Reverse the gripper transformation that is being applied by the Aloha runtime.
424
+ for motor_idx in [6, 13]:
425
+ state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
426
+ return state
427
+
428
+ def _pi_aloha_encode_actions(self, actions):
429
+ # Flip the joints.
430
+ for motor_idx in [1, 2, 8, 9]:
431
+ actions[:, :, motor_idx] *= -1
432
+ # Reverse the gripper transformation that is being applied by the Aloha runtime.
433
+ for motor_idx in [6, 13]:
434
+ actions[:, :, motor_idx] = aloha_gripper_from_angular(
435
+ actions[:, :, motor_idx]
436
+ )
437
+ return actions
438
+
439
+ def _pi_aloha_encode_actions_inv(self, actions):
440
+ # Flip the joints again.
441
+ for motor_idx in [1, 2, 8, 9]:
442
+ actions[:, :, motor_idx] *= -1
443
+ # Reverse the gripper transformation that is being applied by the Aloha runtime.
444
+ for motor_idx in [6, 13]:
445
+ actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(
446
+ actions[:, :, motor_idx]
447
+ )
448
+ return actions
449
+
450
+ def prepare_state(self, batch):
451
+ """Pad state"""
452
+ state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim)
453
+ return state
454
+
455
+ def prepare_action(self, batch):
456
+ """Pad action"""
457
+ actions = pad_vector(batch[ACTION], self.config.max_action_dim)
458
+ return actions
459
+
460
+ def _save_pretrained(self, save_directory) -> None:
461
+ super()._save_pretrained(save_directory)
462
+ print(f"Saving the language tokenizer to {save_directory} ...")
463
+ self.language_tokenizer.save_pretrained(save_directory)
464
+
465
+ print(f"Copying config and model to {save_directory} ...")
466
+ import shutil
467
+
468
+ files = [
469
+ "pi0/configuration_pi0.py",
470
+ "pi0/flex_attention.py",
471
+ "pi0/modeling_pi0.py",
472
+ "pi0/paligemma_with_expert.py",
473
+ ]
474
+ try:
475
+ for file in files:
476
+ shutil.copy(file, save_directory)
477
+ except Exception:
478
+ print("Failed to copy files to save_directory")
479
+
480
+ @classmethod
481
+ def from_pretrained(
482
+ cls,
483
+ pretrained_name_or_path,
484
+ **kwargs,
485
+ ):
486
+ policy = super().from_pretrained(pretrained_name_or_path, **kwargs)
487
+ print(f"Loading the language tokenizer from {pretrained_name_or_path} ...")
488
+ policy.language_tokenizer = AutoTokenizer.from_pretrained(
489
+ pretrained_name_or_path
490
+ )
491
+ return policy
492
+
493
+
494
+ class PI0FlowMatching(nn.Module):
495
+ """
496
+ π0: A Vision-Language-Action Flow Model for General Robot Control
497
+
498
+ [Paper](https://www.physicalintelligence.company/download/pi0.pdf)
499
+ [Jax code](https://github.com/Physical-Intelligence/openpi)
500
+
501
+ Designed by Physical Intelligence. Ported from Jax by Hugging Face.
502
+ ┌──────────────────────────────┐
503
+ │ actions ──────────► noise
504
+ │ ▲ │ │
505
+ │ ┌┴─────┐ │ ┌┴─────┐
506
+ │ kv cache │Gemma │ │ │Gemma │
507
+ │ ┌──────────►│Expert│ │ │Expert│ 4
508
+ │ │ │ │ │ │ │
509
+ │ ┌┴─────▲───┐ │x 10 │ │ │x 10 │
510
+ │ │ │ └▲──▲──┘ │ └▲──▲─-┘
511
+ │ │PaliGemma │ │ │ │ │ │
512
+ │ │ │ │ robot state │ │ robot state
513
+ │ │ │ noise │ vision
514
+ │ └▲──▲──▲───┘ │
515
+ │ │ │ │
516
+ │ │ image(s) │
517
+ │ language tokens │
518
+ └──────────────────────────────┘
519
+ """
520
+
521
+ def __init__(self, config):
522
+ super().__init__()
523
+ self.config = config
524
+
525
+ paligemma_with_export_config = PaliGemmaWithExpertConfig(
526
+ freeze_vision_encoder=self.config.freeze_vision_encoder,
527
+ train_expert_only=self.config.train_expert_only,
528
+ attention_implementation=self.config.attention_implementation,
529
+ paligemma_config=self.config.paligemma_config,
530
+ gemma_expert_config=self.config.gemma_expert_config,
531
+ )
532
+ self.paligemma_with_expert = PaliGemmaWithExpertModel(
533
+ paligemma_with_export_config
534
+ )
535
+
536
+ # Projections are float32
537
+ self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
538
+ self.action_in_proj = nn.Linear(
539
+ self.config.max_action_dim, self.config.proj_width
540
+ )
541
+ self.action_out_proj = nn.Linear(
542
+ self.config.proj_width, self.config.max_action_dim
543
+ )
544
+
545
+ self.action_time_mlp_in = nn.Linear(
546
+ self.config.proj_width * 2, self.config.proj_width
547
+ )
548
+ self.action_time_mlp_out = nn.Linear(
549
+ self.config.proj_width, self.config.proj_width
550
+ )
551
+
552
+ self.set_requires_grad()
553
+
554
+ def set_requires_grad(self):
555
+ for params in self.state_proj.parameters():
556
+ params.requires_grad = self.config.train_state_proj
557
+
558
+ def sample_noise(self, shape, device):
559
+ noise = torch.normal(
560
+ mean=0.0,
561
+ std=1.0,
562
+ size=shape,
563
+ dtype=torch.float32,
564
+ device=device,
565
+ )
566
+ return noise
567
+
568
+ def sample_time(self, bsize, device):
569
+ time_beta = sample_beta(1.5, 1.0, bsize, device)
570
+ time = time_beta * 0.999 + 0.001
571
+ return time.to(dtype=torch.float32, device=device)
572
+
573
+ def embed_prefix(
574
+ self, images, img_masks, lang_tokens, lang_masks
575
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
576
+ """Embed images with SigLIP and language tokens with embedding layer to prepare
577
+ for PaliGemma transformer processing.
578
+ """
579
+ # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
580
+ embs = []
581
+ pad_masks = []
582
+ att_masks = []
583
+
584
+ # TODO: remove for loop
585
+ for (
586
+ img,
587
+ img_mask,
588
+ ) in zip(images, img_masks, strict=False):
589
+ img_emb = self.paligemma_with_expert.embed_image(img)
590
+ img_emb = img_emb.to(dtype=torch.bfloat16)
591
+
592
+ # Normalize image embeddings
593
+ img_emb_dim = img_emb.shape[-1]
594
+ img_emb = img_emb * torch.tensor(
595
+ img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device
596
+ )
597
+
598
+ bsize, num_img_embs = img_emb.shape[:2]
599
+ img_mask = img_mask[:, None].expand(bsize, num_img_embs)
600
+
601
+ embs.append(img_emb)
602
+ pad_masks.append(img_mask)
603
+
604
+ # Create attention masks so that image tokens attend to each other
605
+ att_masks += [0] * num_img_embs
606
+
607
+ lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
608
+
609
+ # Normalize language embeddings
610
+ lang_emb_dim = lang_emb.shape[-1]
611
+ lang_emb = lang_emb * math.sqrt(lang_emb_dim)
612
+
613
+ embs.append(lang_emb)
614
+ pad_masks.append(lang_masks)
615
+
616
+ # full attention between image and language inputs
617
+ num_lang_embs = lang_emb.shape[1]
618
+ att_masks += [0] * num_lang_embs
619
+
620
+ embs = torch.cat(embs, dim=1)
621
+ pad_masks = torch.cat(pad_masks, dim=1)
622
+ att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
623
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
624
+
625
+ return embs, pad_masks, att_masks
626
+
627
+ def embed_suffix(self, state, noisy_actions, timestep):
628
+ """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
629
+ embs = []
630
+ pad_masks = []
631
+ att_masks = []
632
+
633
+ # Embed state
634
+ state_emb = self.state_proj(state)
635
+ state_emb = state_emb.to(dtype=torch.bfloat16)
636
+ embs.append(state_emb[:, None, :])
637
+ bsize = state_emb.shape[0]
638
+ dtype = state_emb.dtype
639
+ device = state_emb.device
640
+
641
+ state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
642
+ pad_masks.append(state_mask)
643
+
644
+ # Set attention masks so that image and language inputs do not attend to state or actions
645
+ att_masks += [1]
646
+
647
+ # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
648
+ time_emb = create_sinusoidal_pos_embedding(
649
+ timestep,
650
+ self.config.proj_width,
651
+ min_period=4e-3,
652
+ max_period=4.0,
653
+ device=device,
654
+ )
655
+ time_emb = time_emb.type(dtype=dtype)
656
+
657
+ # Fuse timestep + action information using an MLP
658
+ action_emb = self.action_in_proj(noisy_actions)
659
+
660
+ time_emb = time_emb[:, None, :].expand_as(action_emb)
661
+ action_time_emb = torch.cat([action_emb, time_emb], dim=2)
662
+
663
+ action_time_emb = self.action_time_mlp_in(action_time_emb)
664
+ action_time_emb = F.silu(action_time_emb) # swish == silu
665
+ action_time_emb = self.action_time_mlp_out(action_time_emb)
666
+
667
+ # Add to input tokens
668
+ embs.append(action_time_emb)
669
+
670
+ bsize, action_time_dim = action_time_emb.shape[:2]
671
+ action_time_mask = torch.ones(
672
+ bsize, action_time_dim, dtype=torch.bool, device=device
673
+ )
674
+ pad_masks.append(action_time_mask)
675
+
676
+ # Set attention masks so that image, language and state inputs do not attend to action tokens
677
+ att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
678
+
679
+ embs = torch.cat(embs, dim=1)
680
+ pad_masks = torch.cat(pad_masks, dim=1)
681
+ att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
682
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
683
+
684
+ return embs, pad_masks, att_masks
685
+
686
+ def forward(
687
+ self,
688
+ images,
689
+ img_masks,
690
+ lang_tokens,
691
+ lang_masks,
692
+ state,
693
+ actions,
694
+ noise=None,
695
+ time=None,
696
+ ) -> Tensor:
697
+ """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
698
+ if noise is None:
699
+ noise = self.sample_noise(actions.shape, actions.device)
700
+
701
+ if time is None:
702
+ time = self.sample_time(actions.shape[0], actions.device)
703
+ time_expanded = time[:, None, None]
704
+ x_t = time_expanded * noise + (1 - time_expanded) * actions
705
+ u_t = noise - actions
706
+
707
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
708
+ images, img_masks, lang_tokens, lang_masks
709
+ )
710
+ suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
711
+ state, x_t, time
712
+ )
713
+
714
+ pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
715
+ att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
716
+
717
+ att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
718
+ position_ids = torch.cumsum(pad_masks, dim=1) - 1
719
+
720
+ (_, suffix_out), _ = self.paligemma_with_expert.forward(
721
+ attention_mask=att_2d_masks,
722
+ position_ids=position_ids,
723
+ past_key_values=None,
724
+ inputs_embeds=[prefix_embs, suffix_embs],
725
+ use_cache=False,
726
+ fill_kv_cache=False,
727
+ )
728
+ suffix_out = suffix_out[:, -self.config.n_action_steps :]
729
+ # Original openpi code, upcast attention output
730
+ suffix_out = suffix_out.to(dtype=torch.float32)
731
+ v_t = self.action_out_proj(suffix_out)
732
+
733
+ losses = F.mse_loss(u_t, v_t, reduction="none")
734
+ return losses
735
+
736
+ def sample_actions(
737
+ self, images, img_masks, lang_tokens, lang_masks, state, noise=None
738
+ ) -> Tensor:
739
+ """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
740
+ bsize = state.shape[0]
741
+ device = state.device
742
+
743
+ if noise is None:
744
+ actions_shape = (
745
+ bsize,
746
+ self.config.n_action_steps,
747
+ self.config.max_action_dim,
748
+ )
749
+ noise = self.sample_noise(actions_shape, device)
750
+
751
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
752
+ images, img_masks, lang_tokens, lang_masks
753
+ )
754
+ prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
755
+ prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
756
+
757
+ # Compute image and language key value cache
758
+ _, past_key_values = self.paligemma_with_expert.forward(
759
+ attention_mask=prefix_att_2d_masks,
760
+ position_ids=prefix_position_ids,
761
+ past_key_values=None,
762
+ inputs_embeds=[prefix_embs, None],
763
+ use_cache=self.config.use_cache,
764
+ fill_kv_cache=True,
765
+ )
766
+
767
+ dt = -1.0 / self.config.num_steps
768
+ dt = torch.tensor(dt, dtype=torch.float32, device=device)
769
+
770
+ x_t = noise
771
+ time = torch.tensor(1.0, dtype=torch.float32, device=device)
772
+ while time >= -dt / 2:
773
+ expanded_time = time.expand(bsize)
774
+ v_t = self.denoise_step(
775
+ state,
776
+ prefix_pad_masks,
777
+ past_key_values,
778
+ x_t,
779
+ expanded_time,
780
+ )
781
+
782
+ # Euler step
783
+ x_t += dt * v_t
784
+ time += dt
785
+ return x_t
786
+
787
+ def denoise_step(
788
+ self,
789
+ state,
790
+ prefix_pad_masks,
791
+ past_key_values,
792
+ x_t,
793
+ timestep,
794
+ ):
795
+ """Apply one denoising step of the noise `x_t` at a given timestep."""
796
+ suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
797
+ state, x_t, timestep
798
+ )
799
+
800
+ suffix_len = suffix_pad_masks.shape[1]
801
+ batch_size = prefix_pad_masks.shape[0]
802
+ prefix_len = prefix_pad_masks.shape[1]
803
+ prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
804
+ batch_size, suffix_len, prefix_len
805
+ )
806
+
807
+ suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
808
+
809
+ full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
810
+
811
+ prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
812
+ position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
813
+
814
+ outputs_embeds, _ = self.paligemma_with_expert.forward(
815
+ attention_mask=full_att_2d_masks,
816
+ position_ids=position_ids,
817
+ past_key_values=past_key_values,
818
+ inputs_embeds=[None, suffix_embs],
819
+ use_cache=self.config.use_cache,
820
+ fill_kv_cache=False,
821
+ )
822
+ suffix_out = outputs_embeds[1]
823
+ suffix_out = suffix_out[:, -self.config.n_action_steps :]
824
+ suffix_out = suffix_out.to(dtype=torch.float32)
825
+ v_t = self.action_out_proj(suffix_out)
826
+ return v_t