taellinglin commited on
Commit
9f495ed
·
verified ·
1 Parent(s): 84b2e15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +219 -219
app.py CHANGED
@@ -1,219 +1,219 @@
1
- import gradio as gr
2
- import torch
3
- import numpy as np
4
- import matplotlib.pyplot as plt
5
- from torch import nn, optim
6
- from torch.utils.data import DataLoader
7
- from io import StringIO
8
- import os
9
- import base64
10
- # Import your modules
11
- from logistic_regression import LogisticRegressionModel
12
- from softmax_regression import SoftmaxRegressionModel
13
- from shallow_neural_network import ShallowNeuralNetwork
14
- import convolutional_neural_networks
15
- from dataset_loader import CustomMNISTDataset
16
- from final_project import train_final_model, get_dataset_options, FinalCNN
17
- import torchvision.transforms as transforms
18
-
19
- import torch
20
- import matplotlib.pyplot as plt
21
- from matplotlib import font_manager
22
- import matplotlib.pyplot as plt
23
- def number_to_char(number):
24
- if 0 <= number <= 9:
25
- return str(number) # 0-9
26
- elif 10 <= number <= 35:
27
- return chr(number + 87) # a-z (10 -> 'a', 35 -> 'z')
28
- elif 36 <= number <= 61:
29
- return chr(number + 65) # A-Z (36 -> 'A', 61 -> 'Z')
30
- else:
31
- return ''
32
-
33
- def visualize_predictions_svg(model, train_loader, stage):
34
- """Visualizes predictions and returns SVG string for Gradio display."""
35
- # Load the Daemon font
36
- font_path = './Daemon.otf' # Path to your Daemon font
37
- prop = font_manager.FontProperties(fname=font_path)
38
-
39
- fig, ax = plt.subplots(6, 3, figsize=(12, 16)) # 6 rows and 3 columns for 18 images
40
-
41
- model.eval()
42
- images, labels = next(iter(train_loader))
43
- images, labels = images[:18], labels[:18] # Get 18 images and labels
44
-
45
- with torch.no_grad():
46
- outputs = model(images)
47
- _, predictions = torch.max(outputs, 1)
48
-
49
- for i in range(18): # Iterate over 18 images
50
- ax[i // 3, i % 3].imshow(images[i].squeeze(), cmap='gray')
51
-
52
- # Convert predictions and labels to characters
53
- pred_char = number_to_char(predictions[i].item())
54
- label_char = number_to_char(labels[i].item())
55
-
56
- # Display = or != based on prediction
57
- if pred_char == label_char:
58
- title_text = f"{pred_char} = {label_char}"
59
- color = 'green' # Green if correct
60
- else:
61
- title_text = f"{pred_char} != {label_char}"
62
- color = 'red' # Red if incorrect
63
-
64
- # Set title with Daemon font and color
65
- ax[i // 3, i % 3].set_title(title_text, fontproperties=prop, fontsize=12, color=color)
66
- ax[i // 3, i % 3].axis('off')
67
-
68
-
69
- # Convert the figure to SVG
70
- svg_str = figure_to_svg(fig)
71
- save_svg_to_output_folder(svg_str, f"{stage}_predictions.svg") # Save SVG to output folder
72
- plt.close(fig)
73
-
74
- return svg_str
75
-
76
- def figure_to_svg(fig):
77
- """Convert a matplotlib figure to SVG string."""
78
- from io import StringIO
79
- from matplotlib.backends.backend_svg import FigureCanvasSVG
80
- canvas = FigureCanvasSVG(fig)
81
- output = StringIO()
82
- canvas.print_svg(output)
83
- return output.getvalue()
84
-
85
- def save_svg_to_output_folder(svg_str, filename):
86
- """Save the SVG string to the output folder."""
87
- output_path = f'./output/{filename}' # Ensure your output folder exists
88
- with open(output_path, 'w') as f:
89
- f.write(svg_str)
90
-
91
-
92
- def plot_metrics_svg(losses, accuracies):
93
- """Generate training metrics as SVG string."""
94
- fig, ax = plt.subplots(1, 2, figsize=(12, 5))
95
-
96
- ax[0].plot(losses, label='Loss', color='red')
97
- ax[0].set_title('Training Loss')
98
- ax[0].set_xlabel('Epoch')
99
- ax[0].set_ylabel('Loss')
100
- ax[0].legend()
101
-
102
- ax[1].plot(accuracies, label='Accuracy', color='green')
103
- ax[1].set_title('Training Accuracy')
104
- ax[1].set_xlabel('Epoch')
105
- ax[1].set_ylabel('Accuracy')
106
- ax[1].legend()
107
-
108
- plt.tight_layout()
109
- svg_str = figure_to_svg(fig)
110
- save_svg_to_output_folder(svg_str, "training_metrics.svg") # Save metrics SVG to output folder
111
- plt.close(fig)
112
-
113
- return svg_str
114
-
115
- def train_model_interface(module, dataset_name, epochs=100, lr=0.01):
116
- """Train the selected model with the chosen dataset."""
117
- transform = transforms.Compose([
118
- transforms.Resize((28, 28)),
119
- transforms.Grayscale(num_output_channels=1),
120
- transforms.ToTensor(),
121
- transforms.Normalize(mean=[0.5], std=[0.5])
122
- ])
123
-
124
- # Load dataset using CustomMNISTDataset
125
- train_dataset = CustomMNISTDataset(os.path.join("data", dataset_name, "raw"), transform=transform)
126
- train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
127
-
128
- # Select Model
129
- if module == "Logistic Regression":
130
- model = LogisticRegressionModel(input_size=1)
131
- elif module == "Softmax Regression":
132
- model = SoftmaxRegressionModel(input_size=2, num_classes=2)
133
- elif module == "Shallow Neural Networks":
134
- model = ShallowNeuralNetwork(input_size=2, hidden_size=5, output_size=2)
135
- elif module == "Deep Networks":
136
- import deep_networks
137
- model = deep_networks.DeepNeuralNetwork(input_size=10, hidden_sizes=[20, 10], output_size=2)
138
- elif module == "Convolutional Neural Networks":
139
- model = convolutional_neural_networks.ConvolutionalNeuralNetwork()
140
- elif module == "AI Calligraphy":
141
- model = FinalCNN()
142
- else:
143
- return "Invalid module selection", None, None, None, None
144
-
145
- # Visualize before training
146
- before_svg = visualize_predictions_svg(model, train_loader, "Before")
147
-
148
- # Train the model
149
- criterion = nn.CrossEntropyLoss()
150
- optimizer = optim.SGD(model.parameters(), lr=lr)
151
-
152
- losses, accuracies = train_final_model(model, criterion, optimizer, train_loader, epochs)
153
-
154
- # Visualize after training
155
- after_svg = visualize_predictions_svg(model, train_loader, "After")
156
-
157
- # Metrics SVG
158
- metrics_svg = plot_metrics_svg(losses, accuracies)
159
-
160
- return model, losses, accuracies, before_svg, after_svg, metrics_svg
161
-
162
-
163
- def list_datasets():
164
- """List all available datasets dynamically"""
165
- dataset_options = get_dataset_options()
166
- if not dataset_options:
167
- return ["No datasets found"]
168
- return dataset_options
169
-
170
- ### 🎯 Gradio Interface ###
171
- def run_module(module, dataset_name, epochs, lr):
172
- """Gradio interface callback"""
173
- # Train model
174
- model, losses, accuracies, before_svg, after_svg, metrics_svg = train_model_interface(
175
- module, dataset_name, epochs, lr
176
- )
177
-
178
- if model is None:
179
- return "Error: Invalid selection.", None, None, None, None
180
-
181
- # Simply pass the SVG strings to Gradio's gr.Image for rendering
182
- return (
183
- f"Training completed for {module} with {epochs} epochs.",
184
- before_svg, # Pass raw SVG for before training
185
- after_svg, # Pass raw SVG for after training
186
- metrics_svg # Return training metrics SVG directly
187
- )
188
-
189
- ### 🌟 Gradio UI ###
190
- with gr.Blocks() as app:
191
- with gr.Tab("Techniques"):
192
- gr.Markdown("### 🧠 Select Model to Train")
193
-
194
- module_select = gr.Dropdown(
195
- choices=[
196
- "AI Calligraphy"
197
- ],
198
- label="Select Module"
199
- )
200
-
201
- dataset_list = gr.Dropdown(choices=list_datasets(), label="Select Dataset")
202
- epochs = gr.Slider(10, 1024, value=100, step=10, label="Epochs")
203
- lr = gr.Slider(0.001, 0.1, value=0.01, step=0.001, label="Learning Rate")
204
-
205
- train_button = gr.Button("Train Model")
206
-
207
- output = gr.Textbox(label="Training Output")
208
- before_svg = gr.HTML(label="Before Training Predictions")
209
- after_svg = gr.HTML(label="After Training Predictions")
210
- metrics_svg = gr.HTML(label="Metrics")
211
-
212
- train_button.click(
213
- run_module,
214
- inputs=[module_select, dataset_list, epochs, lr],
215
- outputs=[output, before_svg, after_svg, metrics_svg]
216
- )
217
-
218
- # Launch Gradio app
219
- app.launch(server_name="127.0.0.1", server_port=5555, share=True)
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from torch import nn, optim
6
+ from torch.utils.data import DataLoader
7
+ from io import StringIO
8
+ import os
9
+ import base64
10
+ # Import your modules
11
+ from logistic_regression import LogisticRegressionModel
12
+ from softmax_regression import SoftmaxRegressionModel
13
+ from shallow_neural_network import ShallowNeuralNetwork
14
+ import convolutional_neural_networks
15
+ from dataset_loader import CustomMNISTDataset
16
+ from final_project import train_final_model, get_dataset_options, FinalCNN
17
+ import torchvision.transforms as transforms
18
+
19
+ import torch
20
+ import matplotlib.pyplot as plt
21
+ from matplotlib import font_manager
22
+ import matplotlib.pyplot as plt
23
+ def number_to_char(number):
24
+ if 0 <= number <= 9:
25
+ return str(number) # 0-9
26
+ elif 10 <= number <= 35:
27
+ return chr(number + 87) # a-z (10 -> 'a', 35 -> 'z')
28
+ elif 36 <= number <= 61:
29
+ return chr(number + 65) # A-Z (36 -> 'A', 61 -> 'Z')
30
+ else:
31
+ return ''
32
+
33
+ def visualize_predictions_svg(model, train_loader, stage):
34
+ """Visualizes predictions and returns SVG string for Gradio display."""
35
+ # Load the Daemon font
36
+ font_path = './Daemon.otf' # Path to your Daemon font
37
+ prop = font_manager.FontProperties(fname=font_path)
38
+
39
+ fig, ax = plt.subplots(6, 3, figsize=(12, 16)) # 6 rows and 3 columns for 18 images
40
+
41
+ model.eval()
42
+ images, labels = next(iter(train_loader))
43
+ images, labels = images[:18], labels[:18] # Get 18 images and labels
44
+
45
+ with torch.no_grad():
46
+ outputs = model(images)
47
+ _, predictions = torch.max(outputs, 1)
48
+
49
+ for i in range(18): # Iterate over 18 images
50
+ ax[i // 3, i % 3].imshow(images[i].squeeze(), cmap='gray')
51
+
52
+ # Convert predictions and labels to characters
53
+ pred_char = number_to_char(predictions[i].item())
54
+ label_char = number_to_char(labels[i].item())
55
+
56
+ # Display = or != based on prediction
57
+ if pred_char == label_char:
58
+ title_text = f"{pred_char} = {label_char}"
59
+ color = 'green' # Green if correct
60
+ else:
61
+ title_text = f"{pred_char} != {label_char}"
62
+ color = 'red' # Red if incorrect
63
+
64
+ # Set title with Daemon font and color
65
+ ax[i // 3, i % 3].set_title(title_text, fontproperties=prop, fontsize=12, color=color)
66
+ ax[i // 3, i % 3].axis('off')
67
+
68
+
69
+ # Convert the figure to SVG
70
+ svg_str = figure_to_svg(fig)
71
+ save_svg_to_output_folder(svg_str, f"{stage}_predictions.svg") # Save SVG to output folder
72
+ plt.close(fig)
73
+
74
+ return svg_str
75
+
76
+ def figure_to_svg(fig):
77
+ """Convert a matplotlib figure to SVG string."""
78
+ from io import StringIO
79
+ from matplotlib.backends.backend_svg import FigureCanvasSVG
80
+ canvas = FigureCanvasSVG(fig)
81
+ output = StringIO()
82
+ canvas.print_svg(output)
83
+ return output.getvalue()
84
+
85
+ def save_svg_to_output_folder(svg_str, filename):
86
+ """Save the SVG string to the output folder."""
87
+ output_path = f'./output/{filename}' # Ensure your output folder exists
88
+ with open(output_path, 'w') as f:
89
+ f.write(svg_str)
90
+
91
+
92
+ def plot_metrics_svg(losses, accuracies):
93
+ """Generate training metrics as SVG string."""
94
+ fig, ax = plt.subplots(1, 2, figsize=(12, 5))
95
+
96
+ ax[0].plot(losses, label='Loss', color='red')
97
+ ax[0].set_title('Training Loss')
98
+ ax[0].set_xlabel('Epoch')
99
+ ax[0].set_ylabel('Loss')
100
+ ax[0].legend()
101
+
102
+ ax[1].plot(accuracies, label='Accuracy', color='green')
103
+ ax[1].set_title('Training Accuracy')
104
+ ax[1].set_xlabel('Epoch')
105
+ ax[1].set_ylabel('Accuracy')
106
+ ax[1].legend()
107
+
108
+ plt.tight_layout()
109
+ svg_str = figure_to_svg(fig)
110
+ save_svg_to_output_folder(svg_str, "training_metrics.svg") # Save metrics SVG to output folder
111
+ plt.close(fig)
112
+
113
+ return svg_str
114
+
115
+ def train_model_interface(module, dataset_name, epochs=100, lr=0.01):
116
+ """Train the selected model with the chosen dataset."""
117
+ transform = transforms.Compose([
118
+ transforms.Resize((28, 28)),
119
+ transforms.Grayscale(num_output_channels=1),
120
+ transforms.ToTensor(),
121
+ transforms.Normalize(mean=[0.5], std=[0.5])
122
+ ])
123
+
124
+ # Load dataset using CustomMNISTDataset
125
+ train_dataset = CustomMNISTDataset(os.path.join("data", dataset_name, "raw"), transform=transform)
126
+ train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
127
+
128
+ # Select Model
129
+ if module == "Logistic Regression":
130
+ model = LogisticRegressionModel(input_size=1)
131
+ elif module == "Softmax Regression":
132
+ model = SoftmaxRegressionModel(input_size=2, num_classes=2)
133
+ elif module == "Shallow Neural Networks":
134
+ model = ShallowNeuralNetwork(input_size=2, hidden_size=5, output_size=2)
135
+ elif module == "Deep Networks":
136
+ import deep_networks
137
+ model = deep_networks.DeepNeuralNetwork(input_size=10, hidden_sizes=[20, 10], output_size=2)
138
+ elif module == "Convolutional Neural Networks":
139
+ model = convolutional_neural_networks.ConvolutionalNeuralNetwork()
140
+ elif module == "AI Calligraphy":
141
+ model = FinalCNN()
142
+ else:
143
+ return "Invalid module selection", None, None, None, None
144
+
145
+ # Visualize before training
146
+ before_svg = visualize_predictions_svg(model, train_loader, "Before")
147
+
148
+ # Train the model
149
+ criterion = nn.CrossEntropyLoss()
150
+ optimizer = optim.SGD(model.parameters(), lr=lr)
151
+
152
+ losses, accuracies = train_final_model(model, criterion, optimizer, train_loader, epochs)
153
+
154
+ # Visualize after training
155
+ after_svg = visualize_predictions_svg(model, train_loader, "After")
156
+
157
+ # Metrics SVG
158
+ metrics_svg = plot_metrics_svg(losses, accuracies)
159
+
160
+ return model, losses, accuracies, before_svg, after_svg, metrics_svg
161
+
162
+
163
+ def list_datasets():
164
+ """List all available datasets dynamically"""
165
+ dataset_options = get_dataset_options()
166
+ if not dataset_options:
167
+ return ["No datasets found"]
168
+ return dataset_options
169
+
170
+ ### 🎯 Gradio Interface ###
171
+ def run_module(module, dataset_name, epochs, lr):
172
+ """Gradio interface callback"""
173
+ # Train model
174
+ model, losses, accuracies, before_svg, after_svg, metrics_svg = train_model_interface(
175
+ module, dataset_name, epochs, lr
176
+ )
177
+
178
+ if model is None:
179
+ return "Error: Invalid selection.", None, None, None, None
180
+
181
+ # Simply pass the SVG strings to Gradio's gr.Image for rendering
182
+ return (
183
+ f"Training completed for {module} with {epochs} epochs.",
184
+ before_svg, # Pass raw SVG for before training
185
+ after_svg, # Pass raw SVG for after training
186
+ metrics_svg # Return training metrics SVG directly
187
+ )
188
+
189
+ ### 🌟 Gradio UI ###
190
+ with gr.Blocks() as app:
191
+ with gr.Tab("Techniques"):
192
+ gr.Markdown("### 🧠 Select Model to Train")
193
+
194
+ module_select = gr.Dropdown(
195
+ choices=[
196
+ "AI Calligraphy"
197
+ ],
198
+ label="Select Module"
199
+ )
200
+
201
+ dataset_list = gr.Dropdown(choices=list_datasets(), label="Select Dataset")
202
+ epochs = gr.Slider(10, 1024, value=100, step=10, label="Epochs")
203
+ lr = gr.Slider(0.001, 0.1, value=0.01, step=0.001, label="Learning Rate")
204
+
205
+ train_button = gr.Button("Train Model")
206
+
207
+ output = gr.Textbox(label="Training Output")
208
+ before_svg = gr.HTML(label="Before Training Predictions")
209
+ after_svg = gr.HTML(label="After Training Predictions")
210
+ metrics_svg = gr.HTML(label="Metrics")
211
+
212
+ train_button.click(
213
+ run_module,
214
+ inputs=[module_select, dataset_list, epochs, lr],
215
+ outputs=[output, before_svg, after_svg, metrics_svg]
216
+ )
217
+
218
+ # Launch Gradio app
219
+ app.launch()