alexnasa commited on
Commit
285eb4b
·
verified ·
1 Parent(s): 0087df7

Update inference_coz_single.py

Browse files
Files changed (1) hide show
  1. inference_coz_single.py +59 -75
inference_coz_single.py CHANGED
@@ -25,66 +25,77 @@ def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image:
25
  # Helper: Generate a single VLM prompt for recursive_multiscale
26
  # -------------------------------------------------------------------
27
  def _generate_vlm_prompt(
28
- vlm_model,
29
- vlm_processor,
30
- process_vision_info,
31
- prev_image_path: str,
32
- zoomed_image_path: str,
33
  device: str = "cuda"
34
  ) -> str:
35
  """
36
- Given two image file paths:
37
- - prev_image_path: the “full” image at the previous recursion.
38
- - zoomed_image_path: the cropped+resized (zoom) image for this step.
39
- This builds a single “recursive_multiscale” prompt via Qwen2.5-VL.
40
- Returns a string like “cat on sofa, pet, indoor, living room”, etc.
41
  """
42
- # (1) Define the system message for recursive_multiscale:
 
43
  message_text = (
44
  "The second image is a zoom-in of the first image. "
45
  "Based on this knowledge, what is in the second image? "
46
  "Give me a set of words."
47
  )
48
 
49
- # (2) Build the two-image “chat” payload:
 
 
 
 
50
  messages = [
51
  {"role": "system", "content": message_text},
52
  {
53
  "role": "user",
54
  "content": [
55
- {"type": "image", "image": prev_image_path},
56
- {"type": "image", "image": zoomed_image_path},
57
  ],
58
  },
59
  ]
60
 
61
- # (3) Wrap through the VL processor to get “inputs”:
 
 
 
 
62
  text = vlm_processor.apply_chat_template(
63
- messages, tokenize=False, add_generation_prompt=True
 
 
64
  )
65
  image_inputs, video_inputs = process_vision_info(messages)
 
66
  inputs = vlm_processor(
67
- text=[text],
68
- images=image_inputs,
69
- videos=video_inputs,
70
- padding=True,
71
  return_tensors="pt",
72
  ).to(device)
73
 
74
- # (4) Generate tokens decode
75
  generated = vlm_model.generate(**inputs, max_new_tokens=128)
76
- # strip off the prompt tokens from each generated sequence:
77
  trimmed = [
78
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated)
 
79
  ]
80
  out_text = vlm_processor.batch_decode(
81
  trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
82
  )[0]
83
 
84
- # (5) Return exactly the bare words (no extra “,” if no additional user prompt)
85
  return out_text.strip()
86
 
87
 
 
88
  # -------------------------------------------------------------------
89
  # Main Function: recursive_multiscale_sr (with multiple centers)
90
  # -------------------------------------------------------------------
