Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
import os | |
import tempfile | |
import torch | |
import gradio as gr | |
import traceback | |
import sys | |
import logging | |
from PIL import Image | |
from models.llava import LLaVA | |
from typing import Dict, Any, Optional, Union | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.StreamHandler(sys.stdout), | |
logging.FileHandler('app.log') | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI(title="LLaVA Web Interface") | |
# Configure CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Global state | |
model = None | |
model_status: Dict[str, Any] = { | |
"initialized": False, | |
"device": None, | |
"error": None, | |
"last_error": None | |
} | |
async def global_exception_handler(request: Request, exc: Exception): | |
"""Global exception handler to catch and log all unhandled exceptions.""" | |
error_msg = f"Unhandled error: {str(exc)}\n{traceback.format_exc()}" | |
logger.error(error_msg) | |
model_status["last_error"] = error_msg | |
return JSONResponse( | |
status_code=500, | |
content={"error": "Internal Server Error", "details": str(exc)} | |
) | |
async def get_status(): | |
"""Endpoint to check model and application status.""" | |
return { | |
"model_initialized": model is not None, | |
"model_status": model_status, | |
"memory_usage": { | |
"cuda_available": torch.cuda.is_available(), | |
"cuda_memory_allocated": torch.cuda.memory_allocated() if torch.cuda.is_available() else 0, | |
"cuda_memory_reserved": torch.cuda.memory_reserved() if torch.cuda.is_available() else 0 | |
} | |
} | |
def initialize_model(): | |
"""Initialize the LLaVA model with proper error handling.""" | |
global model, model_status | |
try: | |
logger.info("Starting model initialization...") | |
model_status["initialized"] = False | |
model_status["error"] = None | |
# Clear any existing model and memory | |
if model is not None: | |
del model | |
torch.cuda.empty_cache() | |
# Set device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {device}") | |
# Initialize new model with basic parameters | |
model = LLaVA( | |
vision_model_path="openai/clip-vit-base-patch32", | |
language_model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
projection_hidden_dim=2048, | |
device=device | |
) | |
# Configure model for inference | |
if hasattr(model, 'language_model'): | |
# Set model to evaluation mode | |
model.language_model.eval() | |
# Configure model for inference | |
if hasattr(model.language_model, 'config'): | |
model.language_model.config.use_cache = False | |
# Move model to device | |
model = model.to(device) | |
# Set generation config if available | |
if hasattr(model.language_model, 'generation_config'): | |
model.language_model.generation_config.do_sample = True | |
model.language_model.generation_config.max_new_tokens = 256 | |
model.language_model.generation_config.temperature = 0.7 | |
model.language_model.generation_config.top_p = 0.9 | |
if hasattr(model.language_model.config, 'eos_token_id'): | |
model.language_model.generation_config.pad_token_id = model.language_model.config.eos_token_id | |
model_status.update({ | |
"initialized": True, | |
"device": str(model.device), | |
"error": None, | |
"model_info": { | |
"vision_model": "openai/clip-vit-base-patch32", | |
"language_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
"device": str(model.device) | |
} | |
}) | |
logger.info(f"Model successfully initialized on {model.device}") | |
return True | |
except Exception as e: | |
error_msg = f"Model initialization failed: {str(e)}" | |
logger.error(error_msg) | |
logger.error(traceback.format_exc()) | |
model = None | |
model_status.update({ | |
"initialized": False, | |
"error": error_msg, | |
"last_error": traceback.format_exc() | |
}) | |
return False | |
def process_image( | |
image: Optional[Image.Image], | |
prompt: str, | |
max_new_tokens: int = 256, | |
temperature: float = 0.7, | |
top_p: float = 0.9 | |
) -> str: | |
"""Process an image with the LLaVA model with comprehensive error handling.""" | |
global model_status | |
logger.info("Starting image processing...") | |
# Validate model state | |
if model is None: | |
logger.error("Model not initialized") | |
if not initialize_model(): | |
model_status["last_error"] = "Model initialization failed during processing" | |
return "Error: Model initialization failed. Please try again later." | |
# Validate inputs | |
if image is None: | |
logger.error("No image provided") | |
return "Error: Please upload an image first." | |
if not isinstance(image, Image.Image): | |
logger.error(f"Invalid image type: {type(image)}") | |
return "Error: Invalid image format. Please upload a valid image." | |
if not prompt or not isinstance(prompt, str) or not prompt.strip(): | |
logger.error("Invalid prompt") | |
return "Error: Please enter a valid prompt." | |
# Validate parameters | |
try: | |
max_new_tokens = int(max_new_tokens) | |
temperature = float(temperature) | |
top_p = float(top_p) | |
except (ValueError, TypeError) as e: | |
logger.error(f"Invalid parameters: {str(e)}") | |
return "Error: Invalid generation parameters." | |
temp_path = None | |
try: | |
logger.info(f"Processing image with prompt: {prompt[:100]}...") | |
# Save image with explicit format | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file: | |
image.save(temp_file.name, format='PNG') | |
temp_path = temp_file.name | |
logger.info(f"Saved temporary image to {temp_path}") | |
# Clear memory | |
torch.cuda.empty_cache() | |
# Process image with Hugging Face specific settings | |
with torch.inference_mode(): | |
try: | |
logger.info("Generating response...") | |
# Update generation config if available | |
if hasattr(model, 'language_model') and hasattr(model.language_model, 'generation_config'): | |
model.language_model.generation_config.max_new_tokens = max_new_tokens | |
model.language_model.generation_config.temperature = temperature | |
model.language_model.generation_config.top_p = top_p | |
response = model.generate_from_image( | |
image_path=temp_path, | |
prompt=prompt, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
num_beams=1, | |
pad_token_id=model.language_model.config.eos_token_id if hasattr(model, 'language_model') else None | |
) | |
if not response: | |
raise ValueError("Empty response from model") | |
if not isinstance(response, str): | |
raise ValueError(f"Invalid response type: {type(response)}") | |
logger.info("Successfully generated response") | |
model_status["last_error"] = None | |
return response | |
except Exception as model_error: | |
error_msg = f"Model inference error: {str(model_error)}" | |
logger.error(error_msg) | |
logger.error(traceback.format_exc()) | |
model_status["last_error"] = error_msg | |
return f"Error during model inference: {str(model_error)}" | |
except Exception as e: | |
error_msg = f"Processing error: {str(e)}" | |
logger.error(error_msg) | |
logger.error(traceback.format_exc()) | |
model_status["last_error"] = error_msg | |
return f"Error processing image: {str(e)}" | |
finally: | |
# Cleanup | |
if temp_path and os.path.exists(temp_path): | |
try: | |
os.unlink(temp_path) | |
logger.info("Cleaned up temporary file") | |
except Exception as e: | |
logger.warning(f"Failed to clean up temporary file: {str(e)}") | |
try: | |
torch.cuda.empty_cache() | |
except Exception as e: | |
logger.warning(f"Failed to clear CUDA cache: {str(e)}") | |
def create_interface(): | |
"""Create a simplified Gradio interface.""" | |
try: | |
with gr.Blocks(title="LLaVA Chat", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# LLaVA Chat | |
Upload an image and chat with LLaVA about it. This model can understand and describe images, answer questions about them, and engage in visual conversations. | |
## Example Prompts | |
Try these prompts to get started: | |
- "What can you see in this image?" | |
- "Describe this scene in detail" | |
- "What emotions does this image convey?" | |
- "What's happening in this picture?" | |
- "Can you identify any objects or people in this image?" | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
# Input components | |
image_input = gr.Image(type="pil", label="Upload Image") | |
prompt_input = gr.Textbox( | |
label="Ask about the image", | |
placeholder="What can you see in this image?", | |
lines=3 | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
max_tokens = gr.Slider( | |
minimum=32, | |
maximum=512, | |
value=256, | |
step=32, | |
label="Max New Tokens" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.1, | |
label="Top P" | |
) | |
submit_btn = gr.Button("Generate Response", variant="primary") | |
with gr.Column(): | |
output = gr.Textbox( | |
label="Model Response", | |
lines=10, | |
show_copy_button=True | |
) | |
# Set up event handler | |
submit_btn.click( | |
fn=process_image, | |
inputs=[ | |
image_input, | |
prompt_input, | |
max_tokens, | |
temperature, | |
top_p | |
], | |
outputs=output | |
) | |
logger.info("Successfully created Gradio interface") | |
return demo | |
except Exception as e: | |
logger.error(f"Failed to create interface: {str(e)}") | |
logger.error(traceback.format_exc()) | |
raise | |
# Create and mount Gradio app | |
try: | |
logger.info("Creating Gradio interface...") | |
demo = create_interface() | |
app = gr.mount_gradio_app(app, demo, path="/") | |
logger.info("Successfully mounted Gradio app") | |
except Exception as e: | |
logger.error(f"Failed to mount Gradio app: {str(e)}") | |
logger.error(traceback.format_exc()) | |
raise | |
if __name__ == "__main__": | |
try: | |
# Initialize model | |
logger.info("Starting application...") | |
if not initialize_model(): | |
logger.error("Model initialization failed. Exiting...") | |
sys.exit(1) | |
# Start server | |
import uvicorn | |
logger.info("Starting server...") | |
uvicorn.run( | |
app, | |
host="0.0.0.0", | |
port=7860, | |
log_level="info" | |
) | |
except Exception as e: | |
logger.error(f"Application startup failed: {str(e)}") | |
logger.error(traceback.format_exc()) | |
sys.exit(1) |