Spaces:
Sleeping
Sleeping
alexandergagliano
commited on
Commit
·
8ff867d
1
Parent(s):
7402f92
Update tests
Browse files- tests/test_ad.py +32 -64
tests/test_ad.py
CHANGED
|
@@ -5,8 +5,9 @@ import relaiss as rl
|
|
| 5 |
from relaiss.anomaly import anomaly_detection, train_AD_model
|
| 6 |
import os
|
| 7 |
import joblib
|
|
|
|
| 8 |
from pathlib import Path
|
| 9 |
-
from unittest.mock import patch, MagicMock
|
| 10 |
|
| 11 |
@pytest.fixture
|
| 12 |
def sample_preprocessed_df():
|
|
@@ -138,8 +139,9 @@ def test_train_AD_model_with_raw_data(tmp_path, sample_preprocessed_df):
|
|
| 138 |
expected_filename = f"IForest_n=100_c=0.02_m=256.pkl"
|
| 139 |
expected_model_path = str(tmp_path / expected_filename)
|
| 140 |
|
| 141 |
-
# Mock the ReLAISS client
|
| 142 |
with patch('relaiss.relaiss.ReLAISS') as mock_client_class, \
|
|
|
|
| 143 |
patch('joblib.dump') as mock_dump:
|
| 144 |
|
| 145 |
# Configure mock client
|
|
@@ -300,41 +302,24 @@ def test_anomaly_detection_basic(sample_preprocessed_df, tmp_path):
|
|
| 300 |
'mjd_cutoff': np.linspace(58000, 58050, 20),
|
| 301 |
})
|
| 302 |
|
| 303 |
-
#
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
def predict(self, X):
|
| 311 |
-
return np.array([1 if np.random.random() > 0.1 else -1 for _ in range(len(X))])
|
| 312 |
-
|
| 313 |
-
def decision_function(self, X):
|
| 314 |
-
return np.random.uniform(-0.5, 0.5, len(X))
|
| 315 |
-
|
| 316 |
-
def predict_proba(self, X):
|
| 317 |
-
# Add predict_proba method that isolation forest doesn't normally have
|
| 318 |
-
n_samples = X.shape[0]
|
| 319 |
-
probas = np.zeros((n_samples, 2))
|
| 320 |
-
# Random probabilities that sum to 1 for each sample
|
| 321 |
-
probas[:, 0] = np.random.uniform(0.6, 0.9, n_samples)
|
| 322 |
-
probas[:, 1] = 1 - probas[:, 0]
|
| 323 |
-
return probas
|
| 324 |
-
|
| 325 |
-
def fit(self, X):
|
| 326 |
-
return self
|
| 327 |
-
|
| 328 |
-
# Save a real model file
|
| 329 |
-
real_forest = MockIsolationForest()
|
| 330 |
-
with open(model_path, 'wb') as f:
|
| 331 |
-
joblib.dump(real_forest, f)
|
| 332 |
|
| 333 |
# Apply comprehensive mocking
|
| 334 |
-
with patch('relaiss.
|
|
|
|
| 335 |
patch('relaiss.anomaly.get_TNS_data', return_value=("MockSN", "Ia", 0.1)), \
|
| 336 |
-
patch('sklearn.ensemble.IsolationForest', return_value=
|
| 337 |
patch('joblib.dump'), \
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
patch('matplotlib.pyplot.figure', return_value=MagicMock()), \
|
| 339 |
patch('matplotlib.pyplot.savefig'), \
|
| 340 |
patch('matplotlib.pyplot.show'), \
|
|
@@ -487,44 +472,27 @@ def test_anomaly_detection_with_host_swap(sample_preprocessed_df, tmp_path):
|
|
| 487 |
mock_swapped_host_df['gKronMag'] = [20.0] * 20
|
| 488 |
mock_swapped_host_df['rKronMag'] = [19.5] * 20
|
| 489 |
|
| 490 |
-
#
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
def predict(self, X):
|
| 498 |
-
return np.array([1 if np.random.random() > 0.1 else -1 for _ in range(len(X))])
|
| 499 |
-
|
| 500 |
-
def decision_function(self, X):
|
| 501 |
-
return np.random.uniform(-0.5, 0.5, len(X))
|
| 502 |
-
|
| 503 |
-
def predict_proba(self, X):
|
| 504 |
-
# Add predict_proba method that isolation forest doesn't normally have
|
| 505 |
-
n_samples = X.shape[0]
|
| 506 |
-
probas = np.zeros((n_samples, 2))
|
| 507 |
-
# Random probabilities that sum to 1 for each sample
|
| 508 |
-
probas[:, 0] = np.random.uniform(0.6, 0.9, n_samples)
|
| 509 |
-
probas[:, 1] = 1 - probas[:, 0]
|
| 510 |
-
return probas
|
| 511 |
-
|
| 512 |
-
def fit(self, X):
|
| 513 |
-
return self
|
| 514 |
-
|
| 515 |
-
# Save a real model file
|
| 516 |
-
real_forest = MockIsolationForest()
|
| 517 |
-
with open(model_path, 'wb') as f:
|
| 518 |
-
joblib.dump(real_forest, f)
|
| 519 |
|
| 520 |
# Create a PDF figure file to satisfy the existence check
|
| 521 |
(ad_dir / "ZTF21abbzjeq_w_host_ZTF19aaaaaaa_AD.pdf").touch()
|
| 522 |
|
| 523 |
# Apply comprehensive mocking
|
| 524 |
-
with patch('relaiss.
|
|
|
|
| 525 |
patch('relaiss.anomaly.get_TNS_data', return_value=("MockSN", "Ia", 0.1)), \
|
| 526 |
-
patch('sklearn.ensemble.IsolationForest', return_value=
|
| 527 |
patch('joblib.dump'), \
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
patch('matplotlib.pyplot.figure', return_value=MagicMock()), \
|
| 529 |
patch('matplotlib.pyplot.savefig'), \
|
| 530 |
patch('matplotlib.pyplot.show'), \
|
|
|
|
| 5 |
from relaiss.anomaly import anomaly_detection, train_AD_model
|
| 6 |
import os
|
| 7 |
import joblib
|
| 8 |
+
import pickle
|
| 9 |
from pathlib import Path
|
| 10 |
+
from unittest.mock import patch, MagicMock, mock_open
|
| 11 |
|
| 12 |
@pytest.fixture
|
| 13 |
def sample_preprocessed_df():
|
|
|
|
| 139 |
expected_filename = f"IForest_n=100_c=0.02_m=256.pkl"
|
| 140 |
expected_model_path = str(tmp_path / expected_filename)
|
| 141 |
|
| 142 |
+
# Mock the ReLAISS client and build_dataset_bank to avoid SFD map initialization
|
| 143 |
with patch('relaiss.relaiss.ReLAISS') as mock_client_class, \
|
| 144 |
+
patch('relaiss.features.build_dataset_bank', return_value=sample_preprocessed_df), \
|
| 145 |
patch('joblib.dump') as mock_dump:
|
| 146 |
|
| 147 |
# Configure mock client
|
|
|
|
| 302 |
'mjd_cutoff': np.linspace(58000, 58050, 20),
|
| 303 |
})
|
| 304 |
|
| 305 |
+
# Create a mock IsolationForest object
|
| 306 |
+
mock_forest = MagicMock()
|
| 307 |
+
mock_forest.n_estimators = 100
|
| 308 |
+
mock_forest.contamination = 0.02
|
| 309 |
+
mock_forest.max_samples = 256
|
| 310 |
+
mock_forest.predict.return_value = np.array([1 if np.random.random() > 0.1 else -1 for _ in range(20)])
|
| 311 |
+
mock_forest.decision_function.return_value = np.random.uniform(-0.5, 0.5, 20)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
# Apply comprehensive mocking
|
| 314 |
+
with patch('relaiss.features.build_dataset_bank', return_value=sample_preprocessed_df), \
|
| 315 |
+
patch('relaiss.anomaly.get_timeseries_df', return_value=mock_timeseries_df), \
|
| 316 |
patch('relaiss.anomaly.get_TNS_data', return_value=("MockSN", "Ia", 0.1)), \
|
| 317 |
+
patch('sklearn.ensemble.IsolationForest', return_value=mock_forest), \
|
| 318 |
patch('joblib.dump'), \
|
| 319 |
+
patch('joblib.load', return_value=mock_forest), \
|
| 320 |
+
patch('pickle.load', return_value=mock_forest), \
|
| 321 |
+
patch('builtins.open', mock_open()), \
|
| 322 |
+
patch('os.path.exists', return_value=True), \
|
| 323 |
patch('matplotlib.pyplot.figure', return_value=MagicMock()), \
|
| 324 |
patch('matplotlib.pyplot.savefig'), \
|
| 325 |
patch('matplotlib.pyplot.show'), \
|
|
|
|
| 472 |
mock_swapped_host_df['gKronMag'] = [20.0] * 20
|
| 473 |
mock_swapped_host_df['rKronMag'] = [19.5] * 20
|
| 474 |
|
| 475 |
+
# Create a mock IsolationForest object
|
| 476 |
+
mock_forest = MagicMock()
|
| 477 |
+
mock_forest.n_estimators = 100
|
| 478 |
+
mock_forest.contamination = 0.02
|
| 479 |
+
mock_forest.max_samples = 256
|
| 480 |
+
mock_forest.predict.return_value = np.array([1 if np.random.random() > 0.1 else -1 for _ in range(20)])
|
| 481 |
+
mock_forest.decision_function.return_value = np.random.uniform(-0.5, 0.5, 20)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
|
| 483 |
# Create a PDF figure file to satisfy the existence check
|
| 484 |
(ad_dir / "ZTF21abbzjeq_w_host_ZTF19aaaaaaa_AD.pdf").touch()
|
| 485 |
|
| 486 |
# Apply comprehensive mocking
|
| 487 |
+
with patch('relaiss.features.build_dataset_bank', return_value=combined_df), \
|
| 488 |
+
patch('relaiss.anomaly.get_timeseries_df') as mock_get_ts, \
|
| 489 |
patch('relaiss.anomaly.get_TNS_data', return_value=("MockSN", "Ia", 0.1)), \
|
| 490 |
+
patch('sklearn.ensemble.IsolationForest', return_value=mock_forest), \
|
| 491 |
patch('joblib.dump'), \
|
| 492 |
+
patch('joblib.load', return_value=mock_forest), \
|
| 493 |
+
patch('pickle.load', return_value=mock_forest), \
|
| 494 |
+
patch('builtins.open', mock_open()), \
|
| 495 |
+
patch('os.path.exists', return_value=True), \
|
| 496 |
patch('matplotlib.pyplot.figure', return_value=MagicMock()), \
|
| 497 |
patch('matplotlib.pyplot.savefig'), \
|
| 498 |
patch('matplotlib.pyplot.show'), \
|