ranjeetjha's picture
Upload 19 files (#1)
ab80e91 verified
from inference import UniversalImageClassifier
import json
def load_model():
"""Load the model for inference"""
with open("class_names.json", "r") as f:
class_names = json.load(f)
classifier = UniversalImageClassifier(
model_path="pytorch_model.pth",
config_path="config.json",
class_names=class_names
)
return classifier
def predict(image_path):
"""Predict image class"""
classifier = load_model()
return classifier.predict(image_path)