@@ -203,88 +214,61 @@ def recursive_multiscale_sr(
203
  ###############################
204
  # 6. Prepare the very first “full” image
205
  ###############################
206
- # 6.1 Load + center crop → first_image is (512×512) PIL on CPU
207
  img0 = Image.open(input_png_path).convert("RGB")
208
  img0 = resize_and_center_crop(img0, process_size)
209
 
210
- # 6.2 Save it once so VLM can read it as “prev.png”
211
- prev_path = os.path.join(td, "step0_prev.png")
212
- img0.save(prev_path)
213
 
214
- # We will maintain lists of PIL outputs and prompts:
215
  sr_pil_list: list[Image.Image] = []
216
- prompt_list: list[str] = []
217
 
218
- ###############################
219
- # 7. Recursion loop (now up to rec_num times)
220
- ###############################
221
  for rec in range(rec_num):
222
- # (A) Load the previous SR output (or original) and compute crop window
223
- prev_pil = Image.open(prev_path).convert("RGB")
224
- w, h = prev_pil.size # should be (512×512) each time
225
 
226
- # (1) Compute the “low-res” window size:
227
- new_w, new_h = w // upscale, h // upscale # e.g. 128×128 for upscale=4
228
-
229
- # (2) Map normalized center → pixel center, then clamp so crop stays in bounds:
230
  cx_norm, cy_norm = centers[rec]
231
  cx = int(cx_norm * w)
232
  cy = int(cy_norm * h)
233
- half_w = new_w // 2
234
- half_h = new_h // 2
235
-
236
- # If center in pixels is too close to left/top, clamp so left=0 or top=0; same on right/bottom
237
- left = cx - half_w
238
- top = cy - half_h
239
- # clamp left ∈ [0, w - new_w], top ∈ [0, h - new_h]
240
- left = max(0, min(left, w - new_w))
241
- top = max(0, min(top, h - new_h))
242
- right = left + new_w
243
- bottom = top + new_h
244
 
245
  cropped = prev_pil.crop((left, top, right, bottom))
246
 
247
- # (B) Resize that crop back up to (512×512) via BICUBIC → zoomed
248
- zoomed = cropped.resize((w, h), Image.BICUBIC)
249
- zoom_path = os.path.join(td, f"step{rec+1}_zoom.png")
250
- zoomed.save(zoom_path)
251
 
252
- # (C) Generate a recursive_multiscale VLM “tag” prompt
253
  prompt_tag = _generate_vlm_prompt(
254
  vlm_model=vlm_model,
255
  vlm_processor=vlm_processor,
256
  process_vision_info=process_vision_info,
257
- prev_image_path=prev_path,
258
- zoomed_image_path=zoom_path,
259
  device=device,
260
  )
261
- # (By default, no extra user prompt is appended.)
262
 
263
- # (D) Prepare the low-res tensor for SR: convert zoomed → Tensor → [0,1] → [−1,1]
264
  to_tensor = transforms.ToTensor()
265
- lq = to_tensor(zoomed).unsqueeze(0).to(device) # shape (1,3,512,512)
266
  lq = (lq * 2.0) - 1.0
267
 
268
- # (E) Do SR inference:
269
  with torch.no_grad():
270
- out_tensor = model_test(lq, prompt=prompt_tag)[0] # (3,512,512) on CPU or GPU
271
  out_tensor = out_tensor.clamp(-1.0, 1.0).cpu()
272
- # back to PIL in [0,1]:
273
  out_pil = transforms.ToPILImage()((out_tensor * 0.5) + 0.5)
274
 
275
- # (F) Save this step’s SR output as “prev.png” for next iteration:
276
- out_path = os.path.join(td, f"step{rec+1}_sr.png")
277
- out_pil.save(out_path)
278
- prev_path = out_path
279
 
280
- # (G) Append the PIL to our list:
281
  sr_pil_list.append(out_pil)
282
  prompt_list.append(prompt_tag)
283
 
284
- # end for(rec)
285
-
286
- ###############################
287
- # 8. Return the SR outputs & prompts
288
- ###############################
289
- # The list sr_pil_list = [ SR1, SR2, …, SR_rec_num ] in order.
290
  return sr_pil_list, prompt_list
 
25
  # Helper: Generate a single VLM prompt for recursive_multiscale
26
  # -------------------------------------------------------------------
27
  def _generate_vlm_prompt(
28
+ vlm_model: Qwen2_5_VLForConditionalGeneration,
29
+ vlm_processor: AutoProcessor,
30
+ process_vision_info, # this is your helper that turns “messages” → image_inputs / video_inputs
31
+ prev_pil: Image.Image, # <– pass PIL instead of path
32
+ zoomed_pil: Image.Image, # <– pass PIL instead of path
33
  device: str = "cuda"
34
  ) -> str:
35
  """
36
+ Given two PIL.Image inputs:
37
+ - prev_pil: the “full” image at the previous recursion.
38
+ - zoomed_pil: the cropped+resized (zoom) image for this step.
39
+ Returns a single “recursive_multiscale” prompt string.
 
40
  """
41
+
42
+ # (1) System message
43
  message_text = (
44
  "The second image is a zoom-in of the first image. "
45
  "Based on this knowledge, what is in the second image? "
46
  "Give me a set of words."
47
  )
48
 
49
+ # (2) Build the two-image “chat” payload
50
+ #
51
+ # Instead of passing a filename, we pass the actual PIL.Image.
52
+ # The processor’s `process_vision_info` should know how to turn
53
+ # a message of the form {"type":"image","image": PIL_IMAGE} into tensors.
54
  messages = [
55
  {"role": "system", "content": message_text},
56
  {
57
  "role": "user",
58
  "content": [
59
+ {"type": "image", "image": prev_pil},
60
+ {"type": "image", "image": zoomed_pil},
61
  ],
62
  },
63
  ]
64
 
65
+ # (3) Now run the “chat” through the VL processor
66
+ #
67
+ # - `apply_chat_template` will build the tokenized prompt (without running it yet).
68
+ # - `process_vision_info` should inspect the same `messages` list and return
69
+ # `image_inputs` and `video_inputs` (tensors) for any attached PIL images.
70
  text = vlm_processor.apply_chat_template(
71
+ messages,
72
+ tokenize=False,
73
+ add_generation_prompt=True
74
  )
75
  image_inputs, video_inputs = process_vision_info(messages)
76
+
77
  inputs = vlm_processor(
78
+ text=[text],
79
+ images=image_inputs,
80
+ videos=video_inputs,
81
+ padding=True,
82
  return_tensors="pt",
83
  ).to(device)
84
 
85
+ # (4) Generate and decode
86
  generated = vlm_model.generate(**inputs, max_new_tokens=128)
 
87
  trimmed = [
88
+ out_ids[len(in_ids):]
89
+ for in_ids, out_ids in zip(inputs.input_ids, generated)
90
  ]
91
  out_text = vlm_processor.batch_decode(
92
  trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
93
  )[0]
94
 
 
95
  return out_text.strip()
96
 
97
 
98
+
99
  # -------------------------------------------------------------------
100
  # Main Function: recursive_multiscale_sr (with multiple centers)
101
  # -------------------------------------------------------------------
 
214
  ###############################
215
  # 6. Prepare the very first “full” image
216
  ###############################
217
+ # (6.1) Load + center crop → first_image (512×512)
218
  img0 = Image.open(input_png_path).convert("RGB")
219
  img0 = resize_and_center_crop(img0, process_size)
220
 
221
+ # Note: we no longer need to write “prev.png” to disk. Just keep it in memory.
222
+ prev_pil = img0.copy()
 
223
 
 
224
  sr_pil_list: list[Image.Image] = []
225
+ prompt_list: list[str] = []
226
 
 
 
 
227
  for rec in range(rec_num):
228
+ # (A) Compute low-res crop window on prev_pil
229
+ w, h = prev_pil.size # (512×512)
230
+ new_w, new_h = w // upscale, h // upscale
231
 
 
 
 
 
232
  cx_norm, cy_norm = centers[rec]
233
  cx = int(cx_norm * w)
234
  cy = int(cy_norm * h)
235
+ half_w, half_h = new_w // 2, new_h // 2
236
+
237
+ left = max(0, min(cx - half_w, w - new_w))
238
+ top = max(0, min(cy - half_h, h - new_h))
239
+ right, bottom = left + new_w, top + new_h
 
 
 
 
 
 
240
 
241
  cropped = prev_pil.crop((left, top, right, bottom))
242
 
243
+ # (B) Upsample that crop back to (512×512)
244
+ zoomed_pil = cropped.resize((w, h), Image.BICUBIC)
 
 
245
 
246
+ # (C) Generate VLM prompt by passing PILs directly:
247
  prompt_tag = _generate_vlm_prompt(
248
  vlm_model=vlm_model,
249
  vlm_processor=vlm_processor,
250
  process_vision_info=process_vision_info,
251
+ prev_pil=prev_pil, # <– PIL
252
+ zoomed_pil=zoomed_pil, # <– PIL
253
  device=device,
254
  )
 
255
 
256
+ # (D) Prepare “zoomed_pil” tensor in [−1, 1]
257
  to_tensor = transforms.ToTensor()
258
+ lq = to_tensor(zoomed_pil).unsqueeze(0).to(device) # (1,3,512,512)
259
  lq = (lq * 2.0) - 1.0
260
 
261
+ # (E) Run SR inference
262
  with torch.no_grad():
263
+ out_tensor = model_test(lq, prompt=prompt_tag)[0]
264
  out_tensor = out_tensor.clamp(-1.0, 1.0).cpu()
 
265
  out_pil = transforms.ToPILImage()((out_tensor * 0.5) + 0.5)
266
 
267
+ # (F) Bookkeeping: set prev_pil = out_pil for next iteration
268
+ prev_pil = out_pil
 
 
269
 
270
+ # (G) Append to results
271
  sr_pil_list.append(out_pil)
272
  prompt_list.append(prompt_tag)
273
 
 
 
 
 
 
 
274
  return sr_pil_list, prompt_list