promblemo / train
aashituli's picture
Create train
8538457 verified
import os
import pandas as pd
import torch
import numpy as np
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision import transforms
from transformers import (
ViTFeatureExtractor,
ViTForImageClassification,
Trainer,
TrainingArguments,
EarlyStoppingCallback,
default_data_collator
)
from datasets import load_dataset, Dataset, DatasetDict
from huggingface_hub import HfApi
# ============ CONFIG ============ #
MODEL_NAME = "wambugu71/crop_leaf_diseases_vit"
CSV_PATH = "dataset/labels.csv"
IMAGE_DIR = "dataset/images"
OUTPUT_DIR = "./vit_leaf_disease_model"
NUM_EPOCHS = 10
BATCH_SIZE = 16
LEARNING_RATE = 2e-5
SEED = 42
# Set random seed for reproducibility
torch.manual_seed(SEED)
np.random.seed(SEED)
# ============ LOAD DATA ============ #
df = pd.read_csv(CSV_PATH)
labels = sorted(df['label'].unique())
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for label, i in label2id.items()}
df['label_id'] = df['label'].map(label2id)
# ============ FEATURE EXTRACTOR & MODEL ============ #
feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_NAME)
model = ViTForImageClassification.from_pretrained(
MODEL_NAME,
num_labels=len(labels),
label2id=label2id,
id2label=id2label
)
# ============ IMAGE TRANSFORM ============ #
def preprocess(example):
image_path = os.path.join(IMAGE_DIR, example['image'])
image = Image.open(image_path).convert("RGB")
inputs = feature_extractor(images=image, return_tensors="pt")
example['pixel_values'] = inputs['pixel_values'][0]
example['label'] = example['label_id']
return example
# Convert to HF dataset
dataset = Dataset.from_pandas(df)
dataset = dataset.map(preprocess, remove_columns=['image', 'label', 'label_id'])
dataset = dataset.train_test_split(test_size=0.2, seed=SEED)
train_ds = dataset['train']
eval_ds = dataset['test']
# ============ METRICS ============ #
from evaluate import load
accuracy = load("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return accuracy.compute(predictions=predictions, references=labels)
# ============ TRAINING ARGS ============ #
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
num_train_epochs=NUM_EPOCHS,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=LEARNING_RATE,
logging_dir="./logs",
logging_steps=10,
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
greater_is_better=True,
seed=SEED,
report_to="none"
)
# ============ TRAINER ============ #
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
tokenizer=feature_extractor,
data_collator=default_data_collator,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)
# ============ TRAIN ============ #
trainer.train()
# ============ SAVE MODEL ============ #
model.save_pretrained(OUTPUT_DIR)
feature_extractor.save_pretrained(OUTPUT_DIR)
# ============ EVALUATE ============ #
outputs = trainer.predict(eval_ds)
preds = np.argmax(outputs.predictions, axis=-1)
true_labels = outputs.label_ids
print("\nClassification Report:\n")
print(classification_report(true_labels, preds, target_names=labels))
# ============ CONFUSION MATRIX ============ #
cm = confusion_matrix(true_labels, preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.savefig("confusion_matrix.png")
plt.show()
# ============ OPTIONAL: UPLOAD TO HF HUB ============ #
# api = HfApi()
# api.upload_folder(
# folder_path=OUTPUT_DIR,
# repo_id="your-username/crop_leaf_disease_vit_finetuned",
# repo_type="model"
# )