alexandergagliano commited on
Commit
e41deef
Β·
1 Parent(s): 271a1e6

restructure and prepare for pypi release

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,4 +1,6 @@
1
- # reLAISS
 
 
2
 
3
  _A flexible library for similarity searches of supernovae and their host galaxies._
4
 
 
1
+ <p align="center">
2
+ <img src="https://github.com/evan-reynolds/re-laiss/blob/main/static/reLAISS_logo.png" style="width: 50%;" alt="reLAISS Logo">
3
+ </p>
4
 
5
  _A flexible library for similarity searches of supernovae and their host galaxies._
6
 
notebooks/.ipynb_checkpoints/01_relaiss_basics-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [project]
3
+ name = "relaiss"
4
+ license = {file = "LICENSE"}
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Evan Reynolds", email = ""}
8
+ { name = "Alex Gagliano", email = "[email protected]" }
9
+ { name = "Ashley Villar", email = "[email protected]"}
10
+ ]
11
+ classifiers = [
12
+ "Development Status :: 4 - Beta",
13
+ "License :: OSI Approved :: MIT License",
14
+ "Intended Audience :: Developers",
15
+ "Intended Audience :: Science/Research",
16
+ "Operating System :: OS Independent",
17
+ "Programming Language :: Python",
18
+ ]
19
+ dynamic = ["version"]
20
+ requires-python = ">=3.8"
21
+ dependencies = [
22
+ "numpy",
23
+ "astropy",
24
+ "matplotlib",
25
+ "pandas",
26
+ "scikit-learn",
27
+ "scipy",
28
+ "requests",
29
+ "sfdmap; python_version <= '3.9'",
30
+ "sfdmap2; python_version >= '3.9'",
31
+ ]
32
+
33
+ [project.urls]
34
+ "Source Code" = "https://github.com/evan-reynolds/re-laiss/"
35
+
36
+ [project.optional-dependencies]
37
+ dev = [
38
+ "asv==0.6.4", # Used to compute performance benchmarks
39
+ "jupyter", # Clears output from Jupyter notebooks
40
+ "pre-commit", # Used to run checks before finalizing a git commit
41
+ "pytest",
42
+ "pytest-cov", # Used to report total code coverage
43
+ "ruff", # Used for static linting of files
44
+ ]
45
+
46
+ [build-system]
47
+ requires = [
48
+ "setuptools>=62", # Used to build and package the Python project
49
+ "setuptools_scm>=6.2", # Gets release version from git. Makes it available programmatically
50
+ ]
51
+ build-backend = "setuptools.build_meta"
52
+
53
+ [tool.setuptools]
54
+ include-package-data = true
55
+
56
+ [tool.setuptools.package-data]
57
+ "relaiss" = ["data/*csv*"]
58
+
59
+ [tool.setuptools_scm]
60
+ write_to = "src/relaiss/_version.py"
61
+ local_scheme = "no-local-version"
62
+ version_scheme = "no-guess-dev"
63
+
64
+ [tool.pytest.ini_options]
65
+ testpaths = [
66
+ "tests",
67
+ ]
68
+
69
+ [tool.black]
70
+ line-length = 110
71
+ target-version = ["py39"]
72
+
73
+ [tool.isort]
74
+ profile = "black"
75
+ line_length = 110
76
+
77
+ [tool.ruff]
78
+ line-length = 110
79
+ target-version = "py39"
80
+
81
+ [tool.ruff.lint]
82
+ select = [
83
+ # pycodestyle
84
+ "E",
85
+ "W",
86
+ # Pyflakes
87
+ "F",
88
+ # pep8-naming
89
+ "N",
90
+ # pyupgrade
91
+ "UP",
92
+ # flake8-bugbear
93
+ "B",
94
+ # flake8-simplify
95
+ "SIM",
96
+ # isort
97
+ "I",
98
+ # docstrings
99
+ "D101",
100
+ "D102",
101
+ "D103",
102
+ "D106",
103
+ "D206",
104
+ "D207",
105
+ "D208",
106
+ "D300",
107
+ "D417",
108
+ "D419",
109
+ # Numpy v2.0 compatibility
110
+ "NPY201",
111
+ ]
112
+
113
+ ignore = [
114
+ "UP006", # Allow non standard library generics in type hints
115
+ "UP007", # Allow Union in type hints
116
+ "SIM114", # Allow if with same arms
117
+ "B028", # Allow default warning level
118
+ "SIM117", # Allow nested with
119
+ "UP015", # Allow redundant open parameters
120
+ "UP028", # Allow yield in for loop
121
+ ]
122
+
123
+ [tool.coverage.run]
124
+ omit=["src/relaiss/_version.py"]
{code β†’ src/relaiss}/constants.py RENAMED
File without changes
{code β†’ src/relaiss}/helper_func.py RENAMED
@@ -28,10 +28,24 @@ from scipy.stats import gamma, uniform
28
  from dust_extinction.parameter_averages import G23
29
  from astro_prost.associate import associate_sample
30
 
31
-
32
  @contextmanager
33
  def re_suppress_output():
34
- """Temporarily silence stdout, stderr, warnings *and* all logging messages < CRITICAL."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  with open(os.devnull, "w") as devnull:
36
  old_stdout, old_stderr = sys.stdout, sys.stderr
37
  sys.stdout, sys.stderr = devnull, devnull
@@ -47,6 +61,19 @@ def re_suppress_output():
47
 
48
 
49
  def re_getTnsData(ztf_id):
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  locus = antares_client.search.get_by_ztf_object_id(ztf_object_id=ztf_id)
51
  try:
52
  tns = locus.catalog_objects["tns_public_objects"][0]
@@ -64,6 +91,25 @@ def re_getExtinctionCorrectedMag(
64
  av_in_raw_df_bank,
65
  path_to_sfd_folder=None,
66
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  central_wv_filters = {"g": 4849.11, "r": 6201.20, "i": 7534.96, "z": 8674.20}
68
  MW_RV = 3.1
69
  ext = G23(Rv=MW_RV)
@@ -90,7 +136,36 @@ def re_build_dataset_bank(
90
  building_entire_df_bank=False,
91
  building_for_AD=False,
92
  ):
93
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  raw_lc_features = constants.lc_features_const.copy()
95
  raw_host_features = constants.raw_host_features_const.copy()
96
 
@@ -232,6 +307,41 @@ def re_extract_lc_and_host_features(
232
  building_for_AD=False,
233
  swapped_host=False,
234
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  start_time = time.time()
236
  df_path = path_to_timeseries_folder
237
 
@@ -492,8 +602,19 @@ def re_extract_lc_and_host_features(
492
 
493
 
494
  def _ps1_list_filenames(ra_deg, dec_deg, flt):
495
- """
496
- Return the first stack FITS filename for (ra,dec) and *flt* or None.
 
 
 
 
 
 
 
 
 
 
 
