Spaces:
Sleeping
Sleeping
import argparse | |
import json | |
import pprint | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
def cosine_similarity(a, b): | |
if a.ndim == 1: | |
a = a.reshape(1, -1) | |
if b.ndim == 1: | |
b = b.reshape(1, -1) | |
return np.dot(a, b.T) / (np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1)) | |
def retrieve_issue_rankings( | |
query: str, | |
model_id: str, | |
input_embedding_filename: str, | |
): | |
""" | |
Given a query returns the list of issues sorted by similarity to the query | |
according to their embedding index | |
""" | |
model = SentenceTransformer(model_id) | |
embeddings = np.load(input_embedding_filename) | |
query_embedding = model.encode(query) | |
# Calculate the cosine similarity between the query and all the issues | |
cosine_similarities = cosine_similarity(query_embedding, embeddings) | |
# Get the index of the most similar issue | |
most_similar_indices = np.argsort(cosine_similarities) | |
most_similar_indices = most_similar_indices[0][::-1] | |
return most_similar_indices | |
def print_issue(issues, issue_id): | |
# Get the issue id of the most similar issue | |
issue_info = issues[issue_id] | |
print(f"#{issue_id}", issue_info["title"]) | |
print(issue_info["body"]) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("query", type=str) | |
parser.add_argument("--model_id", type=str, default="all-mpnet-base-v2") | |
parser.add_argument("--input_embedding_filename", type=str, default="issue_embeddings.npy") | |
parser.add_argument("--input_index_filename", type=str, default="embedding_index_to_issue.json") | |
args = parser.parse_args() | |
issue_rankings = retrieve_issue_rankings( | |
query=args.query, | |
model_id=args.model_id, | |
input_embedding_filename=args.input_embedding_filename, | |
) | |
with open("issues_dict.json", "r") as f: | |
issues = json.load(f) | |
with open(args.input_index_filename, "r") as f: | |
embedding_index_to_issue = json.load(f) | |
issue_ids = [embedding_index_to_issue[str(i)] for i in issue_rankings] | |
for issue_id in issue_ids[:3]: | |
print(issue_id) | |
print_issue(issues, issue_id) | |
print("\n\n\n") | |