import gradio as gr import os from train_model import train from predict_model import predict_all os.environ['NUMPY_EXPERIMENTAL_ARRAY_FUNCTION'] = '0' def train_model(): train() return "Model trained and saved as animal_classifier_resnet.pth" def download_model(): return "animal_classifier_resnet.pth" def run_predictions(): results = predict_all() return "\n".join(results) with gr.Blocks() as demo: gr.Markdown("# Animal Classifier Model") with gr.Tab("Train"): train_button = gr.Button("Train Model") train_output = gr.Textbox() train_button.click(train_model, outputs=train_output) with gr.Tab("Predict"): predict_button = gr.Button("Run Predictions") predict_output = gr.Textbox() predict_button.click(run_predictions, outputs=predict_output) with gr.Tab("Download"): gr.Markdown("## Download Trained Model") download_button = gr.Button("Download Model") download_button.click(download_model, outputs=gr.File()) if __name__ == "__main__": demo.launch()