from inference import UniversalImageClassifier | |
import json | |
def load_model(): | |
"""Load the model for inference""" | |
with open("class_names.json", "r") as f: | |
class_names = json.load(f) | |
classifier = UniversalImageClassifier( | |
model_path="pytorch_model.pth", | |
config_path="config.json", | |
class_names=class_names | |
) | |
return classifier | |
def predict(image_path): | |
"""Predict image class""" | |
classifier = load_model() | |
return classifier.predict(image_path) | |