Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import List, Optional | |
| import logging | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModelForImageTextToText, AutoProcessor | |
| logger = logging.getLogger(__name__) | |
| try: | |
| import spaces # type: ignore | |
| except ImportError: # pragma: no cover - only available on HF Spaces | |
| spaces = None # type: ignore | |
| from .parsers import parse_roi_evidence, parse_structured_reasoning | |
| from .types import GroundedEvidence, PromptLog, ReasoningStep | |
| DEFAULT_REASONING_PROMPT = ( | |
| "You are a careful multimodal reasoner following the CoRGI protocol. " | |
| "Given the question and the image, produce a JSON array of reasoning steps. " | |
| "Each item must contain the keys: index (1-based integer), statement (concise sentence), " | |
| "needs_vision (boolean true if the statement requires visual verification), and reason " | |
| "(short phrase explaining why visual verification is or is not required). " | |
| "Limit the number of steps to {max_steps}. Respond with JSON only; start the reply with '[' and end with ']'. " | |
| "Do not add any commentary or prose outside of the JSON." | |
| ) | |
| DEFAULT_GROUNDING_PROMPT = ( | |
| "You are validating the following reasoning step:\n" | |
| "{step_statement}\n" | |
| "Return a JSON array with up to {max_regions} region candidates that help verify the step. " | |
| "Each object must include: step (integer), bbox (list of four numbers x1,y1,x2,y2, " | |
| "either normalized 0-1 or scaled 0-1000), description (short textual evidence), " | |
| "and confidence (0-1). Use [] if no relevant region exists. " | |
| "Respond with JSON only; do not include explanations outside the JSON array." | |
| ) | |
| DEFAULT_ANSWER_PROMPT = ( | |
| "You are finalizing the answer using verified evidence. " | |
| "Question: {question}\n" | |
| "Structured reasoning steps:\n" | |
| "{steps}\n" | |
| "Verified evidence items:\n" | |
| "{evidence}\n" | |
| "Respond with a concise final answer sentence grounded in the evidence. " | |
| "If unsure, say you are uncertain. Do not include <think> tags or internal monologue." | |
| ) | |
| def _format_steps_for_prompt(steps: List[ReasoningStep]) -> str: | |
| return "\n".join( | |
| f"{step.index}. {step.statement} (needs vision: {step.needs_vision})" | |
| for step in steps | |
| ) | |
| def _format_evidence_for_prompt(evidences: List[GroundedEvidence]) -> str: | |
| if not evidences: | |
| return "No evidence collected." | |
| lines = [] | |
| for ev in evidences: | |
| desc = ev.description or "No description" | |
| bbox = ", ".join(f"{coord:.2f}" for coord in ev.bbox) | |
| conf = f"{ev.confidence:.2f}" if ev.confidence is not None else "n/a" | |
| lines.append(f"Step {ev.step_index}: bbox=({bbox}), conf={conf}, desc={desc}") | |
| return "\n".join(lines) | |
| def _strip_think_content(text: str) -> str: | |
| if not text: | |
| return "" | |
| cleaned = text | |
| if "</think>" in cleaned: | |
| cleaned = cleaned.split("</think>", 1)[-1] | |
| cleaned = cleaned.replace("<think>", "") | |
| return cleaned.strip() | |
| _MODEL_CACHE: dict[str, AutoModelForImageTextToText] = {} | |
| _PROCESSOR_CACHE: dict[str, AutoProcessor] = {} | |
| def _gpu_decorator(duration: int = 120): | |
| if spaces is None: | |
| return lambda fn: fn | |
| return spaces.GPU(duration=duration) | |
| def _ensure_cuda(model: AutoModelForImageTextToText) -> AutoModelForImageTextToText: | |
| if torch.cuda.is_available(): | |
| target_device = torch.device("cuda") | |
| current_device = next(model.parameters()).device | |
| if current_device.type != target_device.type: | |
| model.to(target_device) | |
| return model | |
| def _load_backend(model_id: str) -> tuple[AutoModelForImageTextToText, AutoProcessor]: | |
| if model_id not in _MODEL_CACHE: | |
| # Check if hardware supports bfloat16 | |
| if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): | |
| torch_dtype = torch.bfloat16 | |
| logger.info("Using bfloat16 (hardware supported)") | |
| elif torch.cuda.is_available(): | |
| torch_dtype = torch.float16 # Fallback to float16 if bfloat16 not supported | |
| logger.info("Using float16 (bfloat16 not supported on this GPU)") | |
| else: | |
| torch_dtype = torch.float32 | |
| logger.info("Using float32 (CPU mode)") | |
| # Use single GPU (cuda:0) instead of auto to avoid model sharding across multiple GPUs | |
| device_map = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| model_id, | |
| torch_dtype=torch_dtype, | |
| device_map=device_map, | |
| ) | |
| model = model.eval() | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| _MODEL_CACHE[model_id] = model | |
| _PROCESSOR_CACHE[model_id] = processor | |
| return _MODEL_CACHE[model_id], _PROCESSOR_CACHE[model_id] | |
| class QwenGenerationConfig: | |
| model_id: str = "Qwen/Qwen3-VL-4B-Instruct" | |
| max_new_tokens: int = 512 | |
| temperature: float | None = None | |
| do_sample: bool = False | |
| class Qwen3VLClient: | |
| """Wrapper around transformers Qwen3-VL chat API for CoRGI pipeline.""" | |
| def __init__( | |
| self, | |
| config: Optional[QwenGenerationConfig] = None, | |
| ) -> None: | |
| self.config = config or QwenGenerationConfig() | |
| self._model, self._processor = _load_backend(self.config.model_id) | |
| self.reset_logs() | |
| def reset_logs(self) -> None: | |
| self._reasoning_log: Optional[PromptLog] = None | |
| self._grounding_logs: List[PromptLog] = [] | |
| self._answer_log: Optional[PromptLog] = None | |
| def reasoning_log(self) -> Optional[PromptLog]: | |
| return self._reasoning_log | |
| def grounding_logs(self) -> List[PromptLog]: | |
| return list(self._grounding_logs) | |
| def answer_log(self) -> Optional[PromptLog]: | |
| return self._answer_log | |
| def _chat( | |
| self, | |
| image: Image.Image, | |
| prompt: str, | |
| max_new_tokens: Optional[int] = None, | |
| ) -> str: | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| chat_prompt = self._processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=False, | |
| ) | |
| inputs = self._processor( | |
| text=[chat_prompt], | |
| images=[image], | |
| return_tensors="pt", | |
| ).to(self._model.device) | |
| gen_kwargs = { | |
| "max_new_tokens": max_new_tokens or self.config.max_new_tokens, | |
| "do_sample": self.config.do_sample, | |
| } | |
| if self.config.do_sample and self.config.temperature is not None: | |
| gen_kwargs["temperature"] = self.config.temperature | |
| output_ids = self._model.generate(**inputs, **gen_kwargs) | |
| prompt_length = inputs.input_ids.shape[1] | |
| generated_tokens = output_ids[:, prompt_length:] | |
| response = self._processor.batch_decode( | |
| generated_tokens, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False, | |
| )[0] | |
| return response.strip() | |
| def structured_reasoning(self, image: Image.Image, question: str, max_steps: int) -> List[ReasoningStep]: | |
| prompt = DEFAULT_REASONING_PROMPT.format(max_steps=max_steps) + f"\nQuestion: {question}" | |
| response = self._chat(image=image, prompt=prompt) | |
| self._reasoning_log = PromptLog(prompt=prompt, response=response, stage="reasoning") | |
| return parse_structured_reasoning(response, max_steps=max_steps) | |
| def extract_step_evidence( | |
| self, | |
| image: Image.Image, | |
| question: str, | |
| step: ReasoningStep, | |
| max_regions: int, | |
| ) -> List[GroundedEvidence]: | |
| prompt = DEFAULT_GROUNDING_PROMPT.format( | |
| step_statement=step.statement, | |
| max_regions=max_regions, | |
| ) | |
| response = self._chat(image=image, prompt=prompt, max_new_tokens=256) | |
| evidences = parse_roi_evidence(response, default_step_index=step.index) | |
| self._grounding_logs.append( | |
| PromptLog(prompt=prompt, response=response, step_index=step.index, stage="grounding") | |
| ) | |
| return evidences[:max_regions] | |
| def synthesize_answer( | |
| self, | |
| image: Image.Image, | |
| question: str, | |
| steps: List[ReasoningStep], | |
| evidences: List[GroundedEvidence], | |
| ) -> str: | |
| prompt = DEFAULT_ANSWER_PROMPT.format( | |
| question=question, | |
| steps=_format_steps_for_prompt(steps), | |
| evidence=_format_evidence_for_prompt(evidences), | |
| ) | |
| response = self._chat(image=image, prompt=prompt, max_new_tokens=256) | |
| self._answer_log = PromptLog(prompt=prompt, response=response, stage="synthesis") | |
| return _strip_think_content(response) | |
| __all__ = ["Qwen3VLClient", "QwenGenerationConfig"] | |