import json import random import time import pickle from threading import Lock from datetime import datetime, timedelta from collections import defaultdict from typing import Dict, List from fastapi import FastAPI, HTTPException, Request from loguru import logger from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded from fastapi.middleware.cors import CORSMiddleware MIN_PROMPTS = 1 MAX_PROMPTS = 1000 RATE_LIMIT = "100/minute" CACHE_TTL = 300 CATEGORIES_FILE = "categories.json" CACHE_FILE = "cache.pkl" LOCK = Lock() IP_REQUESTS = defaultdict(list) logger.add("app.log", rotation="500 MB", retention="2 days", level="ERROR") categorias_cache = None last_cache_update = 0 app = FastAPI( title="API de Generación de Prompts", version="1.0.0", docs_url=None, redoc_url=None, openapi_url=None, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) limiter = Limiter(key_func=get_remote_address, default_limits=[RATE_LIMIT]) app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) def load_categories() -> Dict[str, List[str]]: with LOCK: try: with open(CATEGORIES_FILE, "r") as file: return json.load(file) except (FileNotFoundError, json.JSONDecodeError) as e: logger.error(f"Error al cargar 'categories.json': {str(e)}") raise HTTPException(status_code=500, detail="Error al cargar categorías.") def save_cache(data): with open(CACHE_FILE, "wb") as f: pickle.dump(data, f) def load_cache(): try: with open(CACHE_FILE, "rb") as f: return pickle.load(f) except (FileNotFoundError, pickle.UnpicklingError): return None def get_cached_categories() -> Dict[str, List[str]]: global categorias_cache, last_cache_update current_time = time.time() if categorias_cache is None or (current_time - last_cache_update) > CACHE_TTL: categorias_cache = load_categories() save_cache(categorias_cache) last_cache_update = current_time return categorias_cache def calcular_combinaciones(record_count: Dict[str, int]) -> int: total_combinations = 1 for count in record_count.values(): total_combinations *= count return total_combinations @app.get("/") @limiter.limit(RATE_LIMIT) async def read_root(request: Request): logger.info("Endpoint raíz consultado.") return {"message": "Bienvenido a la API de generación de prompts"} @app.get("/generate") @limiter.limit(RATE_LIMIT) async def detail_generate_prompts(request: Request): logger.warning("Intento de generar prompts sin cantidad especificada.") raise HTTPException( status_code=400, detail="Debe especificar la cantidad de prompts a generar, por ejemplo: /generate/10", ) @app.get("/generate/{cantidad}") @limiter.limit(RATE_LIMIT) async def generate_prompts(request: Request, cantidad: int): if cantidad < MIN_PROMPTS or cantidad > MAX_PROMPTS: logger.warning( f"Intento de generar {cantidad} prompts: fuera del rango permitido." ) raise HTTPException( status_code=400, detail=f"La cantidad debe estar entre {MIN_PROMPTS} y {MAX_PROMPTS}.", ) prompts = [generar_base_prompt() for _ in range(cantidad)] logger.info(f"Generados {cantidad} prompts exitosamente.") return {"prompts": prompts} @app.get("/count_records") @limiter.limit(RATE_LIMIT) async def count_records(request: Request): try: categorias = get_cached_categories() record_count = {key: len(value) for key, value in categorias.items()} logger.info(f"Número de registros por etiqueta: {record_count}") total_combinations = calcular_combinaciones(record_count) logger.info(f"Total de combinaciones posibles: {total_combinations}") return {"record_count": record_count, "total_combinations": total_combinations} except HTTPException as e: raise e @app.middleware("http") async def limit_request_frequency(request: Request, call_next): ip = request.client.host now = datetime.now() IP_REQUESTS[ip] = [ time for time in IP_REQUESTS[ip] if now - time < timedelta(minutes=1) ] if len(IP_REQUESTS[ip]) >= 100: logger.warning(f"Bloqueo temporal para la IP {ip}, demasiadas solicitudes.") raise HTTPException( status_code=429, detail="Demasiadas solicitudes. Espere 1 minuto." ) IP_REQUESTS[ip].append(now) response = await call_next(request) return response def generar_base_prompt() -> str: categorias = get_cached_categories() return ( f"A {random.choice(categorias['edad'])} {random.choice(categorias['sexo'])} " f"{random.choice(categorias['tipo'])} with {random.choice(categorias['peinado'])} " f"({random.choice(categorias['color_cabello'])}) and {random.choice(categorias['ojos'])}, " f"having {random.choice(categorias['piel'])}, wearing {random.choice(categorias['ropa'])}, " f"in a {random.choice(categorias['escenario'])}, {random.choice(categorias['pose'])} while feeling " f"{random.choice(categorias['emocion'])}, adorned with {random.choice(categorias['extras'])}." )