|
--- |
|
license: mit |
|
base_model: |
|
- google/paligemma-3b-pt-224 |
|
tags: |
|
- openpi0 |
|
- jax |
|
datasets: |
|
- IPEC-COMMUNITY/bridge_orig_lerobot |
|
--- |
|
|
|
|
|
|
|
download the model |
|
|
|
```bash |
|
huggingface-cli download --resume-download --local-dir-use-symlinks False ${model} --local-dir $(basename ${model}) |
|
``` |
|
|
|
launch the openpi0 server, please create the [openpi](https://github.com/Physical-Intelligence/openpi/) environment first |
|
|
|
```bash |
|
export OPENPI_DATA_HOME=/PATH/TO/OPENPI_DATA_HOME |
|
export LEROBOT_HOME=/PATH/TO/LEROBOT_HOME |
|
|
|
uv run scripts/serve_policy.py policy:checkpoint \ |
|
--policy.config=pi0_fast_bridge_fft_pt_tokenizer \ |
|
--policy.dir=$THE_MODEL_PATH |
|
``` |
|
|
|
### DataConfig |
|
```python |
|
@dataclasses.dataclass(frozen=True) |
|
class LeRobotBridgeDataConfig(DataConfigFactory): |
|
use_quantile_norm: bool = True |
|
|
|
# Action keys that will be used to read the action sequence from the dataset. |
|
action_sequence_keys: Sequence[str] = ("action",) |
|
|
|
prompt_from_task: bool = True |
|
|
|
@override |
|
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: |
|
# Make inputs look like they come from the Libero environment |
|
repack_transform = _transforms.Group( |
|
inputs=[ |
|
_transforms.RepackTransform( |
|
{ |
|
"observation/primary_image": "observation.images.image_0", |
|
# "observation/left_yellow_image": "observation.images.image_1", |
|
# "observation/right_blue_image": "observation.images.image_2", |
|
# "observation/wirst_image": "observation.images.image_3", |
|
"observation/state": "observation.state", |
|
"actions": "action", |
|
"prompt": "prompt", |
|
} |
|
) |
|
] |
|
) |
|
|
|
# Prepare data for policy training |
|
# Convert images to uint8 numpy arrays, add masks |
|
data_transforms = _transforms.Group( |
|
inputs=[ |
|
bridge_policy.BridgeInputs( |
|
action_dim=model_config.action_dim, |
|
model_type=model_config.model_type, |
|
) |
|
], |
|
outputs=[bridge_policy.BridgeOutputs()], |
|
) |
|
|
|
# Model transforms include things like tokenizing the prompt and action targets |
|
model_transforms = ModelTransformFactory()(model_config) |
|
|
|
return dataclasses.replace( |
|
self.create_base_config(assets_dirs), |
|
repack_transforms=repack_transform, |
|
data_transforms=data_transforms, |
|
model_transforms=model_transforms, |
|
use_quantile_norm=self.use_quantile_norm, |
|
action_sequence_keys=self.action_sequence_keys, |
|
prompt_from_task=self.prompt_from_task, |
|
) |
|
``` |