Spaces:
Sleeping
Sleeping
rohithk-03
commited on
Commit
·
f5fe239
1
Parent(s):
95e893d
update model code
Browse files
app.py
CHANGED
@@ -12,25 +12,28 @@ import os
|
|
12 |
import requests
|
13 |
import requests
|
14 |
import cloudinary
|
15 |
-
import model
|
16 |
import cloudinary.uploader
|
17 |
from a import main
|
18 |
import numpy as np
|
|
|
19 |
# Initialize Flask app
|
20 |
app = Flask(__name__)
|
21 |
|
22 |
GDRIVE_MODEL_URL = "https://drive.google.com/uc?id=1fzKneepaRt_--dzamTcDBM-9d3_dLX7z"
|
23 |
LOCAL_MODEL_PATH = "checkpoint32.pth"
|
24 |
-
|
25 |
-
print(GDRIVE_MODEL_URL)
|
26 |
|
27 |
|
28 |
def download_file_from_google_drive():
|
29 |
gdown.download(GDRIVE_MODEL_URL, LOCAL_MODEL_PATH, quiet=False)
|
30 |
|
31 |
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
34 |
|
35 |
|
36 |
def download_model():
|
@@ -46,6 +49,7 @@ def download_model():
|
|
46 |
|
47 |
|
48 |
download_file_from_google_drive()
|
|
|
49 |
|
50 |
|
51 |
@app.route("/")
|
@@ -227,6 +231,10 @@ def predict():
|
|
227 |
# Save file temporarily
|
228 |
temp_path = os.path.join(tempfile.gettempdir(), file.filename)
|
229 |
file.save(temp_path)
|
|
|
|
|
|
|
|
|
230 |
if file.filename.lower().endswith((".png", ".jpg", ".jpeg")):
|
231 |
image = Image.open(temp_path)
|
232 |
image_save_path = os.path.join(
|
@@ -250,7 +258,16 @@ def predict():
|
|
250 |
if (is_mri_image(temp_path)):
|
251 |
return jsonify({"message": "Not an mri image", "confidence": 0.95, "saved_path": image_save_path})
|
252 |
a, b = model.check_file(temp_path)
|
253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
|
256 |
if __name__ == "__main__":
|
|
|
12 |
import requests
|
13 |
import requests
|
14 |
import cloudinary
|
15 |
+
# import model
|
16 |
import cloudinary.uploader
|
17 |
from a import main
|
18 |
import numpy as np
|
19 |
+
import torchvision.transforms as transforms
|
20 |
# Initialize Flask app
|
21 |
app = Flask(__name__)
|
22 |
|
23 |
GDRIVE_MODEL_URL = "https://drive.google.com/uc?id=1fzKneepaRt_--dzamTcDBM-9d3_dLX7z"
|
24 |
LOCAL_MODEL_PATH = "checkpoint32.pth"
|
25 |
+
d = "https://drive.google.com/uc?id=1GfrlFNoa7E4liMHyMuF73nA21yT9SNSb"
|
|
|
26 |
|
27 |
|
28 |
def download_file_from_google_drive():
|
29 |
gdown.download(GDRIVE_MODEL_URL, LOCAL_MODEL_PATH, quiet=False)
|
30 |
|
31 |
|
32 |
+
da = "a.pth"
|
33 |
+
|
34 |
+
|
35 |
+
def download_file_from_google_drived():
|
36 |
+
gdown.download(d, da, quiet=False)
|
37 |
|
38 |
|
39 |
def download_model():
|
|
|
49 |
|
50 |
|
51 |
download_file_from_google_drive()
|
52 |
+
download_file_from_google_drived()
|
53 |
|
54 |
|
55 |
@app.route("/")
|
|
|
231 |
# Save file temporarily
|
232 |
temp_path = os.path.join(tempfile.gettempdir(), file.filename)
|
233 |
file.save(temp_path)
|
234 |
+
transform = transforms.Compose([
|
235 |
+
transforms.Resize((224, 224)),
|
236 |
+
transforms.ToTensor(),
|
237 |
+
])
|
238 |
if file.filename.lower().endswith((".png", ".jpg", ".jpeg")):
|
239 |
image = Image.open(temp_path)
|
240 |
image_save_path = os.path.join(
|
|
|
258 |
if (is_mri_image(temp_path)):
|
259 |
return jsonify({"message": "Not an mri image", "confidence": 0.95, "saved_path": image_save_path})
|
260 |
a, b = model.check_file(temp_path)
|
261 |
+
image = Image.open(temp_path).convert("RGB")
|
262 |
+
output = model(transform(image).unsqueeze(0).to(device))
|
263 |
+
stage = output.item()
|
264 |
+
if stage <= 2.0:
|
265 |
+
stage = "Mild"
|
266 |
+
elif stage >= 2.0 and stage <= 3.2:
|
267 |
+
stage = "Moderate"
|
268 |
+
else:
|
269 |
+
stage = "Severe"
|
270 |
+
return jsonify({"message": a, "confidence": b, "stage": stage, "saved_path": image_save_path})
|
271 |
|
272 |
|
273 |
if __name__ == "__main__":
|
model.py
CHANGED
@@ -161,4 +161,4 @@ def check_file(image_path):
|
|
161 |
model.to(device)
|
162 |
model = nn.DataParallel(model)
|
163 |
output, confidence = test_model(model, test_loader, device)
|
164 |
-
return "
|
|
|
161 |
model.to(device)
|
162 |
model = nn.DataParallel(model)
|
163 |
output, confidence = test_model(model, test_loader, device)
|
164 |
+
return "No ms detected" if output.item() == 0 else "MS Detected", confidence.item()
|