ranjeetjha's picture
Upload 19 files (#1)
ab80e91 verified
"""
Simplified training script for Universal Image Classifier
Works with basic PyTorch installation
"""
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from pathlib import Path
import json
from typing import List, Tuple
from dataclasses import dataclass
# Import our models
from models import create_model, MODEL_REGISTRY
from config import ModelConfig, TrainingConfig
class SimpleImageDataset(Dataset):
"""Simple image dataset without torchvision transforms"""
def __init__(self, image_paths: List[str], labels: List[int], input_size: Tuple[int, int] = (64, 64)):
self.image_paths = image_paths
self.labels = labels
self.input_size = input_size
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# Load image
image = Image.open(self.image_paths[idx]).convert('RGB')
# Resize image
image = image.resize(self.input_size)
# Convert to tensor and normalize
image_array = np.array(image).astype(np.float32) / 255.0
image_tensor = torch.from_numpy(image_array).permute(2, 0, 1) # CHW format
return image_tensor, self.labels[idx]
def load_dataset_from_folder(data_dir: str) -> Tuple[List[str], List[int], List[str]]:
"""Load dataset from folder structure"""
data_path = Path(data_dir)
if not data_path.exists():
raise FileNotFoundError(f"Data directory '{data_dir}' not found")
image_paths = []
labels = []
class_names = []
# Get class directories
class_dirs = [d for d in data_path.iterdir() if d.is_dir()]
class_dirs.sort() # Ensure consistent ordering
for class_idx, class_dir in enumerate(class_dirs):
class_name = class_dir.name
class_names.append(class_name)
# Get all image files in class directory
for img_path in class_dir.glob('*.png'):
image_paths.append(str(img_path))
labels.append(class_idx)
print(f"Found {len(image_paths)} images across {len(class_names)} classes")
print(f"Classes: {class_names}")
return image_paths, labels, class_names
def train_model():
"""Train the Universal Image Classifier"""
# Configuration
model_config = ModelConfig(
input_height=64,
input_width=64,
num_classes=4,
hidden_dim=256,
num_layers=6,
dropout_rate=0.1,
use_batch_norm=True
)
training_config = TrainingConfig(
batch_size=32,
learning_rate=0.001,
num_epochs=20, # Reduced for faster training
weight_decay=1e-4,
early_stopping_patience=10,
validation_split=0.2
)
# Load dataset
data_dir = "sample_dataset"
try:
image_paths, labels, class_names = load_dataset_from_folder(data_dir)
except FileNotFoundError as e:
print(f"Error: {e}")
print("Please run 'python generate_sample_data.py' first to create sample data")
return
# Update model config with actual number of classes
model_config.num_classes = len(class_names)
# Split dataset (80% train, 20% validation)
total_samples = len(image_paths)
train_size = int(0.8 * total_samples)
# Simple random split
indices = list(range(total_samples))
np.random.seed(42) # For reproducibility
np.random.shuffle(indices)
train_indices = indices[:train_size]
val_indices = indices[train_size:]
train_paths = [image_paths[i] for i in train_indices]
train_labels = [labels[i] for i in train_indices]
val_paths = [image_paths[i] for i in val_indices]
val_labels = [labels[i] for i in val_indices]
# Create datasets
train_dataset = SimpleImageDataset(train_paths, train_labels)
val_dataset = SimpleImageDataset(val_paths, val_labels)
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=training_config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=training_config.batch_size, shuffle=False)
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
# Create model
model_name = "mlp_deep_residual"
model = create_model(model_name, model_config)
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel: {model_name}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
# Setup training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=training_config.learning_rate, weight_decay=training_config.weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
# Create outputs directory
os.makedirs('outputs', exist_ok=True)
# Training loop
best_val_acc = 0.0
patience_counter = 0
print(f"\n๐Ÿš€ Starting training for {training_config.num_epochs} epochs...")
for epoch in range(training_config.num_epochs):
# Training phase
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = output.max(1)
train_total += target.size(0)
train_correct += predicted.eq(target).sum().item()
if batch_idx % 10 == 0:
print(f'Epoch {epoch+1}/{training_config.num_epochs} [{batch_idx * len(data)}/{len(train_dataset)} '
f'({100. * batch_idx / len(train_loader):.0f}%)] Loss: {loss.item():.6f}')
# Validation phase
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
val_loss += criterion(output, target).item()
_, predicted = output.max(1)
val_total += target.size(0)
val_correct += predicted.eq(target).sum().item()
# Calculate metrics
train_acc = 100. * train_correct / train_total
val_acc = 100. * val_correct / val_total
train_loss /= len(train_loader)
val_loss /= len(val_loader)
print(f'Epoch {epoch+1}/{training_config.num_epochs}:')
print(f' Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
print(f' Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
# Learning rate scheduling
scheduler.step(val_loss)
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
patience_counter = 0
# Save model
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_acc': val_acc,
'model_config': model_config.__dict__,
'class_names': class_names
}, 'outputs/best_model.pth')
print(f' โœ“ New best model saved! Val Acc: {val_acc:.2f}%')
else:
patience_counter += 1
# Early stopping
if patience_counter >= training_config.early_stopping_patience:
print(f'Early stopping triggered after {epoch+1} epochs')
break
print()
print(f'๐ŸŽ‰ Training completed!')
print(f'Best validation accuracy: {best_val_acc:.2f}%')
print(f'Model saved to: outputs/best_model.pth')
# Save class names and config
with open('outputs/class_names.json', 'w') as f:
json.dump(class_names, f)
with open('outputs/model_config.json', 'w') as f:
json.dump(model_config.__dict__, f, indent=2)
print('Class names and config saved to outputs/')
if __name__ == "__main__":
train_model()