sujitvasanth commited on
Commit
eb78691
·
verified ·
1 Parent(s): 6c5010b

Upload inference_example.py

Browse files

updated with better fully working image embeddings

Files changed (1) hide show
  1. inference_example.py +157 -91
inference_example.py CHANGED
@@ -1,12 +1,27 @@
1
- import sys, torch
2
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  # ====================================================================
4
- # --- MONKEY-PATCH FOR OPENCUA ARCHITECTURE ---
5
- # This block must come BEFORE any other exllamav2 imports.
6
- # It injects our custom model profile into the exllamav2 library at runtime.
7
- # make sure you have a version of exlllamav2 that
8
- # has exllamav2/vlm/vision_tower.py
9
-
10
  from exllamav2.architecture import (
11
  ExLlamaV2ArchParams,
12
  RopeStyle,
@@ -18,30 +33,26 @@ from exllamav2.architecture import (
18
 
19
  print(" -- Applying OpenCUA architecture monkey-patch for inference...")
20
 
21
- # Store a reference to the original __init__ method
22
- original_init = ExLlamaV2ArchParams.__init__
 
 
 
23
 
24
- # Define our new, patched __init__ method
25
- def patched_init(self, arch_string, read_config):
26
-
27
- # --- Our Custom Logic ---
28
  if arch_string == "OpenCUAForConditionalGeneration":
29
-
30
- # This is our entire custom profile, verified from our debugging
31
- arch_recognized = True
32
-
33
- # --- Language Model settings ---
34
  self.lm_prefix = "language_model."
35
- self.lm.layer_keys += \
36
- layer_keys_llama_norms + \
37
- layer_keys_llama_attn + \
38
- layer_keys_llama_mlp
39
- self.lm.expect_keys += \
40
- expect_keys_llama
41
  self.lm.attention_bias_qkv = True
42
  self.lm.supports_tp = True
43
 
44
- # --- Vision Tower settings ---
45
  self.vt_prefix = "vision_tower."
46
  read_config["vision_config"].update({"model_type": "qwen2.5"})
47
  self.vt.keys.update({
@@ -63,10 +74,22 @@ def patched_init(self, arch_string, read_config):
63
  self.vt.attention_bias_o = True
64
  self.vt.vision_input_norm = False
65
  self.vt.vision_conv3d = True
66
- self.vt.rope_style = RopeStyle.NONE
67
- self.vt.mlp_merger = True
68
 
69
- # --- Multi-Modal Projector/Merger settings ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  self.mmp_prefix = "vision_tower.merger."
71
  self.mmp.keys.update({
72
  "mlp_gate": None,
@@ -78,21 +101,13 @@ def patched_init(self, arch_string, read_config):
78
  self.mmp.mlp_act_func = "gelu"
79
  self.mmp.mlp_bias = True
80
  self.mmp.norm = "layernorm"
81
-
82
- # --- Fallback to Original ---
83
- else:
84
- # If it's not our model, call the original __init__ method
85
- original_init(self, arch_string, read_config)
86
-
87
- # Overwrite the class's __init__ method with our patched version
88
- ExLlamaV2ArchParams.__init__ = patched_init
89
- print(" -- Patch applied successfully.")
90
 
91
- # --- END OF MONKEY-PATCH ---
 
 
92
  # ====================================================================
93
 
94
-
95
- # NOW we can import the rest of the library
96
  from exllamav2 import (
97
  ExLlamaV2,
98
  ExLlamaV2Config,
@@ -100,54 +115,105 @@ from exllamav2 import (
100
  ExLlamaV2Tokenizer,
101
  ExLlamaV2VisionTower,
102
  )
103
- from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler
104
- from PIL import Image
105
- import requests
106
- import traceback
 
107
 
108
- MODEL_PATH = "/home/sujit/OpenCUA-7B-exl2"
109
- IMAGE_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"
110
-
111
- try:
112
- print(" -- Loading model...")
113
- config = ExLlamaV2Config(MODEL_PATH) # <-- The patch is active when this line runs
114
- model = ExLlamaV2(config)
115
- cache = ExLlamaV2Cache(model, lazy=True)
116
- model.load_autosplit(cache)
117
- tokenizer = ExLlamaV2Tokenizer(config)
118
-
119
- print(" -- Loading vision tower...")
120
- vision_tower = ExLlamaV2VisionTower(config)
121
- vision_tower.load()
122
-
123
- generator = ExLlamaV2DynamicGenerator(model, cache, tokenizer)
124
-
125
- print(f" -- Downloading test image from: {IMAGE_URL}")
126
- image = Image.open(requests.get(IMAGE_URL, stream=True).raw).convert("RGB")
127
- instruction = "Describe what you see in this image in detail."
128
-
129
- print(" -- Processing image and building prompt...")
130
- image_embeddings = vision_tower.get_image_embeddings(model, tokenizer, image)
131
-
132
- prompt = f"<|user|>\n{image_embeddings.text_alias}\n{instruction}<|end|>\n<|assistant|>"
133
-
134
- print(f"\n--- Prompt Sent to Model ---\n{prompt.replace(image_embeddings.text_alias, '<image>')}\n----------------------------")
135
- print("\n--- Model Output ---")
136
-
137
- gen_settings = ExLlamaV2Sampler.Settings.greedy()
138
-
139
- output = generator.generate(
140
- prompt=prompt,
141
- max_new_tokens=200,
142
- add_bos=True,
143
- embeddings=[image_embeddings],
144
- gen_settings=gen_settings,
145
- decode_special_tokens=True,
146
- )
147
-
148
- print(output)
149
- print("\n--- Test Complete ---")
150
-
151
- except Exception as e:
152
- print(f"\nAn error occurred: {e}")
153
- traceback.print_exc()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """OpenCUA-7B EXL2 — Standalone Visual Inference (Streaming)
4
+ Tested On exllamav2 0.3.2, python3.12.9, torch 2.6.0+cu126
5
+ - Applies a minimal, safe monkey-patch so ExLlamaV2 knows how to wire the
6
+ OpenCUA EXL2 architecture (Qwen2.5-style vision tower + Llama-like LM).
7
+ - Keeps vision RoPE active (DO NOT neutralize positional embeddings).
8
+ - Chooses a valid 1-D RoPE style if available (LLAMA > HF > default).
9
+ - Loads model + vision tower, extracts EXL2 image embeddings.
10
+ - Builds a chat-style prompt with the image alias and user instruction.
11
+ - Streams tokens using ExLlamaV2DynamicGenerator / DynamicJob."""
12
+ # ------------------ CONFIG ------------------
13
+ MODEL_PATH = r"C:\Users\44741\Desktop\OpenCUA-7B-exl2"
14
+ IMAGE_URL = "http://images.cocodataset.org/val2017/000000001584.jpg"
15
+ INSTRUCTION = "Describe in detail everything you can see."
16
+ MAX_NEW_TOKENS = 600
17
+ # --------------------------------------------
18
+ import sys
19
+ import traceback
20
+ import torch
21
+ from PIL import Image
22
+ import requests
23
  # ====================================================================
24
+ # --- MONKEY-PATCH FOR OPENCUA ARCHITECTURE (EXL2) ---
 
 
 
 
 
25
  from exllamav2.architecture import (
26
  ExLlamaV2ArchParams,
27
  RopeStyle,
 
33
 
34
  print(" -- Applying OpenCUA architecture monkey-patch for inference...")
35
 
36
+ _original_arch_init = ExLlamaV2ArchParams.__init__
37
+
38
+ def _patched_arch_init(self, arch_string, read_config):
39
+ # Always call original first
40
+ _original_arch_init(self, arch_string, read_config)
41
 
42
+ # Then apply OpenCUA wiring if we detect the architecture string
 
 
 
43
  if arch_string == "OpenCUAForConditionalGeneration":
44
+ print(" -- Found OpenCUA architecture, applying keys & RoPE settings...")
45
+
46
+ # Language model keys
 
 
47
  self.lm_prefix = "language_model."
48
+ self.lm.layer_keys += (
49
+ layer_keys_llama_norms + layer_keys_llama_attn + layer_keys_llama_mlp
50
+ )
51
+ self.lm.expect_keys += expect_keys_llama
 
 
52
  self.lm.attention_bias_qkv = True
53
  self.lm.supports_tp = True
54
 
55
+ # Vision tower keys (Qwen2.5-style)
56
  self.vt_prefix = "vision_tower."
57
  read_config["vision_config"].update({"model_type": "qwen2.5"})
58
  self.vt.keys.update({
 
74
  self.vt.attention_bias_o = True
75
  self.vt.vision_input_norm = False
76
  self.vt.vision_conv3d = True
 
 
77
 
78
+ # IMPORTANT: Do NOT set RopeStyle.NONE; keep a valid 1-D RoPE if available
79
+ try:
80
+ if hasattr(RopeStyle, "LLAMA"):
81
+ self.vt.rope_style = RopeStyle.LLAMA
82
+ elif hasattr(RopeStyle, "HF"):
83
+ self.vt.rope_style = RopeStyle.HF
84
+ else:
85
+ # leave library default (works for Qwen2.5 vision)
86
+ pass
87
+ except Exception:
88
+ # In case some older exllamav2 builds behave differently
89
+ pass
90
+
91
+ # Vision merger/projection
92
+ self.vt.mlp_merger = True
93
  self.mmp_prefix = "vision_tower.merger."
94
  self.mmp.keys.update({
95
  "mlp_gate": None,
 
101
  self.mmp.mlp_act_func = "gelu"
102
  self.mmp.mlp_bias = True
103
  self.mmp.norm = "layernorm"
 
 
 
 
 
 
 
 
 
104
 
105
+ # Install patch
106
+ ExLlamaV2ArchParams.__init__ = _patched_arch_init
107
+ print(" -- Patch applied successfully.")
108
  # ====================================================================
109
 
110
+ # Now we can import the rest of the library
 
111
  from exllamav2 import (
112
  ExLlamaV2,
113
  ExLlamaV2Config,
 
115
  ExLlamaV2Tokenizer,
116
  ExLlamaV2VisionTower,
117
  )
118
+ from exllamav2.generator import (
119
+ ExLlamaV2DynamicGenerator,
120
+ ExLlamaV2Sampler,
121
+ ExLlamaV2DynamicJob, # <-- for streaming
122
+ )
123
 
124
+ def main():
125
+ try:
126
+ print(" -- Loading model/config...")
127
+ config = ExLlamaV2Config(MODEL_PATH) # Patch is applied during this call
128
+ # Optionally increase context if your EXL2 export supports it
129
+ # config.max_seq_len = 8192
130
+
131
+ model = ExLlamaV2(config)
132
+ cache = ExLlamaV2Cache(model, lazy=True)
133
+ model.load_autosplit(cache)
134
+
135
+ tokenizer = ExLlamaV2Tokenizer(config)
136
+
137
+ print(" -- Loading vision tower...")
138
+ vision_tower = ExLlamaV2VisionTower(config)
139
+ vision_tower.load()
140
+ try:
141
+ print(f"[Debug] vt.rope_style = {getattr(vision_tower, 'rope_style', 'n/a')}")
142
+ except Exception:
143
+ pass
144
+
145
+ generator = ExLlamaV2DynamicGenerator(model, cache, tokenizer)
146
+
147
+ print(f" -- Downloading test image from: {IMAGE_URL}")
148
+ image = Image.open(requests.get(IMAGE_URL, stream=True).raw).convert("RGB")
149
+
150
+ print(" -- Processing image and building prompt...")
151
+ image_embeddings = vision_tower.get_image_embeddings(model, tokenizer, image)
152
+
153
+ # Newline-separated alias is fine; here we have a single image
154
+ placeholders = image_embeddings.text_alias
155
+ prompt = (
156
+ f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
157
+ f"<|im_start|>user\n{placeholders}\n{INSTRUCTION}<|im_end|>\n"
158
+ f"<|im_start|>assistant\n"
159
+ )
160
+
161
+ # Preview (mask the raw alias for readability)
162
+ print("\n--- Prompt Sent to Model ---")
163
+ print(prompt.replace(image_embeddings.text_alias, "<image>"))
164
+ print("----------------------------\n")
165
+
166
+ # ---------------- STREAMING OUTPUT ----------------
167
+ print("--- Model Output (streaming) ---")
168
+ gen_settings = ExLlamaV2Sampler.Settings.greedy()
169
+
170
+ # 1) Build input ids with image embeddings
171
+ input_ids = tokenizer.encode(
172
+ prompt,
173
+ add_bos=True,
174
+ encode_special_tokens=True,
175
+ embeddings=[image_embeddings] # ensure the alias binds correctly
176
+ )
177
+
178
+ # 2) Create a streaming job
179
+ job = ExLlamaV2DynamicJob(
180
+ input_ids=input_ids,
181
+ max_new_tokens=MAX_NEW_TOKENS,
182
+ decode_special_tokens=False, # keep consistent
183
+ gen_settings=gen_settings,
184
+ embeddings=[image_embeddings], # pass embeddings here as well
185
+ )
186
+
187
+ # 3) Enqueue, then iterate results as they arrive
188
+ generator.enqueue(job)
189
+
190
+ final_text = []
191
+ try:
192
+ while generator.num_remaining_jobs():
193
+ results = generator.iterate()
194
+ for r in results:
195
+ chunk = r.get("text", "")
196
+ if chunk:
197
+ print(chunk, end="", flush=True)
198
+ final_text.append(chunk)
199
+ finally:
200
+ print("\n\n--- Test Complete ---")
201
+
202
+ # If you want the full output string:
203
+ full_output = "".join(final_text)
204
+ # print("\n[DEBUG] Full output:\n", full_output)
205
+ # ---------------------------------------------------
206
+
207
+ except Exception as e:
208
+ print(f"\nAn error occurred: {e}")
209
+ traceback.print_exc()
210
+
211
+
212
+ if __name__ == "__main__":
213
+ # Small CUDA perf niceties (safe to ignore if CPU)
214
+ try:
215
+ torch.backends.cuda.matmul.allow_tf32 = True
216
+ except Exception:
217
+ pass
218
+
219
+ main()