alexandergagliano commited on
Commit
8ff867d
·
1 Parent(s): 7402f92

Update tests

Browse files
Files changed (1) hide show
  1. tests/test_ad.py +32 -64
tests/test_ad.py CHANGED
@@ -5,8 +5,9 @@ import relaiss as rl
5
  from relaiss.anomaly import anomaly_detection, train_AD_model
6
  import os
7
  import joblib
 
8
  from pathlib import Path
9
- from unittest.mock import patch, MagicMock
10
 
11
  @pytest.fixture
12
  def sample_preprocessed_df():
@@ -138,8 +139,9 @@ def test_train_AD_model_with_raw_data(tmp_path, sample_preprocessed_df):
138
  expected_filename = f"IForest_n=100_c=0.02_m=256.pkl"
139
  expected_model_path = str(tmp_path / expected_filename)
140
 
141
- # Mock the ReLAISS client
142
  with patch('relaiss.relaiss.ReLAISS') as mock_client_class, \
 
143
  patch('joblib.dump') as mock_dump:
144
 
145
  # Configure mock client
@@ -300,41 +302,24 @@ def test_anomaly_detection_basic(sample_preprocessed_df, tmp_path):
300
  'mjd_cutoff': np.linspace(58000, 58050, 20),
301
  })
302
 
303
- # Mock isolation forest with predict_proba
304
- class MockIsolationForest:
305
- def __init__(self, n_estimators=100, contamination=0.02, max_samples=256):
306
- self.n_estimators = n_estimators
307
- self.contamination = contamination
308
- self.max_samples = max_samples
309
-
310
- def predict(self, X):
311
- return np.array([1 if np.random.random() > 0.1 else -1 for _ in range(len(X))])
312
-
313
- def decision_function(self, X):
314
- return np.random.uniform(-0.5, 0.5, len(X))
315
-
316
- def predict_proba(self, X):
317
- # Add predict_proba method that isolation forest doesn't normally have
318
- n_samples = X.shape[0]
319
- probas = np.zeros((n_samples, 2))
320
- # Random probabilities that sum to 1 for each sample
321
- probas[:, 0] = np.random.uniform(0.6, 0.9, n_samples)
322
- probas[:, 1] = 1 - probas[:, 0]
323
- return probas
324
-
325
- def fit(self, X):
326
- return self
327
-
328
- # Save a real model file
329
- real_forest = MockIsolationForest()
330
- with open(model_path, 'wb') as f:
331
- joblib.dump(real_forest, f)
332
 
333
  # Apply comprehensive mocking
334
- with patch('relaiss.anomaly.get_timeseries_df', return_value=mock_timeseries_df), \
 
335
  patch('relaiss.anomaly.get_TNS_data', return_value=("MockSN", "Ia", 0.1)), \
336
- patch('sklearn.ensemble.IsolationForest', return_value=MockIsolationForest()), \
337
  patch('joblib.dump'), \
 
 
 
 
338
  patch('matplotlib.pyplot.figure', return_value=MagicMock()), \
339
  patch('matplotlib.pyplot.savefig'), \
340
  patch('matplotlib.pyplot.show'), \
@@ -487,44 +472,27 @@ def test_anomaly_detection_with_host_swap(sample_preprocessed_df, tmp_path):
487
  mock_swapped_host_df['gKronMag'] = [20.0] * 20
488
  mock_swapped_host_df['rKronMag'] = [19.5] * 20
489
 
490
- # Mock the isolation forest model with predict_proba method
491
- class MockIsolationForest:
492
- def __init__(self, n_estimators=100, contamination=0.02, max_samples=256):
493
- self.n_estimators = n_estimators
494
- self.contamination = contamination
495
- self.max_samples = max_samples
496
-
497
- def predict(self, X):
498
- return np.array([1 if np.random.random() > 0.1 else -1 for _ in range(len(X))])
499
-
500
- def decision_function(self, X):
501
- return np.random.uniform(-0.5, 0.5, len(X))
502
-
503
- def predict_proba(self, X):
504
- # Add predict_proba method that isolation forest doesn't normally have
505
- n_samples = X.shape[0]
506
- probas = np.zeros((n_samples, 2))
507
- # Random probabilities that sum to 1 for each sample
508
- probas[:, 0] = np.random.uniform(0.6, 0.9, n_samples)
509
- probas[:, 1] = 1 - probas[:, 0]
510
- return probas
511
-
512
- def fit(self, X):
513
- return self
514
-
515
- # Save a real model file
516
- real_forest = MockIsolationForest()
517
- with open(model_path, 'wb') as f:
518
- joblib.dump(real_forest, f)
519
 
520
  # Create a PDF figure file to satisfy the existence check
