Spaces:
Sleeping
Sleeping
token store
Browse files- custom_auth.py +78 -0
- routes/auth.py +42 -10
- routes/health.py +2 -2
- routes/predict.py +7 -8
- token_store.py +159 -0
custom_auth.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import Depends, HTTPException, status, Header, Query
|
2 |
+
from typing import Optional
|
3 |
+
from database import get_users
|
4 |
+
from models import User, UserInDB
|
5 |
+
from token_store import token_store
|
6 |
+
|
7 |
+
|
8 |
+
async def get_token(
|
9 |
+
authorization: Optional[str] = Header(None),
|
10 |
+
token: Optional[str] = Query(
|
11 |
+
None, description="Access token (alternative to Authorization header)"
|
12 |
+
),
|
13 |
+
) -> str:
|
14 |
+
"""
|
15 |
+
Extract token from Authorization header or query parameter
|
16 |
+
Supports both methods for better compatibility with various clients
|
17 |
+
"""
|
18 |
+
# First try to get token from Authorization header
|
19 |
+
if authorization:
|
20 |
+
if authorization.startswith("Bearer "):
|
21 |
+
return authorization.replace("Bearer ", "")
|
22 |
+
else:
|
23 |
+
# If it doesn't have Bearer prefix, use as is
|
24 |
+
return authorization
|
25 |
+
|
26 |
+
# Then try to get token from query parameter
|
27 |
+
if token:
|
28 |
+
return token
|
29 |
+
|
30 |
+
# If no token is provided, raise an error
|
31 |
+
raise HTTPException(
|
32 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
33 |
+
detail="Authorization header missing",
|
34 |
+
headers={"WWW-Authenticate": "Bearer"},
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
async def get_current_user_from_token(token: str = Depends(get_token)):
|
39 |
+
"""
|
40 |
+
Validate token and return user if valid
|
41 |
+
"""
|
42 |
+
credentials_exception = HTTPException(
|
43 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
44 |
+
detail="Could not validate credentials",
|
45 |
+
headers={"WWW-Authenticate": "Bearer"},
|
46 |
+
)
|
47 |
+
|
48 |
+
# Validate token
|
49 |
+
username = token_store.validate_token(token)
|
50 |
+
if not username:
|
51 |
+
print(f"Invalid or expired token")
|
52 |
+
raise credentials_exception
|
53 |
+
|
54 |
+
# Get user from database
|
55 |
+
users = get_users()
|
56 |
+
if username not in users:
|
57 |
+
print(f"User not found: {username}")
|
58 |
+
raise credentials_exception
|
59 |
+
|
60 |
+
user_dict = users[username]
|
61 |
+
user = UserInDB(**user_dict)
|
62 |
+
print(f"User authenticated: {user.username}")
|
63 |
+
|
64 |
+
return user
|
65 |
+
|
66 |
+
|
67 |
+
def create_token_for_user(username: str) -> str:
|
68 |
+
"""
|
69 |
+
Create a new token for a user
|
70 |
+
"""
|
71 |
+
return token_store.create_token(username)
|
72 |
+
|
73 |
+
|
74 |
+
def remove_token(token: str) -> bool:
|
75 |
+
"""
|
76 |
+
Remove a token from the store
|
77 |
+
"""
|
78 |
+
return token_store.remove_token(token)
|
routes/auth.py
CHANGED
@@ -1,9 +1,13 @@
|
|
1 |
from fastapi import APIRouter, Depends, HTTPException, status
|
2 |
from fastapi.security import OAuth2PasswordRequestForm
|
3 |
-
from
|
4 |
-
from
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
7 |
from database import get_users, create_account
|
8 |
|
9 |
router = APIRouter()
|
@@ -11,7 +15,7 @@ router = APIRouter()
|
|
11 |
@router.post("/token", response_model=Token)
|
12 |
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
|
13 |
"""
|
14 |
-
Endpoint
|
15 |
"""
|
16 |
user = authenticate_user(get_users(), form_data.username, form_data.password)
|
17 |
if not user:
|
@@ -21,19 +25,47 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(
|
|
21 |
headers={"WWW-Authenticate": "Bearer"},
|
22 |
)
|
23 |
|
24 |
-
|
25 |
-
access_token =
|
26 |
-
|
27 |
-
)
|
28 |
return Token(access_token=access_token, token_type="bearer")
|
29 |
|
30 |
|
31 |
@router.post("/register")
|
32 |
async def register_user(user_data: UserCreate):
|
33 |
"""
|
34 |
-
Endpoint
|
35 |
"""
|
36 |
success, message = create_account(user_data.username, user_data.password)
|
37 |
if not success:
|
38 |
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message)
|
39 |
return {"message": message}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from fastapi import APIRouter, Depends, HTTPException, status
|
2 |
from fastapi.security import OAuth2PasswordRequestForm
|
3 |
+
from auth import authenticate_user, get_current_user
|
4 |
+
from custom_auth import (
|
5 |
+
create_token_for_user,
|
6 |
+
get_current_user_from_token,
|
7 |
+
get_token,
|
8 |
+
remove_token,
|
9 |
+
)
|
10 |
+
from models import Token, UserCreate, User
|
11 |
from database import get_users, create_account
|
12 |
|
13 |
router = APIRouter()
|
|
|
15 |
@router.post("/token", response_model=Token)
|
16 |
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
|
17 |
"""
|
18 |
+
Endpoint to get an access token using username and password
|
19 |
"""
|
20 |
user = authenticate_user(get_users(), form_data.username, form_data.password)
|
21 |
if not user:
|
|
|
25 |
headers={"WWW-Authenticate": "Bearer"},
|
26 |
)
|
27 |
|
28 |
+
# Create a new token for the user (this will remove any existing token)
|
29 |
+
access_token = create_token_for_user(user.username)
|
30 |
+
|
|
|
31 |
return Token(access_token=access_token, token_type="bearer")
|
32 |
|
33 |
|
34 |
@router.post("/register")
|
35 |
async def register_user(user_data: UserCreate):
|
36 |
"""
|
37 |
+
Endpoint to register a new account
|
38 |
"""
|
39 |
success, message = create_account(user_data.username, user_data.password)
|
40 |
if not success:
|
41 |
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message)
|
42 |
return {"message": message}
|
43 |
+
|
44 |
+
|
45 |
+
@router.post("/logout")
|
46 |
+
async def logout(
|
47 |
+
current_user: User = Depends(get_current_user_from_token),
|
48 |
+
token: str = Depends(get_token),
|
49 |
+
):
|
50 |
+
"""
|
51 |
+
Endpoint to logout (invalidate the current token)
|
52 |
+
"""
|
53 |
+
success = remove_token(token)
|
54 |
+
if success:
|
55 |
+
return {"message": "Logout successful"}
|
56 |
+
else:
|
57 |
+
return {"message": "Token already invalid or expired"}
|
58 |
+
|
59 |
+
|
60 |
+
@router.get("/me")
|
61 |
+
async def get_current_user_info(
|
62 |
+
current_user: User = Depends(get_current_user_from_token),
|
63 |
+
):
|
64 |
+
"""
|
65 |
+
Get the current user's information
|
66 |
+
"""
|
67 |
+
return {
|
68 |
+
"username": current_user.username,
|
69 |
+
"email": current_user.email,
|
70 |
+
"full_name": current_user.full_name,
|
71 |
+
}
|
routes/health.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from fastapi import APIRouter, Depends
|
2 |
-
from
|
3 |
|
4 |
router = APIRouter()
|
5 |
|
@@ -12,7 +12,7 @@ async def health_check():
|
|
12 |
|
13 |
|
14 |
@router.get("/auth-check")
|
15 |
-
async def auth_check(current_user=Depends(
|
16 |
"""
|
17 |
Debug endpoint to verify authentication is working
|
18 |
"""
|
|
|
1 |
from fastapi import APIRouter, Depends
|
2 |
+
from custom_auth import get_current_user_from_token
|
3 |
|
4 |
router = APIRouter()
|
5 |
|
|
|
12 |
|
13 |
|
14 |
@router.get("/auth-check")
|
15 |
+
async def auth_check(current_user=Depends(get_current_user_from_token)):
|
16 |
"""
|
17 |
Debug endpoint to verify authentication is working
|
18 |
"""
|
routes/predict.py
CHANGED
@@ -4,8 +4,7 @@ import shutil
|
|
4 |
from pathlib import Path
|
5 |
from fastapi import APIRouter, UploadFile, File, HTTPException, Depends, Body
|
6 |
from fastapi.responses import FileResponse
|
7 |
-
from
|
8 |
-
from simple_auth import get_current_user_from_api_key
|
9 |
from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service
|
10 |
from data_lib.input_name_data import InputNameData
|
11 |
from data_lib.base_name_data import COL_NAME_SENTENCE
|
@@ -27,14 +26,14 @@ router = APIRouter()
|
|
27 |
|
28 |
@router.post("/predict")
|
29 |
async def predict(
|
30 |
-
current_user=Depends(
|
31 |
file: UploadFile = File(...),
|
32 |
sentence_service: SentenceTransformerService = Depends(
|
33 |
lambda: sentence_transformer_service
|
34 |
),
|
35 |
):
|
36 |
"""
|
37 |
-
Process an input CSV file and return standardized names (requires
|
38 |
"""
|
39 |
if not file.filename.endswith(".csv"):
|
40 |
raise HTTPException(status_code=400, detail="Only CSV files are supported")
|
@@ -120,13 +119,13 @@ async def predict(
|
|
120 |
@router.post("/embeddings")
|
121 |
async def create_embeddings(
|
122 |
request: EmbeddingRequest,
|
123 |
-
current_user=Depends(
|
124 |
sentence_service: SentenceTransformerService = Depends(
|
125 |
lambda: sentence_transformer_service
|
126 |
),
|
127 |
):
|
128 |
"""
|
129 |
-
Create embeddings for a list of input sentences (requires
|
130 |
"""
|
131 |
try:
|
132 |
start_time = time.time()
|
@@ -147,13 +146,13 @@ async def create_embeddings(
|
|
147 |
@router.post("/predict-raw", response_model=PredictRawResponse)
|
148 |
async def predict_raw(
|
149 |
request: PredictRawRequest,
|
150 |
-
current_user=Depends(
|
151 |
sentence_service: SentenceTransformerService = Depends(
|
152 |
lambda: sentence_transformer_service
|
153 |
),
|
154 |
):
|
155 |
"""
|
156 |
-
Process raw input records and return standardized names (requires
|
157 |
"""
|
158 |
try:
|
159 |
# Convert input records to DataFrame
|
|
|
4 |
from pathlib import Path
|
5 |
from fastapi import APIRouter, UploadFile, File, HTTPException, Depends, Body
|
6 |
from fastapi.responses import FileResponse
|
7 |
+
from custom_auth import get_current_user_from_token
|
|
|
8 |
from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service
|
9 |
from data_lib.input_name_data import InputNameData
|
10 |
from data_lib.base_name_data import COL_NAME_SENTENCE
|
|
|
26 |
|
27 |
@router.post("/predict")
|
28 |
async def predict(
|
29 |
+
current_user=Depends(get_current_user_from_token),
|
30 |
file: UploadFile = File(...),
|
31 |
sentence_service: SentenceTransformerService = Depends(
|
32 |
lambda: sentence_transformer_service
|
33 |
),
|
34 |
):
|
35 |
"""
|
36 |
+
Process an input CSV file and return standardized names (requires authentication)
|
37 |
"""
|
38 |
if not file.filename.endswith(".csv"):
|
39 |
raise HTTPException(status_code=400, detail="Only CSV files are supported")
|
|
|
119 |
@router.post("/embeddings")
|
120 |
async def create_embeddings(
|
121 |
request: EmbeddingRequest,
|
122 |
+
current_user=Depends(get_current_user_from_token),
|
123 |
sentence_service: SentenceTransformerService = Depends(
|
124 |
lambda: sentence_transformer_service
|
125 |
),
|
126 |
):
|
127 |
"""
|
128 |
+
Create embeddings for a list of input sentences (requires authentication)
|
129 |
"""
|
130 |
try:
|
131 |
start_time = time.time()
|
|
|
146 |
@router.post("/predict-raw", response_model=PredictRawResponse)
|
147 |
async def predict_raw(
|
148 |
request: PredictRawRequest,
|
149 |
+
current_user=Depends(get_current_user_from_token),
|
150 |
sentence_service: SentenceTransformerService = Depends(
|
151 |
lambda: sentence_transformer_service
|
152 |
),
|
153 |
):
|
154 |
"""
|
155 |
+
Process raw input records and return standardized names (requires authentication)
|
156 |
"""
|
157 |
try:
|
158 |
# Convert input records to DataFrame
|
token_store.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import secrets
|
5 |
+
import threading
|
6 |
+
from datetime import datetime, timedelta
|
7 |
+
from config import ACCESS_TOKEN_EXPIRE_HOURS
|
8 |
+
|
9 |
+
|
10 |
+
# Singleton to store tokens
|
11 |
+
class TokenStore:
|
12 |
+
_instance = None
|
13 |
+
_lock = threading.Lock()
|
14 |
+
|
15 |
+
def __new__(cls):
|
16 |
+
with cls._lock:
|
17 |
+
if cls._instance is None:
|
18 |
+
cls._instance = super(TokenStore, cls).__new__(cls)
|
19 |
+
# Initialize here to make instance attributes
|
20 |
+
cls._instance.tokens = {} # username -> {token, created_at}
|
21 |
+
cls._instance.token_to_user = {} # token -> username
|
22 |
+
cls._instance.tokens_file = "data/tokens.json"
|
23 |
+
return cls._instance
|
24 |
+
|
25 |
+
def __init__(self):
|
26 |
+
# Re-initialize in __init__ to help linters recognize these attributes
|
27 |
+
if not hasattr(self, "tokens"):
|
28 |
+
self.tokens = {}
|
29 |
+
if not hasattr(self, "token_to_user"):
|
30 |
+
self.token_to_user = {}
|
31 |
+
if not hasattr(self, "tokens_file"):
|
32 |
+
self.tokens_file = "data/tokens.json"
|
33 |
+
|
34 |
+
# Load tokens when instance is created
|
35 |
+
if not hasattr(self, "_loaded"):
|
36 |
+
self._load_tokens()
|
37 |
+
self._loaded = True
|
38 |
+
|
39 |
+
def _load_tokens(self):
|
40 |
+
"""Load tokens from file if it exists"""
|
41 |
+
os.makedirs("data", exist_ok=True)
|
42 |
+
if os.path.exists(self.tokens_file):
|
43 |
+
try:
|
44 |
+
with open(self.tokens_file, "r") as f:
|
45 |
+
data = json.load(f)
|
46 |
+
self.tokens = data.get("tokens", {})
|
47 |
+
self.token_to_user = data.get("token_to_user", {})
|
48 |
+
|
49 |
+
# Clean expired tokens on load
|
50 |
+
self._clean_expired_tokens()
|
51 |
+
except Exception as e:
|
52 |
+
print(f"Error loading tokens: {e}")
|
53 |
+
self.tokens = {}
|
54 |
+
self.token_to_user = {}
|
55 |
+
|
56 |
+
def _save_tokens(self):
|
57 |
+
"""Save tokens to file"""
|
58 |
+
try:
|
59 |
+
with open(self.tokens_file, "w") as f:
|
60 |
+
json.dump(
|
61 |
+
{"tokens": self.tokens, "token_to_user": self.token_to_user},
|
62 |
+
f,
|
63 |
+
indent=4,
|
64 |
+
)
|
65 |
+
except Exception as e:
|
66 |
+
print(f"Error saving tokens: {e}")
|
67 |
+
|
68 |
+
def _clean_expired_tokens(self):
|
69 |
+
"""Remove expired tokens"""
|
70 |
+
current_time = time.time()
|
71 |
+
expired_usernames = []
|
72 |
+
expired_tokens = []
|
73 |
+
|
74 |
+
# Find expired tokens
|
75 |
+
for username, token_data in self.tokens.items():
|
76 |
+
created_at = token_data.get("created_at", 0)
|
77 |
+
expiry_seconds = ACCESS_TOKEN_EXPIRE_HOURS * 3600
|
78 |
+
|
79 |
+
if current_time - created_at > expiry_seconds:
|
80 |
+
expired_usernames.append(username)
|
81 |
+
expired_tokens.append(token_data.get("token"))
|
82 |
+
|
83 |
+
# Remove expired tokens
|
84 |
+
for username in expired_usernames:
|
85 |
+
if username in self.tokens:
|
86 |
+
del self.tokens[username]
|
87 |
+
|
88 |
+
for token in expired_tokens:
|
89 |
+
if token in self.token_to_user:
|
90 |
+
del self.token_to_user[token]
|
91 |
+
|
92 |
+
# Save changes if any tokens were removed
|
93 |
+
if expired_tokens:
|
94 |
+
self._save_tokens()
|
95 |
+
|
96 |
+
def create_token(self, username):
|
97 |
+
"""Create a new token for a user, removing any existing token"""
|
98 |
+
with self._lock:
|
99 |
+
# Clean expired tokens first
|
100 |
+
self._clean_expired_tokens()
|
101 |
+
|
102 |
+
# Remove old token if it exists
|
103 |
+
if username in self.tokens:
|
104 |
+
old_token = self.tokens[username].get("token")
|
105 |
+
if old_token in self.token_to_user:
|
106 |
+
del self.token_to_user[old_token]
|
107 |
+
|
108 |
+
# Create new token
|
109 |
+
token = secrets.token_hex(32) # 64 character random hex string
|
110 |
+
self.tokens[username] = {"token": token, "created_at": time.time()}
|
111 |
+
self.token_to_user[token] = username
|
112 |
+
|
113 |
+
# Save changes
|
114 |
+
self._save_tokens()
|
115 |
+
|
116 |
+
return token
|
117 |
+
|
118 |
+
def validate_token(self, token):
|
119 |
+
"""Validate a token and return the username if valid"""
|
120 |
+
with self._lock:
|
121 |
+
# Clean expired tokens first
|
122 |
+
self._clean_expired_tokens()
|
123 |
+
|
124 |
+
# Check if token exists
|
125 |
+
if token not in self.token_to_user:
|
126 |
+
return None
|
127 |
+
|
128 |
+
username = self.token_to_user[token]
|
129 |
+
|
130 |
+
# Check if token is not expired
|
131 |
+
if username in self.tokens:
|
132 |
+
token_data = self.tokens[username]
|
133 |
+
created_at = token_data.get("created_at", 0)
|
134 |
+
current_time = time.time()
|
135 |
+
expiry_seconds = ACCESS_TOKEN_EXPIRE_HOURS * 3600
|
136 |
+
|
137 |
+
if current_time - created_at <= expiry_seconds:
|
138 |
+
return username
|
139 |
+
|
140 |
+
# Token is expired or invalid
|
141 |
+
return None
|
142 |
+
|
143 |
+
def remove_token(self, token):
|
144 |
+
"""Remove a token"""
|
145 |
+
with self._lock:
|
146 |
+
if token in self.token_to_user:
|
147 |
+
username = self.token_to_user[token]
|
148 |
+
del self.token_to_user[token]
|
149 |
+
|
150 |
+
if username in self.tokens:
|
151 |
+
del self.tokens[username]
|
152 |
+
|
153 |
+
self._save_tokens()
|
154 |
+
return True
|
155 |
+
return False
|
156 |
+
|
157 |
+
|
158 |
+
# Get the singleton instance
|
159 |
+
token_store = TokenStore()
|