alexandergagliano commited on
Commit
0a317b5
·
1 Parent(s): b622cf4

update unit tests

Browse files
.github/workflows/ci.yml CHANGED
@@ -28,6 +28,11 @@ jobs:
28
  python -m pip install --upgrade pip
29
  python -m pip install pytest
30
  if [ -f pyproject.toml ]; then pip install -e .; fi
 
 
 
 
 
31
  - name: Test with pytest
32
  run: |
33
  python -m pytest tests/test_utils.py -v --ci
 
28
  python -m pip install --upgrade pip
29
  python -m pip install pytest
30
  if [ -f pyproject.toml ]; then pip install -e .; fi
31
+ - name: Install python bindings for ngt
32
+ run: |
33
+ git clone https://github.com/yahoojapan/NGT.git
34
+ cd NGT/python
35
+ pip3 install .
36
  - name: Test with pytest
37
  run: |
38
  python -m pytest tests/test_utils.py -v --ci
.github/workflows/smoke-test.yml CHANGED
@@ -35,11 +35,15 @@ jobs:
35
  - name: Install dependencies
36
  run: |
37
  sudo apt-get update
38
- python -m pip install --upgrade pip
39
  pip install -e .[dev]
40
  - name: List dependencies
41
  run: |
42
  pip list
 
 
 
 
 
43
  - name: Run unit tests with pytest / pytest-copie
44
  run: |
45
  python -m pytest
 
35
  - name: Install dependencies
36
  run: |
37
  sudo apt-get update
 
38
  pip install -e .[dev]
39
  - name: List dependencies
40
  run: |
41
  pip list
42
+ - name: Install python bindings for ngt
43
+ run: |
44
+ git clone https://github.com/yahoojapan/NGT.git
45
+ cd NGT/python
46
+ pip3 install .
47
  - name: Run unit tests with pytest / pytest-copie
48
  run: |
49
  python -m pytest
README.md CHANGED
@@ -12,8 +12,13 @@ reLAISS lets you retrieve nearest‑neighbour supernovae (or spot outliers) by c
12
 
13
  # Install
14
 
