Senayfre commited on
Commit
80549b2
·
1 Parent(s): e7354c2

first commit

Browse files
Files changed (3) hide show
  1. app.py +85 -0
  2. model_healthy_bot.pth +3 -0
  3. tag2class_healthy_bot.json +9 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List
2
+ import asyncio
3
+ import json
4
+ import logging
5
+ import os
6
+
7
+ from albumentations import Compose, LongestMaxSize, Normalize, PadIfNeeded
8
+ from albumentations.pytorch import ToTensorV2
9
+ import cv2
10
+ import streamlit as st
11
+ import torch
12
+ import PIL
13
+ import numpy as np
14
+
15
+ class ClassifyModel:
16
+ def __init__(self):
17
+ self.model = None
18
+ self.class2tag = None
19
+ self.tag2class = None
20
+ self.transform = None
21
+
22
+ def load(self, path="/model"):
23
+ image_size = 512
24
+ self.transform = Compose(
25
+ [
26
+ LongestMaxSize(max_size=image_size),
27
+ PadIfNeeded(image_size, image_size, border_mode=cv2.BORDER_CONSTANT),
28
+ Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), always_apply=True),
29
+ ToTensorV2()
30
+ ]
31
+ )
32
+ self.model = torch.jit.load("model_healthy_bot.pth")
33
+ with open("tag2class_healthy_bot.json") as fin:
34
+ self.tag2class = json.load(fin)
35
+ self.class2tag = {v: k for k, v in self.tag2class.items()}
36
+ logging.debug(f"class2tag: {self.class2tag}")
37
+
38
+ def predict(self, *imgs) -> List[str]:
39
+ logging.debug(f"batch size: {len(imgs)}")
40
+ input_ts = [self.transform(image=img)["image"] for img in imgs]
41
+ input_t = torch.stack(input_ts)
42
+ logging.debug(f"input_t: {input_t.shape}")
43
+ output_ts = self.model(input_t)
44
+ activation_fn = torch.nn.__dict__['Sigmoid']()
45
+ output_ts = activation_fn(output_ts)
46
+ labels = list(self.tag2class.keys())
47
+ logging.debug(f"output_ts: {output_ts.shape}")
48
+ #logging.debug(f"output_pb: {output_pb}")
49
+ res = []
50
+ trh = 0.5
51
+ for output_t in output_ts:
52
+ logit = (output_t > trh).long()
53
+ if logit[0] and any([*logit[1:3], *logit[4:]]):
54
+ output_t[0] = 0
55
+ indices = (output_t > trh).nonzero(as_tuple=True)[0]
56
+ prob = output_t[indices].tolist()
57
+ tag = [labels[i] for i in indices.tolist()]
58
+ res_dict = dict(zip(
59
+ list(self.tag2class.keys()),list(output_t.numpy())
60
+ ))
61
+ logging.debug(f"all results: {res_dict}")
62
+ logging.debug(f"prob: {prob}")
63
+ logging.debug(f"result: {tag}")
64
+ res.append((tag,prob,res_dict))
65
+ return res
66
+
67
+ m = ClassifyModel()
68
+ m.load()
69
+
70
+ st.sidebar.title("About")
71
+
72
+ st.sidebar.info(
73
+ "This application identifies the crop health in the picture.")
74
+
75
+
76
+ st.title('Wheat Rust Identification')
77
+ model = load_tf_model(model_path)
78
+ st.write("Upload an image.")
79
+ uploaded_file = st.file_uploader("")
80
+
81
+ if uploaded_file is not None:
82
+ image = Image.open(uploaded_file)
83
+ img = np.array(image)
84
+ result = m.predict(img)
85
+ st.write(f"I think this has {result[0][0]}(confidence: {round(result[0][1],2)})")
model_healthy_bot.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d705f89edd648d4f6f55917cac6205d1eb8851b0d1afe872d5c3386df679a871
3
+ size 21884854
tag2class_healthy_bot.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "healthy": 0,
3
+ "leaf rust": 1,
4
+ "powdery mildew": 2,
5
+ "seedlings": 3,
6
+ "septoria": 4,
7
+ "stem rust": 5,
8
+ "yellow rust": 6
9
+ }