Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,7 @@ from sklearn.utils._testing import ignore_warnings
|
|
11 |
from sklearn.exceptions import ConvergenceWarning
|
12 |
from sklearn.utils import shuffle
|
13 |
|
|
|
14 |
def load_mnist(classes, n_samples):
|
15 |
"""Load MNIST, select two classes, shuffle and return only n_samples."""
|
16 |
# Load data from http://openml.org/d/554
|
@@ -110,7 +111,12 @@ def plot(classes, max_iterations, num_samples, n_iter_no_change, validation_frac
|
|
110 |
fig2.tight_layout()
|
111 |
|
112 |
return fig1, fig2
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
114 |
with gr.Blocks() as demo:
|
115 |
gr.Markdown(info)
|
116 |
with gr.Row():
|
@@ -126,4 +132,4 @@ with gr.Blocks() as demo:
|
|
126 |
|
127 |
btn = gr.Button("Run")
|
128 |
btn.click(fn=plot, inputs=[classes, max_iterations, num_samples, n_iter_no_change, validation_fraction, tol], outputs=[out1, out2])
|
129 |
-
demo.launch()
|
|
|
11 |
from sklearn.exceptions import ConvergenceWarning
|
12 |
from sklearn.utils import shuffle
|
13 |
|
14 |
+
|
15 |
def load_mnist(classes, n_samples):
|
16 |
"""Load MNIST, select two classes, shuffle and return only n_samples."""
|
17 |
# Load data from http://openml.org/d/554
|
|
|
111 |
fig2.tight_layout()
|
112 |
|
113 |
return fig1, fig2
|
114 |
+
|
115 |
+
info = '''# Early stopping of Stochastic Gradient Descent\nThis example demonstrates the use of early stopping when training using Stochastic Gradient Descent.
|
116 |
+
Since when using the stochastic method, the loss function isn't guaranteed to decrease with each iteration, and convergence is only guaranteed in expectation. For this reason monitoring the convergence of the loss function might not be the optimal solution and can result in many redundant training steps.
|
117 |
+
An alternative is monitoring the convergence on a validation score, and early stopping the training once a convergence criterion is met. This enables us to find the least number of iterations which is sufficient to build a model that generalizes well to unseen data and reduces the chance of over-fitting the training data.
|
118 |
+
Created by [@Nahrawy](https://huggingface.co/Nahrawy) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/linear_model/plot_sgd_early_stopping.html)'''
|
119 |
+
|
120 |
with gr.Blocks() as demo:
|
121 |
gr.Markdown(info)
|
122 |
with gr.Row():
|
|
|
132 |
|
133 |
btn = gr.Button("Run")
|
134 |
btn.click(fn=plot, inputs=[classes, max_iterations, num_samples, n_iter_no_change, validation_fraction, tol], outputs=[out1, out2])
|
135 |
+
demo.launch()
|