import os import io import json import base64 import time import torch from PIL import Image from typing import Optional from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.responses import Response from fastapi.middleware.cors import CORSMiddleware from safetensors.torch import save_file from src.pipeline import FluxPipeline from src.transformer_flux import FluxTransformer2DModel from src.lora_helper import set_single_lora, set_multi_lora, unset_lora # Define paths base_path = "black-forest-labs/FLUX.1-dev" lora_base_path = "./models" # Initialize the model print("Loading model...") pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16) transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16) pipe.transformer = transformer pipe.to("cuda") print("Model loaded successfully!") # Function to clear cache def clear_cache(transformer): for name, attn_processor in transformer.attn_processors.items(): attn_processor.bank_kv.clear() # Create FastAPI app app = FastAPI(title="Ghibli Image Generator API", description="Convert images to Ghibli Studio style using EasyControl") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") def myfunc(): return {"status":"running"} # Health check endpoint @app.get("/health") async def health_check(): return {"status": "healthy", "model": "loaded"} # Main image conversion endpoint @app.post("/generate-ghibli") async def generate_ghibli( file: UploadFile = File(...), prompt: str = Form("Ghibli Studio style, Charming hand-drawn anime-style illustration"), height: int = Form(768), width: int = Form(768), seed: int = Form(42) ): try: # Validate input image if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="File must be an image") # Read and validate image image_data = await file.read() try: spatial_img = Image.open(io.BytesIO(image_data)) except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid image: {str(e)}") # Validate dimensions if height < 256 or height > 1024 or width < 256 or width > 1024: raise HTTPException(status_code=400, detail="Dimensions must be between 256 and 1024") # Configure LoRA lora_path = os.path.join(lora_base_path, "Ghibli.safetensors") set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512) # Generate image with torch.cuda.amp.autocast(): output = pipe( prompt, height=height, width=width, guidance_scale=3.5, num_inference_steps=25, max_sequence_length=512, generator=torch.Generator("cpu").manual_seed(seed), subject_images=[], spatial_images=[spatial_img], cond_size=512, ).images[0] # Clear cache clear_cache(pipe.transformer) # Convert output to bytes img_byte_arr = io.BytesIO() output.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) # Return the image directly return Response( content=img_byte_arr.getvalue(), media_type="image/png" ) except HTTPException as e: raise e except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") # Run the API with uvicorn if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860)