Apriel-Chat / utils.py
bradnow's picture
Improve logging for normal operation
2ce979c
import os
import sys
import time
from functools import wraps
from typing import Any, Literal
from gradio import ChatMessage
from gradio.components.chatbot import Message
COMMUNITY_POSTFIX_URL = "/discussions"
DEBUG_MODE = False or os.environ.get("DEBUG_MODE") == "True"
DEBUG_MODEL = False or os.environ.get("DEBUG_MODEL") == "True"
models_config = {
"Apriel-Nemotron-15b-Thinker": {
"MODEL_DISPLAY_NAME": "Apriel-Nemotron-15b-Thinker",
"MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-Nemotron-15b-Thinker",
"MODEL_NAME": os.environ.get("MODEL_NAME_NEMO_15B"),
"VLLM_API_URL": os.environ.get("VLLM_API_URL_NEMO_15B"),
"AUTH_TOKEN": os.environ.get("AUTH_TOKEN"),
"REASONING": True
},
"Apriel-5b": {
"MODEL_DISPLAY_NAME": "Apriel-5b",
"MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-5B-Instruct",
"MODEL_NAME": os.environ.get("MODEL_NAME_5B"),
"VLLM_API_URL": os.environ.get("VLLM_API_URL_5B"),
"AUTH_TOKEN": os.environ.get("AUTH_TOKEN"),
"REASONING": False
}
}
def get_model_config(model_name: str) -> dict:
config = models_config.get(model_name)
if not config:
raise ValueError(f"Model {model_name} not found in models_config")
if not config.get("MODEL_NAME"):
raise ValueError(f"Model name not found in config for {model_name}")
if not config.get("VLLM_API_URL"):
raise ValueError(f"VLLM API URL not found in config for {model_name}")
return config
def _log_message(prefix, message, icon=""):
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
if len(icon) > 0:
icon = f"{icon} "
print(f"{timestamp}: {prefix} {icon}{message}")
def log_debug(message):
if DEBUG_MODE is True:
_log_message("DEBUG", message)
def log_info(message):
_log_message("INFO ", message)
def log_warning(message):
_log_message("WARN ", message, "⚠️")
def log_error(message):
_log_message("ERROR", message, "‼️")
# Gradio 5.0.1 had issues with checking the message formats. 5.29.0 does not!
def check_format(messages: Any, type: Literal["messages", "tuples"] = "messages") -> None:
if not DEBUG_MODE:
return
if type == "messages":
all_valid = all(
isinstance(message, dict)
and "role" in message
and "content" in message
or isinstance(message, ChatMessage | Message)
for message in messages
)
if not all_valid:
# Display which message is not valid
for i, message in enumerate(messages):
if not (isinstance(message, dict) and
"role" in message and
"content" in message) and not isinstance(message, ChatMessage | Message):
print(f"_check_format() --> Invalid message at index {i}: {message}\n", file=sys.stderr)
break
raise Exception(
"Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object."
)
# else:
# print("_check_format() --> All messages are valid.")
elif not all(
isinstance(message, (tuple, list)) and len(message) == 2
for message in messages
):
raise Exception(
"Data incompatible with tuples format. Each message should be a list of length 2."
)
# Adds timing info for a gradio event handler (non-generator functions)
def logged_event_handler(log_msg='', event_handler=None, log_timer=None, clear_timer=False):
@wraps(event_handler)
def wrapped_event_handler(*args, **kwargs):
# Log before
if log_timer:
if clear_timer:
log_timer.clear()
log_timer.add_step(f"Start: {log_debug}")
log_debug(f"::: Before event: {log_msg}")
# Call the original event handler
result = event_handler(*args, **kwargs)
# Log after
if log_timer:
log_timer.add_step(f"Completed: {log_msg}")
log_debug(f"::: After event: {log_msg}")
return result
return wrapped_event_handler