opensearch-semantic-highlighter

Overview

The OpenSearch semantic highlighter is a trained classifier that takes a document and query as input and returns a binary score for each sentence in the document indicating its relevance to the query.

Usage

This model is intended to run inside an OpenSearch cluster. For production workloads you should deploy the traced version via the ML Commons plugin—see the OpenSearch documentation on semantic sentence highlighting models.

If you simply want to experiment outside a cluster you can run the source model locally. First install the dependencies (Python ≥ 3.8):

pip install torch transformers datasets nltk
python -m nltk.downloader punkt

Then run the example below:

import nltk
import torch
import numpy as np
from datasets import Dataset
from functools import partial
from torch.utils.data import DataLoader
from dataclasses import dataclass, field
from typing import Any, Dict, List, Union
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, BertModel, BertPreTrainedModel
import torch.nn as nn

class BertTaggerForSentenceExtractionWithBackoff(BertPreTrainedModel):
    """Sentence-level BERT classifier with a confidence-backoff rule."""

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, self.num_labels)
        self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        sentence_ids=None,
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )

        sequence_output = self.dropout(outputs[0])

        def _get_agg_output(ids, seq_out):
            max_sentences = torch.max(ids) + 1
            d_model = seq_out.size(-1)

            agg_out, global_offsets, num_sents = [], [], []
            for i, sen_ids in enumerate(ids):
                out, local_ids = [], sen_ids.clone()
                mask = local_ids != -100
                offset = local_ids[mask].min()
                global_offsets.append(offset)
                local_ids[mask] -= offset
                n_sent = local_ids.max() + 1
                num_sents.append(n_sent)

                for j in range(int(n_sent)):
                    out.append(seq_out[i, local_ids == j].mean(dim=-2, keepdim=True))

                if max_sentences - n_sent:
                    padding = torch.zeros(
                        (int(max_sentences - n_sent), d_model), device=seq_out.device
                    )
                    out.append(padding)
                agg_out.append(torch.cat(out, dim=0))
            return torch.stack(agg_out), global_offsets, num_sents

        agg_output, offsets, num_sents_item = _get_agg_output(sentence_ids, sequence_output)
        logits = self.classifier(agg_output)
        probs = torch.softmax(logits, dim=-1)[:, :, 1]

        def _get_preds(pp, offs, num_s, threshold=0.5, alpha=0.05):
            preds = []
            for p, off, ns in zip(pp, offs, num_s):
                rel_probs = p[:ns]
                hits = (rel_probs >= threshold).int()
                if hits.sum() == 0 and rel_probs.max().item() >= alpha:
                    hits[rel_probs.argmax()] = 1
                preds.append(torch.where(hits == 1)[0] + off)
            return preds

        return tuple(_get_preds(probs, offsets, num_sents_item))


# Dataclass for padding collator
@dataclass
class DataCollatorWithPadding:
    pad_kvs: Dict[str, Union[int, float]] = field(default_factory=dict)

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        first = features[0]
        batch = {}

        # pad and collate keys in self.pad_kvs
        for key, pad_value in self.pad_kvs.items():
            if key in first and first[key] is not None:
                batch[key] = pad_sequence(
                    [torch.tensor(f[key]) for f in features],
                    batch_first=True,
                    padding_value=pad_value,
                )

        # collate remaining keys assuming that the values can be stacked
        for k, v in first.items():
            if k not in self.pad_kvs and v is not None and isinstance(v, torch.Tensor):
                batch[k] = torch.stack([f[k] for f in features])

        return batch


