from typing import Any, Dict, Optional

import torch
from transformers import AutoModel, PreTrainedModel
from transformers.activations import ClippedGELUActivation, GELUActivation
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PoolerEndLogits

from .configuration_relik import RelikReaderConfig


class RelikReaderSample:
    def __init__(self, **kwargs):
        super().__setattr__("_d", {})
        self._d = kwargs

    def __getattribute__(self, item):
        return super(RelikReaderSample, self).__getattribute__(item)

    def __getattr__(self, item):
        if item.startswith("__") and item.endswith("__"):
            # this is likely some python library-specific variable (such as __deepcopy__ for copy)
            # better follow standard behavior here
            raise AttributeError(item)
        elif item in self._d:
            return self._d[item]
        else:
            return None

    def __setattr__(self, key, value):
        if key in self._d:
            self._d[key] = value
        else:
            super().__setattr__(key, value)


activation2functions = {
    "relu": torch.nn.ReLU(),
    "gelu": GELUActivation(),
    "gelu_10": ClippedGELUActivation(-10, 10),
}


class PoolerEndLogitsBi(PoolerEndLogits):
    def __init__(self, config: PretrainedConfig):
        super().__init__(config)
        self.dense_1 = torch.nn.Linear(config.hidden_size, 2)

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        start_states: Optional[torch.FloatTensor] = None,
        start_positions: Optional[torch.LongTensor] = None,
        p_mask: Optional[torch.FloatTensor] = None,
    ) -> torch.FloatTensor:
        if p_mask is not None:
            p_mask = p_mask.unsqueeze(-1)
        logits = super().forward(
            hidden_states,
            start_states,
            start_positions,
            p_mask,
        )
        return logits