497
  """
498
  url = (
499
  "https://ps1images.stsci.edu/cgi-bin/ps1filenames.py"
@@ -509,8 +630,26 @@ def _ps1_list_filenames(ra_deg, dec_deg, flt):
509
 
510
 
511
  def fetch_ps1_cutout(ra_deg, dec_deg, *, size_pix=100, flt="r"):
512
- """
513
- Grayscale cut-out (2-D float) in a single PS1 filter.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
  """
515
  fits_name = _ps1_list_filenames(ra_deg, dec_deg, flt)
516
  if fits_name is None:
@@ -537,9 +676,21 @@ def fetch_ps1_cutout(ra_deg, dec_deg, *, size_pix=100, flt="r"):
537
 
538
 
539
  def fetch_ps1_rgb_jpeg(ra_deg, dec_deg, *, size_pix=100):
540
- """
541
- Colour JPEG (H,W,3 uint8) using PS1 g/r/i stacks.
542
- Falls back by *raising* RuntimeError when the server lacks colour data.
 
 
 
 
 
 
 
 
 
 
 
 
543
  """
544
  url = (
545
  "https://ps1images.stsci.edu/cgi-bin/fitscut.cgi"
@@ -567,6 +718,33 @@ def re_plot_lightcurves(
567
  figure_path,
568
  save_figures=True,
569
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
  print("Making a plot of stacked lightcurves...")
571
 
572
  if primer_dict["lc_tns_z"] is None:
@@ -737,12 +915,35 @@ def re_plot_hosts(
737
  change_contrast=False,
738
  prefer_color=True,
739
  ):
740
- """
741
- Build 3Γ—3 grids of PS1 thumbnails for each row in *df* and write a PDF.
742
-
743
- Set *prefer_color=False* for r-band grayscale only. With *prefer_color=True*
744
- (default) the code *tries* colour first and quietly falls back to grayscale
745
- when colour isn’t available.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
746
  """
747
 
748
  host_grid_path = figure_path + "/host_grids"
@@ -833,6 +1034,40 @@ def re_check_anom_and_plot(
833
  savefig,
834
  figure_path,
835
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
836
  anom_obj_df = timeseries_df_features_only
837
 
838
  pred_prob_anom = 100 * clf.predict_proba(anom_obj_df)
@@ -1000,6 +1235,30 @@ def re_get_timeseries_df(
1000
  building_for_AD=False,
1001
  swapped_host=False,
1002
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1003
  if theorized_lightcurve_df is not None:
1004
  print("Extracting full lightcurve features for theorized lightcurve...")
1005
  timeseries_df = re_extract_lc_and_host_features(
@@ -1048,6 +1307,24 @@ def re_get_timeseries_df(
1048
  def create_re_laiss_features_dict(
1049
  lc_feature_names, host_feature_names, lc_groups=4, host_groups=4
1050
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1051
  re_laiss_features_dict = {}
1052
 
1053
  # Split light curve features into evenly sized chunks
 
28
  from dust_extinction.parameter_averages import G23
29
  from astro_prost.associate import associate_sample
30
 
 
31
  @contextmanager
32
  def re_suppress_output():
33
+ """Context-manager that silences *everything* except CRITICAL logs.
34
+
35
+ Temporarily redirects ``stdout``/``stderr`` to ``os.devnull``, ignores
36
+ warnings, and disables the root logger for messages < ``logging.CRITICAL``.
37
+ Restores all streams and the logger state on exit.
38
+
39
+ Yields
40
+ ------
41
+ None
42
+ Used only for the ``with`` context block.
43
+
44
+ Examples
45
+ --------
46
+ >>> with re_suppress_output():
47
+ ... noisy_function()
48
+ """
49
  with open(os.devnull, "w") as devnull:
50
  old_stdout, old_stderr = sys.stdout, sys.stderr
51
  sys.stdout, sys.stderr = devnull, devnull
 
61
 
62
 
63
  def re_getTnsData(ztf_id):
64
+ """Fetch the TNS cross-match for a given ZTF object.
65
+
66
+ Parameters
67
+ ----------
68
+ ztf_id : str
69
+ ZTF object ID, e.g. ``"ZTF23abcxyz"``.
70
+
71
+ Returns
72
+ -------
73
+ tuple[str, str, float]
74
+ *(tns_name, tns_type, tns_redshift)*. Values default to
75
+ ``("No TNS", "---", -99)`` when no match or metadata are present.
76
+ """
77
  locus = antares_client.search.get_by_ztf_object_id(ztf_object_id=ztf_id)
78
  try:
79
  tns = locus.catalog_objects["tns_public_objects"][0]
 
91
  av_in_raw_df_bank,
92
  path_to_sfd_folder=None,
93
  ):
94
+ """Milky-Way extinction-corrected Kron magnitude for one passband.
95
+
96
+ Parameters
97
+ ----------
98
+ transient_row : pandas.Series
99
+ Row from the raw host-feature DataFrame.
100
+ band : {'g', 'r', 'i', 'z'}
101
+ Photometric filter to correct.
102
+ av_in_raw_df_bank : bool
103
+ If *True* use ``transient_row["A_V"]`` directly; otherwise compute
104
+ E(Bβˆ’V) from the SFD dust map in *path_to_sfd_folder*.
105
+ path_to_sfd_folder : str | pathlib.Path | None, optional
106
+ Folder containing *SFDMap* dust files when A_V is not pre-computed.
107
+
108
+ Returns
109
+ -------
110
+ float
111
+ Extinction-corrected Kron magnitude.
112
+ """
113
  central_wv_filters = {"g": 4849.11, "r": 6201.20, "i": 7534.96, "z": 8674.20}
114
  MW_RV = 3.1
115
  ext = G23(Rv=MW_RV)
 
136
  building_entire_df_bank=False,
137
  building_for_AD=False,
138
  ):
139
+ """Clean, impute, dust-correct, and engineer features for reLAISS.
140
+
141
+ Handles both archival and *theorized* light-curve inputs, performs KNN or
142
+ mean imputation, builds colour indices, propagates uncertainties, and
143
+ returns a ready-to-index DataFrame.
144
+
145
+ Parameters
146
+ ----------
147
+ raw_df_bank : pandas.DataFrame
148
+ Input light-curve + host-galaxy features (one or many rows).
149
+ av_in_raw_df_bank : bool
150
+ Whether A_V is already present in *raw_df_bank*.
151
+ path_to_sfd_folder : str | Path | None, optional
152
+ Directory with SFD dust maps (required if ``av_in_raw_df_bank=False``).
153
+ theorized : bool, default False
154
+ Set *True* when the input is a simulated/theoretical light curve that
155
+ lacks host features.
156
+ path_to_dataset_bank : str | Path | None, optional
157
+ Existing bank used to fit the imputer when not building the entire set.
158
+ building_entire_df_bank : bool, default False
159
+ If *True*, fit the imputer on *raw_df_bank* itself.
160
+ building_for_AD : bool, default False
161
+ Use simpler mean imputation and suppress verbose prints for
162
+ anomaly-detection pipelines.
163
+
164
+ Returns
165
+ -------
166
+ pandas.DataFrame
167
+ Fully hydrated feature table indexed by ``ztf_object_id``.
168
+ """
169
  raw_lc_features = constants.lc_features_const.copy()
170
  raw_host_features = constants.raw_host_features_const.copy()
171
 
 
307
  building_for_AD=False,
308
  swapped_host=False,
309
  ):
310
+ """End-to-end extraction of light-curve **and** host-galaxy features.
311
+
312
+ 1. Pulls ZTF photometry from ANTARES (or uses a supplied theoretical LC).
313
+ 2. Computes time-series features with *lightcurve_engineer*.
314
+ 3. Associates the most probable PS1 host with PROST and appends raw host
315
+ features.
316
+ 4. Dust-corrects, builds colours, imputes gaps, and writes an optional CSV.
317
+
318
+ Parameters
319
+ ----------
320
+ ztf_id : str
321
+ ZTF object identifier (ignored when *theorized_lightcurve_df* is given).
322
+ path_to_timeseries_folder : str | Path
323
+ Folder to cache per-object time-series CSVs.
324
+ path_to_sfd_data_folder : str | Path
325
+ Location of SFD dust maps.
326
+ theorized_lightcurve_df : pandas.DataFrame | None, optional
327
+ Pre-simulated LC in ANTARES column format (``ant_passband``, ``ant_mjd``,
328
+ ``ant_mag``, ``ant_magerr``).
329
+ show_lc : bool, default False
330
+ Plot the g/r light curves.
331
+ show_host : bool, default True
332
+ Print PS1 cut-out URL on successful host association.
333
+ store_csv : bool, default False
334
+ Write a timeseries CSV next to *path_to_timeseries_folder*.
335
+ building_for_AD : bool, default False
336
+ Quieter prints + mean imputation only.
337
+ swapped_host : bool, default False
338
+ Indicator used when re-running with an alternate host galaxy.
339
+
340
+ Returns
341
+ -------
342
+ pandas.DataFrame
343
+ Hydrated feature rows for every increasing-epoch subset of the LC.
344
+ """
345
  start_time = time.time()
346
  df_path = path_to_timeseries_folder
347
 
 
602
 
603
 
604
  def _ps1_list_filenames(ra_deg, dec_deg, flt):
605
+ """Return the first PS1 stacked-image FITS filename at (RA, Dec).
606
+
607
+ Parameters
608
+ ----------
609
+ ra_deg, dec_deg : float
610
+ ICRS coordinates in degrees.
611
+ flt : str
612
+ PS1 filter letter (``'g' 'r' 'i' 'z' 'y'``).
613
+
614
+ Returns
615
+ -------
616
+ str | None
617
+ Filename, e.g. ``'tess-skycell1001.012-i.fits'``, or *None* when absent.
618
  """
619
  url = (
620
  "https://ps1images.stsci.edu/cgi-bin/ps1filenames.py"
 
630
 
631
 
632
  def fetch_ps1_cutout(ra_deg, dec_deg, *, size_pix=100, flt="r"):
633
+ """Download a single-filter PS1 FITS cut-out around *(RA, Dec)*.
634
+
635
+ Parameters
636
+ ----------
637
+ ra_deg, dec_deg : float
638
+ ICRS coordinates (degrees).
639
+ size_pix : int, default 100
640
+ Width/height of the square cut-out in PS1 pixels.
641
+ flt : str, default 'r'
642
+ PS1 filter.
643
+
644
+ Returns
645
+ -------
646
+ numpy.ndarray
647
+ 2-D float array (grayscale image).
648
+
649
+ Raises
650
+ ------
651
+ RuntimeError
652
+ When the target lies outside the PS1 footprint or no data exist.
653
  """
654
  fits_name = _ps1_list_filenames(ra_deg, dec_deg, flt)
655
  if fits_name is None:
 
676
 
677
 
678
  def fetch_ps1_rgb_jpeg(ra_deg, dec_deg, *, size_pix=100):
679
+ """Fetch an RGB JPEG cut-out (g/r/i) from PS1.
680
+
681
+ Falls back via *raising* ``RuntimeError`` when PS1 lacks colour data.
682
+
683
+ Parameters
684
+ ----------
685
+ ra_deg, dec_deg : float
686
+ ICRS coordinates (degrees).
687
+ size_pix : int, default 100
688
+ Square cut-out size in pixels.
689
+
690
+ Returns
691
+ -------
692
+ numpy.ndarray
693
+ ``(H, W, 3)`` uint8 array in RGB order.
694
  """
695
  url = (
696
  "https://ps1images.stsci.edu/cgi-bin/fitscut.cgi"
 
718
  figure_path,
719
  save_figures=True,
720
  ):
721
+ """Stack reference + neighbour light curves in a single figure.
722
+
723
+ Parameters
724
+ ----------
725
+ primer_dict : dict
726
+ Metadata for the reference transient (e.g., TNS name/class/redshift).
727
+ plot_label : str
728
+ Text used for figure title and filename.
729
+ theorized_lightcurve_df : pandas.DataFrame | None
730
+ Optional simulated LC to plot as the reference.
731
+ neighbor_ztfids : list[str]
732
+ ZTF IDs of retrieved neighbours (<= 8 plotted).
733
+ ann_locus_l : list[antares_client.objects.Locus]
734
+ Corresponding ANTARES loci holding photometry.
735
+ ann_dists : list[float]
736
+ ANN distances for labeling.
737
+ tns_ann_names, tns_ann_classes, tns_ann_zs : list
738
+ TNS metadata for neighbours.
739
+ figure_path : str | Path
740
+ Root folder to save PNGs in ``lightcurves/``.
741
+ save_figures : bool, default True
742
+ Write the PNG to disk.
743
+
744
+ Returns
745
+ -------
746
+ None
747
+ """
748
  print("Making a plot of stacked lightcurves...")
749
 
750
  if primer_dict["lc_tns_z"] is None:
 
915
  change_contrast=False,
916
  prefer_color=True,
917
  ):
918
+ """Create 3Γ—3 PS1 thumbnail grids for candidate host galaxies.
919
+
920
+ Saves each page to a multi-page PDF and optionally shows colour cut-outs
921
+ when available.
922
+
923
+ Parameters
924
+ ----------
925
+ ztfid_ref : str
926
+ Reference transient ID (title use only).
927
+ plot_label : str
928
+ Basename for the output PDF.
929
+ df : pandas.DataFrame
930
+ Table with ``ZTFID``, ``HOST_RA``, ``HOST_DEC`` columns.
931
+ figure_path : str | Path
932
+ Destination directory for ``host_grids/*.pdf``.
933
+ ann_num : int
934
+ ANN neighbour index (used in filename).
935
+ save_pdf : bool, default True
936
+ Whether to write the PDF.
937
+ imsizepix : int, default 100
938
+ PS1 cut-out size in pixels.
939
+ change_contrast : bool, default False
940
+ Use a shallower stretch (93 %) for grayscale images.
941
+ prefer_color : bool, default True
942
+ Try RGB first, fall back to r-band grayscale.
943
+
944
+ Returns
945
+ -------
946
+ None
947
  """
948
 
949
  host_grid_path = figure_path + "/host_grids"
 
1034
  savefig,
1035
  figure_path,
1036
  ):
1037
+ """Run anomaly-detector probabilities over a time-series and plot results.
1038
+
1039
+ Produces a two-panel figure: light curve with anomaly epoch marked, and
1040
+ rolling anomaly/normal probabilities.
1041
+
1042
+ Parameters
1043
+ ----------
1044
+ clf : sklearn.base.ClassifierMixin
1045
+ Trained binary classifier with ``predict_proba``.
1046
+ input_ztf_id : str
1047
+ ID of the object evaluated.
1048
+ swapped_host_ztf_id : str | None
1049
+ Alternate host ID (annotated in title).
1050
+ input_spec_cls : str | None
1051
+ Spectroscopic class label for title.
1052
+ input_spec_z : float | str | None
1053
+ Redshift for title.
1054
+ anom_thresh : float
1055
+ Probability (%) above which an epoch is flagged anomalous.
1056
+ timeseries_df_full : pandas.DataFrame
1057
+ Hydrated LC + host features, including ``obs_num`` and ``mjd_cutoff``.
1058
+ timeseries_df_features_only : pandas.DataFrame
1059
+ Same rows but feature columns only (classifier input).
1060
+ ref_info : antares_client.objects.Locus
1061
+ ANTARES locus for retrieving original photometry.
1062
+ savefig : bool
1063
+ Save the plot as ``AD/*.pdf`` inside *figure_path*.
1064
+ figure_path : str | Path
1065
+ Output directory.
1066
+
1067
+ Returns
1068
+ -------
1069
+ None
1070
+ """
1071
  anom_obj_df = timeseries_df_features_only
1072
 
1073
  pred_prob_anom = 100 * clf.predict_proba(anom_obj_df)
 
1235
  building_for_AD=False,
1236
  swapped_host=False,
1237
  ):
1238
+ """Retrieve or build a fully-hydrated time-series feature DataFrame.
1239
+
1240
+ Checks disk cache; otherwise calls
1241
+ ``re_extract_lc_and_host_features`` and optionally writes the CSV.
1242
+
1243
+ Parameters
1244
+ ----------
1245
+ ztf_id : str
1246
+ path_to_timeseries_folder : str | Path
1247
+ path_to_sfd_data_folder : str | Path
1248
+ theorized_lightcurve_df : pandas.DataFrame | None
1249
+ If provided, builds features for a simulated LC.
1250
+ save_timeseries : bool, default False
1251
+ Persist CSV to disk.
1252
+ path_to_dataset_bank : str | Path | None
1253
+ Reference bank for imputers.
1254
+ building_for_AD : bool, default False
1255
+ swapped_host : bool, default False
1256
+
1257
+ Returns
1258
+ -------
1259
+ pandas.DataFrame
1260
+ Feature rows ready for indexing or AD.
1261
+ """
1262
  if theorized_lightcurve_df is not None:
1263
  print("Extracting full lightcurve features for theorized lightcurve...")
1264
  timeseries_df = re_extract_lc_and_host_features(
 
1307
  def create_re_laiss_features_dict(
1308
  lc_feature_names, host_feature_names, lc_groups=4, host_groups=4
1309
  ):
1310
+ """Partition feature names into evenly-sized groups for weighting.
1311
+
1312
+ Parameters
1313
+ ----------
1314
+ lc_feature_names : list[str]
1315
+ Names of light-curve features.
1316
+ host_feature_names : list[str]
1317
+ Names of host-galaxy features.
1318
+ lc_groups : int, default 4
1319
+ Number of LC groups in the output dict.
1320
+ host_groups : int, default 4
1321
+ Number of host groups in the output dict.
1322
+
1323
+ Returns
1324
+ -------
1325
+ dict[str, list[str]]
1326
+ ``{'lc_group_1': [...], 'host_group_1': [...], ...}``
1327
+ """
1328
  re_laiss_features_dict = {}
1329
 
1330
  # Split light curve features into evenly sized chunks
{code β†’ src/relaiss}/lightcurve_engineer.py RENAMED
@@ -6,11 +6,25 @@ from sfdmap2 import sfdmap
6
  from dust_extinction.parameter_averages import G23
7
  from numpy.lib.stride_tricks import sliding_window_view
8
  import warnings
 
9
 
10
  warnings.filterwarnings("ignore", category=RuntimeWarning)
11
 
12
-
13
  def local_curvature(times, mags):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  if len(times) < 3:
15
  return np.nan
16
  curvatures = []
@@ -31,6 +45,14 @@ m = sfdmap.SFDMap()
31
  class SupernovaFeatureExtractor:
32
  @staticmethod
33
  def describe_features():
 
 
 
 
 
 
 
 
34
  return {
35
  "t0": "Time zero-point for light curve normalization",
36
  "g_peak_mag": "Minimum magnitude (brightest point) in g band",
@@ -77,6 +99,22 @@ class SupernovaFeatureExtractor:
77
  def __init__(
78
  self, time_g, mag_g, err_g, time_r, mag_r, err_r, ZTFID=None, ra=None, dec=None
79
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  if ZTFID:
81
  self.ZTFID = ZTFID
82
  else:
@@ -107,6 +145,12 @@ class SupernovaFeatureExtractor:
107
  self._preprocess()
108
 
109
  def _preprocess(self, min_cluster_size=2):
 
 
 
 
 
 
110
  for band_name in ["g", "r"]:
111
  band = getattr(self, band_name)
112
  idx = np.argsort(band["time"])
@@ -146,8 +190,10 @@ class SupernovaFeatureExtractor:
146
  self.time_offset += new_time_offset
147
 
148
  def _select_main_cluster(self, time, mag, min_samples=3, eps=20):
149
- from sklearn.cluster import DBSCAN
150
 
 
 
151
  if len(time) < min_samples:
152
  return np.ones_like(time, dtype=bool)
153
  time_reshaped = np.array(time).reshape(-1, 1)
@@ -171,6 +217,13 @@ class SupernovaFeatureExtractor:
171
  return labels == best_label
172
 
173
  def _flag_isolated_points(time, max_gap_factor=5):
 
 
 
 
 
 
 
174
  time = np.sort(time)
175
  dt = np.diff(time)
176
 
@@ -188,6 +241,22 @@ class SupernovaFeatureExtractor:
188
  return isolated
189
 
190
  def _core_stats(self, band):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  t, m = band["time"], band["mag"]
192
  mask = np.isfinite(t) & np.isfinite(m) & ~np.isnan(m)
193
  t, m = t[mask], m[mask]
@@ -216,6 +285,13 @@ class SupernovaFeatureExtractor:
216
  return peak_mag, peak_time, rise_time, decline_time, duration
217
 
218
  def _variability_stats(self, band):
 
 
 
 
 
 
 
219
  mag = band["mag"]
220
  amp = np.max(mag) - np.min(mag)
221
  std = np.std(mag)
@@ -225,6 +301,14 @@ class SupernovaFeatureExtractor:
225
  return amp, skew, beyond_2
226
 
227
  def _color_features(self):
 
 
 
 
 
 
 
 
228
  if len(self.g["time"]) < 2 or len(self.r["time"]) < 2:
229
  # print("Warning: Not enough data in g or r band to compute color features.")
230
  return None
@@ -261,6 +345,18 @@ class SupernovaFeatureExtractor:
261
  return np.mean(color), gr_at_gpeak, mean_rate
262
 
263
  def _rolling_variance(self, band, window_size=5):
 
 
 
 
 
 
 
 
 
 
 
 
264
  def dedup(t, m):
265
  _, idx = np.unique(t, return_index=True)
266
  return t[idx], m[idx]
@@ -275,6 +371,13 @@ class SupernovaFeatureExtractor:
275
  return np.max(rolling_vars), np.mean(rolling_vars)
276
 
277
  def _peak_structure(self, band):
 
 
 
 
 
 
 
278
  if np.ptp(band["mag"]) < 0.5:
279
  # print("Warning: Insufficient variability to identify peak structure.")
280
  return 0, np.nan, np.nan, np.nan, np.nan
@@ -299,6 +402,13 @@ class SupernovaFeatureExtractor:
299
  return n_peaks, dt, dmag, prominence_second, width_second
300
 
301
  def _local_curvature_features(self, band, window_days=20):
 
 
 
 
 
 
 
302
  t, m = band["time"], band["mag"]
303
  mask = np.isfinite(t) & np.isfinite(m)
304
  t, m = t[mask], m[mask]
@@ -330,6 +440,22 @@ class SupernovaFeatureExtractor:
330
  return rise_curv, decline_curv
331
 
332
  def extract_features(self, return_uncertainty=False, n_trials=20):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  if len(self.g["time"]) == 0 or len(self.r["time"]) == 0:
334
  # print(
335
  # f"Warning: No data left in g or r band after filtering for object: {self.ZTFID}. Skipping."
 
6
  from dust_extinction.parameter_averages import G23
7
  from numpy.lib.stride_tricks import sliding_window_view
8
  import warnings
9
+ from sklearn.cluster import DBSCAN
10
 
11
  warnings.filterwarnings("ignore", category=RuntimeWarning)
12
 
 
13
  def local_curvature(times, mags):
14
+ """Median second derivative (curvature) of a light-curve segment.
15
+
16
+ Parameters
17
+ ----------
18
+ times : array-like
19
+ Strictly increasing observation times (days).
20
+ mags : array-like
21
+ Corresponding magnitudes.
22
+
23
+ Returns
24
+ -------
25
+ float
26
+ Median curvature in mag day⁻²; ``np.nan`` if fewer than three points.
27
+ """
28
  if len(times) < 3:
29
  return np.nan
30
  curvatures = []
 
45
  class SupernovaFeatureExtractor:
46
  @staticmethod
47
  def describe_features():
48
+ """Dictionary mapping feature names β†’ human-readable descriptions.
49
+
50
+ Returns
51
+ -------
52
+ dict[str, str]
53
+ Keys follow the column names produced by
54
+ :pymeth:`SupernovaFeatureExtractor.extract_features`.
55
+ """
56
  return {
57
  "t0": "Time zero-point for light curve normalization",
58
  "g_peak_mag": "Minimum magnitude (brightest point) in g band",
 
99
  def __init__(
100
  self, time_g, mag_g, err_g, time_r, mag_r, err_r, ZTFID=None, ra=None, dec=None
101
  ):
102
+ """Create a feature extractor for g/r light curves.
103
+
104
+ Times are zero-pointed to the earliest observation; optional Milky-Way
105
+ extinction is applied when *ra/dec* are supplied.
106
+
107
+ Parameters
108
+ ----------
109
+ time_g, mag_g, err_g : array-like
110
+ g-band MJD, magnitude and 1-Οƒ uncertainty.
111
+ time_r, mag_r, err_r : array-like
112
+ r-band MJD, magnitude and 1-Οƒ uncertainty.
113
+ ZTFID : str | None, optional
114
+ Identifier used in warnings and output tables.
115
+ ra, dec : float | None, optional
116
+ ICRS coordinates (deg) for dust-extinction correction.
117
+ """
118
  if ZTFID:
119
  self.ZTFID = ZTFID
120
  else:
 
145
  self._preprocess()
146
 
147
  def _preprocess(self, min_cluster_size=2):
148
+ """Sort, de-duplicate, and DBSCAN-filter out isolated epochs.
149
+
150
+ Removes cluster labels with fewer than *min_cluster_size* points and
151
+ re-normalises times so that ``t=0`` corresponds to the earliest good
152
+ observation in either band.
153
+ """
154
  for band_name in ["g", "r"]:
155
  band = getattr(self, band_name)
156
  idx = np.argsort(band["time"])
 
190
  self.time_offset += new_time_offset
191
 
192
  def _select_main_cluster(self, time, mag, min_samples=3, eps=20):
193
+ """Return a boolean mask selecting the dominant DBSCAN time cluster.
194
 
195
+ The cluster with the brightest peak and tightest span wins the tie-break.
196
+ """
197
  if len(time) < min_samples:
198
  return np.ones_like(time, dtype=bool)
199
  time_reshaped = np.array(time).reshape(-1, 1)
 
217
  return labels == best_label
218
 
219
  def _flag_isolated_points(time, max_gap_factor=5):
220
+ """Identify photometric points that are isolated by large temporal gaps.
221
+
222
+ Returns
223
+ -------
224
+ numpy.ndarray[bool]
225
+ True for epochs flanked by gaps > *max_gap_factor* Γ— median cadence.
226
+ """
227
  time = np.sort(time)
228
  dt = np.diff(time)
229
 
 
241
  return isolated
242
 
243
  def _core_stats(self, band):
244
+ """Peak, rise/decline and half-flux duration for one band.
245
+
246
+ Parameters
247
+ ----------
248
+ band : dict
249
+ ``{'time','mag'}`` arrays for a single filter.
250
+
251
+ Returns
252
+ -------
253
+ tuple
254
+ *(peak_mag, peak_time, rise_time, decline_time, duration_above_half)*
255
+
256
+ Notes
257
+ -----
258
+ All values are ``np.nan`` if <3 points or total peak-to-peak amplitude <0.2 mag.
259
+ """
260
  t, m = band["time"], band["mag"]
261
  mask = np.isfinite(t) & np.isfinite(m) & ~np.isnan(m)
262
  t, m = t[mask], m[mask]
 
285
  return peak_mag, peak_time, rise_time, decline_time, duration
286
 
287
  def _variability_stats(self, band):
288
+ """Amplitude, skewness, and 2-Οƒ outlier rate of a magnitude series.
289
+
290
+ Returns
291
+ -------
292
+ tuple
293
+ *(amplitude, skewness, fraction_beyond_2Οƒ)*
294
+ """
295
  mag = band["mag"]
296
  amp = np.max(mag) - np.min(mag)
297
  std = np.std(mag)
 
301
  return amp, skew, beyond_2
302
 
303
  def _color_features(self):
304
+ """Compute mean g–r colour, g–r at g-band peak, and average colour slope.
305
+
306
+ Returns
307
+ -------
308
+ tuple
309
+ ``(mean_colour, colour_at_g_peak, mean_dcolour_dt)``
310
+ or ``None`` when bands lack overlap.
311
+ """
312
  if len(self.g["time"]) < 2 or len(self.r["time"]) < 2:
313
  # print("Warning: Not enough data in g or r band to compute color features.")
314
  return None
 
345
  return np.mean(color), gr_at_gpeak, mean_rate
346
 
347
  def _rolling_variance(self, band, window_size=5):
348
+ """Max & mean variance in sliding windows over an interpolated LC.
349
+
350
+ Parameters
351
+ ----------
352
+ window_size : int, default 5
353
+ Number of interpolated samples per window.
354
+
355
+ Returns
356
+ -------
357
+ tuple
358
+ *(max_var, mean_var)*
359
+ """
360
  def dedup(t, m):
361
  _, idx = np.unique(t, return_index=True)
362
  return t[idx], m[idx]
 
371
  return np.max(rolling_vars), np.mean(rolling_vars)
372
 
373
  def _peak_structure(self, band):
374
+ """Secondary-peak diagnostics using SciPy ``find_peaks``.
375
+
376
+ Returns
377
+ -------
378
+ tuple
379
+ *(n_peaks, Ξ”t, Ξ”mag, prominenceβ‚‚, widthβ‚‚)* with NaNs when <2 peaks.
380
+ """
381
  if np.ptp(band["mag"]) < 0.5:
382
  # print("Warning: Insufficient variability to identify peak structure.")
383
  return 0, np.nan, np.nan, np.nan, np.nan
 
402
  return n_peaks, dt, dmag, prominence_second, width_second
403
 
404
  def _local_curvature_features(self, band, window_days=20):
405
+ """Median curvature on the rise and decline within Β±*window_days* of peak.
406
+
407
+ Returns
408
+ -------
409
+ tuple
410
+ ``(rise_curvature, decline_curvature)``
411
+ """
412
  t, m = band["time"], band["mag"]
413
  mask = np.isfinite(t) & np.isfinite(m)
414
  t, m = t[mask], m[mask]
 
440
  return rise_curv, decline_curv
441
 
442
  def extract_features(self, return_uncertainty=False, n_trials=20):
443
+ """Generate the full reLAISS feature vector for the supplied LC.
444
+
445
+ Parameters
446
+ ----------
447
+ return_uncertainty : bool, default False
448
+ If True, performs *n_trials* MC perturbations and appends 1-Οƒ errors
449
+ (columns with ``_err`` suffix).
450
+ n_trials : int, default 20
451
+ Number of Monte-Carlo resamples when *return_uncertainty* is True.
452
+
453
+ Returns
454
+ -------
455
+ pandas.DataFrame | None
456
+ Single-row feature table (with optional error columns) or *None* when
457
+ either band lacks data after pre-processing.
458
+ """
459
  if len(self.g["time"]) == 0 or len(self.r["time"]) == 0:
460
  # print(
461
  # f"Warning: No data left in g or r band after filtering for object: {self.ZTFID}. Skipping."
{code β†’ src/relaiss}/relaiss_func.py RENAMED
@@ -13,7 +13,6 @@ from kneed import KneeLocator
13
  from pyod.models.iforest import IForest
14
  from statsmodels import robust
15
 
16
-
17
  def re_build_indexed_sample(
18
  dataset_bank_path,
19
  lc_features=[],
@@ -26,6 +25,42 @@ def re_build_indexed_sample(
26
  force_recreation_of_index=False,
27
  weight_lc_feats_factor=1,
28
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  df_bank = pd.read_csv(dataset_bank_path)
30
 
31
  # Confirm that the first column is the ZTF ID, and index by ZTF ID
@@ -138,7 +173,38 @@ def re_LAISS_primer(
138
  host_features=[],
139
  num_sims=10,
140
  ):
141
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  feature_names = lc_features + host_features
143
  if lc_ztf_id is not None and theorized_lightcurve_df is not None:
144
  print(
@@ -349,6 +415,34 @@ def re_LAISS_nearest_neighbors(
349
  save_figures=True,
350
  path_to_figure_directory="../figures",
351
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  start_time = time.time()
353
  index_file = annoy_index_file_stem + ".ann"
354
 
@@ -676,6 +770,23 @@ def re_train_AD_model(
676
  max_samples=1024,
677
  force_retrain=False,
678
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679
  feature_names = lc_features + host_features
680
  df_bank_path = path_to_dataset_bank
681
  model_dir = path_to_models_directory
@@ -742,6 +853,28 @@ def re_anomaly_detection(
742
  max_samples=1024,
743
  force_retrain=False,
744
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745
  print("Running Anomaly Detection:\n")
746
 
747
  # Train the model (if necessary)
@@ -842,7 +975,32 @@ def re_LAISS(
842
  force_AD_retrain=False, # Retrains and saves AD model even if it already exists
843
  save_figures=True, # Saves all figures while running LAISS
844
  ):
845
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
846
  if run_NN or suggest_neighbor_num:
847
  # build ANNOY indexed sample from dataset bank
848
  index_stem_name_with_path = re_build_indexed_sample(
@@ -915,7 +1073,6 @@ def re_LAISS(
915
  return
916
 
917
 
918
- # Note: old corner plots in the figure directory will be overwritten!
919
  def re_corner_plot(
920
  neighbors_df, # from reLAISS nearest neighbors
921
  primer_dict, # from reLAISS nearest neighbors
@@ -924,6 +1081,24 @@ def re_corner_plot(
924
  path_to_figure_directory="../figures",
925
  save_plots=True,
926
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
927
  if primer_dict is None:
928
  raise ValueError(
929
  "primer_dict is None. Try running NN search with reLAISS again."
 
13
  from pyod.models.iforest import IForest
14
  from statsmodels import robust
15
 
 
16
  def re_build_indexed_sample(
17
  dataset_bank_path,
18
  lc_features=[],
 
25
  force_recreation_of_index=False,
26
  weight_lc_feats_factor=1,
27
  ):
28
+ """Create (or load) an ANNOY index over a reference feature bank.
29
+
30
+ Parameters
31
+ ----------
32
+ dataset_bank_path : str | Path
33
+ CSV produced by ``re_build_dataset_bank``; first column must be
34
+ ``ztf_object_id``.
35
+ lc_features, host_features : list[str]
36
+ Feature columns to include in the index.
37
+ Provide one or both lists.
38
+ use_pca : bool, default False
39
+ Apply PCA before indexing.
40
+ n_components : int | None
41
+ Dimensionality of PCA space; ignored if *use_pca=False*.
42
+ num_trees : int, default 1000
43
+ Number of random projection trees for ANNOY.
44
+ path_to_index_directory : str | Path, default ""
45
+ Folder for ``*.ann`` plus ``*.npy`` support files.
46
+ save : bool, default True
47
+ Persist index and numpy arrays.
48
+ force_recreation_of_index : bool, default False
49
+ Rebuild even when an index file already exists.
50
+ weight_lc_feats_factor : float, default 1
51
+ Scalar >1 up-weights LC columns relative to host features
52
+ (ignored if *use_pca=True*).
53
+
54
+ Returns
55
+ -------
56
+ str
57
+ Stem path (without ``.ann`` extension) of the built/loaded index.
58
+
59
+ Raises
60
+ ------
61
+ ValueError
62
+ When feature inputs are invalid or required columns are missing.
63
+ """
64
  df_bank = pd.read_csv(dataset_bank_path)
65
 
66
  # Confirm that the first column is the ZTF ID, and index by ZTF ID
 
173
  host_features=[],
174
  num_sims=10,
175
  ):
176
+ """Assemble input feature vectors (and MC replicas) for a query object.
177
+
178
+ Combines LC + host featuresβ€”optionally swapping in a different hostβ€”and
179
+ returns a dict used later by NN and AD stages.
180
+
181
+ Parameters
182
+ ----------
183
+ lc_ztf_id : str | None
184
+ ZTF ID of the transient to query. Mutually exclusive with
185
+ *theorized_lightcurve_df*.
186
+ theorized_lightcurve_df : pandas.DataFrame | None
187
+ Pre-computed ANTARES-style LC for a theoretical model.
188
+ host_ztf_id : str | None
189
+ If given, replace the query object’s host features with those of this
190
+ transient.
191
+ dataset_bank_path, path_to_timeseries_folder, path_to_sfd_data_folder : str | Path
192
+ Locations for cached data.
193
+ lc_features, host_features : list[str]
194
+ Names of columns to extract.
195
+ num_sims : int, default 10
196
+ Number of Monte-Carlo perturbations for uncertainty propagation.
197
+
198
+ Returns
199
+ -------
200
+ dict
201
+ Primer dictionary containing feature arrays, metadata, and MC sims.
202
+
203
+ Raises
204
+ ------
205
+ ValueError
206
+ On inconsistent inputs or missing data.
207
+ """
208
  feature_names = lc_features + host_features
209
  if lc_ztf_id is not None and theorized_lightcurve_df is not None:
210
  print(
 
415
  save_figures=True,
416
  path_to_figure_directory="../figures",
417
  ):
418
+ """Query the ANNOY index and plot nearest-neighbor diagnostics.
419
+
420
+ Parameters
421
+ ----------
422
+ primer_dict : dict
423
+ Output from :func:`re_LAISS_primer`.
424
+ annoy_index_file_stem : str
425
+ Stem path returned by :func:`re_build_indexed_sample`.
426
+ use_pca, num_pca_components : see above
427
+ n : int, default 8
428
+ Number of neighbours to return.
429
+ suggest_neighbor_num : bool, default False
430
+ If True, plots the distance elbow and exits early.
431
+ max_neighbor_dist : float | None
432
+ Optional cut on L1 distance.
433
+ search_k : int, default 1000
434
+ ANNOY *search_k* parameter.
435
+ weight_lc_feats_factor : float, default 1
436
+ Same interpretation as in ``re_build_indexed_sample``.
437
+ save_figures : bool, default True
438
+ Write LC + host plots and distance-elbow PNGs.
439
+ path_to_figure_directory : str | Path
440
+
441
+ Returns
442
+ -------
443
+ pandas.DataFrame | None
444
+ Table summarising neighbours (or *None* if *suggest_neighbor_num=True*).
445
+ """
446
  start_time = time.time()
447
  index_file = annoy_index_file_stem + ".ann"
448
 
 
770
  max_samples=1024,
771
  force_retrain=False,
772
  ):
773
+ """Train or load an Isolation-Forest anomaly-detection model.
774
+
775
+ Parameters
776
+ ----------
777
+ lc_features, host_features : list[str]
778
+ Feature columns used by the model.
779
+ path_to_dataset_bank : str | Path
780
+ path_to_models_directory : str | Path
781
+ n_estimators, contamination, max_samples : see *pyod.models.IForest*
782
+ force_retrain : bool, default False
783
+ Ignore cached model and retrain.
784
+
785
+ Returns
786
+ -------
787
+ str
788
+ Filesystem path to the saved ``.pkl`` pipeline.
789
+ """
790
  feature_names = lc_features + host_features
791
  df_bank_path = path_to_dataset_bank
792
  model_dir = path_to_models_directory
 
853
  max_samples=1024,
854
  force_retrain=False,
855
  ):
856
+ """Run anomaly detection for a single transient (with optional host swap).
857
+
858
+ Generates an AD probability plot and calls
859
+ :func:`re_check_anom_and_plot`.
860
+
861
+ Parameters
862
+ ----------
863
+ transient_ztf_id : str
864
+ Target object ID.
865
+ host_ztf_id_to_swap_in : str | None
866
+ Replace host features before scoring.
867
+ lc_features, host_features : list[str]
868
+ path_* : folders for intermediates, models, and figures.
869
+ save_figures : bool, default True
870
+ n_estimators, contamination, max_samples : Isolation-Forest params.
871
+ force_retrain : bool, default False
872
+ Pass-through to :func:`re_train_AD_model`.
873
+
874
+ Returns
875
+ -------
876
+ None
877
+ """
878
  print("Running Anomaly Detection:\n")
879
 
880
  # Train the model (if necessary)
 
975
  force_AD_retrain=False, # Retrains and saves AD model even if it already exists
976
  save_figures=True, # Saves all figures while running LAISS
977
  ):
978
+ """High-level convenience wrapper: build index β†’ NN search β†’ AD.
979
+
980
+ Combines the *primer*, *nearest-neighbours*, and *anomaly-detection*
981
+ pipelines with many toggles for experimentation.
982
+
983
+ Parameters
984
+ ----------
985
+ transient_ztf_id : str | None
986
+ theorized_lightcurve_df : pandas.DataFrame | None
987
+ host_ztf_id_to_swap_in : str | None
988
+ lc_feature_names, host_feature_names : list[str]
989
+ neighbors : int
990
+ Target neighbour count.
991
+ suggest_neighbor_num : bool
992
+ Show elbow plot instead of full NN run.
993
+ run_NN, run_AD : bool
994
+ Enable/disable each pipeline stage.
995
+ *Other params*
996
+ See lower-level helpers for details.
997
+
998
+ Returns
999
+ -------
1000
+ (pandas.DataFrame | None, dict | None)
1001
+ Neighbours table and primer dict when NN stage executed; otherwise
1002
+ *None*.
1003
+ """
1004
  if run_NN or suggest_neighbor_num:
1005
  # build ANNOY indexed sample from dataset bank
1006
  index_stem_name_with_path = re_build_indexed_sample(
 
1073
  return
1074
 
1075
 
 
1076
  def re_corner_plot(
1077
  neighbors_df, # from reLAISS nearest neighbors
1078
  primer_dict, # from reLAISS nearest neighbors
 
1081
  path_to_figure_directory="../figures",
1082
  save_plots=True,
1083
  ):
1084
+ """Corner-plot visualisation of feature distributions vs. neighbours.
1085
+
1086
+ Parameters
1087
+ ----------
1088
+ neighbors_df : pandas.DataFrame
1089
+ Output from :func:`re_LAISS_nearest_neighbors`.
1090
+ primer_dict : dict
1091
+ Output from :func:`re_LAISS_primer`.
1092
+ path_to_dataset_bank : str | Path
1093
+ remove_outliers_bool : bool, default True
1094
+ Apply robust MAD clipping before plotting.
1095
+ save_plots : bool, default True
1096
+ Write PNGs to ``corner_plots/``.
1097
+
1098
+ Returns
1099
+ -------
1100
+ None
1101
+ """
1102
  if primer_dict is None:
1103
  raise ValueError(
1104
  "primer_dict is None. Try running NN search with reLAISS again."
static/reLAISS_logo.png ADDED

Git LFS Details

  • SHA256: a24af358ae92e2f8950d424fafa441ffbdc337b9c1b0818c6b0f3b3a717b5fd1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.47 MB
tests/test_search.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ import relaiss as rl
6
+
7
+
8
+ @pytest.fixture(scope="session")
9
+ def relaiss_client():
10
+ """Load the cached reference client once for the whole test session."""
11
+ try:
12
+ client = rl.load_reference()
13
+ except FileNotFoundError as err:
14
+ pytest.skip(f"Reference index unavailable – {err}")
15
+ return client
16
+
17
+
18
+ def test_load_reference_singleton(relaiss_client):
19
+ c1 = rl.load_reference()
20
+ c2 = rl.load_reference()
21
+ assert c1 is c2, "load_reference should cache the client instance"
22
+
23
+
24
+ def test_find_neighbors_dataframe(relaiss_client):
25
+ df = rl.find_neighbors("ZTF21abbzjeq", k=5) # arbitrary real ZTF ID
26
+ assert isinstance(df, pd.DataFrame)
27
+ assert list(df.columns) == ["ztfid", "distance"]
28
+ assert len(df) == 5
29
+ # Distances should be non-decreasing
30
+ assert np.all(df["distance"].values[:-1] <= df["distance"].values[1:])