rohithk-03 commited on
Commit
f5fe239
·
1 Parent(s): 95e893d

update model code

Browse files
Files changed (2) hide show
  1. app.py +23 -6
  2. model.py +1 -1
app.py CHANGED
@@ -12,25 +12,28 @@ import os
12
  import requests
13
  import requests
14
  import cloudinary
15
- import model
16
  import cloudinary.uploader
17
  from a import main
18
  import numpy as np
 
19
  # Initialize Flask app
20
  app = Flask(__name__)
21
 
22
  GDRIVE_MODEL_URL = "https://drive.google.com/uc?id=1fzKneepaRt_--dzamTcDBM-9d3_dLX7z"
23
  LOCAL_MODEL_PATH = "checkpoint32.pth"
24
-
25
- print(GDRIVE_MODEL_URL)
26
 
27
 
28
  def download_file_from_google_drive():
29
  gdown.download(GDRIVE_MODEL_URL, LOCAL_MODEL_PATH, quiet=False)
30
 
31
 
32
- file_id = "1fzKneepaRt_--dzamTcDBM-9d3_dLX7z"
33
- destination = "checkpoint32.pth"
 
 
 
34
 
35
 
36
  def download_model():
@@ -46,6 +49,7 @@ def download_model():
46
 
47
 
48
  download_file_from_google_drive()
 
49
 
50
 
51
  @app.route("/")
@@ -227,6 +231,10 @@ def predict():
227
  # Save file temporarily
228
  temp_path = os.path.join(tempfile.gettempdir(), file.filename)
229
  file.save(temp_path)
 
 
 
 
230
  if file.filename.lower().endswith((".png", ".jpg", ".jpeg")):
231
  image = Image.open(temp_path)
232
  image_save_path = os.path.join(
@@ -250,7 +258,16 @@ def predict():
250
  if (is_mri_image(temp_path)):
251
  return jsonify({"message": "Not an mri image", "confidence": 0.95, "saved_path": image_save_path})
252
  a, b = model.check_file(temp_path)
253
- return jsonify({"message": a, "confidence": b, "saved_path": image_save_path})
 
 
 
 
 
 
 
 
 
254
 
255
 
256
  if __name__ == "__main__":
 
12
  import requests
13
  import requests
14
  import cloudinary
15
+ # import model
16
  import cloudinary.uploader
17
  from a import main
18
  import numpy as np
19
+ import torchvision.transforms as transforms
20
  # Initialize Flask app
21
  app = Flask(__name__)
22
 
23
  GDRIVE_MODEL_URL = "https://drive.google.com/uc?id=1fzKneepaRt_--dzamTcDBM-9d3_dLX7z"
24
  LOCAL_MODEL_PATH = "checkpoint32.pth"
25
+ d = "https://drive.google.com/uc?id=1GfrlFNoa7E4liMHyMuF73nA21yT9SNSb"
 
26
 
27
 
28
  def download_file_from_google_drive():
29
  gdown.download(GDRIVE_MODEL_URL, LOCAL_MODEL_PATH, quiet=False)
30
 
31
 
32
+ da = "a.pth"
33
+
34
+
35
+ def download_file_from_google_drived():
36
+ gdown.download(d, da, quiet=False)
37
 
38
 
39
  def download_model():
 
49
 
50
 
51
  download_file_from_google_drive()
52
+ download_file_from_google_drived()
53
 
54
 
55
  @app.route("/")
 
231
  # Save file temporarily
232
  temp_path = os.path.join(tempfile.gettempdir(), file.filename)
233
  file.save(temp_path)
234
+ transform = transforms.Compose([
235
+ transforms.Resize((224, 224)),
236
+ transforms.ToTensor(),
237
+ ])
238
  if file.filename.lower().endswith((".png", ".jpg", ".jpeg")):
239
  image = Image.open(temp_path)
240
  image_save_path = os.path.join(
 
258
  if (is_mri_image(temp_path)):
259
  return jsonify({"message": "Not an mri image", "confidence": 0.95, "saved_path": image_save_path})
260
  a, b = model.check_file(temp_path)
261
+ image = Image.open(temp_path).convert("RGB")
262
+ output = model(transform(image).unsqueeze(0).to(device))
263
+ stage = output.item()
264
+ if stage <= 2.0:
265
+ stage = "Mild"
266
+ elif stage >= 2.0 and stage <= 3.2:
267
+ stage = "Moderate"
268
+ else:
269
+ stage = "Severe"
270
+ return jsonify({"message": a, "confidence": b, "stage": stage, "saved_path": image_save_path})
271
 
272
 
273
  if __name__ == "__main__":
model.py CHANGED
@@ -161,4 +161,4 @@ def check_file(image_path):
161
  model.to(device)
162
  model = nn.DataParallel(model)
163
  output, confidence = test_model(model, test_loader, device)
164
- return "control" if output.item() == 0 else "ms", confidence.item()
 
161
  model.to(device)
162
  model = nn.DataParallel(model)
163
  output, confidence = test_model(model, test_loader, device)
164
+ return "No ms detected" if output.item() == 0 else "MS Detected", confidence.item()