|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
π0: A Vision-Language-Action Flow Model for General Robot Control |
|
|
|
[Paper](https://www.physicalintelligence.company/download/pi0.pdf) |
|
[Jax code](https://github.com/Physical-Intelligence/openpi) |
|
|
|
Designed by Physical Intelligence. Ported from Jax by Hugging Face. |
|
|
|
Install pi0 extra dependencies: |
|
```bash |
|
pip install -e ".[pi0]" |
|
``` |
|
|
|
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`): |
|
```bash |
|
python lerobot/scripts/train.py \ |
|
--policy.path=lerobot/pi0 \ |
|
--dataset.repo_id=danaaubakirova/koch_test |
|
``` |
|
|
|
Example of finetuning the pi0 neural network with PaliGemma and expert Gemma |
|
pretrained with VLM default parameters before pi0 finetuning: |
|
```bash |
|
python lerobot/scripts/train.py \ |
|
--policy.type=pi0 \ |
|
--dataset.repo_id=danaaubakirova/koch_test |
|
``` |
|
|
|
Example of using the pi0 pretrained model outside LeRobot training framework: |
|
```python |
|
policy = Pi0Policy.from_pretrained("lerobot/pi0") |
|
``` |
|
|
|
""" |
|
|
|
import math |
|
from collections import deque |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from configuration_pi0 import PI0Config |
|
from lerobot.common.constants import ACTION, OBS_ROBOT |
|
from lerobot.common.policies.normalize import Normalize, Unnormalize |
|
from lerobot.common.policies.pretrained import PreTrainedPolicy |
|
from lerobot.common.utils.utils import get_safe_dtype |
|
from paligemma_with_expert import ( |
|
PaliGemmaWithExpertConfig, |
|
PaliGemmaWithExpertModel, |
|
) |
|
from torch import Tensor, nn |
|
from transformers import AutoTokenizer |
|
|
|
|
|
def create_sinusoidal_pos_embedding( |
|
time: torch.tensor, |
|
dimension: int, |
|
min_period: float, |
|
max_period: float, |
|
device="cpu", |
|
) -> Tensor: |
|
"""Computes sine-cosine positional embedding vectors for scalar positions.""" |
|
if dimension % 2 != 0: |
|
raise ValueError(f"dimension ({dimension}) must be divisible by 2") |
|
|
|
if time.ndim != 1: |
|
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") |
|
|
|
dtype = get_safe_dtype(torch.float64, device.type) |
|
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) |
|
period = min_period * (max_period / min_period) ** fraction |
|
|
|
|
|
scaling_factor = 1.0 / period * 2 * math.pi |
|
sin_input = scaling_factor[None, :] * time[:, None] |
|
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) |
|
return pos_emb |
|
|
|
|
|
def sample_beta(alpha, beta, bsize, device): |
|
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha) |
|
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta) |
|
return gamma1 / (gamma1 + gamma2) |
|
|
|
|
|
def make_att_2d_masks(pad_masks, att_masks): |
|
"""Copied from big_vision. |
|
|
|
Tokens can attend to valid inputs tokens which have a cumulative mask_ar |
|
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to |
|
setup several types of attention, for example: |
|
|
|
[[1 1 1 1 1 1]]: pure causal attention. |
|
|
|
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between |
|
themselves and the last 3 tokens have a causal attention. The first |
|
entry could also be a 1 without changing behaviour. |
|
|
|
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a |
|
block can attend all previous blocks and all tokens on the same block. |
|
|
|
Args: |
|
input_mask: bool[B, N] true if its part of the input, false if padding. |
|
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on |
|
it and 0 where it shares the same attention mask as the previous token. |
|
""" |
|
if att_masks.ndim != 2: |
|
raise ValueError(att_masks.ndim) |
|
if pad_masks.ndim != 2: |
|
raise ValueError(pad_masks.ndim) |
|
|
|
cumsum = torch.cumsum(att_masks, dim=1) |
|
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] |
|
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] |
|
att_2d_masks = att_2d_masks & pad_2d_masks |
|
return att_2d_masks |
|
|
|
|
|
def resize_with_pad(img, width, height, pad_value=-1): |
|
|
|
if img.ndim != 4: |
|
raise ValueError(f"(b,c,h,w) expected, but {img.shape}") |
|
|
|
cur_height, cur_width = img.shape[2:] |
|
|
|
ratio = max(cur_width / width, cur_height / height) |
|
resized_height = int(cur_height / ratio) |
|
resized_width = int(cur_width / ratio) |
|
resized_img = F.interpolate( |
|
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False |
|
) |
|
|
|
pad_height = max(0, int(height - resized_height)) |
|
pad_width = max(0, int(width - resized_width)) |
|
|
|
|
|
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) |
|
return padded_img |
|
|
|
|
|
def pad_vector(vector, new_dim): |
|
"""Can be (batch_size x sequence_length x features_dimension) |
|
or (batch_size x features_dimension) |
|
""" |
|
if vector.shape[-1] == new_dim: |
|
return vector |
|
shape = list(vector.shape) |
|
current_dim = shape[-1] |
|
shape[-1] = new_dim |
|
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device) |
|
new_vector[..., :current_dim] = vector |
|
return new_vector |
|
|
|
|
|
def normalize(x, min_val, max_val): |
|
return (x - min_val) / (max_val - min_val) |
|
|
|
|
|
def unnormalize(x, min_val, max_val): |
|
return x * (max_val - min_val) + min_val |
|
|
|
|
|
def safe_arcsin(value): |
|
|
|
|
|
return torch.arcsin(torch.clamp(value, -1.0, 1.0)) |
|
|
|
|
|
def aloha_gripper_to_angular(value): |
|
|
|
|
|
|
|
|
|
|
|
|
|
value = unnormalize(value, min_val=0.01844, max_val=0.05800) |
|
|
|
|
|
def linear_to_radian(linear_position, arm_length, horn_radius): |
|
value = (horn_radius**2 + linear_position**2 - arm_length**2) / ( |
|
2 * horn_radius * linear_position |
|
) |
|
return safe_arcsin(value) |
|
|
|
|
|
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) |
|
|
|
|
|
|
|
return normalize(value, min_val=0.4, max_val=1.5) |
|
|
|
|
|
def aloha_gripper_from_angular(value): |
|
|
|
|
|
|
|
|
|
value = unnormalize(value, min_val=0.4, max_val=1.5) |
|
|
|
|
|
|
|
return normalize(value, min_val=-0.6213, max_val=1.4910) |
|
|
|
|
|
def aloha_gripper_from_angular_inv(value): |
|
|
|
value = unnormalize(value, min_val=-0.6213, max_val=1.4910) |
|
return normalize(value, min_val=0.4, max_val=1.5) |
|
|
|
|
|
class PI0Policy(PreTrainedPolicy): |
|
"""Wrapper class around PI0FlowMatching model to train and run inference within LeRobot.""" |
|
|
|
config_class = PI0Config |
|
name = "pi0" |
|
|
|
def __init__( |
|
self, |
|
config: PI0Config, |
|
dataset_stats: dict[str, dict[str, Tensor]] | None = None, |
|
): |
|
""" |
|
Args: |
|
config: Policy configuration class instance or None, in which case the default instantiation of |
|
the configuration class is used. |
|
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected |
|
that they will be passed with a call to `load_state_dict` before the policy is used. |
|
""" |
|
|
|
super().__init__(config) |
|
config.validate_features() |
|
self.config = config |
|
|
|
|
|
self.normalize_inputs = Normalize( |
|
config.input_features, config.normalization_mapping, dataset_stats |
|
) |
|
self.normalize_targets = Normalize( |
|
config.output_features, config.normalization_mapping, dataset_stats |
|
) |
|
self.unnormalize_outputs = Unnormalize( |
|
config.output_features, config.normalization_mapping, dataset_stats |
|
) |
|
|
|
|
|
self.language_tokenizer = None |
|
self.model = PI0FlowMatching(config) |
|
|
|
self.reset() |
|
|
|
def reset(self): |
|
"""This should be called whenever the environment is reset.""" |
|
self._action_queue = deque([], maxlen=self.config.n_action_steps) |
|
|
|
def get_optim_params(self) -> dict: |
|
return self.parameters() |
|
|
|
@torch.no_grad |
|
def select_action( |
|
self, batch: dict[str, Tensor], noise: Tensor | None = None |
|
) -> Tensor: |
|
"""Select a single action given environment observations. |
|
|
|
This method wraps `select_actions` in order to return one action at a time for execution in the |
|
environment. It works by managing the actions in a queue and only calling `select_actions` when the |
|
queue is empty. |
|
""" |
|
self.eval() |
|
|
|
if self.config.adapt_to_pi_aloha: |
|
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) |
|
|
|
batch = self.normalize_inputs(batch) |
|
|
|
|
|
|
|
images, img_masks = self.prepare_images(batch) |
|
state = self.prepare_state(batch) |
|
lang_tokens, lang_masks = self.prepare_language(batch) |
|
|
|
actions = self.model.sample_actions( |
|
images, img_masks, lang_tokens, lang_masks, state, noise=noise |
|
) |
|
|
|
|
|
original_action_dim = self.config.action_feature.shape[0] |
|
actions = actions[:, :, :original_action_dim] |
|
|
|
actions = self.unnormalize_outputs({"action": actions})["action"] |
|
|
|
if self.config.adapt_to_pi_aloha: |
|
actions = self._pi_aloha_encode_actions(actions) |
|
return actions |
|
|
|
def forward( |
|
self, batch: dict[str, Tensor], noise=None, time=None |
|
) -> tuple[Tensor, dict[str, Tensor]]: |
|
"""Do a full training forward pass to compute the loss""" |
|
if self.config.adapt_to_pi_aloha: |
|
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) |
|
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) |
|
|
|
batch = self.normalize_inputs(batch) |
|
batch = self.normalize_targets(batch) |
|
|
|
images, img_masks = self.prepare_images(batch) |
|
state = self.prepare_state(batch) |
|
lang_tokens, lang_masks = self.prepare_language(batch) |
|
actions = self.prepare_action(batch) |
|
actions_is_pad = batch.get("action_is_pad") |
|
|
|
loss_dict = {} |
|
losses = self.model.forward( |
|
images, img_masks, lang_tokens, lang_masks, state, actions, noise, time |
|
) |
|
|
|
|
|
if actions_is_pad is not None: |
|
in_episode_bound = ~actions_is_pad |
|
losses = losses * in_episode_bound.unsqueeze(-1) |
|
|
|
|
|
|
|
losses = losses[:, :, : self.config.max_action_dim] |
|
|
|
|
|
|
|
loss = losses.mean() |
|
|
|
loss_dict["l2_loss"] = loss.item() |
|
|
|
return loss, loss_dict |
|
|
|
def prepare_images(self, batch): |
|
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and |
|
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP. |
|
""" |
|
images = [] |
|
img_masks = [] |
|
|
|
present_img_keys = [key for key in self.config.image_features if key in batch] |
|
missing_img_keys = [ |
|
key for key in self.config.image_features if key not in batch |
|
] |
|
|
|
if len(present_img_keys) == 0: |
|
raise ValueError( |
|
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" |
|
) |
|
|
|
|
|
for key in present_img_keys: |
|
img = batch[key] |
|
|
|
if self.config.resize_imgs_with_padding is not None: |
|
img = resize_with_pad( |
|
img, *self.config.resize_imgs_with_padding, pad_value=0 |
|
) |
|
|
|
|
|
img = img * 2.0 - 1.0 |
|
|
|
bsize = img.shape[0] |
|
device = img.device |
|
mask = torch.ones(bsize, dtype=torch.bool, device=device) |
|
images.append(img) |
|
img_masks.append(mask) |
|
|
|
|
|
|
|
for num_empty_cameras in range(len(missing_img_keys)): |
|
if num_empty_cameras >= self.config.empty_cameras: |
|
break |
|
img = torch.ones_like(img) * -1 |
|
mask = torch.zeros_like(mask) |
|
images.append(img) |
|
img_masks.append(mask) |
|
|
|
return images, img_masks |
|
|
|
def prepare_language(self, batch) -> tuple[Tensor, Tensor]: |
|
"""Tokenize the text input""" |
|
device = batch[OBS_ROBOT].device |
|
tasks = batch["task"] |
|
|
|
|
|
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] |
|
|
|
tokenized_prompt = self.language_tokenizer.__call__( |
|
tasks, |
|
padding="max_length", |
|
padding_side="right", |
|
max_length=self.config.tokenizer_max_length, |
|
return_tensors="pt", |
|
truncation=True, |
|
) |
|
lang_tokens = tokenized_prompt["input_ids"].to(device=device) |
|
lang_masks = tokenized_prompt["attention_mask"].to( |
|
device=device, dtype=torch.bool |
|
) |
|
|
|
return lang_tokens, lang_masks |
|
|
|
def _pi_aloha_decode_state(self, state): |
|
|
|
for motor_idx in [1, 2, 8, 9]: |
|
state[:, motor_idx] *= -1 |
|
|
|
for motor_idx in [6, 13]: |
|
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) |
|
return state |
|
|
|
def _pi_aloha_encode_actions(self, actions): |
|
|
|
for motor_idx in [1, 2, 8, 9]: |
|
actions[:, :, motor_idx] *= -1 |
|
|
|
for motor_idx in [6, 13]: |
|
actions[:, :, motor_idx] = aloha_gripper_from_angular( |
|
actions[:, :, motor_idx] |
|
) |
|
return actions |
|
|
|
def _pi_aloha_encode_actions_inv(self, actions): |
|
|
|
for motor_idx in [1, 2, 8, 9]: |
|
actions[:, :, motor_idx] *= -1 |
|
|
|
for motor_idx in [6, 13]: |
|
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv( |
|
actions[:, :, motor_idx] |
|
) |
|
return actions |
|
|
|
def prepare_state(self, batch): |
|
"""Pad state""" |
|
state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim) |
|
return state |
|
|
|
def prepare_action(self, batch): |
|
"""Pad action""" |
|
actions = pad_vector(batch[ACTION], self.config.max_action_dim) |
|
return actions |
|
|
|
def _save_pretrained(self, save_directory) -> None: |
|
super()._save_pretrained(save_directory) |
|
print(f"Saving the language tokenizer to {save_directory} ...") |
|
self.language_tokenizer.save_pretrained(save_directory) |
|
|
|
print(f"Copying config and model to {save_directory} ...") |
|
import shutil |
|
|
|
files = [ |
|
"pi0/configuration_pi0.py", |
|
"pi0/flex_attention.py", |
|
"pi0/modeling_pi0.py", |
|
"pi0/paligemma_with_expert.py", |
|
] |
|
try: |
|
for file in files: |
|
shutil.copy(file, save_directory) |
|
except Exception: |
|
print("Failed to copy files to save_directory") |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_name_or_path, |
|
**kwargs, |
|
): |
|
policy = super().from_pretrained(pretrained_name_or_path, **kwargs) |
|
print(f"Loading the language tokenizer from {pretrained_name_or_path} ...") |
|
policy.language_tokenizer = AutoTokenizer.from_pretrained( |
|
pretrained_name_or_path |
|
) |
|
return policy |
|
|
|
|
|
class PI0FlowMatching(nn.Module): |
|
""" |
|
π0: A Vision-Language-Action Flow Model for General Robot Control |
|
|
|
[Paper](https://www.physicalintelligence.company/download/pi0.pdf) |
|
[Jax code](https://github.com/Physical-Intelligence/openpi) |
|
|
|
Designed by Physical Intelligence. Ported from Jax by Hugging Face. |
|
┌──────────────────────────────┐ |
|
│ actions ──────────► noise |
|
│ ▲ │ │ |
|
│ ┌┴─────┐ │ ┌┴─────┐ |
|
│ kv cache │Gemma │ │ │Gemma │ |
|
│ ┌──────────►│Expert│ │ │Expert│ 4 |
|
│ │ │ │ │ │ │ |
|
│ ┌┴─────▲───┐ │x 10 │ │ │x 10 │ |
|
│ │ │ └▲──▲──┘ │ └▲──▲─-┘ |
|
│ │PaliGemma │ │ │ │ │ │ |
|
│ │ │ │ robot state │ │ robot state |
|
│ │ │ noise │ vision |
|
│ └▲──▲──▲───┘ │ |
|
│ │ │ │ |
|
│ │ image(s) │ |
|
│ language tokens │ |
|
└──────────────────────────────┘ |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
|
|
paligemma_with_export_config = PaliGemmaWithExpertConfig( |
|
freeze_vision_encoder=self.config.freeze_vision_encoder, |
|
train_expert_only=self.config.train_expert_only, |
|
attention_implementation=self.config.attention_implementation, |
|
paligemma_config=self.config.paligemma_config, |
|
gemma_expert_config=self.config.gemma_expert_config, |
|
) |
|
self.paligemma_with_expert = PaliGemmaWithExpertModel( |
|
paligemma_with_export_config |
|
) |
|
|
|
|
|
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width) |
|
self.action_in_proj = nn.Linear( |
|
self.config.max_action_dim, self.config.proj_width |
|
) |
|
self.action_out_proj = nn.Linear( |
|
self.config.proj_width, self.config.max_action_dim |
|
) |
|
|
|
self.action_time_mlp_in = nn.Linear( |
|
self.config.proj_width * 2, self.config.proj_width |
|
) |
|
self.action_time_mlp_out = nn.Linear( |
|
self.config.proj_width, self.config.proj_width |
|
) |
|
|
|
self.set_requires_grad() |
|
|
|
def set_requires_grad(self): |
|
for params in self.state_proj.parameters(): |
|
params.requires_grad = self.config.train_state_proj |
|
|
|
def sample_noise(self, shape, device): |
|
noise = torch.normal( |
|
mean=0.0, |
|
std=1.0, |
|
size=shape, |
|
dtype=torch.float32, |
|
device=device, |
|
) |
|
return noise |
|
|
|
def sample_time(self, bsize, device): |
|
time_beta = sample_beta(1.5, 1.0, bsize, device) |
|
time = time_beta * 0.999 + 0.001 |
|
return time.to(dtype=torch.float32, device=device) |
|
|
|
def embed_prefix( |
|
self, images, img_masks, lang_tokens, lang_masks |
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
"""Embed images with SigLIP and language tokens with embedding layer to prepare |
|
for PaliGemma transformer processing. |
|
""" |
|
|
|
embs = [] |
|
pad_masks = [] |
|
att_masks = [] |
|
|
|
|
|
for ( |
|
img, |
|
img_mask, |
|
) in zip(images, img_masks, strict=False): |
|
img_emb = self.paligemma_with_expert.embed_image(img) |
|
img_emb = img_emb.to(dtype=torch.bfloat16) |
|
|
|
|
|
img_emb_dim = img_emb.shape[-1] |
|
img_emb = img_emb * torch.tensor( |
|
img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device |
|
) |
|
|
|
bsize, num_img_embs = img_emb.shape[:2] |
|
img_mask = img_mask[:, None].expand(bsize, num_img_embs) |
|
|
|
embs.append(img_emb) |
|
pad_masks.append(img_mask) |
|
|
|
|
|
att_masks += [0] * num_img_embs |
|
|
|
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) |
|
|
|
|
|
lang_emb_dim = lang_emb.shape[-1] |
|
lang_emb = lang_emb * math.sqrt(lang_emb_dim) |
|
|
|
embs.append(lang_emb) |
|
pad_masks.append(lang_masks) |
|
|
|
|
|
num_lang_embs = lang_emb.shape[1] |
|
att_masks += [0] * num_lang_embs |
|
|
|
embs = torch.cat(embs, dim=1) |
|
pad_masks = torch.cat(pad_masks, dim=1) |
|
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) |
|
att_masks = att_masks[None, :].expand(bsize, len(att_masks)) |
|
|
|
return embs, pad_masks, att_masks |
|
|
|
def embed_suffix(self, state, noisy_actions, timestep): |
|
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" |
|
embs = [] |
|
pad_masks = [] |
|
att_masks = [] |
|
|
|
|
|
state_emb = self.state_proj(state) |
|
state_emb = state_emb.to(dtype=torch.bfloat16) |
|
embs.append(state_emb[:, None, :]) |
|
bsize = state_emb.shape[0] |
|
dtype = state_emb.dtype |
|
device = state_emb.device |
|
|
|
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device) |
|
pad_masks.append(state_mask) |
|
|
|
|
|
att_masks += [1] |
|
|
|
|
|
time_emb = create_sinusoidal_pos_embedding( |
|
timestep, |
|
self.config.proj_width, |
|
min_period=4e-3, |
|
max_period=4.0, |
|
device=device, |
|
) |
|
time_emb = time_emb.type(dtype=dtype) |
|
|
|
|
|
action_emb = self.action_in_proj(noisy_actions) |
|
|
|
time_emb = time_emb[:, None, :].expand_as(action_emb) |
|
action_time_emb = torch.cat([action_emb, time_emb], dim=2) |
|
|
|
action_time_emb = self.action_time_mlp_in(action_time_emb) |
|
action_time_emb = F.silu(action_time_emb) |
|
action_time_emb = self.action_time_mlp_out(action_time_emb) |
|
|
|
|
|
embs.append(action_time_emb) |
|
|
|
bsize, action_time_dim = action_time_emb.shape[:2] |
|
action_time_mask = torch.ones( |
|
bsize, action_time_dim, dtype=torch.bool, device=device |
|
) |
|
pad_masks.append(action_time_mask) |
|
|
|
|
|
att_masks += [1] + ([0] * (self.config.n_action_steps - 1)) |
|
|
|
embs = torch.cat(embs, dim=1) |
|
pad_masks = torch.cat(pad_masks, dim=1) |
|
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) |
|
att_masks = att_masks[None, :].expand(bsize, len(att_masks)) |
|
|
|
return embs, pad_masks, att_masks |
|
|
|
def forward( |
|
self, |
|
images, |
|
img_masks, |
|
lang_tokens, |
|
lang_masks, |
|
state, |
|
actions, |
|
noise=None, |
|
time=None, |
|
) -> Tensor: |
|
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" |
|
if noise is None: |
|
noise = self.sample_noise(actions.shape, actions.device) |
|
|
|
if time is None: |
|
time = self.sample_time(actions.shape[0], actions.device) |
|
time_expanded = time[:, None, None] |
|
x_t = time_expanded * noise + (1 - time_expanded) * actions |
|
u_t = noise - actions |
|
|
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( |
|
images, img_masks, lang_tokens, lang_masks |
|
) |
|
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix( |
|
state, x_t, time |
|
) |
|
|
|
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) |
|
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) |
|
|
|
att_2d_masks = make_att_2d_masks(pad_masks, att_masks) |
|
position_ids = torch.cumsum(pad_masks, dim=1) - 1 |
|
|
|
(_, suffix_out), _ = self.paligemma_with_expert.forward( |
|
attention_mask=att_2d_masks, |
|
position_ids=position_ids, |
|
past_key_values=None, |
|
inputs_embeds=[prefix_embs, suffix_embs], |
|
use_cache=False, |
|
fill_kv_cache=False, |
|
) |
|
suffix_out = suffix_out[:, -self.config.n_action_steps :] |
|
|
|
suffix_out = suffix_out.to(dtype=torch.float32) |
|
v_t = self.action_out_proj(suffix_out) |
|
|
|
losses = F.mse_loss(u_t, v_t, reduction="none") |
|
return losses |
|
|
|
def sample_actions( |
|
self, images, img_masks, lang_tokens, lang_masks, state, noise=None |
|
) -> Tensor: |
|
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" |
|
bsize = state.shape[0] |
|
device = state.device |
|
|
|
if noise is None: |
|
actions_shape = ( |
|
bsize, |
|
self.config.n_action_steps, |
|
self.config.max_action_dim, |
|
) |
|
noise = self.sample_noise(actions_shape, device) |
|
|
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( |
|
images, img_masks, lang_tokens, lang_masks |
|
) |
|
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) |
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 |
|
|
|
|
|
_, past_key_values = self.paligemma_with_expert.forward( |
|
attention_mask=prefix_att_2d_masks, |
|
position_ids=prefix_position_ids, |
|
past_key_values=None, |
|
inputs_embeds=[prefix_embs, None], |
|
use_cache=self.config.use_cache, |
|
fill_kv_cache=True, |
|
) |
|
|
|
dt = -1.0 / self.config.num_steps |
|
dt = torch.tensor(dt, dtype=torch.float32, device=device) |
|
|
|
x_t = noise |
|
time = torch.tensor(1.0, dtype=torch.float32, device=device) |
|
while time >= -dt / 2: |
|
expanded_time = time.expand(bsize) |
|
v_t = self.denoise_step( |
|
state, |
|
prefix_pad_masks, |
|
past_key_values, |
|
x_t, |
|
expanded_time, |
|
) |
|
|
|
|
|
x_t += dt * v_t |
|
time += dt |
|
return x_t |
|
|
|
def denoise_step( |
|
self, |
|
state, |
|
prefix_pad_masks, |
|
past_key_values, |
|
x_t, |
|
timestep, |
|
): |
|
"""Apply one denoising step of the noise `x_t` at a given timestep.""" |
|
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix( |
|
state, x_t, timestep |
|
) |
|
|
|
suffix_len = suffix_pad_masks.shape[1] |
|
batch_size = prefix_pad_masks.shape[0] |
|
prefix_len = prefix_pad_masks.shape[1] |
|
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand( |
|
batch_size, suffix_len, prefix_len |
|
) |
|
|
|
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) |
|
|
|
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) |
|
|
|
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] |
|
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 |
|
|
|
outputs_embeds, _ = self.paligemma_with_expert.forward( |
|
attention_mask=full_att_2d_masks, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=[None, suffix_embs], |
|
use_cache=self.config.use_cache, |
|
fill_kv_cache=False, |
|
) |
|
suffix_out = outputs_embeds[1] |
|
suffix_out = suffix_out[:, -self.config.n_action_steps :] |
|
suffix_out = suffix_out.to(dtype=torch.float32) |
|
v_t = self.action_out_proj(suffix_out) |
|
return v_t |
|
|