BlackGoku7's picture
Update app.py
5505c5b verified
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
app = FastAPI()
MODEL_NAME = "BlackGoku7/deepseek-ai-DeepSeek-R1-Distill-Qwen-14B"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
torch_dtype=torch.bfloat16, # Or torch.float16 if your Space supports it
trust_remote_code=True
)
model.eval()
class Prompt(BaseModel):
text: str
max_new_tokens: int = 200
@app.get("/")
def root():
return {"message": "POST to /generate with {'text': 'your prompt'}"}
@app.post("/generate")
def generate(prompt: Prompt):
inputs = tokenizer(prompt.text, return_tensors="pt").to(model.device)
output = model.generate(
**inputs,
max_new_tokens=prompt.max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
return {"response": decoded}