vumichien commited on
Commit
6830bc7
·
1 Parent(s): 460e51f

token store

Browse files
Files changed (5) hide show
  1. custom_auth.py +78 -0
  2. routes/auth.py +42 -10
  3. routes/health.py +2 -2
  4. routes/predict.py +7 -8
  5. token_store.py +159 -0
custom_auth.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Depends, HTTPException, status, Header, Query
2
+ from typing import Optional
3
+ from database import get_users
4
+ from models import User, UserInDB
5
+ from token_store import token_store
6
+
7
+
8
+ async def get_token(
9
+ authorization: Optional[str] = Header(None),
10
+ token: Optional[str] = Query(
11
+ None, description="Access token (alternative to Authorization header)"
12
+ ),
13
+ ) -> str:
14
+ """
15
+ Extract token from Authorization header or query parameter
16
+ Supports both methods for better compatibility with various clients
17
+ """
18
+ # First try to get token from Authorization header
19
+ if authorization:
20
+ if authorization.startswith("Bearer "):
21
+ return authorization.replace("Bearer ", "")
22
+ else:
23
+ # If it doesn't have Bearer prefix, use as is
24
+ return authorization
25
+
26
+ # Then try to get token from query parameter
27
+ if token:
28
+ return token
29
+
30
+ # If no token is provided, raise an error
31
+ raise HTTPException(
32
+ status_code=status.HTTP_401_UNAUTHORIZED,
33
+ detail="Authorization header missing",
34
+ headers={"WWW-Authenticate": "Bearer"},
35
+ )
36
+
37
+
38
+ async def get_current_user_from_token(token: str = Depends(get_token)):
39
+ """
40
+ Validate token and return user if valid
41
+ """
42
+ credentials_exception = HTTPException(
43
+ status_code=status.HTTP_401_UNAUTHORIZED,
44
+ detail="Could not validate credentials",
45
+ headers={"WWW-Authenticate": "Bearer"},
46
+ )
47
+
48
+ # Validate token
49
+ username = token_store.validate_token(token)
50
+ if not username:
51
+ print(f"Invalid or expired token")
52
+ raise credentials_exception
53
+
54
+ # Get user from database
55
+ users = get_users()
56
+ if username not in users:
57
+ print(f"User not found: {username}")
58
+ raise credentials_exception
59
+
60
+ user_dict = users[username]
61
+ user = UserInDB(**user_dict)
62
+ print(f"User authenticated: {user.username}")
63
+
64
+ return user
65
+
66
+
67
+ def create_token_for_user(username: str) -> str:
68
+ """
69
+ Create a new token for a user
70
+ """
71
+ return token_store.create_token(username)
72
+
73
+
74
+ def remove_token(token: str) -> bool:
75
+ """
76
+ Remove a token from the store
77
+ """
78
+ return token_store.remove_token(token)
routes/auth.py CHANGED
@@ -1,9 +1,13 @@
1
  from fastapi import APIRouter, Depends, HTTPException, status
2
  from fastapi.security import OAuth2PasswordRequestForm
3
- from datetime import timedelta
4
- from auth import authenticate_user, create_access_token
5
- from models import Token, UserCreate
6
- from config import ACCESS_TOKEN_EXPIRE_HOURS
 
 
 
 
7
  from database import get_users, create_account
8
 
9
  router = APIRouter()
@@ -11,7 +15,7 @@ router = APIRouter()
11
  @router.post("/token", response_model=Token)
12
  async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
13
  """
14
- Endpoint để lấy access token bằng username password
15
  """
16
  user = authenticate_user(get_users(), form_data.username, form_data.password)
17
  if not user:
@@ -21,19 +25,47 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(
21
  headers={"WWW-Authenticate": "Bearer"},
22
  )
23
 
24
- access_token_expires = timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS)
25
- access_token = create_access_token(
26
- data={"sub": user.username}, expires_delta=access_token_expires
27
- )
28
  return Token(access_token=access_token, token_type="bearer")
29
 
30
 
31
  @router.post("/register")
32
  async def register_user(user_data: UserCreate):
33
  """
34
- Endpoint để đăng tài khoản mới
35
  """
