Spaces:
Running
Running
import os | |
import cv2 | |
import numpy as np | |
from flask import Flask, render_template, request, redirect, url_for, jsonify | |
from tensorflow.keras.layers import Layer | |
from tensorflow.keras.models import load_model | |
import matplotlib.pyplot as plt | |
import tensorflow as tf | |
app = Flask(__name__) | |
app.config['UPLOAD_FOLDER'] = 'static/uploads' | |
app.config['ALLOWED_EXTENSIONS'] = {'png', 'jpg', 'jpeg'} | |
class Sampling(tf.keras.layers.Layer): | |
def call(self, inputs): | |
z_mean, z_log_var = inputs | |
batch = tf.shape(z_mean)[0] | |
dim = tf.shape(z_mean)[1] | |
epsilon = tf.keras.backend.random_normal(shape=(batch, dim)) | |
return z_mean + tf.exp(0.5 * z_log_var) * epsilon | |
# Load models with explicit TensorFlow context | |
with tf.init_scope(): | |
ct_to_mri_model = load_model('models/ct_to_mri_epoch_39.h5', | |
custom_objects={'Sampling': Sampling}) | |
mri_to_ct_model = load_model('models/mri_to_ct_epoch_39.h5', | |
custom_objects={'Sampling': Sampling}) | |
def allowed_file(filename): | |
return '.' in filename and \ | |
filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS'] | |
def process_image(image_path, model): | |
img = cv2.imread(image_path) | |
if img is None: | |
raise ValueError("Could not load image") | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
img = cv2.resize(img, (256, 256)) | |
img = img.astype(np.float32) / 255.0 | |
img = np.expand_dims(img, axis=0) | |
prediction = model.predict(img) | |
if isinstance(prediction, (list, tuple)): | |
prediction = prediction[0] | |
prediction = np.squeeze(prediction) | |
prediction = (prediction * 255).astype(np.uint8) | |
return prediction | |
def clean_uploads(): | |
upload_dir = app.config['UPLOAD_FOLDER'] | |
for filename in os.listdir(upload_dir): | |
file_path = os.path.join(upload_dir, filename) | |
try: | |
if os.path.isfile(file_path): | |
os.unlink(file_path) | |
except Exception as e: | |
print(f'Error deleting {file_path}: {e}') | |
def index(): | |
return render_template('index.html') | |
def try_now(): | |
return render_template('try_now.html') | |
def samples(): | |
return render_template('samples.html') | |
def model_info(): | |
return render_template('model_info.html') | |
def translate(): | |
clean_uploads() # Clean previous uploads | |
if 'file' not in request.files: | |
return redirect(request.url) | |
file = request.files['file'] | |
if file.filename == '': | |
return redirect(request.url) | |
if file and allowed_file(file.filename): | |
# Save original image | |
upload_path = os.path.join(app.config['UPLOAD_FOLDER'], 'original.png') | |
file.save(upload_path) | |
# Choose model based on translation direction | |
direction = request.form.get('direction') | |
model = ct_to_mri_model if direction == 'ct_to_mri' else mri_to_ct_model | |
try: | |
result = process_image(upload_path, model) | |
result_path = os.path.join(app.config['UPLOAD_FOLDER'], 'result.png') | |
plt.imsave(result_path, result) | |
return render_template('result.html', | |
original=url_for('static', filename='uploads/original.png'), | |
result=url_for('static', filename='uploads/result.png')) | |
except Exception as e: | |
return f"Error processing image: {str(e)}" | |
return redirect(url_for('try_now')) | |
if __name__ == '__main__': | |
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
app.run(host='0.0.0.0', port=5000, debug=False) # Set debug=False for production |