tommytracx commited on
Commit
cbfa8a5
·
verified ·
1 Parent(s): cb69fc7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from threading import Thread
5
+ from typing import List, Dict, Any, Optional, Union
6
+
7
+ from fastapi import FastAPI, HTTPException, Request, BackgroundTasks
8
+ from fastapi.responses import StreamingResponse
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from pydantic import BaseModel, Field
11
+ from transformers import AutoTokenizer, TextIteratorStreamer
12
+ from vllm import LLM, SamplingParams
13
+
14
+
15
+ # Initialize FastAPI app
16
+ app = FastAPI(title="GainEnergy/OGAI-24B API")
17
+
18
+ # Add CORS middleware
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"],
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ # Load environment variables
28
+ MODEL_ID = os.environ.get("MODEL_ID", "GainEnergy/OGAI-24B")
29
+ DEFAULT_MAX_LENGTH = int(os.environ.get("DEFAULT_MAX_LENGTH", "2048"))
30
+ DEFAULT_TEMPERATURE = float(os.environ.get("DEFAULT_TEMPERATURE", "0.7"))
31
+
32
+ # Initialize the model and tokenizer
33
+ try:
34
+ model = LLM(
35
+ model=MODEL_ID,
36
+ trust_remote_code=True,
37
+ tensor_parallel_size=torch.cuda.device_count(),
38
+ )
39
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
40
+ print(f"Model {MODEL_ID} loaded successfully!")
41
+ except Exception as e:
42
+ print(f"Error loading model: {e}")
43
+ raise
44
+
45
+ # Pydantic models for request/response
46
+ class Message(BaseModel):
47
+ role: str = Field(..., description="The role of the message sender (system, user, assistant)")
48
+ content: str = Field(..., description="The content of the message")
49
+
50
+ class GenerationRequest(BaseModel):
51
+ messages: List[Message] = Field(..., description="List of messages in the conversation")
52
+ temperature: Optional[float] = Field(DEFAULT_TEMPERATURE, description="Temperature for sampling")
53
+ max_tokens: Optional[int] = Field(DEFAULT_MAX_LENGTH, description="Maximum number of tokens to generate")
54
+ top_p: Optional[float] = Field(0.95, description="Top-p sampling parameter")
55
+ top_k: Optional[int] = Field(50, description="Top-k sampling parameter")
56
+ stream: Optional[bool] = Field(False, description="Whether to stream the response")
57
+
58
+ class GenerationResponse(BaseModel):
59
+ generated_text: str = Field(..., description="Generated text from the model")
60
+
61
+ # Helper function to format messages for the model
62
+ def format_messages(messages: List[Message]) -> str:
63
+ """Format a list of messages into a prompt string the model can understand."""
64
+ formatted_prompt = ""
65
+
66
+ for message in messages:
67
+ if message.role == "system":
68
+ formatted_prompt += f"<|system|>\n{message.content}</s>\n"
69
+ elif message.role == "user":
70
+ formatted_prompt += f"<|user|>\n{message.content}</s>\n"
71
+ elif message.role == "assistant":
72
+ formatted_prompt += f"<|assistant|>\n{message.content}</s>\n"
73
+
74
+ # Add the final assistant token to prompt the model to generate a response
75
+ formatted_prompt += "<|assistant|>\n"
76
+
77
+ return formatted_prompt
78
+
79
+ # API endpoints
80
+ @app.get("/")
81
+ async def root():
82
+ """Root endpoint with basic information."""
83
+ return {
84
+ "status": "running",
85
+ "model": MODEL_ID,
86
+ "version": "1.0.0"
87
+ }
88
+
89
+ @app.post("/generate", response_model=GenerationResponse)
90
+ async def generate(request: GenerationRequest):
91
+ """Generate text based on the conversation history."""
92
+ try:
93
+ prompt = format_messages(request.messages)
94
+
95
+ sampling_params = SamplingParams(
96
+ temperature=request.temperature,
97
+ max_tokens=request.max_tokens,
98
+ top_p=request.top_p,
99
+ top_k=request.top_k
100
+ )
101
+
102
+ outputs = model.generate(prompt, sampling_params)
103
+ generated_text = outputs[0].outputs[0].text
104
+
105
+ return {"generated_text": generated_text}
106
+
107
+ except Exception as e:
108
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
109
+
110
+ @app.post("/generate_stream")
111
+ async def generate_stream(request: GenerationRequest):
112
+ """Stream generated text based on the conversation history."""
113
+ if not request.stream:
114
+ return await generate(request)
115
+
116
+ try:
117
+ prompt = format_messages(request.messages)
118
+
119
+ sampling_params = SamplingParams(
120
+ temperature=request.temperature,
121
+ max_tokens=request.max_tokens,
122
+ top_p=request.top_p,
123
+ top_k=request.top_k
124
+ )
125
+
126
+ async def stream_generator():
127
+ for output in model.generate(prompt, sampling_params, stream=True):
128
+ chunk = output.outputs[0].text
129
+ yield f"data: {json.dumps({'text': chunk})}\n\n"
130
+ yield "data: [DONE]\n\n"
131
+
132
+ return StreamingResponse(
133
+ stream_generator(),
134
+ media_type="text/event-stream"
135
+ )
136
+
137
+ except Exception as e:
138
+ raise HTTPException(status_code=500, detail=f"Streaming generation failed: {str(e)}")
139
+
140
+ @app.get("/health")
141
+ async def health_check():
142
+ """Health check endpoint."""
143
+ return {"status": "healthy"}
144
+
145
+ # Run the FastAPI app with uvicorn when this script is executed directly
146
+ if __name__ == "__main__":
147
+ import uvicorn
148
+ uvicorn.run(app, host="0.0.0.0", port=7860)