Update README.md to have inference usage script (#2)
Browse files- Update README.md to have inference usage script (59d7eca6bd0a01eeda43076b463a3a39fc7b55e2)
Co-authored-by: Junqiu Lei <[email protected]>
README.md
CHANGED
@@ -1,29 +1,266 @@
|
|
1 |
-
---
|
2 |
-
language: en
|
3 |
-
license: apache-2.0
|
4 |
-
library_name: transformers
|
5 |
-
tags:
|
6 |
-
- opensearch
|
7 |
-
- semantic-search
|
8 |
-
- highlighting
|
9 |
-
- sentence-highlighter
|
10 |
-
- bert
|
11 |
-
- text-classification
|
12 |
-
- pytorch
|
13 |
-
pipeline_tag: text-classification
|
14 |
-
---
|
15 |
-
|
16 |
-
# opensearch-semantic-highlighter
|
17 |
-
|
18 |
-
## Overview
|
19 |
-
|
20 |
-
The OpenSearch semantic highlighter is a trained classifier that takes a document and query as input and returns a binary score for each sentence in the document indicating its relevance to the query.
|
21 |
-
|
22 |
-
##
|
23 |
-
|
24 |
-
This
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: en
|
3 |
+
license: apache-2.0
|
4 |
+
library_name: transformers
|
5 |
+
tags:
|
6 |
+
- opensearch
|
7 |
+
- semantic-search
|
8 |
+
- highlighting
|
9 |
+
- sentence-highlighter
|
10 |
+
- bert
|
11 |
+
- text-classification
|
12 |
+
- pytorch
|
13 |
+
pipeline_tag: text-classification
|
14 |
+
---
|
15 |
+
|
16 |
+
# opensearch-semantic-highlighter
|
17 |
+
|
18 |
+
## Overview
|
19 |
+
|
20 |
+
The OpenSearch semantic highlighter is a trained classifier that takes a document and query as input and returns a binary score for each sentence in the document indicating its relevance to the query.
|
21 |
+
|
22 |
+
## Usage
|
23 |
+
|
24 |
+
This model is intended to run **inside an OpenSearch cluster**. For production workloads you should deploy the traced version via the ML Commons plugin—see the OpenSearch documentation on [semantic sentence highlighting models](https://docs.opensearch.org/docs/latest/ml-commons-plugin/pretrained-models/#semantic-sentence-highlighting-models).
|
25 |
+
|
26 |
+
If you simply want to experiment outside a cluster you can run the source model locally. First install the dependencies (Python ≥ 3.8):
|
27 |
+
|
28 |
+
```bash
|
29 |
+
pip install torch transformers datasets nltk
|
30 |
+
python -m nltk.downloader punkt
|
31 |
+
```
|
32 |
+
|
33 |
+
Then run the example below:
|
34 |
+
|
35 |
+
```python
|
36 |
+
import nltk
|
37 |
+
import torch
|
38 |
+
import numpy as np
|
39 |
+
from datasets import Dataset
|
40 |
+
from functools import partial
|
41 |
+
from torch.utils.data import DataLoader
|
42 |
+
from dataclasses import dataclass, field
|
43 |
+
from typing import Any, Dict, List, Union
|
44 |
+
from torch.nn.utils.rnn import pad_sequence
|
45 |
+
from transformers import AutoTokenizer, BertModel, BertPreTrainedModel
|
46 |
+
import torch.nn as nn
|
47 |
+
|
48 |
+
class BertTaggerForSentenceExtractionWithBackoff(BertPreTrainedModel):
|
49 |
+
"""Sentence-level BERT classifier with a confidence-backoff rule."""
|
50 |
+
|
51 |
+
def __init__(self, config):
|
52 |
+
super().__init__(config)
|
53 |
+
self.num_labels = config.num_labels
|
54 |
+
|
55 |
+
self.bert = BertModel(config)
|
56 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
57 |
+
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
|
58 |
+
self.init_weights()
|
59 |
+
|
60 |
+
def forward(
|
61 |
+
self,
|
62 |
+
input_ids=None,
|
63 |
+
attention_mask=None,
|
64 |
+
token_type_ids=None,
|
65 |
+
sentence_ids=None,
|
66 |
+
):
|
67 |
+
outputs = self.bert(
|
68 |
+
input_ids,
|
69 |
+
attention_mask=attention_mask,
|
70 |
+
token_type_ids=token_type_ids,
|
71 |
+
)
|
72 |
+
|
73 |
+
sequence_output = self.dropout(outputs[0])
|
74 |
+
|
75 |
+
def _get_agg_output(ids, seq_out):
|
76 |
+
max_sentences = torch.max(ids) + 1
|
77 |
+
d_model = seq_out.size(-1)
|
78 |
+
|
79 |
+
agg_out, global_offsets, num_sents = [], [], []
|
80 |
+
for i, sen_ids in enumerate(ids):
|
81 |
+
out, local_ids = [], sen_ids.clone()
|
82 |
+
mask = local_ids != -100
|
83 |
+
offset = local_ids[mask].min()
|
84 |
+
global_offsets.append(offset)
|
85 |
+
local_ids[mask] -= offset
|
86 |
+
n_sent = local_ids.max() + 1
|
87 |
+
num_sents.append(n_sent)
|
88 |
+
|
89 |
+
for j in range(int(n_sent)):
|
90 |
+
out.append(seq_out[i, local_ids == j].mean(dim=-2, keepdim=True))
|
91 |
+
|
92 |
+
if max_sentences - n_sent:
|
93 |
+
padding = torch.zeros(
|
94 |
+
(int(max_sentences - n_sent), d_model), device=seq_out.device
|
95 |
+
)
|
96 |
+
out.append(padding)
|
97 |
+
agg_out.append(torch.cat(out, dim=0))
|
98 |
+
return torch.stack(agg_out), global_offsets, num_sents
|
99 |
+
|
100 |
+
agg_output, offsets, num_sents_item = _get_agg_output(sentence_ids, sequence_output)
|
101 |
+
logits = self.classifier(agg_output)
|
102 |
+
probs = torch.softmax(logits, dim=-1)[:, :, 1]
|
103 |
+
|
104 |
+
def _get_preds(pp, offs, num_s, threshold=0.5, alpha=0.05):
|
105 |
+
preds = []
|
106 |
+
for p, off, ns in zip(pp, offs, num_s):
|
107 |
+
rel_probs = p[:ns]
|
108 |
+
hits = (rel_probs >= threshold).int()
|
109 |
+
if hits.sum() == 0 and rel_probs.max().item() >= alpha:
|
110 |
+
hits[rel_probs.argmax()] = 1
|
111 |
+
preds.append(torch.where(hits == 1)[0] + off)
|
112 |
+
return preds
|
113 |
+
|
114 |
+
return tuple(_get_preds(probs, offsets, num_sents_item))
|
115 |
+
|
116 |
+
|
117 |
+
# Dataclass for padding collator
|
118 |
+
@dataclass
|
119 |
+
class DataCollatorWithPadding:
|
120 |
+
pad_kvs: Dict[str, Union[int, float]] = field(default_factory=dict)
|
121 |
+
|
122 |
+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
123 |
+
first = features[0]
|
124 |
+
batch = {}
|
125 |
+
|
126 |
+
# pad and collate keys in self.pad_kvs
|
127 |
+
for key, pad_value in self.pad_kvs.items():
|
128 |
+
if key in first and first[key] is not None:
|
129 |
+
batch[key] = pad_sequence(
|
130 |
+
[torch.tensor(f[key]) for f in features],
|
131 |
+
batch_first=True,
|
132 |
+
padding_value=pad_value,
|
133 |
+
)
|
134 |
+
|
135 |
+
# collate remaining keys assuming that the values can be stacked
|
136 |
+
for k, v in first.items():
|
137 |
+
if k not in self.pad_kvs and v is not None and isinstance(v, torch.Tensor):
|
138 |
+
batch[k] = torch.stack([f[k] for f in features])
|
139 |
+
|
140 |
+
return batch
|
141 |
+
|
142 |
+
|
143 |
+
def prepare_input_features(
|
144 |
+
tokenizer, examples, max_seq_length=510, stride=128, padding=False
|
145 |
+
):
|
146 |
+
|
147 |
+
# jointly tokenize questions and context
|
148 |
+
tokenized_examples = tokenizer(
|
149 |
+
examples["question"],
|
150 |
+
examples["context"],
|
151 |
+
truncation="only_second",
|
152 |
+
max_length=max_seq_length,
|
153 |
+
stride=stride,
|
154 |
+
return_overflowing_tokens=True,
|
155 |
+
padding=padding,
|
156 |
+
is_split_into_words=True,
|
157 |
+
)
|
158 |
+
|
159 |
+
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
160 |
+
tokenized_examples["example_id"] = []
|
161 |
+
tokenized_examples["word_ids"] = []
|
162 |
+
tokenized_examples["sentence_ids"] = []
|
163 |
+
|
164 |
+
# process model inputs
|
165 |
+
for i, sample_index in enumerate(sample_mapping):
|
166 |
+
word_ids = tokenized_examples.word_ids(i)
|
167 |
+
word_level_sentence_ids = examples["word_level_sentence_ids"][sample_index]
|
168 |
+
|
169 |
+
sequence_ids = tokenized_examples.sequence_ids(i)
|
170 |
+
token_start_index = 0
|
171 |
+
while sequence_ids[token_start_index] != 1:
|
172 |
+
token_start_index += 1
|
173 |
+
|
174 |
+
sentences_ids = [-100] * token_start_index
|
175 |
+
for word_idx in word_ids[token_start_index:]:
|
176 |
+
if word_idx is not None:
|
177 |
+
sentences_ids.append(word_level_sentence_ids[word_idx])
|
178 |
+
else:
|
179 |
+
sentences_ids.append(-100)
|
180 |
+
|
181 |
+
tokenized_examples["sentence_ids"].append(sentences_ids)
|
182 |
+
tokenized_examples["example_id"].append(examples["id"][sample_index])
|
183 |
+
tokenized_examples["word_ids"].append(word_ids)
|
184 |
+
|
185 |
+
# ensure we don't exceed the model's max position embeddings (512 for BERT)
|
186 |
+
for key in ("input_ids", "token_type_ids", "attention_mask", "sentence_ids"):
|
187 |
+
tokenized_examples[key] = [seq[:max_seq_length] for seq in tokenized_examples[key]]
|
188 |
+
|
189 |
+
return tokenized_examples
|
190 |
+
|
191 |
+
|
192 |
+
# single example (same as README)
|
193 |
+
query = "When does OpenSearch use text reanalysis for highlighting?"
|
194 |
+
document = "To highlight the search terms, the highlighter needs the start and end character offsets of each term. The offsets mark the term's position in the original text. The highlighter can obtain the offsets from the following sources: Postings: When documents are indexed, OpenSearch creates an inverted search index—a core data structure used to search for documents. Postings represent the inverted search index and store the mapping of each analyzed term to the list of documents in which it occurs. If you set the index_options parameter to offsets when mapping a text field, OpenSearch adds each term's start and end character offsets to the inverted index. During highlighting, the highlighter reruns the original query directly on the postings to locate each term. Thus, storing offsets makes highlighting more efficient for large fields because it does not require reanalyzing the text. Storing term offsets requires additional disk space, but uses less disk space than storing term vectors. Text reanalysis: In the absence of both postings and term vectors, the highlighter reanalyzes text in order to highlight it. For every document and every field that needs highlighting, the highlighter creates a small in-memory index and reruns the original query through Lucene's query execution planner to access low-level match information for the current document. Reanalyzing the text works well in most use cases. However, this method is more memory and time intensive for large fields."
|
195 |
+
|
196 |
+
doc_sents = nltk.sent_tokenize(document)
|
197 |
+
sentence_ids, context = [], []
|
198 |
+
for sid, sent in enumerate(doc_sents):
|
199 |
+
words = sent.split()
|
200 |
+
context.extend(words)
|
201 |
+
sentence_ids.extend([sid] * len(words))
|
202 |
+
|
203 |
+
example_dataset = Dataset.from_dict(
|
204 |
+
{
|
205 |
+
"question": [[query]],
|
206 |
+
"context": [context],
|
207 |
+
"word_level_sentence_ids": [sentence_ids],
|
208 |
+
"id": [0],
|
209 |
+
}
|
210 |
+
)
|
211 |
+
|
212 |
+
# prepare to featurize the raw text data
|
213 |
+
base_model_id = "bert-base-uncased"
|
214 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
|
215 |
+
collator = DataCollatorWithPadding(
|
216 |
+
pad_kvs={
|
217 |
+
"input_ids": 0,
|
218 |
+
"token_type_ids": 0,
|
219 |
+
"attention_mask": 0,
|
220 |
+
"sentence_ids": -100,
|
221 |
+
"sentence_labels": -100,
|
222 |
+
}
|
223 |
+
)
|
224 |
+
preprocess_fn = partial(prepare_input_features, tokenizer)
|
225 |
+
|
226 |
+
# featurize
|
227 |
+
example_dataset = example_dataset.map(
|
228 |
+
preprocess_fn,
|
229 |
+
batched=True,
|
230 |
+
remove_columns=example_dataset.column_names,
|
231 |
+
desc="Preparing model inputs",
|
232 |
+
)
|
233 |
+
loader = DataLoader(example_dataset, batch_size=1, collate_fn=collator)
|
234 |
+
|
235 |
+
# get single batch
|
236 |
+
batch = next(iter(loader))
|
237 |
+
|
238 |
+
# load model and get sentence highlights
|
239 |
+
model = BertTaggerForSentenceExtractionWithBackoff.from_pretrained(
|
240 |
+
"opensearch-project/opensearch-semantic-highlighter-v1"
|
241 |
+
)
|
242 |
+
|
243 |
+
# clamp tensors to model max length
|
244 |
+
max_len = model.config.max_position_embeddings
|
245 |
+
for key in ("input_ids", "token_type_ids", "attention_mask", "sentence_ids"):
|
246 |
+
batch[key] = batch[key][:, :max_len]
|
247 |
+
|
248 |
+
highlights = model(
|
249 |
+
batch["input_ids"],
|
250 |
+
batch["attention_mask"],
|
251 |
+
batch["token_type_ids"],
|
252 |
+
batch["sentence_ids"],
|
253 |
+
)
|
254 |
+
|
255 |
+
highlighted_sentences = [doc_sents[i] for i in highlights[0]]
|
256 |
+
print(highlighted_sentences)
|
257 |
+
```
|
258 |
+
|
259 |
+
## License
|
260 |
+
|
261 |
+
This project is licensed under the [Apache v2.0 License](https://github.com/opensearch-project/neural-search/blob/main/LICENSE).
|
262 |
+
|
263 |
+
|
264 |
+
## Copyright
|
265 |
+
|
266 |
+
Copyright OpenSearch Contributors. See [NOTICE](https://github.com/opensearch-project/neural-search/blob/main/NOTICE) for details.
|