flan-ul2-text-encoder
The encoder model extracted from flan-ul2 via a new class add in a recent release.
⚠️ This model is 17.44 GB in bfloat16
precision ⚠️
basic usage
from transformers import AutoTokenizer, AutoModelForTextEncoding
tokenizer = AutoTokenizer.from_pretrained("pszemraj/flan-ul2-text-encoder")
model = AutoModelForTextEncoding.from_pretrained("pszemraj/flan-ul2-text-encoder")
inputs = tokenizer("Hello, my dog loves memes", return_tensors="pt")
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state
usage: semantic similarity
note: this is 'one way' to use the encoder, not 'the only way'. suggestions and ideas welcome.
Below is an example and a set of functions to compute the cosine similarity between the embeddings of different texts with this model
Functions
load_model_and_tokenizer
Loads the model and tokenizer based on model_name
, returning a tuple containing the loaded model and tokenizer.
Details
from typing import List, Tuple
import torch
from transformers import AutoModel, AutoTokenizer
from transformers import AutoModelForTextEncoding
def load_model_and_tokenizer(model_name: str) -> Tuple[AutoModel, AutoTokenizer]:
"""
Load the model and tokenizer based on the given model name.
Args:
model_name (str): The name of the model to be loaded.
Returns:
Tuple[AutoModelForTextEncoding, AutoTokenizer]: The loaded model and tokenizer.
"""
model = AutoModelForTextEncoding.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto"
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
get_embeddings
This computes the embeddings for the given texts given the model and tokenizer via weighted mean pooling across seq_len (as in SGPT)
Details
def get_embeddings(
model: AutoModel, tokenizer: AutoTokenizer, texts: List[str]
) -> torch.Tensor:
"""
compute text embeddings via weighted mean pooling across seq_len
Args:
model (AutoModel): The model to be used for getting embeddings.
tokenizer (AutoTokenizer): The tokenizer to be used for tokenizing the texts.
texts (List[str]): The texts for which embeddings are to be calculated.
Returns:
torch.Tensor: The calculated embeddings.
"""
# Tokenize input texts
batch_tokens = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
# Get the embeddings
with torch.no_grad():
last_hidden_state = model(
**batch_tokens, output_hidden_states=True, return_dict=True
).last_hidden_state
# Get weights
weights = (
torch.arange(start=1, end=last_hidden_state.shape[1] + 1)
.unsqueeze(0)
.unsqueeze(-1)
.expand(last_hidden_state.size())
.float()
.to(last_hidden_state.device)
)
# Get attn mask
input_mask_expanded = (
batch_tokens["attention_mask"]
.unsqueeze(-1)
.expand(last_hidden_state.size())
.float()
)
# Perform weighted mean pooling across seq_len: bs, seq_len, hidden_dim -> bs, hidden_dim
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded * weights, dim=1)
sum_mask = torch.sum(input_mask_expanded * weights, dim=1)
embeddings = sum_embeddings / sum_mask
return embeddings
calculate_cosine_similarity
Helper fn to compute and print out cosine similarity
click to expand
from scipy.spatial.distance import cosine
def calculate_cosine_similarity(embeddings: torch.Tensor, texts: List[str]) -> None:
"""compute and print the cosine sim between the first text and all others"""
# Calculate cosine similarities
for i in range(1, len(embeddings)):
cosine_sim = 1 - cosine(embeddings[0], embeddings[i])
print(
'Cosine similarity between "%s" and "%s" is: %.3f'
% (texts[0], texts[i], cosine_sim)
)
Usage
Install packages:
pip install transformers accelerate sentencepiece scipy
Then, you can use the functions to compute embeddings and similarity scores:
model_name = "pszemraj/flan-ul2-text-encoder"
model, tokenizer = load_model_and_tokenizer(model_name)
texts = [
"deep learning",
"artificial intelligence",
"deep diving",
"artificial snow",
]
embeddings = get_embeddings(model, tokenizer, texts)
calculate_cosine_similarity(embeddings, texts)
This will print the cosine similarity between the first text and all other texts in the `texts' list.
References
Inference with this model/the example is based on the ideas and examples in the SGPT repository.
@article{muennighoff2022sgpt,
title={SGPT: GPT Sentence Embeddings for Semantic Search},
author={Muennighoff, Niklas},
journal={arXiv preprint arXiv:2202.08904},
year={2022}
}
- Downloads last month
- 2,565