Spaces:
Sleeping
Sleeping
alexandergagliano
commited on
Commit
·
0a317b5
1
Parent(s):
b622cf4
update unit tests
Browse files- .github/workflows/ci.yml +5 -0
- .github/workflows/smoke-test.yml +5 -1
- README.md +6 -1
- examples/advanced_usage.py +4 -4
- src/relaiss/anomaly_config.py +2 -0
- src/relaiss/relaiss.py +1 -4
- tests/conftest.py +7 -7
- tests/fixtures/search.py +21 -22
- tests/test_search.py +5 -7
.github/workflows/ci.yml
CHANGED
|
@@ -28,6 +28,11 @@ jobs:
|
|
| 28 |
python -m pip install --upgrade pip
|
| 29 |
python -m pip install pytest
|
| 30 |
if [ -f pyproject.toml ]; then pip install -e .; fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
- name: Test with pytest
|
| 32 |
run: |
|
| 33 |
python -m pytest tests/test_utils.py -v --ci
|
|
|
|
| 28 |
python -m pip install --upgrade pip
|
| 29 |
python -m pip install pytest
|
| 30 |
if [ -f pyproject.toml ]; then pip install -e .; fi
|
| 31 |
+
- name: Install python bindings for ngt
|
| 32 |
+
run: |
|
| 33 |
+
git clone https://github.com/yahoojapan/NGT.git
|
| 34 |
+
cd NGT/python
|
| 35 |
+
pip3 install .
|
| 36 |
- name: Test with pytest
|
| 37 |
run: |
|
| 38 |
python -m pytest tests/test_utils.py -v --ci
|
.github/workflows/smoke-test.yml
CHANGED
|
@@ -35,11 +35,15 @@ jobs:
|
|
| 35 |
- name: Install dependencies
|
| 36 |
run: |
|
| 37 |
sudo apt-get update
|
| 38 |
-
python -m pip install --upgrade pip
|
| 39 |
pip install -e .[dev]
|
| 40 |
- name: List dependencies
|
| 41 |
run: |
|
| 42 |
pip list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
- name: Run unit tests with pytest / pytest-copie
|
| 44 |
run: |
|
| 45 |
python -m pytest
|
|
|
|
| 35 |
- name: Install dependencies
|
| 36 |
run: |
|
| 37 |
sudo apt-get update
|
|
|
|
| 38 |
pip install -e .[dev]
|
| 39 |
- name: List dependencies
|
| 40 |
run: |
|
| 41 |
pip list
|
| 42 |
+
- name: Install python bindings for ngt
|
| 43 |
+
run: |
|
| 44 |
+
git clone https://github.com/yahoojapan/NGT.git
|
| 45 |
+
cd NGT/python
|
| 46 |
+
pip3 install .
|
| 47 |
- name: Run unit tests with pytest / pytest-copie
|
| 48 |
run: |
|
| 49 |
python -m pytest
|
README.md
CHANGED
|
@@ -12,8 +12,13 @@ reLAISS lets you retrieve nearest‑neighbour supernovae (or spot outliers) by c
|
|
| 12 |
|
| 13 |
# Install
|
| 14 |
|
| 15 |
-
In a fresh conda environment, run `pip install relaiss`.
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# Code Demo
|
| 19 |
```
|
|
|
|
| 12 |
|
| 13 |
# Install
|
| 14 |
|
| 15 |
+
In a fresh conda environment, run `pip install relaiss`. After installing, you must install the python bindings for `ngt` from source:
|
| 16 |
|
| 17 |
+
```
|
| 18 |
+
git clone https://github.com/yahoojapan/NGT.git
|
| 19 |
+
cd NGT/python
|
| 20 |
+
pip3 install .
|
| 21 |
+
```
|
| 22 |
|
| 23 |
# Code Demo
|
| 24 |
```
|
examples/advanced_usage.py
CHANGED
|
@@ -6,7 +6,7 @@ This script demonstrates advanced features of reLAISS including:
|
|
| 6 |
- Swapping host galaxies
|
| 7 |
- Using PCA for dimensionality reduction
|
| 8 |
- Setting maximum neighbor distances
|
| 9 |
-
- Tweaking
|
| 10 |
- Making corner plots
|
| 11 |
- Advanced anomaly detection with parameter tuning
|
| 12 |
- Host swapping in anomaly detection
|
|
@@ -106,8 +106,8 @@ def main():
|
|
| 106 |
print("\nNearest neighbors within distance threshold:")
|
| 107 |
print(neighbors_df)
|
| 108 |
|
| 109 |
-
# Example 5: Tweaking
|
| 110 |
-
print("\nExample 5: Tweaking
|
| 111 |
neighbors_df = client.find_neighbors(
|
| 112 |
ztf_object_id='ZTF21abbzjeq',
|
| 113 |
n=5,
|
|
@@ -116,7 +116,7 @@ def main():
|
|
| 116 |
save_figures=True,
|
| 117 |
path_to_figure_directory='./figures'
|
| 118 |
)
|
| 119 |
-
print("\nNearest neighbors with tweaked
|
| 120 |
print(neighbors_df)
|
| 121 |
|
| 122 |
# Example 6: Making corner plots
|
|
|
|
| 6 |
- Swapping host galaxies
|
| 7 |
- Using PCA for dimensionality reduction
|
| 8 |
- Setting maximum neighbor distances
|
| 9 |
+
- Tweaking NGT parameters
|
| 10 |
- Making corner plots
|
| 11 |
- Advanced anomaly detection with parameter tuning
|
| 12 |
- Host swapping in anomaly detection
|
|
|
|
| 106 |
print("\nNearest neighbors within distance threshold:")
|
| 107 |
print(neighbors_df)
|
| 108 |
|
| 109 |
+
# Example 5: Tweaking NGT parameters
|
| 110 |
+
print("\nExample 5: Tweaking NGT parameters")
|
| 111 |
neighbors_df = client.find_neighbors(
|
| 112 |
ztf_object_id='ZTF21abbzjeq',
|
| 113 |
n=5,
|
|
|
|
| 116 |
save_figures=True,
|
| 117 |
path_to_figure_directory='./figures'
|
| 118 |
)
|
| 119 |
+
print("\nNearest neighbors with tweaked NGT parameters:")
|
| 120 |
print(neighbors_df)
|
| 121 |
|
| 122 |
# Example 6: Making corner plots
|
src/relaiss/anomaly_config.py
CHANGED
|
@@ -11,3 +11,5 @@ TRAINING_SAMPLE_SIZE = 5000 # Maximum training features to store
|
|
| 11 |
# Training Data Quality Parameters
|
| 12 |
HIGH_NAN_WARNING_THRESHOLD = 20 # Percentage of NaN samples to trigger warning
|
| 13 |
VERY_HIGH_NAN_ERROR_THRESHOLD = 50 # Percentage of NaN samples to trigger error
|
|
|
|
|
|
|
|
|
| 11 |
# Training Data Quality Parameters
|
| 12 |
HIGH_NAN_WARNING_THRESHOLD = 20 # Percentage of NaN samples to trigger warning
|
| 13 |
VERY_HIGH_NAN_ERROR_THRESHOLD = 50 # Percentage of NaN samples to trigger error
|
| 14 |
+
|
| 15 |
+
MODEL_FILENAME_TEMPLATE = "kNN_scaler_lc={num_lc}_host={num_host}.pkl"
|
src/relaiss/relaiss.py
CHANGED
|
@@ -12,10 +12,7 @@ from .fetch import get_TNS_data
|
|
| 12 |
from .plotting import plot_lightcurves, plot_hosts
|
| 13 |
import os
|
| 14 |
from kneed import KneeLocator
|
| 15 |
-
|
| 16 |
-
import ngtpy as ngt
|
| 17 |
-
except:
|
| 18 |
-
import ngt
|
| 19 |
import numpy as np
|
| 20 |
import pandas as pd
|
| 21 |
from sklearn.decomposition import PCA
|
|
|
|
| 12 |
from .plotting import plot_lightcurves, plot_hosts
|
| 13 |
import os
|
| 14 |
from kneed import KneeLocator
|
| 15 |
+
import ngtpy as ngt
|
|
|
|
|
|
|
|
|
|
| 16 |
import numpy as np
|
| 17 |
import pandas as pd
|
| 18 |
from sklearn.decomposition import PCA
|
tests/conftest.py
CHANGED
|
@@ -7,7 +7,7 @@ from unittest.mock import patch, MagicMock
|
|
| 7 |
import tempfile
|
| 8 |
import shutil
|
| 9 |
import astropy.units as u
|
| 10 |
-
from .fixtures.search import
|
| 11 |
|
| 12 |
# Get the path to the test fixtures
|
| 13 |
FIXTURES_DIR = Path(__file__).parent / "fixtures"
|
|
@@ -192,15 +192,15 @@ def mock_extinction_all():
|
|
| 192 |
yield mock_map, mock_g23
|
| 193 |
|
| 194 |
@pytest.fixture
|
| 195 |
-
def
|
| 196 |
-
"""Create a test
|
| 197 |
# Get the dataset bank file
|
| 198 |
df = pd.read_csv(dataset_bank_path)
|
| 199 |
|
| 200 |
-
# Build a temporary
|
| 201 |
with tempfile.TemporaryDirectory() as tmpdir:
|
| 202 |
-
index_path = Path(tmpdir) / "
|
| 203 |
-
index, index_path, object_ids =
|
| 204 |
test_databank_path=dataset_bank_path,
|
| 205 |
lc_features=['g_peak_mag', 'r_peak_mag', 'g_peak_time', 'r_peak_time'],
|
| 206 |
host_features=['host_ra', 'host_dec']
|
|
@@ -308,4 +308,4 @@ def pytest_addoption(parser):
|
|
| 308 |
"""Add command line options."""
|
| 309 |
parser.addoption(
|
| 310 |
"--ci", action="store_true", default=False, help="Run in CI mode (skip tests requiring real data)"
|
| 311 |
-
)
|
|
|
|
| 7 |
import tempfile
|
| 8 |
import shutil
|
| 9 |
import astropy.units as u
|
| 10 |
+
from .fixtures.search import build_test_ngt_index
|
| 11 |
|
| 12 |
# Get the path to the test fixtures
|
| 13 |
FIXTURES_DIR = Path(__file__).parent / "fixtures"
|
|
|
|
| 192 |
yield mock_map, mock_g23
|
| 193 |
|
| 194 |
@pytest.fixture
|
| 195 |
+
def test_ngt_index(dataset_bank_path):
|
| 196 |
+
"""Create a test NGT index for testing."""
|
| 197 |
# Get the dataset bank file
|
| 198 |
df = pd.read_csv(dataset_bank_path)
|
| 199 |
|
| 200 |
+
# Build a temporary NGT index
|
| 201 |
with tempfile.TemporaryDirectory() as tmpdir:
|
| 202 |
+
index_path = Path(tmpdir) / "ngt_index"
|
| 203 |
+
index, index_path, object_ids = build_test_ngt_index(
|
| 204 |
test_databank_path=dataset_bank_path,
|
| 205 |
lc_features=['g_peak_mag', 'r_peak_mag', 'g_peak_time', 'r_peak_time'],
|
| 206 |
host_features=['host_ra', 'host_dec']
|
|
|
|
| 308 |
"""Add command line options."""
|
| 309 |
parser.addoption(
|
| 310 |
"--ci", action="store_true", default=False, help="Run in CI mode (skip tests requiring real data)"
|
| 311 |
+
)
|
tests/fixtures/search.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
import numpy as np
|
| 2 |
import pandas as pd
|
| 3 |
-
import
|
| 4 |
import os
|
| 5 |
from pathlib import Path
|
| 6 |
import tempfile
|
| 7 |
|
| 8 |
-
# Test fixture for building an
|
| 9 |
-
def
|
| 10 |
-
"""Build an
|
| 11 |
|
| 12 |
Parameters
|
| 13 |
----------
|
|
@@ -21,7 +21,7 @@ def build_test_annoy_index(test_databank_path, lc_features=None, host_features=N
|
|
| 21 |
Returns
|
| 22 |
-------
|
| 23 |
tuple
|
| 24 |
-
(index, index_path, object_ids) - the
|
| 25 |
"""
|
| 26 |
# Default features if none provided
|
| 27 |
if lc_features is None:
|
|
@@ -42,31 +42,31 @@ def build_test_annoy_index(test_databank_path, lc_features=None, host_features=N
|
|
| 42 |
feat_arr = np.array(df_features)
|
| 43 |
feat_arr_scaled = (feat_arr - np.mean(feat_arr, axis=0)) / np.std(feat_arr, axis=0)
|
| 44 |
|
| 45 |
-
# Create
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
index_dim = feat_arr.shape[1]
|
| 47 |
-
|
|
|
|
| 48 |
|
| 49 |
# Add items to index
|
| 50 |
for i, obj_id in enumerate(df_bank.index):
|
| 51 |
-
index.
|
| 52 |
|
| 53 |
-
# Build index
|
| 54 |
-
index.
|
| 55 |
-
|
| 56 |
-
# Create temp file to save index
|
| 57 |
-
temp_dir = tempfile.mkdtemp()
|
| 58 |
-
index_path = os.path.join(temp_dir, "test_index.ann")
|
| 59 |
-
index.save(index_path)
|
| 60 |
|
| 61 |
return index, index_path, np.array(df_bank.index)
|
| 62 |
|
| 63 |
def find_neighbors(index, idx_arr, query_vector, n=5):
|
| 64 |
-
"""Find neighbors using the test
|
| 65 |
|
| 66 |
Parameters
|
| 67 |
----------
|
| 68 |
-
index :
|
| 69 |
-
The
|
| 70 |
idx_arr : numpy.ndarray
|
| 71 |
Array of object IDs
|
| 72 |
query_vector : numpy.ndarray
|
|
@@ -80,11 +80,10 @@ def find_neighbors(index, idx_arr, query_vector, n=5):
|
|
| 80 |
(ids, distances) - arrays of neighbor IDs and distances
|
| 81 |
"""
|
| 82 |
# Query the index
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
)
|
| 86 |
|
| 87 |
# Get ZTF IDs of neighbors
|
| 88 |
neighbor_ids = idx_arr[neighbor_indices]
|
| 89 |
|
| 90 |
-
return neighbor_ids, distances
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import pandas as pd
|
| 3 |
+
import ngtpy as ngt
|
| 4 |
import os
|
| 5 |
from pathlib import Path
|
| 6 |
import tempfile
|
| 7 |
|
| 8 |
+
# Test fixture for building an NGT index for testing
|
| 9 |
+
def build_test_ngt_index(test_databank_path, lc_features=None, host_features=None):
|
| 10 |
+
"""Build an NGT index from the test dataset bank for testing.
|
| 11 |
|
| 12 |
Parameters
|
| 13 |
----------
|
|
|
|
| 21 |
Returns
|
| 22 |
-------
|
| 23 |
tuple
|
| 24 |
+
(index, index_path, object_ids) - the ngt index, path to temp file, and array of object ids
|
| 25 |
"""
|
| 26 |
# Default features if none provided
|
| 27 |
if lc_features is None:
|
|
|
|
| 42 |
feat_arr = np.array(df_features)
|
| 43 |
feat_arr_scaled = (feat_arr - np.mean(feat_arr, axis=0)) / np.std(feat_arr, axis=0)
|
| 44 |
|
| 45 |
+
# Create temp directory for NGT index
|
| 46 |
+
temp_dir = tempfile.mkdtemp()
|
| 47 |
+
index_path = os.path.join(temp_dir, "test_index.ngt")
|
| 48 |
+
|
| 49 |
+
# Create NGT index
|
| 50 |
index_dim = feat_arr.shape[1]
|
| 51 |
+
ngt.create(index_path.encode(), index_dim, distance_type="L2")
|
| 52 |
+
index = ngt.Index(index_path.encode())
|
| 53 |
|
| 54 |
# Add items to index
|
| 55 |
for i, obj_id in enumerate(df_bank.index):
|
| 56 |
+
index.insert(feat_arr_scaled[i].astype(np.float32))
|
| 57 |
|
| 58 |
+
# Build index
|
| 59 |
+
index.build_index()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
return index, index_path, np.array(df_bank.index)
|
| 62 |
|
| 63 |
def find_neighbors(index, idx_arr, query_vector, n=5):
|
| 64 |
+
"""Find neighbors using the test NGT index.
|
| 65 |
|
| 66 |
Parameters
|
| 67 |
----------
|
| 68 |
+
index : ngt.Index
|
| 69 |
+
The NGT index to query
|
| 70 |
idx_arr : numpy.ndarray
|
| 71 |
Array of object IDs
|
| 72 |
query_vector : numpy.ndarray
|
|
|
|
| 80 |
(ids, distances) - arrays of neighbor IDs and distances
|
| 81 |
"""
|
| 82 |
# Query the index
|
| 83 |
+
res = index.search(query_vector.astype(np.float32), n)
|
| 84 |
+
neighbor_indices, distances = zip(*res)
|
|
|
|
| 85 |
|
| 86 |
# Get ZTF IDs of neighbors
|
| 87 |
neighbor_ids = idx_arr[neighbor_indices]
|
| 88 |
|
| 89 |
+
return neighbor_ids, distances
|
tests/test_search.py
CHANGED
|
@@ -52,9 +52,9 @@ def test_find_neighbors_invalid_input():
|
|
| 52 |
with pytest.raises(ValueError):
|
| 53 |
client.find_neighbors(ztf_object_id="ZTF21abbzjeq", n=-1)
|
| 54 |
|
| 55 |
-
def
|
| 56 |
-
"""Test that the
|
| 57 |
-
index, index_path, object_ids =
|
| 58 |
|
| 59 |
# Use a predefined set of features to ensure dimensions match
|
| 60 |
lc_features = ['g_peak_mag', 'r_peak_mag', 'g_peak_time', 'r_peak_time']
|
|
@@ -64,12 +64,10 @@ def test_annoy_search(test_annoy_index, dataset_bank_path):
|
|
| 64 |
# Create a random test vector with the correct dimension
|
| 65 |
test_vector = np.random.rand(vector_dim)
|
| 66 |
|
| 67 |
-
# Increase the search_k parameter for better accuracy
|
| 68 |
-
search_k = 1000
|
| 69 |
-
|
| 70 |
# Get nearest neighbors
|
| 71 |
n_items = min(5, len(object_ids))
|
| 72 |
-
|
|
|
|
| 73 |
|
| 74 |
# Verify results
|
| 75 |
assert len(nearest_indices) >= 1 # Accept at least 1 neighbor
|
|
|
|
| 52 |
with pytest.raises(ValueError):
|
| 53 |
client.find_neighbors(ztf_object_id="ZTF21abbzjeq", n=-1)
|
| 54 |
|
| 55 |
+
def test_ngt_search(test_ngt_index, dataset_bank_path):
|
| 56 |
+
"""Test that the NGT index works as expected for neighbor search."""
|
| 57 |
+
index, index_path, object_ids = test_ngt_index
|
| 58 |
|
| 59 |
# Use a predefined set of features to ensure dimensions match
|
| 60 |
lc_features = ['g_peak_mag', 'r_peak_mag', 'g_peak_time', 'r_peak_time']
|
|
|
|
| 64 |
# Create a random test vector with the correct dimension
|
| 65 |
test_vector = np.random.rand(vector_dim)
|
| 66 |
|
|
|
|
|
|
|
|
|
|
| 67 |
# Get nearest neighbors
|
| 68 |
n_items = min(5, len(object_ids))
|
| 69 |
+
res = index.search(test_vector.astype(np.float32), n_items)
|
| 70 |
+
nearest_indices, distances = zip(*res)
|
| 71 |
|
| 72 |
# Verify results
|
| 73 |
assert len(nearest_indices) >= 1 # Accept at least 1 neighbor
|