Update README.md
Browse files
README.md
CHANGED
@@ -102,27 +102,62 @@ pip install -U sentence-transformers
|
|
102 |
Then you can load this model and run inference.
|
103 |
```python
|
104 |
from sentence_transformers import SparseEncoder
|
105 |
-
|
106 |
-
|
107 |
-
model =
|
108 |
-
|
109 |
-
queries
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
#
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
```
|
127 |
|
128 |
<!--
|
|
|
102 |
Then you can load this model and run inference.
|
103 |
```python
|
104 |
from sentence_transformers import SparseEncoder
|
105 |
+
import numpy as np
|
106 |
+
|
107 |
+
def retrieve_top_k(model, queries, documents, top_k=3):
|
108 |
+
"""
|
109 |
+
Given a SparseEncoder model, a list of queries and documents,
|
110 |
+
returns for each query the top_k documents ranked by SPLADE score.
|
111 |
+
"""
|
112 |
+
# 1) Encode all queries and documents
|
113 |
+
query_embeddings = model.encode_query(queries) # shape: [n_queries, vocab_size]
|
114 |
+
document_embeddings = model.encode_document(documents) # shape: [n_docs, vocab_size]
|
115 |
+
|
116 |
+
# 2) Compute pairwise similarity
|
117 |
+
# result shape: [n_queries, n_docs]
|
118 |
+
sims = model.similarity(query_embeddings, document_embeddings).cpu().numpy()
|
119 |
+
|
120 |
+
# 3) For each query, pick top_k documents
|
121 |
+
all_results = []
|
122 |
+
for qi, query in enumerate(queries):
|
123 |
+
scores = sims[qi]
|
124 |
+
topk_idx = np.argsort(-scores)[:top_k]
|
125 |
+
results = [(idx, float(scores[idx]), documents[idx]) for idx in topk_idx]
|
126 |
+
all_results.append((query, results))
|
127 |
+
return all_results
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
# Load the SPLADE‐DistilBERT Arabic model
|
131 |
+
model_name = "Omartificial-Intelligence-Space/inference-free-splade-distilbert-base-Arabic-cased-nq"
|
132 |
+
print(f"Loading sparse model {model_name} …")
|
133 |
+
model = SparseEncoder(model_name)
|
134 |
+
|
135 |
+
# Example documents (could be paragraphs from your corpus)
|
136 |
+
documents = [
|
137 |
+
"ليونيل ميسي ولد وترعرع في وسط الأرجنتين، وتم تشخيصه بضعف هرمون النمو في طفولته.",
|
138 |
+
"علم روسيا هناك تفسيرات مختلفة لما تعنيه الألوان: الأبيض للنبلاء، الأزرق للصدق، الأحمر للشجاعة.",
|
139 |
+
"كانت جم��ورية تكساس دولة مستقلة في أمريكا الشمالية من 1836 إلى 1846.",
|
140 |
+
"تقع مكة المكرمة في غرب المملكة العربية السعودية، وهي أقدس مدن الإسلام.",
|
141 |
+
"برج خليفة في دبي هو أطول بناء من صنع الإنسان في العالم بارتفاع 828 متراً."
|
142 |
+
]
|
143 |
+
|
144 |
+
# Example queries
|
145 |
+
queries = [
|
146 |
+
"من هو ليونيل ميسي؟",
|
147 |
+
"ما معنى ألوان علم روسيا؟",
|
148 |
+
"ما هي جمهورية تكساس؟",
|
149 |
+
"أين تقع مكة المكرمة؟",
|
150 |
+
"ما هو أطول مبنى في العالم؟"
|
151 |
+
]
|
152 |
+
|
153 |
+
# Retrieve top-3 docs per query
|
154 |
+
results = retrieve_top_k(model, queries, documents, top_k=2)
|
155 |
+
|
156 |
+
# Print nicely
|
157 |
+
for query, hits in results:
|
158 |
+
print(f"\nQuery: {query}")
|
159 |
+
for rank, (doc_idx, score, doc_text) in enumerate(hits, start=1):
|
160 |
+
print(f" {rank}. (score={score:.4f}) {doc_text}")
|
161 |
```
|
162 |
|
163 |
<!--
|