File size: 3,823 Bytes
1ccf66a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import cv2
import numpy as np
from flask import Flask, render_template, request, redirect, url_for, jsonify
from tensorflow.keras.layers import Layer
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
import tensorflow as tf

app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = 'static/uploads'
app.config['ALLOWED_EXTENSIONS'] = {'png', 'jpg', 'jpeg'}

class Sampling(tf.keras.layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

# Load models with explicit TensorFlow context
with tf.init_scope():
    ct_to_mri_model = load_model('models/ct_to_mri_epoch_39.h5', 
                               custom_objects={'Sampling': Sampling})
    mri_to_ct_model = load_model('models/mri_to_ct_epoch_39.h5',
                               custom_objects={'Sampling': Sampling})

def allowed_file(filename):
    return '.' in filename and \
           filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']

def process_image(image_path, model):
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError("Could not load image")
    
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (256, 256))
    img = img.astype(np.float32) / 255.0
    img = np.expand_dims(img, axis=0)
    
    prediction = model.predict(img)
    if isinstance(prediction, (list, tuple)):
        prediction = prediction[0]
    
    prediction = np.squeeze(prediction)
    prediction = (prediction * 255).astype(np.uint8)
    return prediction

def clean_uploads():
    upload_dir = app.config['UPLOAD_FOLDER']
    for filename in os.listdir(upload_dir):
        file_path = os.path.join(upload_dir, filename)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print(f'Error deleting {file_path}: {e}')

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/try_now')
def try_now():
    return render_template('try_now.html')

@app.route('/samples')
def samples():
    return render_template('samples.html')

@app.route('/model_info')
def model_info():
    return render_template('model_info.html')

@app.route('/translate', methods=['POST'])
def translate():
    clean_uploads()  # Clean previous uploads
    
    if 'file' not in request.files:
        return redirect(request.url)
    
    file = request.files['file']
    if file.filename == '':
        return redirect(request.url)
    
    if file and allowed_file(file.filename):
        # Save original image
        upload_path = os.path.join(app.config['UPLOAD_FOLDER'], 'original.png')
        file.save(upload_path)
        
        # Choose model based on translation direction
        direction = request.form.get('direction')
        model = ct_to_mri_model if direction == 'ct_to_mri' else mri_to_ct_model
        
        try:
            result = process_image(upload_path, model)
            result_path = os.path.join(app.config['UPLOAD_FOLDER'], 'result.png')
            plt.imsave(result_path, result)
            
            return render_template('result.html', 
                                original=url_for('static', filename='uploads/original.png'),
                                result=url_for('static', filename='uploads/result.png'))
        except Exception as e:
            return f"Error processing image: {str(e)}"
    
    return redirect(url_for('try_now'))

if __name__ == '__main__':
    os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
    app.run(host='0.0.0.0', port=5000, debug=False)  # Set debug=False for production