Spaces:
Running
Running
File size: 4,409 Bytes
a360e3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import json
import logging
import os
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
from pydantic import BaseModel
load_dotenv()
logger = logging.getLogger(__name__)
MAX_RESPONSE_TOKENS = 2048
TEMPERATURE = 0.9
models = json.loads(os.getenv("MODEL_NAMES"))
providers = json.loads(os.getenv("PROVIDERS"))
EMB_MODEL = os.getenv("EMB_MODEL")
def _engine_working(engine: InferenceClient) -> bool:
try:
engine.chat_completion([{"role": "user", "content": "ping"}], max_tokens=1)
logger.info("Engine is Working.")
return True
except Exception as e:
logger.exception(f"Engine is not working: {e}")
return False
def _load_llm_client() -> InferenceClient:
"""
Attempts to load the provided model from the huggingface endpoint.
Returns InferenceClient if successful.
Raises Exception if no model is available.
"""
logger.warning("Loading Model...")
errors = []
for model in models:
for provider in providers:
if isinstance(model, str):
try:
logger.info(f"Checking model: {model} provider: {provider}")
client = InferenceClient(
model=model,
timeout=15,
provider=provider,
)
if _engine_working(client):
logger.info(
f"The model is loaded : {model} , provider: {provider}"
)
return client
except Exception as e:
logger.error(
f"Error loading model {model} provider {provider}: {e}"
)
errors.append(str(e))
raise Exception(f"Unable to load any provided model: {errors}.")
def _load_embedding_client() -> InferenceClient:
logger.warning("Loading Embedding Model...")
try:
emb_client = InferenceClient(timeout=15, model=EMB_MODEL)
return emb_client
except Exception as e:
logger.error(f"Error loading model {EMB_MODEL}: {e}")
raise Exception("Unable to load the embedding model.")
_default_client = _load_llm_client()
embed_client = _load_embedding_client()
class LLMChain:
def __init__(self, client: InferenceClient = _default_client):
self.client = client
self.total_tokens = 0
def run(
self,
system_prompt: str | None = None,
user_prompt: str | None = None,
messages: list[dict] | None = None,
format_name: str | None = None,
response_format: type[BaseModel] | None = None,
) -> str | dict[str, str | int | float | None] | list[str] | None:
try:
if system_prompt and user_prompt:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
elif not messages:
raise ValueError(
"Either system_prompt and user_prompt or messages must be provided."
)
llm_response = self.client.chat_completion(
messages=messages,
max_tokens=MAX_RESPONSE_TOKENS,
temperature=TEMPERATURE,
response_format=(
{
"type": "json_schema",
"json_schema": {
"name": format_name,
"schema": response_format.model_json_schema(),
"strict": True,
},
}
if format_name and response_format
else None
),
)
self.total_tokens += llm_response.usage.total_tokens
analysis = llm_response.choices[0].message.content
if response_format:
analysis = json.loads(analysis)
fields = list(response_format.model_fields.keys())
if len(fields) == 1:
return analysis.get(fields[0])
return {field: analysis.get(field) for field in fields}
return analysis
except Exception as e:
logger.error(f"Error during LLM calls: {e}")
return None
|