Maxwell-Jia commited on
Commit
fce4e7c
·
verified ·
1 Parent(s): 2f7b8e0

Upload FlareDetectionPipeline

Browse files
Files changed (2) hide show
  1. config.json +10 -1
  2. flare_detection.py +233 -0
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "results/fcn4flare-kepler_flare/checkpoint-2085",
3
  "architectures": [
4
  "FCN4FlareModel"
5
  ],
@@ -7,6 +7,15 @@
7
  "AutoConfig": "Maxwell-Jia/fcn4flare--configuration_fcn4flare.FCN4FlareConfig",
8
  "AutoModel": "Maxwell-Jia/fcn4flare--modeling_fcn4flare.FCN4FlareModel"
9
  },
 
 
 
 
 
 
 
 
 
10
  "depth": 10,
11
  "dilation": [
12
  1,
 
1
  {
2
+ "_name_or_path": "Maxwell-Jia/fcn4flare",
3
  "architectures": [
4
  "FCN4FlareModel"
5
  ],
 
7
  "AutoConfig": "Maxwell-Jia/fcn4flare--configuration_fcn4flare.FCN4FlareConfig",
8
  "AutoModel": "Maxwell-Jia/fcn4flare--modeling_fcn4flare.FCN4FlareModel"
9
  },
10
+ "custom_pipelines": {
11
+ "flare_detection": {
12
+ "impl": "flare_detection.FlareDetectionPipeline",
13
+ "pt": [
14
+ "AutoModel"
15
+ ],
16
+ "tf": []
17
+ }
18
+ },
19
  "depth": 10,
20
  "dilation": [
21
  1,
flare_detection.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Union
2
+ import numpy as np
3
+ import torch
4
+ from transformers import Pipeline
5
+ from astropy.io import fits
6
+
7
+ class FlareDetectionPipeline(Pipeline):
8
+ def __init__(self, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.call_count = 0
11
+
12
+ def _sanitize_parameters(self, **kwargs):
13
+ preprocess_kwargs = {}
14
+ postprocess_kwargs = {}
15
+
16
+ # Add parameters that need to be passed to specific steps
17
+ return preprocess_kwargs, {}, postprocess_kwargs
18
+
19
+ def preprocess(self, light_curve: Union[np.ndarray, str, List[str]], **kwargs) -> Dict[str, torch.Tensor]:
20
+ """Preprocess the input light curve from FITS files.
21
+
22
+ Args:
23
+ light_curve: Single FITS file path, list of FITS file paths, or numpy array
24
+ """
25
+ # Convert single path to list
26
+ if isinstance(light_curve, str):
27
+ light_curve = [light_curve]
28
+
29
+ # Handle list of FITS file paths
30
+ if isinstance(light_curve, list) and isinstance(light_curve[0], str):
31
+ # Read data from all FITS files
32
+ flux_data = []
33
+ times_data = []
34
+ lengths = [] # Store lengths of each light curve
35
+
36
+ # First pass: get max length and collect data
37
+ max_length = 0
38
+ for fits_path in light_curve:
39
+ with fits.open(fits_path) as hdul:
40
+ time = hdul[1].data['TIME'].astype(np.float32)
41
+ flux = hdul[1].data['PDCSAP_FLUX'].astype(np.float32)
42
+ # Normalize flux
43
+ flux = flux / np.nanmedian(flux)
44
+
45
+ max_length = max(max_length, len(flux))
46
+ lengths.append(len(flux))
47
+ flux_data.append(flux)
48
+ times_data.append(time)
49
+
50
+ # Second pass: pad sequences
51
+ padded_flux = []
52
+ padded_times = []
53
+ sequence_mask = []
54
+
55
+ for flux, time, length in zip(flux_data, times_data, lengths):
56
+ # Create padding
57
+ pad_length = max_length - length
58
+
59
+ # Pad flux and time arrays
60
+ padded_f = np.pad(flux, (0, pad_length), mode='constant', constant_values=np.nan)
61
+ padded_t = np.pad(time, (0, pad_length), mode='constant', constant_values=np.nan)
62
+
63
+ # Create mask (1 for real values, 0 for padding)
64
+ mask = np.ones(length)
65
+ mask = np.pad(mask, (0, pad_length), mode='constant', constant_values=0)
66
+
67
+ padded_flux.append(padded_f)
68
+ padded_times.append(padded_t)
69
+ sequence_mask.append(mask)
70
+
71
+ # Store time data as attribute for use in postprocessing
72
+ self.time_series = np.array(padded_times)
73
+ # Convert to arrays
74
+ flux_array = np.array(padded_flux)
75
+ sequence_mask = np.array(sequence_mask)
76
+
77
+ # Add channel dimension
78
+ flux_array = flux_array.reshape(flux_array.shape[0], flux_array.shape[1], 1)
79
+
80
+ # Convert to torch tensors
81
+ inputs = torch.tensor(flux_array, dtype=torch.float32)
82
+ mask = torch.tensor(sequence_mask, dtype=torch.float32)
83
+
84
+ return {
85
+ "input_features": inputs,
86
+ "sequence_mask": mask
87
+ }
88
+
89
+ def _forward(self, model_inputs, **forward_params):
90
+ """Forward pass through the model.
91
+
92
+ Args:
93
+ model_inputs: Dictionary containing input tensors
94
+ forward_params: Additional parameters for the forward pass
95
+ """
96
+ if model_inputs is None:
97
+ raise ValueError("model_inputs cannot be None. Check if preprocess method is returning correct dictionary.")
98
+
99
+ if "input_features" not in model_inputs:
100
+ raise KeyError("model_inputs must contain 'input_features' key.")
101
+
102
+ # Save input_features for use in postprocessing
103
+ self.input_features = model_inputs["input_features"]
104
+
105
+ # Ensure input_features is properly passed to the model
106
+ return self.model(
107
+ input_features=model_inputs["input_features"],
108
+ sequence_mask=model_inputs.get("sequence_mask", None),
109
+ return_dict=True
110
+ )
111
+
112
+ def postprocess(self, model_outputs, **kwargs):
113
+ """
114
+ Postprocess the model outputs to detect flare events.
115
+ Returns a list of dictionaries containing flare events information.
116
+ """
117
+ logits = model_outputs.logits
118
+ predictions = torch.sigmoid(logits).squeeze(-1)
119
+ binary_predictions = (predictions > 0.5).long()
120
+
121
+ # Convert to numpy for processing
122
+ predictions_np = binary_predictions.cpu().numpy()
123
+ flux_data = self.input_features.cpu().numpy()
124
+
125
+ flare_events = []
126
+
127
+ def is_valid_flare(flux, start_idx, end_idx, peak_idx):
128
+ """Helper function to validate flare events
129
+
130
+ Args:
131
+ flux: Array of flux values
132
+ start_idx: Start index of potential flare
133
+ end_idx: End index of potential flare
134
+ peak_idx: Peak index of potential flare
135
+
136
+ Returns:
137
+ bool: True if the event is a valid flare, False otherwise
138
+ """
139
+ # Duration of a flare should be longer than 3 cadences
140
+ if end_idx - start_idx < 2:
141
+ return False
142
+
143
+ try:
144
+ # If start time is the peak time, flux[start] must be greater than flux[start-1]
145
+ if peak_idx == start_idx and flux[peak_idx] <= flux[peak_idx - 1]:
146
+ return False
147
+
148
+ # Time for flux to decrease should be longer than that to increase
149
+ if end_idx - peak_idx <= peak_idx - start_idx:
150
+ return False
151
+
152
+ # Check flux level consistency before and after flare
153
+ alter = (flux[peak_idx] - flux[start_idx - 2]) / (flux[peak_idx] - flux[end_idx + 2] + 1e-8)
154
+ # Flux level should be similar before and after flare
155
+ if alter < 0.5 or alter > 2 or np.isnan(alter):
156
+ return False
157
+
158
+ # Check if the slope before peak is too steep
159
+ # if np.abs(flux[peak_idx] - flux[peak_idx-1]) < 1.2 * np.abs(flux[peak_idx-1] - flux[peak_idx-2]):
160
+ # return False
161
+
162
+ except (IndexError, ValueError):
163
+ return False
164
+
165
+ return True
166
+
167
+ for i in range(predictions_np.shape[0]):
168
+ pred = predictions_np[i]
169
+ flux = flux_data[i, :, 0] # Get flux data
170
+ flare_idx = np.where(pred == 1)[0]
171
+
172
+ if len(flare_idx) == 0:
173
+ continue
174
+
175
+ # Find continuous segments
176
+ splits = np.where(np.diff(flare_idx) > 1)[0] + 1
177
+ segments = np.split(flare_idx, splits)
178
+
179
+ for segment in segments:
180
+ # Skip short segments early
181
+ if len(segment) < 3:
182
+ continue
183
+
184
+ start_idx = segment[0]
185
+ end_idx = segment[-1]
186
+
187
+ # Find peak within segment
188
+ segment_flux = flux[start_idx:end_idx+1]
189
+ peak_idx = np.argmax(segment_flux) + start_idx
190
+
191
+ # Validate flare characteristics
192
+ if not is_valid_flare(flux, start_idx, end_idx, peak_idx):
193
+ continue
194
+
195
+ # Valid flare event found
196
+ start_time = float(self.time_series[i][start_idx])
197
+ end_time = float(self.time_series[i][end_idx])
198
+ duration = end_time - start_time
199
+ event = {
200
+ "start_idx": int(start_idx),
201
+ "peak_idx": int(peak_idx),
202
+ "end_idx": int(end_idx),
203
+ "start_time": start_time,
204
+ "peak_time": float(self.time_series[i][peak_idx]),
205
+ "end_time": end_time,
206
+ "duration": duration,
207
+ "confidence": float(predictions[i, segment].mean()),
208
+ }
209
+ flare_events.append(event)
210
+
211
+ return flare_events
212
+
213
+ def load_flare_detection_pipeline(
214
+ model_name: str = "Maxwell-Jia/fcn4flare",
215
+ device: int = -1,
216
+ **kwargs
217
+ ) -> FlareDetectionPipeline:
218
+ """
219
+ Load a flare detection pipeline.
220
+
221
+ Args:
222
+ model_name (str): The model name or path to load
223
+ device (int): Device to use (-1 for CPU, GPU number otherwise)
224
+ **kwargs: Additional arguments to pass to the pipeline
225
+
226
+ Returns:
227
+ FlareDetectionPipeline: A pipeline for flare detection
228
+ """
229
+ return FlareDetectionPipeline(
230
+ model=model_name,
231
+ device=device,
232
+ **kwargs
233
+ )