|
""" |
|
modal-medgemma.py |
|
----------------- |
|
|
|
This module defines a Modal app and FastAPI endpoint for running the MedGemma agent, a multimodal LLM that can process text and images. It provides a streaming API for inference, including Wikipedia tool-calling capabilities, and handles model download, loading, and inference with GPU support. |
|
|
|
Nb. needs to be deployed with the following command: |
|
`modal deploy modal-medgemma.py` |
|
|
|
Key components: |
|
- Modal app and volume setup for model weights |
|
- MedGemmaAgent class for model loading and inference |
|
- FastAPI endpoint for streaming responses |
|
- Utility for processing base64-encoded images |
|
""" |
|
import modal |
|
from typing import Optional, Generator, Dict, Any, List |
|
import os |
|
from fastapi import Security, HTTPException, Depends |
|
from fastapi.security.api_key import APIKeyHeader |
|
from fastapi.responses import StreamingResponse |
|
from pydantic import BaseModel |
|
import json |
|
import base64 |
|
from PIL import Image |
|
import io |
|
|
|
app = modal.App("example-medgemma-agent") |
|
volume = modal.Volume.from_name("model-weights-vol", create_if_missing=True) |
|
MODEL_DIR = "/models" |
|
|
|
MODEL_ID = "unsloth/medgemma-4b-it-unsloth-bnb-4bit" |
|
MINUTES = 60 |
|
|
|
API_KEY_NAME = "X-API-Key" |
|
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=True) |
|
|
|
async def get_api_key(api_key_header: str = Security(api_key_header)): |
|
""" |
|
Validates the provided API key against the environment variable. |
|
|
|
Args: |
|
api_key_header (str): The API key provided in the request header. |
|
|
|
Raises: |
|
HTTPException: If the API key is invalid. |
|
|
|
Returns: |
|
str: The validated API key. |
|
""" |
|
if api_key_header != os.environ["FASTAPI_KEY"]: |
|
raise HTTPException( |
|
status_code=403, detail="Invalid API Key" |
|
) |
|
return api_key_header |
|
|
|
image = ( |
|
modal.Image.debian_slim() |
|
.pip_install( |
|
"smolagents[vllm]", |
|
"fastapi[standard]", |
|
"wikipedia-api", |
|
"accelerate", |
|
"bitsandbytes", |
|
"huggingface-hub[hf_transfer]", |
|
"Pillow") |
|
.env({ |
|
"HF_HUB_ENABLE_HF_TRANSFER": "1", |
|
"HF_HUB_CACHE": MODEL_DIR |
|
}) |
|
) |
|
|
|
with image.imports(): |
|
from smolagents import VLLMModel, ToolCallingAgent, tool |
|
from pydantic import BaseModel |
|
import wikipediaapi |
|
|
|
@app.function( |
|
image=image, |
|
secrets=[modal.Secret.from_name("access_medgemma_hf")], |
|
volumes={MODEL_DIR: volume} |
|
) |
|
def download_model(): |
|
""" |
|
Downloads the model weights from Hugging Face Hub using the provided token. |
|
|
|
Returns: |
|
dict: Status message indicating success. |
|
""" |
|
from huggingface_hub import snapshot_download |
|
snapshot_download( |
|
repo_id=MODEL_ID, |
|
token=os.environ["HF_TOKEN"] |
|
) |
|
return {"status": "Model downloaded successfully"} |
|
|
|
@app.cls( |
|
image=image, |
|
gpu="L4:1", |
|
volumes={MODEL_DIR: volume}, |
|
min_containers=1, |
|
max_containers=1, |
|
timeout=15 * MINUTES, |
|
secrets=[modal.Secret.from_name("access_medgemma_hf")], |
|
) |
|
class MedGemmaAgent: |
|
""" |
|
Modal class for managing the MedGemma model and running inference with optional tool-calling. |
|
""" |
|
@modal.enter() |
|
def load_models(self): |
|
""" |
|
Loads the MedGemma model into memory and prepares it for inference. |
|
Downloads the model weights if not already present. |
|
""" |
|
download_model.remote() |
|
model_kwargs = { |
|
"max_model_len": 8192, |
|
"dtype": "bfloat16", |
|
"gpu_memory_utilization": 0.95, |
|
"tensor_parallel_size": 1, |
|
"trust_remote_code": True |
|
} |
|
self.model = VLLMModel( |
|
model_id=MODEL_ID, |
|
model_kwargs=model_kwargs |
|
) |
|
print(f"Model: {MODEL_ID} loaded successfully") |
|
|
|
@modal.method() |
|
def run(self, prompt: str, images: Optional[List[Image.Image]] = None) -> Generator[Dict[str, Any], None, None]: |
|
""" |
|
Runs the MedGemma agent on the provided prompt and optional images, yielding streaming responses. |
|
|
|
Args: |
|
prompt (str): The user prompt to process. |
|
images (Optional[List[Image.Image]]): List of PIL Images to provide as context (optional). |
|
|
|
Yields: |
|
Dict[str, Any]: Streaming response chunks, including 'thinking' and 'final' messages. |
|
""" |
|
@tool |
|
def wiki(query: str) -> str: |
|
""" |
|
Fetches a summary of a Wikipedia page based on a given search query (only one word or group of words). |
|
|
|
Args: |
|
query: The search term for the Wikipedia page (only one word or group of words). |
|
""" |
|
wiki = wikipediaapi.Wikipedia(language="en", user_agent="MinimalAgent/1.0") |
|
page = wiki.page(query) |
|
if not page.exists(): |
|
return "No Wikipedia page found." |
|
return page.summary[:1000] |
|
|
|
self.agent = ToolCallingAgent( |
|
tools=[wiki], |
|
model=self.model, |
|
max_steps=3 |
|
) |
|
|
|
|
|
yield { |
|
"type": "thinking", |
|
"content": {"message": "Starting to process your request..."} |
|
} |
|
|
|
|
|
result = self.agent.run( |
|
task=prompt, |
|
stream=False, |
|
reset=True, |
|
images=images if images else None, |
|
additional_args={"flatten_messages_as_text": False}, |
|
max_steps=3 |
|
) |
|
|
|
|
|
yield { |
|
"type": "final", |
|
"content": {"response": result} |
|
} |
|
|
|
class PromptRequest(BaseModel): |
|
""" |
|
Request model for the /run_medgemma endpoint. |
|
|
|
Attributes: |
|
prompt (str): The user prompt to process. |
|
image (Optional[str]): Base64-encoded image string (optional). |
|
history (Optional[list]): Conversation history (optional). |
|
""" |
|
prompt: str |
|
image: Optional[str] = None |
|
history: Optional[list] = None |
|
|
|
class GenerationResponse(BaseModel): |
|
""" |
|
Response model for non-streaming generation (not used in this file). |
|
|
|
Attributes: |
|
response (str): The generated response. |
|
""" |
|
response: str |
|
|
|
class StreamResponse(BaseModel): |
|
""" |
|
Response model for streaming responses. |
|
|
|
Attributes: |
|
type (str): The type of message ('thinking', 'tool_call', 'tool_result', 'final'). |
|
content (Dict[str, Any]): The content of the message. |
|
""" |
|
type: str |
|
content: Dict[str, Any] |
|
|
|
def process_image(image_base64: Optional[str]) -> Optional[Image.Image]: |
|
""" |
|
Decodes a base64-encoded image string into a PIL Image. |
|
|
|
Args: |
|
image_base64 (Optional[str]): Base64-encoded image string. |
|
|
|
Returns: |
|
Optional[Image.Image]: The decoded PIL Image, or None if decoding fails or input is None. |
|
""" |
|
if not image_base64: |
|
return None |
|
try: |
|
image_data = base64.b64decode(image_base64) |
|
return Image.open(io.BytesIO(image_data)) |
|
except Exception as e: |
|
print(f"Error processing image: {e}") |
|
return None |
|
|
|
@app.function( |
|
image=image, |
|
secrets=[ |
|
modal.Secret.from_name("access_medgemma_hf"), |
|
modal.Secret.from_name("FASTAPI_KEY") |
|
] |
|
) |
|
@modal.fastapi_endpoint(method="POST") |
|
async def run_medgemma(request: PromptRequest, api_key: str = Depends(get_api_key)): |
|
""" |
|
FastAPI endpoint for running the MedGemma agent with streaming responses. |
|
|
|
Args: |
|
request (PromptRequest): The request payload containing prompt and optional image. |
|
api_key (str): The validated API key (injected by Depends). |
|
|
|
Returns: |
|
StreamingResponse: An event-stream response yielding model output chunks. |
|
""" |
|
model_handler = MedGemmaAgent() |
|
|
|
|
|
image = process_image(request.image) |
|
images = [image] if image else None |
|
|
|
async def generate(): |
|
stream = model_handler.run.remote_gen.aio(request.prompt, images=images) |
|
async for chunk in stream: |
|
yield f"data: {json.dumps(chunk)}\n\n" |
|
|
|
return StreamingResponse( |
|
generate(), |
|
media_type="text/event-stream" |
|
) |