Spaces:
Running
Running
| import json | |
| import logging | |
| import os | |
| from dotenv import load_dotenv | |
| from huggingface_hub import InferenceClient | |
| from pydantic import BaseModel | |
| load_dotenv() | |
| logger = logging.getLogger(__name__) | |
| MAX_RESPONSE_TOKENS = 2048 | |
| TEMPERATURE = 0.9 | |
| models = json.loads(os.getenv("MODEL_NAMES")) | |
| providers = json.loads(os.getenv("PROVIDERS")) | |
| EMB_MODEL = os.getenv("EMB_MODEL") | |
| def _engine_working(engine: InferenceClient) -> bool: | |
| try: | |
| engine.chat_completion([{"role": "user", "content": "ping"}], max_tokens=1) | |
| logger.info("Engine is Working.") | |
| return True | |
| except Exception as e: | |
| logger.exception(f"Engine is not working: {e}") | |
| return False | |
| def _load_llm_client() -> InferenceClient: | |
| """ | |
| Attempts to load the provided model from the huggingface endpoint. | |
| Returns InferenceClient if successful. | |
| Raises Exception if no model is available. | |
| """ | |
| logger.warning("Loading Model...") | |
| errors = [] | |
| for model in models: | |
| for provider in providers: | |
| if isinstance(model, str): | |
| try: | |
| logger.info(f"Checking model: {model} provider: {provider}") | |
| client = InferenceClient( | |
| model=model, | |
| timeout=15, | |
| provider=provider, | |
| ) | |
| if _engine_working(client): | |
| logger.info( | |
| f"The model is loaded : {model} , provider: {provider}" | |
| ) | |
| return client | |
| except Exception as e: | |
| logger.error( | |
| f"Error loading model {model} provider {provider}: {e}" | |
| ) | |
| errors.append(str(e)) | |
| raise Exception(f"Unable to load any provided model: {errors}.") | |
| def _load_embedding_client() -> InferenceClient: | |
| logger.warning("Loading Embedding Model...") | |
| try: | |
| emb_client = InferenceClient(timeout=15, model=EMB_MODEL) | |
| return emb_client | |
| except Exception as e: | |
| logger.error(f"Error loading model {EMB_MODEL}: {e}") | |
| raise Exception("Unable to load the embedding model.") | |
| _default_client = _load_llm_client() | |
| embed_client = _load_embedding_client() | |
| class LLMChain: | |
| def __init__(self, client: InferenceClient = _default_client): | |
| self.client = client | |
| self.total_tokens = 0 | |
| def run( | |
| self, | |
| system_prompt: str | None = None, | |
| user_prompt: str | None = None, | |
| messages: list[dict] | None = None, | |
| format_name: str | None = None, | |
| response_format: type[BaseModel] | None = None, | |
| ) -> str | dict[str, str | int | float | None] | list[str] | None: | |
| try: | |
| if system_prompt and user_prompt: | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| elif not messages: | |
| raise ValueError( | |
| "Either system_prompt and user_prompt or messages must be provided." | |
| ) | |
| llm_response = self.client.chat_completion( | |
| messages=messages, | |
| max_tokens=MAX_RESPONSE_TOKENS, | |
| temperature=TEMPERATURE, | |
| response_format=( | |
| { | |
| "type": "json_schema", | |
| "json_schema": { | |
| "name": format_name, | |
| "schema": response_format.model_json_schema(), | |
| "strict": True, | |
| }, | |
| } | |
| if format_name and response_format | |
| else None | |
| ), | |
| ) | |
| self.total_tokens += llm_response.usage.total_tokens | |
| analysis = llm_response.choices[0].message.content | |
| if response_format: | |
| analysis = json.loads(analysis) | |
| fields = list(response_format.model_fields.keys()) | |
| if len(fields) == 1: | |
| return analysis.get(fields[0]) | |
| return {field: analysis.get(field) for field in fields} | |
| return analysis | |
| except Exception as e: | |
| logger.error(f"Error during LLM calls: {e}") | |
| return None | |