GRATITUD3 nupurkmr9 commited on
Commit
f971083
·
0 Parent(s):

Duplicate from nupurkmr9/custom-diffusion

Browse files

Co-authored-by: Nupur Kumari <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ method.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training_data/
2
+ results/
3
+
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "custom-diffusion"]
2
+ path = custom-diffusion
3
+ url = https://github.com/adobe-research/custom-diffusion
.pre-commit-config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.2.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: double-quote-string-fixer
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ['--fix=lf']
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.4
19
+ hooks:
20
+ - id: docformatter
21
+ args: ['--in-place']
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.10.1
24
+ hooks:
25
+ - id: isort
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v0.991
28
+ hooks:
29
+ - id: mypy
30
+ args: ['--ignore-missing-imports']
31
+ - repo: https://github.com/google/yapf
32
+ rev: v0.32.0
33
+ hooks:
34
+ - id: yapf
35
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Custom-Diffusion + SD Training
3
+ emoji: 🏢
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.12.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: nupurkmr9/custom-diffusion
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Demo app for https://github.com/adobe-research/custom-diffusion.
3
+
4
+ The code in this repo is partly adapted from the following repository:
5
+ https://huggingface.co/spaces/hysts/LoRA-SD-training
6
+ """
7
+
8
+ from __future__ import annotations
9
+ import sys
10
+ import os
11
+ import pathlib
12
+
13
+ import gradio as gr
14
+ import torch
15
+
16
+ from inference import InferencePipeline
17
+ from trainer import Trainer
18
+ from uploader import upload
19
+
20
+ TITLE = '# Custom Diffusion + StableDiffusion Training UI'
21
+ DESCRIPTION = '''This is a demo for [https://github.com/adobe-research/custom-diffusion](https://github.com/adobe-research/custom-diffusion).
22
+ It is recommended to upgrade to GPU in Settings after duplicating this space to use it.
23
+ <a href="https://huggingface.co/spaces/nupurkmr9/custom-diffusion?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
24
+ '''
25
+ DETAILDESCRIPTION='''
26
+ Custom Diffusion allows you to fine-tune text-to-image diffusion models, such as Stable Diffusion, given a few images of a new concept (~4-20).
27
+ We fine-tune only a subset of model parameters, namely key and value projection matrices, in the cross-attention layers and the modifier token used to represent the object.
28
+ This also reduces the extra storage for each additional concept to 75MB.
29
+ Our method further allows you to use a combination of concepts. Demo for multiple concepts will be added soon.
30
+ <center>
31
+ <img src="https://huggingface.co/spaces/nupurkmr9/custom-diffusion/resolve/main/method.jpg" width="600" align="center" >
32
+ </center>
33
+ '''
34
+
35
+ ORIGINAL_SPACE_ID = 'nupurkmr9/custom-diffusion'
36
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
37
+ SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
38
+
39
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
40
+ '''
41
+ if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
42
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
43
+
44
+ else:
45
+ SETTINGS = 'Settings'
46
+ CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
47
+ <center>
48
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
49
+ "T4 small" is sufficient to run this demo.
50
+ </center>
51
+ '''
52
+
53
+ os.system("git clone https://github.com/adobe-research/custom-diffusion")
54
+ sys.path.append("custom-diffusion")
55
+
56
+ def show_warning(warning_text: str) -> gr.Blocks:
57
+ with gr.Blocks() as demo:
58
+ with gr.Box():
59
+ gr.Markdown(warning_text)
60
+ return demo
61
+
62
+
63
+ def update_output_files() -> dict:
64
+ paths = sorted(pathlib.Path('results').glob('*.pt'))
65
+ paths = [path.as_posix() for path in paths] # type: ignore
66
+ return gr.update(value=paths or None)
67
+
68
+
69
+ def create_training_demo(trainer: Trainer,
70
+ pipe: InferencePipeline) -> gr.Blocks:
71
+ with gr.Blocks() as demo:
72
+ base_model = gr.Dropdown(
73
+ choices=['stabilityai/stable-diffusion-2-1-base', 'CompVis/stable-diffusion-v1-4'],
74
+ value='CompVis/stable-diffusion-v1-4',
75
+ label='Base Model',
76
+ visible=True)
77
+ resolution = gr.Dropdown(choices=['512', '768'],
78
+ value='512',
79
+ label='Resolution',
80
+ visible=True)
81
+
82
+ with gr.Row():
83
+ with gr.Box():
84
+ gr.Markdown('Training Data')
85
+ concept_images = gr.Files(label='Images for your concept')
86
+ with gr.Row():
87
+ class_prompt = gr.Textbox(label='Class Prompt',
88
+ max_lines=1, placeholder='Example: "cat"')
89
+ with gr.Column():
90
+ modifier_token = gr.Checkbox(label='modifier token',
91
+ value=True)
92
+ train_text_encoder = gr.Checkbox(label='Train Text Encoder',
93
+ value=False)
94
+ concept_prompt = gr.Textbox(label='Concept Prompt',
95
+ max_lines=1, placeholder='Example: "photo of a \<new1\> cat"')
96
+ gr.Markdown('''
97
+ - We use "\<new1\>" modifier token in front of the concept, e.g., "\<new1\> cat". By default modifier_token is enabled.
98
+ - If "Train Text Encoder", disable "modifier token" and use any unique text to describe the concept e.g. "ktn cat".
99
+ - For a new concept an e.g. concept prompt is "photo of a \<new1\> cat" and "cat" for class prompt.
100
+ - For a style concept, use "painting in the style of \<new1\> art" for concept prompt and "art" for class prompt.
101
+ - Class prompt should be the object category.
102
+ ''')
103
+ with gr.Box():
104
+ gr.Markdown('Training Parameters')
105
+ num_training_steps = gr.Number(
106
+ label='Number of Training Steps', value=1000, precision=0)
107
+ learning_rate = gr.Number(label='Learning Rate', value=0.00001)
108
+ batch_size = gr.Number(
109
+ label='batch_size', value=1, precision=0)
110
+ with gr.Row():
111
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
112
+ gradient_checkpointing = gr.Checkbox(label='Enable gradient checkpointing', value=False)
113
+ with gr.Accordion('Other Parameters', open=False):
114
+ gradient_accumulation = gr.Number(
115
+ label='Number of Gradient Accumulation',
116
+ value=1,
117
+ precision=0)
118
+ gen_images = gr.Checkbox(label='Generated images as regularization',
119
+ value=False)
120
+ gr.Markdown('''
121
+ - It will take about ~10 minutes to train for 1000 steps and ~21GB on a 3090 GPU.
122
+ - Our results in the paper are trained with batch-size 4 (8 including class regularization samples).
123
+ - Enable gradient checkpointing for lower memory requirements (~14GB) at the expense of slower backward pass.
124
+ - Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
125
+ ''')
126
+
127
+ run_button = gr.Button('Start Training')
128
+ with gr.Box():
129
+ with gr.Row():
130
+ check_status_button = gr.Button('Check Training Status')
131
+ with gr.Column():
132
+ with gr.Box():
133
+ gr.Markdown('Message')
134
+ training_status = gr.Markdown()
135
+ output_files = gr.Files(label='Trained Weight Files')
136
+
137
+ run_button.click(fn=pipe.clear,
138
+ inputs=None,
139
+ outputs=None,)
140
+ run_button.click(fn=trainer.run,
141
+ inputs=[
142
+ base_model,
143
+ resolution,
144
+ concept_images,
145
+ concept_prompt,
146
+ class_prompt,
147
+ num_training_steps,
148
+ learning_rate,
149
+ train_text_encoder,
150
+ modifier_token,
151
+ gradient_accumulation,
152
+ batch_size,
153
+ use_8bit_adam,
154
+ gradient_checkpointing,
155
+ gen_images
156
+ ],
157
+ outputs=[
158
+ training_status,
159
+ output_files,
160
+ ],
161
+ queue=False)
162
+ check_status_button.click(fn=trainer.check_if_running,
163
+ inputs=None,
164
+ outputs=training_status,
165
+ queue=False)
166
+ check_status_button.click(fn=update_output_files,
167
+ inputs=None,
168
+ outputs=output_files,
169
+ queue=False)
170
+ return demo
171
+
172
+
173
+ def find_weight_files() -> list[str]:
174
+ curr_dir = pathlib.Path(__file__).parent
175
+ paths = sorted(curr_dir.rglob('*.bin'))
176
+ paths = [path for path in paths if '.lfs' not in path.name]
177
+ return [path.relative_to(curr_dir).as_posix() for path in paths]
178
+
179
+
180
+ def reload_custom_diffusion_weight_list() -> dict:
181
+ return gr.update(choices=find_weight_files())
182
+
183
+
184
+ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
185
+ with gr.Blocks() as demo:
186
+ with gr.Row():
187
+ with gr.Column():
188
+ base_model = gr.Dropdown(
189
+ choices=['stabilityai/stable-diffusion-2-1-base', 'CompVis/stable-diffusion-v1-4'],
190
+ value='CompVis/stable-diffusion-v1-4',
191
+ label='Base Model',
192
+ visible=True)
193
+ resolution = gr.Dropdown(choices=[512, 768],
194
+ value=512,
195
+ label='Resolution',
196
+ visible=True)
197
+ reload_button = gr.Button('Reload Weight List')
198
+ weight_name = gr.Dropdown(choices=find_weight_files(),
199
+ value='custom-diffusion-models/cat.bin',
200
+ label='Custom Diffusion Weight File')
201
+ prompt = gr.Textbox(
202
+ label='Prompt',
203
+ max_lines=1,
204
+ placeholder='Example: "\<new1\> cat in outer space"')
205
+ seed = gr.Slider(label='Seed',
206
+ minimum=0,
207
+ maximum=100000,
208
+ step=1,
209
+ value=42)
210
+ with gr.Accordion('Other Parameters', open=False):
211
+ num_steps = gr.Slider(label='Number of Steps',
212
+ minimum=0,
213
+ maximum=500,
214
+ step=1,
215
+ value=200)
216
+ guidance_scale = gr.Slider(label='CFG Scale',
217
+ minimum=0,
218
+ maximum=50,
219
+ step=0.1,
220
+ value=6)
221
+ eta = gr.Slider(label='DDIM eta',
222
+ minimum=0,
223
+ maximum=1.,
224
+ step=0.1,
225
+ value=1.)
226
+ batch_size = gr.Slider(label='Batch Size',
227
+ minimum=0,
228
+ maximum=10.,
229
+ step=1,
230
+ value=2)
231
+
232
+ run_button = gr.Button('Generate')
233
+
234
+ gr.Markdown('''
235
+ - Models with names starting with "custom-diffusion-models/" are the pretrained models provided in the [original repo](https://github.com/adobe-research/custom-diffusion), and the ones with names starting with "results/delta.bin" are your trained models.
236
+ - After training, you can press "Reload Weight List" button to load your trained model names.
237
+ - Change default batch-size and steps for faster sampling.
238
+ ''')
239
+ with gr.Column():
240
+ result = gr.Image(label='Result')
241
+
242
+ reload_button.click(fn=reload_custom_diffusion_weight_list,
243
+ inputs=None,
244
+ outputs=weight_name)
245
+ prompt.submit(fn=pipe.run,
246
+ inputs=[
247
+ base_model,
248
+ weight_name,
249
+ prompt,
250
+ seed,
251
+ num_steps,
252
+ guidance_scale,
253
+ eta,
254
+ batch_size,
255
+ resolution
256
+ ],
257
+ outputs=result,
258
+ queue=False)
259
+ run_button.click(fn=pipe.run,
260
+ inputs=[
261
+ base_model,
262
+ weight_name,
263
+ prompt,
264
+ seed,
265
+ num_steps,
266
+ guidance_scale,
267
+ eta,
268
+ batch_size,
269
+ resolution
270
+ ],
271
+ outputs=result,
272
+ queue=False)
273
+ return demo
274
+
275
+
276
+ def create_upload_demo() -> gr.Blocks:
277
+ with gr.Blocks() as demo:
278
+ model_name = gr.Textbox(label='Model Name')
279
+ hf_token = gr.Textbox(
280
+ label='Hugging Face Token (with write permission)')
281
+ upload_button = gr.Button('Upload')
282
+ with gr.Box():
283
+ gr.Markdown('Message')
284
+ result = gr.Markdown()
285
+ gr.Markdown('''
286
+ - You can upload your trained model to your private Model repo (i.e. https://huggingface.co/{your_username}/{model_name}).
287
+ - You can find your Hugging Face token [here](https://huggingface.co/settings/tokens).
288
+ ''')
289
+
290
+ upload_button.click(fn=upload,
291
+ inputs=[model_name, hf_token],
292
+ outputs=result)
293
+
294
+ return demo
295
+
296
+
297
+ pipe = InferencePipeline()
298
+ trainer = Trainer()
299
+
300
+ with gr.Blocks(css='style.css') as demo:
301
+ if os.getenv('IS_SHARED_UI'):
302
+ show_warning(SHARED_UI_WARNING)
303
+ if not torch.cuda.is_available():
304
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
305
+
306
+ gr.Markdown(TITLE)
307
+ gr.Markdown(DESCRIPTION)
308
+ gr.Markdown(DETAILDESCRIPTION)
309
+
310
+ with gr.Tabs():
311
+ with gr.TabItem('Train'):
312
+ create_training_demo(trainer, pipe)
313
+ with gr.TabItem('Test'):
314
+ create_inference_demo(pipe)
315
+ with gr.TabItem('Upload'):
316
+ create_upload_demo()
317
+
318
+ demo.queue(default_enabled=False).launch(share=False)
custom-diffusion-models/barn.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e6ffb7953740286e005fb5ceffc3e985f93b3de97cd46202cf7d66d2171094b
3
+ size 76690626
custom-diffusion-models/cat.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08754e711b9ecaa36785dc64ad0c08317a93d106629c5f42cc5b9a406fe4aefc
3
+ size 76690626
custom-diffusion-models/chair.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e2edb6eaddf540ab9e1a0aa75f3e46ee77c9ee41e8d8e87127777d5dd3ba4b7
3
+ size 76690626
custom-diffusion-models/dog.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbe8a8843279fa01f2eaa3e9b0b34267e5b1949456f81b5bc17fb2a0d23086fe
3
+ size 76690626
custom-diffusion-models/flower.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a37dcb14359c2baae984a218f62758eb52c842c7557e790063d4cd4daa120e5b
3
+ size 76690626
custom-diffusion-models/moongate.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:865ccdd950d4384af1b4cf45d955db4b26ec3736eb03bccab70fee4f51abb441
3
+ size 76687301
custom-diffusion-models/table.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3771413ef6b319fd1f6df0ec2490febe17524cc3afd67b63042bed85af8cb9c2
3
+ size 76690626
custom-diffusion-models/teddybear.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0656c6eafed89a146f7ab971913c9ee35b4f0a96a4e1aa8eb8ccc28326a8164
3
+ size 76690626
custom-diffusion-models/tortoise_plushy.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38597922778e96c54e3a57379a11a356e62b76be990097922def3f9b764db48d
3
+ size 76690626
custom-diffusion-models/wooden_pot.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:414cbd2bef3e7e65d860c4df17c8f8b8616f5dd8676634e0b228be8ed039f176
3
+ size 76690626
inference.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import pathlib
5
+ import sys
6
+
7
+ import gradio as gr
8
+ import PIL.Image
9
+ import numpy as np
10
+
11
+ import torch
12
+ from diffusers import StableDiffusionPipeline
13
+ sys.path.insert(0, './custom-diffusion')
14
+
15
+
16
+ class InferencePipeline:
17
+ def __init__(self):
18
+ self.pipe = None
19
+ self.device = torch.device(
20
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
21
+ self.weight_path = None
22
+
23
+ def clear(self) -> None:
24
+ self.weight_path = None
25
+ del self.pipe
26
+ self.pipe = None
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
+
30
+ @staticmethod
31
+ def get_weight_path(name: str) -> pathlib.Path:
32
+ curr_dir = pathlib.Path(__file__).parent
33
+ return curr_dir / name
34
+
35
+ def load_pipe(self, model_id: str, filename: str) -> None:
36
+ weight_path = self.get_weight_path(filename)
37
+ if weight_path == self.weight_path:
38
+ return
39
+ self.weight_path = weight_path
40
+ weight = torch.load(self.weight_path, map_location=self.device)
41
+
42
+ if self.device.type == 'cpu':
43
+ pipe = StableDiffusionPipeline.from_pretrained(model_id)
44
+ else:
45
+ pipe = StableDiffusionPipeline.from_pretrained(
46
+ model_id, torch_dtype=torch.float16)
47
+ pipe = pipe.to(self.device)
48
+
49
+ from src import diffuser_training
50
+ diffuser_training.load_model(pipe.text_encoder, pipe.tokenizer, pipe.unet, weight_path, '<new1>')
51
+
52
+ self.pipe = pipe
53
+
54
+ def run(
55
+ self,
56
+ base_model: str,
57
+ weight_name: str,
58
+ prompt: str,
59
+ seed: int,
60
+ n_steps: int,
61
+ guidance_scale: float,
62
+ eta: float,
63
+ batch_size: int,
64
+ resolution: int,
65
+ ) -> PIL.Image.Image:
66
+ if not torch.cuda.is_available():
67
+ raise gr.Error('CUDA is not available.')
68
+
69
+ self.load_pipe(base_model, weight_name)
70
+
71
+ generator = torch.Generator(device=self.device).manual_seed(seed)
72
+ out = self.pipe([prompt]*batch_size,
73
+ num_inference_steps=n_steps,
74
+ guidance_scale=guidance_scale,
75
+ height=resolution, width=resolution,
76
+ eta = eta,
77
+ generator=generator) # type: ignore
78
+ out = out.images
79
+ out = PIL.Image.fromarray(np.hstack([np.array(x) for x in out]))
80
+ return out
method.jpg ADDED

Git LFS Details

  • SHA256: 12a48301b17741a6c1bea4208b7dcb5613b2cfe974f9d6c8e1de331d6dd8a0a6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.51 MB
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.15.0
2
+ bitsandbytes==0.35.4
3
+ diffusers==0.10.2
4
+ ftfy==6.1.1
5
+ Pillow==9.3.0
6
+ torch==1.13.0
7
+ torchvision==0.14.0
8
+ transformers==4.25.1
9
+ triton==2.0.0.dev20220701
10
+ xformers==0.0.13
11
+ clip_retrieval
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
trainer.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pathlib
5
+ import shlex
6
+ import shutil
7
+ import subprocess
8
+
9
+ import gradio as gr
10
+ import PIL.Image
11
+ import torch
12
+
13
+ os.environ['PYTHONPATH'] = f'custom-diffusion:{os.getenv("PYTHONPATH", "")}'
14
+
15
+
16
+ def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
17
+ w, h = image.size
18
+ if w == h:
19
+ return image
20
+ elif w > h:
21
+ new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
22
+ new_image.paste(image, (0, (w - h) // 2))
23
+ return new_image
24
+ else:
25
+ new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
26
+ new_image.paste(image, ((h - w) // 2, 0))
27
+ return new_image
28
+
29
+
30
+ class Trainer:
31
+ def __init__(self):
32
+ self.is_running = False
33
+ self.is_running_message = 'Another training is in progress.'
34
+
35
+ self.output_dir = pathlib.Path('results')
36
+ self.instance_data_dir = self.output_dir / 'training_data'
37
+ self.class_data_dir = self.output_dir / 'regularization_data'
38
+
39
+ def check_if_running(self) -> dict:
40
+ if self.is_running:
41
+ return gr.update(value=self.is_running_message)
42
+ else:
43
+ return gr.update(value='No training is running.')
44
+
45
+ def cleanup_dirs(self) -> None:
46
+ shutil.rmtree(self.output_dir, ignore_errors=True)
47
+
48
+ def prepare_dataset(self, concept_images: list, resolution: int) -> None:
49
+ self.instance_data_dir.mkdir(parents=True)
50
+ for i, temp_path in enumerate(concept_images):
51
+ image = PIL.Image.open(temp_path.name)
52
+ image = pad_image(image)
53
+ image = image.resize((resolution, resolution))
54
+ image = image.convert('RGB')
55
+ out_path = self.instance_data_dir / f'{i:03d}.jpg'
56
+ image.save(out_path, format='JPEG', quality=100)
57
+
58
+ def run(
59
+ self,
60
+ base_model: str,
61
+ resolution_s: str,
62
+ concept_images: list | None,
63
+ concept_prompt: str,
64
+ class_prompt: str,
65
+ n_steps: int,
66
+ learning_rate: float,
67
+ train_text_encoder: bool,
68
+ modifier_token: bool,
69
+ gradient_accumulation: int,
70
+ batch_size: int,
71
+ use_8bit_adam: bool,
72
+ gradient_checkpointing: bool,
73
+ gen_images: bool,
74
+ ) -> tuple[dict, list[pathlib.Path]]:
75
+ if not torch.cuda.is_available():
76
+ raise gr.Error('CUDA is not available.')
77
+
78
+ if self.is_running:
79
+ return gr.update(value=self.is_running_message), []
80
+
81
+ if concept_images is None:
82
+ raise gr.Error('You need to upload images.')
83
+ if not concept_prompt:
84
+ raise gr.Error('The concept prompt is missing.')
85
+
86
+ resolution = int(resolution_s)
87
+
88
+ self.cleanup_dirs()
89
+ self.prepare_dataset(concept_images, resolution)
90
+
91
+ command = f'''
92
+ accelerate launch custom-diffusion/src/diffuser_training.py \
93
+ --pretrained_model_name_or_path={base_model} \
94
+ --instance_data_dir={self.instance_data_dir} \
95
+ --output_dir={self.output_dir} \
96
+ --instance_prompt="{concept_prompt}" \
97
+ --class_data_dir={self.class_data_dir} \
98
+ --with_prior_preservation --prior_loss_weight=1.0 \
99
+ --class_prompt="{class_prompt}" \
100
+ --resolution={resolution} \
101
+ --train_batch_size={batch_size} \
102
+ --gradient_accumulation_steps={gradient_accumulation} \
103
+ --learning_rate={learning_rate} \
104
+ --lr_scheduler="constant" \
105
+ --lr_warmup_steps=0 \
106
+ --max_train_steps={n_steps} \
107
+ --num_class_images=200 \
108
+ --scale_lr
109
+ '''
110
+ if modifier_token:
111
+ command += ' --modifier_token "<new1>"'
112
+ if not gen_images:
113
+ command += ' --real_prior'
114
+ if use_8bit_adam:
115
+ command += ' --use_8bit_adam'
116
+ if train_text_encoder:
117
+ command += f' --train_text_encoder'
118
+ if gradient_checkpointing:
119
+ command += f' --gradient_checkpointing'
120
+
121
+ with open(self.output_dir / 'train.sh', 'w') as f:
122
+ command_s = ' '.join(command.split())
123
+ f.write(command_s)
124
+
125
+ self.is_running = True
126
+ res = subprocess.run(shlex.split(command))
127
+ self.is_running = False
128
+
129
+ if res.returncode == 0:
130
+ result_message = 'Training Completed!'
131
+ else:
132
+ result_message = 'Training Failed!'
133
+ weight_paths = sorted(self.output_dir.glob('*.bin'))
134
+ return gr.update(value=result_message), weight_paths
uploader.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi
3
+
4
+
5
+ def upload(model_name: str, hf_token: str) -> None:
6
+ api = HfApi(token=hf_token)
7
+ user_name = api.whoami()['name']
8
+ model_id = f'{user_name}/{model_name}'
9
+ try:
10
+ api.create_repo(model_id, repo_type='model', private=True)
11
+ api.upload_folder(repo_id=model_id,
12
+ folder_path='results',
13
+ path_in_repo='results',
14
+ repo_type='model')
15
+ url = f'https://huggingface.co/{model_id}'
16
+ message = f'Your model was successfully uploaded to [{url}]({url}).'
17
+ except Exception as e:
18
+ message = str(e)
19
+
20
+ return gr.update(value=message, visible=True)