alexandergagliano commited on
Commit
37b78a0
·
1 Parent(s): 8ff867d

update examples in nbs and example dir

Browse files
examples/advanced_usage.py CHANGED
@@ -8,6 +8,8 @@ This script demonstrates advanced features of reLAISS including:
8
  - Setting maximum neighbor distances
9
  - Tweaking ANNOY parameters
10
  - Making corner plots
 
 
11
  """
12
 
13
  import os
@@ -38,6 +40,8 @@ def main():
38
  # Create output directories
39
  os.makedirs('./figures', exist_ok=True)
40
  os.makedirs('./sfddata-master', exist_ok=True)
 
 
41
 
42
  # Initialize the client
43
  client = rl.ReLAISS()
@@ -150,6 +154,90 @@ def main():
150
  path_to_figure_directory='./figures',
151
  save_plots=True
152
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  if __name__ == "__main__":
155
  main()
 
8
  - Setting maximum neighbor distances
9
  - Tweaking ANNOY parameters
10
  - Making corner plots
11
+ - Advanced anomaly detection with parameter tuning
12
+ - Host swapping in anomaly detection
13
  """
14
 
15
  import os
 
40
  # Create output directories
41
  os.makedirs('./figures', exist_ok=True)
42
  os.makedirs('./sfddata-master', exist_ok=True)
43
+ os.makedirs('./models', exist_ok=True)
44
+ os.makedirs('./timeseries', exist_ok=True)
45
 
46
  # Initialize the client
47
  client = rl.ReLAISS()
 
154
  path_to_figure_directory='./figures',
155
  save_plots=True
156
  )
157
+
158
+ # Example 7: Advanced anomaly detection with parameter tuning
159
+ print("\nExample 7: Advanced anomaly detection with parameter tuning")
160
+ from relaiss.anomaly import train_AD_model, anomaly_detection
161
+
162
+ # Train models with different parameters to compare
163
+ print("Training anomaly detection model with default parameters...")
164
+ default_model_path = train_AD_model(
165
+ lc_features=client.lc_features,
166
+ host_features=client.host_features,
167
+ path_to_dataset_bank=client.bank_csv,
168
+ path_to_sfd_folder='./sfddata-master',
169
+ path_to_models_directory="./models",
170
+ n_estimators=100,
171
+ contamination=0.02,
172
+ max_samples=256,
173
+ force_retrain=True
174
+ )
175
+
176
+ print("Training anomaly detection model with more trees...")
177
+ model_more_trees_path = train_AD_model(
178
+ lc_features=client.lc_features,
179
+ host_features=client.host_features,
180
+ path_to_dataset_bank=client.bank_csv,
181
+ path_to_sfd_folder='./sfddata-master',
182
+ path_to_models_directory="./models",
183
+ n_estimators=200, # More trees
184
+ contamination=0.02,
185
+ max_samples=256,
186
+ force_retrain=True
187
+ )
188
+
189
+ print("Training anomaly detection model with higher contamination...")
190
+ model_higher_contam_path = train_AD_model(
191
+ lc_features=client.lc_features,
192
+ host_features=client.host_features,
193
+ path_to_dataset_bank=client.bank_csv,
194
+ path_to_sfd_folder='./sfddata-master',
195
+ path_to_models_directory="./models",
196
+ n_estimators=100,
197
+ contamination=0.05, # Higher contamination
198
+ max_samples=256,
199
+ force_retrain=True
200
+ )
201
+
202
+ # Run anomaly detection with each model
203
+ print("\nRunning anomaly detection with default model...")
204
+ anomaly_detection(
205
+ transient_ztf_id="ZTF21abbzjeq",
206
+ lc_features=client.lc_features,
207
+ host_features=client.host_features,
208
+ path_to_timeseries_folder="./timeseries",
209
+ path_to_sfd_folder='./sfddata-master',
210
+ path_to_dataset_bank=client.bank_csv,
211
+ path_to_models_directory="./models",
212
+ path_to_figure_directory="./figures/AD_default",
213
+ save_figures=True,
214
+ n_estimators=100,
215
+ contamination=0.02,
216
+ max_samples=256,
217
+ force_retrain=False
218
+ )
219
+
220
+ # Example 8: Anomaly detection with host swapping
221
+ print("\nExample 8: Anomaly detection with host swapping")
222
+ # Use the default model but swap in a different host galaxy
223
+ anomaly_detection(
224
+ transient_ztf_id="ZTF21abbzjeq",
225
+ lc_features=client.lc_features,
226
+ host_features=client.host_features,
227
+ path_to_timeseries_folder="./timeseries",
228
+ path_to_sfd_folder='./sfddata-master',
229
+ path_to_dataset_bank=client.bank_csv,
230
+ host_ztf_id_to_swap_in="ZTF21aakswqr", # Swap in a different host
231
+ path_to_models_directory="./models",
232
+ path_to_figure_directory="./figures/AD_host_swap",
233
+ save_figures=True,
234
+ n_estimators=100,
235
+ contamination=0.02,
236
+ max_samples=256,
237
+ force_retrain=False
238
+ )
239
+
240
+ print("Anomaly detection figures saved to ./figures/AD_default/ and ./figures/AD_host_swap/")
241
 
242
  if __name__ == "__main__":
243
  main()
examples/basic_usage.py CHANGED
@@ -6,6 +6,7 @@ This script demonstrates the basic functionality of reLAISS including:
6
  - Running nearest neighbor search
7
  - Using Monte Carlo simulations
8
  - Adjusting feature weights
 
9
  """
10
 
11
  import os
@@ -15,6 +16,8 @@ def main():
15
  # Create output directories
16
  os.makedirs('./figures', exist_ok=True)
17
  os.makedirs('./sfddata-master', exist_ok=True)
 
 
18
 
19
  # Initialize the client
20
  client = rl.ReLAISS()
@@ -55,7 +58,7 @@ def main():
55
  neighbors_df = client.find_neighbors(
56
  ztf_object_id='ZTF21abbzjeq', # Using the test transient
57
  n=5,
58
- num_mc_simulations=20, # Number of Monte Carlo simulations
59
  weight_lc_feats_factor=3.0, # Up-weight lightcurve features
60
  plot=True,
61
  save_figures=True,
@@ -63,6 +66,44 @@ def main():
63
  )
64
  print("\nNearest neighbors with MC simulations:")
65
  print(neighbors_df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  if __name__ == "__main__":
68
  main()
 
6
  - Running nearest neighbor search
7
  - Using Monte Carlo simulations
8
  - Adjusting feature weights
9
+ - Basic anomaly detection
10
  """
11
 
12
  import os
 
16
  # Create output directories
17
  os.makedirs('./figures', exist_ok=True)
18
  os.makedirs('./sfddata-master', exist_ok=True)
19
+ os.makedirs('./models', exist_ok=True)
20
+ os.makedirs('./timeseries', exist_ok=True)
21
 
22
  # Initialize the client
23
  client = rl.ReLAISS()
 
58
  neighbors_df = client.find_neighbors(
59
  ztf_object_id='ZTF21abbzjeq', # Using the test transient
60
  n=5,
61
+ num_sims=20, # Number of Monte Carlo simulations
62
  weight_lc_feats_factor=3.0, # Up-weight lightcurve features
63
  plot=True,
64
  save_figures=True,
 
66
  )
67
  print("\nNearest neighbors with MC simulations:")
68
  print(neighbors_df)
69
+
70
+ # Example 4: Basic anomaly detection
71
+ print("\nExample 4: Basic anomaly detection")
72
+ from relaiss.anomaly import train_AD_model, anomaly_detection
73
+
74
+ # First, train an anomaly detection model
75
+ print("Training anomaly detection model...")
76
+ model_path = train_AD_model(
77
+ lc_features=client.lc_features,
78
+ host_features=client.host_features,
79
+ path_to_dataset_bank=client.bank_csv,
80
+ path_to_sfd_folder='./sfddata-master',
81
+ path_to_models_directory="./models",
82
+ n_estimators=100, # Using smaller value for faster execution
83
+ contamination=0.02, # Expected proportion of anomalies
84
+ max_samples=256, # Max samples per tree
85
+ force_retrain=False # Only retrain if model doesn't exist
86
+ )
87
+ print(f"Anomaly detection model saved to: {model_path}")
88
+
89
+ # Now, run anomaly detection on a specific transient
90
+ print("\nRunning anomaly detection...")
91
+ anomaly_detection(
92
+ transient_ztf_id="ZTF21abbzjeq", # Same test transient
93
+ lc_features=client.lc_features,
94
+ host_features=client.host_features,
95
+ path_to_timeseries_folder="./timeseries",
96
+ path_to_sfd_folder='./sfddata-master',
97
+ path_to_dataset_bank=client.bank_csv,
98
+ path_to_models_directory="./models",
99
+ path_to_figure_directory="./figures",
100
+ save_figures=True,
101
+ n_estimators=100,
102
+ contamination=0.02,
103
+ max_samples=256,
104
+ force_retrain=False
105
+ )
106
+ print("Anomaly detection figures saved to ./figures/AD/")
107
 
108
  if __name__ == "__main__":
109
  main()
examples/build_databank.py CHANGED
@@ -6,12 +6,14 @@ This script demonstrates how to build a new dataset bank for reLAISS, including:
6
  2. Joining new lightcurve features
7
  3. Handling missing values
8
  4. Building the final dataset bank
 
9
 
10
  The process involves several steps:
11
  1. Add extinction corrections to the large dataset bank
12
  2. Join new lightcurve features to the small dataset bank
13
  3. Handle missing values using KNN imputation
14
  4. Build the final dataset bank with all features
 
15
  """
16
 
17
  import os
@@ -21,6 +23,7 @@ from sfdmap2 import sfdmap
21
  from sklearn.impute import KNNImputer
22
  from relaiss.features import build_dataset_bank
23
  from relaiss import constants
 
24
 
25
  def add_extinction_corrections(df, path_to_sfd_folder):
26
  """Add extinction corrections (A_V) to the dataset.
