BlackGoku7 commited on
Commit
5505c5b
·
verified ·
1 Parent(s): d0b6609

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -3
app.py CHANGED
@@ -1,8 +1,39 @@
1
  from fastapi import FastAPI
 
 
 
2
 
3
  app = FastAPI()
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
8
-
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
  app = FastAPI()
7
 
8
+ MODEL_NAME = "BlackGoku7/deepseek-ai-DeepSeek-R1-Distill-Qwen-14B"
9
+
10
+ # Load tokenizer and model
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ MODEL_NAME,
14
+ device_map="auto",
15
+ torch_dtype=torch.bfloat16, # Or torch.float16 if your Space supports it
16
+ trust_remote_code=True
17
+ )
18
+ model.eval()
19
+
20
+ class Prompt(BaseModel):
21
+ text: str
22
+ max_new_tokens: int = 200
23
+
24
  @app.get("/")
25
+ def root():
26
+ return {"message": "POST to /generate with {'text': 'your prompt'}"}
27
+
28
+ @app.post("/generate")
29
+ def generate(prompt: Prompt):
30
+ inputs = tokenizer(prompt.text, return_tensors="pt").to(model.device)
31
+ output = model.generate(
32
+ **inputs,
33
+ max_new_tokens=prompt.max_new_tokens,
34
+ do_sample=True,
35
+ temperature=0.7,
36
+ top_p=0.9,
37
+ )
38
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True)
39
+ return {"response": decoded}