corgi-qwen3-vl-demo / corgi /qwen_client.py
dung-vpt-uney
Deploy CoRGI demo - 2025-10-29 14:17:23
58fe08c
raw
history blame
9.05 kB
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional
import logging
import torch
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
logger = logging.getLogger(__name__)
try:
import spaces # type: ignore
except ImportError: # pragma: no cover - only available on HF Spaces
spaces = None # type: ignore
from .parsers import parse_roi_evidence, parse_structured_reasoning
from .types import GroundedEvidence, PromptLog, ReasoningStep
DEFAULT_REASONING_PROMPT = (
"You are a careful multimodal reasoner following the CoRGI protocol. "
"Given the question and the image, produce a JSON array of reasoning steps. "
"Each item must contain the keys: index (1-based integer), statement (concise sentence), "
"needs_vision (boolean true if the statement requires visual verification), and reason "
"(short phrase explaining why visual verification is or is not required). "
"Limit the number of steps to {max_steps}. Respond with JSON only; start the reply with '[' and end with ']'. "
"Do not add any commentary or prose outside of the JSON."
)
DEFAULT_GROUNDING_PROMPT = (
"You are validating the following reasoning step:\n"
"{step_statement}\n"
"Return a JSON array with up to {max_regions} region candidates that help verify the step. "
"Each object must include: step (integer), bbox (list of four numbers x1,y1,x2,y2, "
"either normalized 0-1 or scaled 0-1000), description (short textual evidence), "
"and confidence (0-1). Use [] if no relevant region exists. "
"Respond with JSON only; do not include explanations outside the JSON array."
)
DEFAULT_ANSWER_PROMPT = (
"You are finalizing the answer using verified evidence. "
"Question: {question}\n"
"Structured reasoning steps:\n"
"{steps}\n"
"Verified evidence items:\n"
"{evidence}\n"
"Respond with a concise final answer sentence grounded in the evidence. "
"If unsure, say you are uncertain. Do not include <think> tags or internal monologue."
)
def _format_steps_for_prompt(steps: List[ReasoningStep]) -> str:
return "\n".join(
f"{step.index}. {step.statement} (needs vision: {step.needs_vision})"
for step in steps
)
def _format_evidence_for_prompt(evidences: List[GroundedEvidence]) -> str:
if not evidences:
return "No evidence collected."
lines = []
for ev in evidences:
desc = ev.description or "No description"
bbox = ", ".join(f"{coord:.2f}" for coord in ev.bbox)
conf = f"{ev.confidence:.2f}" if ev.confidence is not None else "n/a"
lines.append(f"Step {ev.step_index}: bbox=({bbox}), conf={conf}, desc={desc}")
return "\n".join(lines)
def _strip_think_content(text: str) -> str:
if not text:
return ""
cleaned = text
if "</think>" in cleaned:
cleaned = cleaned.split("</think>", 1)[-1]
cleaned = cleaned.replace("<think>", "")
return cleaned.strip()
_MODEL_CACHE: dict[str, AutoModelForImageTextToText] = {}
_PROCESSOR_CACHE: dict[str, AutoProcessor] = {}
def _gpu_decorator(duration: int = 120):
if spaces is None:
return lambda fn: fn
return spaces.GPU(duration=duration)
def _ensure_cuda(model: AutoModelForImageTextToText) -> AutoModelForImageTextToText:
if torch.cuda.is_available():
target_device = torch.device("cuda")
current_device = next(model.parameters()).device
if current_device.type != target_device.type:
model.to(target_device)
return model
def _load_backend(model_id: str) -> tuple[AutoModelForImageTextToText, AutoProcessor]:
if model_id not in _MODEL_CACHE:
# Check if hardware supports bfloat16
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
torch_dtype = torch.bfloat16
logger.info("Using bfloat16 (hardware supported)")
elif torch.cuda.is_available():
torch_dtype = torch.float16 # Fallback to float16 if bfloat16 not supported
logger.info("Using float16 (bfloat16 not supported on this GPU)")
else:
torch_dtype = torch.float32
logger.info("Using float32 (CPU mode)")
# Use single GPU (cuda:0) instead of auto to avoid model sharding across multiple GPUs
device_map = "cuda:0" if torch.cuda.is_available() else "cpu"
model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch_dtype,
device_map=device_map,
)
model = model.eval()
processor = AutoProcessor.from_pretrained(model_id)
_MODEL_CACHE[model_id] = model
_PROCESSOR_CACHE[model_id] = processor
return _MODEL_CACHE[model_id], _PROCESSOR_CACHE[model_id]
@dataclass
class QwenGenerationConfig:
model_id: str = "Qwen/Qwen3-VL-4B-Instruct"
max_new_tokens: int = 512
temperature: float | None = None
do_sample: bool = False
class Qwen3VLClient:
"""Wrapper around transformers Qwen3-VL chat API for CoRGI pipeline."""
def __init__(
self,
config: Optional[QwenGenerationConfig] = None,
) -> None:
self.config = config or QwenGenerationConfig()
self._model, self._processor = _load_backend(self.config.model_id)
self.reset_logs()
def reset_logs(self) -> None:
self._reasoning_log: Optional[PromptLog] = None
self._grounding_logs: List[PromptLog] = []
self._answer_log: Optional[PromptLog] = None
@property
def reasoning_log(self) -> Optional[PromptLog]:
return self._reasoning_log
@property
def grounding_logs(self) -> List[PromptLog]:
return list(self._grounding_logs)
@property
def answer_log(self) -> Optional[PromptLog]:
return self._answer_log
@_gpu_decorator()
def _chat(
self,
image: Image.Image,
prompt: str,
max_new_tokens: Optional[int] = None,
) -> str:
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
chat_prompt = self._processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
)
inputs = self._processor(
text=[chat_prompt],
images=[image],
return_tensors="pt",
).to(self._model.device)
gen_kwargs = {
"max_new_tokens": max_new_tokens or self.config.max_new_tokens,
"do_sample": self.config.do_sample,
}
if self.config.do_sample and self.config.temperature is not None:
gen_kwargs["temperature"] = self.config.temperature
output_ids = self._model.generate(**inputs, **gen_kwargs)
prompt_length = inputs.input_ids.shape[1]
generated_tokens = output_ids[:, prompt_length:]
response = self._processor.batch_decode(
generated_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
return response.strip()
def structured_reasoning(self, image: Image.Image, question: str, max_steps: int) -> List[ReasoningStep]:
prompt = DEFAULT_REASONING_PROMPT.format(max_steps=max_steps) + f"\nQuestion: {question}"
response = self._chat(image=image, prompt=prompt)
self._reasoning_log = PromptLog(prompt=prompt, response=response, stage="reasoning")
return parse_structured_reasoning(response, max_steps=max_steps)
def extract_step_evidence(
self,
image: Image.Image,
question: str,
step: ReasoningStep,
max_regions: int,
) -> List[GroundedEvidence]:
prompt = DEFAULT_GROUNDING_PROMPT.format(
step_statement=step.statement,
max_regions=max_regions,
)
response = self._chat(image=image, prompt=prompt, max_new_tokens=256)
evidences = parse_roi_evidence(response, default_step_index=step.index)
self._grounding_logs.append(
PromptLog(prompt=prompt, response=response, step_index=step.index, stage="grounding")
)
return evidences[:max_regions]
def synthesize_answer(
self,
image: Image.Image,
question: str,
steps: List[ReasoningStep],
evidences: List[GroundedEvidence],
) -> str:
prompt = DEFAULT_ANSWER_PROMPT.format(
question=question,
steps=_format_steps_for_prompt(steps),
evidence=_format_evidence_for_prompt(evidences),
)
response = self._chat(image=image, prompt=prompt, max_new_tokens=256)
self._answer_log = PromptLog(prompt=prompt, response=response, stage="synthesis")
return _strip_think_content(response)
__all__ = ["Qwen3VLClient", "QwenGenerationConfig"]