Spaces:
Runtime error
Runtime error
Commit
·
640efa7
1
Parent(s):
f1a3e74
Update to print disallowed endpoints in-place in the model list
Browse filesUpdate configuration handling to put all clients in `clients` with backwards-compat. parsing
Troubleshoot radio button rendering
Refactor permissions configuration to support other oauth methods
app.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import json
|
|
|
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
|
| 5 |
import uvicorn
|
|
@@ -11,7 +13,7 @@ from starlette.responses import RedirectResponse
|
|
| 11 |
from authlib.integrations.starlette_client import OAuth, OAuthError
|
| 12 |
from fastapi import FastAPI, Request
|
| 13 |
|
| 14 |
-
from shared import Client
|
| 15 |
|
| 16 |
app = FastAPI()
|
| 17 |
config = {}
|
|
@@ -51,41 +53,44 @@ def init_config():
|
|
| 51 |
global clients
|
| 52 |
global llm_host_names
|
| 53 |
config = json.loads(os.environ['CONFIG'])
|
| 54 |
-
|
| 55 |
-
for name in
|
| 56 |
-
|
| 57 |
-
continue
|
| 58 |
-
model_personas = config[name].get("personas", {})
|
| 59 |
client = Client(
|
| 60 |
-
api_url=os.environ.get(
|
| 61 |
-
|
| 62 |
-
api_key=os.environ.get(
|
| 63 |
-
|
| 64 |
personas=model_personas
|
| 65 |
)
|
| 66 |
clients[name] = client
|
| 67 |
-
llm_host_names = list(
|
| 68 |
|
| 69 |
|
| 70 |
-
def get_allowed_models(
|
| 71 |
"""
|
| 72 |
Get a list of allowed endpoints for a specified user domain. Allowed domains
|
| 73 |
are configured in each model's configuration and may optionally be overridden
|
| 74 |
in the Gradio demo configuration.
|
| 75 |
-
:param
|
| 76 |
-
:return: List of allowed endpoints from configuration
|
|
|
|
| 77 |
"""
|
| 78 |
-
overrides = config.get("
|
| 79 |
allowed_endpoints = []
|
| 80 |
for client in clients:
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
if
|
| 84 |
-
#
|
| 85 |
allowed_endpoints.append(client)
|
| 86 |
-
elif
|
| 87 |
-
|
|
|
|
| 88 |
allowed_endpoints.append(client)
|
|
|
|
|
|
|
|
|
|
| 89 |
return allowed_endpoints
|
| 90 |
|
| 91 |
|
|
@@ -107,7 +112,7 @@ def get_login_button(request: gr.Request) -> gr.Button:
|
|
| 107 |
:param request: Gradio request to evaluate
|
| 108 |
:return: Button for either login or logout action
|
| 109 |
"""
|
| 110 |
-
user = get_user(request)
|
| 111 |
print(f"Getting login button for {user}")
|
| 112 |
|
| 113 |
if user == "guest":
|
|
@@ -116,15 +121,39 @@ def get_login_button(request: gr.Request) -> gr.Button:
|
|
| 116 |
return gr.Button(f"Logout {user}", link="/logout")
|
| 117 |
|
| 118 |
|
| 119 |
-
def get_user(request: Request) ->
|
| 120 |
"""
|
| 121 |
Get a unique user email address for the specified request
|
| 122 |
:param request: FastAPI Request object with user session data
|
| 123 |
:return: String user email address or "guest"
|
| 124 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
if not request:
|
| 126 |
-
return "guest"
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
return user
|
| 129 |
|
| 130 |
|
|
@@ -232,25 +261,25 @@ def get_model_options(request: gr.Request) -> List[gr.Radio]:
|
|
| 232 |
# `user` is a valid Google email address or 'guest'
|
| 233 |
user = get_user(request.request)
|
| 234 |
else:
|
| 235 |
-
user = "guest"
|
| 236 |
-
print(f"Getting models for {user}")
|
| 237 |
|
| 238 |
-
|
| 239 |
-
allowed_llm_host_names = get_allowed_models(domain)
|
| 240 |
|
| 241 |
radio_infos = [f"{name} ({clients[name].vllm_model_name})"
|
|
|
|
| 242 |
for name in allowed_llm_host_names]
|
| 243 |
# Components
|
| 244 |
-
radios = [gr.Radio(choices=clients[name].personas.keys(),
|
| 245 |
value=None, label=info) for name, info
|
| 246 |
in zip(allowed_llm_host_names, radio_infos)]
|
| 247 |
|
| 248 |
# Select the first available option by default
|
| 249 |
radios[0].value = list(clients[allowed_llm_host_names[0]].personas.keys())[0]
|
| 250 |
print(f"Set default persona to {radios[0].value} for {allowed_llm_host_names[0]}")
|
| 251 |
-
# Ensure we always have the same number of rows
|
| 252 |
-
while len(radios) < len(llm_host_names):
|
| 253 |
-
|
| 254 |
return radios
|
| 255 |
|
| 256 |
|
|
@@ -271,6 +300,17 @@ def init_gradio() -> gr.Blocks:
|
|
| 271 |
@gr.on(triggers=[blocks.load, *[radio.input for radio in radios]],
|
| 272 |
inputs=[radio_state, *radios], outputs=[radio_state, *radios])
|
| 273 |
def radio_click(state, *new_state):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
try:
|
| 275 |
changed_index = next(i for i in range(len(state))
|
| 276 |
if state[i] != new_state[i])
|
|
@@ -326,5 +366,5 @@ if __name__ == "__main__":
|
|
| 326 |
init_config()
|
| 327 |
init_oauth()
|
| 328 |
blocks = init_gradio()
|
| 329 |
-
app = gr.mount_gradio_app(app, blocks, '/'
|
| 330 |
uvicorn.run(app, host='0.0.0.0', port=7860)
|
|
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
+
from time import sleep
|
| 4 |
+
|
| 5 |
import gradio as gr
|
| 6 |
|
| 7 |
import uvicorn
|
|
|
|
| 13 |
from authlib.integrations.starlette_client import OAuth, OAuthError
|
| 14 |
from fastapi import FastAPI, Request
|
| 15 |
|
| 16 |
+
from shared import Client, User, OAuthProvider
|
| 17 |
|
| 18 |
app = FastAPI()
|
| 19 |
config = {}
|
|
|
|
| 53 |
global clients
|
| 54 |
global llm_host_names
|
| 55 |
config = json.loads(os.environ['CONFIG'])
|
| 56 |
+
client_config = config.get("clients") or config
|
| 57 |
+
for name in client_config:
|
| 58 |
+
model_personas = client_config[name].get("personas", {})
|
|
|
|
|
|
|
| 59 |
client = Client(
|
| 60 |
+
api_url=os.environ.get(client_config[name]['api_url'],
|
| 61 |
+
client_config[name]['api_url']),
|
| 62 |
+
api_key=os.environ.get(client_config[name]['api_key'],
|
| 63 |
+
client_config[name]['api_key']),
|
| 64 |
personas=model_personas
|
| 65 |
)
|
| 66 |
clients[name] = client
|
| 67 |
+
llm_host_names = list(client_config.keys())
|
| 68 |
|
| 69 |
|
| 70 |
+
def get_allowed_models(user: User) -> List[str]:
|
| 71 |
"""
|
| 72 |
Get a list of allowed endpoints for a specified user domain. Allowed domains
|
| 73 |
are configured in each model's configuration and may optionally be overridden
|
| 74 |
in the Gradio demo configuration.
|
| 75 |
+
:param user: User to get permissions for
|
| 76 |
+
:return: List of allowed endpoints from configuration (including empty
|
| 77 |
+
strings for disallowed endpoints)
|
| 78 |
"""
|
| 79 |
+
overrides = config.get("permissions_override", {})
|
| 80 |
allowed_endpoints = []
|
| 81 |
for client in clients:
|
| 82 |
+
permission = overrides.get(client,
|
| 83 |
+
clients[client].config.inference.permissions)
|
| 84 |
+
if not permission:
|
| 85 |
+
# Permissions not specified (None or empty dict); model is public
|
| 86 |
allowed_endpoints.append(client)
|
| 87 |
+
elif user.oauth == OAuthProvider.GOOGLE and user.permissions_id in \
|
| 88 |
+
permission.get("google_domains", []):
|
| 89 |
+
# Google oauth domain is in the allowed domain list
|
| 90 |
allowed_endpoints.append(client)
|
| 91 |
+
else:
|
| 92 |
+
allowed_endpoints.append("")
|
| 93 |
+
print(f"No permission to access {client}")
|
| 94 |
return allowed_endpoints
|
| 95 |
|
| 96 |
|
|
|
|
| 112 |
:param request: Gradio request to evaluate
|
| 113 |
:return: Button for either login or logout action
|
| 114 |
"""
|
| 115 |
+
user = get_user(request).username
|
| 116 |
print(f"Getting login button for {user}")
|
| 117 |
|
| 118 |
if user == "guest":
|
|
|
|
| 121 |
return gr.Button(f"Logout {user}", link="/logout")
|
| 122 |
|
| 123 |
|
| 124 |
+
def get_user(request: Request) -> User:
|
| 125 |
"""
|
| 126 |
Get a unique user email address for the specified request
|
| 127 |
:param request: FastAPI Request object with user session data
|
| 128 |
:return: String user email address or "guest"
|
| 129 |
"""
|
| 130 |
+
# {'iss': 'https://accounts.google.com',
|
| 131 |
+
# 'azp': '***.apps.googleusercontent.com',
|
| 132 |
+
# 'aud': '***.apps.googleusercontent.com',
|
| 133 |
+
# 'sub': '###',
|
| 134 |
+
# 'hd': 'neon.ai',
|
| 135 |
+
# 'email': '[email protected]',
|
| 136 |
+
# 'email_verified': True,
|
| 137 |
+
# 'at_hash': '***',
|
| 138 |
+
# 'nonce': '***',
|
| 139 |
+
# 'name': 'Daniel McKnight',
|
| 140 |
+
# 'picture': 'https://lh3.googleusercontent.com/a/***',
|
| 141 |
+
# 'given_name': '***',
|
| 142 |
+
# 'family_name': '***',
|
| 143 |
+
# 'iat': ###,
|
| 144 |
+
# 'exp': ###}
|
| 145 |
if not request:
|
| 146 |
+
return User(OAuthProvider.NONE, "guest", "")
|
| 147 |
+
|
| 148 |
+
user_dict = request.session.get("user", {})
|
| 149 |
+
if user_dict.get("iss") == "https://accounts.google.com":
|
| 150 |
+
user = User(OAuthProvider.GOOGLE, user_dict["email"], user_dict["hd"])
|
| 151 |
+
elif user_dict:
|
| 152 |
+
print(f"Unknown user session data: {user_dict}")
|
| 153 |
+
user = User(OAuthProvider.NONE, "guest", "")
|
| 154 |
+
else:
|
| 155 |
+
user = User(OAuthProvider.NONE, "guest", "")
|
| 156 |
+
print(user)
|
| 157 |
return user
|
| 158 |
|
| 159 |
|
|
|
|
| 261 |
# `user` is a valid Google email address or 'guest'
|
| 262 |
user = get_user(request.request)
|
| 263 |
else:
|
| 264 |
+
user = User(OAuthProvider.NONE, "guest", "")
|
| 265 |
+
print(f"Getting models for {user.username}")
|
| 266 |
|
| 267 |
+
allowed_llm_host_names = get_allowed_models(user)
|
|
|
|
| 268 |
|
| 269 |
radio_infos = [f"{name} ({clients[name].vllm_model_name})"
|
| 270 |
+
if name in clients else "Not Authorized"
|
| 271 |
for name in allowed_llm_host_names]
|
| 272 |
# Components
|
| 273 |
+
radios = [gr.Radio(choices=clients[name].personas.keys() if name in clients else [],
|
| 274 |
value=None, label=info) for name, info
|
| 275 |
in zip(allowed_llm_host_names, radio_infos)]
|
| 276 |
|
| 277 |
# Select the first available option by default
|
| 278 |
radios[0].value = list(clients[allowed_llm_host_names[0]].personas.keys())[0]
|
| 279 |
print(f"Set default persona to {radios[0].value} for {allowed_llm_host_names[0]}")
|
| 280 |
+
# # Ensure we always have the same number of rows
|
| 281 |
+
# while len(radios) < len(llm_host_names):
|
| 282 |
+
# radios.append(gr.Radio(choices=[], value=None, label="Not Authorized"))
|
| 283 |
return radios
|
| 284 |
|
| 285 |
|
|
|
|
| 300 |
@gr.on(triggers=[blocks.load, *[radio.input for radio in radios]],
|
| 301 |
inputs=[radio_state, *radios], outputs=[radio_state, *radios])
|
| 302 |
def radio_click(state, *new_state):
|
| 303 |
+
"""
|
| 304 |
+
Handle any state changes that require re-rendering radio buttons
|
| 305 |
+
:param state: Previous radio state representation (before selection)
|
| 306 |
+
:param new_state: Current radio state (including selection)
|
| 307 |
+
:return: Desired new state (current option selected, previous option
|
| 308 |
+
deselected)
|
| 309 |
+
"""
|
| 310 |
+
# Login and model options are triggered on load. This sleep is just
|
| 311 |
+
# a hack to make sure those events run before this logic to select
|
| 312 |
+
# the default model
|
| 313 |
+
sleep(0.1)
|
| 314 |
try:
|
| 315 |
changed_index = next(i for i in range(len(state))
|
| 316 |
if state[i] != new_state[i])
|
|
|
|
| 366 |
init_config()
|
| 367 |
init_oauth()
|
| 368 |
blocks = init_gradio()
|
| 369 |
+
app = gr.mount_gradio_app(app, blocks, '/')
|
| 370 |
uvicorn.run(app, host='0.0.0.0', port=7860)
|
shared.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import yaml
|
| 2 |
|
| 3 |
from typing import Dict, Optional, List
|
|
@@ -8,6 +11,18 @@ from huggingface_hub.utils import EntryNotFoundError
|
|
| 8 |
from openai import OpenAI
|
| 9 |
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
class PileConfig(BaseModel):
|
| 12 |
file2persona: Dict[str, str]
|
| 13 |
file2prefix: Dict[str, str]
|
|
@@ -17,7 +32,7 @@ class PileConfig(BaseModel):
|
|
| 17 |
|
| 18 |
class InferenceConfig(BaseModel):
|
| 19 |
chat_template: str
|
| 20 |
-
|
| 21 |
|
| 22 |
|
| 23 |
class RepoConfig(BaseModel):
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from enum import IntEnum
|
| 3 |
+
|
| 4 |
import yaml
|
| 5 |
|
| 6 |
from typing import Dict, Optional, List
|
|
|
|
| 11 |
from openai import OpenAI
|
| 12 |
|
| 13 |
|
| 14 |
+
class OAuthProvider(IntEnum):
|
| 15 |
+
NONE = 0
|
| 16 |
+
GOOGLE = 1
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class User:
|
| 21 |
+
oauth: OAuthProvider
|
| 22 |
+
username: str
|
| 23 |
+
permissions_id: str
|
| 24 |
+
|
| 25 |
+
|
| 26 |
class PileConfig(BaseModel):
|
| 27 |
file2persona: Dict[str, str]
|
| 28 |
file2prefix: Dict[str, str]
|
|
|
|
| 32 |
|
| 33 |
class InferenceConfig(BaseModel):
|
| 34 |
chat_template: str
|
| 35 |
+
permissions: Dict[str, list] = {}
|
| 36 |
|
| 37 |
|
| 38 |
class RepoConfig(BaseModel):
|