deberta_api / main.py
AISimplyExplained's picture
Update main.py
f43f094 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch
from detoxify import Detoxify
import asyncio
from fastapi.concurrency import run_in_threadpool
from typing import List, Optional
class Guardrail:
def __init__(self):
tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
self.classifier = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
truncation=True,
max_length=512,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
async def guard(self, prompt):
return await run_in_threadpool(self.classifier, prompt)
def determine_level(self, label, score):
if label == "SAFE":
return 0, "safe"
else:
if score > 0.9:
return 4, "high"
elif score > 0.75:
return 3, "medium"
elif score > 0.5:
return 2, "low"
else:
return 1, "very low"
class TextPrompt(BaseModel):
prompt: str
class ClassificationResult(BaseModel):
label: str
score: float
level: int
severity_label: str
class ToxicityResult(BaseModel):
toxicity: float
severe_toxicity: float
obscene: float
threat: float
insult: float
identity_attack: float
@classmethod
def from_dict(cls, data: dict):
return cls(**{k: float(v) for k, v in data.items()})
class TopicBannerClassifier:
def __init__(self):
self.classifier = pipeline(
"zero-shot-classification",
model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
self.hypothesis_template = "This text is about {}"
async def classify(self, text, labels):
return await run_in_threadpool(
self.classifier,
text,
labels,
hypothesis_template=self.hypothesis_template,
multi_label=False
)
class TopicBannerRequest(BaseModel):
prompt: str
labels: List[str]
class TopicBannerResult(BaseModel):
sequence: str
labels: list
scores: list
class GuardrailsRequest(BaseModel):
prompt: str
guardrails: List[str]
labels: Optional[List[str]] = None
class GuardrailsResponse(BaseModel):
prompt_injection: Optional[ClassificationResult] = None
toxicity: Optional[ToxicityResult] = None
topic_banner: Optional[TopicBannerResult] = None
app = FastAPI()
guardrail = Guardrail()
toxicity_classifier = Detoxify('original')
topic_banner_classifier = TopicBannerClassifier()
@app.post("/api/models/toxicity/classify", response_model=ToxicityResult)
async def classify_toxicity(text_prompt: TextPrompt):
try:
result = await run_in_threadpool(toxicity_classifier.predict, text_prompt.prompt)
return ToxicityResult.from_dict(result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/models/PromptInjection/classify", response_model=ClassificationResult)
async def classify_text(text_prompt: TextPrompt):
try:
result = await guardrail.guard(text_prompt.prompt)
label = result[0]['label']
score = result[0]['score']
level, severity_label = guardrail.determine_level(label, score)
return {"label": label, "score": score, "level": level, "severity_label": severity_label}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
async def classify_topic_banner(request: TopicBannerRequest):
try:
result = await topic_banner_classifier.classify(request.prompt, request.labels)
return {
"sequence": result["sequence"],
"labels": result["labels"],
"scores": result["scores"]
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/guardrails", response_model=GuardrailsResponse)
async def evaluate_guardrails(request: GuardrailsRequest):
tasks = []
response = GuardrailsResponse()
if "pi" in request.guardrails:
tasks.append(classify_text(TextPrompt(prompt=request.prompt)))
if "tox" in request.guardrails:
tasks.append(classify_toxicity(TextPrompt(prompt=request.prompt)))
if "top" in request.guardrails:
if not request.labels:
raise HTTPException(status_code=400, detail="Labels are required for topic banner classification")
tasks.append(classify_topic_banner(TopicBannerRequest(prompt=request.prompt, labels=request.labels)))
results = await asyncio.gather(*tasks, return_exceptions=True)
for result, guardrail in zip(results, request.guardrails):
if isinstance(result, Exception):
# Handle the exception as needed
continue
if guardrail == "pi":
response.prompt_injection = result
elif guardrail == "tox":
response.toxicity = result
elif guardrail == "top":
response.topic_banner = result
return response
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)