Anwarkh1 commited on
Commit
856919c
·
verified ·
1 Parent(s): 01c57ef

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +56 -23
main.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" # Set cache directory to a writable location
3
 
4
- from fastapi import FastAPI, UploadFile, File
5
  from transformers import ViTForImageClassification, ViTFeatureExtractor
6
  import torch
7
  import torch.nn as nn
@@ -9,7 +9,7 @@ import torchvision.transforms as transforms
9
  from PIL import Image
10
  import io
11
 
12
- app = FastAPI()
13
 
14
  # Load the ViT model and its feature extractor
15
  model_name = "google/vit-base-patch16-224-in21k"
@@ -22,7 +22,6 @@ model.classifier = nn.Linear(model.config.hidden_size, num_classes)
22
  model.load_state_dict(torch.load("skin_cancer_model.pth", map_location=torch.device('cpu')))
23
  model.eval()
24
 
25
-
26
  # Define class labels
27
  class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma']
28
 
@@ -35,26 +34,60 @@ transform = transforms.Compose([
35
  transforms.ToTensor(),
36
  ])
37
 
38
- # Define API endpoint for model inference with class-specific thresholds
39
- @app.post('/predict')
40
- async def predict(file: UploadFile = File(...)):
41
- contents = await file.read()
42
- image = Image.open(io.BytesIO(contents))
43
- image = transform(image).unsqueeze(0) # Add batch dimension and move to device
44
-
45
  with torch.no_grad():
46
  outputs = model(image)
47
-
48
- # Calculate softmax probabilities
49
- probabilities = torch.softmax(outputs.logits, dim=1).cpu().numpy()[0]
50
-
51
- # Get predicted class index and its probability
52
- predicted_idx = torch.argmax(torch.tensor(probabilities)).item()
53
- predicted_label = class_labels[predicted_idx]
54
- predicted_probability = probabilities[predicted_idx]
55
-
56
- # Check if the predicted probability meets the class-specific threshold
57
- if predicted_probability < thresholds[predicted_idx]:
58
- return {'predicted_class': 'uncertain', 'accuracy': float(predicted_probability)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  else:
60
- return {'predicted_class': predicted_label, 'accuracy': float(predicted_probability)}
 
 
 
 
1
  import os
2
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" # Set cache directory to a writable location
3
 
4
+ from flask import Flask, request, render_template, jsonify
5
  from transformers import ViTForImageClassification, ViTFeatureExtractor
6
  import torch
7
  import torch.nn as nn
 
9
  from PIL import Image
10
  import io
11
 
12
+ app = Flask(__name__)
13
 
14
  # Load the ViT model and its feature extractor
15
  model_name = "google/vit-base-patch16-224-in21k"
 
22
  model.load_state_dict(torch.load("skin_cancer_model.pth", map_location=torch.device('cpu')))
23
  model.eval()
24
 
 
25
  # Define class labels
26
  class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma']
27
 
 
34
  transforms.ToTensor(),
35
  ])
36
 
37
+ @app.route('/')
38
+ def index():
39
+ return render_template('index.html', appName="Skin Cancer Classification Application")
40
+
41
+ def model_predict(image):
42
+ image = transform(image).unsqueeze(0) # Add batch dimension
 
43
  with torch.no_grad():
44
  outputs = model(image)
45
+ return outputs
46
+
47
+ @app.route('/predictApi', methods=["POST"])
48
+ def api():
49
+ try:
50
+ if 'fileup' not in request.files:
51
+ return jsonify({'Error': "Please try again. The Image doesn't exist"})
52
+ file = request.files.get('fileup')
53
+ image = Image.open(io.BytesIO(file.read()))
54
+ result = model_predict(image)
55
+
56
+ probabilities = torch.softmax(result.logits, dim=1).cpu().numpy()[0]
57
+ predicted_idx = torch.argmax(torch.tensor(probabilities)).item()
58
+ max_prob = probabilities[predicted_idx]
59
+ threshold = thresholds[predicted_idx]
60
+
61
+ if max_prob < threshold:
62
+ return jsonify({'Error': 'No cancer detected or benign lesion.'})
63
+ prediction = class_labels[predicted_idx]
64
+ return jsonify({'prediction': prediction})
65
+ except Exception as e:
66
+ return jsonify({'Error': 'An error occurred', 'Message': str(e)})
67
+
68
+ @app.route('/predict', methods=['GET', 'POST'])
69
+ def predict():
70
+ if request.method == 'POST':
71
+ try:
72
+ if 'fileup' not in request.files:
73
+ return render_template('index.html', prediction='No file selected.', appName="Skin Cancer Classification Application")
74
+ file = request.files['fileup']
75
+ image = Image.open(io.BytesIO(file.read()))
76
+ result = model_predict(image)
77
+
78
+ probabilities = torch.softmax(result.logits, dim=1).cpu().numpy()[0]
79
+ predicted_idx = torch.argmax(torch.tensor(probabilities)).item()
80
+ max_prob = probabilities[predicted_idx]
81
+ threshold = thresholds[predicted_idx]
82
+
83
+ if max_prob < threshold:
84
+ return render_template('index.html', prediction='No cancer detected or benign lesion.', appName="Skin Cancer Classification Application")
85
+ prediction = class_labels[predicted_idx]
86
+ return render_template('index.html', prediction=prediction, appName="Skin Cancer Classification Application")
87
+ except Exception as e:
88
+ return render_template('index.html', prediction='Error: ' + str(e), appName="Skin Cancer Classification Application")
89
  else:
90
+ return render_template('index.html', appName="Skin Cancer Classification Application")
91
+
92
+ if __name__ == '__main__':
93
+ app.run(debug=True)