|
"""classify.py |
|
|
|
This module classifies the input image. |
|
|
|
""" |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
from tensorflow.keras.preprocessing import image |
|
|
|
|
|
import os |
|
|
|
script_directory = os.path.dirname(os.path.abspath(__file__)) |
|
model_path = os.path.join(script_directory, '../../models/football_logo_model.h5') |
|
|
|
model = tf.keras.models.load_model(model_path) |
|
|
|
def preprocess_image(img): |
|
""" |
|
Preprocess the input image for model prediction. |
|
|
|
Args: |
|
img: Input image. |
|
|
|
Returns: |
|
img: Preprocessed image. |
|
""" |
|
img = image.img_to_array(img) |
|
img = np.expand_dims(img, axis=0) |
|
img /= 255.0 |
|
return img |
|
|
|
|
|
class_names = ['Arsenal', 'Chelsea', 'Liverpool', 'Manchester City', 'Manchester United'] |
|
|
|
def classify_logo(img): |
|
""" |
|
Classify the football logo in the input image. |
|
|
|
Args: |
|
img: Path to the input image. |
|
|
|
Returns: |
|
str: The predicted class of the football logo. |
|
""" |
|
img_path = img |
|
img = image.load_img(img_path, target_size=(224, 224)) |
|
img = image.img_to_array(img) |
|
img = preprocess_image(img) |
|
prediction = model.predict(img) |
|
predicted_class_index = prediction.argmax(axis=1)[0] |
|
predicted_class_name = class_names[predicted_class_index] |
|
return predicted_class_name |
|
|