Upload ZetaNet-v2.py
Browse files- ZetaNet-v2.py +383 -0
ZetaNet-v2.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.optim as optim
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
from sklearn.preprocessing import StandardScaler
|
9 |
+
from time import time
|
10 |
+
import warnings
|
11 |
+
warnings.filterwarnings('ignore')
|
12 |
+
|
13 |
+
class ImprovedZetaNet(nn.Module):
|
14 |
+
def __init__(self, input_size=2, hidden_sizes=[128, 256, 128, 64], output_size=2, dropout_rate=0.1):
|
15 |
+
super(ImprovedZetaNet, self).__init__()
|
16 |
+
|
17 |
+
# Construir camadas dinamicamente
|
18 |
+
layers = []
|
19 |
+
prev_size = input_size
|
20 |
+
|
21 |
+
for hidden_size in hidden_sizes:
|
22 |
+
layers.extend([
|
23 |
+
nn.Linear(prev_size, hidden_size),
|
24 |
+
nn.BatchNorm1d(hidden_size),
|
25 |
+
nn.ReLU(),
|
26 |
+
nn.Dropout(dropout_rate)
|
27 |
+
])
|
28 |
+
prev_size = hidden_size
|
29 |
+
|
30 |
+
# Camada de saída sem ativação
|
31 |
+
layers.append(nn.Linear(prev_size, output_size))
|
32 |
+
|
33 |
+
self.network = nn.Sequential(*layers)
|
34 |
+
|
35 |
+
# Inicialização Xavier/Glorot
|
36 |
+
self._initialize_weights()
|
37 |
+
|
38 |
+
def _initialize_weights(self):
|
39 |
+
for module in self.modules():
|
40 |
+
if isinstance(module, nn.Linear):
|
41 |
+
nn.init.xavier_normal_(module.weight)
|
42 |
+
if module.bias is not None:
|
43 |
+
nn.init.constant_(module.bias, 0)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
return self.network(x)
|
47 |
+
|
48 |
+
class ZetaTrainer:
|
49 |
+
def __init__(self, model, device='cpu'):
|
50 |
+
self.model = model.to(device)
|
51 |
+
self.device = device
|
52 |
+
self.train_losses = []
|
53 |
+
self.val_losses = []
|
54 |
+
|
55 |
+
def train_epoch(self, train_loader, optimizer, criterion):
|
56 |
+
self.model.train()
|
57 |
+
total_loss = 0
|
58 |
+
num_batches = 0
|
59 |
+
|
60 |
+
for batch_x, batch_y in train_loader:
|
61 |
+
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
|
62 |
+
|
63 |
+
optimizer.zero_grad()
|
64 |
+
predictions = self.model(batch_x)
|
65 |
+
loss = criterion(predictions, batch_y)
|
66 |
+
loss.backward()
|
67 |
+
|
68 |
+
# Gradient clipping para estabilidade
|
69 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
70 |
+
|
71 |
+
optimizer.step()
|
72 |
+
|
73 |
+
total_loss += loss.item()
|
74 |
+
num_batches += 1
|
75 |
+
|
76 |
+
return total_loss / num_batches
|
77 |
+
|
78 |
+
def validate(self, val_loader, criterion):
|
79 |
+
self.model.eval()
|
80 |
+
total_loss = 0
|
81 |
+
num_batches = 0
|
82 |
+
|
83 |
+
with torch.no_grad():
|
84 |
+
for batch_x, batch_y in val_loader:
|
85 |
+
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
|
86 |
+
predictions = self.model(batch_x)
|
87 |
+
loss = criterion(predictions, batch_y)
|
88 |
+
total_loss += loss.item()
|
89 |
+
num_batches += 1
|
90 |
+
|
91 |
+
return total_loss / num_batches
|
92 |
+
|
93 |
+
def train(self, train_loader, val_loader, epochs=200, learning_rate=0.001, patience=20):
|
94 |
+
# Usar Adam com weight decay
|
95 |
+
optimizer = optim.AdamW(self.model.parameters(), lr=learning_rate, weight_decay=1e-5)
|
96 |
+
|
97 |
+
# Learning rate scheduler
|
98 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
99 |
+
optimizer, mode='min', factor=0.5, patience=10, verbose=True
|
100 |
+
)
|
101 |
+
|
102 |
+
criterion = nn.MSELoss()
|
103 |
+
best_val_loss = float('inf')
|
104 |
+
patience_counter = 0
|
105 |
+
|
106 |
+
print(f"Iniciando treinamento por {epochs} épocas...")
|
107 |
+
print("-" * 60)
|
108 |
+
|
109 |
+
for epoch in range(epochs):
|
110 |
+
# Treinar
|
111 |
+
train_loss = self.train_epoch(train_loader, optimizer, criterion)
|
112 |
+
|
113 |
+
# Validar
|
114 |
+
val_loss = self.validate(val_loader, criterion)
|
115 |
+
|
116 |
+
# Atualizar scheduler
|
117 |
+
scheduler.step(val_loss)
|
118 |
+
|
119 |
+
# Salvar histórico
|
120 |
+
self.train_losses.append(train_loss)
|
121 |
+
self.val_losses.append(val_loss)
|
122 |
+
|
123 |
+
# Early stopping
|
124 |
+
if val_loss < best_val_loss:
|
125 |
+
best_val_loss = val_loss
|
126 |
+
patience_counter = 0
|
127 |
+
# Salvar melhor modelo
|
128 |
+
torch.save(self.model.state_dict(), 'best_zetanet.pth')
|
129 |
+
else:
|
130 |
+
patience_counter += 1
|
131 |
+
|
132 |
+
# Print progress
|
133 |
+
if (epoch + 1) % 20 == 0 or epoch == 0:
|
134 |
+
current_lr = optimizer.param_groups[0]['lr']
|
135 |
+
print(f"Época {epoch+1:3d}/{epochs} | "
|
136 |
+
f"Train Loss: {train_loss:.6f} | "
|
137 |
+
f"Val Loss: {val_loss:.6f} | "
|
138 |
+
f"LR: {current_lr:.2e}")
|
139 |
+
|
140 |
+
# Early stopping
|
141 |
+
if patience_counter >= patience:
|
142 |
+
print(f"\nEarly stopping na época {epoch+1}")
|
143 |
+
break
|
144 |
+
|
145 |
+
# Carregar melhor modelo
|
146 |
+
self.model.load_state_dict(torch.load('best_zetanet.pth'))
|
147 |
+
print(f"\nTreinamento concluído! Melhor perda de validação: {best_val_loss:.6f}")
|
148 |
+
|
149 |
+
def parse_complex_improved(value):
|
150 |
+
"""Função melhorada para parsing de números complexos"""
|
151 |
+
if pd.isna(value):
|
152 |
+
return np.nan
|
153 |
+
|
154 |
+
value = str(value).strip()
|
155 |
+
|
156 |
+
# Remover parênteses
|
157 |
+
value = value.replace('(', '').replace(')', '')
|
158 |
+
|
159 |
+
# Substituir vírgulas por pontos
|
160 |
+
value = value.replace(',', '.')
|
161 |
+
|
162 |
+
# Casos especiais
|
163 |
+
if value == '' or value.lower() == 'nan':
|
164 |
+
return np.nan
|
165 |
+
|
166 |
+
try:
|
167 |
+
# Se não tem 'j' ou 'i', adicionar 'j' no final
|
168 |
+
if 'j' not in value.lower() and 'i' not in value.lower():
|
169 |
+
if '+' in value or '-' in value[1:]: # Tem parte real e imaginária
|
170 |
+
value += 'j'
|
171 |
+
else: # Só parte real
|
172 |
+
return complex(float(value), 0)
|
173 |
+
|
174 |
+
# Substituir 'i' por 'j'
|
175 |
+
value = value.replace('i', 'j')
|
176 |
+
|
177 |
+
return complex(value)
|
178 |
+
except (ValueError, TypeError):
|
179 |
+
return np.nan
|
180 |
+
|
181 |
+
def load_and_preprocess_data(filepath, test_size=0.2, random_state=42):
|
182 |
+
"""Carrega e preprocessa os dados com melhor tratamento de erros"""
|
183 |
+
print("Carregando dados...")
|
184 |
+
|
185 |
+
try:
|
186 |
+
data = pd.read_csv(filepath)
|
187 |
+
print(f"Dados carregados: {len(data)} amostras")
|
188 |
+
except FileNotFoundError:
|
189 |
+
print(f"Arquivo {filepath} não encontrado!")
|
190 |
+
return None
|
191 |
+
|
192 |
+
# Limpar e converter dados complexos
|
193 |
+
print("Processando números complexos...")
|
194 |
+
data['s'] = data['s'].apply(parse_complex_improved)
|
195 |
+
data['zeta(s)'] = data['zeta(s)'].apply(parse_complex_improved)
|
196 |
+
|
197 |
+
# Remover valores inválidos
|
198 |
+
initial_len = len(data)
|
199 |
+
data = data.dropna()
|
200 |
+
final_len = len(data)
|
201 |
+
|
202 |
+
if final_len < initial_len:
|
203 |
+
print(f"Removidas {initial_len - final_len} amostras inválidas")
|
204 |
+
|
205 |
+
if len(data) == 0:
|
206 |
+
print("Nenhum dado válido encontrado!")
|
207 |
+
return None
|
208 |
+
|
209 |
+
# Separar partes real e imaginária
|
210 |
+
data['s_real'] = data['s'].apply(lambda x: x.real)
|
211 |
+
data['s_imag'] = data['s'].apply(lambda x: x.imag)
|
212 |
+
data['zeta_real'] = data['zeta(s)'].apply(lambda x: x.real)
|
213 |
+
data['zeta_imag'] = data['zeta(s)'].apply(lambda x: x.imag)
|
214 |
+
|
215 |
+
# Preparar features e targets
|
216 |
+
X = data[['s_real', 's_imag']].values
|
217 |
+
y = data[['zeta_real', 'zeta_imag']].values
|
218 |
+
|
219 |
+
# Split treino/validação
|
220 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
221 |
+
X, y, test_size=test_size, random_state=random_state
|
222 |
+
)
|
223 |
+
|
224 |
+
# Normalização robusta
|
225 |
+
scaler_X = StandardScaler()
|
226 |
+
scaler_y = StandardScaler()
|
227 |
+
|
228 |
+
X_train_scaled = scaler_X.fit_transform(X_train)
|
229 |
+
X_val_scaled = scaler_X.transform(X_val)
|
230 |
+
y_train_scaled = scaler_y.fit_transform(y_train)
|
231 |
+
y_val_scaled = scaler_y.transform(y_val)
|
232 |
+
|
233 |
+
# Converter para tensores
|
234 |
+
X_train_tensor = torch.FloatTensor(X_train_scaled)
|
235 |
+
X_val_tensor = torch.FloatTensor(X_val_scaled)
|
236 |
+
y_train_tensor = torch.FloatTensor(y_train_scaled)
|
237 |
+
y_val_tensor = torch.FloatTensor(y_val_scaled)
|
238 |
+
|
239 |
+
print(f"Dados preprocessados:")
|
240 |
+
print(f" Treino: {len(X_train_tensor)} amostras")
|
241 |
+
print(f" Validação: {len(X_val_tensor)} amostras")
|
242 |
+
|
243 |
+
return {
|
244 |
+
'train': (X_train_tensor, y_train_tensor),
|
245 |
+
'val': (X_val_tensor, y_val_tensor),
|
246 |
+
'scalers': (scaler_X, scaler_y),
|
247 |
+
'raw_data': data
|
248 |
+
}
|
249 |
+
|
250 |
+
def create_data_loaders(data_dict, batch_size=64):
|
251 |
+
"""Cria DataLoaders do PyTorch"""
|
252 |
+
train_dataset = torch.utils.data.TensorDataset(
|
253 |
+
data_dict['train'][0], data_dict['train'][1]
|
254 |
+
)
|
255 |
+
val_dataset = torch.utils.data.TensorDataset(
|
256 |
+
data_dict['val'][0], data_dict['val'][1]
|
257 |
+
)
|
258 |
+
|
259 |
+
train_loader = torch.utils.data.DataLoader(
|
260 |
+
train_dataset, batch_size=batch_size, shuffle=True
|
261 |
+
)
|
262 |
+
val_loader = torch.utils.data.DataLoader(
|
263 |
+
val_dataset, batch_size=batch_size, shuffle=False
|
264 |
+
)
|
265 |
+
|
266 |
+
return train_loader, val_loader
|
267 |
+
|
268 |
+
def plot_results(trainer, data_dict, model):
|
269 |
+
"""Plota resultados do treinamento e predições"""
|
270 |
+
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
|
271 |
+
|
272 |
+
# 1. Curvas de perda
|
273 |
+
axes[0,0].plot(trainer.train_losses, label='Treino', alpha=0.8)
|
274 |
+
axes[0,0].plot(trainer.val_losses, label='Validação', alpha=0.8)
|
275 |
+
axes[0,0].set_xlabel('Época')
|
276 |
+
axes[0,0].set_ylabel('MSE Loss')
|
277 |
+
axes[0,0].set_title('Curvas de Aprendizado')
|
278 |
+
axes[0,0].legend()
|
279 |
+
axes[0,0].grid(True, alpha=0.3)
|
280 |
+
axes[0,0].set_yscale('log')
|
281 |
+
|
282 |
+
# 2. Predições vs Real (parte real)
|
283 |
+
model.eval()
|
284 |
+
with torch.no_grad():
|
285 |
+
X_val, y_val = data_dict['val']
|
286 |
+
predictions = model(X_val)
|
287 |
+
|
288 |
+
# Denormalizar
|
289 |
+
scaler_y = data_dict['scalers'][1]
|
290 |
+
y_val_denorm = scaler_y.inverse_transform(y_val.numpy())
|
291 |
+
pred_denorm = scaler_y.inverse_transform(predictions.numpy())
|
292 |
+
|
293 |
+
axes[0,1].scatter(y_val_denorm[:, 0], pred_denorm[:, 0], alpha=0.6, s=1)
|
294 |
+
axes[0,1].plot([y_val_denorm[:, 0].min(), y_val_denorm[:, 0].max()],
|
295 |
+
[y_val_denorm[:, 0].min(), y_val_denorm[:, 0].max()], 'r--')
|
296 |
+
axes[0,1].set_xlabel('ζ(s) Real - Valor Real')
|
297 |
+
axes[0,1].set_ylabel('ζ(s) Real - Predição')
|
298 |
+
axes[0,1].set_title('Parte Real: Predição vs Real')
|
299 |
+
axes[0,1].grid(True, alpha=0.3)
|
300 |
+
|
301 |
+
# 3. Predições vs Real (parte imaginária)
|
302 |
+
axes[1,0].scatter(y_val_denorm[:, 1], pred_denorm[:, 1], alpha=0.6, s=1)
|
303 |
+
axes[1,0].plot([y_val_denorm[:, 1].min(), y_val_denorm[:, 1].max()],
|
304 |
+
[y_val_denorm[:, 1].min(), y_val_denorm[:, 1].max()], 'r--')
|
305 |
+
axes[1,0].set_xlabel('ζ(s) Imag - Valor Real')
|
306 |
+
axes[1,0].set_ylabel('ζ(s) Imag - Predição')
|
307 |
+
axes[1,0].set_title('Parte Imaginária: Predição vs Real')
|
308 |
+
axes[1,0].grid(True, alpha=0.3)
|
309 |
+
|
310 |
+
# 4. Distribuição dos erros
|
311 |
+
errors_real = np.abs(y_val_denorm[:, 0] - pred_denorm[:, 0])
|
312 |
+
errors_imag = np.abs(y_val_denorm[:, 1] - pred_denorm[:, 1])
|
313 |
+
|
314 |
+
axes[1,1].hist(errors_real, bins=50, alpha=0.7, label='Erro Parte Real')
|
315 |
+
axes[1,1].hist(errors_imag, bins=50, alpha=0.7, label='Erro Parte Imag')
|
316 |
+
axes[1,1].set_xlabel('Erro Absoluto')
|
317 |
+
axes[1,1].set_ylabel('Frequência')
|
318 |
+
axes[1,1].set_title('Distribuição dos Erros')
|
319 |
+
axes[1,1].legend()
|
320 |
+
axes[1,1].grid(True, alpha=0.3)
|
321 |
+
axes[1,1].set_yscale('log')
|
322 |
+
|
323 |
+
plt.tight_layout()
|
324 |
+
plt.savefig('zetanet_results.png', dpi=300, bbox_inches='tight')
|
325 |
+
plt.show()
|
326 |
+
|
327 |
+
# Estatísticas
|
328 |
+
print(f"\nEstatísticas de Erro:")
|
329 |
+
print(f"Erro médio (parte real): {errors_real.mean():.6f}")
|
330 |
+
print(f"Erro médio (parte imag): {errors_imag.mean():.6f}")
|
331 |
+
print(f"Erro máximo (parte real): {errors_real.max():.6f}")
|
332 |
+
print(f"Erro máximo (parte imag): {errors_imag.max():.6f}")
|
333 |
+
|
334 |
+
def main():
|
335 |
+
start_time = time()
|
336 |
+
|
337 |
+
# Configurações
|
338 |
+
FILEPATH = "/content/combined_zeta_data.csv" # Ajuste o caminho
|
339 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
340 |
+
print(f"Usando dispositivo: {DEVICE}")
|
341 |
+
|
342 |
+
# Carregar e preprocessar dados
|
343 |
+
data_dict = load_and_preprocess_data(FILEPATH)
|
344 |
+
if data_dict is None:
|
345 |
+
return
|
346 |
+
|
347 |
+
# Criar data loaders
|
348 |
+
train_loader, val_loader = create_data_loaders(data_dict, batch_size=128)
|
349 |
+
|
350 |
+
# Criar modelo melhorado
|
351 |
+
model = ImprovedZetaNet(
|
352 |
+
input_size=2,
|
353 |
+
hidden_sizes=[128, 256, 256, 128, 64],
|
354 |
+
output_size=2,
|
355 |
+
dropout_rate=0.1
|
356 |
+
)
|
357 |
+
|
358 |
+
print(f"\nArquitetura do modelo:")
|
359 |
+
print(model)
|
360 |
+
|
361 |
+
# Contar parâmetros
|
362 |
+
total_params = sum(p.numel() for p in model.parameters())
|
363 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
364 |
+
print(f"\nParâmetros totais: {total_params:,}")
|
365 |
+
print(f"Parâmetros treináveis: {trainable_params:,}")
|
366 |
+
|
367 |
+
# Treinar modelo
|
368 |
+
trainer = ZetaTrainer(model, DEVICE)
|
369 |
+
trainer.train(
|
370 |
+
train_loader, val_loader,
|
371 |
+
epochs=300,
|
372 |
+
learning_rate=0.001,
|
373 |
+
patience=30
|
374 |
+
)
|
375 |
+
|
376 |
+
# Plotar resultados
|
377 |
+
plot_results(trainer, data_dict, model)
|
378 |
+
|
379 |
+
end_time = time()
|
380 |
+
print(f"\nTempo total de execução: {(end_time - start_time):.2f} segundos")
|
381 |
+
|
382 |
+
if __name__ == "__main__":
|
383 |
+
main()
|