Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
import pandas as pd
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import seaborn as sns
|
6 |
+
from ydata_synthetic.synthesizers.regular import RegularSynthesizer
|
7 |
+
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
|
8 |
+
|
9 |
+
st.set_page_config(layout="wide",initial_sidebar_state="auto")
|
10 |
+
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
|
11 |
+
def run():
|
12 |
+
#global data_synn
|
13 |
+
st.sidebar.image('YData_logo.svg')
|
14 |
+
st.title('Generate synthetic data for a tabular classification dataset using [ydata-synthetic](https://github.com/ydataai/ydata-synthetic)')
|
15 |
+
st.markdown('This streamlit application can generate synthetic data for your dataset. Please read all the instructions in the sidebar before you start the process.')
|
16 |
+
data = st.file_uploader('Upload a preprocessed dataset in csv format')
|
17 |
+
st.sidebar.title('About')
|
18 |
+
st.sidebar.markdown('[ydata-synthetic](https://github.com/ydataai/ydata-synthetic) is an open-source library and is used to generate synthetic data mimicking the real world data.')
|
19 |
+
st.sidebar.header('What is synthetic data?')
|
20 |
+
st.sidebar.markdown('Synthetic data is artificially generated data that is not collected from real world events. It replicates the statistical components of real data without containing any identifiable information, ensuring individuals privacy.')
|
21 |
+
st.sidebar.header('Why Synthetic Data?')
|
22 |
+
st.sidebar.markdown('''Synthetic data can be used for many applications:
|
23 |
+
- Privacy
|
24 |
+
- Remove bias
|
25 |
+
- Balance datasets
|
26 |
+
- Augment datasets''')
|
27 |
+
|
28 |
+
|
29 |
+
st.sidebar.header('Steps to follow')
|
30 |
+
st.sidebar.markdown('''
|
31 |
+
- Upload any preprocessed tabular classification dataset.
|
32 |
+
- Choose the parameters in the adjacent window appropriately.
|
33 |
+
- Since this is a demo, please choose less number of epochs for quick completion of training.
|
34 |
+
- After choosing all parameters, Click the button under the parameters to start training.
|
35 |
+
- After the training is complete, you will see a graph comparing both real data set and synthetic dataset. Categorical columns are used to compare.
|
36 |
+
- You will also see a button to download your synthetic dataset. Click that button to download your dataset.''')
|
37 |
+
|
38 |
+
st.sidebar.markdown('''[![Repo](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/ydataai/ydata-synthetic)''',unsafe_allow_html=True)
|
39 |
+
|
40 |
+
@st.cache
|
41 |
+
def train(df):
|
42 |
+
#models_dir = './cache'
|
43 |
+
gan_args = ModelParameters(batch_size=batch_size,
|
44 |
+
lr=learning_rate*0.001,
|
45 |
+
betas=(beta_1, beta_2),
|
46 |
+
noise_dim=noise_dim,
|
47 |
+
layers_dim=layer_dim)
|
48 |
+
|
49 |
+
train_args = TrainParameters(epochs=epochs,
|
50 |
+
sample_interval=log_step)
|
51 |
+
synthesizer = RegularSynthesizer(modelname = model, model_parameters = gan_args, n_discriminator=3)
|
52 |
+
synthesizer.fit(data, train_args, num_cols, cat_cols)
|
53 |
+
synthesizer.save('data_synth.pkl')
|
54 |
+
synthesizer = model.load('data_synth.pkl')
|
55 |
+
data_syn = synthesizer.sample(samples)
|
56 |
+
return data_syn
|
57 |
+
@st.cache
|
58 |
+
def convert_df(df):
|
59 |
+
return df.to_csv().encode('utf-8')
|
60 |
+
if data is not None:
|
61 |
+
data = pd.read_csv(data)
|
62 |
+
data.dropna(inplace=True)
|
63 |
+
st.header('Choose the parameters!!')
|
64 |
+
col1, col2, col3,col4 = st.columns(4)
|
65 |
+
with col1:
|
66 |
+
model = st.selectbox('Choose the GAN model', ['cgan','wgangp'],key=1)
|
67 |
+
if model=='cgan':
|
68 |
+
model=cgan
|
69 |
+
else:
|
70 |
+
model = wgangp
|
71 |
+
num_cols = st.multiselect('Choose the numerical columns', data.columns,key=1)
|
72 |
+
cat_cols = st.multiselect('Choose categorical columns', [x for x in data.columns if x not in num_cols], key=2)
|
73 |
+
|
74 |
+
with col2:
|
75 |
+
noise_dim = st.number_input('Select noise dimension', 0,200,128,1)
|
76 |
+
layer_dim = st.number_input('Select the layer dimension', 0,200,128,1)
|
77 |
+
batch_size = st.number_input('Select batch size', 0,500, 500,1)
|
78 |
+
|
79 |
+
with col3:
|
80 |
+
log_step = st.number_input('Select sample interval', 0,200,100,1)
|
81 |
+
epochs = st.number_input('Select the number of epochs',0,50,2,1)
|
82 |
+
learning_rate = st.number_input('Select learning rate(x1e-3', 0.01, 0.1, 0.05, 0.01)
|
83 |
+
|
84 |
+
with col4:
|
85 |
+
beta_1 = st.slider('Select first beta co-efficient', 0.0, 1.0, 0.5)
|
86 |
+
beta_2 = st.slider('Select second beta co-efficient', 0.0, 1.0, 0.9)
|
87 |
+
samples = st.number_input('Select the number of synthetic samples to be generated', 0, 400000, step=1000)
|
88 |
+
if st.button('Click here to start the training process'):
|
89 |
+
if data is not None:
|
90 |
+
st.write('Model Training is in progress. It may take a few minutes. Please wait for a while.')
|
91 |
+
data_synn = train(data)
|
92 |
+
st.success('Synthetic dataset with the given number of samples is generated!!')
|
93 |
+
st.subheader('Real Data vs Synthetic Data')
|
94 |
+
f , axes = plt.subplots(len(cat_cols),2, figsize=(20,25))
|
95 |
+
f.suptitle('Real data vs Synthetic data')
|
96 |
+
for i, j in enumerate(cat_cols):
|
97 |
+
sns.countplot(x=j, data=data, ax = axes[i,0])
|
98 |
+
sns.countplot(x=j, data=data_synn, ax = axes[i,1])
|
99 |
+
st.pyplot(f)
|
100 |
+
st.download_button(
|
101 |
+
label="Download data as CSV",
|
102 |
+
data=convert_df(data_synn),
|
103 |
+
file_name='data_syn.csv',
|
104 |
+
mime='text/csv')
|
105 |
+
st.balloons()
|
106 |
+
else:
|
107 |
+
st.write('Upload a dataset to train!!')
|
108 |
+
if __name__== '__main__':
|
109 |
+
run()
|