Spaces:
Runtime error
Runtime error
import os | |
import io | |
import json | |
import base64 | |
import time | |
import torch | |
from PIL import Image | |
from typing import Optional | |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
from fastapi.responses import Response | |
from fastapi.middleware.cors import CORSMiddleware | |
from safetensors.torch import save_file | |
from src.pipeline import FluxPipeline | |
from src.transformer_flux import FluxTransformer2DModel | |
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora | |
# Define paths | |
base_path = "black-forest-labs/FLUX.1-dev" | |
lora_base_path = "./models" | |
# Initialize the model | |
print("Loading model...") | |
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16) | |
transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16) | |
pipe.transformer = transformer | |
pipe.to("cuda") | |
print("Model loaded successfully!") | |
# Function to clear cache | |
def clear_cache(transformer): | |
for name, attn_processor in transformer.attn_processors.items(): | |
attn_processor.bank_kv.clear() | |
# Create FastAPI app | |
app = FastAPI(title="Ghibli Image Generator API", | |
description="Convert images to Ghibli Studio style using EasyControl") | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def myfunc(): | |
return {"status":"running"} | |
# Health check endpoint | |
async def health_check(): | |
return {"status": "healthy", "model": "loaded"} | |
# Main image conversion endpoint | |
async def generate_ghibli( | |
file: UploadFile = File(...), | |
prompt: str = Form("Ghibli Studio style, Charming hand-drawn anime-style illustration"), | |
height: int = Form(768), | |
width: int = Form(768), | |
seed: int = Form(42) | |
): | |
try: | |
# Validate input image | |
if not file.content_type.startswith("image/"): | |
raise HTTPException(status_code=400, detail="File must be an image") | |
# Read and validate image | |
image_data = await file.read() | |
try: | |
spatial_img = Image.open(io.BytesIO(image_data)) | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Invalid image: {str(e)}") | |
# Validate dimensions | |
if height < 256 or height > 1024 or width < 256 or width > 1024: | |
raise HTTPException(status_code=400, detail="Dimensions must be between 256 and 1024") | |
# Configure LoRA | |
lora_path = os.path.join(lora_base_path, "Ghibli.safetensors") | |
set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512) | |
# Generate image | |
with torch.cuda.amp.autocast(): | |
output = pipe( | |
prompt, | |
height=height, | |
width=width, | |
guidance_scale=3.5, | |
num_inference_steps=25, | |
max_sequence_length=512, | |
generator=torch.Generator("cpu").manual_seed(seed), | |
subject_images=[], | |
spatial_images=[spatial_img], | |
cond_size=512, | |
).images[0] | |
# Clear cache | |
clear_cache(pipe.transformer) | |
# Convert output to bytes | |
img_byte_arr = io.BytesIO() | |
output.save(img_byte_arr, format='PNG') | |
img_byte_arr.seek(0) | |
# Return the image directly | |
return Response( | |
content=img_byte_arr.getvalue(), | |
media_type="image/png" | |
) | |
except HTTPException as e: | |
raise e | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
# Run the API with uvicorn | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=7860) |