File size: 2,358 Bytes
f0c1a1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
'''
    This file is used for inference on new test samples.
    One sample or mutiple samples can be fed to the predict() method
'''
import os
import sys
from typing import List
import numpy as np
from gcg import config
from gcg.utils import logging, CustomException, load_from_checkpoint, load_object
from gcg.components import build_model, preprocess_image, grad_cam_plus, show_GradCAM

def predict(img_paths:List):
    '''
    Inputs: List of image paths
    Output: Predictions List and Generates heatmaps
    '''
    try:
        predictions_list = []
        # Step 1: Building the model
        model = build_model(input_shape=config.image_size, num_classes=config.num_classes)

        # Step 2: Loading the model from checkpoint     
        logging.info("Loading the model from checkpoint...")
        model = load_from_checkpoint(model, config.model_path)

        # Step 3: Loading the label encoder to decode indices
        le = load_object(config.labelencoder_save_path)

        for img_path in img_paths:
            # Step 4: Read and preprocess the image
            img_name = img_path.split("/")[-1]

            resized_img = preprocess_image(img_path, config.image_size)
            img_array = np.expand_dims(resized_img, axis=0)
            
            # Step 5: Inference on the model
            logging.info("Getting your prediction...")
            pred = model.predict(img_array)
            predicted_class = np.argmax(pred, axis=1)[0]  # Get class index
            
            logging.info(f"Prediction: {le.classes_[predicted_class]}, Path: {img_name}")
            predictions_list.append(le.classes_[predicted_class])
            
            # Step 6: Generating heatmap for the image
            logging.info("Generating the heatmap using GradCAM++")
            heatmap_plus = grad_cam_plus(model, resized_img, config.gcg_layer_name, label_name=config.labels, category_id=predicted_class)

            os.makedirs(config.heatmaps_save_path, exist_ok=True) 
            heatmap_img = config.heatmaps_save_path + f'/heatmap_{img_name}'
            show_GradCAM(resized_img, heatmap_plus, save_path=heatmap_img)

        return predictions_list

    except Exception as e:
        raise CustomException(e, sys)
    
if __name__=='__main__':
    predictions_list = predict(config.test_images)
    print(predictions_list)