TraceVLA-7B

TraceVLA-7B model is a vision-language-action model obtained by finetuning the base OpenVLA model with visual trace prompting technique.

Results on SimplerEnv Fractal + SimplerEnv:

Fractal:

Policy/Settings Pick up Coke Move near Open/Close Drawer Put in Drawer Average Success Rate
(Visual Matching) OpenVLA-7B 23.7% 65.0% 57.4% 0.% 36.5%
(Visual Matching) TraceVLA-7B 45.0% 63.8% 63.1% 11.1.% 45.8%
(Variant Aggregation) OpenVLA-7B 61.3% 55.8% 24.9% 1.0% 35.8%
(Variant Aggregation) TraceVLA-7B 64.3% 60.6% 61.6% 12.5.% 49.8%

Bridge:

Policy/Settings Put Spoon Put Carrot Stack Block Put Eggplant Average Success Rate
OpenVLA-7B 8.3% 8.3% 4.2% 45.8% 16.7%
TraceVLA-7B 12.5% 16.6% 16.6% 65.0% 27.7%

Sample Inference Code

Here is the sample inference code of TraceVLA-7B model.

model_path = "furonghuang-lab/tracevla_7b" 
# Load Processor & VLA
processor = AutoProcessor.from_pretrained(
    model_path,
    trust_remote_code=True,
    num_crops=1, 
)

vla = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    _attn_implementation='flash_attention_2',
    use_cache=True
).to(device='cuda')

# Load Visual Trace Processor
# cotracker_model_path corresponds to the path to your downloaded scaled_offline.pth checkpoint
from prismatic.eval.trace_processor import TraceProcessor
trace_processor = TraceProcessor(cotracker_model_path)

# Grab image input & format prompt
# In case where the visual trace returned by Co-Tracker is not valid, we use the default openvla prompt.
openvla_prompt_template = "In: What action should the robot take to {task_description}?\nOut:"
tracevla_prompt_template = "In: You are given two images: one with the original robot observation, and another one marked with historical traces of the robot end effector and moving objects, separated by a special separator token. What action should the robot take to {task_description}?\nOut:"

image: Image.Image = get_from_camera(...)
image_overlaid, has_trace = trace_processors.process_image(image)

if not has_trace:
    prompt = openvla_prompt_template.format(task_description=task_description)
    inputs = processor(prompt, [image, image]).to(device='cuda', dtype=torch.bfloat16)
else:
    prompt = tracevla_prompt_template.format(task_description=task_description)
    inputs = processor(prompt, [image, image_overlaid]).to(device='cuda', dtype=torch.bfloat16)

### Predict the action
with torch.inference_mode():
    action = vla.predict_action(**inputs)

# Execute the action
robot.act(action, ...)

For more examples, including scripts for finetuning TraceVLA models on your own robot demonstration datasets, check out our repository.

Citation

If you find our code or models useful in your work, please cite our paper:

@misc{zheng2024tracevlavisualtraceprompting,
      title={TraceVLA: Visual Trace Prompting Enhances Spatial-Temporal Awareness for Generalist Robotic Policies}, 
      author={Ruijie Zheng and Yongyuan Liang and Shuaiyi Huang and Jianfeng Gao and Hal Daumé III and Andrey Kolobov and Furong Huang and Jianwei Yang},
      year={2024},
      eprint={2412.10345},
      archivePrefix={arXiv},
      primaryClass={cs.RO},
      url={https://arxiv.org/abs/2412.10345}, 
}
Downloads last month
2
Safetensors
Model size
7.54B params
Tensor type
BF16
·
Inference Examples
Unable to determine this model's library. Check the docs .

Collection including furonghuang-lab/tracevla_7b