Spaces:
Running
Running
import os | |
import requests | |
import joblib | |
import torch | |
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForSequenceClassification | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
app = FastAPI() | |
# Utility to download files if not present locally | |
def download_file(url, dest): | |
if not os.path.exists(dest): | |
print(f"Downloading {url} to {dest}") | |
r = requests.get(url) | |
r.raise_for_status() | |
with open(dest, 'wb') as f: | |
f.write(r.content) | |
else: | |
print(f"File {dest} already exists.") | |
# ----------- Setup for BERT QA model (Virtual Consultation) ------------ | |
qa_model_dir = "./bert_mini_squadv2_finetuned" | |
os.makedirs(qa_model_dir, exist_ok=True) | |
qa_files = { | |
"pytorch_model.bin": "https://huggingface.co/spaces/isana25/DoctorTwin/resolve/main/pytorch_model.bin", | |
"config.json": "https://huggingface.co/spaces/isana25/DoctorTwin/resolve/main/config.json", | |
"tokenizer_config.json": "https://huggingface.co/spaces/isana25/DoctorTwin/resolve/main/tokenizer_config.json", | |
"vocab.txt": "https://huggingface.co/spaces/isana25/DoctorTwin/resolve/main/vocab.txt", | |
} | |
for fname, furl in qa_files.items(): | |
download_file(furl, os.path.join(qa_model_dir, fname)) | |
tokenizer_qa = AutoTokenizer.from_pretrained(qa_model_dir) | |
model_qa = AutoModelForQuestionAnswering.from_pretrained(qa_model_dir) | |
# ----------- Setup for Diabetes XGBoost Model (Risk Prediction) ------------ | |
diabetes_pkl_url = "https://huggingface.co/spaces/isana25/DoctorTwin/resolve/main/diabetes_xgboost_model.pkl" | |
diabetes_pkl_path = "./diabetes_xgboost_model.pkl" | |
download_file(diabetes_pkl_url, diabetes_pkl_path) | |
diabetes_model = joblib.load(diabetes_pkl_path) | |
# ----------- Setup for other features: load pretrained models directly ------------ | |
from transformers import pipeline | |
# Monitoring & Alerts - Summarization using bert-mini finetuned on squad_v2 | |
monitoring_model_id = "prajjwal1/bert-mini" | |
summarizer = pipeline("summarization", model=monitoring_model_id) | |
# Personalized Simulation - Bio_ClinicalBERT sequence classifier | |
personalized_model_id = "emilyalsentzer/Bio_ClinicalBERT" | |
personalized_tokenizer = AutoTokenizer.from_pretrained(personalized_model_id) | |
personalized_model = AutoModelForSequenceClassification.from_pretrained(personalized_model_id) | |
# --- Pydantic models for request validation --- | |
class QARequest(BaseModel): | |
question: str | |
context: str | |
class RiskPredictionRequest(BaseModel): | |
features: list # example: [age, bmi, blood_pressure, ...] | |
# --- API endpoints --- | |
def virtual_consultation(data: QARequest): | |
inputs = tokenizer_qa(data.question, data.context, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model_qa(**inputs) | |
answer_start = torch.argmax(outputs.start_logits) | |
answer_end = torch.argmax(outputs.end_logits) + 1 | |
answer = tokenizer_qa.convert_tokens_to_string( | |
tokenizer_qa.convert_ids_to_tokens(inputs.input_ids[0][answer_start:answer_end]) | |
) | |
return {"answer": answer} | |
def risk_prediction(data: RiskPredictionRequest): | |
import numpy as np | |
features = np.array(data.features).reshape(1, -1) | |
pred = diabetes_model.predict(features) | |
return {"risk_prediction": int(pred[0])} | |
def monitoring_alerts(text: str): | |
summary = summarizer(text, max_length=50, min_length=20, do_sample=False) | |
return {"summary": summary[0]['summary_text']} | |
def personalized_simulation(text: str): | |
inputs = personalized_tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
outputs = personalized_model(**inputs) | |
logits = outputs.logits.detach().numpy() | |
pred_label = logits.argmax() | |
return {"predicted_label": int(pred_label)} | |