OpenCUA-7B-exl2 / inference_example.py
sujitvasanth's picture
Upload inference_example.py
eb78691 verified
raw
history blame
7.82 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""OpenCUA-7B EXL2 — Standalone Visual Inference (Streaming)
Tested On exllamav2 0.3.2, python3.12.9, torch 2.6.0+cu126
- Applies a minimal, safe monkey-patch so ExLlamaV2 knows how to wire the
OpenCUA EXL2 architecture (Qwen2.5-style vision tower + Llama-like LM).
- Keeps vision RoPE active (DO NOT neutralize positional embeddings).
- Chooses a valid 1-D RoPE style if available (LLAMA > HF > default).
- Loads model + vision tower, extracts EXL2 image embeddings.
- Builds a chat-style prompt with the image alias and user instruction.
- Streams tokens using ExLlamaV2DynamicGenerator / DynamicJob."""
# ------------------ CONFIG ------------------
MODEL_PATH = r"C:\Users\44741\Desktop\OpenCUA-7B-exl2"
IMAGE_URL = "http://images.cocodataset.org/val2017/000000001584.jpg"
INSTRUCTION = "Describe in detail everything you can see."
MAX_NEW_TOKENS = 600
# --------------------------------------------
import sys
import traceback
import torch
from PIL import Image
import requests
# ====================================================================
# --- MONKEY-PATCH FOR OPENCUA ARCHITECTURE (EXL2) ---
from exllamav2.architecture import (
ExLlamaV2ArchParams,
RopeStyle,
layer_keys_llama_norms,
layer_keys_llama_attn,
layer_keys_llama_mlp,
expect_keys_llama
)
print(" -- Applying OpenCUA architecture monkey-patch for inference...")
_original_arch_init = ExLlamaV2ArchParams.__init__
def _patched_arch_init(self, arch_string, read_config):
# Always call original first
_original_arch_init(self, arch_string, read_config)
# Then apply OpenCUA wiring if we detect the architecture string
if arch_string == "OpenCUAForConditionalGeneration":
print(" -- Found OpenCUA architecture, applying keys & RoPE settings...")
# Language model keys
self.lm_prefix = "language_model."
self.lm.layer_keys += (
layer_keys_llama_norms + layer_keys_llama_attn + layer_keys_llama_mlp
)
self.lm.expect_keys += expect_keys_llama
self.lm.attention_bias_qkv = True
self.lm.supports_tp = True
# Vision tower keys (Qwen2.5-style)
self.vt_prefix = "vision_tower."
read_config["vision_config"].update({"model_type": "qwen2.5"})
self.vt.keys.update({
"fused_qkv": ".attn.qkv",
"attn_o": ".attn.proj",
"mlp_gate": ".mlp.gate_proj",
"mlp_up": ".mlp.up_proj",
"mlp_down": ".mlp.down_proj",
"norm_1": ".norm1",
"norm_2": ".norm2",
"layers": "blocks",
"patch_conv": "patch_embed.proj",
})
self.vt.mlp_gate = True
self.vt.mlp_act_func = "silu"
self.vt.norm = "rmsnorm"
self.vt.mlp_bias = True
self.vt.attention_bias_qkv = True
self.vt.attention_bias_o = True
self.vt.vision_input_norm = False
self.vt.vision_conv3d = True
# IMPORTANT: Do NOT set RopeStyle.NONE; keep a valid 1-D RoPE if available
try:
if hasattr(RopeStyle, "LLAMA"):
self.vt.rope_style = RopeStyle.LLAMA
elif hasattr(RopeStyle, "HF"):
self.vt.rope_style = RopeStyle.HF
else:
# leave library default (works for Qwen2.5 vision)
pass
except Exception:
# In case some older exllamav2 builds behave differently
pass
# Vision merger/projection
self.vt.mlp_merger = True
self.mmp_prefix = "vision_tower.merger."
self.mmp.keys.update({
"mlp_gate": None,
"mlp_up": "mlp.0",
"mlp_down": "mlp.2",
"norm_2": "ln_q",
})
self.mmp.mlp_gate = False
self.mmp.mlp_act_func = "gelu"
self.mmp.mlp_bias = True
self.mmp.norm = "layernorm"
# Install patch
ExLlamaV2ArchParams.__init__ = _patched_arch_init
print(" -- Patch applied successfully.")
# ====================================================================
# Now we can import the rest of the library
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Tokenizer,
ExLlamaV2VisionTower,
)
from exllamav2.generator import (
ExLlamaV2DynamicGenerator,
ExLlamaV2Sampler,
ExLlamaV2DynamicJob, # <-- for streaming
)
def main():
try:
print(" -- Loading model/config...")
config = ExLlamaV2Config(MODEL_PATH) # Patch is applied during this call
# Optionally increase context if your EXL2 export supports it
# config.max_seq_len = 8192
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, lazy=True)
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config)
print(" -- Loading vision tower...")
vision_tower = ExLlamaV2VisionTower(config)
vision_tower.load()
try:
print(f"[Debug] vt.rope_style = {getattr(vision_tower, 'rope_style', 'n/a')}")
except Exception:
pass
generator = ExLlamaV2DynamicGenerator(model, cache, tokenizer)
print(f" -- Downloading test image from: {IMAGE_URL}")
image = Image.open(requests.get(IMAGE_URL, stream=True).raw).convert("RGB")
print(" -- Processing image and building prompt...")
image_embeddings = vision_tower.get_image_embeddings(model, tokenizer, image)
# Newline-separated alias is fine; here we have a single image
placeholders = image_embeddings.text_alias
prompt = (
f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n{placeholders}\n{INSTRUCTION}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
# Preview (mask the raw alias for readability)
print("\n--- Prompt Sent to Model ---")
print(prompt.replace(image_embeddings.text_alias, "<image>"))
print("----------------------------\n")
# ---------------- STREAMING OUTPUT ----------------
print("--- Model Output (streaming) ---")
gen_settings = ExLlamaV2Sampler.Settings.greedy()
# 1) Build input ids with image embeddings
input_ids = tokenizer.encode(
prompt,
add_bos=True,
encode_special_tokens=True,
embeddings=[image_embeddings] # ensure the alias binds correctly
)
# 2) Create a streaming job
job = ExLlamaV2DynamicJob(
input_ids=input_ids,
max_new_tokens=MAX_NEW_TOKENS,
decode_special_tokens=False, # keep consistent
gen_settings=gen_settings,
embeddings=[image_embeddings], # pass embeddings here as well
)
# 3) Enqueue, then iterate results as they arrive
generator.enqueue(job)
final_text = []
try:
while generator.num_remaining_jobs():
results = generator.iterate()
for r in results:
chunk = r.get("text", "")
if chunk:
print(chunk, end="", flush=True)
final_text.append(chunk)
finally:
print("\n\n--- Test Complete ---")
# If you want the full output string:
full_output = "".join(final_text)
# print("\n[DEBUG] Full output:\n", full_output)
# ---------------------------------------------------
except Exception as e:
print(f"\nAn error occurred: {e}")
traceback.print_exc()
if __name__ == "__main__":
# Small CUDA perf niceties (safe to ignore if CPU)
try:
torch.backends.cuda.matmul.allow_tf32 = True
except Exception:
pass
main()