@@ -95,6 +98,8 @@ def main():
95
  # Create necessary directories
96
  os.makedirs('../data', exist_ok=True)
97
  os.makedirs('../data/sfddata-master', exist_ok=True)
 
 
98
 
99
  # Step 1: Add extinction corrections to large dataset bank
100
  print("\nStep 1: Adding extinction corrections")
@@ -146,6 +151,109 @@ def main():
146
  print("Shape of final dataset bank:", final_dataset_bank.shape)
147
  final_dataset_bank.to_csv('../data/large_final_df_bank_new_lc_feats.csv', index=False)
148
  print("Successfully saved final dataset bank!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  if __name__ == "__main__":
151
  main()
 
6
  2. Joining new lightcurve features
7
  3. Handling missing values
8
  4. Building the final dataset bank
9
+ 5. Using different feature combinations for nearest neighbor search
10
 
11
  The process involves several steps:
12
  1. Add extinction corrections to the large dataset bank
13
  2. Join new lightcurve features to the small dataset bank
14
  3. Handle missing values using KNN imputation
15
  4. Build the final dataset bank with all features
16
+ 5. Demonstrate custom feature selection for nearest neighbor search
17
  """
18
 
19
  import os
 
23
  from sklearn.impute import KNNImputer
24
  from relaiss.features import build_dataset_bank
25
  from relaiss import constants
26
+ import relaiss as rl
27
 
28
  def add_extinction_corrections(df, path_to_sfd_folder):
29
  """Add extinction corrections (A_V) to the dataset.
 
98
  # Create necessary directories
99
  os.makedirs('../data', exist_ok=True)
100
  os.makedirs('../data/sfddata-master', exist_ok=True)
101
+ os.makedirs('./figures', exist_ok=True)
102
+ os.makedirs('./sfddata-master', exist_ok=True)
103
 
104
  # Step 1: Add extinction corrections to large dataset bank
105
  print("\nStep 1: Adding extinction corrections")
 
151
  print("Shape of final dataset bank:", final_dataset_bank.shape)
152
  final_dataset_bank.to_csv('../data/large_final_df_bank_new_lc_feats.csv', index=False)
153
  print("Successfully saved final dataset bank!")
154
+
155
+ # Step 5: Demonstrate different feature combinations for search
156
+ print("\nStep 5: Using different feature combinations")
157
+
158
+ # Define default feature sets from constants
159
+ default_lc_features = constants.lc_features_const.copy()
160
+ default_host_features = constants.host_features_const.copy()
161
+
162
+ # Initialize client
163
+ client = rl.ReLAISS()
164
+ client.load_reference(
165
+ path_to_sfd_folder='./sfddata-master'
166
+ )
167
+
168
+ # Example 1: Using only lightcurve features (no host features)
169
+ print("\nExample 1: Using only lightcurve features")
170
+ lc_only_client = rl.ReLAISS()
171
+ lc_only_client.load_reference(
172
+ path_to_sfd_folder='./sfddata-master',
173
+ lc_features=default_lc_features, # Use default lightcurve features
174
+ host_features=[], # Empty list means no host features
175
+ )
176
+
177
+ # Find neighbors using only lightcurve features
178
+ neighbors_df_lc_only = lc_only_client.find_neighbors(
179
+ ztf_object_id='ZTF21abbzjeq',
180
+ n=5,
181
+ plot=True,
182
+ save_figures=True,
183
+ path_to_figure_directory='./figures/lc_only'
184
+ )
185
+ print("\nNearest neighbors using only lightcurve features:")
186
+ print(neighbors_df_lc_only)
187
+
188
+ # Example 2: Using only host features (no lightcurve features)
189
+ print("\nExample 2: Using only host features")
190
+ host_only_client = rl.ReLAISS()
191
+ host_only_client.load_reference(
192
+ path_to_sfd_folder='./sfddata-master',
193
+ lc_features=[], # Empty list means no lightcurve features
194
+ host_features=default_host_features, # Use default host features
195
+ )
196
+
197
+ # Find neighbors using only host features
198
+ neighbors_df_host_only = host_only_client.find_neighbors(
199
+ ztf_object_id='ZTF21abbzjeq',
200
+ n=5,
201
+ plot=True,
202
+ save_figures=True,
203
+ path_to_figure_directory='./figures/host_only'
204
+ )
205
+ print("\nNearest neighbors using only host features:")
206
+ print(neighbors_df_host_only)
207
+
208
+ # Example 3: Using custom feature subset
209
+ print("\nExample 3: Using custom feature subset")
210
+ # Select specific lightcurve and host features
211
+ custom_lc_features = ['g_peak_mag', 'r_peak_mag', 'g_peak_time', 'r_peak_time']
212
+ custom_host_features = ['host_ra', 'host_dec', 'gKronMag', 'rKronMag']
213
+
214
+ custom_client = rl.ReLAISS()
215
+ custom_client.load_reference(
216
+ path_to_sfd_folder='./sfddata-master',
217
+ lc_features=custom_lc_features, # Custom subset of lightcurve features
218
+ host_features=custom_host_features, # Custom subset of host features
219
+ )
220
+
221
+ # Find neighbors with custom feature subset
222
+ neighbors_df_custom = custom_client.find_neighbors(
223
+ ztf_object_id='ZTF21abbzjeq',
224
+ n=5,
225
+ plot=True,
226
+ save_figures=True,
227
+ path_to_figure_directory='./figures/custom_features'
228
+ )
229
+ print("\nNearest neighbors using custom feature subset:")
230
+ print(neighbors_df_custom)
231
+
232
+ # Example 4: Setting feature importance with feature weighting
233
+ print("\nExample 4: Using feature weighting")
234
+ # Regular search prioritizing lightcurve features
235
+ neighbors_df_lc_weighted = client.find_neighbors(
236
+ ztf_object_id='ZTF21abbzjeq',
237
+ n=5,
238
+ weight_lc_feats_factor=3.0, # Strongly prioritize lightcurve features
239
+ plot=True,
240
+ save_figures=True,
241
+ path_to_figure_directory='./figures/lc_weighted'
242
+ )
243
+ print("\nNearest neighbors with lightcurve features weighted 3x:")
244
+ print(neighbors_df_lc_weighted)
245
+
246
+ # Now prioritize host features by using a factor < 1
247
+ neighbors_df_host_weighted = client.find_neighbors(
248
+ ztf_object_id='ZTF21abbzjeq',
249
+ n=5,
250
+ weight_lc_feats_factor=0.3, # Prioritize host features
251
+ plot=True,
252
+ save_figures=True,
253
+ path_to_figure_directory='./figures/host_weighted'
254
+ )
255
+ print("\nNearest neighbors with host features given higher weight:")
256
+ print(neighbors_df_host_weighted)
257
 
258
  if __name__ == "__main__":
259
  main()
notebooks/01_basic_usage.ipynb ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "648c2b16",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Basic Usage of reLAISS\n",
9
+ "### Authors: Evan Reynolds and Alex Gagliano\n",
10
+ "\n",
11
+ "## Introduction\n",
12
+ "\n",
13
+ "reLAISS is the second version of LAISS (Lightcurve Anomaly Identification & Similarity Search); a tool to find similar supernovae & identify anomalous supernovae (and the galaxies that host them) using their photometric features.\n",
14
+ "\n",
15
+ "The similarity search takes advantage of [Approximate Nearest Neighbors Oh Yeah (ANNOY)](https://github.com/spotify/annoy), the approximate nearest neighbors algorithm developed by Spotify that allows you to come up with a relevant song to listen to before your current one ends. The anomaly detection classifier is an isolation forest model trained on a dataset bank of over 22,000 transients.\n",
16
+ "\n",
17
+ "This notebook demonstrates the basic features of the reLAISS library for finding similar astronomical transients.\n",
18
+ "\n",
19
+ "## Topics Covered\n",
20
+ "1. Initializing the ReLAISS client\n",
21
+ "2. Loading reference data\n",
22
+ "3. Finding optimal number of neighbors\n",
23
+ "4. Basic nearest neighbor search\n",
24
+ "5. Using Monte Carlo simulations and feature weighting\n",
25
+ "6. Basic anomaly detection"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "id": "869e17e0",
31
+ "metadata": {},
32
+ "source": [
33
+ "## Setup\n",
34
+ "\n",
35
+ "First, let's import the required packages and create the necessary output directories:"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "id": "1bf22151",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "import os\n",
46
+ "import pandas as pd\n",
47
+ "import relaiss\n",
48
+ "\n",
49
+ "# Create output directories\n",
50
+ "os.makedirs('./figures', exist_ok=True)\n",
51
+ "os.makedirs('./sfddata-master', exist_ok=True)\n",
52
+ "os.makedirs('./models', exist_ok=True)\n",
53
+ "os.makedirs('./timeseries', exist_ok=True)"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "markdown",
58
+ "id": "0e82fe89",
59
+ "metadata": {},
60
+ "source": [
61
+ "## 1. Initialize the ReLAISS Client\n",
62
+ "\n",
63
+ "First, we create an instance of the ReLAISS client that we'll use to find similar transients."
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": null,
69
+ "id": "cc144944",
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "# Create ReLAISS client\n",
74
+ "client = relaiss.ReLAISS()"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "markdown",
79
+ "id": "5483eb7b",
80
+ "metadata": {},
81
+ "source": [
82
+ "## 2. Load Reference Data\n",
83
+ "\n",
84
+ "Next, we load the reference dataset bank. This contains the features of known transients that we'll use for comparison.\n",
85
+ "\n",
86
+ "The `load_reference` function will automatically download the SFD dust map files if they don't exist in the specified directory. These files are required for extinction corrections in the reLAISS pipeline."
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": null,
92
+ "id": "808bf098",
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "# Load reference data\n",
97
+ "client.load_reference(\n",
98
+ " path_to_sfd_folder='./sfddata-master', # Directory for SFD dust maps\n",
99
+ " use_pca=False, # Don't use PCA for this example\n",
100
+ " host_features=[] # Empty list for this example\n",
101
+ ")"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "markdown",
106
+ "id": "9ae5d46f",
107
+ "metadata": {},
108
+ "source": [
109
+ "## 3. Finding the Optimal Number of Neighbors\n",
110
+ "\n",
111
+ "Before doing a full neighbor search, we can use reLAISS to suggest an optimal number of neighbors based on the distance distribution. This helps avoid arbitrary choices for the number of neighbors to return.\n",
112
+ "\n",
113
+ "First, let's run a search with a larger number of neighbors and set `suggest_neighbor_num=True`. This will show us a distance plot that helps identify a reasonable cutoff point for similar objects."
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": null,
119
+ "id": "91904c57",
120
+ "metadata": {},
121
+ "outputs": [],
122
+ "source": [
123
+ "# Find optimal number of neighbors\n",
124
+ "client.find_neighbors(\n",
125
+ " ztf_object_id='ZTF21aaublej', # ZTF ID to find neighbors for\n",
126
+ " n=40, # Search in a larger pool\n",
127
+ " suggest_neighbor_num=True, # Only suggest optimal number, don't return neighbors\n",
128
+ " plot=True, # Show the distance elbow plot\n",
129
+ " save_figures=True, # Save plots to disk\n",
130
+ " path_to_figure_directory='./figures'\n",
131
+ ")"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "markdown",
136
+ "id": "83324aa1",
137
+ "metadata": {},
138
+ "source": [
139
+ "## 4. Basic Nearest Neighbor Search\n",
140
+ "\n",
141
+ "Now we can find the most similar transients to a given ZTF object. Let's use ZTF21aaublej as an example.\n",
142
+ "\n",
143
+ "The `find_neighbors` function allows you to:\n",
144
+ "- Specify the number of neighbors to return\n",
145
+ "- Set a maximum distance threshold\n",
146
+ "- Adjust the weight of lightcurve features relative to host features\n",
147
+ "- Generate diagnostic plots\n",
148
+ "\n",
149
+ "Based on the distance curve we saw earlier, we'll choose to return 5 neighbors."
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "id": "550f1c6c",
156
+ "metadata": {},
157
+ "outputs": [],
158
+ "source": [
159
+ "# Find nearest neighbors\n",
160
+ "neighbors_df = client.find_neighbors(\n",
161
+ " ztf_object_id='ZTF21aaublej', # ZTF ID to find neighbors for\n",
162
+ " n=5, # Number of neighbors to return\n",
163
+ " suggest_neighbor_num=False, # Return actual neighbors\n",
164
+ " plot=True, # Generate diagnostic plots\n",
165
+ " save_figures=True, # Save plots to disk\n",
166
+ " path_to_figure_directory='./figures'\n",
167
+ ")\n",
168
+ "\n",
169
+ "# Display the results\n",
170
+ "print(\"\\nNearest Neighbors:\")\n",
171
+ "print(neighbors_df)"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "markdown",
176
+ "id": "de3a82fb",
177
+ "metadata": {},
178
+ "source": [
179
+ "## 5. Using Monte Carlo Simulations and Feature Weighting\n",
180
+ "\n",
181
+ "reLAISS allows you to adjust the relative importance of lightcurve features compared to host galaxy features using the `weight_lc_feats_factor` parameter. A value greater than 1.0 will make lightcurve features more important in the similarity search.\n",
182
+ "\n",
183
+ "The Monte Carlo simulation functionality (`num_sims` parameter) helps account for measurement uncertainties by running multiple simulations with perturbed feature values.\n",
184
+ "\n",
185
+ "If you find that your matches aren't quite what you're looking for, you can try:\n",
186
+ "- Using Monte Carlo simulations to account for feature measurement uncertainties\n",
187
+ "- Upweighting lightcurve features to focus more on the transient's photometric properties than its host\n",
188
+ "- Removing host features entirely for a \"lightcurve-only\" search\n",
189
+ "- Removing lightcurve features for a \"host-only\" search\n",
190
+ "\n",
191
+ "Let's try using Monte Carlo simulations with upweighted lightcurve features:"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": null,
197
+ "id": "ba5d5748",
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "# Using Monte Carlo simulations and feature weighting\n",
202
+ "neighbors_df = client.find_neighbors(\n",
203
+ " ztf_object_id='ZTF21aaublej', # Using the test transient\n",
204
+ " n=5,\n",
205
+ " num_sims=20, # Number of Monte Carlo simulations\n",
206
+ " weight_lc_feats_factor=3.0, # Up-weight lightcurve features\n",
207
+ " plot=True,\n",
208
+ " save_figures=True,\n",
209
+ " path_to_figure_directory='./figures'\n",
210
+ ")\n",
211
+ "\n",
212
+ "print(\"\\nNearest neighbors with Monte Carlo simulations:\")\n",
213
+ "print(neighbors_df)"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "markdown",
218
+ "id": "e4a78f7d",
219
+ "metadata": {},
220
+ "source": [
221
+ "## 6. Basic Anomaly Detection\n",
222
+ "\n",
223
+ "reLAISS also includes tools for anomaly detection that can help identify unusual transients. The anomaly detection module uses an Isolation Forest algorithm to identify outliers in the feature space.\n",
224
+ "\n",
225
+ "The anomaly detection process will produce plots showing the lightcurve of the input transient and a graph of the probability (in time) that the transient is anomalous. If the probability exceeds 50% at any epoch, the transient is flagged as anomalous.\n",
226
+ "\n",
227
+ "### Training an Anomaly Detection Model\n",
228
+ "\n",
229
+ "First, let's train an anomaly detection model on our dataset bank:"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "id": "e19948ad",
236
+ "metadata": {},
237
+ "outputs": [],
238
+ "source": [
239
+ "from relaiss.anomaly import train_AD_model\n",
240
+ "\n",
241
+ "# Train the anomaly detection model\n",
242
+ "model_path = train_AD_model(\n",
243
+ " lc_features=client.lc_features,\n",
244
+ " host_features=client.host_features,\n",
245
+ " path_to_dataset_bank=client.bank_csv,\n",
246
+ " path_to_sfd_folder='./sfddata-master',\n",
247
+ " path_to_models_directory=\"./models\",\n",
248
+ " n_estimators=100, # Using a smaller value for faster execution\n",
249
+ " contamination=0.02, # Expected proportion of anomalies\n",
250
+ " max_samples=256, # Maximum samples used for each tree\n",
251
+ " force_retrain=False # Only retrain if model doesn't exist\n",
252
+ ")\n",
253
+ "\n",
254
+ "print(f\"Model saved to: {model_path}\")"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "markdown",
259
+ "id": "a104e058",
260
+ "metadata": {},
261
+ "source": [
262
+ "### Running Anomaly Detection on a Transient\n",
263
+ "\n",
264
+ "Now we can run anomaly detection on a specific transient to see if it's considered anomalous:"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": null,
270
+ "id": "53cbc236",
271
+ "metadata": {},
272
+ "outputs": [],
273
+ "source": [
274
+ "from relaiss.anomaly import anomaly_detection\n",
275
+ "\n",
276
+ "# Run anomaly detection on a transient\n",
277
+ "anomaly_detection(\n",
278
+ " transient_ztf_id=\"ZTF21aaublej\", # Use the same transient for this example\n",
279
+ " lc_features=client.lc_features,\n",
280
+ " host_features=client.host_features,\n",
281
+ " path_to_timeseries_folder=\"./timeseries\",\n",
282
+ " path_to_sfd_folder='./sfddata-master',\n",
283
+ " path_to_dataset_bank=client.bank_csv,\n",
284
+ " path_to_models_directory=\"./models\",\n",
285
+ " path_to_figure_directory=\"./figures\",\n",
286
+ " save_figures=True,\n",
287
+ " n_estimators=100,\n",
288
+ " contamination=0.02,\n",
289
+ " max_samples=256,\n",
290
+ " force_retrain=False\n",
291
+ ")\n",
292
+ "\n",
293
+ "print(\"Anomaly detection figures saved to ./figures/AD/\")"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "markdown",
298
+ "id": "c23a1b0a",
299
+ "metadata": {},
300
+ "source": [
301
+ "## Next Steps\n",
302
+ "\n",
303
+ "To explore more advanced features, check out the `advanced_usage.ipynb` notebook which covers:\n",
304
+ "- Using PCA for dimensionality reduction\n",
305
+ "- Creating theorized lightcurves\n",
306
+ "- Swapping host galaxies\n",
307
+ "- Setting maximum neighbor distances\n",
308
+ "- Tweaking ANNOY parameters\n",
309
+ "- Making corner plots\n",
310
+ "- Advanced anomaly detection techniques"
311
+ ]
312
+ }
313
+ ],
314
+ "metadata": {
315
+ "language_info": {
316
+ "name": "python"
317
+ }
318
+ },
319
+ "nbformat": 4,
320
+ "nbformat_minor": 5
321
+ }
notebooks/02_advanced_usage.ipynb ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "4b362a55",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Advanced Usage of reLAISS\n",
9
+ "### Authors: Evan Reynolds and Alex Gagliano\n",
10
+ "\n",
11
+ "## Introduction\n",
12
+ "\n",
13
+ "This notebook demonstrates advanced features of the reLAISS library for finding similar astronomical transients. While the basic_usage.ipynb notebook covered the fundamental functionality, here we'll explore more sophisticated techniques that give you greater flexibility and power in your analysis.\n",
14
+ "\n",
15
+ "These advanced features allow you to customize how reLAISS processes and analyzes data, including dimensionality reduction, theorized lightcurves, host galaxy swapping, fine-tuning of algorithm parameters, visualization tools, and advanced anomaly detection.\n",
16
+ "\n",
17
+ "## Topics Covered\n",
18
+ "1. Using PCA for dimensionality reduction\n",
19
+ "2. Creating and using theorized lightcurves\n",
20
+ "3. Swapping host galaxies\n",
21
+ "4. Setting maximum neighbor distances\n",
22
+ "5. Tweaking ANNOY parameters\n",
23
+ "6. Making corner plots\n",
24
+ "7. Advanced anomaly detection with parameter tuning\n",
25
+ "8. Host swapping in anomaly detection"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "id": "e3c3f836",
31
+ "metadata": {},
32
+ "source": [
33
+ "## Setup\n",
34
+ "\n",
35
+ "First, let's import the necessary packages and create the required directories:"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "id": "3cd152f1",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "import os\n",
46
+ "import pandas as pd\n",
47
+ "import numpy as np\n",
48
+ "import relaiss as rl\n",
49
+ "import astropy.units as u\n",
50
+ "import matplotlib.pyplot as plt\n",
51
+ "\n",
52
+ "# Create output directories\n",
53
+ "os.makedirs('./figures', exist_ok=True)\n",
54
+ "os.makedirs('./sfddata-master', exist_ok=True)\n",
55
+ "os.makedirs('./models', exist_ok=True)\n",
56
+ "os.makedirs('./timeseries', exist_ok=True)\n",
57
+ "\n",
58
+ "def create_theorized_lightcurve():\n",
59
+ " \"\"\"Create a simple theorized lightcurve for demonstration.\"\"\"\n",
60
+ " # Create time points\n",
61
+ " times = np.linspace(0, 100, 50) * u.day\n",
62
+ " # Create magnitudes (simple gaussian)\n",
63
+ " mags = 20 + 2 * np.exp(-(times.value - 50)**2 / 100)\n",
64
+ " # Create errors\n",
65
+ " errors = np.ones_like(mags) * 0.1\n",
66
+ " \n",
67
+ " # Create DataFrame in ANTARES format\n",
68
+ " df = pd.DataFrame({\n",
69
+ " 'ant_mjd': times.to(u.day).value,\n",
70
+ " 'ant_mag': mags,\n",
71
+ " 'ant_magerr': errors,\n",
72
+ " 'ant_passband': ['g' if i % 2 == 0 else 'R' for i in range(len(times))]\n",
73
+ " })\n",
74
+ " return df"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "markdown",
79
+ "id": "60730aa2",
80
+ "metadata": {},
81
+ "source": [
82
+ "## Initialize the ReLAISS Client\n",
83
+ "\n",
84
+ "We'll start by creating a ReLAISS client instance:"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "id": "b3adc953",
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "# Initialize the client\n",
95
+ "client = rl.ReLAISS()"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "markdown",
100
+ "id": "67b42e90",
101
+ "metadata": {},
102
+ "source": [
103
+ "## 1. Using PCA for Dimensionality Reduction\n",
104
+ "\n",
105
+ "PCA (Principal Component Analysis) can be used to reduce the dimensionality of the feature space while preserving most of the variance. This has several benefits:\n",
106
+ "\n",
107
+ "- Improves search speed by reducing the computational complexity\n",
108
+ "- Potentially reduces noise in the feature space\n",
109
+ "- Helps mitigate the \"curse of dimensionality\" for high-dimensional data\n",
110
+ "\n",
111
+ "To use PCA, we set `use_pca=True` in the `load_reference` method and specify the number of components to keep:"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "id": "a136889a",
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "client.load_reference(\n",
122
+ " path_to_sfd_folder='./sfddata-master',\n",
123
+ " use_pca=True,\n",
124
+ " num_pca_components=20, # Keep 20 PCA components\n",
125
+ ")\n",
126
+ "\n",
127
+ "neighbors_df = client.find_neighbors(\n",
128
+ " ztf_object_id='ZTF21abbzjeq', # Using the test transient\n",
129
+ " n=5,\n",
130
+ " plot=True,\n",
131
+ " save_figures=True,\n",
132
+ " path_to_figure_directory='./figures'\n",
133
+ ")\n",
134
+ "print(\"\\nNearest neighbors using PCA:\")\n",
135
+ "print(neighbors_df)"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "markdown",
140
+ "id": "f0ac7c6e",
141
+ "metadata": {},
142
+ "source": [
143
+ "## 2. Creating and Using Theorized Lightcurves\n",
144
+ "\n",
145
+ "One powerful feature of reLAISS is the ability to use theorized (synthetic) lightcurves in the neighbor search. This allows you to:\n",
146
+ "\n",
147
+ "- Test theoretical models against observed data\n",
148
+ "- Explore \"what-if\" scenarios by creating custom lightcurves\n",
149
+ "- Find real transients that match your theoretical predictions\n",
150
+ "\n",
151
+ "Below, we create a simple Gaussian-shaped lightcurve and find its nearest neighbors:"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": null,
157
+ "id": "16bb77d7",
158
+ "metadata": {},
159
+ "outputs": [],
160
+ "source": [
161
+ "# Create a theorized lightcurve\n",
162
+ "theorized_lc = create_theorized_lightcurve()\n",
163
+ "\n",
164
+ "# Find neighbors for the theorized lightcurve\n",
165
+ "# Need to provide a host galaxy when using theorized lightcurve\n",
166
+ "neighbors_df = client.find_neighbors(\n",
167
+ " theorized_lightcurve_df=theorized_lc,\n",
168
+ " host_ztf_id='ZTF21abbzjeq', # Use this transient's host\n",
169
+ " n=5,\n",
170
+ " plot=True,\n",
171
+ " save_figures=True,\n",
172
+ " path_to_figure_directory='./figures'\n",
173
+ ")\n",
174
+ "print(\"\\nNearest neighbors for theorized lightcurve:\")\n",
175
+ "print(neighbors_df)"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "markdown",
180
+ "id": "bab57bd3",
181
+ "metadata": {},
182
+ "source": [
183
+ "## 3. Swapping Host Galaxies\n",
184
+ "\n",
185
+ "reLAISS allows you to swap the host galaxy of a transient, which is useful for:\n",
186
+ "\n",
187
+ "- Exploring how host properties affect the similarity search results\n",
188
+ "- Investigating the effects of different environments on transient characteristics\n",
189
+ "- Testing hypotheses about host galaxy influences\n",
190
+ "\n",
191
+ "Here's how to swap in a different host galaxy:"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": null,
197
+ "id": "efdf2937",
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "# Find neighbors with a swapped host galaxy\n",
202
+ "neighbors_df = client.find_neighbors(\n",
203
+ " ztf_object_id='ZTF21abbzjeq', # Source transient\n",
204
+ " host_ztf_id='ZTF21aakswqr', # Host to swap in\n",
205
+ " n=5,\n",
206
+ " plot=True,\n",
207
+ " save_figures=True,\n",
208
+ " path_to_figure_directory='./figures'\n",
209
+ ")\n",
210
+ "print(\"\\nNearest neighbors with swapped host galaxy:\")\n",
211
+ "print(neighbors_df)"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "markdown",
216
+ "id": "0bd99f9c",
217
+ "metadata": {},
218
+ "source": [
219
+ "## 4. Setting Maximum Neighbor Distances\n",
220
+ "\n",
221
+ "Sometimes you're only interested in neighbors that are truly similar to your target. By setting a maximum distance threshold, you can:\n",
222
+ "\n",
223
+ "- Filter out neighbors that are too dissimilar\n",
224
+ "- Focus only on highly confident matches\n",
225
+ "- Avoid including poor matches just to reach a specific number of neighbors\n",
226
+ "\n",
227
+ "Note that you might get fewer neighbors than requested if the distance threshold is applied:"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": null,
233
+ "id": "3a859853",
234
+ "metadata": {},
235
+ "outputs": [],
236
+ "source": [
237
+ "# Find neighbors with maximum distance constraint\n",
238
+ "neighbors_df = client.find_neighbors(\n",
239
+ " ztf_object_id='ZTF21abbzjeq',\n",
240
+ " n=5,\n",
241
+ " max_neighbor_dist=0.5, # Only return neighbors within this distance\n",
242
+ " plot=True,\n",
243
+ " save_figures=True,\n",
244
+ " path_to_figure_directory='./figures'\n",
245
+ ")\n",
246
+ "print(\"\\nNearest neighbors with maximum distance constraint:\")\n",
247
+ "print(neighbors_df)\n",
248
+ "print(f\"Number of neighbors found: {len(neighbors_df)}\")"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "markdown",
253
+ "id": "a71acf61",
254
+ "metadata": {},
255
+ "source": [
256
+ "## 5. Tweaking ANNOY Parameters\n",
257
+ "\n",
258
+ "ANNOY (Approximate Nearest Neighbors Oh Yeah) is the algorithm used for fast nearest neighbor search. You can tune its parameters to balance search accuracy and speed:\n",
259
+ "\n",
260
+ "- `search_k`: Controls the number of nodes to explore during search (higher = more accurate but slower)\n",
261
+ "- `n_trees`: Controls the number of random projection trees built (set during client initialization)\n",
262
+ "\n",
263
+ "Here's how to adjust the search_k parameter:"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "code",
268
+ "execution_count": null,
269
+ "id": "aaa03a6c",
270
+ "metadata": {},
271
+ "outputs": [],
272
+ "source": [
273
+ "# Find neighbors with tweaked ANNOY parameters\n",
274
+ "neighbors_df = client.find_neighbors(\n",
275
+ " ztf_object_id='ZTF21abbzjeq',\n",
276
+ " n=5,\n",
277
+ " search_k=2000, # Increase search_k for more accurate results\n",
278
+ " plot=True,\n",
279
+ " save_figures=True,\n",
280
+ " path_to_figure_directory='./figures'\n",
281
+ ")\n",
282
+ "print(\"\\nNearest neighbors with tweaked ANNOY parameters:\")\n",
283
+ "print(neighbors_df)"
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "markdown",
288
+ "id": "5d7522b8",
289
+ "metadata": {},
290
+ "source": [
291
+ "## 6. Making Corner Plots\n",
292
+ "\n",
293
+ "Corner plots are a powerful visualization tool that show the distribution of features for the input transient and its neighbors. They can help you:\n",
294
+ "\n",
295
+ "- Understand which features are driving the similarity matching\n",
296
+ "- Identify potential correlations between different features\n",
297
+ "- Visualize the feature space and where your transient sits within it\n",
298
+ "\n",
299
+ "To create corner plots, we need to first get the primer_dict containing information about the input transient:"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "code",
304
+ "execution_count": null,
305
+ "id": "8d7ff16d",
306
+ "metadata": {},
307
+ "outputs": [],
308
+ "source": [
309
+ "# Get neighbors from a new search\n",
310
+ "neighbors_df = client.find_neighbors(\n",
311
+ " ztf_object_id='ZTF21abbzjeq',\n",
312
+ " n=5,\n",
313
+ " plot=True,\n",
314
+ " save_figures=True,\n",
315
+ " path_to_figure_directory='./figures'\n",
316
+ ")\n",
317
+ "\n",
318
+ "# Get primer_dict separately\n",
319
+ "from relaiss.search import primer\n",
320
+ "primer_dict = primer(\n",
321
+ " lc_ztf_id='ZTF21abbzjeq',\n",
322
+ " theorized_lightcurve_df=None,\n",
323
+ " host_ztf_id=None,\n",
324
+ " dataset_bank_path=client.bank_csv,\n",
325
+ " path_to_timeseries_folder='./',\n",
326
+ " path_to_sfd_folder=client.path_to_sfd_folder,\n",
327
+ " lc_features=client.lc_features,\n",
328
+ " host_features=client.host_features,\n",
329
+ " num_sims=0,\n",
330
+ " save_timeseries=False,\n",
331
+ ")\n",
332
+ "\n",
333
+ "# Create corner plots using the primer_dict\n",
334
+ "from relaiss.plotting import corner_plot\n",
335
+ "corner_plot(\n",
336
+ " neighbors_df=neighbors_df,\n",
337
+ " primer_dict=primer_dict,\n",
338
+ " path_to_dataset_bank=client.bank_csv,\n",
339
+ " path_to_figure_directory='./figures',\n",
340
+ " save_plots=True\n",
341
+ ")\n",
342
+ "print(\"Corner plots saved to ./figures/\")"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "markdown",
347
+ "id": "7079194a",
348
+ "metadata": {},
349
+ "source": [
350
+ "## 7. Advanced Anomaly Detection with Parameter Tuning\n",
351
+ "\n",
352
+ "The anomaly detection module in reLAISS uses an Isolation Forest algorithm that can be tuned for different scenarios. Key parameters include:\n",
353
+ "\n",
354
+ "- `n_estimators`: Number of base estimators (trees) in the ensemble\n",
355
+ "- `contamination`: Expected proportion of outliers in the dataset\n",
356
+ "- `max_samples`: Number of samples drawn to train each base estimator\n",
357
+ "\n",
358
+ "Let's explore how different parameters affect the model's performance:"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "code",
363
+ "execution_count": null,
364
+ "id": "726834d5",
365
+ "metadata": {},
366
+ "outputs": [],
367
+ "source": [
368
+ "from relaiss.anomaly import train_AD_model, anomaly_detection\n",
369
+ "\n",
370
+ "# Train models with different parameters to compare\n",
371
+ "print(\"Training anomaly detection model with default parameters...\")\n",
372
+ "default_model_path = train_AD_model(\n",
373
+ " lc_features=client.lc_features,\n",
374
+ " host_features=client.host_features,\n",
375
+ " path_to_dataset_bank=client.bank_csv,\n",
376
+ " path_to_sfd_folder='./sfddata-master',\n",
377
+ " path_to_models_directory=\"./models\",\n",
378
+ " n_estimators=100,\n",
379
+ " contamination=0.02,\n",
380
+ " max_samples=256,\n",
381
+ " force_retrain=True\n",
382
+ ")\n",
383
+ "\n",
384
+ "print(\"Training anomaly detection model with more trees...\")\n",
385
+ "model_more_trees_path = train_AD_model(\n",
386
+ " lc_features=client.lc_features,\n",
387
+ " host_features=client.host_features,\n",
388
+ " path_to_dataset_bank=client.bank_csv,\n",
389
+ " path_to_sfd_folder='./sfddata-master',\n",
390
+ " path_to_models_directory=\"./models\",\n",
391
+ " n_estimators=200, # More trees\n",
392
+ " contamination=0.02,\n",
393
+ " max_samples=256,\n",
394
+ " force_retrain=True\n",
395
+ ")\n",
396
+ "\n",
397
+ "print(\"Training anomaly detection model with higher contamination...\")\n",
398
+ "model_higher_contam_path = train_AD_model(\n",
399
+ " lc_features=client.lc_features,\n",
400
+ " host_features=client.host_features,\n",
401
+ " path_to_dataset_bank=client.bank_csv,\n",
402
+ " path_to_sfd_folder='./sfddata-master',\n",
403
+ " path_to_models_directory=\"./models\",\n",
404
+ " n_estimators=100,\n",
405
+ " contamination=0.05, # Higher contamination\n",
406
+ " max_samples=256,\n",
407
+ " force_retrain=True\n",
408
+ ")\n",
409
+ "\n",
410
+ "# Run anomaly detection with default model\n",
411
+ "print(\"\\nRunning anomaly detection with default model...\")\n",
412
+ "anomaly_detection(\n",
413
+ " transient_ztf_id=\"ZTF21abbzjeq\",\n",
414
+ " lc_features=client.lc_features,\n",
415
+ " host_features=client.host_features,\n",
416
+ " path_to_timeseries_folder=\"./timeseries\",\n",
417
+ " path_to_sfd_folder='./sfddata-master',\n",
418
+ " path_to_dataset_bank=client.bank_csv,\n",
419
+ " path_to_models_directory=\"./models\",\n",
420
+ " path_to_figure_directory=\"./figures/AD_default\",\n",
421
+ " save_figures=True,\n",
422
+ " n_estimators=100,\n",
423
+ " contamination=0.02,\n",
424
+ " max_samples=256,\n",
425
+ " force_retrain=False\n",
426
+ ")"
427
+ ]
428
+ },
429
+ {
430
+ "cell_type": "markdown",
431
+ "id": "4886eb0e",
432
+ "metadata": {},
433
+ "source": [
434
+ "## 8. Anomaly Detection with Host Swapping\n",
435
+ "\n",
436
+ "Just as with neighbor searches, you can swap host galaxies for anomaly detection. This helps you understand how host properties contribute to a transient's anomaly score.\n",
437
+ "\n",
438
+ "This feature is particularly useful for:\n",
439
+ "- Testing if the anomalous nature of a transient is due to its host galaxy\n",
440
+ "- Exploring the \"what if\" scenario of a transient occurring in a different environment\n",
441
+ "- Separating intrinsic transient anomalies from host-related factors"
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "execution_count": null,
447
+ "id": "5c65e14e",
448
+ "metadata": {},
449
+ "outputs": [],
450
+ "source": [
451
+ "# Use the default model but swap in a different host galaxy\n",
452
+ "anomaly_detection(\n",
453
+ " transient_ztf_id=\"ZTF21abbzjeq\",\n",
454
+ " lc_features=client.lc_features,\n",
455
+ " host_features=client.host_features,\n",
456
+ " path_to_timeseries_folder=\"./timeseries\",\n",
457
+ " path_to_sfd_folder='./sfddata-master',\n",
458
+ " path_to_dataset_bank=client.bank_csv,\n",
459
+ " host_ztf_id_to_swap_in=\"ZTF21aakswqr\", # Swap in a different host\n",
460
+ " path_to_models_directory=\"./models\",\n",
461
+ " path_to_figure_directory=\"./figures/AD_host_swap\",\n",
462
+ " save_figures=True,\n",
463
+ " n_estimators=100,\n",
464
+ " contamination=0.02,\n",
465
+ " max_samples=256,\n",
466
+ " force_retrain=False\n",
467
+ ")\n",
468
+ "\n",
469
+ "print(\"Anomaly detection figures saved to ./figures/AD_default/ and ./figures/AD_host_swap/\")"
470
+ ]
471
+ },
472
+ {
473
+ "cell_type": "markdown",
474
+ "id": "e08b6ef6",
475
+ "metadata": {},
476
+ "source": [
477
+ "## Conclusion\n",
478
+ "\n",
479
+ "By combining these features, you can create highly customized searches tailored to your specific research questions.\n",
480
+ "\n",
481
+ "For information on how to build your own dataset bank for reLAISS, see the `build_databank.ipynb` notebook."
482
+ ]
483
+ }
484
+ ],
485
+ "metadata": {},
486
+ "nbformat": 4,
487
+ "nbformat_minor": 5
488
+ }
notebooks/03_build_databank.ipynb ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "3ede68f0",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Building a New Dataset Bank for reLAISS\n",
9
+ "### Authors: Evan Reynolds and Alex Gagliano\n",
10
+ "\n",
11
+ "## Introduction\n",
12
+ "\n",
13
+ "This notebook demonstrates how to build a new dataset bank for reLAISS and use different feature combinations for nearest neighbor searches. The dataset bank is the foundation of reLAISS, containing all the features of transients that are used for similarity searches and anomaly detection.\n",
14
+ "\n",
15
+ "Building your own dataset bank allows you to incorporate new data, apply custom preprocessing steps, and tailor the feature set to your specific research needs.\n",
16
+ "\n",
17
+ "## Topics Covered\n",
18
+ "1. Adding extinction corrections (A_V)\n",
19
+ "2. Joining new lightcurve features\n",
20
+ "3. Handling missing values\n",
21
+ "4. Building the final dataset bank\n",
22
+ "5. Using different feature combinations for nearest neighbor search:\n",
23
+ " - Lightcurve-only features\n",
24
+ " - Host-only features\n",
25
+ " - Custom feature subsets\n",
26
+ " - Feature weighting"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "markdown",
31
+ "id": "9212bf17",
32
+ "metadata": {},
33
+ "source": [
34
+ "## Setup\n",
35
+ "\n",
36
+ "First, let's import the necessary libraries and create the required directories:"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": null,
42
+ "id": "ff694547",
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "import os\n",
47
+ "import pandas as pd\n",
48
+ "import numpy as np\n",
49
+ "from relaiss import constants\n",
50
+ "import relaiss as rl\n",
51
+ "\n",
52
+ "# Create necessary directories\n",
53
+ "os.makedirs('./figures', exist_ok=True)\n",
54
+ "os.makedirs('./sfddata-master', exist_ok=True)\n",
55
+ "\n",
56
+ "# Define default feature sets from constants\n",
57
+ "default_lc_features = constants.lc_features_const.copy()\n",
58
+ "default_host_features = constants.host_features_const.copy()\n",
59
+ "\n",
60
+ "# Initialize client\n",
61
+ "client = rl.ReLAISS()\n",
62
+ "client.load_reference(\n",
63
+ " path_to_sfd_folder='./sfddata-master'\n",
64
+ ")"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "markdown",
69
+ "id": "c7229e7f",
70
+ "metadata": {},
71
+ "source": [
72
+ "## 1. Adding Extinction Corrections (A_V)\n",
73
+ "\n",
74
+ "The first step in building a dataset bank is to add extinction corrections to account for interstellar dust. The Schlegel, Finkbeiner & Davis (SFD) dust maps are used to estimate the amount of extinction.\n",
75
+ "\n",
76
+ "```python\n",
77
+ "# Example code for adding extinction corrections\n",
78
+ "from sfdmap2 import sfdmap\n",
79
+ "\n",
80
+ "df = pd.read_csv(\"../data/large_df_bank.csv\")\n",
81
+ "m = sfdmap.SFDMap('../data/sfddata-master')\n",
82
+ "RV = 3.1 # Standard value for Milky Way\n",
83
+ "ebv = m.ebv(df['ra'].values, df['dec'].values)\n",
84
+ "df['A_V'] = RV * ebv\n",
85
+ "df.to_csv(\"../data/large_df_bank_wAV.csv\", index=False)\n",
86
+ "```\n",
87
+ "\n",
88
+ "This adds the A_V (extinction in V-band) column to your dataset, which will be used later in the feature processing pipeline."
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "markdown",
93
+ "id": "ac7ecbca",
94
+ "metadata": {},
95
+ "source": [
96
+ "## 2. Joining New Lightcurve Features\n",
97
+ "\n",
98
+ "If you have additional features in a separate dataset, you can merge them with your existing bank:\n",
99
+ "\n",
100
+ "```python\n",
101
+ "# Example code for joining features\n",
102
+ "df_large = pd.read_csv(\"../data/large_df_bank_wAV.csv\")\n",
103
+ "df_small = pd.read_csv(\"../data/small_df_bank_re_laiss.csv\")\n",
104
+ "\n",
105
+ "key = 'ztf_object_id'\n",
106
+ "extra_features = [col for col in df_large.columns if col not in df_small.columns]\n",
107
+ "\n",
108
+ "merged_df = df_small.merge(df_large[[key] + extra_features], on=key, how='left')\n",
109
+ "\n",
110
+ "lc_feature_names = constants.lc_features_const.copy()\n",
111
+ "host_feature_names = constants.host_features_const.copy()\n",
112
+ "\n",
113
+ "small_final_df = merged_df.replace([np.inf, -np.inf, -999], np.nan).dropna(subset=lc_feature_names + host_feature_names)\n",
114
+ "\n",
115
+ "small_final_df.to_csv(\"../data/small_hydrated_df_bank_re_laiss.csv\", index=False)\n",
116
+ "```\n",
117
+ "\n",
118
+ "This merges additional features from a larger dataset into your working dataset."
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "markdown",
123
+ "id": "a0fb3a25",
124
+ "metadata": {},
125
+ "source": [
126
+ "## 3. Handling Missing Values\n",
127
+ "\n",
128
+ "Missing values in the dataset can cause problems during analysis. reLAISS uses KNN imputation to fill in missing values:\n",
129
+ "\n",
130
+ "```python\n",
131
+ "# Example code for handling missing values\n",
132
+ "from sklearn.impute import KNNImputer\n",
133
+ "\n",
134
+ "raw_host_feature_names = constants.raw_host_features_const.copy()\n",
135
+ "raw_dataset_bank = pd.read_csv('../data/large_df_bank_wAV.csv')\n",
136
+ "\n",
137
+ "X = raw_dataset_bank[lc_feature_names + raw_host_feature_names]\n",
138
+ "feat_imputer = KNNImputer(weights='distance').fit(X)\n",
139
+ "imputed_filt_arr = feat_imputer.transform(X)\n",
140
+ "\n",
141
+ "imputed_df = pd.DataFrame(imputed_filt_arr, columns=lc_feature_names + raw_host_feature_names)\n",
142
+ "imputed_df.index = raw_dataset_bank.index\n",
143
+ "raw_dataset_bank[lc_feature_names + raw_host_feature_names] = imputed_df\n",
144
+ "\n",
145
+ "imputed_df_bank = raw_dataset_bank\n",
146
+ "```\n",
147
+ "\n",
148
+ "KNN imputation works by finding the k-nearest neighbors in feature space for samples with missing values and using their values to fill in the gaps."
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "markdown",
153
+ "id": "390d9e9d",
154
+ "metadata": {},
155
+ "source": [
156
+ "## 4. Building the Final Dataset Bank\n",
157
+ "\n",
158
+ "With all the preprocessing done, we can now build the final dataset bank using the `build_dataset_bank` function from reLAISS:\n",
159
+ "\n",
160
+ "```python\n",
161
+ "# Example code for building the final dataset bank\n",
162
+ "from relaiss.features import build_dataset_bank\n",
163
+ "\n",
164
+ "dataset_bank = build_dataset_bank(\n",
165
+ " raw_df_bank=imputed_df_bank,\n",
166
+ " av_in_raw_df_bank=True,\n",
167
+ " path_to_sfd_folder=\"../data/sfddata-master\",\n",
168
+ " building_entire_df_bank=True\n",
169
+ ")\n",
170
+ "\n",
171
+ "# Clean and save final dataset\n",
172
+ "final_dataset_bank = dataset_bank.replace(\n",
173
+ " [np.inf, -np.inf, -999], np.nan\n",
174
+ ").dropna(subset=lc_feature_names + host_feature_names)\n",
175
+ "\n",
176
+ "final_dataset_bank.to_csv('../data/large_final_df_bank_new_lc_feats.csv', index=False)\n",
177
+ "```\n",
178
+ "\n",
179
+ "This function applies additional processing to prepare the features for reLAISS, including normalization and other transformations."
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "markdown",
184
+ "id": "597b7b19",
185
+ "metadata": {},
186
+ "source": [
187
+ "## 5. Using Different Feature Combinations\n",
188
+ "\n",
189
+ "reLAISS allows you to customize which features are used for similarity search. This can be useful for studying the importance of different features and for tailoring the search to specific scientific questions."
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "markdown",
194
+ "id": "8eb99200",
195
+ "metadata": {},
196
+ "source": [
197
+ "### 5.1 Using Only Lightcurve Features\n",
198
+ "\n",
199
+ "You can perform a search using only lightcurve features, ignoring host galaxy properties. This is useful when:\n",
200
+ "- You want to focus solely on the temporal evolution of the transient\n",
201
+ "- Host data might be unreliable or missing\n",
202
+ "- You're testing hypotheses about lightcurve-based classification\n",
203
+ "\n",
204
+ "Here's how to set up a lightcurve-only search:"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": null,
210
+ "id": "bde7ec63",
211
+ "metadata": {},
212
+ "outputs": [],
213
+ "source": [
214
+ "lc_only_client = rl.ReLAISS()\n",
215
+ "lc_only_client.load_reference(\n",
216
+ " path_to_sfd_folder='./sfddata-master',\n",
217
+ " lc_features=default_lc_features, # Use default lightcurve features\n",
218
+ " host_features=[], # Empty list means no host features\n",
219
+ ")\n",
220
+ "\n",
221
+ "# Find neighbors using only lightcurve features\n",
222
+ "neighbors_df_lc_only = lc_only_client.find_neighbors(\n",
223
+ " ztf_object_id='ZTF21abbzjeq',\n",
224
+ " n=5,\n",
225
+ " plot=True,\n",
226
+ " save_figures=True,\n",
227
+ " path_to_figure_directory='./figures/lc_only'\n",
228
+ ")\n",
229
+ "print(\"\\nNearest neighbors using only lightcurve features:\")\n",
230
+ "print(neighbors_df_lc_only)"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "markdown",
235
+ "id": "865b24e5",
236
+ "metadata": {},
237
+ "source": [
238
+ "### 5.2 Using Only Host Features\n",
239
+ "\n",
240
+ "Alternatively, you can perform a search using only host galaxy features, ignoring the lightcurve properties. This approach is valuable when:\n",
241
+ "- You're more interested in environmental effects on transients\n",
242
+ "- You want to find transients in similar host galaxies\n",
243
+ "- You're studying correlations between host properties and transient types\n",
244
+ "\n",
245
+ "Here's how to set up a host-only search:"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": null,
251
+ "id": "305d8042",
252
+ "metadata": {},
253
+ "outputs": [],
254
+ "source": [
255
+ "host_only_client = rl.ReLAISS()\n",
256
+ "host_only_client.load_reference(\n",
257
+ " path_to_sfd_folder='./sfddata-master',\n",
258
+ " lc_features=[], # Empty list means no lightcurve features\n",
259
+ " host_features=default_host_features, # Use default host features\n",
260
+ ")\n",
261
+ "\n",
262
+ "# Find neighbors using only host features\n",
263
+ "neighbors_df_host_only = host_only_client.find_neighbors(\n",
264
+ " ztf_object_id='ZTF21abbzjeq',\n",
265
+ " n=5,\n",
266
+ " plot=True,\n",
267
+ " save_figures=True,\n",
268
+ " path_to_figure_directory='./figures/host_only'\n",
269
+ ")\n",
270
+ "print(\"\\nNearest neighbors using only host features:\")\n",
271
+ "print(neighbors_df_host_only)"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "markdown",
276
+ "id": "fbe91d24",
277
+ "metadata": {},
278
+ "source": [
279
+ "### 5.3 Using Custom Feature Subset\n",
280
+ "\n",
281
+ "You can also select specific features from both categories for a more targeted search. This allows you to:\n",
282
+ "- Focus on the features most relevant to your research question\n",
283
+ "- Reduce noise by excluding less useful features\n",
284
+ "- Test hypotheses about which features drive similarity\n",
285
+ "\n",
286
+ "Here's how to create a custom feature subset:"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "code",
291
+ "execution_count": null,
292
+ "id": "c7c721c7",
293
+ "metadata": {},
294
+ "outputs": [],
295
+ "source": [
296
+ "# Select specific lightcurve and host features\n",
297
+ "custom_lc_features = ['g_peak_mag', 'r_peak_mag', 'g_peak_time', 'r_peak_time']\n",
298
+ "custom_host_features = ['host_ra', 'host_dec', 'gKronMag', 'rKronMag']\n",
299
+ "\n",
300
+ "custom_client = rl.ReLAISS()\n",
301
+ "custom_client.load_reference(\n",
302
+ " path_to_sfd_folder='./sfddata-master',\n",
303
+ " lc_features=custom_lc_features, # Custom subset of lightcurve features\n",
304
+ " host_features=custom_host_features, # Custom subset of host features\n",
305
+ ")\n",
306
+ "\n",
307
+ "# Find neighbors with custom feature subset\n",
308
+ "neighbors_df_custom = custom_client.find_neighbors(\n",
309
+ " ztf_object_id='ZTF21abbzjeq',\n",
310
+ " n=5,\n",
311
+ " plot=True,\n",
312
+ " save_figures=True,\n",
313
+ " path_to_figure_directory='./figures/custom_features'\n",
314
+ ")\n",
315
+ "print(\"\\nNearest neighbors using custom feature subset:\")\n",
316
+ "print(neighbors_df_custom)"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "markdown",
321
+ "id": "7da9ee34",
322
+ "metadata": {},
323
+ "source": [
324
+ "### 5.4 Using Feature Weighting\n",
325
+ "\n",
326
+ "You can also adjust the relative importance of lightcurve features versus host galaxy features using the `weight_lc_feats_factor` parameter:\n",
327
+ "- Values > 1: Emphasize lightcurve features\n",
328
+ "- Values < 1: Emphasize host features\n",
329
+ "- Value = 1: Equal weighting (default)\n",
330
+ "\n",
331
+ "This allows you to fine-tune the balance between photometric and host properties:"
332
+ ]
333
+ },
334
+ {
335
+ "cell_type": "code",
336
+ "execution_count": null,
337
+ "id": "1209d70b",
338
+ "metadata": {},
339
+ "outputs": [],
340
+ "source": [
341
+ "# Regular search prioritizing lightcurve features\n",
342
+ "neighbors_df_lc_weighted = client.find_neighbors(\n",
343
+ " ztf_object_id='ZTF21abbzjeq',\n",
344
+ " n=5,\n",
345
+ " weight_lc_feats_factor=3.0, # Strongly prioritize lightcurve features\n",
346
+ " plot=True,\n",
347
+ " save_figures=True,\n",
348
+ " path_to_figure_directory='./figures/lc_weighted'\n",
349
+ ")\n",
350
+ "print(\"\\nNearest neighbors with lightcurve features weighted 3x:\")\n",
351
+ "print(neighbors_df_lc_weighted)\n",
352
+ "\n",
353
+ "# Now prioritize host features by using a factor < 1\n",
354
+ "neighbors_df_host_weighted = client.find_neighbors(\n",
355
+ " ztf_object_id='ZTF21abbzjeq',\n",
356
+ " n=5,\n",
357
+ " weight_lc_feats_factor=0.3, # Prioritize host features\n",
358
+ " plot=True,\n",
359
+ " save_figures=True,\n",
360
+ " path_to_figure_directory='./figures/host_weighted'\n",
361
+ ")\n",
362
+ "print(\"\\nNearest neighbors with host features given higher weight:\")\n",
363
+ "print(neighbors_df_host_weighted)"
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "markdown",
368
+ "id": "d243999b",
369
+ "metadata": {},
370
+ "source": [
371
+ "## Conclusion\n",
372
+ "\n",
373
+ "Building your own dataset bank and customizing feature combinations provides powerful flexibility for tailoring reLAISS to your specific research questions. By selecting different feature combinations and adjusting feature weights, you can explore various aspects of transient similarity and discover new insights about the transient population."
374
+ ]
375
+ }
376
+ ],
377
+ "metadata": {
378
+ "language_info": {
379
+ "name": "python"
380
+ }
381
+ },
382
+ "nbformat": 4,
383
+ "nbformat_minor": 5
384
+ }
src/relaiss/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
  __version_tuple__: VERSION_TUPLE
18
  version_tuple: VERSION_TUPLE
19
 
20
- __version__ = version = '0.0.post1.dev44'
21
- __version_tuple__ = version_tuple = (0, 0, 'post1', 'dev44')
 
17
  __version_tuple__: VERSION_TUPLE
18
  version_tuple: VERSION_TUPLE
19
 
20
+ __version__ = version = '1.0.0'
21
+ __version_tuple__ = version_tuple = (1, 0, 0)
src/relaiss/anomaly.py CHANGED
@@ -5,6 +5,7 @@ import antares_client
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
  import pandas as pd
 
8
  from pyod.models.iforest import IForest
9
  from sklearn.pipeline import Pipeline
10
  from sklearn.preprocessing import StandardScaler
@@ -17,6 +18,7 @@ def train_AD_model(
17
  host_features,
18
  path_to_dataset_bank=None,
19
  preprocessed_df=None,
 
20
  path_to_models_directory="../models",
21
  n_estimators=500,
22
  contamination=0.02,
@@ -36,6 +38,8 @@ def train_AD_model(
36
  preprocessed_df : pandas.DataFrame | None, optional
37
  Pre-processed dataframe with imputed features. If provided, this is used
38
  instead of loading and processing the raw dataset bank.
 
 
39
  path_to_models_directory : str | Path, default "../models"
40
  Directory to save trained models.
41
  n_estimators : int, default 500
@@ -88,6 +92,7 @@ def train_AD_model(
88
  raw_df = pd.read_csv(path_to_dataset_bank)
89
  df = build_dataset_bank(
90
  raw_df,
 
91
  building_entire_df_bank=True,
92
  building_for_AD=True
93
  )
@@ -118,7 +123,7 @@ def anomaly_detection(
118
  lc_features,
119
  host_features,
120
  path_to_timeseries_folder,
121
- path_to_sfd_data_folder,
122
  path_to_dataset_bank,
123
  host_ztf_id_to_swap_in=None,
124
  path_to_models_directory="../models",
@@ -159,6 +164,7 @@ def anomaly_detection(
159
  lc_features,
160
  host_features,
161
  path_to_dataset_bank,
 
162
  path_to_models_directory=path_to_models_directory,
163
  n_estimators=n_estimators,
164
  contamination=contamination,
@@ -167,8 +173,7 @@ def anomaly_detection(
167
  )
168
 
169
  # Load the model
170
- with open(path_to_trained_model, "rb") as f:
171
- clf = pickle.load(f)
172
 
173
  # Load the timeseries dataframe
174
  print("\nRebuilding timeseries dataframe(s) for AD...")
@@ -176,7 +181,7 @@ def anomaly_detection(
176
  ztf_id=transient_ztf_id,
177
  theorized_lightcurve_df=None,
178
  path_to_timeseries_folder=path_to_timeseries_folder,
179
- path_to_sfd_data_folder=path_to_sfd_data_folder,
180
  path_to_dataset_bank=path_to_dataset_bank,
181
  save_timeseries=False,
182
  building_for_AD=True,
@@ -188,7 +193,7 @@ def anomaly_detection(
188
  ztf_id=host_ztf_id_to_swap_in,
189
  theorized_lightcurve_df=None,
190
  path_to_timeseries_folder=path_to_timeseries_folder,
191
- path_to_sfd_data_folder=path_to_sfd_data_folder,
192
  path_to_dataset_bank=path_to_dataset_bank,
193
  save_timeseries=False,
194
  building_for_AD=True,
@@ -271,9 +276,20 @@ def check_anom_and_plot(
271
  """
272
  anom_obj_df = timeseries_df_features_only
273
 
274
- pred_prob_anom = 100 * clf.predict_proba(anom_obj_df)
275
- pred_prob_anom[:, 0] = [round(a, 1) for a in pred_prob_anom[:, 0]]
276
- pred_prob_anom[:, 1] = [round(b, 1) for b in pred_prob_anom[:, 1]]
 
 
 
 
 
 
 
 
 
 
 
277
  num_anom_epochs = len(np.where(pred_prob_anom[:, 1] >= anom_thresh)[0])
278
 
279
  try:
 
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
  import pandas as pd
8
+ import joblib
9
  from pyod.models.iforest import IForest
10
  from sklearn.pipeline import Pipeline
11
  from sklearn.preprocessing import StandardScaler
 
18
  host_features,
19
  path_to_dataset_bank=None,
20
  preprocessed_df=None,
21
+ path_to_sfd_folder=None,
22
  path_to_models_directory="../models",
23
  n_estimators=500,
24
  contamination=0.02,
 
38
  preprocessed_df : pandas.DataFrame | None, optional
39
  Pre-processed dataframe with imputed features. If provided, this is used
40
  instead of loading and processing the raw dataset bank.
41
+ path_to_sfd_folder : str | Path | None, optional
42
+ Path to SFD dust maps.
43
  path_to_models_directory : str | Path, default "../models"
44
  Directory to save trained models.
45
  n_estimators : int, default 500
 
92
  raw_df = pd.read_csv(path_to_dataset_bank)
93
  df = build_dataset_bank(
94
  raw_df,
95
+ path_to_sfd_folder=path_to_sfd_folder,
96
  building_entire_df_bank=True,
97
  building_for_AD=True
98
  )
 
123
  lc_features,
124
  host_features,
125
  path_to_timeseries_folder,
126
+ path_to_sfd_folder,
127
  path_to_dataset_bank,
128
  host_ztf_id_to_swap_in=None,
129
  path_to_models_directory="../models",
 
164
  lc_features,
165
  host_features,
166
  path_to_dataset_bank,
167
+ path_to_sfd_folder=path_to_sfd_folder,
168
  path_to_models_directory=path_to_models_directory,
169
  n_estimators=n_estimators,
170
  contamination=contamination,
 
173
  )
174
 
175
  # Load the model
176
+ clf = joblib.load(path_to_trained_model)
 
177
 
178
  # Load the timeseries dataframe
179
  print("\nRebuilding timeseries dataframe(s) for AD...")
 
181
  ztf_id=transient_ztf_id,
182
  theorized_lightcurve_df=None,
183
  path_to_timeseries_folder=path_to_timeseries_folder,
184
+ path_to_sfd_folder=path_to_sfd_folder,
185
  path_to_dataset_bank=path_to_dataset_bank,
186
  save_timeseries=False,
187
  building_for_AD=True,
 
193
  ztf_id=host_ztf_id_to_swap_in,
194
  theorized_lightcurve_df=None,
195
  path_to_timeseries_folder=path_to_timeseries_folder,
196
+ path_to_sfd_folder=path_to_sfd_folder,
197
  path_to_dataset_bank=path_to_dataset_bank,
198
  save_timeseries=False,
199
  building_for_AD=True,
 
276
  """
277
  anom_obj_df = timeseries_df_features_only
278
 
279
+ # Get anomaly scores from decision_function (-ve = anomalous, +ve = normal)
280
+ # Convert to probabilities (0-100 scale)
281
+ scores = clf.decision_function(anom_obj_df)
282
+ # Normalize scores to [0,1] - more negative means more anomalous
283
+ # Convert to a format compatible with the rest of the function: [[normal_prob, anomaly_prob], ...]
284
+ pred_prob_anom = np.zeros((len(scores), 2))
285
+ for i, score in enumerate(scores):
286
+ # Convert decision scores to probability-like values (0-100 scale)
287
+ # Lower scores = more anomalous
288
+ anomaly_prob = 100 * (1 / (1 + np.exp(score))) # Sigmoid function to convert to [0,100]
289
+ normal_prob = 100 - anomaly_prob
290
+ pred_prob_anom[i, 0] = round(normal_prob, 1) # normal probability
291
+ pred_prob_anom[i, 1] = round(anomaly_prob, 1) # anomaly probability
292
+
293
  num_anom_epochs = len(np.where(pred_prob_anom[:, 1] >= anom_thresh)[0])
294
 
295
  try:
src/relaiss/features.py CHANGED
@@ -243,6 +243,8 @@ def build_dataset_bank(
243
  # if "ztf_object_id" is the index, move it to the first column
244
  if raw_df_bank.index.name == "ztf_object_id":
245
  raw_df_bank = raw_df_bank.reset_index()
 
 
246
 
247
  if theorized:
248
  raw_features = raw_lc_features
 
243
  # if "ztf_object_id" is the index, move it to the first column
244
  if raw_df_bank.index.name == "ztf_object_id":
245
  raw_df_bank = raw_df_bank.reset_index()
246
+ elif 'ZTFID' in raw_df_bank.columns.values:
247
+ raw_df_bank['ztf_object_id'] = raw_df_bank['ZTFID']
248
 
249
  if theorized:
250
  raw_features = raw_lc_features
tests/test_ad.py CHANGED
@@ -210,7 +210,7 @@ def test_anomaly_detection_basic(sample_preprocessed_df, tmp_path, setup_sfd_dat
210
  lc_features=client.lc_features,
211
  host_features=client.host_features,
212
  path_to_timeseries_folder=str(timeseries_dir),
213
- path_to_sfd_data_folder=str(setup_sfd_data),
214
  path_to_dataset_bank=client.bank_csv,
215
  path_to_models_directory=str(tmp_path),
216
  path_to_figure_directory=str(tmp_path / "figures"),
@@ -240,7 +240,7 @@ def test_anomaly_detection_with_host_swap(sample_preprocessed_df, tmp_path, setu
240
  lc_features=client.lc_features,
241
  host_features=client.host_features,
242
  path_to_timeseries_folder=str(timeseries_dir),
243
- path_to_sfd_data_folder=str(setup_sfd_data),
244
  path_to_dataset_bank=client.bank_csv,
245
  host_ztf_id_to_swap_in="ZTF19aaaaaaa", # Swap in this host
246
  path_to_models_directory=str(tmp_path),
@@ -355,7 +355,7 @@ def test_anomaly_detection_basic(sample_preprocessed_df, tmp_path):
355
  lc_features=lc_features,
356
  host_features=host_features,
357
  path_to_timeseries_folder=str(timeseries_dir),
358
- path_to_sfd_data_folder=str(sfd_dir),
359
  path_to_dataset_bank=str(dataset_bank),
360
  path_to_models_directory=str(tmp_path),
361
  path_to_figure_directory=str(figures_dir),
@@ -536,7 +536,7 @@ def test_anomaly_detection_with_host_swap(sample_preprocessed_df, tmp_path):
536
  lc_features=lc_features,
537
  host_features=host_features,
538
  path_to_timeseries_folder=str(timeseries_dir),
539
- path_to_sfd_data_folder=str(sfd_dir),
540
  path_to_dataset_bank=str(dataset_bank),
541
  host_ztf_id_to_swap_in="ZTF19aaaaaaa", # Swap in this host
542
  path_to_models_directory=str(tmp_path),
 
210
  lc_features=client.lc_features,
211
  host_features=client.host_features,
212
  path_to_timeseries_folder=str(timeseries_dir),
213
+ path_to_sfd_folder=str(setup_sfd_data),
214
  path_to_dataset_bank=client.bank_csv,
215
  path_to_models_directory=str(tmp_path),
216
  path_to_figure_directory=str(tmp_path / "figures"),
 
240
  lc_features=client.lc_features,
241
  host_features=client.host_features,
242
  path_to_timeseries_folder=str(timeseries_dir),
243
+ path_to_sfd_folder=str(setup_sfd_data),
244
  path_to_dataset_bank=client.bank_csv,
245
  host_ztf_id_to_swap_in="ZTF19aaaaaaa", # Swap in this host
246
  path_to_models_directory=str(tmp_path),
 
355
  lc_features=lc_features,
356
  host_features=host_features,
357
  path_to_timeseries_folder=str(timeseries_dir),
358
+ path_to_sfd_folder=str(sfd_dir),
359
  path_to_dataset_bank=str(dataset_bank),
360
  path_to_models_directory=str(tmp_path),
361
  path_to_figure_directory=str(figures_dir),
 
536
  lc_features=lc_features,
537
  host_features=host_features,
538
  path_to_timeseries_folder=str(timeseries_dir),
539
+ path_to_sfd_folder=str(sfd_dir),
540
  path_to_dataset_bank=str(dataset_bank),
541
  host_ztf_id_to_swap_in="ZTF19aaaaaaa", # Swap in this host
542
  path_to_models_directory=str(tmp_path),
tests/test_ad_host_swap_simple.py CHANGED
@@ -80,7 +80,7 @@ def test_ad_host_swap_simple(tmp_path):
80
  lc_features=lc_features,
81
  host_features=host_features,
82
  path_to_timeseries_folder=str(tmp_path),
83
- path_to_sfd_data_folder=None, # Ignored due to mocking
84
  path_to_dataset_bank=None, # Ignored due to mocking
85
  host_ztf_id_to_swap_in="ZTF19aaaaaaa", # Swap in this host
86
  path_to_models_directory=str(model_dir),
 
80
  lc_features=lc_features,
81
  host_features=host_features,
82
  path_to_timeseries_folder=str(tmp_path),
83
+ path_to_sfd_folder=None, # Ignored due to mocking
84
  path_to_dataset_bank=None, # Ignored due to mocking
85
  host_ztf_id_to_swap_in="ZTF19aaaaaaa", # Swap in this host
86
  path_to_models_directory=str(model_dir),
tests/test_ad_mock.py CHANGED
@@ -137,7 +137,7 @@ def test_anomaly_detection_simplified(tmp_path):
137
  lc_features=lc_features,
138
  host_features=host_features,
139
  path_to_timeseries_folder=str(tmp_path),
140
- path_to_sfd_data_folder=None, # Not needed with our mocks
141
  path_to_dataset_bank=None, # Not needed with our mocks
142
  path_to_models_directory=str(model_dir),
143
  path_to_figure_directory=str(figure_dir),
 
137
  lc_features=lc_features,
138
  host_features=host_features,
139
  path_to_timeseries_folder=str(tmp_path),
140
+ path_to_sfd_folder=None, # Not needed with our mocks
141
  path_to_dataset_bank=None, # Not needed with our mocks
142
  path_to_models_directory=str(model_dir),
143
  path_to_figure_directory=str(figure_dir),
tests/test_ad_simple.py CHANGED
@@ -122,7 +122,7 @@ def test_anomaly_detection_simplified(tmp_path):
122
  lc_features=lc_features,
123
  host_features=host_features,
124
  path_to_timeseries_folder=str(tmp_path),
125
- path_to_sfd_data_folder=None, # This will be ignored due to our mocking
126
  path_to_dataset_bank=None, # This will be ignored due to our mocking
127
  path_to_models_directory=str(model_dir),
128
  path_to_figure_directory=str(figure_dir),
 
122
  lc_features=lc_features,
123
  host_features=host_features,
124
  path_to_timeseries_folder=str(tmp_path),
125
+ path_to_sfd_folder=None, # This will be ignored due to our mocking
126
  path_to_dataset_bank=None, # This will be ignored due to our mocking
127
  path_to_models_directory=str(model_dir),
128
  path_to_figure_directory=str(figure_dir),