import datetime
import json
import logging
from json import JSONDecodeError
from typing import Optional

from constants import HIDDEN_VALUE
from core.entities.provider_configuration import ProviderConfiguration
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_manager import LBModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import (
    ModelCredentialSchema,
    ProviderCredentialSchema,
)
from core.model_runtime.model_providers import model_provider_factory
from core.provider_manager import ProviderManager
from extensions.ext_database import db
from models.provider import LoadBalancingModelConfig

logger = logging.getLogger(__name__)


class ModelLoadBalancingService:
    def __init__(self) -> None:
        self.provider_manager = ProviderManager()

    def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
        """
        enable model load balancing.

        :param tenant_id: workspace id
        :param provider: provider name
        :param model: model name
        :param model_type: model type
        :return:
        """
        # Get all provider configurations of the current workspace
        provider_configurations = self.provider_manager.get_configurations(tenant_id)

        # Get provider configuration
        provider_configuration = provider_configurations.get(provider)
        if not provider_configuration:
            raise ValueError(f"Provider {provider} does not exist.")

        # Enable model load balancing
        provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))

    def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
        """
        disable model load balancing.

        :param tenant_id: workspace id
        :param provider: provider name
        :param model: model name
        :param model_type: model type
        :return:
        """
        # Get all provider configurations of the current workspace
        provider_configurations = self.provider_manager.get_configurations(tenant_id)

        # Get provider configuration
        provider_configuration = provider_configurations.get(provider)
        if not provider_configuration:
            raise ValueError(f"Provider {provider} does not exist.")

        # disable model load balancing
        provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))

    def get_load_balancing_configs(
        self, tenant_id: str, provider: str, model: str, model_type: str
    ) -> tuple[bool, list[dict]]:
        """
        Get load balancing configurations.
        :param tenant_id: workspace id
        :param provider: provider name
        :param model: model name
        :param model_type: model type
        :return:
        """
        # Get all provider configurations of the current workspace
        provider_configurations = self.provider_manager.get_configurations(tenant_id)

        # Get provider configuration
        provider_configuration = provider_configurations.get(provider)
        if not provider_configuration:
            raise ValueError(f"Provider {provider} does not exist.")

        # Convert model type to ModelType
        model_type = ModelType.value_of(model_type)

        # Get provider model setting
        provider_model_setting = provider_configuration.get_provider_model_setting(
            model_type=model_type,
            model=model,
        )

        is_load_balancing_enabled = False
        if provider_model_setting and provider_model_setting.load_balancing_enabled:
            is_load_balancing_enabled = True

        # Get load balancing configurations
        load_balancing_configs = (
            db.session.query(LoadBalancingModelConfig)
            .filter(
                LoadBalancingModelConfig.tenant_id == tenant_id,
                LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
                LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
                LoadBalancingModelConfig.model_name == model,
            )
            .order_by(LoadBalancingModelConfig.created_at)
            .all()
        )

        if provider_configuration.custom_configuration.provider:
            # check if the inherit configuration exists,
            # inherit is represented for the provider or model custom credentials
            inherit_config_exists = False
            for load_balancing_config in load_balancing_configs:
                if load_balancing_config.name == "__inherit__":
                    inherit_config_exists = True
                    break

            if not inherit_config_exists:
                # Initialize the inherit configuration
                inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type)

                # prepend the inherit configuration
                load_balancing_configs.insert(0, inherit_config)
            else:
                # move the inherit configuration to the first
                for i, load_balancing_config in enumerate(load_balancing_configs[:]):
                    if load_balancing_config.name == "__inherit__":
                        inherit_config = load_balancing_configs.pop(i)
                        load_balancing_configs.insert(0, inherit_config)

        # Get credential form schemas from model credential schema or provider credential schema
        credential_schemas = self._get_credential_schema(provider_configuration)

        # Get decoding rsa key and cipher for decrypting credentials
        decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)

        # fetch status and ttl for each config
        datas = []
        for load_balancing_config in load_balancing_configs:
            in_cooldown, ttl = LBModelManager.get_config_in_cooldown_and_ttl(
                tenant_id=tenant_id,
                provider=provider,
                model=model,
                model_type=model_type,
                config_id=load_balancing_config.id,
            )

            try:
                if load_balancing_config.encrypted_config:
                    credentials = json.loads(load_balancing_config.encrypted_config)
                else:
                    credentials = {}
            except JSONDecodeError:
                credentials = {}

            # Get provider credential secret variables
            credential_secret_variables = provider_configuration.extract_secret_variables(
                credential_schemas.credential_form_schemas
            )

            # decrypt credentials
            for variable in credential_secret_variables:
                if variable in credentials:
                    try:
                        credentials[variable] = encrypter.decrypt_token_with_decoding(
                            credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa
                        )
                    except ValueError:
                        pass

            # Obfuscate credentials
            credentials = provider_configuration.obfuscated_credentials(
                credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
            )

            datas.append(
                {
                    "id": load_balancing_config.id,
                    "name": load_balancing_config.name,
                    "credentials": credentials,
                    "enabled": load_balancing_config.enabled,
                    "in_cooldown": in_cooldown,
                    "ttl": ttl,
                }
            )

        return is_load_balancing_enabled, datas

    def get_load_balancing_config(
        self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str
    ) -> Optional[dict]:
        """
        Get load balancing configuration.
        :param tenant_id: workspace id
        :param provider: provider name
        :param model: model name
        :param model_type: model type
        :param config_id: load balancing config id
        :return:
        """
        # Get all provider configurations of the current workspace
        provider_configurations = self.provider_manager.get_configurations(tenant_id)

        # Get provider configuration
        provider_configuration = provider_configurations.get(provider)
        if not provider_configuration:
            raise ValueError(f"Provider {provider} does not exist.")

        # Convert model type to ModelType
        model_type = ModelType.value_of(model_type)

        # Get load balancing configurations
        load_balancing_model_config = (
            db.session.query(LoadBalancingModelConfig)
            .filter(
                LoadBalancingModelConfig.tenant_id == tenant_id,
                LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
                LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
                LoadBalancingModelConfig.model_name == model,
                LoadBalancingModelConfig.id == config_id,
            )
            .first()
        )

        if not load_balancing_model_config:
            return None

        try:
            if load_balancing_model_config.encrypted_config:
                credentials = json.loads(load_balancing_model_config.encrypted_config)
            else:
                credentials = {}
        except JSONDecodeError:
            credentials = {}

        # Get credential form schemas from model credential schema or provider credential schema
        credential_schemas = self._get_credential_schema(provider_configuration)

        # Obfuscate credentials
        credentials = provider_configuration.obfuscated_credentials(
            credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
        )

        return {
            "id": load_balancing_model_config.id,
            "name": load_balancing_model_config.name,
            "credentials": credentials,
            "enabled": load_balancing_model_config.enabled,
        }

    def _init_inherit_config(
        self, tenant_id: str, provider: str, model: str, model_type: ModelType
    ) -> LoadBalancingModelConfig:
        """
        Initialize the inherit configuration.
        :param tenant_id: workspace id
        :param provider: provider name
        :param model: model name
        :param model_type: model type
        :return:
        """
        # Initialize the inherit configuration
        inherit_config = LoadBalancingModelConfig(
            tenant_id=tenant_id,
            provider_name=provider,
            model_type=model_type.to_origin_model_type(),
            model_name=model,
            name="__inherit__",
        )
        db.session.add(inherit_config)
        db.session.commit()

        return inherit_config

    def update_load_balancing_configs(
        self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict]
    ) -> None:
        """
        Update load balancing configurations.
        :param tenant_id: workspace id
        :param provider: provider name
        :param model: model name
        :param model_type: model type
        :param configs: load balancing configs
        :return:
        """
        # Get all provider configurations of the current workspace
        provider_configurations = self.provider_manager.get_configurations(tenant_id)

        # Get provider configuration
        provider_configuration = provider_configurations.get(provider)
        if not provider_configuration:
            raise ValueError(f"Provider {provider} does not exist.")

        # Convert model type to ModelType
        model_type = ModelType.value_of(model_type)

        if not isinstance(configs, list):
            raise ValueError("Invalid load balancing configs")

        current_load_balancing_configs = (
            db.session.query(LoadBalancingModelConfig)
            .filter(
                LoadBalancingModelConfig.tenant_id == tenant_id,
                LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
                LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
                LoadBalancingModelConfig.model_name == model,
            )
            .all()
        )

        # id as key, config as value
        current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}
        updated_config_ids = set()

        for config in configs:
            if not isinstance(config, dict):
                raise ValueError("Invalid load balancing config")

            config_id = config.get("id")
            name = config.get("name")
            credentials = config.get("credentials")
            enabled = config.get("enabled")

            if not name:
                raise ValueError("Invalid load balancing config name")

            if enabled is None:
                raise ValueError("Invalid load balancing config enabled")

            # is config exists
            if config_id:
                config_id = str(config_id)

                if config_id not in current_load_balancing_configs_dict:
                    raise ValueError("Invalid load balancing config id: {}".format(config_id))

                updated_config_ids.add(config_id)

                load_balancing_config = current_load_balancing_configs_dict[config_id]

                # check duplicate name
                for current_load_balancing_config in current_load_balancing_configs:
                    if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name:
                        raise ValueError("Load balancing config name {} already exists".format(name))

                if credentials:
                    if not isinstance(credentials, dict):
                        raise ValueError("Invalid load balancing config credentials")

                    # validate custom provider config
                    credentials = self._custom_credentials_validate(
                        tenant_id=tenant_id,
                        provider_configuration=provider_configuration,
                        model_type=model_type,
                        model=model,
                        credentials=credentials,
                        load_balancing_model_config=load_balancing_config,
                        validate=False,
                    )

                    # update load balancing config
                    load_balancing_config.encrypted_config = json.dumps(credentials)

                load_balancing_config.name = name
                load_balancing_config.enabled = enabled
                load_balancing_config.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
                db.session.commit()

                self._clear_credentials_cache(tenant_id, config_id)
            else:
                # create load balancing config
                if name == "__inherit__":
                    raise ValueError("Invalid load balancing config name")

                # check duplicate name
                for current_load_balancing_config in current_load_balancing_configs:
                    if current_load_balancing_config.name == name:
                        raise ValueError("Load balancing config name {} already exists".format(name))

                if not credentials:
                    raise ValueError("Invalid load balancing config credentials")

                if not isinstance(credentials, dict):
                    raise ValueError("Invalid load balancing config credentials")

                # validate custom provider config
                credentials = self._custom_credentials_validate(
                    tenant_id=tenant_id,
                    provider_configuration=provider_configuration,
                    model_type=model_type,
                    model=model,
                    credentials=credentials,
                    validate=False,
                )

                # create load balancing config
                load_balancing_model_config = LoadBalancingModelConfig(
                    tenant_id=tenant_id,
                    provider_name=provider_configuration.provider.provider,
                    model_type=model_type.to_origin_model_type(),
                    model_name=model,
                    name=name,
                    encrypted_config=json.dumps(credentials),
                )

                db.session.add(load_balancing_model_config)
                db.session.commit()

        # get deleted config ids
        deleted_config_ids = set(current_load_balancing_configs_dict.keys()) - updated_config_ids
        for config_id in deleted_config_ids:
            db.session.delete(current_load_balancing_configs_dict[config_id])
            db.session.commit()

            self._clear_credentials_cache(tenant_id, config_id)

    def validate_load_balancing_credentials(
        self,
        tenant_id: str,
        provider: str,
        model: str,
        model_type: str,
        credentials: dict,
        config_id: Optional[str] = None,
    ) -> None:
        """
        Validate load balancing credentials.
        :param tenant_id: workspace id
        :param provider: provider name
        :param model_type: model type
        :param model: model name
        :param credentials: credentials
        :param config_id: load balancing config id
        :return:
        """
        # Get all provider configurations of the current workspace
        provider_configurations = self.provider_manager.get_configurations(tenant_id)

        # Get provider configuration
        provider_configuration = provider_configurations.get(provider)
        if not provider_configuration:
            raise ValueError(f"Provider {provider} does not exist.")

        # Convert model type to ModelType
        model_type = ModelType.value_of(model_type)

        load_balancing_model_config = None
        if config_id:
            # Get load balancing config
            load_balancing_model_config = (
                db.session.query(LoadBalancingModelConfig)
                .filter(
                    LoadBalancingModelConfig.tenant_id == tenant_id,
                    LoadBalancingModelConfig.provider_name == provider,
                    LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
                    LoadBalancingModelConfig.model_name == model,
                    LoadBalancingModelConfig.id == config_id,
                )
                .first()
            )

            if not load_balancing_model_config:
                raise ValueError(f"Load balancing config {config_id} does not exist.")

        # Validate custom provider config
        self._custom_credentials_validate(
            tenant_id=tenant_id,
            provider_configuration=provider_configuration,
            model_type=model_type,
            model=model,
            credentials=credentials,
            load_balancing_model_config=load_balancing_model_config,
        )

    def _custom_credentials_validate(
        self,
        tenant_id: str,
        provider_configuration: ProviderConfiguration,
        model_type: ModelType,
        model: str,
        credentials: dict,
        load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
        validate: bool = True,
    ) -> dict:
        """
        Validate custom credentials.
        :param tenant_id: workspace id
        :param provider_configuration: provider configuration
        :param model_type: model type
        :param model: model name
        :param credentials: credentials
        :param load_balancing_model_config: load balancing model config
        :param validate: validate credentials
        :return:
        """
        # Get credential form schemas from model credential schema or provider credential schema
        credential_schemas = self._get_credential_schema(provider_configuration)

        # Get provider credential secret variables
        provider_credential_secret_variables = provider_configuration.extract_secret_variables(
            credential_schemas.credential_form_schemas
        )

        if load_balancing_model_config:
            try:
                # fix origin data
                if load_balancing_model_config.encrypted_config:
                    original_credentials = json.loads(load_balancing_model_config.encrypted_config)
                else:
                    original_credentials = {}
            except JSONDecodeError:
                original_credentials = {}

            # encrypt credentials
            for key, value in credentials.items():
                if key in provider_credential_secret_variables:
                    # if send [__HIDDEN__] in secret input, it will be same as original value
                    if value == HIDDEN_VALUE and key in original_credentials:
                        credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key])

        if validate:
            if isinstance(credential_schemas, ModelCredentialSchema):
                credentials = model_provider_factory.model_credentials_validate(
                    provider=provider_configuration.provider.provider,
                    model_type=model_type,
                    model=model,
                    credentials=credentials,
                )
            else:
                credentials = model_provider_factory.provider_credentials_validate(
                    provider=provider_configuration.provider.provider, credentials=credentials
                )

        for key, value in credentials.items():
            if key in provider_credential_secret_variables:
                credentials[key] = encrypter.encrypt_token(tenant_id, value)

        return credentials

    def _get_credential_schema(
        self, provider_configuration: ProviderConfiguration
    ) -> ModelCredentialSchema | ProviderCredentialSchema:
        """
        Get form schemas.
        :param provider_configuration: provider configuration
        :return:
        """
        # Get credential form schemas from model credential schema or provider credential schema
        if provider_configuration.provider.model_credential_schema:
            credential_schema = provider_configuration.provider.model_credential_schema
        else:
            credential_schema = provider_configuration.provider.provider_credential_schema

        return credential_schema

    def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
        """
        Clear credentials cache.
        :param tenant_id: workspace id
        :param config_id: load balancing config id
        :return:
        """
        provider_model_credentials_cache = ProviderCredentialsCache(
            tenant_id=tenant_id, identity_id=config_id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
        )

        provider_model_credentials_cache.delete()