agentic-coach-advisor-medgemma / modal-medgemma.py
David Tang
add modal backend
e0baf14
"""
modal-medgemma.py
-----------------
This module defines a Modal app and FastAPI endpoint for running the MedGemma agent, a multimodal LLM that can process text and images. It provides a streaming API for inference, including Wikipedia tool-calling capabilities, and handles model download, loading, and inference with GPU support.
Nb. needs to be deployed with the following command:
`modal deploy modal-medgemma.py`
Key components:
- Modal app and volume setup for model weights
- MedGemmaAgent class for model loading and inference
- FastAPI endpoint for streaming responses
- Utility for processing base64-encoded images
"""
import modal
from typing import Optional, Generator, Dict, Any, List
import os
from fastapi import Security, HTTPException, Depends
from fastapi.security.api_key import APIKeyHeader
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import json
import base64
from PIL import Image
import io
app = modal.App("example-medgemma-agent")
volume = modal.Volume.from_name("model-weights-vol", create_if_missing=True)
MODEL_DIR = "/models"
# MODEL_ID = "google/medgemma-4b-it"
MODEL_ID = "unsloth/medgemma-4b-it-unsloth-bnb-4bit"
MINUTES = 60
API_KEY_NAME = "X-API-Key"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=True)
async def get_api_key(api_key_header: str = Security(api_key_header)):
"""
Validates the provided API key against the environment variable.
Args:
api_key_header (str): The API key provided in the request header.
Raises:
HTTPException: If the API key is invalid.
Returns:
str: The validated API key.
"""
if api_key_header != os.environ["FASTAPI_KEY"]:
raise HTTPException(
status_code=403, detail="Invalid API Key"
)
return api_key_header
image = (
modal.Image.debian_slim()
.pip_install(
"smolagents[vllm]",
"fastapi[standard]",
"wikipedia-api",
"accelerate",
"bitsandbytes",
"huggingface-hub[hf_transfer]",
"Pillow")
.env({
"HF_HUB_ENABLE_HF_TRANSFER": "1",
"HF_HUB_CACHE": MODEL_DIR
})
)
with image.imports():
from smolagents import VLLMModel, ToolCallingAgent, tool
from pydantic import BaseModel
import wikipediaapi
@app.function(
image=image,
secrets=[modal.Secret.from_name("access_medgemma_hf")],
volumes={MODEL_DIR: volume}
)
def download_model():
"""
Downloads the model weights from Hugging Face Hub using the provided token.
Returns:
dict: Status message indicating success.
"""
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=MODEL_ID,
token=os.environ["HF_TOKEN"]
)
return {"status": "Model downloaded successfully"}
@app.cls(
image=image,
gpu="L4:1",
volumes={MODEL_DIR: volume},
min_containers=1,
max_containers=1,
timeout=15 * MINUTES,
secrets=[modal.Secret.from_name("access_medgemma_hf")],
)
class MedGemmaAgent:
"""
Modal class for managing the MedGemma model and running inference with optional tool-calling.
"""
@modal.enter()
def load_models(self):
"""
Loads the MedGemma model into memory and prepares it for inference.
Downloads the model weights if not already present.
"""
download_model.remote()
model_kwargs = {
"max_model_len": 8192,
"dtype": "bfloat16",
"gpu_memory_utilization": 0.95,
"tensor_parallel_size": 1,
"trust_remote_code": True
}
self.model = VLLMModel(
model_id=MODEL_ID,
model_kwargs=model_kwargs
)
print(f"Model: {MODEL_ID} loaded successfully")
@modal.method()
def run(self, prompt: str, images: Optional[List[Image.Image]] = None) -> Generator[Dict[str, Any], None, None]:
"""
Runs the MedGemma agent on the provided prompt and optional images, yielding streaming responses.
Args:
prompt (str): The user prompt to process.
images (Optional[List[Image.Image]]): List of PIL Images to provide as context (optional).
Yields:
Dict[str, Any]: Streaming response chunks, including 'thinking' and 'final' messages.
"""
@tool
def wiki(query: str) -> str:
"""
Fetches a summary of a Wikipedia page based on a given search query (only one word or group of words).
Args:
query: The search term for the Wikipedia page (only one word or group of words).
"""
wiki = wikipediaapi.Wikipedia(language="en", user_agent="MinimalAgent/1.0")
page = wiki.page(query)
if not page.exists():
return "No Wikipedia page found."
return page.summary[:1000]
self.agent = ToolCallingAgent(
tools=[wiki],
model=self.model,
max_steps=3
)
# Yield thinking step
yield {
"type": "thinking",
"content": {"message": "Starting to process your request..."}
}
# Run the agent and capture the result
result = self.agent.run(
task=prompt,
stream=False,
reset=True,
images=images if images else None,
additional_args={"flatten_messages_as_text": False},
max_steps=3
)
# Yield the final response
yield {
"type": "final",
"content": {"response": result}
}
class PromptRequest(BaseModel):
"""
Request model for the /run_medgemma endpoint.
Attributes:
prompt (str): The user prompt to process.
image (Optional[str]): Base64-encoded image string (optional).
history (Optional[list]): Conversation history (optional).
"""
prompt: str
image: Optional[str] = None # Base64 encoded image
history: Optional[list] = None
class GenerationResponse(BaseModel):
"""
Response model for non-streaming generation (not used in this file).
Attributes:
response (str): The generated response.
"""
response: str
class StreamResponse(BaseModel):
"""
Response model for streaming responses.
Attributes:
type (str): The type of message ('thinking', 'tool_call', 'tool_result', 'final').
content (Dict[str, Any]): The content of the message.
"""
type: str # 'thinking', 'tool_call', 'tool_result', 'final'
content: Dict[str, Any]
def process_image(image_base64: Optional[str]) -> Optional[Image.Image]:
"""
Decodes a base64-encoded image string into a PIL Image.
Args:
image_base64 (Optional[str]): Base64-encoded image string.
Returns:
Optional[Image.Image]: The decoded PIL Image, or None if decoding fails or input is None.
"""
if not image_base64:
return None
try:
image_data = base64.b64decode(image_base64)
return Image.open(io.BytesIO(image_data))
except Exception as e:
print(f"Error processing image: {e}")
return None
@app.function(
image=image,
secrets=[
modal.Secret.from_name("access_medgemma_hf"),
modal.Secret.from_name("FASTAPI_KEY")
]
)
@modal.fastapi_endpoint(method="POST")
async def run_medgemma(request: PromptRequest, api_key: str = Depends(get_api_key)):
"""
FastAPI endpoint for running the MedGemma agent with streaming responses.
Args:
request (PromptRequest): The request payload containing prompt and optional image.
api_key (str): The validated API key (injected by Depends).
Returns:
StreamingResponse: An event-stream response yielding model output chunks.
"""
model_handler = MedGemmaAgent()
# Process image if provided
image = process_image(request.image)
images = [image] if image else None
async def generate():
stream = model_handler.run.remote_gen.aio(request.prompt, images=images)
async for chunk in stream:
yield f"data: {json.dumps(chunk)}\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream"
)