import os import json import torch from threading import Thread from typing import List, Dict, Any, Optional, Union from fastapi import FastAPI, HTTPException, Request, BackgroundTasks from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from transformers import AutoTokenizer, TextIteratorStreamer from vllm import LLM, SamplingParams # Initialize FastAPI app app = FastAPI(title="GainEnergy/OGAI-24B API") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load environment variables MODEL_ID = os.environ.get("MODEL_ID", "GainEnergy/OGAI-24B") DEFAULT_MAX_LENGTH = int(os.environ.get("DEFAULT_MAX_LENGTH", "2048")) DEFAULT_TEMPERATURE = float(os.environ.get("DEFAULT_TEMPERATURE", "0.7")) # Initialize the model and tokenizer try: model = LLM( model=MODEL_ID, trust_remote_code=True, tensor_parallel_size=torch.cuda.device_count(), ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) print(f"Model {MODEL_ID} loaded successfully!") except Exception as e: print(f"Error loading model: {e}") raise # Pydantic models for request/response class Message(BaseModel): role: str = Field(..., description="The role of the message sender (system, user, assistant)") content: str = Field(..., description="The content of the message") class GenerationRequest(BaseModel): messages: List[Message] = Field(..., description="List of messages in the conversation") temperature: Optional[float] = Field(DEFAULT_TEMPERATURE, description="Temperature for sampling") max_tokens: Optional[int] = Field(DEFAULT_MAX_LENGTH, description="Maximum number of tokens to generate") top_p: Optional[float] = Field(0.95, description="Top-p sampling parameter") top_k: Optional[int] = Field(50, description="Top-k sampling parameter") stream: Optional[bool] = Field(False, description="Whether to stream the response") class GenerationResponse(BaseModel): generated_text: str = Field(..., description="Generated text from the model") # Helper function to format messages for the model def format_messages(messages: List[Message]) -> str: """Format a list of messages into a prompt string the model can understand.""" formatted_prompt = "" for message in messages: if message.role == "system": formatted_prompt += f"<|system|>\n{message.content}\n" elif message.role == "user": formatted_prompt += f"<|user|>\n{message.content}\n" elif message.role == "assistant": formatted_prompt += f"<|assistant|>\n{message.content}\n" # Add the final assistant token to prompt the model to generate a response formatted_prompt += "<|assistant|>\n" return formatted_prompt # API endpoints @app.get("/") async def root(): """Root endpoint with basic information.""" return { "status": "running", "model": MODEL_ID, "version": "1.0.0" } @app.post("/generate", response_model=GenerationResponse) async def generate(request: GenerationRequest): """Generate text based on the conversation history.""" try: prompt = format_messages(request.messages) sampling_params = SamplingParams( temperature=request.temperature, max_tokens=request.max_tokens, top_p=request.top_p, top_k=request.top_k ) outputs = model.generate(prompt, sampling_params) generated_text = outputs[0].outputs[0].text return {"generated_text": generated_text} except Exception as e: raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") @app.post("/generate_stream") async def generate_stream(request: GenerationRequest): """Stream generated text based on the conversation history.""" if not request.stream: return await generate(request) try: prompt = format_messages(request.messages) sampling_params = SamplingParams( temperature=request.temperature, max_tokens=request.max_tokens, top_p=request.top_p, top_k=request.top_k ) async def stream_generator(): for output in model.generate(prompt, sampling_params, stream=True): chunk = output.outputs[0].text yield f"data: {json.dumps({'text': chunk})}\n\n" yield "data: [DONE]\n\n" return StreamingResponse( stream_generator(), media_type="text/event-stream" ) except Exception as e: raise HTTPException(status_code=500, detail=f"Streaming generation failed: {str(e)}") @app.get("/health") async def health_check(): """Health check endpoint.""" return {"status": "healthy"} # Run the FastAPI app with uvicorn when this script is executed directly if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)