File size: 1,980 Bytes
b855174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import tensorflow as tf
import numpy as np
from PIL import Image
import io
import base64
import json

class MRIClassifier:
    def __init__(self):
        # Load the model
        self.model = tf.keras.models.load_model("./model")
        self.img_size = (224, 224)  # EfficientNet standard size
        print(f"Model loaded with input size: {self.img_size}")
        
    def preprocess(self, image):
        # Resize image to expected input shape
        image = image.resize(self.img_size)
        
        # Convert to RGB if not already
        if image.mode != "RGB":
            image = image.convert("RGB")
            
        # Convert to array and normalize
        image_array = np.array(image) / 255.0
        
        # Add batch dimension
        image_array = np.expand_dims(image_array, axis=0)
        
        return image_array
        
    def predict(self, image_bytes):
        # Convert bytes to image
        image = Image.open(io.BytesIO(image_bytes))
        
        # Preprocess image
        image_array = self.preprocess(image)
        
        # Make prediction
        prediction = self.model.predict(image_array)
        pred_value = float(prediction[0][0])
        
        # Format result
        predicted_class = "VAD-Demented" if pred_value > 0.5 else "Non-Demented"
        
        return [
            {
                "label": predicted_class,
                "score": pred_value
            }
        ]

# Initialize classifier
classifier = MRIClassifier()

def inference(model_inputs):
    # Handle both direct image inputs and base64 encoded images
    if isinstance(model_inputs, dict) and "image" in model_inputs:
        # Base64 encoded image
        image_bytes = base64.b64decode(model_inputs["image"])
    else:
        # Direct image bytes (from API)
        image_bytes = model_inputs
        
    results = classifier.predict(image_bytes)
    return results