File size: 4,409 Bytes
a360e3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import json
import logging
import os

from dotenv import load_dotenv
from huggingface_hub import InferenceClient
from pydantic import BaseModel

load_dotenv()

logger = logging.getLogger(__name__)

MAX_RESPONSE_TOKENS = 2048
TEMPERATURE = 0.9

models = json.loads(os.getenv("MODEL_NAMES"))
providers = json.loads(os.getenv("PROVIDERS"))
EMB_MODEL = os.getenv("EMB_MODEL")


def _engine_working(engine: InferenceClient) -> bool:
    try:
        engine.chat_completion([{"role": "user", "content": "ping"}], max_tokens=1)
        logger.info("Engine is Working.")
        return True
    except Exception as e:
        logger.exception(f"Engine is not working: {e}")
        return False


def _load_llm_client() -> InferenceClient:
    """
    Attempts to load the provided model from the huggingface endpoint.

    Returns InferenceClient if successful.
    Raises Exception if no model is available.
    """
    logger.warning("Loading Model...")
    errors = []
    for model in models:
        for provider in providers:
            if isinstance(model, str):
                try:
                    logger.info(f"Checking model: {model} provider: {provider}")
                    client = InferenceClient(
                        model=model,
                        timeout=15,
                        provider=provider,
                    )
                    if _engine_working(client):
                        logger.info(
                            f"The model is loaded : {model} , provider: {provider}"
                        )
                        return client
                except Exception as e:
                    logger.error(
                        f"Error loading model {model} provider {provider}: {e}"
                    )
                    errors.append(str(e))
    raise Exception(f"Unable to load any provided model: {errors}.")


def _load_embedding_client() -> InferenceClient:
    logger.warning("Loading Embedding Model...")
    try:
        emb_client = InferenceClient(timeout=15, model=EMB_MODEL)
        return emb_client
    except Exception as e:
        logger.error(f"Error loading model {EMB_MODEL}: {e}")
        raise Exception("Unable to load the embedding model.")


_default_client = _load_llm_client()
embed_client = _load_embedding_client()


class LLMChain:
    def __init__(self, client: InferenceClient = _default_client):
        self.client = client
        self.total_tokens = 0

    def run(
        self,
        system_prompt: str | None = None,
        user_prompt: str | None = None,
        messages: list[dict] | None = None,
        format_name: str | None = None,
        response_format: type[BaseModel] | None = None,
    ) -> str | dict[str, str | int | float | None] | list[str] | None:
        try:
            if system_prompt and user_prompt:
                messages = [
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt},
                ]
            elif not messages:
                raise ValueError(
                    "Either system_prompt and user_prompt or messages must be provided."
                )

            llm_response = self.client.chat_completion(
                messages=messages,
                max_tokens=MAX_RESPONSE_TOKENS,
                temperature=TEMPERATURE,
                response_format=(
                    {
                        "type": "json_schema",
                        "json_schema": {
                            "name": format_name,
                            "schema": response_format.model_json_schema(),
                            "strict": True,
                        },
                    }
                    if format_name and response_format
                    else None
                ),
            )
            self.total_tokens += llm_response.usage.total_tokens
            analysis = llm_response.choices[0].message.content
            if response_format:
                analysis = json.loads(analysis)
                fields = list(response_format.model_fields.keys())
                if len(fields) == 1:
                    return analysis.get(fields[0])
                return {field: analysis.get(field) for field in fields}

            return analysis

        except Exception as e:
            logger.error(f"Error during LLM calls: {e}")
            return None