Harry2687 commited on
Commit
250ec14
·
1 Parent(s): 678f8d9

reorganised prediction code

Browse files
app.py CHANGED
@@ -1,54 +1,11 @@
1
  from shiny import App, reactive, render, ui
2
  from shiny.types import ImgData
3
 
4
- from PIL import Image
5
-
6
- import torch
7
- import torchvision.transforms as transforms
8
- import modules.model as model
9
-
10
- def forward_prop(image_path):
11
- imsize = 128
12
- classes = ('Female', 'Male')
13
- modelsave_name = 'model_parameters.pt'
14
-
15
- if torch.backends.mps.is_available():
16
- device = torch.device('mps')
17
- device_name = 'Apple Silicon GPU'
18
- elif torch.cuda.is_available():
19
- device = torch.device('cuda')
20
- device_name = 'CUDA'
21
- else:
22
- device = torch.device('cpu')
23
- device_name = 'CPU'
24
-
25
- torch.set_default_device(device)
26
-
27
- resnet = model.resnetModel_128()
28
- resnet.load_state_dict(torch.load(modelsave_name, map_location=device))
29
- resnet.eval()
30
-
31
- loader = transforms.Compose([
32
- transforms.Resize([imsize, imsize]),
33
- transforms.Grayscale(1),
34
- transforms.ToTensor(),
35
- transforms.Normalize(0, 1)
36
- ])
37
-
38
- image = Image.open(image_path).convert('RGB')
39
- image_tensor = loader(image)
40
- image_tensor = image_tensor.unsqueeze(0)
41
-
42
- X = image_tensor.to(device)
43
- y_pred = resnet.forward(X)
44
- predicted = torch.max(y_pred.data,1)[1]
45
-
46
- return f'Prediction: {classes[predicted]} with weight {y_pred[0][predicted].item()}. Predicted using {device_name}.'
47
-
48
 
49
  app_ui = ui.page_fluid(
50
- ui.panel_title('Image Uploader'),
51
- ui.input_file('image', 'Image', accept=['.png', '.jpg', '.jpeg']),
52
  ui.output_image('show_image'),
53
  ui.input_action_button('predict_gender', 'Predict'),
54
  ui.output_text('prediction')
@@ -61,7 +18,7 @@ def server(input, output, session):
61
  return None
62
 
63
  image_path = input.image()[0]['datapath']
64
- img: ImgData = {'src': image_path}
65
  return img
66
 
67
  @render.text
@@ -71,8 +28,11 @@ def server(input, output, session):
71
  return None
72
 
73
  image_path = input.image()[0]['datapath']
74
- prediction = forward_prop(image_path)
 
 
 
75
 
76
- return prediction
77
 
78
  app = App(app_ui, server)
 
1
  from shiny import App, reactive, render, ui
2
  from shiny.types import ImgData
3
 
4
+ from gender_cnn.predict import predict_gender
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  app_ui = ui.page_fluid(
7
+ ui.panel_title('Gender Classifier'),
8
+ ui.input_file('image', 'Upload image', accept=['.png', '.jpg', '.jpeg']),
9
  ui.output_image('show_image'),
10
  ui.input_action_button('predict_gender', 'Predict'),
11
  ui.output_text('prediction')
 
18
  return None
19
 
20
  image_path = input.image()[0]['datapath']
21
+ img: ImgData = {'src': image_path, 'height': '300px', 'width': '300px'}
22
  return img
23
 
24
  @render.text
 
28
  return None
29
 
30
  image_path = input.image()[0]['datapath']
31
+ output = predict_gender(image_path)
32
+ prediction = output['prediction']
33
+ weighting = output['weighting']
34
+ device = output['device']
35
 
36
+ return f'Prediction: {prediction}. Weighting: {str(round(weighting, 2))}. Device: {device}.'
37
 
38
  app = App(app_ui, server)
gender_cnn/__init__.py ADDED
File without changes
{modules → gender_cnn}/model.py RENAMED
File without changes
gender_cnn/predict.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+ from .model import resnetModel_128
5
+
6
+ def predict_gender(image_path: str):
7
+ # Constants
8
+ imsize = 128
9
+ classes = ('Female', 'Male')
10
+ model_name = 'resnetModel_128_epoch_2.pt'
11
+
12
+ # Set Backend
13
+ if torch.backends.mps.is_available():
14
+ device = torch.device('mps')
15
+ device_name = 'Apple Silicon GPU'
16
+ elif torch.cuda.is_available():
17
+ device = torch.device('cuda')
18
+ device_name = 'CUDA'
19
+ else:
20
+ device = torch.device('cpu')
21
+ device_name = 'CPU'
22
+
23
+ # Init model
24
+ resnet = resnetModel_128().to(device)
25
+ resnet.load_state_dict(torch.load(model_name, map_location=device))
26
+ resnet.eval()
27
+
28
+ # Load and transform image
29
+ loader = transforms.Compose([
30
+ transforms.Resize([imsize, imsize]),
31
+ transforms.Grayscale(1),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(0, 1)
34
+ ])
35
+
36
+ image = Image.open(image_path).convert('RGB')
37
+ image_tensor = loader(image)
38
+ image_tensor = image_tensor.unsqueeze(0)
39
+
40
+ # Predict
41
+ X = image_tensor.to(device)
42
+ y_pred = resnet.forward(X)
43
+ pred_index = torch.max(y_pred.data,1)[1]
44
+ prediction = classes[pred_index]
45
+ weighting = y_pred[0][pred_index].item()
46
+
47
+ return {'prediction': prediction, 'weighting': weighting, 'device': device_name}