SmartHeal commited on
Commit
befc6dd
Β·
verified Β·
1 Parent(s): 984d739

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +250 -97
src/ai_processor.py CHANGED
@@ -16,7 +16,7 @@ import cv2
16
  import numpy as np
17
  from PIL import Image
18
  from PIL.ExifTags import TAGS
19
-
20
  # --- Logging config ---
21
  logging.basicConfig(
22
  level=getattr(logging, LOGLEVEL, logging.INFO),
@@ -26,12 +26,6 @@ logging.basicConfig(
26
  def _log_kv(prefix: str, kv: Dict):
27
  logging.debug(prefix + " | " + " | ".join(f"{k}={v}" for k, v in kv.items()))
28
 
29
- # --- Spaces GPU decorator (REQUIRED) ---
30
- from spaces import GPU as _SPACES_GPU
31
-
32
- @_SPACES_GPU(enable_queue=True)
33
- def smartheal_gpu_stub(ping: int = 0) -> str:
34
- return "ready"
35
 
36
  # ---- Paths / constants ----
37
  UPLOADS_DIR = "uploads"
@@ -39,7 +33,7 @@ os.makedirs(UPLOADS_DIR, exist_ok=True)
39
 
40
  HF_TOKEN = os.getenv("HF_TOKEN", None)
41
  YOLO_MODEL_PATH = "src/best.pt"
42
- SEG_MODEL_PATH = "src/segmentation_model.h5" # optional; legacy .h5 supported
43
  GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
44
  DATASET_ID = "SmartHeal/wound-image-uploads"
45
  DEFAULT_PX_PER_CM = 38.0
@@ -123,8 +117,10 @@ SMARTHEAL_USER_PREFIX = """\
123
  Patient: {patient_info}
124
  Visual findings: type={wound_type}, size={length_cm}x{breadth_cm} cm, area={area_cm2} cm^2,
125
  detection_conf={det_conf:.2f}, calibration={px_per_cm} px/cm.
 
126
  Guideline context (snippets you can draw principles from; do not quote at length):
127
  {guideline_context}
 
128
  Write a structured answer with these headings exactly:
129
  1. Clinical Summary (max 4 bullet points)
130
  2. Likely Stage/Type (if uncertain, say 'uncertain')
@@ -132,53 +128,238 @@ Write a structured answer with these headings exactly:
132
  4. Red Flags (what to escalate and when)
133
  5. Follow-up Cadence (days)
134
  6. Notes (assumptions/uncertainties)
 
135
  Keep to 220–300 words. Do NOT provide diagnosis. Avoid contraindicated advice.
136
  """
137
 
138
- # ---------- MedGemma-only text generator ----------
139
- @_SPACES_GPU(enable_queue=True)
140
- def _medgemma_generate_gpu(prompt: str, model_id: str, max_new_tokens: int, token: Optional[str]):
141
  """
142
- Runs entirely inside a Spaces GPU worker. Uses Med-Gemma (text-only) to draft the report.
 
 
 
 
 
143
  """
 
144
  import torch
145
- from transformers import pipeline
 
 
 
 
 
 
146
 
147
- pipe = pipeline(
148
- "image-text-to-text",
149
- model="google/medgemma-4b-it",
150
- torch_dtype=torch.bfloat16,
151
- device="cuda",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  )
153
- out = pipe(
154
- prompt,
155
- max_new_tokens=max_new_tokens,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  do_sample=False,
157
- temperature=0.2,
158
- return_full_text=True,
 
159
  )
160
- text = (out[0].get("generated_text") if isinstance(out, list) else out).strip()
161
- # Remove the prompt echo if present
162
- if text.startswith(prompt):
163
- text = text[len(prompt):].lstrip()
164
- return text or "⚠️ Empty response"
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- def generate_medgemma_report( # kept name so callers don't change
167
  patient_info: str,
168
  visual_results: Dict,
169
  guideline_context: str,
170
- image_pil: Image.Image, # kept for signature compatibility; not used by MedGemma
171
  max_new_tokens: Optional[int] = None,
172
  ) -> str:
173
  """
174
- MedGemma (text-only) report generation.
175
- The image is analyzed by the vision pipeline; MedGemma formats clinical guidance text.
176
  """
177
  if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1":
178
  return "⚠️ VLM disabled"
179
 
180
- # Default to a public Med-Gemma instruction-tuned model (update via env if you have access to another).
181
- model_id = os.getenv("SMARTHEAL_MEDGEMMA_MODEL", "google/med-gemma-2-2b-it")
182
  max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600"))
183
 
184
  uprompt = SMARTHEAL_USER_PREFIX.format(
@@ -192,69 +373,27 @@ def generate_medgemma_report( # kept name so callers don't change
192
  guideline_context=(guideline_context or "")[:900],
193
  )
194
 
195
- # Compose a single text prompt
196
- prompt = f"{SMARTHEAL_SYSTEM_PROMPT}\n\n{uprompt}\n\nAnswer:"
 
 
 
 
 
 
197
 
198
  try:
199
- return _medgemma_generate_gpu(prompt, model_id, max_new_tokens, HF_TOKEN)
200
  except Exception as e:
201
- logging.error(f"MedGemma call failed: {e}")
202
- return "⚠️ VLM error"
203
-
204
- # ---------- Input-shape helpers (avoid `.as_list()` on strings) ----------
205
- def _shape_to_hw(shape) -> Tuple[Optional[int], Optional[int]]:
206
- try:
207
- if hasattr(shape, "as_list"):
208
- shape = shape.as_list()
209
- except Exception:
210
- pass
211
- if isinstance(shape, (tuple, list)):
212
- if len(shape) == 4: # (None, H, W, C)
213
- H, W = shape[1], shape[2]
214
- elif len(shape) == 3: # (H, W, C)
215
- H, W = shape[0], shape[1]
216
- else:
217
- return (None, None)
218
- try: H = int(H) if (H is not None and str(H).lower() != "none") else None
219
- except Exception: H = None
220
- try: W = int(W) if (W is not None and str(W).lower() != "none") else None
221
- except Exception: W = None
222
- return (H, W)
223
- return (None, None)
224
-
225
- def _get_model_input_hw(model, default_hw: Tuple[int, int] = (224, 224)) -> Tuple[int, int]:
226
- H, W = _shape_to_hw(getattr(model, "input_shape", None))
227
- if H and W:
228
- return H, W
229
- try:
230
- inputs = getattr(model, "inputs", None)
231
- if inputs:
232
- H, W = _shape_to_hw(inputs[0].shape)
233
- if H and W:
234
- return H, W
235
- except Exception:
236
- pass
237
- try:
238
- cfg = model.get_config() if hasattr(model, "get_config") else None
239
- if isinstance(cfg, dict):
240
- for layer in cfg.get("layers", []):
241
- conf = (layer or {}).get("config", {})
242
- cand = conf.get("batch_input_shape") or conf.get("batch_shape")
243
- H, W = _shape_to_hw(cand)
244
- if H and W:
245
- return H, W
246
- except Exception:
247
- pass
248
- logging.warning(f"Could not resolve model input shape; using default {default_hw}.")
249
- return default_hw
250
-
251
  # ---------- Initialize CPU models ----------
252
  def load_yolo_model():
253
  YOLO = _import_ultralytics()
 
254
  with _no_cuda_env():
255
  model = YOLO(YOLO_MODEL_PATH)
256
  return model
257
-
258
  def load_segmentation_model():
259
  import tensorflow as tf
260
  load_model = _import_tf_loader()
@@ -287,11 +426,11 @@ def initialize_cpu_models() -> None:
287
  if "seg" not in models_cache:
288
  try:
289
  if os.path.exists(SEG_MODEL_PATH):
290
- m = load_segmentation_model() # uses global path by default
291
- models_cache["seg"] = m
292
- th, tw = _get_model_input_hw(m, default_hw=(224, 224))
293
  oshape = getattr(m, "output_shape", None)
294
- logging.info(f"βœ… Segmentation model loaded (CPU) | input_hw=({th},{tw}) output_shape={oshape}")
295
  else:
296
  models_cache["seg"] = None
297
  logging.warning("Segmentation model file missing; skipping.")
@@ -509,7 +648,11 @@ def segment_wound(image_bgr: np.ndarray, ts: str, out_dir: str) -> Tuple[np.ndar
509
  # --- Model path ---
510
  if seg_model is not None:
511
  try:
512
- th, tw = _get_model_input_hw(seg_model, default_hw=(224, 224))
 
 
 
 
513
  x = _preprocess_for_seg(image_bgr, (th, tw))
514
  roi_seen_path = None
515
  if SMARTHEAL_DEBUG:
@@ -600,7 +743,7 @@ def measure_min_area_rect(mask01: np.ndarray, px_per_cm: float) -> Tuple[float,
600
  cnt = max(contours, key=cv2.contourArea)
601
  rect = cv2.minAreaRect(cnt)
602
  (w_px, h_px) = rect[1]
603
- length_px, breadth_px = (max(w_px, h_px), min(h_px, w_px))
604
  length_cm = round(length_px / max(px_per_cm, 1e-6), 2)
605
  breadth_cm = round(breadth_px / max(px_per_cm, 1e-6), 2)
606
  box = cv2.boxPoints(rect).astype(int)
@@ -859,6 +1002,7 @@ class AIProcessor:
859
  if not vs:
860
  return "Knowledge base is not available."
861
  retriever = vs.as_retriever(search_kwargs={"k": 5})
 
862
  docs = retriever.invoke(query)
863
  lines: List[str] = []
864
  for d in docs:
@@ -872,30 +1016,38 @@ class AIProcessor:
872
 
873
  def _generate_fallback_report(self, patient_info: str, visual_results: Dict, guideline_context: str) -> str:
874
  return f"""# 🩺 SmartHeal AI - Comprehensive Wound Analysis Report
 
875
  ## πŸ“‹ Patient Information
876
  {patient_info}
 
877
  ## πŸ” Visual Analysis Results
878
  - **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
879
  - **Dimensions**: {visual_results.get('length_cm', 0)} cm Γ— {visual_results.get('breadth_cm', 0)} cm
880
  - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cmΒ²
881
  - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%}
882
  - **Calibration**: {visual_results.get('px_per_cm','?')} px/cm ({(visual_results.get('calibration_meta') or {}).get('used','default')})
 
883
  ## πŸ“Š Analysis Images
884
  - **Original**: {visual_results.get('original_image_path', 'N/A')}
885
  - **Detection**: {visual_results.get('detection_image_path', 'N/A')}
886
  - **Segmentation**: {visual_results.get('segmentation_image_path', 'N/A')}
887
  - **Annotated**: {visual_results.get('segmentation_annotated_path', 'N/A')}
 
888
  ## 🎯 Clinical Summary
889
  Automated analysis provides quantitative measurements; verify via clinical examination.
 
890
  ## πŸ’Š Recommendations
891
  - Cleanse wound gently; select dressing per exudate/infection risk
892
  - Debride necrotic tissue if indicated (clinical decision)
893
  - Document with serial photos and measurements
 
894
  ## πŸ“… Monitoring
895
  - Daily in week 1, then every 2–3 days (or as indicated)
896
  - Weekly progress review
 
897
  ## πŸ“š Guideline Context
898
  {(guideline_context or '')[:800]}{"..." if guideline_context and len(guideline_context) > 800 else ''}
 
899
  **Disclaimer:** Automated, for decision support only. Verify clinically.
900
  """
901
 
@@ -949,7 +1101,8 @@ Automated analysis provides quantitative measurements; verify via clinical exami
949
  except Exception as e:
950
  logging.error(f"Failed to save/commit image: {e}")
951
  return ""
952
-
 
953
  def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict:
954
  try:
955
  saved_path = self.save_and_commit_image(image_pil)
@@ -995,7 +1148,7 @@ Automated analysis provides quantitative measurements; verify via clinical exami
995
  "saved_image_path": None,
996
  "guideline_context": "",
997
  }
998
-
999
  def analyze_wound(self, image, questionnaire_data: Dict) -> Dict:
1000
  try:
1001
  if isinstance(image, str):
@@ -1019,4 +1172,4 @@ Automated analysis provides quantitative measurements; verify via clinical exami
1019
  "report": f"Analysis initialization failed: {str(e)}",
1020
  "saved_image_path": None,
1021
  "guideline_context": "",
1022
- }
 
16
  import numpy as np
17
  from PIL import Image
18
  from PIL.ExifTags import TAGS
19
+ import spaces
20
  # --- Logging config ---
21
  logging.basicConfig(
22
  level=getattr(logging, LOGLEVEL, logging.INFO),
 
26
  def _log_kv(prefix: str, kv: Dict):
27
  logging.debug(prefix + " | " + " | ".join(f"{k}={v}" for k, v in kv.items()))
28
 
 
 
 
 
 
 
29
 
30
  # ---- Paths / constants ----
31
  UPLOADS_DIR = "uploads"
 
33
 
34
  HF_TOKEN = os.getenv("HF_TOKEN", None)
35
  YOLO_MODEL_PATH = "src/best.pt"
36
+ SEG_MODEL_PATH = "src/segmentation_model.h5" # optional
37
  GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
38
  DATASET_ID = "SmartHeal/wound-image-uploads"
39
  DEFAULT_PX_PER_CM = 38.0
 
117
  Patient: {patient_info}
118
  Visual findings: type={wound_type}, size={length_cm}x{breadth_cm} cm, area={area_cm2} cm^2,
119
  detection_conf={det_conf:.2f}, calibration={px_per_cm} px/cm.
120
+
121
  Guideline context (snippets you can draw principles from; do not quote at length):
122
  {guideline_context}
123
+
124
  Write a structured answer with these headings exactly:
125
  1. Clinical Summary (max 4 bullet points)
126
  2. Likely Stage/Type (if uncertain, say 'uncertain')
 
128
  4. Red Flags (what to escalate and when)
129
  5. Follow-up Cadence (days)
130
  6. Notes (assumptions/uncertainties)
131
+
132
  Keep to 220–300 words. Do NOT provide diagnosis. Avoid contraindicated advice.
133
  """
134
 
135
+
136
+ def _vlm_infer_gpu(messages, model_id: str, max_new_tokens: int, token: Optional[str]):
 
137
  """
138
+ Runs entirely inside a Spaces GPU worker. It's the ONLY place we allow CUDA init.
139
+ Safe for:
140
+ - CUDA device selection (no 'Invalid device id')
141
+ - BF16/FP16 choice via compute capability
142
+ - LLaVA processors with patch_size=None
143
+ - Processors WITHOUT a chat template (fallback to plain/LLaVA-style prompt)
144
  """
145
+ import logging
146
  import torch
147
+ from typing import Optional, List
148
+ from transformers import (
149
+ AutoProcessor,
150
+ AutoModelForVision2Seq,
151
+ StoppingCriteria,
152
+ StoppingCriteriaList,
153
+ )
154
 
155
+ # -------- Device & dtype (robust) --------
156
+ def _pick_device_and_dtype():
157
+ if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
158
+ logging.warning("CUDA not available; using CPU.")
159
+ return "cpu", torch.float32
160
+ idx = 0
161
+ try:
162
+ torch.cuda.set_device(idx)
163
+ except Exception as e:
164
+ logging.warning(f"torch.cuda.set_device({idx}) failed: {e}; falling back to CPU.")
165
+ return "cpu", torch.float32
166
+ device = f"cuda:{idx}"
167
+ try:
168
+ props = torch.cuda.get_device_properties(idx)
169
+ cc = props.major * 10 + props.minor
170
+ dtype = torch.bfloat16 if cc >= 80 else torch.float16
171
+ except Exception as e:
172
+ logging.warning(f"Could not query CUDA props: {e}; defaulting to float16.")
173
+ dtype = torch.float16
174
+ return device, dtype
175
+
176
+ device, torch_dtype = _pick_device_and_dtype()
177
+
178
+ # -------- Load model & processor --------
179
+ model = AutoModelForVision2Seq.from_pretrained(
180
+ model_id,
181
+ torch_dtype=torch_dtype,
182
+ trust_remote_code=True,
183
+ low_cpu_mem_usage=True,
184
+ token=token,
185
+ ).to(device)
186
+ model.eval()
187
+
188
+ processor = AutoProcessor.from_pretrained(
189
+ model_id, trust_remote_code=True, token=token
190
  )
191
+
192
+ # -------- Extract image & text --------
193
+ image_obj = None
194
+ text_prompt = ""
195
+ for m in messages:
196
+ if m.get("role") == "user":
197
+ for c in m.get("content", []):
198
+ if c.get("type") == "image":
199
+ image_obj = c.get("image")
200
+ elif c.get("type") == "text":
201
+ text_prompt = c.get("text", "")
202
+ break
203
+ if image_obj is None:
204
+ raise ValueError("No image found in messages for VLM inference.")
205
+
206
+ # -------- Normalize image to PIL --------
207
+ from PIL import Image
208
+ import numpy as np
209
+ def _to_pil(x):
210
+ if isinstance(x, Image.Image):
211
+ return x.convert("RGB")
212
+ if isinstance(x, str):
213
+ return Image.open(x).convert("RGB")
214
+ if isinstance(x, np.ndarray):
215
+ if x.ndim == 2:
216
+ x = np.stack([x]*3, axis=-1)
217
+ if x.dtype != np.uint8:
218
+ x = x.astype(np.uint8)
219
+ return Image.fromarray(x, "RGB")
220
+ if hasattr(x, "read"):
221
+ return Image.open(x).convert("RGB")
222
+ raise TypeError(f"Unsupported image type: {type(x)}")
223
+ image_pil = _to_pil(image_obj)
224
+
225
+ # -------- Ensure patch_size for LLaVA processors --------
226
+ def _ensure_patch_size(proc, mdl):
227
+ ps = getattr(proc, "patch_size", None)
228
+ if not ps:
229
+ candidates = [
230
+ getattr(getattr(mdl, "vision_tower", None), "config", None),
231
+ getattr(mdl.config, "vision_config", None),
232
+ getattr(proc, "image_processor", None),
233
+ getattr(getattr(proc, "image_processor", None), "config", None),
234
+ ]
235
+ for obj in candidates:
236
+ if obj is None:
237
+ continue
238
+ maybe = getattr(obj, "patch_size", None)
239
+ if maybe:
240
+ ps = int(maybe); break
241
+ if not ps:
242
+ ps = 14 # safe default for ViT-L/14-style
243
+ try:
244
+ setattr(proc, "patch_size", ps)
245
+ except Exception:
246
+ pass
247
+ return ps
248
+ _ensure_patch_size(processor, model)
249
+
250
+ # -------- Build text (chat-template only if it truly exists) --------
251
+ # Some processors expose apply_chat_template but tokenizer has no template β†’ ValueError. Guard it.
252
+ tokenizer = getattr(processor, "tokenizer", None)
253
+ has_template = bool(getattr(tokenizer, "chat_template", None))
254
+ used_chat_template = False
255
+
256
+ def _looks_like_llava():
257
+ name = processor.__class__.__name__.lower()
258
+ mid = (model_id or "").lower()
259
+ return ("llava" in name) or ("llava" in mid)
260
+
261
+ if hasattr(processor, "apply_chat_template") and has_template:
262
+ try:
263
+ chat = [{
264
+ "role": "user",
265
+ "content": [
266
+ {"type": "image", "image": image_pil},
267
+ {"type": "text", "text": text_prompt or "Describe the image."},
268
+ ],
269
+ }]
270
+ text_for_model = processor.apply_chat_template(
271
+ chat, add_generation_prompt=True, tokenize=False
272
+ )
273
+ used_chat_template = True
274
+ except Exception as e:
275
+ logging.info(f"No usable chat template ({e}); falling back to plain prompt.")
276
+ text_for_model = (
277
+ f"USER: <image>\n{text_prompt or 'Describe the image.'}\nASSISTANT:"
278
+ if _looks_like_llava() else (text_prompt or "Describe the image.")
279
+ )
280
+ else:
281
+ text_for_model = (
282
+ f"USER: <image>\n{text_prompt or 'Describe the image.'}\nASSISTANT:"
283
+ if _looks_like_llava() else (text_prompt or "Describe the image.")
284
+ )
285
+
286
+ # -------- Tokenize --------
287
+ inputs = processor(
288
+ text=[text_for_model],
289
+ images=[image_pil],
290
+ return_tensors="pt",
291
+ padding=True,
292
+ ).to(device)
293
+
294
+ # -------- Stopping criteria --------
295
+ class EosTokenCriteria(StoppingCriteria):
296
+ def __init__(self, eos_token_ids: List[int]):
297
+ import torch as _t
298
+ self.eos = _t.tensor(eos_token_ids, dtype=_t.long)
299
+ def __call__(self, input_ids, scores, **kwargs) -> bool:
300
+ import torch as _t
301
+ last_tok = input_ids[:, -1]
302
+ return _t.isin(last_tok, self.eos.to(last_tok.device)).any().item()
303
+
304
+ eos_ids: List[int] = []
305
+ if tokenizer is not None:
306
+ for attr in ("eos_token_id", "eot_token_id"):
307
+ v = getattr(tokenizer, attr, None)
308
+ if v is None: continue
309
+ eos_ids.extend([v] if isinstance(v, int) else list(v))
310
+ if not eos_ids:
311
+ cfg = getattr(model, "generation_config", None)
312
+ if cfg and getattr(cfg, "eos_token_id", None) is not None:
313
+ eos_ids = [cfg.eos_token_id]
314
+ else:
315
+ eos_ids = [2]
316
+ stopping_criteria = StoppingCriteriaList([EosTokenCriteria(eos_ids)])
317
+
318
+ if tokenizer is not None and getattr(tokenizer, "pad_token_id", None) is None:
319
+ try: tokenizer.pad_token_id = eos_ids[0]
320
+ except Exception: pass
321
+
322
+ # -------- Generate --------
323
+ gen_kwargs = dict(
324
+ max_new_tokens=int(max_new_tokens or 256),
325
  do_sample=False,
326
+ stopping_criteria=stopping_criteria,
327
+ eos_token_id=eos_ids[0] if eos_ids else None,
328
+ pad_token_id=getattr(tokenizer, "pad_token_id", None) if tokenizer else None,
329
  )
330
+ with torch.inference_mode():
331
+ out = model.generate(**inputs, **gen_kwargs)
332
+
333
+ # -------- Decode --------
334
+ seq = out[0]
335
+ if "input_ids" in inputs:
336
+ cut = inputs["input_ids"].shape[-1]
337
+ seq = seq[cut:]
338
+ if tokenizer is not None:
339
+ text_out = tokenizer.decode(seq, skip_special_tokens=True)
340
+ elif hasattr(processor, "batch_decode"):
341
+ text_out = processor.batch_decode(seq.unsqueeze(0), skip_special_tokens=True)[0]
342
+ else:
343
+ text_out = str(seq.tolist())
344
+
345
+ return text_out.strip()
346
+
347
 
348
+ def generate_medgemma_report(
349
  patient_info: str,
350
  visual_results: Dict,
351
  guideline_context: str,
352
+ image_pil: Image.Image,
353
  max_new_tokens: Optional[int] = None,
354
  ) -> str:
355
  """
356
+ MedGemma replacement using a vision-language model.
357
+ Loads & runs ONLY inside a GPU worker to satisfy Stateless GPU constraints.
358
  """
359
  if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1":
360
  return "⚠️ VLM disabled"
361
 
362
+ model_id = os.getenv("SMARTHEAL_VLM_MODEL", "bczhou/tiny-llava-v1-hf")
 
363
  max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600"))
364
 
365
  uprompt = SMARTHEAL_USER_PREFIX.format(
 
373
  guideline_context=(guideline_context or "")[:900],
374
  )
375
 
376
+ # The `messages` structure is passed to the verified `_vlm_infer_gpu` function
377
+ messages = [
378
+ {"role": "system", "content": [{"type": "text", "text": SMARTHEAL_SYSTEM_PROMPT}]},
379
+ {"role": "user", "content": [
380
+ {"type": "image", "image": image_pil},
381
+ {"type": "text", "text": uprompt},
382
+ ]},
383
+ ]
384
 
385
  try:
386
+ return _vlm_infer_gpu(messages, model_id, max_new_tokens, HF_TOKEN)
387
  except Exception as e:
388
+ logging.error(f"VLM call failed: {e}", exc_info=True)
389
+ return f"⚠️ VLM error: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  # ---------- Initialize CPU models ----------
391
  def load_yolo_model():
392
  YOLO = _import_ultralytics()
393
+ # Construct model with CUDA masked to avoid auto-selecting cuda:0
394
  with _no_cuda_env():
395
  model = YOLO(YOLO_MODEL_PATH)
396
  return model
 
397
  def load_segmentation_model():
398
  import tensorflow as tf
399
  load_model = _import_tf_loader()
 
426
  if "seg" not in models_cache:
427
  try:
428
  if os.path.exists(SEG_MODEL_PATH):
429
+ models_cache["seg"] = load_segmentation_model()
430
+ m = models_cache["seg"]
431
+ ishape = getattr(m, "input_shape", None)
432
  oshape = getattr(m, "output_shape", None)
433
+ logging.info(f"βœ… Segmentation model loaded (CPU) | input_shape={ishape} output_shape={oshape}")
434
  else:
435
  models_cache["seg"] = None
436
  logging.warning("Segmentation model file missing; skipping.")
 
648
  # --- Model path ---
649
  if seg_model is not None:
650
  try:
651
+ ishape = getattr(seg_model, "input_shape", None)
652
+ if not ishape or len(ishape) < 4:
653
+ raise ValueError(f"Bad seg input_shape: {ishape}")
654
+ th, tw = int(ishape[1]), int(ishape[2])
655
+
656
  x = _preprocess_for_seg(image_bgr, (th, tw))
657
  roi_seen_path = None
658
  if SMARTHEAL_DEBUG:
 
743
  cnt = max(contours, key=cv2.contourArea)
744
  rect = cv2.minAreaRect(cnt)
745
  (w_px, h_px) = rect[1]
746
+ length_px, breadth_px = (max(w_px, h_px), min(w_px, h_px))
747
  length_cm = round(length_px / max(px_per_cm, 1e-6), 2)
748
  breadth_cm = round(breadth_px / max(px_per_cm, 1e-6), 2)
749
  box = cv2.boxPoints(rect).astype(int)
 
1002
  if not vs:
1003
  return "Knowledge base is not available."
1004
  retriever = vs.as_retriever(search_kwargs={"k": 5})
1005
+ # Modern API (avoid get_relevant_documents deprecation)
1006
  docs = retriever.invoke(query)
1007
  lines: List[str] = []
1008
  for d in docs:
 
1016
 
1017
  def _generate_fallback_report(self, patient_info: str, visual_results: Dict, guideline_context: str) -> str:
1018
  return f"""# 🩺 SmartHeal AI - Comprehensive Wound Analysis Report
1019
+
1020
  ## πŸ“‹ Patient Information
1021
  {patient_info}
1022
+
1023
  ## πŸ” Visual Analysis Results
1024
  - **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
1025
  - **Dimensions**: {visual_results.get('length_cm', 0)} cm Γ— {visual_results.get('breadth_cm', 0)} cm
1026
  - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cmΒ²
1027
  - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%}
1028
  - **Calibration**: {visual_results.get('px_per_cm','?')} px/cm ({(visual_results.get('calibration_meta') or {}).get('used','default')})
1029
+
1030
  ## πŸ“Š Analysis Images
1031
  - **Original**: {visual_results.get('original_image_path', 'N/A')}
1032
  - **Detection**: {visual_results.get('detection_image_path', 'N/A')}
1033
  - **Segmentation**: {visual_results.get('segmentation_image_path', 'N/A')}
1034
  - **Annotated**: {visual_results.get('segmentation_annotated_path', 'N/A')}
1035
+
1036
  ## 🎯 Clinical Summary
1037
  Automated analysis provides quantitative measurements; verify via clinical examination.
1038
+
1039
  ## πŸ’Š Recommendations
1040
  - Cleanse wound gently; select dressing per exudate/infection risk
1041
  - Debride necrotic tissue if indicated (clinical decision)
1042
  - Document with serial photos and measurements
1043
+
1044
  ## πŸ“… Monitoring
1045
  - Daily in week 1, then every 2–3 days (or as indicated)
1046
  - Weekly progress review
1047
+
1048
  ## πŸ“š Guideline Context
1049
  {(guideline_context or '')[:800]}{"..." if guideline_context and len(guideline_context) > 800 else ''}
1050
+
1051
  **Disclaimer:** Automated, for decision support only. Verify clinically.
1052
  """
1053
 
 
1101
  except Exception as e:
1102
  logging.error(f"Failed to save/commit image: {e}")
1103
  return ""
1104
+
1105
+
1106
  def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict:
1107
  try:
1108
  saved_path = self.save_and_commit_image(image_pil)
 
1148
  "saved_image_path": None,
1149
  "guideline_context": "",
1150
  }
1151
+
1152
  def analyze_wound(self, image, questionnaire_data: Dict) -> Dict:
1153
  try:
1154
  if isinstance(image, str):
 
1172
  "report": f"Analysis initialization failed: {str(e)}",
1173
  "saved_image_path": None,
1174
  "guideline_context": "",
1175
+ }