|
""" |
|
Example usage of the Pi-0 Bolt Nut Sort model |
|
""" |
|
|
|
from openpi.policies import policy_config |
|
from openpi.training import config |
|
import numpy as np |
|
|
|
def load_model(checkpoint_path: str): |
|
"""Load the Pi-0 bolt nut sort model.""" |
|
train_config = config.get_config("pi0_bns") |
|
|
|
policy = policy_config.create_trained_policy( |
|
train_config, |
|
checkpoint_path, |
|
default_prompt="sort the bolts and the nuts into separate baskets" |
|
) |
|
|
|
return policy |
|
|
|
def create_observation(images, joint_positions): |
|
"""Create observation dict for the model.""" |
|
return { |
|
"images": { |
|
"cam_high": images["high"], |
|
"cam_left_wrist": images["left_wrist"], |
|
"cam_right_wrist": images["right_wrist"], |
|
}, |
|
"state": joint_positions, |
|
"prompt": "sort the bolts and the nuts into separate baskets" |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
policy = load_model("./checkpoint") |
|
|
|
|
|
images = { |
|
"high": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8), |
|
"left_wrist": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8), |
|
"right_wrist": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8), |
|
} |
|
joint_positions = np.random.randn(14).astype(np.float32) |
|
|
|
obs = create_observation(images, joint_positions) |
|
|
|
|
|
result = policy.infer(obs) |
|
actions = result["actions"] |
|
|
|
print(f"Generated actions shape: {actions.shape}") |
|
|