akane-ai / app.py
Arifzyn's picture
Update app.py
7bf34a0 verified
import torch
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import gc
import logging
from typing import List, Dict, Any, Optional
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
app = FastAPI(title="TinyLlama API", description="API untuk model TinyLlama-1.1B-Chat")
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model_dir = "model_cache"
tokenizer = None
model = None
is_loading = False
def load_model():
global tokenizer, model, is_loading
if is_loading:
logger.info("Model sedang dimuat oleh proses lain")
return
if tokenizer is None or model is None:
try:
is_loading = True
logger.info(f"Memuat model {model_id}...")
os.makedirs(model_dir, exist_ok=True)
if model is not None:
del model
torch.cuda.empty_cache()
gc.collect()
tokenizer = AutoTokenizer.from_pretrained(
model_id,
cache_dir=model_dir,
use_fast=True,
)
device_map = "auto" if torch.cuda.is_available() else None
model = AutoModelForCausalLM.from_pretrained(
model_id,
cache_dir=model_dir,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
device_map=device_map
)
logger.info("Model berhasil dimuat!")
except Exception as e:
logger.error(f"Gagal memuat model: {str(e)}")
raise e
finally:
is_loading = False
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
max_tokens: Optional[int] = 500
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.9
class ChatResponse(BaseModel):
response: str
usage: Dict[str, Any]
@app.on_event("startup")
async def startup_event():
load_model()
@app.post("/chat", response_model=ChatResponse)
async def chat(req: ChatRequest):
if model is None:
raise HTTPException(status_code=500, detail="Gagal memuat model")
try:
system_content = ""
for msg in req.messages:
if msg.role.lower() == "system":
system_content = msg.content
break
messages_text = []
if system_content:
messages_text.append(f"<|system|>\n{system_content}")
for msg in req.messages:
role = msg.role.lower()
content = msg.content
if role == "system":
continue
if role == "user":
messages_text.append(f"<|user|>\n{content}")
elif role == "assistant":
messages_text.append(f"<|assistant|>\n{content}")
messages_text.append("<|assistant|>")
prompt = "\n".join(messages_text)
inputs = tokenizer(prompt, return_tensors="pt")
input_length = len(inputs.input_ids[0])
if hasattr(model, 'device'):
inputs = {key: value.to(model.device) for key, value in inputs.items()}
generation_config = {
'max_new_tokens': req.max_tokens,
'temperature': req.temperature,
'top_p': req.top_p,
'do_sample': True if req.temperature > 0 else False,
'pad_token_id': tokenizer.eos_token_id
}
with torch.no_grad():
output = model.generate(
inputs['input_ids'],
**generation_config
)
result = tokenizer.decode(output[0], skip_special_tokens=True)
assistants = result.split("<|assistant|>")
if len(assistants) > 1:
response = assistants[-1].strip()
else:
user_tokens = result.split("<|user|>")
if len(user_tokens) > 1:
last_part = user_tokens[-1]
if "\n" in last_part:
response = "\n".join(last_part.split("\n")[1:]).strip()
else:
response = last_part.strip()
else:
prompt_length = len(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True))
response = result[prompt_length:].strip()
if not response:
response = "Maaf, tidak dapat menghasilkan respons yang valid."
output_length = len(output[0])
new_tokens = output_length - input_length
usage_info = {
"prompt_tokens": input_length,
"completion_tokens": new_tokens,
"total_tokens": output_length
}
return ChatResponse(response=response, usage=usage_info)
except Exception as e:
logger.error(f"Error saat melakukan chat: {str(e)}")
raise HTTPException(status_code=500, detail=f"Gagal menghasilkan respons: {str(e)}")
@app.get("/model-status")
async def model_status():
status = "loading" if is_loading else "not_loaded" if model is None else "loaded"
return {
"status": status,
"model_id": model_id,
"device": str(model.device) if model is not None and hasattr(model, 'device') else "tidak tersedia"
}
@app.post("/load-model")
async def force_load_model(background_tasks: BackgroundTasks):
global is_loading
if is_loading:
return {"status": "loading", "message": f"Model {model_id} sedang dimuat"}
if model is not None:
return {"status": "already_loaded", "message": f"Model {model_id} sudah dimuat"}
background_tasks.add_task(load_model)
return {"status": "loading_started", "message": f"Proses memuat model {model_id} telah dimulai"}
@app.get("/")
async def root():
status = "loading" if is_loading else "not_loaded" if model is None else "loaded"
return {
"message": "API TinyLlama berjalan",
"model": model_id,
"status": status,
"endpoints": [
{"path": "/", "method": "GET", "description": "Informasi API"},
{"path": "/chat", "method": "POST", "description": "Endpoint untuk chat dengan model"},
{"path": "/model-status", "method": "GET", "description": "Cek status model"},
{"path": "/load-model", "method": "POST", "description": "Muat model jika belum dimuat"}
]
}
if __name__ == "__main__":
import uvicorn
logger.info(f"Memulai server API untuk model {model_id}")
uvicorn.run(app, host="0.0.0.0", port=7860)