36
  success, message = create_account(user_data.username, user_data.password)
37
  if not success:
38
  raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message)
39
  return {"message": message}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import APIRouter, Depends, HTTPException, status
2
  from fastapi.security import OAuth2PasswordRequestForm
3
+ from auth import authenticate_user, get_current_user
4
+ from custom_auth import (
5
+ create_token_for_user,
6
+ get_current_user_from_token,
7
+ get_token,
8
+ remove_token,
9
+ )
10
+ from models import Token, UserCreate, User
11
  from database import get_users, create_account
12
 
13
  router = APIRouter()
 
15
  @router.post("/token", response_model=Token)
16
  async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
17
  """
18
+ Endpoint to get an access token using username and password
19
  """
20
  user = authenticate_user(get_users(), form_data.username, form_data.password)
21
  if not user:
 
25
  headers={"WWW-Authenticate": "Bearer"},
26
  )
27
 
28
+ # Create a new token for the user (this will remove any existing token)
29
+ access_token = create_token_for_user(user.username)
30
+
 
31
  return Token(access_token=access_token, token_type="bearer")
32
 
33
 
34
  @router.post("/register")
35
  async def register_user(user_data: UserCreate):
36
  """
37
+ Endpoint to register a new account
38
  """
39
  success, message = create_account(user_data.username, user_data.password)
40
  if not success:
41
  raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message)
42
  return {"message": message}
43
+
44
+
45
+ @router.post("/logout")
46
+ async def logout(
47
+ current_user: User = Depends(get_current_user_from_token),
48
+ token: str = Depends(get_token),
49
+ ):
50
+ """
51
+ Endpoint to logout (invalidate the current token)
52
+ """
53
+ success = remove_token(token)
54
+ if success:
55
+ return {"message": "Logout successful"}
56
+ else:
57
+ return {"message": "Token already invalid or expired"}
58
+
59
+
60
+ @router.get("/me")
61
+ async def get_current_user_info(
62
+ current_user: User = Depends(get_current_user_from_token),
63
+ ):
64
+ """
65
+ Get the current user's information
66
+ """
67
+ return {
68
+ "username": current_user.username,
69
+ "email": current_user.email,
70
+ "full_name": current_user.full_name,
71
+ }
routes/health.py CHANGED
@@ -1,5 +1,5 @@
1
  from fastapi import APIRouter, Depends
2
- from simple_auth import get_current_user_from_api_key
3
 
4
  router = APIRouter()
5
 
@@ -12,7 +12,7 @@ async def health_check():
12
 
13
 
14
  @router.get("/auth-check")
15
- async def auth_check(current_user=Depends(get_current_user_from_api_key)):
16
  """
17
  Debug endpoint to verify authentication is working
18
  """
 
1
  from fastapi import APIRouter, Depends
2
+ from custom_auth import get_current_user_from_token
3
 
4
  router = APIRouter()
5
 
 
12
 
13
 
14
  @router.get("/auth-check")
15
+ async def auth_check(current_user=Depends(get_current_user_from_token)):
16
  """
17
  Debug endpoint to verify authentication is working
18
  """
routes/predict.py CHANGED
@@ -4,8 +4,7 @@ import shutil
4
  from pathlib import Path
5
  from fastapi import APIRouter, UploadFile, File, HTTPException, Depends, Body
6
  from fastapi.responses import FileResponse
7
- from auth import get_current_user
8
- from simple_auth import get_current_user_from_api_key
9
  from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service
10
  from data_lib.input_name_data import InputNameData
11
  from data_lib.base_name_data import COL_NAME_SENTENCE
@@ -27,14 +26,14 @@ router = APIRouter()
27
 
28
  @router.post("/predict")
29
  async def predict(
30
- current_user=Depends(get_current_user_from_api_key),
31
  file: UploadFile = File(...),
32
  sentence_service: SentenceTransformerService = Depends(
33
  lambda: sentence_transformer_service
34
  ),
35
  ):
36
  """
37
- Process an input CSV file and return standardized names (requires API Key authentication)
38
  """
39
  if not file.filename.endswith(".csv"):
40
  raise HTTPException(status_code=400, detail="Only CSV files are supported")
