Upload FlareDetectionPipeline
Browse files- config.json +10 -1
- flare_detection.py +233 -0
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "
|
3 |
"architectures": [
|
4 |
"FCN4FlareModel"
|
5 |
],
|
@@ -7,6 +7,15 @@
|
|
7 |
"AutoConfig": "Maxwell-Jia/fcn4flare--configuration_fcn4flare.FCN4FlareConfig",
|
8 |
"AutoModel": "Maxwell-Jia/fcn4flare--modeling_fcn4flare.FCN4FlareModel"
|
9 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
"depth": 10,
|
11 |
"dilation": [
|
12 |
1,
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "Maxwell-Jia/fcn4flare",
|
3 |
"architectures": [
|
4 |
"FCN4FlareModel"
|
5 |
],
|
|
|
7 |
"AutoConfig": "Maxwell-Jia/fcn4flare--configuration_fcn4flare.FCN4FlareConfig",
|
8 |
"AutoModel": "Maxwell-Jia/fcn4flare--modeling_fcn4flare.FCN4FlareModel"
|
9 |
},
|
10 |
+
"custom_pipelines": {
|
11 |
+
"flare_detection": {
|
12 |
+
"impl": "flare_detection.FlareDetectionPipeline",
|
13 |
+
"pt": [
|
14 |
+
"AutoModel"
|
15 |
+
],
|
16 |
+
"tf": []
|
17 |
+
}
|
18 |
+
},
|
19 |
"depth": 10,
|
20 |
"dilation": [
|
21 |
1,
|
flare_detection.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Union
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from transformers import Pipeline
|
5 |
+
from astropy.io import fits
|
6 |
+
|
7 |
+
class FlareDetectionPipeline(Pipeline):
|
8 |
+
def __init__(self, *args, **kwargs):
|
9 |
+
super().__init__(*args, **kwargs)
|
10 |
+
self.call_count = 0
|
11 |
+
|
12 |
+
def _sanitize_parameters(self, **kwargs):
|
13 |
+
preprocess_kwargs = {}
|
14 |
+
postprocess_kwargs = {}
|
15 |
+
|
16 |
+
# Add parameters that need to be passed to specific steps
|
17 |
+
return preprocess_kwargs, {}, postprocess_kwargs
|
18 |
+
|
19 |
+
def preprocess(self, light_curve: Union[np.ndarray, str, List[str]], **kwargs) -> Dict[str, torch.Tensor]:
|
20 |
+
"""Preprocess the input light curve from FITS files.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
light_curve: Single FITS file path, list of FITS file paths, or numpy array
|
24 |
+
"""
|
25 |
+
# Convert single path to list
|
26 |
+
if isinstance(light_curve, str):
|
27 |
+
light_curve = [light_curve]
|
28 |
+
|
29 |
+
# Handle list of FITS file paths
|
30 |
+
if isinstance(light_curve, list) and isinstance(light_curve[0], str):
|
31 |
+
# Read data from all FITS files
|
32 |
+
flux_data = []
|
33 |
+
times_data = []
|
34 |
+
lengths = [] # Store lengths of each light curve
|
35 |
+
|
36 |
+
# First pass: get max length and collect data
|
37 |
+
max_length = 0
|
38 |
+
for fits_path in light_curve:
|
39 |
+
with fits.open(fits_path) as hdul:
|
40 |
+
time = hdul[1].data['TIME'].astype(np.float32)
|
41 |
+
flux = hdul[1].data['PDCSAP_FLUX'].astype(np.float32)
|
42 |
+
# Normalize flux
|
43 |
+
flux = flux / np.nanmedian(flux)
|
44 |
+
|
45 |
+
max_length = max(max_length, len(flux))
|
46 |
+
lengths.append(len(flux))
|
47 |
+
flux_data.append(flux)
|
48 |
+
times_data.append(time)
|
49 |
+
|
50 |
+
# Second pass: pad sequences
|
51 |
+
padded_flux = []
|
52 |
+
padded_times = []
|
53 |
+
sequence_mask = []
|
54 |
+
|
55 |
+
for flux, time, length in zip(flux_data, times_data, lengths):
|
56 |
+
# Create padding
|
57 |
+
pad_length = max_length - length
|
58 |
+
|
59 |
+
# Pad flux and time arrays
|
60 |
+
padded_f = np.pad(flux, (0, pad_length), mode='constant', constant_values=np.nan)
|
61 |
+
padded_t = np.pad(time, (0, pad_length), mode='constant', constant_values=np.nan)
|
62 |
+
|
63 |
+
# Create mask (1 for real values, 0 for padding)
|
64 |
+
mask = np.ones(length)
|
65 |
+
mask = np.pad(mask, (0, pad_length), mode='constant', constant_values=0)
|
66 |
+
|
67 |
+
padded_flux.append(padded_f)
|
68 |
+
padded_times.append(padded_t)
|
69 |
+
sequence_mask.append(mask)
|
70 |
+
|
71 |
+
# Store time data as attribute for use in postprocessing
|
72 |
+
self.time_series = np.array(padded_times)
|
73 |
+
# Convert to arrays
|
74 |
+
flux_array = np.array(padded_flux)
|
75 |
+
sequence_mask = np.array(sequence_mask)
|
76 |
+
|
77 |
+
# Add channel dimension
|
78 |
+
flux_array = flux_array.reshape(flux_array.shape[0], flux_array.shape[1], 1)
|
79 |
+
|
80 |
+
# Convert to torch tensors
|
81 |
+
inputs = torch.tensor(flux_array, dtype=torch.float32)
|
82 |
+
mask = torch.tensor(sequence_mask, dtype=torch.float32)
|
83 |
+
|
84 |
+
return {
|
85 |
+
"input_features": inputs,
|
86 |
+
"sequence_mask": mask
|
87 |
+
}
|
88 |
+
|
89 |
+
def _forward(self, model_inputs, **forward_params):
|
90 |
+
"""Forward pass through the model.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
model_inputs: Dictionary containing input tensors
|
94 |
+
forward_params: Additional parameters for the forward pass
|
95 |
+
"""
|
96 |
+
if model_inputs is None:
|
97 |
+
raise ValueError("model_inputs cannot be None. Check if preprocess method is returning correct dictionary.")
|
98 |
+
|
99 |
+
if "input_features" not in model_inputs:
|
100 |
+
raise KeyError("model_inputs must contain 'input_features' key.")
|
101 |
+
|
102 |
+
# Save input_features for use in postprocessing
|
103 |
+
self.input_features = model_inputs["input_features"]
|
104 |
+
|
105 |
+
# Ensure input_features is properly passed to the model
|
106 |
+
return self.model(
|
107 |
+
input_features=model_inputs["input_features"],
|
108 |
+
sequence_mask=model_inputs.get("sequence_mask", None),
|
109 |
+
return_dict=True
|
110 |
+
)
|
111 |
+
|
112 |
+
def postprocess(self, model_outputs, **kwargs):
|
113 |
+
"""
|
114 |
+
Postprocess the model outputs to detect flare events.
|
115 |
+
Returns a list of dictionaries containing flare events information.
|
116 |
+
"""
|
117 |
+
logits = model_outputs.logits
|
118 |
+
predictions = torch.sigmoid(logits).squeeze(-1)
|
119 |
+
binary_predictions = (predictions > 0.5).long()
|
120 |
+
|
121 |
+
# Convert to numpy for processing
|
122 |
+
predictions_np = binary_predictions.cpu().numpy()
|
123 |
+
flux_data = self.input_features.cpu().numpy()
|
124 |
+
|
125 |
+
flare_events = []
|
126 |
+
|
127 |
+
def is_valid_flare(flux, start_idx, end_idx, peak_idx):
|
128 |
+
"""Helper function to validate flare events
|
129 |
+
|
130 |
+
Args:
|
131 |
+
flux: Array of flux values
|
132 |
+
start_idx: Start index of potential flare
|
133 |
+
end_idx: End index of potential flare
|
134 |
+
peak_idx: Peak index of potential flare
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
bool: True if the event is a valid flare, False otherwise
|
138 |
+
"""
|
139 |
+
# Duration of a flare should be longer than 3 cadences
|
140 |
+
if end_idx - start_idx < 2:
|
141 |
+
return False
|
142 |
+
|
143 |
+
try:
|
144 |
+
# If start time is the peak time, flux[start] must be greater than flux[start-1]
|
145 |
+
if peak_idx == start_idx and flux[peak_idx] <= flux[peak_idx - 1]:
|
146 |
+
return False
|
147 |
+
|
148 |
+
# Time for flux to decrease should be longer than that to increase
|
149 |
+
if end_idx - peak_idx <= peak_idx - start_idx:
|
150 |
+
return False
|
151 |
+
|
152 |
+
# Check flux level consistency before and after flare
|
153 |
+
alter = (flux[peak_idx] - flux[start_idx - 2]) / (flux[peak_idx] - flux[end_idx + 2] + 1e-8)
|
154 |
+
# Flux level should be similar before and after flare
|
155 |
+
if alter < 0.5 or alter > 2 or np.isnan(alter):
|
156 |
+
return False
|
157 |
+
|
158 |
+
# Check if the slope before peak is too steep
|
159 |
+
# if np.abs(flux[peak_idx] - flux[peak_idx-1]) < 1.2 * np.abs(flux[peak_idx-1] - flux[peak_idx-2]):
|
160 |
+
# return False
|
161 |
+
|
162 |
+
except (IndexError, ValueError):
|
163 |
+
return False
|
164 |
+
|
165 |
+
return True
|
166 |
+
|
167 |
+
for i in range(predictions_np.shape[0]):
|
168 |
+
pred = predictions_np[i]
|
169 |
+
flux = flux_data[i, :, 0] # Get flux data
|
170 |
+
flare_idx = np.where(pred == 1)[0]
|
171 |
+
|
172 |
+
if len(flare_idx) == 0:
|
173 |
+
continue
|
174 |
+
|
175 |
+
# Find continuous segments
|
176 |
+
splits = np.where(np.diff(flare_idx) > 1)[0] + 1
|
177 |
+
segments = np.split(flare_idx, splits)
|
178 |
+
|
179 |
+
for segment in segments:
|
180 |
+
# Skip short segments early
|
181 |
+
if len(segment) < 3:
|
182 |
+
continue
|
183 |
+
|
184 |
+
start_idx = segment[0]
|
185 |
+
end_idx = segment[-1]
|
186 |
+
|
187 |
+
# Find peak within segment
|
188 |
+
segment_flux = flux[start_idx:end_idx+1]
|
189 |
+
peak_idx = np.argmax(segment_flux) + start_idx
|
190 |
+
|
191 |
+
# Validate flare characteristics
|
192 |
+
if not is_valid_flare(flux, start_idx, end_idx, peak_idx):
|
193 |
+
continue
|
194 |
+
|
195 |
+
# Valid flare event found
|
196 |
+
start_time = float(self.time_series[i][start_idx])
|
197 |
+
end_time = float(self.time_series[i][end_idx])
|
198 |
+
duration = end_time - start_time
|
199 |
+
event = {
|
200 |
+
"start_idx": int(start_idx),
|
201 |
+
"peak_idx": int(peak_idx),
|
202 |
+
"end_idx": int(end_idx),
|
203 |
+
"start_time": start_time,
|
204 |
+
"peak_time": float(self.time_series[i][peak_idx]),
|
205 |
+
"end_time": end_time,
|
206 |
+
"duration": duration,
|
207 |
+
"confidence": float(predictions[i, segment].mean()),
|
208 |
+
}
|
209 |
+
flare_events.append(event)
|
210 |
+
|
211 |
+
return flare_events
|
212 |
+
|
213 |
+
def load_flare_detection_pipeline(
|
214 |
+
model_name: str = "Maxwell-Jia/fcn4flare",
|
215 |
+
device: int = -1,
|
216 |
+
**kwargs
|
217 |
+
) -> FlareDetectionPipeline:
|
218 |
+
"""
|
219 |
+
Load a flare detection pipeline.
|
220 |
+
|
221 |
+
Args:
|
222 |
+
model_name (str): The model name or path to load
|
223 |
+
device (int): Device to use (-1 for CPU, GPU number otherwise)
|
224 |
+
**kwargs: Additional arguments to pass to the pipeline
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
FlareDetectionPipeline: A pipeline for flare detection
|
228 |
+
"""
|
229 |
+
return FlareDetectionPipeline(
|
230 |
+
model=model_name,
|
231 |
+
device=device,
|
232 |
+
**kwargs
|
233 |
+
)
|