class RelikReaderSpanModel(PreTrainedModel):
    config_class = RelikReaderConfig

    def __init__(self, config: RelikReaderConfig, *args, **kwargs):
        super().__init__(config)
        # Transformer model declaration
        self.config = config
        self.transformer_model = (
            AutoModel.from_pretrained(self.config.transformer_model)
            if self.config.num_layers is None
            else AutoModel.from_pretrained(
                self.config.transformer_model, num_hidden_layers=self.config.num_layers
            )
        )
        self.transformer_model.resize_token_embeddings(
            self.transformer_model.config.vocab_size
            + self.config.additional_special_symbols
        )

        self.activation = self.config.activation
        self.linears_hidden_size = self.config.linears_hidden_size
        self.use_last_k_layers = self.config.use_last_k_layers

        # named entity detection layers
        self.ned_start_classifier = self._get_projection_layer(
            self.activation, last_hidden=2, layer_norm=False
        )
        self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config)

        # END entity disambiguation layer
        self.ed_start_projector = self._get_projection_layer(self.activation)
        self.ed_end_projector = self._get_projection_layer(self.activation)

        self.training = self.config.training

        # criterion
        self.criterion = torch.nn.CrossEntropyLoss()

    def _get_projection_layer(
        self,
        activation: str,
        last_hidden: Optional[int] = None,
        input_hidden=None,
        layer_norm: bool = True,
    ) -> torch.nn.Sequential:
        head_components = [
            torch.nn.Dropout(0.1),
            torch.nn.Linear(
                self.transformer_model.config.hidden_size * self.use_last_k_layers
                if input_hidden is None
                else input_hidden,
                self.linears_hidden_size,
            ),
            activation2functions[activation],
            torch.nn.Dropout(0.1),
            torch.nn.Linear(
                self.linears_hidden_size,
                self.linears_hidden_size if last_hidden is None else last_hidden,
            ),
        ]

        if layer_norm:
            head_components.append(
                torch.nn.LayerNorm(
                    self.linears_hidden_size if last_hidden is None else last_hidden,
                    self.transformer_model.config.layer_norm_eps,
                )
            )

        return torch.nn.Sequential(*head_components)

    def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        mask = mask.unsqueeze(-1)
        if next(self.parameters()).dtype == torch.float16:
            logits = logits * (1 - mask) - 65500 * mask
        else:
            logits = logits * (1 - mask) - 1e30 * mask
        return logits

    def _get_model_features(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        token_type_ids: Optional[torch.Tensor],
    ):
        model_input = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "output_hidden_states": self.use_last_k_layers > 1,
        }

        if token_type_ids is not None:
            model_input["token_type_ids"] = token_type_ids

        model_output = self.transformer_model(**model_input)

        if self.use_last_k_layers > 1:
            model_features = torch.cat(
                model_output[1][-self.use_last_k_layers :], dim=-1
            )
        else:
            model_features = model_output[0]

        return model_features

    def compute_ned_end_logits(
        self,
        start_predictions,
        start_labels,
        model_features,
        prediction_mask,
        batch_size,
    ) -> Optional[torch.Tensor]:
        # todo: maybe when constraining on the spans,
        #  we should not use a prediction_mask for the end tokens.
        #  at least we should not during training imo
        start_positions = start_labels if self.training else start_predictions
        start_positions_indices = (
            torch.arange(start_positions.size(1), device=start_positions.device)
            .unsqueeze(0)
            .expand(batch_size, -1)[start_positions > 0]
        ).to(start_positions.device)

        if len(start_positions_indices) > 0:
            expanded_features = torch.cat(
                [
                    model_features[i].unsqueeze(0).expand(x, -1, -1)
                    for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
                    if x > 0
                ],
                dim=0,
            ).to(start_positions_indices.device)

            expanded_prediction_mask = torch.cat(
                [
                    prediction_mask[i].unsqueeze(0).expand(x, -1)
                    for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
                    if x > 0
                ],
                dim=0,
            ).to(expanded_features.device)

            end_logits = self.ned_end_classifier(
                hidden_states=expanded_features,
                start_positions=start_positions_indices,
                p_mask=expanded_prediction_mask,
            )

            return end_logits

        return None

    def compute_classification_logits(
        self,
        model_features,
        special_symbols_mask,
        prediction_mask,
        batch_size,
        start_positions=None,
        end_positions=None,
    ) -> torch.Tensor:
        if start_positions is None or end_positions is None:
            start_positions = torch.zeros_like(prediction_mask)
            end_positions = torch.zeros_like(prediction_mask)

        model_start_features = self.ed_start_projector(model_features)
        model_end_features = self.ed_end_projector(model_features)
        model_end_features[start_positions > 0] = model_end_features[end_positions > 0]

        model_ed_features = torch.cat(
            [model_start_features, model_end_features], dim=-1
        )

        # computing ed features
        classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item()
        special_symbols_representation = model_ed_features[special_symbols_mask].view(
            batch_size, classes_representations, -1
        )

        logits = torch.bmm(
            model_ed_features,
            torch.permute(special_symbols_representation, (0, 2, 1)),
        )

        logits = self._mask_logits(logits, prediction_mask)

        return logits

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        token_type_ids: Optional[torch.Tensor] = None,
        prediction_mask: Optional[torch.Tensor] = None,
        special_symbols_mask: Optional[torch.Tensor] = None,
        start_labels: Optional[torch.Tensor] = None,
        end_labels: Optional[torch.Tensor] = None,
        use_predefined_spans: bool = False,
        *args,
        **kwargs,
    ) -> Dict[str, Any]:
        batch_size, seq_len = input_ids.shape

        model_features = self._get_model_features(
            input_ids, attention_mask, token_type_ids
        )

        ned_start_labels = None

        # named entity detection if required
        if use_predefined_spans:  # no need to compute spans
            ned_start_logits, ned_start_probabilities, ned_start_predictions = (
                None,
                None,
                torch.clone(start_labels)
                if start_labels is not None
                else torch.zeros_like(input_ids),
            )
            ned_end_logits, ned_end_probabilities, ned_end_predictions = (
                None,
                None,
                torch.clone(end_labels)
                if end_labels is not None
                else torch.zeros_like(input_ids),
            )

            ned_start_predictions[ned_start_predictions > 0] = 1
            ned_end_predictions[ned_end_predictions > 0] = 1

        else:  # compute spans
            # start boundary prediction
            ned_start_logits = self.ned_start_classifier(model_features)
            ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask)
            ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
            ned_start_predictions = ned_start_probabilities.argmax(dim=-1)

            # end boundary prediction
            ned_start_labels = (
                torch.zeros_like(start_labels) if start_labels is not None else None
            )

            if ned_start_labels is not None:
                ned_start_labels[start_labels == -100] = -100
                ned_start_labels[start_labels > 0] = 1

            ned_end_logits = self.compute_ned_end_logits(
                ned_start_predictions,
                ned_start_labels,
                model_features,
                prediction_mask,
                batch_size,
            )

            if ned_end_logits is not None:
                ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
                ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
            else:
                ned_end_logits, ned_end_probabilities = None, None
                ned_end_predictions = ned_start_predictions.new_zeros(batch_size)

            # flattening end predictions
            #   (flattening can happen only if the
            #   end boundaries were not predicted using the gold labels)
            if not self.training:
                flattened_end_predictions = torch.clone(ned_start_predictions)
                flattened_end_predictions[flattened_end_predictions > 0] = 0

                batch_start_predictions = list()
                for elem_idx in range(batch_size):
                    batch_start_predictions.append(
                        torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist()
                    )

                # check that the total number of start predictions
                # is equal to the end predictions
                total_start_predictions = sum(map(len, batch_start_predictions))
                total_end_predictions = len(ned_end_predictions)
                assert (
                    total_start_predictions == 0
                    or total_start_predictions == total_end_predictions
                ), (
                    f"Total number of start predictions = {total_start_predictions}. "
                    f"Total number of end predictions = {total_end_predictions}"
                )

                curr_end_pred_num = 0
                for elem_idx, bsp in enumerate(batch_start_predictions):
                    for sp in bsp:
                        ep = ned_end_predictions[curr_end_pred_num].item()
                        if ep < sp:
                            ep = sp

                        # if we already set this span throw it (no overlap)
                        if flattened_end_predictions[elem_idx, ep] == 1:
                            ned_start_predictions[elem_idx, sp] = 0
                        else:
                            flattened_end_predictions[elem_idx, ep] = 1

                        curr_end_pred_num += 1

                ned_end_predictions = flattened_end_predictions

        start_position, end_position = (
            (start_labels, end_labels)
            if self.training
            else (ned_start_predictions, ned_end_predictions)
        )

        # Entity disambiguation
        ed_logits = self.compute_classification_logits(
            model_features,
            special_symbols_mask,
            prediction_mask,
            batch_size,
            start_position,
            end_position,
        )
        ed_probabilities = torch.softmax(ed_logits, dim=-1)
        ed_predictions = torch.argmax(ed_probabilities, dim=-1)

        # output build
        output_dict = dict(
            batch_size=batch_size,
            ned_start_logits=ned_start_logits,
            ned_start_probabilities=ned_start_probabilities,
            ned_start_predictions=ned_start_predictions,
            ned_end_logits=ned_end_logits,
            ned_end_probabilities=ned_end_probabilities,
            ned_end_predictions=ned_end_predictions,
            ed_logits=ed_logits,
            ed_probabilities=ed_probabilities,
            ed_predictions=ed_predictions,
        )

        # compute loss if labels
        if start_labels is not None and end_labels is not None and self.training:
            # named entity detection loss

            # start
            if ned_start_logits is not None:
                ned_start_loss = self.criterion(
                    ned_start_logits.view(-1, ned_start_logits.shape[-1]),
                    ned_start_labels.view(-1),
                )
            else:
                ned_start_loss = 0

            # end
            if ned_end_logits is not None:
                ned_end_labels = torch.zeros_like(end_labels)
                ned_end_labels[end_labels == -100] = -100
                ned_end_labels[end_labels > 0] = 1

                ned_end_loss = self.criterion(
                    ned_end_logits,
                    (
                        torch.arange(
                            ned_end_labels.size(1), device=ned_end_labels.device
                        )
                        .unsqueeze(0)
                        .expand(batch_size, -1)[ned_end_labels > 0]
                    ).to(ned_end_labels.device),
                )

            else:
                ned_end_loss = 0

            # entity disambiguation loss
            start_labels[ned_start_labels != 1] = -100
            ed_labels = torch.clone(start_labels)
            ed_labels[end_labels > 0] = end_labels[end_labels > 0]
            ed_loss = self.criterion(
                ed_logits.view(-1, ed_logits.shape[-1]),
                ed_labels.view(-1),
            )

            output_dict["ned_start_loss"] = ned_start_loss
            output_dict["ned_end_loss"] = ned_end_loss
            output_dict["ed_loss"] = ed_loss

            output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss

        return output_dict