15
- In a fresh conda environment, run `pip install relaiss`. Then, install the python bindings for ngt from source via the instructions [here](https://github.com/yahoojapan/NGT/blob/main/python/README.md).
16
 
 
 
 
 
 
17
 
18
  # Code Demo
19
  ```
 
12
 
13
  # Install
14
 
15
+ In a fresh conda environment, run `pip install relaiss`. After installing, you must install the python bindings for `ngt` from source:
16
 
17
+ ```
18
+ git clone https://github.com/yahoojapan/NGT.git
19
+ cd NGT/python
20
+ pip3 install .
21
+ ```
22
 
23
  # Code Demo
24
  ```
examples/advanced_usage.py CHANGED
@@ -6,7 +6,7 @@ This script demonstrates advanced features of reLAISS including:
6
  - Swapping host galaxies
7
  - Using PCA for dimensionality reduction
8
  - Setting maximum neighbor distances
9
- - Tweaking ANNOY parameters
10
  - Making corner plots
11
  - Advanced anomaly detection with parameter tuning
12
  - Host swapping in anomaly detection
@@ -106,8 +106,8 @@ def main():
106
  print("\nNearest neighbors within distance threshold:")
107
  print(neighbors_df)
108
 
109
- # Example 5: Tweaking ANNOY parameters
110
- print("\nExample 5: Tweaking ANNOY parameters")
111
  neighbors_df = client.find_neighbors(
112
  ztf_object_id='ZTF21abbzjeq',
113
  n=5,
@@ -116,7 +116,7 @@ def main():
116
  save_figures=True,
117
  path_to_figure_directory='./figures'
118
  )
119
- print("\nNearest neighbors with tweaked ANNOY parameters:")
120
  print(neighbors_df)
121
 
122
  # Example 6: Making corner plots
 
6
  - Swapping host galaxies
7
  - Using PCA for dimensionality reduction
8
  - Setting maximum neighbor distances
9
+ - Tweaking NGT parameters
10
  - Making corner plots
11
  - Advanced anomaly detection with parameter tuning
12
  - Host swapping in anomaly detection
 
106
  print("\nNearest neighbors within distance threshold:")
107
  print(neighbors_df)
108
 
109
+ # Example 5: Tweaking NGT parameters
110
+ print("\nExample 5: Tweaking NGT parameters")
111
  neighbors_df = client.find_neighbors(
112
  ztf_object_id='ZTF21abbzjeq',
113
  n=5,
 
116
  save_figures=True,
117
  path_to_figure_directory='./figures'
118
  )
119
+ print("\nNearest neighbors with tweaked NGT parameters:")
120
  print(neighbors_df)
121
 
122
  # Example 6: Making corner plots
src/relaiss/anomaly_config.py CHANGED
@@ -11,3 +11,5 @@ TRAINING_SAMPLE_SIZE = 5000 # Maximum training features to store
11
  # Training Data Quality Parameters
12
  HIGH_NAN_WARNING_THRESHOLD = 20 # Percentage of NaN samples to trigger warning
13
  VERY_HIGH_NAN_ERROR_THRESHOLD = 50 # Percentage of NaN samples to trigger error
 
 
 
11
  # Training Data Quality Parameters
12
  HIGH_NAN_WARNING_THRESHOLD = 20 # Percentage of NaN samples to trigger warning
13
  VERY_HIGH_NAN_ERROR_THRESHOLD = 50 # Percentage of NaN samples to trigger error
14
+
15
+ MODEL_FILENAME_TEMPLATE = "kNN_scaler_lc={num_lc}_host={num_host}.pkl"
src/relaiss/relaiss.py CHANGED
@@ -12,10 +12,7 @@ from .fetch import get_TNS_data
12
  from .plotting import plot_lightcurves, plot_hosts
13
  import os
14
  from kneed import KneeLocator
15
- try:
16
- import ngtpy as ngt
17
- except:
18
- import ngt
19
  import numpy as np
20
  import pandas as pd
21
  from sklearn.decomposition import PCA
 
12
  from .plotting import plot_lightcurves, plot_hosts
13
  import os
14
  from kneed import KneeLocator
15
+ import ngtpy as ngt
 
 
 
16
  import numpy as np
17
  import pandas as pd
18
  from sklearn.decomposition import PCA
tests/conftest.py CHANGED
@@ -7,7 +7,7 @@ from unittest.mock import patch, MagicMock
7
  import tempfile
8
  import shutil
9
  import astropy.units as u
10
- from .fixtures.search import build_test_annoy_index
11
 
12
  # Get the path to the test fixtures
13
  FIXTURES_DIR = Path(__file__).parent / "fixtures"
@@ -192,15 +192,15 @@ def mock_extinction_all():
192
  yield mock_map, mock_g23
193
 
194
  @pytest.fixture
195
- def test_annoy_index(dataset_bank_path):
196
- """Create a test Annoy index for testing."""
197
  # Get the dataset bank file
198
  df = pd.read_csv(dataset_bank_path)
199
 
200
- # Build a temporary Annoy index
201
  with tempfile.TemporaryDirectory() as tmpdir:
202
- index_path = Path(tmpdir) / "annoy_index"
203
- index, index_path, object_ids = build_test_annoy_index(
204
  test_databank_path=dataset_bank_path,
205
  lc_features=['g_peak_mag', 'r_peak_mag', 'g_peak_time', 'r_peak_time'],
206
  host_features=['host_ra', 'host_dec']
@@ -308,4 +308,4 @@ def pytest_addoption(parser):
308
  """Add command line options."""
309
  parser.addoption(
310
  "--ci", action="store_true", default=False, help="Run in CI mode (skip tests requiring real data)"
311
- )
 
7
  import tempfile
8
  import shutil
9
  import astropy.units as u
10
+ from .fixtures.search import build_test_ngt_index
11
 
12
  # Get the path to the test fixtures
13
  FIXTURES_DIR = Path(__file__).parent / "fixtures"
 
192
  yield mock_map, mock_g23
193
 
194
  @pytest.fixture
195
+ def test_ngt_index(dataset_bank_path):
196
+ """Create a test NGT index for testing."""
197
  # Get the dataset bank file
198
  df = pd.read_csv(dataset_bank_path)
199
 
200
+ # Build a temporary NGT index
201
  with tempfile.TemporaryDirectory() as tmpdir:
202
+ index_path = Path(tmpdir) / "ngt_index"
203
+ index, index_path, object_ids = build_test_ngt_index(
204
  test_databank_path=dataset_bank_path,
205
  lc_features=['g_peak_mag', 'r_peak_mag', 'g_peak_time', 'r_peak_time'],
206
  host_features=['host_ra', 'host_dec']
 
308
  """Add command line options."""
309
  parser.addoption(
310
  "--ci", action="store_true", default=False, help="Run in CI mode (skip tests requiring real data)"
311
+ )
tests/fixtures/search.py CHANGED
@@ -1,13 +1,13 @@
1
  import numpy as np
2
  import pandas as pd
3
- import annoy
4
  import os
5
  from pathlib import Path
6
  import tempfile
7
 
8
- # Test fixture for building an Annoy index for testing
9
- def build_test_annoy_index(test_databank_path, lc_features=None, host_features=None):
10
- """Build an Annoy index from the test dataset bank for testing.
11
 
12
  Parameters
13
  ----------
@@ -21,7 +21,7 @@ def build_test_annoy_index(test_databank_path, lc_features=None, host_features=N
21
  Returns
22
  -------
23
  tuple
24
- (index, index_path, object_ids) - the annoy index, path to temp file, and array of object ids
25
  """
26
  # Default features if none provided
27
  if lc_features is None:
@@ -42,31 +42,31 @@ def build_test_annoy_index(test_databank_path, lc_features=None, host_features=N
42
  feat_arr = np.array(df_features)
43
  feat_arr_scaled = (feat_arr - np.mean(feat_arr, axis=0)) / np.std(feat_arr, axis=0)
44
 
45
- # Create annoy index
 
 
 
 
46
  index_dim = feat_arr.shape[1]
47
- index = annoy.AnnoyIndex(index_dim, "manhattan")
 
48
 
49
  # Add items to index
50
  for i, obj_id in enumerate(df_bank.index):
51
- index.add_item(i, feat_arr_scaled[i])
52
 
53
- # Build index with 10 trees (fewer for tests to be faster)
54
- index.build(10)
55
-
56
- # Create temp file to save index
57
- temp_dir = tempfile.mkdtemp()
58
- index_path = os.path.join(temp_dir, "test_index.ann")
59
- index.save(index_path)
60
 
61
  return index, index_path, np.array(df_bank.index)
62
 
63
  def find_neighbors(index, idx_arr, query_vector, n=5):
64
- """Find neighbors using the test Annoy index.
65
 
66
  Parameters
67
  ----------
68
- index : annoy.AnnoyIndex
69
- The Annoy index to query
70
  idx_arr : numpy.ndarray
71
  Array of object IDs
72
  query_vector : numpy.ndarray
@@ -80,11 +80,10 @@ def find_neighbors(index, idx_arr, query_vector, n=5):
80
  (ids, distances) - arrays of neighbor IDs and distances
81
  """
82
  # Query the index
83
- neighbor_indices, distances = index.get_nns_by_vector(
84
- query_vector, n, include_distances=True
85
- )
86
 
87
  # Get ZTF IDs of neighbors
88
  neighbor_ids = idx_arr[neighbor_indices]
89
 
90
- return neighbor_ids, distances
 
1
  import numpy as np
2
  import pandas as pd
3
+ import ngtpy as ngt
4
  import os
5
  from pathlib import Path
6
  import tempfile
7
 
8
+ # Test fixture for building an NGT index for testing
9
+ def build_test_ngt_index(test_databank_path, lc_features=None, host_features=None):
10
+ """Build an NGT index from the test dataset bank for testing.
11
 
12
  Parameters
13
  ----------
 
21
  Returns
22
  -------
23
  tuple
24
+ (index, index_path, object_ids) - the ngt index, path to temp file, and array of object ids
25
  """
26
  # Default features if none provided
27
  if lc_features is None:
 
42
  feat_arr = np.array(df_features)
43
  feat_arr_scaled = (feat_arr - np.mean(feat_arr, axis=0)) / np.std(feat_arr, axis=0)
44
 
45
+ # Create temp directory for NGT index
46
+ temp_dir = tempfile.mkdtemp()
47
+ index_path = os.path.join(temp_dir, "test_index.ngt")
48
+
49
+ # Create NGT index
50
  index_dim = feat_arr.shape[1]
51
+ ngt.create(index_path.encode(), index_dim, distance_type="L2")
52
+ index = ngt.Index(index_path.encode())
53
 
54
  # Add items to index
55
  for i, obj_id in enumerate(df_bank.index):
56
+ index.insert(feat_arr_scaled[i].astype(np.float32))
57
 
58
+ # Build index
59
+ index.build_index()
 
 
 
 
 
60
 
61
  return index, index_path, np.array(df_bank.index)
62
 
63
  def find_neighbors(index, idx_arr, query_vector, n=5):
64
+ """Find neighbors using the test NGT index.
65
 
66
  Parameters
67
  ----------
68
+ index : ngt.Index
69
+ The NGT index to query
70
  idx_arr : numpy.ndarray
71
  Array of object IDs
72
  query_vector : numpy.ndarray
 
80
  (ids, distances) - arrays of neighbor IDs and distances
81
  """
82
  # Query the index
83
+ res = index.search(query_vector.astype(np.float32), n)
84
+ neighbor_indices, distances = zip(*res)
 
85
 
86
  # Get ZTF IDs of neighbors
87
  neighbor_ids = idx_arr[neighbor_indices]
88
 
89
+ return neighbor_ids, distances
tests/test_search.py CHANGED
@@ -52,9 +52,9 @@ def test_find_neighbors_invalid_input():
52
  with pytest.raises(ValueError):
53
  client.find_neighbors(ztf_object_id="ZTF21abbzjeq", n=-1)
54
 
55
- def test_annoy_search(test_annoy_index, dataset_bank_path):
56
- """Test that the Annoy index works as expected for neighbor search."""
57
- index, index_path, object_ids = test_annoy_index
58
 
59
  # Use a predefined set of features to ensure dimensions match
60
  lc_features = ['g_peak_mag', 'r_peak_mag', 'g_peak_time', 'r_peak_time']
@@ -64,12 +64,10 @@ def test_annoy_search(test_annoy_index, dataset_bank_path):
64
  # Create a random test vector with the correct dimension
65
  test_vector = np.random.rand(vector_dim)
66
 
67
- # Increase the search_k parameter for better accuracy
68
- search_k = 1000
69
-
70
  # Get nearest neighbors
71
  n_items = min(5, len(object_ids))
72
- nearest_indices = index.get_nns_by_vector(test_vector, n_items, search_k=search_k)
 
73
 
74
  # Verify results
75
  assert len(nearest_indices) >= 1 # Accept at least 1 neighbor
 
52
  with pytest.raises(ValueError):
53
  client.find_neighbors(ztf_object_id="ZTF21abbzjeq", n=-1)
54
 
55
+ def test_ngt_search(test_ngt_index, dataset_bank_path):
56
+ """Test that the NGT index works as expected for neighbor search."""
57
+ index, index_path, object_ids = test_ngt_index
58
 
59
  # Use a predefined set of features to ensure dimensions match
60
  lc_features = ['g_peak_mag', 'r_peak_mag', 'g_peak_time', 'r_peak_time']
 
64
  # Create a random test vector with the correct dimension
65
  test_vector = np.random.rand(vector_dim)
66
 
 
 
 
67
  # Get nearest neighbors
68
  n_items = min(5, len(object_ids))
69
+ res = index.search(test_vector.astype(np.float32), n_items)
70
+ nearest_indices, distances = zip(*res)
71
 
72
  # Verify results
73
  assert len(nearest_indices) >= 1 # Accept at least 1 neighbor