Upload inference_example.py
Browse filesupdated with better fully working image embeddings
- inference_example.py +157 -91
inference_example.py
CHANGED
|
@@ -1,12 +1,27 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
# ====================================================================
|
| 4 |
-
# --- MONKEY-PATCH FOR OPENCUA ARCHITECTURE ---
|
| 5 |
-
# This block must come BEFORE any other exllamav2 imports.
|
| 6 |
-
# It injects our custom model profile into the exllamav2 library at runtime.
|
| 7 |
-
# make sure you have a version of exlllamav2 that
|
| 8 |
-
# has exllamav2/vlm/vision_tower.py
|
| 9 |
-
|
| 10 |
from exllamav2.architecture import (
|
| 11 |
ExLlamaV2ArchParams,
|
| 12 |
RopeStyle,
|
|
@@ -18,30 +33,26 @@ from exllamav2.architecture import (
|
|
| 18 |
|
| 19 |
print(" -- Applying OpenCUA architecture monkey-patch for inference...")
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
#
|
| 25 |
-
def patched_init(self, arch_string, read_config):
|
| 26 |
-
|
| 27 |
-
# --- Our Custom Logic ---
|
| 28 |
if arch_string == "OpenCUAForConditionalGeneration":
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
# --- Language Model settings ---
|
| 34 |
self.lm_prefix = "language_model."
|
| 35 |
-
self.lm.layer_keys +=
|
| 36 |
-
layer_keys_llama_norms +
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
self.lm.expect_keys += \
|
| 40 |
-
expect_keys_llama
|
| 41 |
self.lm.attention_bias_qkv = True
|
| 42 |
self.lm.supports_tp = True
|
| 43 |
|
| 44 |
-
#
|
| 45 |
self.vt_prefix = "vision_tower."
|
| 46 |
read_config["vision_config"].update({"model_type": "qwen2.5"})
|
| 47 |
self.vt.keys.update({
|
|
@@ -63,10 +74,22 @@ def patched_init(self, arch_string, read_config):
|
|
| 63 |
self.vt.attention_bias_o = True
|
| 64 |
self.vt.vision_input_norm = False
|
| 65 |
self.vt.vision_conv3d = True
|
| 66 |
-
self.vt.rope_style = RopeStyle.NONE
|
| 67 |
-
self.vt.mlp_merger = True
|
| 68 |
|
| 69 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
self.mmp_prefix = "vision_tower.merger."
|
| 71 |
self.mmp.keys.update({
|
| 72 |
"mlp_gate": None,
|
|
@@ -78,21 +101,13 @@ def patched_init(self, arch_string, read_config):
|
|
| 78 |
self.mmp.mlp_act_func = "gelu"
|
| 79 |
self.mmp.mlp_bias = True
|
| 80 |
self.mmp.norm = "layernorm"
|
| 81 |
-
|
| 82 |
-
# --- Fallback to Original ---
|
| 83 |
-
else:
|
| 84 |
-
# If it's not our model, call the original __init__ method
|
| 85 |
-
original_init(self, arch_string, read_config)
|
| 86 |
-
|
| 87 |
-
# Overwrite the class's __init__ method with our patched version
|
| 88 |
-
ExLlamaV2ArchParams.__init__ = patched_init
|
| 89 |
-
print(" -- Patch applied successfully.")
|
| 90 |
|
| 91 |
-
#
|
|
|
|
|
|
|
| 92 |
# ====================================================================
|
| 93 |
|
| 94 |
-
|
| 95 |
-
# NOW we can import the rest of the library
|
| 96 |
from exllamav2 import (
|
| 97 |
ExLlamaV2,
|
| 98 |
ExLlamaV2Config,
|
|
@@ -100,54 +115,105 @@ from exllamav2 import (
|
|
| 100 |
ExLlamaV2Tokenizer,
|
| 101 |
ExLlamaV2VisionTower,
|
| 102 |
)
|
| 103 |
-
from exllamav2.generator import
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""OpenCUA-7B EXL2 — Standalone Visual Inference (Streaming)
|
| 4 |
+
Tested On exllamav2 0.3.2, python3.12.9, torch 2.6.0+cu126
|
| 5 |
+
- Applies a minimal, safe monkey-patch so ExLlamaV2 knows how to wire the
|
| 6 |
+
OpenCUA EXL2 architecture (Qwen2.5-style vision tower + Llama-like LM).
|
| 7 |
+
- Keeps vision RoPE active (DO NOT neutralize positional embeddings).
|
| 8 |
+
- Chooses a valid 1-D RoPE style if available (LLAMA > HF > default).
|
| 9 |
+
- Loads model + vision tower, extracts EXL2 image embeddings.
|
| 10 |
+
- Builds a chat-style prompt with the image alias and user instruction.
|
| 11 |
+
- Streams tokens using ExLlamaV2DynamicGenerator / DynamicJob."""
|
| 12 |
+
# ------------------ CONFIG ------------------
|
| 13 |
+
MODEL_PATH = r"C:\Users\44741\Desktop\OpenCUA-7B-exl2"
|
| 14 |
+
IMAGE_URL = "http://images.cocodataset.org/val2017/000000001584.jpg"
|
| 15 |
+
INSTRUCTION = "Describe in detail everything you can see."
|
| 16 |
+
MAX_NEW_TOKENS = 600
|
| 17 |
+
# --------------------------------------------
|
| 18 |
+
import sys
|
| 19 |
+
import traceback
|
| 20 |
+
import torch
|
| 21 |
+
from PIL import Image
|
| 22 |
+
import requests
|
| 23 |
# ====================================================================
|
| 24 |
+
# --- MONKEY-PATCH FOR OPENCUA ARCHITECTURE (EXL2) ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
from exllamav2.architecture import (
|
| 26 |
ExLlamaV2ArchParams,
|
| 27 |
RopeStyle,
|
|
|
|
| 33 |
|
| 34 |
print(" -- Applying OpenCUA architecture monkey-patch for inference...")
|
| 35 |
|
| 36 |
+
_original_arch_init = ExLlamaV2ArchParams.__init__
|
| 37 |
+
|
| 38 |
+
def _patched_arch_init(self, arch_string, read_config):
|
| 39 |
+
# Always call original first
|
| 40 |
+
_original_arch_init(self, arch_string, read_config)
|
| 41 |
|
| 42 |
+
# Then apply OpenCUA wiring if we detect the architecture string
|
|
|
|
|
|
|
|
|
|
| 43 |
if arch_string == "OpenCUAForConditionalGeneration":
|
| 44 |
+
print(" -- Found OpenCUA architecture, applying keys & RoPE settings...")
|
| 45 |
+
|
| 46 |
+
# Language model keys
|
|
|
|
|
|
|
| 47 |
self.lm_prefix = "language_model."
|
| 48 |
+
self.lm.layer_keys += (
|
| 49 |
+
layer_keys_llama_norms + layer_keys_llama_attn + layer_keys_llama_mlp
|
| 50 |
+
)
|
| 51 |
+
self.lm.expect_keys += expect_keys_llama
|
|
|
|
|
|
|
| 52 |
self.lm.attention_bias_qkv = True
|
| 53 |
self.lm.supports_tp = True
|
| 54 |
|
| 55 |
+
# Vision tower keys (Qwen2.5-style)
|
| 56 |
self.vt_prefix = "vision_tower."
|
| 57 |
read_config["vision_config"].update({"model_type": "qwen2.5"})
|
| 58 |
self.vt.keys.update({
|
|
|
|
| 74 |
self.vt.attention_bias_o = True
|
| 75 |
self.vt.vision_input_norm = False
|
| 76 |
self.vt.vision_conv3d = True
|
|
|
|
|
|
|
| 77 |
|
| 78 |
+
# IMPORTANT: Do NOT set RopeStyle.NONE; keep a valid 1-D RoPE if available
|
| 79 |
+
try:
|
| 80 |
+
if hasattr(RopeStyle, "LLAMA"):
|
| 81 |
+
self.vt.rope_style = RopeStyle.LLAMA
|
| 82 |
+
elif hasattr(RopeStyle, "HF"):
|
| 83 |
+
self.vt.rope_style = RopeStyle.HF
|
| 84 |
+
else:
|
| 85 |
+
# leave library default (works for Qwen2.5 vision)
|
| 86 |
+
pass
|
| 87 |
+
except Exception:
|
| 88 |
+
# In case some older exllamav2 builds behave differently
|
| 89 |
+
pass
|
| 90 |
+
|
| 91 |
+
# Vision merger/projection
|
| 92 |
+
self.vt.mlp_merger = True
|
| 93 |
self.mmp_prefix = "vision_tower.merger."
|
| 94 |
self.mmp.keys.update({
|
| 95 |
"mlp_gate": None,
|
|
|
|
| 101 |
self.mmp.mlp_act_func = "gelu"
|
| 102 |
self.mmp.mlp_bias = True
|
| 103 |
self.mmp.norm = "layernorm"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
+
# Install patch
|
| 106 |
+
ExLlamaV2ArchParams.__init__ = _patched_arch_init
|
| 107 |
+
print(" -- Patch applied successfully.")
|
| 108 |
# ====================================================================
|
| 109 |
|
| 110 |
+
# Now we can import the rest of the library
|
|
|
|
| 111 |
from exllamav2 import (
|
| 112 |
ExLlamaV2,
|
| 113 |
ExLlamaV2Config,
|
|
|
|
| 115 |
ExLlamaV2Tokenizer,
|
| 116 |
ExLlamaV2VisionTower,
|
| 117 |
)
|
| 118 |
+
from exllamav2.generator import (
|
| 119 |
+
ExLlamaV2DynamicGenerator,
|
| 120 |
+
ExLlamaV2Sampler,
|
| 121 |
+
ExLlamaV2DynamicJob, # <-- for streaming
|
| 122 |
+
)
|
| 123 |
|
| 124 |
+
def main():
|
| 125 |
+
try:
|
| 126 |
+
print(" -- Loading model/config...")
|
| 127 |
+
config = ExLlamaV2Config(MODEL_PATH) # Patch is applied during this call
|
| 128 |
+
# Optionally increase context if your EXL2 export supports it
|
| 129 |
+
# config.max_seq_len = 8192
|
| 130 |
+
|
| 131 |
+
model = ExLlamaV2(config)
|
| 132 |
+
cache = ExLlamaV2Cache(model, lazy=True)
|
| 133 |
+
model.load_autosplit(cache)
|
| 134 |
+
|
| 135 |
+
tokenizer = ExLlamaV2Tokenizer(config)
|
| 136 |
+
|
| 137 |
+
print(" -- Loading vision tower...")
|
| 138 |
+
vision_tower = ExLlamaV2VisionTower(config)
|
| 139 |
+
vision_tower.load()
|
| 140 |
+
try:
|
| 141 |
+
print(f"[Debug] vt.rope_style = {getattr(vision_tower, 'rope_style', 'n/a')}")
|
| 142 |
+
except Exception:
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
generator = ExLlamaV2DynamicGenerator(model, cache, tokenizer)
|
| 146 |
+
|
| 147 |
+
print(f" -- Downloading test image from: {IMAGE_URL}")
|
| 148 |
+
image = Image.open(requests.get(IMAGE_URL, stream=True).raw).convert("RGB")
|
| 149 |
+
|
| 150 |
+
print(" -- Processing image and building prompt...")
|
| 151 |
+
image_embeddings = vision_tower.get_image_embeddings(model, tokenizer, image)
|
| 152 |
+
|
| 153 |
+
# Newline-separated alias is fine; here we have a single image
|
| 154 |
+
placeholders = image_embeddings.text_alias
|
| 155 |
+
prompt = (
|
| 156 |
+
f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
| 157 |
+
f"<|im_start|>user\n{placeholders}\n{INSTRUCTION}<|im_end|>\n"
|
| 158 |
+
f"<|im_start|>assistant\n"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Preview (mask the raw alias for readability)
|
| 162 |
+
print("\n--- Prompt Sent to Model ---")
|
| 163 |
+
print(prompt.replace(image_embeddings.text_alias, "<image>"))
|
| 164 |
+
print("----------------------------\n")
|
| 165 |
+
|
| 166 |
+
# ---------------- STREAMING OUTPUT ----------------
|
| 167 |
+
print("--- Model Output (streaming) ---")
|
| 168 |
+
gen_settings = ExLlamaV2Sampler.Settings.greedy()
|
| 169 |
+
|
| 170 |
+
# 1) Build input ids with image embeddings
|
| 171 |
+
input_ids = tokenizer.encode(
|
| 172 |
+
prompt,
|
| 173 |
+
add_bos=True,
|
| 174 |
+
encode_special_tokens=True,
|
| 175 |
+
embeddings=[image_embeddings] # ensure the alias binds correctly
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# 2) Create a streaming job
|
| 179 |
+
job = ExLlamaV2DynamicJob(
|
| 180 |
+
input_ids=input_ids,
|
| 181 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
| 182 |
+
decode_special_tokens=False, # keep consistent
|
| 183 |
+
gen_settings=gen_settings,
|
| 184 |
+
embeddings=[image_embeddings], # pass embeddings here as well
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# 3) Enqueue, then iterate results as they arrive
|
| 188 |
+
generator.enqueue(job)
|
| 189 |
+
|
| 190 |
+
final_text = []
|
| 191 |
+
try:
|
| 192 |
+
while generator.num_remaining_jobs():
|
| 193 |
+
results = generator.iterate()
|
| 194 |
+
for r in results:
|
| 195 |
+
chunk = r.get("text", "")
|
| 196 |
+
if chunk:
|
| 197 |
+
print(chunk, end="", flush=True)
|
| 198 |
+
final_text.append(chunk)
|
| 199 |
+
finally:
|
| 200 |
+
print("\n\n--- Test Complete ---")
|
| 201 |
+
|
| 202 |
+
# If you want the full output string:
|
| 203 |
+
full_output = "".join(final_text)
|
| 204 |
+
# print("\n[DEBUG] Full output:\n", full_output)
|
| 205 |
+
# ---------------------------------------------------
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
print(f"\nAn error occurred: {e}")
|
| 209 |
+
traceback.print_exc()
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
# Small CUDA perf niceties (safe to ignore if CPU)
|
| 214 |
+
try:
|
| 215 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 216 |
+
except Exception:
|
| 217 |
+
pass
|
| 218 |
+
|
| 219 |
+
main()
|