class RelikReaderREModel(PreTrainedModel):
    config_class = RelikReaderConfig

    def __init__(self, config, *args, **kwargs):
        super().__init__(config)
        # Transformer model declaration
        # self.transformer_model_name = transformer_model
        self.config = config
        self.transformer_model = (
            AutoModel.from_pretrained(config.transformer_model)
            if config.num_layers is None
            else AutoModel.from_pretrained(
                config.transformer_model, num_hidden_layers=config.num_layers
            )
        )
        self.transformer_model.resize_token_embeddings(
            self.transformer_model.config.vocab_size + config.additional_special_symbols
        )

        # named entity detection layers
        self.ned_start_classifier = self._get_projection_layer(
            config.activation, last_hidden=2, layer_norm=False
        )

        self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config)

        self.relation_disambiguation_loss = (
            config.relation_disambiguation_loss
            if hasattr(config, "relation_disambiguation_loss")
            else False
        )

        if self.config.entity_type_loss and self.config.add_entity_embedding:
            input_hidden_ents = 3 * self.transformer_model.config.hidden_size
        else:
            input_hidden_ents = 2 * self.transformer_model.config.hidden_size

        self.re_subject_projector = self._get_projection_layer(
            config.activation, input_hidden=input_hidden_ents
        )
        self.re_object_projector = self._get_projection_layer(
            config.activation, input_hidden=input_hidden_ents
        )
        self.re_relation_projector = self._get_projection_layer(config.activation)

        if self.config.entity_type_loss or self.relation_disambiguation_loss:
            self.re_entities_projector = self._get_projection_layer(
                config.activation,
                input_hidden=2 * self.transformer_model.config.hidden_size,
            )
            self.re_definition_projector = self._get_projection_layer(
                config.activation,
            )

        self.re_classifier = self._get_projection_layer(
            config.activation,
            input_hidden=config.linears_hidden_size,
            last_hidden=2,
            layer_norm=False,
        )

        if self.config.entity_type_loss or self.relation_disambiguation_loss:
            self.re_ed_classifier = self._get_projection_layer(
                config.activation,
                input_hidden=config.linears_hidden_size,
                last_hidden=2,
                layer_norm=False,
            )

        self.training = config.training

        # criterion
        self.criterion = torch.nn.CrossEntropyLoss()

    def _get_projection_layer(
        self,
        activation: str,
        last_hidden: Optional[int] = None,
        input_hidden=None,
        layer_norm: bool = True,
    ) -> torch.nn.Sequential:
        head_components = [
            torch.nn.Dropout(0.1),
            torch.nn.Linear(
                self.transformer_model.config.hidden_size
                * self.config.use_last_k_layers
                if input_hidden is None
                else input_hidden,
                self.config.linears_hidden_size,
            ),
            activation2functions[activation],
            torch.nn.Dropout(0.1),
            torch.nn.Linear(
                self.config.linears_hidden_size,
                self.config.linears_hidden_size if last_hidden is None else last_hidden,
            ),
        ]

        if layer_norm:
            head_components.append(
                torch.nn.LayerNorm(
                    self.config.linears_hidden_size
                    if last_hidden is None
                    else last_hidden,
                    self.transformer_model.config.layer_norm_eps,
                )
            )

        return torch.nn.Sequential(*head_components)

    def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        mask = mask.unsqueeze(-1)
        if next(self.parameters()).dtype == torch.float16:
            logits = logits * (1 - mask) - 65500 * mask
        else:
            logits = logits * (1 - mask) - 1e30 * mask
        return logits

    def _get_model_features(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        token_type_ids: Optional[torch.Tensor],
    ):
        model_input = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "output_hidden_states": self.config.use_last_k_layers > 1,
        }

        if token_type_ids is not None:
            model_input["token_type_ids"] = token_type_ids

        model_output = self.transformer_model(**model_input)

        if self.config.use_last_k_layers > 1:
            model_features = torch.cat(
                model_output[1][-self.config.use_last_k_layers :], dim=-1
            )
        else:
            model_features = model_output[0]

        return model_features

    def compute_ned_end_logits(
        self,
        start_predictions,
        start_labels,
        model_features,
        prediction_mask,
        batch_size,
        mask_preceding: bool = False,
    ) -> Optional[torch.Tensor]:
        # todo: maybe when constraining on the spans,
        #  we should not use a prediction_mask for the end tokens.
        #  at least we should not during training imo
        start_positions = start_labels if self.training else start_predictions
        start_positions_indices = (
            torch.arange(start_positions.size(1), device=start_positions.device)
            .unsqueeze(0)
            .expand(batch_size, -1)[start_positions > 0]
        ).to(start_positions.device)

        if len(start_positions_indices) > 0:
            expanded_features = model_features.repeat_interleave(
                torch.sum(start_positions > 0, dim=-1), dim=0
            )
            expanded_prediction_mask = prediction_mask.repeat_interleave(
                torch.sum(start_positions > 0, dim=-1), dim=0
            )
            if mask_preceding:
                expanded_prediction_mask[
                    torch.arange(
                        expanded_prediction_mask.shape[1],
                        device=expanded_prediction_mask.device,
                    )
                    < start_positions_indices.unsqueeze(1)
                ] = 1
            end_logits = self.ned_end_classifier(
                hidden_states=expanded_features,
                start_positions=start_positions_indices,
                p_mask=expanded_prediction_mask,
            )

            return end_logits

        return None

    def compute_relation_logits(
        self,
        model_entity_features,
        special_symbols_features,
    ) -> torch.Tensor:
        model_subject_features = self.re_subject_projector(model_entity_features)
        model_object_features = self.re_object_projector(model_entity_features)
        special_symbols_start_representation = self.re_relation_projector(
            special_symbols_features
        )
        re_logits = torch.einsum(
            "bse,bde,bfe->bsdfe",
            model_subject_features,
            model_object_features,
            special_symbols_start_representation,
        )
        re_logits = self.re_classifier(re_logits)

        return re_logits

    def compute_entity_logits(
        self,
        model_entity_features,
        special_symbols_features,
    ) -> torch.Tensor:
        model_ed_features = self.re_entities_projector(model_entity_features)
        special_symbols_ed_representation = self.re_definition_projector(
            special_symbols_features
        )
        logits = torch.einsum(
            "bce,bde->bcde",
            model_ed_features,
            special_symbols_ed_representation,
        )
        logits = self.re_ed_classifier(logits)
        start_logits = self._mask_logits(
            logits,
            (model_entity_features == -100)
            .all(2)
            .long()
            .unsqueeze(2)
            .repeat(1, 1, torch.sum(model_entity_features, dim=1)[0].item()),
        )

        return logits

    def compute_loss(self, logits, labels, mask=None):
        logits = logits.view(-1, logits.shape[-1])
        labels = labels.view(-1).long()
        if mask is not None:
            return self.criterion(logits[mask], labels[mask])
        return self.criterion(logits, labels)

    def compute_ned_end_loss(self, ned_end_logits, end_labels):
        if ned_end_logits is None:
            return 0
        ned_end_labels = torch.zeros_like(end_labels)
        ned_end_labels[end_labels == -100] = -100
        ned_end_labels[end_labels > 0] = 1
        return self.compute_loss(ned_end_logits, ned_end_labels)

    def compute_ned_type_loss(
        self,
        disambiguation_labels,
        re_ned_entities_logits,
        ned_type_logits,
        re_entities_logits,
        entity_types,
    ):
        if self.config.entity_type_loss and self.relation_disambiguation_loss:
            return self.compute_loss(disambiguation_labels, re_ned_entities_logits)
        if self.config.entity_type_loss:
            return self.compute_loss(
                disambiguation_labels[:, :, :entity_types], ned_type_logits
            )
        if self.relation_disambiguation_loss:
            return self.compute_loss(disambiguation_labels, re_entities_logits)
        return 0

    def compute_relation_loss(self, relation_labels, re_logits):
        return self.compute_loss(
            re_logits, relation_labels, relation_labels.view(-1) != -100
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        token_type_ids: torch.Tensor,
        prediction_mask: Optional[torch.Tensor] = None,
        special_symbols_mask: Optional[torch.Tensor] = None,
        special_symbols_mask_entities: Optional[torch.Tensor] = None,
        start_labels: Optional[torch.Tensor] = None,
        end_labels: Optional[torch.Tensor] = None,
        disambiguation_labels: Optional[torch.Tensor] = None,
        relation_labels: Optional[torch.Tensor] = None,
        is_validation: bool = False,
        is_prediction: bool = False,
        *args,
        **kwargs,
    ) -> Dict[str, Any]:
        batch_size = input_ids.shape[0]

        model_features = self._get_model_features(
            input_ids, attention_mask, token_type_ids
        )

        # named entity detection
        if is_prediction and start_labels is not None:
            ned_start_logits, ned_start_probabilities, ned_start_predictions = (
                None,
                None,
                torch.zeros_like(start_labels),
            )
            ned_end_logits, ned_end_probabilities, ned_end_predictions = (
                None,
                None,
                torch.zeros_like(end_labels),
            )

            ned_start_predictions[start_labels > 0] = 1
            ned_end_predictions[end_labels > 0] = 1
            ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
        else:
            # start boundary prediction
            ned_start_logits = self.ned_start_classifier(model_features)
            ned_start_logits = self._mask_logits(
                ned_start_logits, prediction_mask
            )  # why?
            ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
            ned_start_predictions = ned_start_probabilities.argmax(dim=-1)

            # end boundary prediction
            ned_start_labels = (
                torch.zeros_like(start_labels) if start_labels is not None else None
            )

            # start_labels contain entity id at their position, we just need 1 for start of entity
            if ned_start_labels is not None:
                ned_start_labels[start_labels == -100] = -100
                ned_start_labels[start_labels > 0] = 1

            # compute end logits only if there are any start predictions.
            # For each start prediction, n end predictions are made
            ned_end_logits = self.compute_ned_end_logits(
                ned_start_predictions,
                ned_start_labels,
                model_features,
                prediction_mask,
                batch_size,
                True,
            )
            # For each start prediction, n end predictions are made based on
            # binary classification ie. argmax at each position.
            ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
            ned_end_predictions = ned_end_probabilities.argmax(dim=-1)
            if is_prediction or is_validation:
                end_preds_count = ned_end_predictions.sum(1)
                # If there are no end predictions for a start prediction, remove the start prediction
                ned_start_predictions[ned_start_predictions == 1] = (
                    end_preds_count != 0
                ).long()
                ned_end_predictions = ned_end_predictions[end_preds_count != 0]

        if end_labels is not None:
            end_labels = end_labels[~(end_labels == -100).all(2)]

        start_position, end_position = (
            (start_labels, end_labels)
            if (not is_prediction and not is_validation)
            else (ned_start_predictions, ned_end_predictions)
        )

        start_counts = (start_position > 0).sum(1)
        ned_end_predictions = ned_end_predictions.split(start_counts.tolist())

        # We can only predict relations if we have start and end predictions
        if (end_position > 0).sum() > 0:
            ends_count = (end_position > 0).sum(1)
            model_subject_features = torch.cat(
                [
                    torch.repeat_interleave(
                        model_features[start_position > 0], ends_count, dim=0
                    ),  # start position features
                    torch.repeat_interleave(model_features, start_counts, dim=0)[
                        end_position > 0
                    ],  # end position features
                ],
                dim=-1,
            )
            ents_count = torch.nn.utils.rnn.pad_sequence(
                torch.split(ends_count, start_counts.tolist()),
                batch_first=True,
                padding_value=0,
            ).sum(1)
            model_subject_features = torch.nn.utils.rnn.pad_sequence(
                torch.split(model_subject_features, ents_count.tolist()),
                batch_first=True,
                padding_value=-100,
            )

            if is_validation or is_prediction:
                model_subject_features = model_subject_features[:, :30, :]

            # entity disambiguation. Here relation_disambiguation_loss would only be useful to
            # reduce the number of candidate relations for the next step, but currently unused.
            if self.config.entity_type_loss or self.relation_disambiguation_loss:
                (re_ned_entities_logits) = self.compute_entity_logits(
                    model_subject_features,
                    model_features[
                        special_symbols_mask | special_symbols_mask_entities
                    ].view(batch_size, -1, model_features.shape[-1]),
                )
                entity_types = torch.sum(special_symbols_mask_entities, dim=1)[0].item()
                ned_type_logits = re_ned_entities_logits[:, :, :entity_types]
                re_entities_logits = re_ned_entities_logits[:, :, entity_types:]

                if self.config.entity_type_loss:
                    ned_type_probabilities = torch.softmax(ned_type_logits, dim=-1)
                    ned_type_predictions = ned_type_probabilities.argmax(dim=-1)
                    ned_type_predictions = ned_type_predictions.argmax(dim=-1)
                    if self.config.add_entity_embedding:
                        special_symbols_representation = model_features[
                            special_symbols_mask
                        ].view(batch_size, entity_types, -1)

                        entities_representation = torch.einsum(
                            "bsp,bpe->bse",
                            ned_type_probabilities,
                            special_symbols_representation,
                        )
                        model_subject_features = torch.cat(
                            [model_subject_features, entities_representation], dim=-1
                        )
                re_entities_probabilities = torch.softmax(re_entities_logits, dim=-1)
                re_entities_predictions = re_entities_probabilities.argmax(dim=-1)
            else:
                (
                    ned_type_logits,
                    ned_type_probabilities,
                    re_entities_logits,
                    re_entities_probabilities,
                ) = (None, None, None, None)
                ned_type_predictions, re_entities_predictions = (
                    torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
                    torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
                )

            # Compute relation logits
            re_logits = self.compute_relation_logits(
                model_subject_features,
                model_features[special_symbols_mask].view(
                    batch_size, -1, model_features.shape[-1]
                ),
            )

            re_probabilities = torch.softmax(re_logits, dim=-1)
            # we set a thresshold instead of argmax in cause it needs to be tweaked
            re_predictions = re_probabilities[:, :, :, :, 1] > 0.5
            # re_predictions = re_probabilities.argmax(dim=-1)
            re_probabilities = re_probabilities[:, :, :, :, 1]

        else:
            (
                ned_type_logits,
                ned_type_probabilities,
                re_entities_logits,
                re_entities_probabilities,
            ) = (None, None, None, None)
            ned_type_predictions, re_entities_predictions = (
                torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
                torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
            )
            re_logits, re_probabilities, re_predictions = (
                torch.zeros(
                    [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
                ).to(input_ids.device),
                torch.zeros(
                    [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
                ).to(input_ids.device),
                torch.zeros(
                    [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
                ).to(input_ids.device),
            )

        # output build
        output_dict = dict(
            batch_size=batch_size,
            ned_start_logits=ned_start_logits,
            ned_start_probabilities=ned_start_probabilities,
            ned_start_predictions=ned_start_predictions,
            ned_end_logits=ned_end_logits,
            ned_end_probabilities=ned_end_probabilities,
            ned_end_predictions=ned_end_predictions,
            ned_type_logits=ned_type_logits,
            ned_type_probabilities=ned_type_probabilities,
            ned_type_predictions=ned_type_predictions,
            re_entities_logits=re_entities_logits,
            re_entities_probabilities=re_entities_probabilities,
            re_entities_predictions=re_entities_predictions,
            re_logits=re_logits,
            re_probabilities=re_probabilities,
            re_predictions=re_predictions,
        )

        if (
            start_labels is not None
            and end_labels is not None
            and relation_labels is not None
        ):
            ned_start_loss = self.compute_loss(ned_start_logits, ned_start_labels)
            ned_end_loss = self.compute_ned_end_loss(ned_end_logits, end_labels)
            if self.config.entity_type_loss or self.relation_disambiguation_loss:
                ned_type_loss = self.compute_ned_type_loss(
                    disambiguation_labels,
                    re_ned_entities_logits,
                    ned_type_logits,
                    re_entities_logits,
                    entity_types,
                )
            relation_loss = self.compute_relation_loss(relation_labels, re_logits)
            # compute loss. We can skip the relation loss if we are in the first epochs (optional)
            if self.config.entity_type_loss or self.relation_disambiguation_loss:
                output_dict["loss"] = (
                    ned_start_loss + ned_end_loss + relation_loss + ned_type_loss
                ) / 4
                output_dict["ned_type_loss"] = ned_type_loss
            else:
                output_dict["loss"] = (
                    ned_start_loss + ned_end_loss + relation_loss
                ) / 3

            output_dict["ned_start_loss"] = ned_start_loss
            output_dict["ned_end_loss"] = ned_end_loss
            output_dict["re_loss"] = relation_loss

        return output_dict