OpenFly
Collection
5 items
•
Updated
•
2
OpenFly, a platform comprising a versatile toolchain and large-scale benchmark for aerial VLN. The code is purely huggingFace-based and concise, with efficient performance.
For full details, please read our paper and see our project page.
OpenFly relies solely on HuggingFace Transformers 🤗, making deployment extremely easy. If your environment supports transformers >= 4.47.0
, you can directly use the following code to load the model and perform inference.
from typing import Dict, List, Optional, Union
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from transformers import LlamaTokenizerFast
from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor
import os, json
from model.prismatic import PrismaticVLM
from model.overwatch import initialize_overwatch
from model.action_tokenizer import ActionTokenizer
from model.vision_backbone import DinoSigLIPViTBackbone, DinoSigLIPImageTransform
from model.llm_backbone import LLaMa2LLMBackbone
from extern.hf.configuration_prismatic import OpenFlyConfig
from extern.hf.modeling_prismatic import OpenVLAForActionPrediction
from extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
AutoConfig.register("openvla", OpenFlyConfig)
AutoImageProcessor.register(OpenFlyConfig, PrismaticImageProcessor)
AutoProcessor.register(OpenFlyConfig, PrismaticProcessor)
AutoModelForVision2Seq.register(OpenFlyConfig, OpenVLAForActionPrediction)
model_name_or_path="IPEC-COMMUNITY/openfly-agent-7b"
processor = AutoProcessor.from_pretrained(model_name_or_path)
model = AutoModelForVision2Seq.from_pretrained(
model_name_or_path,
attn_implementation="flash_attention_2", # [Optional] Requires `flash_attn`
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
).to("cuda:0")
image = Image.fromarray(cv2.imread("example.png"))
prompt = "Take off, go straight pass the river"
inputs = processor(prompt, [image, image, image]).to("cuda:0", dtype=torch.bfloat16)
action = model.predict_action(**inputs, unnorm_key="vln_norm", do_sample=False)
print(action)
Base model
openvla/openvla-7b-prismatic