|
""" |
|
Simplified training script for Universal Image Classifier |
|
Works with basic PyTorch installation |
|
""" |
|
import sys |
|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import Dataset, DataLoader |
|
from PIL import Image |
|
import numpy as np |
|
from pathlib import Path |
|
import json |
|
from typing import List, Tuple |
|
from dataclasses import dataclass |
|
|
|
|
|
from models import create_model, MODEL_REGISTRY |
|
from config import ModelConfig, TrainingConfig |
|
|
|
class SimpleImageDataset(Dataset): |
|
"""Simple image dataset without torchvision transforms""" |
|
|
|
def __init__(self, image_paths: List[str], labels: List[int], input_size: Tuple[int, int] = (64, 64)): |
|
self.image_paths = image_paths |
|
self.labels = labels |
|
self.input_size = input_size |
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
def __getitem__(self, idx): |
|
|
|
image = Image.open(self.image_paths[idx]).convert('RGB') |
|
|
|
|
|
image = image.resize(self.input_size) |
|
|
|
|
|
image_array = np.array(image).astype(np.float32) / 255.0 |
|
image_tensor = torch.from_numpy(image_array).permute(2, 0, 1) |
|
|
|
return image_tensor, self.labels[idx] |
|
|
|
def load_dataset_from_folder(data_dir: str) -> Tuple[List[str], List[int], List[str]]: |
|
"""Load dataset from folder structure""" |
|
data_path = Path(data_dir) |
|
|
|
if not data_path.exists(): |
|
raise FileNotFoundError(f"Data directory '{data_dir}' not found") |
|
|
|
image_paths = [] |
|
labels = [] |
|
class_names = [] |
|
|
|
|
|
class_dirs = [d for d in data_path.iterdir() if d.is_dir()] |
|
class_dirs.sort() |
|
|
|
for class_idx, class_dir in enumerate(class_dirs): |
|
class_name = class_dir.name |
|
class_names.append(class_name) |
|
|
|
|
|
for img_path in class_dir.glob('*.png'): |
|
image_paths.append(str(img_path)) |
|
labels.append(class_idx) |
|
|
|
print(f"Found {len(image_paths)} images across {len(class_names)} classes") |
|
print(f"Classes: {class_names}") |
|
|
|
return image_paths, labels, class_names |
|
|
|
def train_model(): |
|
"""Train the Universal Image Classifier""" |
|
|
|
|
|
model_config = ModelConfig( |
|
input_height=64, |
|
input_width=64, |
|
num_classes=4, |
|
hidden_dim=256, |
|
num_layers=6, |
|
dropout_rate=0.1, |
|
use_batch_norm=True |
|
) |
|
|
|
training_config = TrainingConfig( |
|
batch_size=32, |
|
learning_rate=0.001, |
|
num_epochs=20, |
|
weight_decay=1e-4, |
|
early_stopping_patience=10, |
|
validation_split=0.2 |
|
) |
|
|
|
|
|
data_dir = "sample_dataset" |
|
|
|
try: |
|
image_paths, labels, class_names = load_dataset_from_folder(data_dir) |
|
except FileNotFoundError as e: |
|
print(f"Error: {e}") |
|
print("Please run 'python generate_sample_data.py' first to create sample data") |
|
return |
|
|
|
|
|
model_config.num_classes = len(class_names) |
|
|
|
|
|
total_samples = len(image_paths) |
|
train_size = int(0.8 * total_samples) |
|
|
|
|
|
indices = list(range(total_samples)) |
|
np.random.seed(42) |
|
np.random.shuffle(indices) |
|
|
|
train_indices = indices[:train_size] |
|
val_indices = indices[train_size:] |
|
|
|
train_paths = [image_paths[i] for i in train_indices] |
|
train_labels = [labels[i] for i in train_indices] |
|
val_paths = [image_paths[i] for i in val_indices] |
|
val_labels = [labels[i] for i in val_indices] |
|
|
|
|
|
train_dataset = SimpleImageDataset(train_paths, train_labels) |
|
val_dataset = SimpleImageDataset(val_paths, val_labels) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=training_config.batch_size, shuffle=True) |
|
val_loader = DataLoader(val_dataset, batch_size=training_config.batch_size, shuffle=False) |
|
|
|
print(f"Training samples: {len(train_dataset)}") |
|
print(f"Validation samples: {len(val_dataset)}") |
|
|
|
|
|
model_name = "mlp_deep_residual" |
|
model = create_model(model_name, model_config) |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
print(f"\nModel: {model_name}") |
|
print(f"Total parameters: {total_params:,}") |
|
print(f"Trainable parameters: {trainable_params:,}") |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print(f"Using device: {device}") |
|
|
|
model = model.to(device) |
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(model.parameters(), lr=training_config.learning_rate, weight_decay=training_config.weight_decay) |
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5) |
|
|
|
|
|
os.makedirs('outputs', exist_ok=True) |
|
|
|
|
|
best_val_acc = 0.0 |
|
patience_counter = 0 |
|
|
|
print(f"\n๐ Starting training for {training_config.num_epochs} epochs...") |
|
|
|
for epoch in range(training_config.num_epochs): |
|
|
|
model.train() |
|
train_loss = 0.0 |
|
train_correct = 0 |
|
train_total = 0 |
|
|
|
for batch_idx, (data, target) in enumerate(train_loader): |
|
data, target = data.to(device), target.to(device) |
|
|
|
optimizer.zero_grad() |
|
output = model(data) |
|
loss = criterion(output, target) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
train_loss += loss.item() |
|
_, predicted = output.max(1) |
|
train_total += target.size(0) |
|
train_correct += predicted.eq(target).sum().item() |
|
|
|
if batch_idx % 10 == 0: |
|
print(f'Epoch {epoch+1}/{training_config.num_epochs} [{batch_idx * len(data)}/{len(train_dataset)} ' |
|
f'({100. * batch_idx / len(train_loader):.0f}%)] Loss: {loss.item():.6f}') |
|
|
|
|
|
model.eval() |
|
val_loss = 0.0 |
|
val_correct = 0 |
|
val_total = 0 |
|
|
|
with torch.no_grad(): |
|
for data, target in val_loader: |
|
data, target = data.to(device), target.to(device) |
|
output = model(data) |
|
val_loss += criterion(output, target).item() |
|
|
|
_, predicted = output.max(1) |
|
val_total += target.size(0) |
|
val_correct += predicted.eq(target).sum().item() |
|
|
|
|
|
train_acc = 100. * train_correct / train_total |
|
val_acc = 100. * val_correct / val_total |
|
train_loss /= len(train_loader) |
|
val_loss /= len(val_loader) |
|
|
|
print(f'Epoch {epoch+1}/{training_config.num_epochs}:') |
|
print(f' Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%') |
|
print(f' Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%') |
|
|
|
|
|
scheduler.step(val_loss) |
|
|
|
|
|
if val_acc > best_val_acc: |
|
best_val_acc = val_acc |
|
patience_counter = 0 |
|
|
|
|
|
torch.save({ |
|
'epoch': epoch, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'val_acc': val_acc, |
|
'model_config': model_config.__dict__, |
|
'class_names': class_names |
|
}, 'outputs/best_model.pth') |
|
|
|
print(f' โ New best model saved! Val Acc: {val_acc:.2f}%') |
|
else: |
|
patience_counter += 1 |
|
|
|
|
|
if patience_counter >= training_config.early_stopping_patience: |
|
print(f'Early stopping triggered after {epoch+1} epochs') |
|
break |
|
|
|
print() |
|
|
|
print(f'๐ Training completed!') |
|
print(f'Best validation accuracy: {best_val_acc:.2f}%') |
|
print(f'Model saved to: outputs/best_model.pth') |
|
|
|
|
|
with open('outputs/class_names.json', 'w') as f: |
|
json.dump(class_names, f) |
|
|
|
with open('outputs/model_config.json', 'w') as f: |
|
json.dump(model_config.__dict__, f, indent=2) |
|
|
|
print('Class names and config saved to outputs/') |
|
|
|
if __name__ == "__main__": |
|
train_model() |
|
|