@@ -120,13 +119,13 @@ async def predict(
120
  @router.post("/embeddings")
121
  async def create_embeddings(
122
  request: EmbeddingRequest,
123
- current_user=Depends(get_current_user_from_api_key),
124
  sentence_service: SentenceTransformerService = Depends(
125
  lambda: sentence_transformer_service
126
  ),
127
  ):
128
  """
129
- Create embeddings for a list of input sentences (requires API Key authentication)
130
  """
131
  try:
132
  start_time = time.time()
@@ -147,13 +146,13 @@ async def create_embeddings(
147
  @router.post("/predict-raw", response_model=PredictRawResponse)
148
  async def predict_raw(
149
  request: PredictRawRequest,
150
- current_user=Depends(get_current_user_from_api_key),
151
  sentence_service: SentenceTransformerService = Depends(
152
  lambda: sentence_transformer_service
153
  ),
154
  ):
155
  """
156
- Process raw input records and return standardized names (requires API Key authentication)
157
  """
158
  try:
159
  # Convert input records to DataFrame
 
4
  from pathlib import Path
5
  from fastapi import APIRouter, UploadFile, File, HTTPException, Depends, Body
6
  from fastapi.responses import FileResponse
7
+ from custom_auth import get_current_user_from_token
 
8
  from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service
9
  from data_lib.input_name_data import InputNameData
10
  from data_lib.base_name_data import COL_NAME_SENTENCE
 
26
 
27
  @router.post("/predict")
28
  async def predict(
29
+ current_user=Depends(get_current_user_from_token),
30
  file: UploadFile = File(...),
31
  sentence_service: SentenceTransformerService = Depends(
32
  lambda: sentence_transformer_service
33
  ),
34
  ):
35
  """
36
+ Process an input CSV file and return standardized names (requires authentication)
37
  """
38
  if not file.filename.endswith(".csv"):
39
  raise HTTPException(status_code=400, detail="Only CSV files are supported")
 
119
  @router.post("/embeddings")
120
  async def create_embeddings(
121
  request: EmbeddingRequest,
122
+ current_user=Depends(get_current_user_from_token),
123
  sentence_service: SentenceTransformerService = Depends(
124
  lambda: sentence_transformer_service
125
  ),
126
  ):
127
  """
128
+ Create embeddings for a list of input sentences (requires authentication)
129
  """
130
  try:
131
  start_time = time.time()
 
146
  @router.post("/predict-raw", response_model=PredictRawResponse)
147
  async def predict_raw(
148
  request: PredictRawRequest,
149
+ current_user=Depends(get_current_user_from_token),
150
  sentence_service: SentenceTransformerService = Depends(
151
  lambda: sentence_transformer_service
152
  ),
153
  ):
154
  """
155
+ Process raw input records and return standardized names (requires authentication)
156
  """
157
  try:
158
  # Convert input records to DataFrame
token_store.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+ import secrets
5
+ import threading
6
+ from datetime import datetime, timedelta
7
+ from config import ACCESS_TOKEN_EXPIRE_HOURS
8
+
9
+
10
+ # Singleton to store tokens
11
+ class TokenStore:
12
+ _instance = None
13
+ _lock = threading.Lock()
14
+
15
+ def __new__(cls):
16
+ with cls._lock:
17
+ if cls._instance is None:
18
+ cls._instance = super(TokenStore, cls).__new__(cls)
19
+ # Initialize here to make instance attributes
20
+ cls._instance.tokens = {} # username -> {token, created_at}
21
+ cls._instance.token_to_user = {} # token -> username
22
+ cls._instance.tokens_file = "data/tokens.json"
23
+ return cls._instance
24
+
25
+ def __init__(self):
26
+ # Re-initialize in __init__ to help linters recognize these attributes
27
+ if not hasattr(self, "tokens"):
28
+ self.tokens = {}
29
+ if not hasattr(self, "token_to_user"):
30
+ self.token_to_user = {}
31
+ if not hasattr(self, "tokens_file"):
32
+ self.tokens_file = "data/tokens.json"
33
+
34
+ # Load tokens when instance is created
35
+ if not hasattr(self, "_loaded"):
36
+ self._load_tokens()
37
+ self._loaded = True
38
+
39
+ def _load_tokens(self):
40
+ """Load tokens from file if it exists"""
41
+ os.makedirs("data", exist_ok=True)
42
+ if os.path.exists(self.tokens_file):
43
+ try:
44
+ with open(self.tokens_file, "r") as f:
45
+ data = json.load(f)
46
+ self.tokens = data.get("tokens", {})
47
+ self.token_to_user = data.get("token_to_user", {})
48
+
49
+ # Clean expired tokens on load
50
+ self._clean_expired_tokens()
51
+ except Exception as e:
52
+ print(f"Error loading tokens: {e}")
53
+ self.tokens = {}
54
+ self.token_to_user = {}
55
+
56
+ def _save_tokens(self):
57
+ """Save tokens to file"""
58
+ try:
59
+ with open(self.tokens_file, "w") as f:
60
+ json.dump(
61
+ {"tokens": self.tokens, "token_to_user": self.token_to_user},
62
+ f,
63
+ indent=4,
64
+ )
65
+ except Exception as e:
66
+ print(f"Error saving tokens: {e}")
67
+
68
+ def _clean_expired_tokens(self):
69
+ """Remove expired tokens"""
70
+ current_time = time.time()
71
+ expired_usernames = []
72
+ expired_tokens = []
73
+
74
+ # Find expired tokens
75
+ for username, token_data in self.tokens.items():
76
+ created_at = token_data.get("created_at", 0)
77
+ expiry_seconds = ACCESS_TOKEN_EXPIRE_HOURS * 3600
78
+
79
+ if current_time - created_at > expiry_seconds:
80
+ expired_usernames.append(username)
81
+ expired_tokens.append(token_data.get("token"))
82
+
83
+ # Remove expired tokens
84
+ for username in expired_usernames:
85
+ if username in self.tokens:
86
+ del self.tokens[username]
87
+
88
+ for token in expired_tokens:
89
+ if token in self.token_to_user:
90
+ del self.token_to_user[token]
91
+
92
+ # Save changes if any tokens were removed
93
+ if expired_tokens:
94
+ self._save_tokens()
95
+
96
+ def create_token(self, username):
97
+ """Create a new token for a user, removing any existing token"""
98
+ with self._lock:
99
+ # Clean expired tokens first
100
+ self._clean_expired_tokens()
101
+
102
+ # Remove old token if it exists
103
+ if username in self.tokens:
104
+ old_token = self.tokens[username].get("token")
105
+ if old_token in self.token_to_user:
106
+ del self.token_to_user[old_token]
107
+
108
+ # Create new token
109
+ token = secrets.token_hex(32) # 64 character random hex string
110
+ self.tokens[username] = {"token": token, "created_at": time.time()}
111
+ self.token_to_user[token] = username
112
+
113
+ # Save changes
114
+ self._save_tokens()
115
+
116
+ return token
117
+
118
+ def validate_token(self, token):
119
+ """Validate a token and return the username if valid"""
120
+ with self._lock:
121
+ # Clean expired tokens first
122
+ self._clean_expired_tokens()
123
+
124
+ # Check if token exists
125
+ if token not in self.token_to_user:
126
+ return None
127
+
128
+ username = self.token_to_user[token]
129
+
130
+ # Check if token is not expired
131
+ if username in self.tokens:
132
+ token_data = self.tokens[username]
133
+ created_at = token_data.get("created_at", 0)
134
+ current_time = time.time()
135
+ expiry_seconds = ACCESS_TOKEN_EXPIRE_HOURS * 3600
136
+
137
+ if current_time - created_at <= expiry_seconds:
138
+ return username
139
+
140
+ # Token is expired or invalid
141
+ return None
142
+
143
+ def remove_token(self, token):
144
+ """Remove a token"""
145
+ with self._lock:
146
+ if token in self.token_to_user:
147
+ username = self.token_to_user[token]
148
+ del self.token_to_user[token]
149
+
150
+ if username in self.tokens:
151
+ del self.tokens[username]
152
+
153
+ self._save_tokens()
154
+ return True
155
+ return False
156
+
157
+
158
+ # Get the singleton instance
159
+ token_store = TokenStore()