JoseRFJunior commited on
Commit
6a77297
·
verified ·
1 Parent(s): 3e017bb

Upload ZetaNet-v2.py

Browse files
Files changed (1) hide show
  1. ZetaNet-v2.py +383 -0
ZetaNet-v2.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import matplotlib.pyplot as plt
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.preprocessing import StandardScaler
9
+ from time import time
10
+ import warnings
11
+ warnings.filterwarnings('ignore')
12
+
13
+ class ImprovedZetaNet(nn.Module):
14
+ def __init__(self, input_size=2, hidden_sizes=[128, 256, 128, 64], output_size=2, dropout_rate=0.1):
15
+ super(ImprovedZetaNet, self).__init__()
16
+
17
+ # Construir camadas dinamicamente
18
+ layers = []
19
+ prev_size = input_size
20
+
21
+ for hidden_size in hidden_sizes:
22
+ layers.extend([
23
+ nn.Linear(prev_size, hidden_size),
24
+ nn.BatchNorm1d(hidden_size),
25
+ nn.ReLU(),
26
+ nn.Dropout(dropout_rate)
27
+ ])
28
+ prev_size = hidden_size
29
+
30
+ # Camada de saída sem ativação
31
+ layers.append(nn.Linear(prev_size, output_size))
32
+
33
+ self.network = nn.Sequential(*layers)
34
+
35
+ # Inicialização Xavier/Glorot
36
+ self._initialize_weights()
37
+
38
+ def _initialize_weights(self):
39
+ for module in self.modules():
40
+ if isinstance(module, nn.Linear):
41
+ nn.init.xavier_normal_(module.weight)
42
+ if module.bias is not None:
43
+ nn.init.constant_(module.bias, 0)
44
+
45
+ def forward(self, x):
46
+ return self.network(x)
47
+
48
+ class ZetaTrainer:
49
+ def __init__(self, model, device='cpu'):
50
+ self.model = model.to(device)
51
+ self.device = device
52
+ self.train_losses = []
53
+ self.val_losses = []
54
+
55
+ def train_epoch(self, train_loader, optimizer, criterion):
56
+ self.model.train()
57
+ total_loss = 0
58
+ num_batches = 0
59
+
60
+ for batch_x, batch_y in train_loader:
61
+ batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
62
+
63
+ optimizer.zero_grad()
64
+ predictions = self.model(batch_x)
65
+ loss = criterion(predictions, batch_y)
66
+ loss.backward()
67
+
68
+ # Gradient clipping para estabilidade
69
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
70
+
71
+ optimizer.step()
72
+
73
+ total_loss += loss.item()
74
+ num_batches += 1
75
+
76
+ return total_loss / num_batches
77
+
78
+ def validate(self, val_loader, criterion):
79
+ self.model.eval()
80
+ total_loss = 0
81
+ num_batches = 0
82
+
83
+ with torch.no_grad():
84
+ for batch_x, batch_y in val_loader:
85
+ batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
86
+ predictions = self.model(batch_x)
87
+ loss = criterion(predictions, batch_y)
88
+ total_loss += loss.item()
89
+ num_batches += 1
90
+
91
+ return total_loss / num_batches
92
+
93
+ def train(self, train_loader, val_loader, epochs=200, learning_rate=0.001, patience=20):
94
+ # Usar Adam com weight decay
95
+ optimizer = optim.AdamW(self.model.parameters(), lr=learning_rate, weight_decay=1e-5)
96
+
97
+ # Learning rate scheduler
98
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
99
+ optimizer, mode='min', factor=0.5, patience=10, verbose=True
100
+ )
101
+
102
+ criterion = nn.MSELoss()
103
+ best_val_loss = float('inf')
104
+ patience_counter = 0
105
+
106
+ print(f"Iniciando treinamento por {epochs} épocas...")
107
+ print("-" * 60)
108
+
109
+ for epoch in range(epochs):
110
+ # Treinar
111
+ train_loss = self.train_epoch(train_loader, optimizer, criterion)
112
+
113
+ # Validar
114
+ val_loss = self.validate(val_loader, criterion)
115
+
116
+ # Atualizar scheduler
117
+ scheduler.step(val_loss)
118
+
119
+ # Salvar histórico
120
+ self.train_losses.append(train_loss)
121
+ self.val_losses.append(val_loss)
122
+
123
+ # Early stopping
124
+ if val_loss < best_val_loss:
125
+ best_val_loss = val_loss
126
+ patience_counter = 0
127
+ # Salvar melhor modelo
128
+ torch.save(self.model.state_dict(), 'best_zetanet.pth')
129
+ else:
130
+ patience_counter += 1
131
+
132
+ # Print progress
133
+ if (epoch + 1) % 20 == 0 or epoch == 0:
134
+ current_lr = optimizer.param_groups[0]['lr']
135
+ print(f"Época {epoch+1:3d}/{epochs} | "
136
+ f"Train Loss: {train_loss:.6f} | "
137
+ f"Val Loss: {val_loss:.6f} | "
138
+ f"LR: {current_lr:.2e}")
139
+
140
+ # Early stopping
141
+ if patience_counter >= patience:
142
+ print(f"\nEarly stopping na época {epoch+1}")
143
+ break
144
+
145
+ # Carregar melhor modelo
146
+ self.model.load_state_dict(torch.load('best_zetanet.pth'))
147
+ print(f"\nTreinamento concluído! Melhor perda de validação: {best_val_loss:.6f}")
148
+
149
+ def parse_complex_improved(value):
150
+ """Função melhorada para parsing de números complexos"""
151
+ if pd.isna(value):
152
+ return np.nan
153
+
154
+ value = str(value).strip()
155
+
156
+ # Remover parênteses
157
+ value = value.replace('(', '').replace(')', '')
158
+
159
+ # Substituir vírgulas por pontos
160
+ value = value.replace(',', '.')
161
+
162
+ # Casos especiais
163
+ if value == '' or value.lower() == 'nan':
164
+ return np.nan
165
+
166
+ try:
167
+ # Se não tem 'j' ou 'i', adicionar 'j' no final
168
+ if 'j' not in value.lower() and 'i' not in value.lower():
169
+ if '+' in value or '-' in value[1:]: # Tem parte real e imaginária
170
+ value += 'j'
171
+ else: # Só parte real
172
+ return complex(float(value), 0)
173
+
174
+ # Substituir 'i' por 'j'
175
+ value = value.replace('i', 'j')
176
+
177
+ return complex(value)
178
+ except (ValueError, TypeError):
179
+ return np.nan
180
+
181
+ def load_and_preprocess_data(filepath, test_size=0.2, random_state=42):
182
+ """Carrega e preprocessa os dados com melhor tratamento de erros"""
183
+ print("Carregando dados...")
184
+
185
+ try:
186
+ data = pd.read_csv(filepath)
187
+ print(f"Dados carregados: {len(data)} amostras")
188
+ except FileNotFoundError:
189
+ print(f"Arquivo {filepath} não encontrado!")
190
+ return None
191
+
192
+ # Limpar e converter dados complexos
193
+ print("Processando números complexos...")
194
+ data['s'] = data['s'].apply(parse_complex_improved)
195
+ data['zeta(s)'] = data['zeta(s)'].apply(parse_complex_improved)
196
+
197
+ # Remover valores inválidos
198
+ initial_len = len(data)
199
+ data = data.dropna()
200
+ final_len = len(data)
201
+
202
+ if final_len < initial_len:
203
+ print(f"Removidas {initial_len - final_len} amostras inválidas")
204
+
205
+ if len(data) == 0:
206
+ print("Nenhum dado válido encontrado!")
207
+ return None
208
+
209
+ # Separar partes real e imaginária
210
+ data['s_real'] = data['s'].apply(lambda x: x.real)
211
+ data['s_imag'] = data['s'].apply(lambda x: x.imag)
212
+ data['zeta_real'] = data['zeta(s)'].apply(lambda x: x.real)
213
+ data['zeta_imag'] = data['zeta(s)'].apply(lambda x: x.imag)
214
+
215
+ # Preparar features e targets
216
+ X = data[['s_real', 's_imag']].values
217
+ y = data[['zeta_real', 'zeta_imag']].values
218
+
219
+ # Split treino/validação
220
+ X_train, X_val, y_train, y_val = train_test_split(
221
+ X, y, test_size=test_size, random_state=random_state
222
+ )
223
+
224
+ # Normalização robusta
225
+ scaler_X = StandardScaler()
226
+ scaler_y = StandardScaler()
227
+
228
+ X_train_scaled = scaler_X.fit_transform(X_train)
229
+ X_val_scaled = scaler_X.transform(X_val)
230
+ y_train_scaled = scaler_y.fit_transform(y_train)
231
+ y_val_scaled = scaler_y.transform(y_val)
232
+
233
+ # Converter para tensores
234
+ X_train_tensor = torch.FloatTensor(X_train_scaled)
235
+ X_val_tensor = torch.FloatTensor(X_val_scaled)
236
+ y_train_tensor = torch.FloatTensor(y_train_scaled)
237
+ y_val_tensor = torch.FloatTensor(y_val_scaled)
238
+
239
+ print(f"Dados preprocessados:")
240
+ print(f" Treino: {len(X_train_tensor)} amostras")
241
+ print(f" Validação: {len(X_val_tensor)} amostras")
242
+
243
+ return {
244
+ 'train': (X_train_tensor, y_train_tensor),
245
+ 'val': (X_val_tensor, y_val_tensor),
246
+ 'scalers': (scaler_X, scaler_y),
247
+ 'raw_data': data
248
+ }
249
+
250
+ def create_data_loaders(data_dict, batch_size=64):
251
+ """Cria DataLoaders do PyTorch"""
252
+ train_dataset = torch.utils.data.TensorDataset(
253
+ data_dict['train'][0], data_dict['train'][1]
254
+ )
255
+ val_dataset = torch.utils.data.TensorDataset(
256
+ data_dict['val'][0], data_dict['val'][1]
257
+ )
258
+
259
+ train_loader = torch.utils.data.DataLoader(
260
+ train_dataset, batch_size=batch_size, shuffle=True
261
+ )
262
+ val_loader = torch.utils.data.DataLoader(
263
+ val_dataset, batch_size=batch_size, shuffle=False
264
+ )
265
+
266
+ return train_loader, val_loader
267
+
268
+ def plot_results(trainer, data_dict, model):
269
+ """Plota resultados do treinamento e predições"""
270
+ fig, axes = plt.subplots(2, 2, figsize=(15, 12))
271
+
272
+ # 1. Curvas de perda
273
+ axes[0,0].plot(trainer.train_losses, label='Treino', alpha=0.8)
274
+ axes[0,0].plot(trainer.val_losses, label='Validação', alpha=0.8)
275
+ axes[0,0].set_xlabel('Época')
276
+ axes[0,0].set_ylabel('MSE Loss')
277
+ axes[0,0].set_title('Curvas de Aprendizado')
278
+ axes[0,0].legend()
279
+ axes[0,0].grid(True, alpha=0.3)
280
+ axes[0,0].set_yscale('log')
281
+
282
+ # 2. Predições vs Real (parte real)
283
+ model.eval()
284
+ with torch.no_grad():
285
+ X_val, y_val = data_dict['val']
286
+ predictions = model(X_val)
287
+
288
+ # Denormalizar
289
+ scaler_y = data_dict['scalers'][1]
290
+ y_val_denorm = scaler_y.inverse_transform(y_val.numpy())
291
+ pred_denorm = scaler_y.inverse_transform(predictions.numpy())
292
+
293
+ axes[0,1].scatter(y_val_denorm[:, 0], pred_denorm[:, 0], alpha=0.6, s=1)
294
+ axes[0,1].plot([y_val_denorm[:, 0].min(), y_val_denorm[:, 0].max()],
295
+ [y_val_denorm[:, 0].min(), y_val_denorm[:, 0].max()], 'r--')
296
+ axes[0,1].set_xlabel('ζ(s) Real - Valor Real')
297
+ axes[0,1].set_ylabel('ζ(s) Real - Predição')
298
+ axes[0,1].set_title('Parte Real: Predição vs Real')
299
+ axes[0,1].grid(True, alpha=0.3)
300
+
301
+ # 3. Predições vs Real (parte imaginária)
302
+ axes[1,0].scatter(y_val_denorm[:, 1], pred_denorm[:, 1], alpha=0.6, s=1)
303
+ axes[1,0].plot([y_val_denorm[:, 1].min(), y_val_denorm[:, 1].max()],
304
+ [y_val_denorm[:, 1].min(), y_val_denorm[:, 1].max()], 'r--')
305
+ axes[1,0].set_xlabel('ζ(s) Imag - Valor Real')
306
+ axes[1,0].set_ylabel('ζ(s) Imag - Predição')
307
+ axes[1,0].set_title('Parte Imaginária: Predição vs Real')
308
+ axes[1,0].grid(True, alpha=0.3)
309
+
310
+ # 4. Distribuição dos erros
311
+ errors_real = np.abs(y_val_denorm[:, 0] - pred_denorm[:, 0])
312
+ errors_imag = np.abs(y_val_denorm[:, 1] - pred_denorm[:, 1])
313
+
314
+ axes[1,1].hist(errors_real, bins=50, alpha=0.7, label='Erro Parte Real')
315
+ axes[1,1].hist(errors_imag, bins=50, alpha=0.7, label='Erro Parte Imag')
316
+ axes[1,1].set_xlabel('Erro Absoluto')
317
+ axes[1,1].set_ylabel('Frequência')
318
+ axes[1,1].set_title('Distribuição dos Erros')
319
+ axes[1,1].legend()
320
+ axes[1,1].grid(True, alpha=0.3)
321
+ axes[1,1].set_yscale('log')
322
+
323
+ plt.tight_layout()
324
+ plt.savefig('zetanet_results.png', dpi=300, bbox_inches='tight')
325
+ plt.show()
326
+
327
+ # Estatísticas
328
+ print(f"\nEstatísticas de Erro:")
329
+ print(f"Erro médio (parte real): {errors_real.mean():.6f}")
330
+ print(f"Erro médio (parte imag): {errors_imag.mean():.6f}")
331
+ print(f"Erro máximo (parte real): {errors_real.max():.6f}")
332
+ print(f"Erro máximo (parte imag): {errors_imag.max():.6f}")
333
+
334
+ def main():
335
+ start_time = time()
336
+
337
+ # Configurações
338
+ FILEPATH = "/content/combined_zeta_data.csv" # Ajuste o caminho
339
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
340
+ print(f"Usando dispositivo: {DEVICE}")
341
+
342
+ # Carregar e preprocessar dados
343
+ data_dict = load_and_preprocess_data(FILEPATH)
344
+ if data_dict is None:
345
+ return
346
+
347
+ # Criar data loaders
348
+ train_loader, val_loader = create_data_loaders(data_dict, batch_size=128)
349
+
350
+ # Criar modelo melhorado
351
+ model = ImprovedZetaNet(
352
+ input_size=2,
353
+ hidden_sizes=[128, 256, 256, 128, 64],
354
+ output_size=2,
355
+ dropout_rate=0.1
356
+ )
357
+
358
+ print(f"\nArquitetura do modelo:")
359
+ print(model)
360
+
361
+ # Contar parâmetros
362
+ total_params = sum(p.numel() for p in model.parameters())
363
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
364
+ print(f"\nParâmetros totais: {total_params:,}")
365
+ print(f"Parâmetros treináveis: {trainable_params:,}")
366
+
367
+ # Treinar modelo
368
+ trainer = ZetaTrainer(model, DEVICE)
369
+ trainer.train(
370
+ train_loader, val_loader,
371
+ epochs=300,
372
+ learning_rate=0.001,
373
+ patience=30
374
+ )
375
+
376
+ # Plotar resultados
377
+ plot_results(trainer, data_dict, model)
378
+
379
+ end_time = time()
380
+ print(f"\nTempo total de execução: {(end_time - start_time):.2f} segundos")
381
+
382
+ if __name__ == "__main__":
383
+ main()