datajoi-sql-agent / src /client.py
Muhammad Mustehson
Update Old Code
a360e3c
import json
import logging
import os
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
from pydantic import BaseModel
load_dotenv()
logger = logging.getLogger(__name__)
MAX_RESPONSE_TOKENS = 2048
TEMPERATURE = 0.9
models = json.loads(os.getenv("MODEL_NAMES"))
providers = json.loads(os.getenv("PROVIDERS"))
EMB_MODEL = os.getenv("EMB_MODEL")
def _engine_working(engine: InferenceClient) -> bool:
try:
engine.chat_completion([{"role": "user", "content": "ping"}], max_tokens=1)
logger.info("Engine is Working.")
return True
except Exception as e:
logger.exception(f"Engine is not working: {e}")
return False
def _load_llm_client() -> InferenceClient:
"""
Attempts to load the provided model from the huggingface endpoint.
Returns InferenceClient if successful.
Raises Exception if no model is available.
"""
logger.warning("Loading Model...")
errors = []
for model in models:
for provider in providers:
if isinstance(model, str):
try:
logger.info(f"Checking model: {model} provider: {provider}")
client = InferenceClient(
model=model,
timeout=15,
provider=provider,
)
if _engine_working(client):
logger.info(
f"The model is loaded : {model} , provider: {provider}"
)
return client
except Exception as e:
logger.error(
f"Error loading model {model} provider {provider}: {e}"
)
errors.append(str(e))
raise Exception(f"Unable to load any provided model: {errors}.")
def _load_embedding_client() -> InferenceClient:
logger.warning("Loading Embedding Model...")
try:
emb_client = InferenceClient(timeout=15, model=EMB_MODEL)
return emb_client
except Exception as e:
logger.error(f"Error loading model {EMB_MODEL}: {e}")
raise Exception("Unable to load the embedding model.")
_default_client = _load_llm_client()
embed_client = _load_embedding_client()
class LLMChain:
def __init__(self, client: InferenceClient = _default_client):
self.client = client
self.total_tokens = 0
def run(
self,
system_prompt: str | None = None,
user_prompt: str | None = None,
messages: list[dict] | None = None,
format_name: str | None = None,
response_format: type[BaseModel] | None = None,
) -> str | dict[str, str | int | float | None] | list[str] | None:
try:
if system_prompt and user_prompt:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
elif not messages:
raise ValueError(
"Either system_prompt and user_prompt or messages must be provided."
)
llm_response = self.client.chat_completion(
messages=messages,
max_tokens=MAX_RESPONSE_TOKENS,
temperature=TEMPERATURE,
response_format=(
{
"type": "json_schema",
"json_schema": {
"name": format_name,
"schema": response_format.model_json_schema(),
"strict": True,
},
}
if format_name and response_format
else None
),
)
self.total_tokens += llm_response.usage.total_tokens
analysis = llm_response.choices[0].message.content
if response_format:
analysis = json.loads(analysis)
fields = list(response_format.model_fields.keys())
if len(fields) == 1:
return analysis.get(fields[0])
return {field: analysis.get(field) for field in fields}
return analysis
except Exception as e:
logger.error(f"Error during LLM calls: {e}")
return None