rudra0410hf's picture
Upload folder using huggingface_hub
1ccf66a verified
import os
import cv2
import numpy as np
from flask import Flask, render_template, request, redirect, url_for, jsonify
from tensorflow.keras.layers import Layer
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
import tensorflow as tf
app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = 'static/uploads'
app.config['ALLOWED_EXTENSIONS'] = {'png', 'jpg', 'jpeg'}
class Sampling(tf.keras.layers.Layer):
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
# Load models with explicit TensorFlow context
with tf.init_scope():
ct_to_mri_model = load_model('models/ct_to_mri_epoch_39.h5',
custom_objects={'Sampling': Sampling})
mri_to_ct_model = load_model('models/mri_to_ct_epoch_39.h5',
custom_objects={'Sampling': Sampling})
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
def process_image(image_path, model):
img = cv2.imread(image_path)
if img is None:
raise ValueError("Could not load image")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (256, 256))
img = img.astype(np.float32) / 255.0
img = np.expand_dims(img, axis=0)
prediction = model.predict(img)
if isinstance(prediction, (list, tuple)):
prediction = prediction[0]
prediction = np.squeeze(prediction)
prediction = (prediction * 255).astype(np.uint8)
return prediction
def clean_uploads():
upload_dir = app.config['UPLOAD_FOLDER']
for filename in os.listdir(upload_dir):
file_path = os.path.join(upload_dir, filename)
try:
if os.path.isfile(file_path):
os.unlink(file_path)
except Exception as e:
print(f'Error deleting {file_path}: {e}')
@app.route('/')
def index():
return render_template('index.html')
@app.route('/try_now')
def try_now():
return render_template('try_now.html')
@app.route('/samples')
def samples():
return render_template('samples.html')
@app.route('/model_info')
def model_info():
return render_template('model_info.html')
@app.route('/translate', methods=['POST'])
def translate():
clean_uploads() # Clean previous uploads
if 'file' not in request.files:
return redirect(request.url)
file = request.files['file']
if file.filename == '':
return redirect(request.url)
if file and allowed_file(file.filename):
# Save original image
upload_path = os.path.join(app.config['UPLOAD_FOLDER'], 'original.png')
file.save(upload_path)
# Choose model based on translation direction
direction = request.form.get('direction')
model = ct_to_mri_model if direction == 'ct_to_mri' else mri_to_ct_model
try:
result = process_image(upload_path, model)
result_path = os.path.join(app.config['UPLOAD_FOLDER'], 'result.png')
plt.imsave(result_path, result)
return render_template('result.html',
original=url_for('static', filename='uploads/original.png'),
result=url_for('static', filename='uploads/result.png'))
except Exception as e:
return f"Error processing image: {str(e)}"
return redirect(url_for('try_now'))
if __name__ == '__main__':
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
app.run(host='0.0.0.0', port=5000, debug=False) # Set debug=False for production