def prepare_input_features(
    tokenizer, examples, max_seq_length=510, stride=128, padding=False
):

    # jointly tokenize questions and context
    tokenized_examples = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=max_seq_length,
        stride=stride,
        return_overflowing_tokens=True,
        padding=padding,
        is_split_into_words=True,
    )

    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    tokenized_examples["example_id"] = []
    tokenized_examples["word_ids"] = []
    tokenized_examples["sentence_ids"] = []

    # process model inputs
    for i, sample_index in enumerate(sample_mapping):
        word_ids = tokenized_examples.word_ids(i)
        word_level_sentence_ids = examples["word_level_sentence_ids"][sample_index]

        sequence_ids = tokenized_examples.sequence_ids(i)
        token_start_index = 0
        while sequence_ids[token_start_index] != 1:
            token_start_index += 1

        sentences_ids = [-100] * token_start_index
        for word_idx in word_ids[token_start_index:]:
            if word_idx is not None:
                sentences_ids.append(word_level_sentence_ids[word_idx])
            else:
                sentences_ids.append(-100)

        tokenized_examples["sentence_ids"].append(sentences_ids)
        tokenized_examples["example_id"].append(examples["id"][sample_index])
        tokenized_examples["word_ids"].append(word_ids)

    # ensure we don't exceed the model's max position embeddings (512 for BERT)
    for key in ("input_ids", "token_type_ids", "attention_mask", "sentence_ids"):
        tokenized_examples[key] = [seq[:max_seq_length] for seq in tokenized_examples[key]]

    return tokenized_examples


# single example (same as README)
query = "When does OpenSearch use text reanalysis for highlighting?"
document = "To highlight the search terms, the highlighter needs the start and end character offsets of each term. The offsets mark the term's position in the original text. The highlighter can obtain the offsets from the following sources: Postings: When documents are indexed, OpenSearch creates an inverted search index—a core data structure used to search for documents. Postings represent the inverted search index and store the mapping of each analyzed term to the list of documents in which it occurs. If you set the index_options parameter to offsets when mapping a text field, OpenSearch adds each term's start and end character offsets to the inverted index. During highlighting, the highlighter reruns the original query directly on the postings to locate each term. Thus, storing offsets makes highlighting more efficient for large fields because it does not require reanalyzing the text. Storing term offsets requires additional disk space, but uses less disk space than storing term vectors. Text reanalysis: In the absence of both postings and term vectors, the highlighter reanalyzes text in order to highlight it. For every document and every field that needs highlighting, the highlighter creates a small in-memory index and reruns the original query through Lucene's query execution planner to access low-level match information for the current document. Reanalyzing the text works well in most use cases. However, this method is more memory and time intensive for large fields."

doc_sents = nltk.sent_tokenize(document)
sentence_ids, context = [], []
for sid, sent in enumerate(doc_sents):
    words = sent.split()
    context.extend(words)
    sentence_ids.extend([sid] * len(words))

example_dataset = Dataset.from_dict(
    {
        "question": [[query]],
        "context": [context],
        "word_level_sentence_ids": [sentence_ids],
        "id": [0],
    }
)

# prepare to featurize the raw text data
base_model_id = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
collator = DataCollatorWithPadding(
        pad_kvs={
            "input_ids": 0,
            "token_type_ids": 0,
            "attention_mask": 0,
            "sentence_ids": -100,
            "sentence_labels": -100,
        }
    )
preprocess_fn = partial(prepare_input_features, tokenizer)

# featurize 
example_dataset = example_dataset.map(
    preprocess_fn,
    batched=True,
    remove_columns=example_dataset.column_names,
    desc="Preparing model inputs",
)
loader = DataLoader(example_dataset, batch_size=1, collate_fn=collator)

# get single batch
batch = next(iter(loader))

# load model and get sentence highlights
model = BertTaggerForSentenceExtractionWithBackoff.from_pretrained(
    "opensearch-project/opensearch-semantic-highlighter-v1"
)

# clamp tensors to model max length
max_len = model.config.max_position_embeddings
for key in ("input_ids", "token_type_ids", "attention_mask", "sentence_ids"):
    batch[key] = batch[key][:, :max_len]

highlights = model(
    batch["input_ids"],
    batch["attention_mask"],
    batch["token_type_ids"],
    batch["sentence_ids"],
)

highlighted_sentences = [doc_sents[i] for i in highlights[0]]
print(highlighted_sentences)

License

This project is licensed under the Apache v2.0 License.

Copyright

Copyright OpenSearch Contributors. See NOTICE for details.

Downloads last month
2,956
Safetensors
Model size
0.1B params
Tensor type
F32
·
Inference Providers NEW