Spaces:
Runtime error
Runtime error
first commit
Browse files- app.py +85 -0
- model_healthy_bot.pth +3 -0
- tag2class_healthy_bot.json +9 -0
app.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List
|
2 |
+
import asyncio
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
|
7 |
+
from albumentations import Compose, LongestMaxSize, Normalize, PadIfNeeded
|
8 |
+
from albumentations.pytorch import ToTensorV2
|
9 |
+
import cv2
|
10 |
+
import streamlit as st
|
11 |
+
import torch
|
12 |
+
import PIL
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
class ClassifyModel:
|
16 |
+
def __init__(self):
|
17 |
+
self.model = None
|
18 |
+
self.class2tag = None
|
19 |
+
self.tag2class = None
|
20 |
+
self.transform = None
|
21 |
+
|
22 |
+
def load(self, path="/model"):
|
23 |
+
image_size = 512
|
24 |
+
self.transform = Compose(
|
25 |
+
[
|
26 |
+
LongestMaxSize(max_size=image_size),
|
27 |
+
PadIfNeeded(image_size, image_size, border_mode=cv2.BORDER_CONSTANT),
|
28 |
+
Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), always_apply=True),
|
29 |
+
ToTensorV2()
|
30 |
+
]
|
31 |
+
)
|
32 |
+
self.model = torch.jit.load("model_healthy_bot.pth")
|
33 |
+
with open("tag2class_healthy_bot.json") as fin:
|
34 |
+
self.tag2class = json.load(fin)
|
35 |
+
self.class2tag = {v: k for k, v in self.tag2class.items()}
|
36 |
+
logging.debug(f"class2tag: {self.class2tag}")
|
37 |
+
|
38 |
+
def predict(self, *imgs) -> List[str]:
|
39 |
+
logging.debug(f"batch size: {len(imgs)}")
|
40 |
+
input_ts = [self.transform(image=img)["image"] for img in imgs]
|
41 |
+
input_t = torch.stack(input_ts)
|
42 |
+
logging.debug(f"input_t: {input_t.shape}")
|
43 |
+
output_ts = self.model(input_t)
|
44 |
+
activation_fn = torch.nn.__dict__['Sigmoid']()
|
45 |
+
output_ts = activation_fn(output_ts)
|
46 |
+
labels = list(self.tag2class.keys())
|
47 |
+
logging.debug(f"output_ts: {output_ts.shape}")
|
48 |
+
#logging.debug(f"output_pb: {output_pb}")
|
49 |
+
res = []
|
50 |
+
trh = 0.5
|
51 |
+
for output_t in output_ts:
|
52 |
+
logit = (output_t > trh).long()
|
53 |
+
if logit[0] and any([*logit[1:3], *logit[4:]]):
|
54 |
+
output_t[0] = 0
|
55 |
+
indices = (output_t > trh).nonzero(as_tuple=True)[0]
|
56 |
+
prob = output_t[indices].tolist()
|
57 |
+
tag = [labels[i] for i in indices.tolist()]
|
58 |
+
res_dict = dict(zip(
|
59 |
+
list(self.tag2class.keys()),list(output_t.numpy())
|
60 |
+
))
|
61 |
+
logging.debug(f"all results: {res_dict}")
|
62 |
+
logging.debug(f"prob: {prob}")
|
63 |
+
logging.debug(f"result: {tag}")
|
64 |
+
res.append((tag,prob,res_dict))
|
65 |
+
return res
|
66 |
+
|
67 |
+
m = ClassifyModel()
|
68 |
+
m.load()
|
69 |
+
|
70 |
+
st.sidebar.title("About")
|
71 |
+
|
72 |
+
st.sidebar.info(
|
73 |
+
"This application identifies the crop health in the picture.")
|
74 |
+
|
75 |
+
|
76 |
+
st.title('Wheat Rust Identification')
|
77 |
+
model = load_tf_model(model_path)
|
78 |
+
st.write("Upload an image.")
|
79 |
+
uploaded_file = st.file_uploader("")
|
80 |
+
|
81 |
+
if uploaded_file is not None:
|
82 |
+
image = Image.open(uploaded_file)
|
83 |
+
img = np.array(image)
|
84 |
+
result = m.predict(img)
|
85 |
+
st.write(f"I think this has {result[0][0]}(confidence: {round(result[0][1],2)})")
|
model_healthy_bot.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d705f89edd648d4f6f55917cac6205d1eb8851b0d1afe872d5c3386df679a871
|
3 |
+
size 21884854
|
tag2class_healthy_bot.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"healthy": 0,
|
3 |
+
"leaf rust": 1,
|
4 |
+
"powdery mildew": 2,
|
5 |
+
"seedlings": 3,
|
6 |
+
"septoria": 4,
|
7 |
+
"stem rust": 5,
|
8 |
+
"yellow rust": 6
|
9 |
+
}
|