|
import torch |
|
from fastapi import FastAPI, HTTPException, BackgroundTasks |
|
from pydantic import BaseModel |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import os |
|
import gc |
|
import logging |
|
from typing import List, Dict, Any, Optional |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI(title="TinyLlama API", description="API untuk model TinyLlama-1.1B-Chat") |
|
|
|
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
model_dir = "model_cache" |
|
|
|
tokenizer = None |
|
model = None |
|
is_loading = False |
|
|
|
def load_model(): |
|
global tokenizer, model, is_loading |
|
|
|
if is_loading: |
|
logger.info("Model sedang dimuat oleh proses lain") |
|
return |
|
|
|
if tokenizer is None or model is None: |
|
try: |
|
is_loading = True |
|
logger.info(f"Memuat model {model_id}...") |
|
|
|
os.makedirs(model_dir, exist_ok=True) |
|
|
|
if model is not None: |
|
del model |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_id, |
|
cache_dir=model_dir, |
|
use_fast=True, |
|
) |
|
|
|
device_map = "auto" if torch.cuda.is_available() else None |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
cache_dir=model_dir, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
low_cpu_mem_usage=True, |
|
device_map=device_map |
|
) |
|
|
|
logger.info("Model berhasil dimuat!") |
|
except Exception as e: |
|
logger.error(f"Gagal memuat model: {str(e)}") |
|
raise e |
|
finally: |
|
is_loading = False |
|
|
|
|
|
class Message(BaseModel): |
|
role: str |
|
content: str |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
messages: List[Message] |
|
max_tokens: Optional[int] = 500 |
|
temperature: Optional[float] = 0.7 |
|
top_p: Optional[float] = 0.9 |
|
|
|
|
|
class ChatResponse(BaseModel): |
|
response: str |
|
usage: Dict[str, Any] |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
load_model() |
|
|
|
@app.post("/chat", response_model=ChatResponse) |
|
async def chat(req: ChatRequest): |
|
if model is None: |
|
raise HTTPException(status_code=500, detail="Gagal memuat model") |
|
|
|
try: |
|
system_content = "" |
|
|
|
for msg in req.messages: |
|
if msg.role.lower() == "system": |
|
system_content = msg.content |
|
break |
|
|
|
messages_text = [] |
|
|
|
if system_content: |
|
messages_text.append(f"<|system|>\n{system_content}") |
|
|
|
for msg in req.messages: |
|
role = msg.role.lower() |
|
content = msg.content |
|
|
|
if role == "system": |
|
continue |
|
|
|
if role == "user": |
|
messages_text.append(f"<|user|>\n{content}") |
|
elif role == "assistant": |
|
messages_text.append(f"<|assistant|>\n{content}") |
|
|
|
messages_text.append("<|assistant|>") |
|
|
|
prompt = "\n".join(messages_text) |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
input_length = len(inputs.input_ids[0]) |
|
|
|
if hasattr(model, 'device'): |
|
inputs = {key: value.to(model.device) for key, value in inputs.items()} |
|
|
|
generation_config = { |
|
'max_new_tokens': req.max_tokens, |
|
'temperature': req.temperature, |
|
'top_p': req.top_p, |
|
'do_sample': True if req.temperature > 0 else False, |
|
'pad_token_id': tokenizer.eos_token_id |
|
} |
|
|
|
with torch.no_grad(): |
|
output = model.generate( |
|
inputs['input_ids'], |
|
**generation_config |
|
) |
|
|
|
result = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
assistants = result.split("<|assistant|>") |
|
if len(assistants) > 1: |
|
response = assistants[-1].strip() |
|
else: |
|
user_tokens = result.split("<|user|>") |
|
if len(user_tokens) > 1: |
|
last_part = user_tokens[-1] |
|
if "\n" in last_part: |
|
response = "\n".join(last_part.split("\n")[1:]).strip() |
|
else: |
|
response = last_part.strip() |
|
else: |
|
prompt_length = len(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)) |
|
response = result[prompt_length:].strip() |
|
|
|
if not response: |
|
response = "Maaf, tidak dapat menghasilkan respons yang valid." |
|
|
|
output_length = len(output[0]) |
|
new_tokens = output_length - input_length |
|
|
|
usage_info = { |
|
"prompt_tokens": input_length, |
|
"completion_tokens": new_tokens, |
|
"total_tokens": output_length |
|
} |
|
|
|
return ChatResponse(response=response, usage=usage_info) |
|
|
|
except Exception as e: |
|
logger.error(f"Error saat melakukan chat: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Gagal menghasilkan respons: {str(e)}") |
|
|
|
|
|
@app.get("/model-status") |
|
async def model_status(): |
|
status = "loading" if is_loading else "not_loaded" if model is None else "loaded" |
|
return { |
|
"status": status, |
|
"model_id": model_id, |
|
"device": str(model.device) if model is not None and hasattr(model, 'device') else "tidak tersedia" |
|
} |
|
|
|
|
|
@app.post("/load-model") |
|
async def force_load_model(background_tasks: BackgroundTasks): |
|
global is_loading |
|
|
|
if is_loading: |
|
return {"status": "loading", "message": f"Model {model_id} sedang dimuat"} |
|
|
|
if model is not None: |
|
return {"status": "already_loaded", "message": f"Model {model_id} sudah dimuat"} |
|
|
|
background_tasks.add_task(load_model) |
|
return {"status": "loading_started", "message": f"Proses memuat model {model_id} telah dimulai"} |
|
|
|
|
|
@app.get("/") |
|
async def root(): |
|
status = "loading" if is_loading else "not_loaded" if model is None else "loaded" |
|
return { |
|
"message": "API TinyLlama berjalan", |
|
"model": model_id, |
|
"status": status, |
|
"endpoints": [ |
|
{"path": "/", "method": "GET", "description": "Informasi API"}, |
|
{"path": "/chat", "method": "POST", "description": "Endpoint untuk chat dengan model"}, |
|
{"path": "/model-status", "method": "GET", "description": "Cek status model"}, |
|
{"path": "/load-model", "method": "POST", "description": "Muat model jika belum dimuat"} |
|
] |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
logger.info(f"Memulai server API untuk model {model_id}") |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |