pi0_bnss / example_usage.py
naungth's picture
Upload Pi-0 bolt nut sort model (step 29999)
e13cb78 verified
"""
Example usage of the Pi-0 Bolt Nut Sort model
"""
from openpi.policies import policy_config
from openpi.training import config
import numpy as np
def load_model(checkpoint_path: str):
"""Load the Pi-0 bolt nut sort model."""
train_config = config.get_config("pi0_bns")
policy = policy_config.create_trained_policy(
train_config,
checkpoint_path,
default_prompt="sort the bolts and the nuts into separate baskets"
)
return policy
def create_observation(images, joint_positions):
"""Create observation dict for the model."""
return {
"images": {
"cam_high": images["high"], # [224, 224, 3] uint8
"cam_left_wrist": images["left_wrist"], # [224, 224, 3] uint8
"cam_right_wrist": images["right_wrist"], # [224, 224, 3] uint8
},
"state": joint_positions, # [14] float32
"prompt": "sort the bolts and the nuts into separate baskets"
}
# Example usage
if __name__ == "__main__":
# Load model
policy = load_model("./checkpoint")
# Create dummy observation
images = {
"high": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8),
"left_wrist": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8),
"right_wrist": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8),
}
joint_positions = np.random.randn(14).astype(np.float32)
obs = create_observation(images, joint_positions)
# Get actions
result = policy.infer(obs)
actions = result["actions"] # [50, 14] - 50 timesteps of 14-DoF actions
print(f"Generated actions shape: {actions.shape}")