rabiyulfahim commited on
Commit
c8ad3e4
·
verified ·
1 Parent(s): 39e29db

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+ import os
5
+ from pydantic import BaseModel
6
+
7
+ # ✅ Force Hugging Face cache to /tmp (writable in Spaces)
8
+ os.environ["HF_HOME"] = "/tmp"
9
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp"
10
+
11
+
12
+ model_id = "rabiyulfahim/qa_python_gpt2"
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="/tmp")
15
+ model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir="/tmp")
16
+
17
+
18
+ app = FastAPI(title="QA GPT2 API", description="Serving HuggingFace model with FastAPI")
19
+
20
+
21
+ # Request schema
22
+ class QueryRequest(BaseModel):
23
+ question: str
24
+ max_new_tokens: int = 50
25
+ temperature: float = 0.7
26
+ top_p: float = 0.9
27
+
28
+
29
+ @app.get("/")
30
+ def home():
31
+ return {"message": "Welcome to QA GPT2 API 🚀"}
32
+
33
+ @app.get("/ask")
34
+ def ask(question: str, max_new_tokens: int = 50):
35
+ inputs = tokenizer(question, return_tensors="pt")
36
+ outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
37
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
+ return {"question": question, "answer": answer}
39
+
40
+
41
+
42
+ # Health check endpoint
43
+ @app.get("/health")
44
+ def health():
45
+ return {"status": "ok"}
46
+
47
+ # Inference endpoint
48
+ @app.post("/predict")
49
+ def predict(request: QueryRequest):
50
+ inputs = tokenizer(request.question, return_tensors="pt")
51
+ outputs = model.generate(
52
+ **inputs,
53
+ max_new_tokens=request.max_new_tokens,
54
+ do_sample=True,
55
+ temperature=0.7,
56
+ top_p=0.9,
57
+ pad_token_id=tokenizer.eos_token_id,
58
+ return_dict_in_generate=True
59
+ )
60
+
61
+ answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
62
+ return {
63
+ "question": request.question,
64
+ "answer": answer
65
+ }