521
  (ad_dir / "ZTF21abbzjeq_w_host_ZTF19aaaaaaa_AD.pdf").touch()
522
 
523
  # Apply comprehensive mocking
524
- with patch('relaiss.anomaly.get_timeseries_df') as mock_get_ts, \
 
525
  patch('relaiss.anomaly.get_TNS_data', return_value=("MockSN", "Ia", 0.1)), \
526
- patch('sklearn.ensemble.IsolationForest', return_value=MockIsolationForest()), \
527
  patch('joblib.dump'), \
 
 
 
 
528
  patch('matplotlib.pyplot.figure', return_value=MagicMock()), \
529
  patch('matplotlib.pyplot.savefig'), \
530
  patch('matplotlib.pyplot.show'), \
 
5
  from relaiss.anomaly import anomaly_detection, train_AD_model
6
  import os
7
  import joblib
8
+ import pickle
9
  from pathlib import Path
10
+ from unittest.mock import patch, MagicMock, mock_open
11
 
12
  @pytest.fixture
13
  def sample_preprocessed_df():
 
139
  expected_filename = f"IForest_n=100_c=0.02_m=256.pkl"
140
  expected_model_path = str(tmp_path / expected_filename)
141
 
142
+ # Mock the ReLAISS client and build_dataset_bank to avoid SFD map initialization
143
  with patch('relaiss.relaiss.ReLAISS') as mock_client_class, \
144
+ patch('relaiss.features.build_dataset_bank', return_value=sample_preprocessed_df), \
145
  patch('joblib.dump') as mock_dump:
146
 
147
  # Configure mock client
 
302
  'mjd_cutoff': np.linspace(58000, 58050, 20),
303
  })
304
 
305
+ # Create a mock IsolationForest object
306
+ mock_forest = MagicMock()
307
+ mock_forest.n_estimators = 100
308
+ mock_forest.contamination = 0.02
309
+ mock_forest.max_samples = 256
310
+ mock_forest.predict.return_value = np.array([1 if np.random.random() > 0.1 else -1 for _ in range(20)])
311
+ mock_forest.decision_function.return_value = np.random.uniform(-0.5, 0.5, 20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
  # Apply comprehensive mocking
314
+ with patch('relaiss.features.build_dataset_bank', return_value=sample_preprocessed_df), \
315
+ patch('relaiss.anomaly.get_timeseries_df', return_value=mock_timeseries_df), \
316
  patch('relaiss.anomaly.get_TNS_data', return_value=("MockSN", "Ia", 0.1)), \
317
+ patch('sklearn.ensemble.IsolationForest', return_value=mock_forest), \
318
  patch('joblib.dump'), \
319
+ patch('joblib.load', return_value=mock_forest), \
320
+ patch('pickle.load', return_value=mock_forest), \
321
+ patch('builtins.open', mock_open()), \
322
+ patch('os.path.exists', return_value=True), \
323
  patch('matplotlib.pyplot.figure', return_value=MagicMock()), \
324
  patch('matplotlib.pyplot.savefig'), \
325
  patch('matplotlib.pyplot.show'), \
 
472
  mock_swapped_host_df['gKronMag'] = [20.0] * 20
473
  mock_swapped_host_df['rKronMag'] = [19.5] * 20
474
 
475
+ # Create a mock IsolationForest object
476
+ mock_forest = MagicMock()
477
+ mock_forest.n_estimators = 100
478
+ mock_forest.contamination = 0.02
479
+ mock_forest.max_samples = 256
480
+ mock_forest.predict.return_value = np.array([1 if np.random.random() > 0.1 else -1 for _ in range(20)])
481
+ mock_forest.decision_function.return_value = np.random.uniform(-0.5, 0.5, 20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
 
483
  # Create a PDF figure file to satisfy the existence check
484
  (ad_dir / "ZTF21abbzjeq_w_host_ZTF19aaaaaaa_AD.pdf").touch()
485
 
486
  # Apply comprehensive mocking
487
+ with patch('relaiss.features.build_dataset_bank', return_value=combined_df), \
488
+ patch('relaiss.anomaly.get_timeseries_df') as mock_get_ts, \
489
  patch('relaiss.anomaly.get_TNS_data', return_value=("MockSN", "Ia", 0.1)), \
490
+ patch('sklearn.ensemble.IsolationForest', return_value=mock_forest), \
491
  patch('joblib.dump'), \
492
+ patch('joblib.load', return_value=mock_forest), \
493
+ patch('pickle.load', return_value=mock_forest), \
494
+ patch('builtins.open', mock_open()), \
495
+ patch('os.path.exists', return_value=True), \
496
  patch('matplotlib.pyplot.figure', return_value=MagicMock()), \
497
  patch('matplotlib.pyplot.savefig'), \
498
  patch('matplotlib.pyplot.show'), \