anime-image-labeller / anime_image_label_inference.py
Curt Tigges
added main files
65a728c
raw
history blame
1.82 kB
# -*- coding: utf-8 -*-
"""Anime Image Label Inference.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1BxPfM2uV54LeiQGEk43xcyNwMgqwCR80
"""
!pip install -Uqq gradio
!pip install -Uqq fastbook
import fastbook
fastbook.setup_book()
import gradio as gr
from fastbook import *
path = Path('gdrive/MyDrive/anime-image-labeller/safebooru')
"""
Get the prediction labels and their accuracies, then return the results as a dictionary.
[obj] - tensor matrix containing the predicted accuracy given from the model
[learn] - fastai learner needed to get the labels
[thresh] - minimum accuracy threshold to returning results
"""
def get_pred_classes(obj, learn, thresh):
labels = []
# get list of classes from csv--replace
with open(path/'classes.txt', 'r') as f:
for line in f:
labels.append(line.strip('\n'))
predictions = {}
x=0
for item in obj:
acc= round(item.item(), 3)
if acc > thresh:
predictions[labels[x]] = round(acc, 3)
x+=1
predictions =sorted(predictions.items(), key=lambda x: x[1], reverse=True)
return predictions
def get_x(r): return path/'images'/r['img_name']
def get_y(r): return [t for t in r['tags'].split(' ') if t in pop_tags]
learn = load_learner(path/'model-large-40e.pkl')
def predict_single_img(imf, thresh=0.2, learn=learn):
img = PILImage.create(imf)
#img.show() #show image
_, _, pred_pct = learn.predict(img) #predict while ignoring first 2 array inputs
img.show() #show image
return str(get_pred_classes(pred_pct, learn, thresh))
predict_single_img(path/'test/mask.jpeg')
iface = gr.Interface(fn=predict_single_img,
inputs=["image","number"],
outputs="text")
iface.launch()