tommytracx's picture
Create app.py
cbfa8a5 verified
raw
history blame
5.21 kB
import os
import json
import torch
from threading import Thread
from typing import List, Dict, Any, Optional, Union
from fastapi import FastAPI, HTTPException, Request, BackgroundTasks
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, TextIteratorStreamer
from vllm import LLM, SamplingParams
# Initialize FastAPI app
app = FastAPI(title="GainEnergy/OGAI-24B API")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load environment variables
MODEL_ID = os.environ.get("MODEL_ID", "GainEnergy/OGAI-24B")
DEFAULT_MAX_LENGTH = int(os.environ.get("DEFAULT_MAX_LENGTH", "2048"))
DEFAULT_TEMPERATURE = float(os.environ.get("DEFAULT_TEMPERATURE", "0.7"))
# Initialize the model and tokenizer
try:
model = LLM(
model=MODEL_ID,
trust_remote_code=True,
tensor_parallel_size=torch.cuda.device_count(),
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
print(f"Model {MODEL_ID} loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise
# Pydantic models for request/response
class Message(BaseModel):
role: str = Field(..., description="The role of the message sender (system, user, assistant)")
content: str = Field(..., description="The content of the message")
class GenerationRequest(BaseModel):
messages: List[Message] = Field(..., description="List of messages in the conversation")
temperature: Optional[float] = Field(DEFAULT_TEMPERATURE, description="Temperature for sampling")
max_tokens: Optional[int] = Field(DEFAULT_MAX_LENGTH, description="Maximum number of tokens to generate")
top_p: Optional[float] = Field(0.95, description="Top-p sampling parameter")
top_k: Optional[int] = Field(50, description="Top-k sampling parameter")
stream: Optional[bool] = Field(False, description="Whether to stream the response")
class GenerationResponse(BaseModel):
generated_text: str = Field(..., description="Generated text from the model")
# Helper function to format messages for the model
def format_messages(messages: List[Message]) -> str:
"""Format a list of messages into a prompt string the model can understand."""
formatted_prompt = ""
for message in messages:
if message.role == "system":
formatted_prompt += f"<|system|>\n{message.content}</s>\n"
elif message.role == "user":
formatted_prompt += f"<|user|>\n{message.content}</s>\n"
elif message.role == "assistant":
formatted_prompt += f"<|assistant|>\n{message.content}</s>\n"
# Add the final assistant token to prompt the model to generate a response
formatted_prompt += "<|assistant|>\n"
return formatted_prompt
# API endpoints
@app.get("/")
async def root():
"""Root endpoint with basic information."""
return {
"status": "running",
"model": MODEL_ID,
"version": "1.0.0"
}
@app.post("/generate", response_model=GenerationResponse)
async def generate(request: GenerationRequest):
"""Generate text based on the conversation history."""
try:
prompt = format_messages(request.messages)
sampling_params = SamplingParams(
temperature=request.temperature,
max_tokens=request.max_tokens,
top_p=request.top_p,
top_k=request.top_k
)
outputs = model.generate(prompt, sampling_params)
generated_text = outputs[0].outputs[0].text
return {"generated_text": generated_text}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
@app.post("/generate_stream")
async def generate_stream(request: GenerationRequest):
"""Stream generated text based on the conversation history."""
if not request.stream:
return await generate(request)
try:
prompt = format_messages(request.messages)
sampling_params = SamplingParams(
temperature=request.temperature,
max_tokens=request.max_tokens,
top_p=request.top_p,
top_k=request.top_k
)
async def stream_generator():
for output in model.generate(prompt, sampling_params, stream=True):
chunk = output.outputs[0].text
yield f"data: {json.dumps({'text': chunk})}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
stream_generator(),
media_type="text/event-stream"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Streaming generation failed: {str(e)}")
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy"}
# Run the FastAPI app with uvicorn when this script is executed directly
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)