|
import gradio as gr |
|
import torch |
|
from huggingface_hub import from_pretrained_fastai |
|
from pathlib import Path |
|
|
|
examples = ["./examples/image_1.png", |
|
"./examples/image_2.png", |
|
"./examples/image_3.png", |
|
"./examples/image_4.png", |
|
"./examples/image_5.png"] |
|
|
|
repo_id = "hugginglearners/rice_image_classification" |
|
path = Path("./") |
|
|
|
def get_y(r): |
|
return r["label"] |
|
|
|
def get_x(r): |
|
return path/r["fname"] |
|
|
|
learner = from_pretrained_fastai(repo_id) |
|
|
|
def inference(image): |
|
label_predict,_,probs = learner.predict(image) |
|
return f"This rice image is {label_predict} with {100*probs[torch.argmax(probs)].item():.2f}% probability" |
|
|
|
gr.Interface( |
|
fn=inference, |
|
title="Rice image classification", |
|
description = "Predict which type of rice belong to Arborio, Basmati, Ipsala, Jasmine, Karacadag", |
|
inputs="image", |
|
examples=examples, |
|
outputs=gr.Textbox(label='Prediction'), |
|
cache_examples=False, |
|
article = "Author: <a href=\"https://www.linkedin.com/in/vumichien/\">Vu Minh Chien</a>", |
|
).launch(debug=True, enable_queue=True) |