change gradio interface to use files instead of images. expand valid list to sync with updated model.
Browse files
app.py
CHANGED
@@ -5,25 +5,41 @@ import numpy as np
|
|
5 |
import requests
|
6 |
import torch
|
7 |
import heapq
|
|
|
|
|
8 |
from PIL import Image
|
9 |
|
10 |
from huggingface_hub import snapshot_download
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def readLines(filename):
|
15 |
with open(filename,'r') as f:
|
16 |
return(f.read().splitlines())
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
labels = readLines('labels.txt')
|
20 |
labels = [x for x in labels if x] #remove blanks
|
21 |
label_count = len(labels)
|
22 |
valid = readLines('valid.txt')
|
23 |
|
24 |
-
model_path = snapshot_download(repo_id="eshieh2/jaguarid_pantanal")
|
25 |
-
model = tf.saved_model.load(f"{model_path}/saved_model")
|
26 |
-
serving = model.signatures['serving_default']
|
27 |
|
28 |
detector_path = hf_hub_download(repo_id= "eshieh2/jaguarhead",
|
29 |
filename = "jaguarheadv5.pt")
|
@@ -31,12 +47,48 @@ detector = torch.hub.load('ultralytics/yolov5', 'custom', path = detector_path)
|
|
31 |
|
32 |
topk = 3
|
33 |
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
if in_image is None:
|
36 |
return None
|
37 |
width,height = in_image.size
|
38 |
heads = detector(in_image)
|
39 |
masks = [] # tuple of box coords and string
|
|
|
|
|
40 |
for head in heads.xyxy[0]:
|
41 |
x,y,x2,y2,pct,cls = head.numpy()
|
42 |
w = x2 - x
|
@@ -49,11 +101,13 @@ def classify_image(in_image):
|
|
49 |
prediction = serving(tf.convert_to_tensor(inp))['output_0']
|
50 |
prediction = tf.squeeze(prediction)
|
51 |
pred = {labels[i]: float(prediction[i]) for i in range(label_count)}
|
|
|
|
|
52 |
#print(pred)
|
53 |
rect = (int(x),int(y),int(x2),int(y2))
|
|
|
54 |
if topk is not None:
|
55 |
top = heapq.nlargest(topk,pred,key=pred.get)
|
56 |
-
label = ''
|
57 |
for t in top:
|
58 |
if show_all or t.lower() in valid:
|
59 |
if len(label) != 0:
|
@@ -65,14 +119,35 @@ def classify_image(in_image):
|
|
65 |
else:
|
66 |
max_key = max(pred, key=pred.get)
|
67 |
if show_all or max_key.lower() in valid:
|
68 |
-
|
69 |
else:
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
|
74 |
-
|
75 |
title = "JaguarID AI identification App"
|
76 |
-
|
|
|
77 |
|
78 |
-
gr.Interface(fn=
|
|
|
5 |
import requests
|
6 |
import torch
|
7 |
import heapq
|
8 |
+
import argparse
|
9 |
+
|
10 |
from PIL import Image
|
11 |
|
12 |
from huggingface_hub import snapshot_download
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument('--show_all',action = 'store_true')
|
17 |
+
parser.add_argument('--local',action = 'store')
|
18 |
+
|
19 |
+
args = parser.parse_args()
|
20 |
+
show_all = args.show_all
|
21 |
+
local = args.local
|
22 |
+
|
23 |
def readLines(filename):
|
24 |
with open(filename,'r') as f:
|
25 |
return(f.read().splitlines())
|
26 |
|
27 |
+
if local:
|
28 |
+
if len(local) == 0:
|
29 |
+
local = "jaguaridfull"
|
30 |
+
print(f"Using model in {local}")
|
31 |
+
model = tf.saved_model.load(local)
|
32 |
+
labels = readLines('master_labels.txt')
|
33 |
+
else:
|
34 |
+
model_path = snapshot_download(repo_id="eshieh2/jaguarid_pantanal")
|
35 |
+
model = tf.saved_model.load(f"{model_path}/saved_model")
|
36 |
+
labels = readLines('labels.txt')
|
37 |
+
serving = model.signatures['serving_default']
|
38 |
|
|
|
39 |
labels = [x for x in labels if x] #remove blanks
|
40 |
label_count = len(labels)
|
41 |
valid = readLines('valid.txt')
|
42 |
|
|
|
|
|
|
|
43 |
|
44 |
detector_path = hf_hub_download(repo_id= "eshieh2/jaguarhead",
|
45 |
filename = "jaguarheadv5.pt")
|
|
|
47 |
|
48 |
topk = 3
|
49 |
|
50 |
+
from PIL import ExifTags
|
51 |
+
|
52 |
+
def open_image_correct_orientation(image_path):
|
53 |
+
img = Image.open(image_path)
|
54 |
+
|
55 |
+
# Try to get EXIF orientation data
|
56 |
+
try:
|
57 |
+
for tag in ExifTags.TAGS:
|
58 |
+
if ExifTags.TAGS[tag] == "Orientation":
|
59 |
+
orientation_tag = tag
|
60 |
+
break
|
61 |
+
exif = img._getexif()
|
62 |
+
if exif and orientation_tag in exif:
|
63 |
+
orientation = exif[orientation_tag]
|
64 |
+
new_image = None
|
65 |
+
if orientation == 3:
|
66 |
+
new_image = img.rotate(180, expand=True)
|
67 |
+
elif orientation == 6:
|
68 |
+
new_image = img.rotate(270, expand=True)
|
69 |
+
elif orientation == 8:
|
70 |
+
new_image = img.rotate(90, expand=True)
|
71 |
+
if new_image is not None:
|
72 |
+
# Strip EXIF data by saving without it
|
73 |
+
data = list(new_image.getdata())
|
74 |
+
image_without_exif = Image.new(new_image.mode, new_image.size)
|
75 |
+
image_without_exif.putdata(data)
|
76 |
+
img = image_without_exif
|
77 |
+
|
78 |
+
except (AttributeError, KeyError, IndexError):
|
79 |
+
pass # If EXIF data is missing, do nothing
|
80 |
+
|
81 |
+
return img
|
82 |
+
|
83 |
+
def classify_image(in_image,confidence):
|
84 |
+
in_image = open_image_correct_orientation(in_image)
|
85 |
if in_image is None:
|
86 |
return None
|
87 |
width,height = in_image.size
|
88 |
heads = detector(in_image)
|
89 |
masks = [] # tuple of box coords and string
|
90 |
+
retLabel = ''
|
91 |
+
index = 1
|
92 |
for head in heads.xyxy[0]:
|
93 |
x,y,x2,y2,pct,cls = head.numpy()
|
94 |
w = x2 - x
|
|
|
101 |
prediction = serving(tf.convert_to_tensor(inp))['output_0']
|
102 |
prediction = tf.squeeze(prediction)
|
103 |
pred = {labels[i]: float(prediction[i]) for i in range(label_count)}
|
104 |
+
if confidence is not None:
|
105 |
+
pred = {key: value for key, value in pred.items() if value > confidence}
|
106 |
#print(pred)
|
107 |
rect = (int(x),int(y),int(x2),int(y2))
|
108 |
+
label = ''
|
109 |
if topk is not None:
|
110 |
top = heapq.nlargest(topk,pred,key=pred.get)
|
|
|
111 |
for t in top:
|
112 |
if show_all or t.lower() in valid:
|
113 |
if len(label) != 0:
|
|
|
119 |
else:
|
120 |
max_key = max(pred, key=pred.get)
|
121 |
if show_all or max_key.lower() in valid:
|
122 |
+
label = f"{max_key}:{pred[max_key]}"
|
123 |
else:
|
124 |
+
label = f"unknown"
|
125 |
+
masks.append((rect,label))
|
126 |
+
if index > 1:
|
127 |
+
retLabel += ", "
|
128 |
+
retLabel += f"[{index} {label}]"
|
129 |
+
index += 1
|
130 |
+
|
131 |
+
return (in_image,masks),retLabel
|
132 |
+
|
133 |
+
def classify_multiple(files,confidence):
|
134 |
+
textResult = ''
|
135 |
+
firstResult = None
|
136 |
+
for file in files:
|
137 |
+
result = classify_image(file,confidence)
|
138 |
+
if firstResult is None:
|
139 |
+
firstResult = result[0]
|
140 |
+
textResult = textResult + f"{os.path.basename(file)}: {result[1]}\n"
|
141 |
+
|
142 |
+
return firstResult, textResult
|
143 |
+
#image = gr.Image(type='pil')
|
144 |
+
image = gr.File(label = "Input Images", file_count = 'multiple', file_types = ['image'])
|
145 |
+
conf = gr.Radio(choices = [("Best Guess", None), ("Reduce False Positives",0.23)], label = "Confidence")
|
146 |
|
147 |
+
output = gr.AnnotatedImage(label = "First Result")
|
148 |
+
output2 = gr.Textbox(label = "File Results")
|
149 |
title = "JaguarID AI identification App"
|
150 |
+
valid_display = [word.capitalize() for word in valid]
|
151 |
+
desc = f"Identifies the following: {', '.join(valid_display)} . Confidence displayed after the name is on a 0.0 - 1.0 scale. Must have a clear front/side head view."
|
152 |
|
153 |
+
gr.Interface(fn=classify_multiple, inputs=[image,conf], outputs=[output,output2], examples = [[["medrosa.jpg"]],[["guaraci.jpg"]],[["marcela.jpg"]]], title = title, description = desc).launch()
|
valid.txt
CHANGED
@@ -1,11 +1,18 @@
|
|
1 |
bagua
|
|
|
|
|
2 |
guaraci
|
|
|
|
|
3 |
kasimir
|
|
|
|
|
4 |
manath
|
5 |
-
margo
|
6 |
marcela
|
|
|
7 |
medrosa
|
8 |
ousado
|
9 |
patricia
|
10 |
-
|
11 |
ti
|
|
|
|
1 |
bagua
|
2 |
+
bernard
|
3 |
+
bororo
|
4 |
guaraci
|
5 |
+
ibaca
|
6 |
+
jaju
|
7 |
kasimir
|
8 |
+
krishna
|
9 |
+
kyyavera
|
10 |
manath
|
|
|
11 |
marcela
|
12 |
+
margo
|
13 |
medrosa
|
14 |
ousado
|
15 |
patricia
|
16 |
+
rio
|
17 |
ti
|
18 |
+
xingu
|