Upload 28 files
Browse files- .gitattributes +5 -0
- __pycache__/spark_setup.cpython-311.pyc +0 -0
- app.py +273 -0
- data/abstracts.csv +3 -0
- data/affiliations.csv +0 -0
- data/author_affiliations.csv +0 -0
- data/authors.csv +3 -0
- data/clustering.csv +3 -0
- data/embeddings.csv +3 -0
- data/geospatial_clustering_data.csv +3 -0
- data/geospatial_data_by_publication.csv +0 -0
- data/keywords.csv +0 -0
- data/publications.csv +0 -0
- data/scopus_affiliation_data.csv +0 -0
- data/subject_areas.csv +0 -0
- embedding_model/1_Pooling/config.json +10 -0
- embedding_model/README.md +177 -0
- embedding_model/config.json +24 -0
- embedding_model/config_sentence_transformers.json +10 -0
- embedding_model/model.safetensors +3 -0
- embedding_model/modules.json +20 -0
- embedding_model/sentence_bert_config.json +4 -0
- embedding_model/special_tokens_map.json +51 -0
- embedding_model/tokenizer.json +0 -0
- embedding_model/tokenizer_config.json +73 -0
- embedding_model/vocab.txt +0 -0
- embeddings.npy +3 -0
- search.ipynb +0 -0
- spark_setup.py +40 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data/abstracts.csv filter=lfs diff=lfs merge=lfs -text
|
37 |
+
data/authors.csv filter=lfs diff=lfs merge=lfs -text
|
38 |
+
data/clustering.csv filter=lfs diff=lfs merge=lfs -text
|
39 |
+
data/embeddings.csv filter=lfs diff=lfs merge=lfs -text
|
40 |
+
data/geospatial_clustering_data.csv filter=lfs diff=lfs merge=lfs -text
|
__pycache__/spark_setup.cpython-311.pyc
ADDED
Binary file (1.97 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import streamlit as st
|
5 |
+
import plotly.express as px
|
6 |
+
from pyspark.sql import functions as F
|
7 |
+
from pyspark.sql import SparkSession
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
10 |
+
|
11 |
+
from spark_setup import create_spark_session, load_data, file_paths
|
12 |
+
|
13 |
+
############################
|
14 |
+
# Caching and Setup
|
15 |
+
############################
|
16 |
+
|
17 |
+
@st.cache_data(show_spinner=False)
|
18 |
+
def get_model(model_path="embedding_model"):
|
19 |
+
if not os.path.exists(model_path):
|
20 |
+
model = SentenceTransformer('all-mpnet-base-v2')
|
21 |
+
model.save(model_path)
|
22 |
+
return SentenceTransformer(model_path)
|
23 |
+
|
24 |
+
@st.cache_data(show_spinner=False)
|
25 |
+
def get_embeddings(embedding_path="embeddings.npy"):
|
26 |
+
return np.load(embedding_path)
|
27 |
+
|
28 |
+
@st.cache_resource(show_spinner=False)
|
29 |
+
def get_spark_session():
|
30 |
+
return create_spark_session()
|
31 |
+
|
32 |
+
@st.cache_resource(show_spinner=False)
|
33 |
+
def get_data(_spark):
|
34 |
+
return load_data(_spark, file_paths)
|
35 |
+
|
36 |
+
@st.cache_data(show_spinner=False)
|
37 |
+
def get_index_to_pub_id(_publications_spark_df):
|
38 |
+
pub_ids = _publications_spark_df.select("publication_id").rdd.map(lambda x: x.publication_id).collect()
|
39 |
+
return {idx: pub_id for idx, pub_id in enumerate(pub_ids)}
|
40 |
+
|
41 |
+
############################
|
42 |
+
# Main Code
|
43 |
+
############################
|
44 |
+
|
45 |
+
embeddings = get_embeddings()
|
46 |
+
model = get_model()
|
47 |
+
|
48 |
+
spark = get_spark_session()
|
49 |
+
dataframes = get_data(spark)
|
50 |
+
|
51 |
+
data = dataframes["geospatial_clustering_data"]
|
52 |
+
publications_df = dataframes["clustering"]
|
53 |
+
|
54 |
+
# Create the mapping from embedding index to publication_id
|
55 |
+
index_to_pub_id = get_index_to_pub_id(publications_df)
|
56 |
+
|
57 |
+
# Rename clusters as fields of study
|
58 |
+
field_topics = {
|
59 |
+
0: "Phylogenetics and Species Diversity",
|
60 |
+
1: "Advanced Materials and Nanotechnology",
|
61 |
+
2: "Bioactive Compounds and Antioxidant Studies",
|
62 |
+
3: "Catalysis and Energy Conversion",
|
63 |
+
4: "Machine Learning and Image Processing",
|
64 |
+
5: "Clinical and Epidemiological Studies",
|
65 |
+
6: "Social and Behavioral Research",
|
66 |
+
7: "Environmental Risk and Water Management",
|
67 |
+
8: "Microbiology and Antibiotic Resistance",
|
68 |
+
9: "Systems Engineering and Optimization",
|
69 |
+
10: "Virology and Infectious Diseases",
|
70 |
+
11: "Oral and Dental Research",
|
71 |
+
12: "Surgery and Clinical Outcomes",
|
72 |
+
13: "Composite Materials and Structural Engineering",
|
73 |
+
14: "Cancer Research and Cellular Mechanisms",
|
74 |
+
15: "Particle Physics and Cosmology",
|
75 |
+
16: "Psychiatry and Cognitive Disorders"
|
76 |
+
}
|
77 |
+
|
78 |
+
# Page configuration
|
79 |
+
st.set_page_config(
|
80 |
+
page_title="🌏 Chulalongkorn University Global Collaboration Explorer",
|
81 |
+
layout="wide",
|
82 |
+
page_icon="🌏"
|
83 |
+
)
|
84 |
+
|
85 |
+
# Initialize variables
|
86 |
+
field_id, field_name = -1, "All Fields"
|
87 |
+
keyword = None
|
88 |
+
|
89 |
+
# Sidebar
|
90 |
+
with st.sidebar:
|
91 |
+
st.title("🌟 Global Collaboration Explorer")
|
92 |
+
st.markdown("""
|
93 |
+
**Explore Chulalongkorn University's global academic collaborations**
|
94 |
+
Use the options below to choose a field of study or explore by keyword.
|
95 |
+
""")
|
96 |
+
|
97 |
+
# Add a search mode radio button
|
98 |
+
search_mode = st.radio(
|
99 |
+
"Exploration Mode:",
|
100 |
+
options=["Explore by Field of Study", "Explore by Keyword"],
|
101 |
+
index=0
|
102 |
+
)
|
103 |
+
|
104 |
+
if search_mode == "Explore by Field of Study":
|
105 |
+
st.markdown("#### 🎓 Select a Field of Study")
|
106 |
+
if "selected_field" not in st.session_state:
|
107 |
+
st.session_state.selected_field = -1
|
108 |
+
|
109 |
+
search_query = st.selectbox(
|
110 |
+
"Field of Study:",
|
111 |
+
options=[(-1, "All Fields")] + list(field_topics.items()),
|
112 |
+
format_func=lambda x: "All Fields" if x[0] == -1 else f"Field {x[0]}: {x[1]}",
|
113 |
+
index=st.session_state.selected_field + 1
|
114 |
+
)
|
115 |
+
field_id, field_name = search_query
|
116 |
+
st.session_state.selected_field = field_id
|
117 |
+
|
118 |
+
# Filter data based on selected field
|
119 |
+
if field_id == -1:
|
120 |
+
filtered_map_data_spark = data
|
121 |
+
else:
|
122 |
+
filtered_map_data_spark = data.filter(F.col("cluster") == field_id)
|
123 |
+
|
124 |
+
elif search_mode == "Explore by Keyword":
|
125 |
+
st.markdown("#### 🔍 Enter a Keyword")
|
126 |
+
keyword = st.text_input("Keyword:")
|
127 |
+
if keyword:
|
128 |
+
input_embedding = model.encode(keyword)
|
129 |
+
cos_similarities = cosine_similarity([input_embedding], embeddings)[0]
|
130 |
+
|
131 |
+
# Create similarity DataFrame
|
132 |
+
similarity_df = pd.DataFrame({
|
133 |
+
"publication_id": [index_to_pub_id[i] for i in range(len(embeddings))],
|
134 |
+
"similarity": cos_similarities
|
135 |
+
})
|
136 |
+
|
137 |
+
# Threshold filtering
|
138 |
+
similarity_threshold = 0.38
|
139 |
+
similarity_df = similarity_df[similarity_df["similarity"] >= similarity_threshold]
|
140 |
+
|
141 |
+
if similarity_df.empty:
|
142 |
+
filtered_map_data_spark = data.limit(0)
|
143 |
+
else:
|
144 |
+
# Convert to Spark DF and join all matched publications
|
145 |
+
similarity_spark_df = spark.createDataFrame(similarity_df)
|
146 |
+
joined_df = data.join(similarity_spark_df, on="publication_id", how="inner")
|
147 |
+
|
148 |
+
# Sort by similarity descending
|
149 |
+
filtered_map_data_spark = joined_df.orderBy(F.col("similarity").desc())
|
150 |
+
else:
|
151 |
+
filtered_map_data_spark = data.limit(0)
|
152 |
+
|
153 |
+
# Function to get unique affiliation count as points
|
154 |
+
def get_country_points(_filtered_spark_df):
|
155 |
+
return (
|
156 |
+
_filtered_spark_df.groupBy("country")
|
157 |
+
.agg(F.countDistinct("affiliation_id").alias("points"))
|
158 |
+
.orderBy(F.col("points").desc())
|
159 |
+
)
|
160 |
+
|
161 |
+
country_points_spark = get_country_points(filtered_map_data_spark)
|
162 |
+
filtered_map_data_pd = country_points_spark.toPandas()
|
163 |
+
|
164 |
+
def get_dynamic_country_options(pdf):
|
165 |
+
return [("All Countries", 0)] + [(row["country"], row["points"]) for _, row in pdf.iterrows()]
|
166 |
+
|
167 |
+
if "selected_country" not in st.session_state:
|
168 |
+
st.session_state.selected_country = "All Countries"
|
169 |
+
|
170 |
+
country_options = get_dynamic_country_options(filtered_map_data_pd)
|
171 |
+
|
172 |
+
selected_country = st.selectbox(
|
173 |
+
"Select a Country:",
|
174 |
+
options=country_options,
|
175 |
+
format_func=lambda x: f"{x[0]} ({x[1]} unique affiliations)" if x[0] != "All Countries" else "All Countries",
|
176 |
+
index=0
|
177 |
+
)
|
178 |
+
|
179 |
+
selected_country_name = selected_country[0]
|
180 |
+
st.session_state.selected_country = selected_country_name
|
181 |
+
|
182 |
+
# Statistics Table Section
|
183 |
+
st.markdown("#### 📊 Show Country Statistics")
|
184 |
+
show_stats = st.checkbox("Show Table", value=True)
|
185 |
+
|
186 |
+
# Main Title and Description
|
187 |
+
st.title("🌏 Chulalongkorn University's Global Research Collaborations")
|
188 |
+
|
189 |
+
if search_mode == "Explore by Field of Study":
|
190 |
+
st.markdown(
|
191 |
+
f"**Exploring collaborations in:** {'All Fields' if field_id == -1 else field_name} "
|
192 |
+
f"**|** {'All Countries' if selected_country_name == 'All Countries' else selected_country_name}"
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
st.markdown(
|
196 |
+
f"**Exploring collaborations by keyword:** {'None' if not keyword else keyword} "
|
197 |
+
f"**|** {'All Countries' if selected_country_name == 'All Countries' else selected_country_name}"
|
198 |
+
)
|
199 |
+
|
200 |
+
# Filter by selected country if needed
|
201 |
+
if selected_country_name != "All Countries":
|
202 |
+
filtered_map_data_spark = filtered_map_data_spark.filter(F.col("country") == selected_country_name)
|
203 |
+
filtered_map_data_pd = (
|
204 |
+
filtered_map_data_spark.groupBy("country")
|
205 |
+
.agg(F.countDistinct("affiliation_id").alias("points"))
|
206 |
+
.orderBy(F.col("points").desc())
|
207 |
+
.toPandas()
|
208 |
+
)
|
209 |
+
|
210 |
+
if search_mode == "Explore by Field of Study":
|
211 |
+
title_text = f"Chulalongkorn University's Global Collaborations by {'All Fields' if field_id == -1 else field_name}"
|
212 |
+
else:
|
213 |
+
title_text = "Chulalongkorn University's Global Collaborations by Keyword"
|
214 |
+
|
215 |
+
fig = px.choropleth(
|
216 |
+
filtered_map_data_pd,
|
217 |
+
locations="country",
|
218 |
+
locationmode="country names",
|
219 |
+
color="points",
|
220 |
+
color_continuous_scale="Greens",
|
221 |
+
title=title_text,
|
222 |
+
labels={'points': 'Unique Affiliations'},
|
223 |
+
)
|
224 |
+
|
225 |
+
fig.update_geos(
|
226 |
+
showcountries=True,
|
227 |
+
countrycolor="Black",
|
228 |
+
showcoastlines=True,
|
229 |
+
coastlinecolor="Gray",
|
230 |
+
showland=True,
|
231 |
+
landcolor="white",
|
232 |
+
showocean=True,
|
233 |
+
oceancolor="lightblue",
|
234 |
+
projection_type="natural earth"
|
235 |
+
)
|
236 |
+
|
237 |
+
fig.update_layout(
|
238 |
+
title_font=dict(size=24, family="Arial"),
|
239 |
+
margin={"r": 10, "t": 50, "l": 10, "b": 10},
|
240 |
+
coloraxis_colorbar=dict(
|
241 |
+
title="Unique Affiliations",
|
242 |
+
title_font=dict(size=16, family="Arial"),
|
243 |
+
tickfont=dict(size=12, family="Arial"),
|
244 |
+
)
|
245 |
+
)
|
246 |
+
|
247 |
+
# Display the map
|
248 |
+
st.plotly_chart(fig, use_container_width=True)
|
249 |
+
|
250 |
+
# Show top 10 rows
|
251 |
+
top_10_pd = filtered_map_data_spark.limit(10).toPandas()
|
252 |
+
|
253 |
+
# Select only needed columns for preview and rename them
|
254 |
+
# Original columns: header -> affiliation, city -> city, country -> country, title_x -> title
|
255 |
+
display_df = top_10_pd[["header", "city", "country", "title_x"]].copy()
|
256 |
+
display_df.rename(columns={
|
257 |
+
"header": "affiliation",
|
258 |
+
"city": "city",
|
259 |
+
"country": "country",
|
260 |
+
"title_x": "title"
|
261 |
+
}, inplace=True)
|
262 |
+
|
263 |
+
st.markdown("### 📜 Example Papers from Chulalongkorn University and Its Partners")
|
264 |
+
st.dataframe(display_df)
|
265 |
+
|
266 |
+
# Display Country Statistics if enabled
|
267 |
+
if show_stats:
|
268 |
+
st.markdown("---")
|
269 |
+
st.subheader("🌐 Country Statistics (Unique Affiliations)")
|
270 |
+
if filtered_map_data_pd.empty:
|
271 |
+
st.write("No data available for the selected filters.")
|
272 |
+
else:
|
273 |
+
st.dataframe(filtered_map_data_pd.style.format(precision=0).set_properties(**{'text-align': 'left'}))
|
data/abstracts.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:87f092b08f6a0997fb7e7dc8a69bb2a0bfb50f1f97a74e766650816d7fc67937
|
3 |
+
size 27109000
|
data/affiliations.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/author_affiliations.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/authors.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fcf90cce3f96116cf903568cc968c046f04fe0b2fdc94d8823035483820f3d9a
|
3 |
+
size 87265279
|
data/clustering.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3e2fe64e97eac26714a67cf25c339310e5c1d0d28c7d66af44e918a3d7d7adf4
|
3 |
+
size 62658453
|
data/embeddings.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b92d220f5a9a6bf5a6bc19845024f0d66708fb64b69a79a7f2782170f36bad96
|
3 |
+
size 87835864
|
data/geospatial_clustering_data.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:086e0645f51d81341aef6d5981ae862b048df69a4dc5b211d1a07ce003dc97c1
|
3 |
+
size 90943859
|
data/geospatial_data_by_publication.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/keywords.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/publications.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/scopus_affiliation_data.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/subject_areas.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
embedding_model/1_Pooling/config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"word_embedding_dimension": 768,
|
3 |
+
"pooling_mode_cls_token": false,
|
4 |
+
"pooling_mode_mean_tokens": true,
|
5 |
+
"pooling_mode_max_tokens": false,
|
6 |
+
"pooling_mode_mean_sqrt_len_tokens": false,
|
7 |
+
"pooling_mode_weightedmean_tokens": false,
|
8 |
+
"pooling_mode_lasttoken": false,
|
9 |
+
"include_prompt": true
|
10 |
+
}
|
embedding_model/README.md
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: en
|
3 |
+
license: apache-2.0
|
4 |
+
library_name: sentence-transformers
|
5 |
+
tags:
|
6 |
+
- sentence-transformers
|
7 |
+
- feature-extraction
|
8 |
+
- sentence-similarity
|
9 |
+
- transformers
|
10 |
+
datasets:
|
11 |
+
- s2orc
|
12 |
+
- flax-sentence-embeddings/stackexchange_xml
|
13 |
+
- ms_marco
|
14 |
+
- gooaq
|
15 |
+
- yahoo_answers_topics
|
16 |
+
- code_search_net
|
17 |
+
- search_qa
|
18 |
+
- eli5
|
19 |
+
- snli
|
20 |
+
- multi_nli
|
21 |
+
- wikihow
|
22 |
+
- natural_questions
|
23 |
+
- trivia_qa
|
24 |
+
- embedding-data/sentence-compression
|
25 |
+
- embedding-data/flickr30k-captions
|
26 |
+
- embedding-data/altlex
|
27 |
+
- embedding-data/simple-wiki
|
28 |
+
- embedding-data/QQP
|
29 |
+
- embedding-data/SPECTER
|
30 |
+
- embedding-data/PAQ_pairs
|
31 |
+
- embedding-data/WikiAnswers
|
32 |
+
pipeline_tag: sentence-similarity
|
33 |
+
---
|
34 |
+
|
35 |
+
|
36 |
+
# all-mpnet-base-v2
|
37 |
+
This is a [sentence-transformers](https://www.SBERT.net) model: It maps sentences & paragraphs to a 768 dimensional dense vector space and can be used for tasks like clustering or semantic search.
|
38 |
+
|
39 |
+
## Usage (Sentence-Transformers)
|
40 |
+
Using this model becomes easy when you have [sentence-transformers](https://www.SBERT.net) installed:
|
41 |
+
|
42 |
+
```
|
43 |
+
pip install -U sentence-transformers
|
44 |
+
```
|
45 |
+
|
46 |
+
Then you can use the model like this:
|
47 |
+
```python
|
48 |
+
from sentence_transformers import SentenceTransformer
|
49 |
+
sentences = ["This is an example sentence", "Each sentence is converted"]
|
50 |
+
|
51 |
+
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
|
52 |
+
embeddings = model.encode(sentences)
|
53 |
+
print(embeddings)
|
54 |
+
```
|
55 |
+
|
56 |
+
## Usage (HuggingFace Transformers)
|
57 |
+
Without [sentence-transformers](https://www.SBERT.net), you can use the model like this: First, you pass your input through the transformer model, then you have to apply the right pooling-operation on-top of the contextualized word embeddings.
|
58 |
+
|
59 |
+
```python
|
60 |
+
from transformers import AutoTokenizer, AutoModel
|
61 |
+
import torch
|
62 |
+
import torch.nn.functional as F
|
63 |
+
|
64 |
+
#Mean Pooling - Take attention mask into account for correct averaging
|
65 |
+
def mean_pooling(model_output, attention_mask):
|
66 |
+
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
67 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
68 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
69 |
+
|
70 |
+
|
71 |
+
# Sentences we want sentence embeddings for
|
72 |
+
sentences = ['This is an example sentence', 'Each sentence is converted']
|
73 |
+
|
74 |
+
# Load model from HuggingFace Hub
|
75 |
+
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
|
76 |
+
model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
|
77 |
+
|
78 |
+
# Tokenize sentences
|
79 |
+
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
80 |
+
|
81 |
+
# Compute token embeddings
|
82 |
+
with torch.no_grad():
|
83 |
+
model_output = model(**encoded_input)
|
84 |
+
|
85 |
+
# Perform pooling
|
86 |
+
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
87 |
+
|
88 |
+
# Normalize embeddings
|
89 |
+
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
|
90 |
+
|
91 |
+
print("Sentence embeddings:")
|
92 |
+
print(sentence_embeddings)
|
93 |
+
```
|
94 |
+
|
95 |
+
## Evaluation Results
|
96 |
+
|
97 |
+
For an automated evaluation of this model, see the *Sentence Embeddings Benchmark*: [https://seb.sbert.net](https://seb.sbert.net?model_name=sentence-transformers/all-mpnet-base-v2)
|
98 |
+
|
99 |
+
------
|
100 |
+
|
101 |
+
## Background
|
102 |
+
|
103 |
+
The project aims to train sentence embedding models on very large sentence level datasets using a self-supervised
|
104 |
+
contrastive learning objective. We used the pretrained [`microsoft/mpnet-base`](https://huggingface.co/microsoft/mpnet-base) model and fine-tuned in on a
|
105 |
+
1B sentence pairs dataset. We use a contrastive learning objective: given a sentence from the pair, the model should predict which out of a set of randomly sampled other sentences, was actually paired with it in our dataset.
|
106 |
+
|
107 |
+
We developped this model during the
|
108 |
+
[Community week using JAX/Flax for NLP & CV](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104),
|
109 |
+
organized by Hugging Face. We developped this model as part of the project:
|
110 |
+
[Train the Best Sentence Embedding Model Ever with 1B Training Pairs](https://discuss.huggingface.co/t/train-the-best-sentence-embedding-model-ever-with-1b-training-pairs/7354). We benefited from efficient hardware infrastructure to run the project: 7 TPUs v3-8, as well as intervention from Googles Flax, JAX, and Cloud team member about efficient deep learning frameworks.
|
111 |
+
|
112 |
+
## Intended uses
|
113 |
+
|
114 |
+
Our model is intented to be used as a sentence and short paragraph encoder. Given an input text, it ouptuts a vector which captures
|
115 |
+
the semantic information. The sentence vector may be used for information retrieval, clustering or sentence similarity tasks.
|
116 |
+
|
117 |
+
By default, input text longer than 384 word pieces is truncated.
|
118 |
+
|
119 |
+
|
120 |
+
## Training procedure
|
121 |
+
|
122 |
+
### Pre-training
|
123 |
+
|
124 |
+
We use the pretrained [`microsoft/mpnet-base`](https://huggingface.co/microsoft/mpnet-base) model. Please refer to the model card for more detailed information about the pre-training procedure.
|
125 |
+
|
126 |
+
### Fine-tuning
|
127 |
+
|
128 |
+
We fine-tune the model using a contrastive objective. Formally, we compute the cosine similarity from each possible sentence pairs from the batch.
|
129 |
+
We then apply the cross entropy loss by comparing with true pairs.
|
130 |
+
|
131 |
+
#### Hyper parameters
|
132 |
+
|
133 |
+
We trained ou model on a TPU v3-8. We train the model during 100k steps using a batch size of 1024 (128 per TPU core).
|
134 |
+
We use a learning rate warm up of 500. The sequence length was limited to 128 tokens. We used the AdamW optimizer with
|
135 |
+
a 2e-5 learning rate. The full training script is accessible in this current repository: `train_script.py`.
|
136 |
+
|
137 |
+
#### Training data
|
138 |
+
|
139 |
+
We use the concatenation from multiple datasets to fine-tune our model. The total number of sentence pairs is above 1 billion sentences.
|
140 |
+
We sampled each dataset given a weighted probability which configuration is detailed in the `data_config.json` file.
|
141 |
+
|
142 |
+
|
143 |
+
| Dataset | Paper | Number of training tuples |
|
144 |
+
|--------------------------------------------------------|:----------------------------------------:|:--------------------------:|
|
145 |
+
| [Reddit comments (2015-2018)](https://github.com/PolyAI-LDN/conversational-datasets/tree/master/reddit) | [paper](https://arxiv.org/abs/1904.06472) | 726,484,430 |
|
146 |
+
| [S2ORC](https://github.com/allenai/s2orc) Citation pairs (Abstracts) | [paper](https://aclanthology.org/2020.acl-main.447/) | 116,288,806 |
|
147 |
+
| [WikiAnswers](https://github.com/afader/oqa#wikianswers-corpus) Duplicate question pairs | [paper](https://doi.org/10.1145/2623330.2623677) | 77,427,422 |
|
148 |
+
| [PAQ](https://github.com/facebookresearch/PAQ) (Question, Answer) pairs | [paper](https://arxiv.org/abs/2102.07033) | 64,371,441 |
|
149 |
+
| [S2ORC](https://github.com/allenai/s2orc) Citation pairs (Titles) | [paper](https://aclanthology.org/2020.acl-main.447/) | 52,603,982 |
|
150 |
+
| [S2ORC](https://github.com/allenai/s2orc) (Title, Abstract) | [paper](https://aclanthology.org/2020.acl-main.447/) | 41,769,185 |
|
151 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title, Body) pairs | - | 25,316,456 |
|
152 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title+Body, Answer) pairs | - | 21,396,559 |
|
153 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title, Answer) pairs | - | 21,396,559 |
|
154 |
+
| [MS MARCO](https://microsoft.github.io/msmarco/) triplets | [paper](https://doi.org/10.1145/3404835.3462804) | 9,144,553 |
|
155 |
+
| [GOOAQ: Open Question Answering with Diverse Answer Types](https://github.com/allenai/gooaq) | [paper](https://arxiv.org/pdf/2104.08727.pdf) | 3,012,496 |
|
156 |
+
| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 1,198,260 |
|
157 |
+
| [Code Search](https://huggingface.co/datasets/code_search_net) | - | 1,151,414 |
|
158 |
+
| [COCO](https://cocodataset.org/#home) Image captions | [paper](https://link.springer.com/chapter/10.1007%2F978-3-319-10602-1_48) | 828,395|
|
159 |
+
| [SPECTER](https://github.com/allenai/specter) citation triplets | [paper](https://doi.org/10.18653/v1/2020.acl-main.207) | 684,100 |
|
160 |
+
| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Question, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 681,164 |
|
161 |
+
| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Question) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 659,896 |
|
162 |
+
| [SearchQA](https://huggingface.co/datasets/search_qa) | [paper](https://arxiv.org/abs/1704.05179) | 582,261 |
|
163 |
+
| [Eli5](https://huggingface.co/datasets/eli5) | [paper](https://doi.org/10.18653/v1/p19-1346) | 325,475 |
|
164 |
+
| [Flickr 30k](https://shannon.cs.illinois.edu/DenotationGraph/) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/229/33) | 317,695 |
|
165 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (titles) | | 304,525 |
|
166 |
+
| AllNLI ([SNLI](https://nlp.stanford.edu/projects/snli/) and [MultiNLI](https://cims.nyu.edu/~sbowman/multinli/) | [paper SNLI](https://doi.org/10.18653/v1/d15-1075), [paper MultiNLI](https://doi.org/10.18653/v1/n18-1101) | 277,230 |
|
167 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (bodies) | | 250,519 |
|
168 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (titles+bodies) | | 250,460 |
|
169 |
+
| [Sentence Compression](https://github.com/google-research-datasets/sentence-compression) | [paper](https://www.aclweb.org/anthology/D13-1155/) | 180,000 |
|
170 |
+
| [Wikihow](https://github.com/pvl/wikihow_pairs_dataset) | [paper](https://arxiv.org/abs/1810.09305) | 128,542 |
|
171 |
+
| [Altlex](https://github.com/chridey/altlex/) | [paper](https://aclanthology.org/P16-1135.pdf) | 112,696 |
|
172 |
+
| [Quora Question Triplets](https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs) | - | 103,663 |
|
173 |
+
| [Simple Wikipedia](https://cs.pomona.edu/~dkauchak/simplification/) | [paper](https://www.aclweb.org/anthology/P11-2117/) | 102,225 |
|
174 |
+
| [Natural Questions (NQ)](https://ai.google.com/research/NaturalQuestions) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/1455) | 100,231 |
|
175 |
+
| [SQuAD2.0](https://rajpurkar.github.io/SQuAD-explorer/) | [paper](https://aclanthology.org/P18-2124.pdf) | 87,599 |
|
176 |
+
| [TriviaQA](https://huggingface.co/datasets/trivia_qa) | - | 73,346 |
|
177 |
+
| **Total** | | **1,170,060,424** |
|
embedding_model/config.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "sentence-transformers/all-mpnet-base-v2",
|
3 |
+
"architectures": [
|
4 |
+
"MPNetModel"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"bos_token_id": 0,
|
8 |
+
"eos_token_id": 2,
|
9 |
+
"hidden_act": "gelu",
|
10 |
+
"hidden_dropout_prob": 0.1,
|
11 |
+
"hidden_size": 768,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"intermediate_size": 3072,
|
14 |
+
"layer_norm_eps": 1e-05,
|
15 |
+
"max_position_embeddings": 514,
|
16 |
+
"model_type": "mpnet",
|
17 |
+
"num_attention_heads": 12,
|
18 |
+
"num_hidden_layers": 12,
|
19 |
+
"pad_token_id": 1,
|
20 |
+
"relative_attention_num_buckets": 32,
|
21 |
+
"torch_dtype": "float32",
|
22 |
+
"transformers_version": "4.47.0",
|
23 |
+
"vocab_size": 30527
|
24 |
+
}
|
embedding_model/config_sentence_transformers.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"__version__": {
|
3 |
+
"sentence_transformers": "3.3.1",
|
4 |
+
"transformers": "4.47.0",
|
5 |
+
"pytorch": "2.5.1+cpu"
|
6 |
+
},
|
7 |
+
"prompts": {},
|
8 |
+
"default_prompt_name": null,
|
9 |
+
"similarity_fn_name": "cosine"
|
10 |
+
}
|
embedding_model/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0b3c8c717335c801abb15983036a6f1df4b6943fd6b93717969efd96d22eeec6
|
3 |
+
size 437967672
|
embedding_model/modules.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"idx": 0,
|
4 |
+
"name": "0",
|
5 |
+
"path": "",
|
6 |
+
"type": "sentence_transformers.models.Transformer"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"idx": 1,
|
10 |
+
"name": "1",
|
11 |
+
"path": "1_Pooling",
|
12 |
+
"type": "sentence_transformers.models.Pooling"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"idx": 2,
|
16 |
+
"name": "2",
|
17 |
+
"path": "2_Normalize",
|
18 |
+
"type": "sentence_transformers.models.Normalize"
|
19 |
+
}
|
20 |
+
]
|
embedding_model/sentence_bert_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_seq_length": 384,
|
3 |
+
"do_lower_case": false
|
4 |
+
}
|
embedding_model/special_tokens_map.json
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"cls_token": {
|
10 |
+
"content": "<s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"eos_token": {
|
17 |
+
"content": "</s>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"mask_token": {
|
24 |
+
"content": "<mask>",
|
25 |
+
"lstrip": true,
|
26 |
+
"normalized": false,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
},
|
30 |
+
"pad_token": {
|
31 |
+
"content": "<pad>",
|
32 |
+
"lstrip": false,
|
33 |
+
"normalized": false,
|
34 |
+
"rstrip": false,
|
35 |
+
"single_word": false
|
36 |
+
},
|
37 |
+
"sep_token": {
|
38 |
+
"content": "</s>",
|
39 |
+
"lstrip": false,
|
40 |
+
"normalized": false,
|
41 |
+
"rstrip": false,
|
42 |
+
"single_word": false
|
43 |
+
},
|
44 |
+
"unk_token": {
|
45 |
+
"content": "[UNK]",
|
46 |
+
"lstrip": false,
|
47 |
+
"normalized": false,
|
48 |
+
"rstrip": false,
|
49 |
+
"single_word": false
|
50 |
+
}
|
51 |
+
}
|
embedding_model/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
embedding_model/tokenizer_config.json
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "<s>",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"1": {
|
12 |
+
"content": "<pad>",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"2": {
|
20 |
+
"content": "</s>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"3": {
|
28 |
+
"content": "<unk>",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": true,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"104": {
|
36 |
+
"content": "[UNK]",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
},
|
43 |
+
"30526": {
|
44 |
+
"content": "<mask>",
|
45 |
+
"lstrip": true,
|
46 |
+
"normalized": false,
|
47 |
+
"rstrip": false,
|
48 |
+
"single_word": false,
|
49 |
+
"special": true
|
50 |
+
}
|
51 |
+
},
|
52 |
+
"bos_token": "<s>",
|
53 |
+
"clean_up_tokenization_spaces": false,
|
54 |
+
"cls_token": "<s>",
|
55 |
+
"do_lower_case": true,
|
56 |
+
"eos_token": "</s>",
|
57 |
+
"extra_special_tokens": {},
|
58 |
+
"mask_token": "<mask>",
|
59 |
+
"max_length": 128,
|
60 |
+
"model_max_length": 384,
|
61 |
+
"pad_to_multiple_of": null,
|
62 |
+
"pad_token": "<pad>",
|
63 |
+
"pad_token_type_id": 0,
|
64 |
+
"padding_side": "right",
|
65 |
+
"sep_token": "</s>",
|
66 |
+
"stride": 0,
|
67 |
+
"strip_accents": null,
|
68 |
+
"tokenize_chinese_chars": true,
|
69 |
+
"tokenizer_class": "MPNetTokenizer",
|
70 |
+
"truncation_side": "right",
|
71 |
+
"truncation_strategy": "longest_first",
|
72 |
+
"unk_token": "[UNK]"
|
73 |
+
}
|
embedding_model/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
embeddings.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:18b6bf845c4c24135a9f4637328e2e738717e7862034f8a48261c76f534a24ce
|
3 |
+
size 61437056
|
search.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
spark_setup.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# spark_setup.py
|
2 |
+
from pyspark.sql import SparkSession
|
3 |
+
|
4 |
+
# Initialize Spark session
|
5 |
+
def create_spark_session(app_name="University Research Analysis", master="local"):
|
6 |
+
spark = SparkSession.builder \
|
7 |
+
.appName(app_name) \
|
8 |
+
.master(master) \
|
9 |
+
.getOrCreate()
|
10 |
+
return spark
|
11 |
+
|
12 |
+
# Load data into Spark DataFrames and return a dictionary of DataFrames
|
13 |
+
def load_data(spark, file_paths):
|
14 |
+
dataframes = {}
|
15 |
+
for name, path in file_paths.items():
|
16 |
+
dataframes[name] = spark.read.csv(path, header=True, inferSchema=True)
|
17 |
+
# Register as a temporary view
|
18 |
+
dataframes[name].createOrReplaceTempView(name)
|
19 |
+
return dataframes
|
20 |
+
|
21 |
+
# File paths for each dataset
|
22 |
+
file_paths = {
|
23 |
+
"author_affiliations": "data/author_affiliations.csv",
|
24 |
+
"affiliations": "data/affiliations.csv",
|
25 |
+
"subject_areas": "data/subject_areas.csv",
|
26 |
+
"keywords": "data/keywords.csv",
|
27 |
+
"publications": "data/publications.csv",
|
28 |
+
"authors": "data/authors.csv",
|
29 |
+
"embeddings": "data/embeddings.csv",
|
30 |
+
"clustering": "data/clustering.csv",
|
31 |
+
"abstracts": "data/abstracts.csv",
|
32 |
+
"geospatial_clustering_data": "data/geospatial_clustering_data.csv",
|
33 |
+
"geospatial_data_by_publication": "data/geospatial_data_by_publication.csv",
|
34 |
+
"scopus_affiliation_data": "data/scopus_affiliation_data.csv"
|
35 |
+
}
|
36 |
+
|
37 |
+
if __name__ == "__main__":
|
38 |
+
spark = create_spark_session()
|
39 |
+
dataframes = load_data(spark, file_paths)
|
40 |
+
print("Spark session initialized and data loaded.")
|