philtoms commited on
Commit
a8c1cc6
·
verified ·
1 Parent(s): 4a81ba0

Create App.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import time
4
+ from sentence_transformers import SentenceTransformer, util
5
+ import os
6
+ import json
7
+
8
+ # Determine model path based on environment
9
+ if "HF_SPACE_ID" in os.environ:
10
+ # Running on Hugging Face Spaces
11
+ # Assumes the model is in a repository with the same name as the space
12
+ space_name = os.environ["HF_SPACE_ID"].split("/")[-1]
13
+ model_path = f"{os.environ['HF_USER_NAME']}/{space_name}"
14
+ print(f"Running on HF Spaces. Using model: {model_path}")
15
+ else:
16
+ # Running locally
17
+ model_path = "../models/minilm-alice-base-rsft-v1/final"
18
+ print(f"Running locally. Using model: {model_path}")
19
+
20
+ # Load the model
21
+ model = SentenceTransformer(model_path)
22
+
23
+ # Load the dataset
24
+ # Adjust the data path for local vs. HF environment
25
+ data_path = "data/alice_pairs.jsonl" if "HF_SPACE_ID" in os.environ else "../data/alice_pairs.jsonl"
26
+
27
+ dataset = []
28
+ with open(data_path, "r") as f:
29
+ for line in f:
30
+ dataset.append(json.loads(line))
31
+
32
+ corpus = [item["passage"] for item in dataset]
33
+ corpus_embeddings = model.encode(corpus, convert_to_tensor=True)
34
+
35
+ def find_similar(prompt, top_k):
36
+ start_time = time.time()
37
+
38
+ prompt_embedding = model.encode(prompt, convert_to_tensor=True)
39
+ cos_scores = util.cos_sim(prompt_embedding, corpus_embeddings)[0]
40
+ top_results = cos_scores.topk(k=int(top_k))
41
+
42
+ end_time = time.time()
43
+
44
+ results = []
45
+ for score, idx in zip(top_results[0], top_results[1]):
46
+ results.append((corpus[idx], score.item()))
47
+
48
+ return results, f"{(end_time - start_time) * 1000:.2f} ms"
49
+
50
+ iface = gr.Interface(
51
+ fn=find_similar,
52
+ inputs=[
53
+ gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
54
+ gr.Slider(1, 20, value=5, step=1, label="Top K")
55
+ ],
56
+ outputs=[
57
+ gr.Dataframe(headers=["Response", "Score"]),
58
+ gr.Textbox(label="Time Taken")
59
+ ],
60
+ title="RSFT Alice embeddings",
61
+ description="Enter a prompt and get the most similar sentences from the corpus."
62
+ )
63
+
64
+ if __name__ == "__main__":
65
+ iface.launch()