transformers-github-bot / retrieval.py
Amy Roberts
Draft
9b744c5
raw
history blame
2.21 kB
import argparse
import json
import pprint
import numpy as np
from sentence_transformers import SentenceTransformer
def cosine_similarity(a, b):
if a.ndim == 1:
a = a.reshape(1, -1)
if b.ndim == 1:
b = b.reshape(1, -1)
return np.dot(a, b.T) / (np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1))
def retrieve_issue_rankings(
query: str,
model_id: str,
input_embedding_filename: str,
):
"""
Given a query returns the list of issues sorted by similarity to the query
according to their embedding index
"""
model = SentenceTransformer(model_id)
embeddings = np.load(input_embedding_filename)
query_embedding = model.encode(query)
# Calculate the cosine similarity between the query and all the issues
cosine_similarities = cosine_similarity(query_embedding, embeddings)
# Get the index of the most similar issue
most_similar_indices = np.argsort(cosine_similarities)
most_similar_indices = most_similar_indices[0][::-1]
return most_similar_indices
def print_issue(issues, issue_id):
# Get the issue id of the most similar issue
issue_info = issues[issue_id]
print(f"#{issue_id}", issue_info["title"])
print(issue_info["body"])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("query", type=str)
parser.add_argument("--model_id", type=str, default="all-mpnet-base-v2")
parser.add_argument("--input_embedding_filename", type=str, default="issue_embeddings.npy")
parser.add_argument("--input_index_filename", type=str, default="embedding_index_to_issue.json")
args = parser.parse_args()
issue_rankings = retrieve_issue_rankings(
query=args.query,
model_id=args.model_id,
input_embedding_filename=args.input_embedding_filename,
)
with open("issues_dict.json", "r") as f:
issues = json.load(f)
with open(args.input_index_filename, "r") as f:
embedding_index_to_issue = json.load(f)
issue_ids = [embedding_index_to_issue[str(i)] for i in issue_rankings]
for issue_id in issue_ids[:3]:
print(issue_id)
print_issue(issues, issue_id)
print("\n\n\n")