diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..e2e68308705c59d0e047f13c1116ed223f9b7fec --- /dev/null +++ b/.gitattributes @@ -0,0 +1,33 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..75c45fdcb59a4c94649d0e5fd317f22969ecaf3d --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.idea +pretrained_models diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..b4b944d4a3d42d900933fa1cad13c099ad8c1a07 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,400 @@ + +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d1aae787b021c47ad6914bdc857cf7016244a496 --- /dev/null +++ b/README.md @@ -0,0 +1,13 @@ +--- +title: Diffusion Transformers (DiT) +emoji: 🚀 +colorFrom: yellow +colorTo: green +sdk: gradio +sdk_version: 3.6 +app_file: app.py +pinned: false +license: cc-by-nc-4.0 +--- + +The code and model weights are licensed under CC-BY-NC. See LICENSE.txt for details. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..ffac77c274d3eed27e0728636331397710993921 --- /dev/null +++ b/app.py @@ -0,0 +1,136 @@ +import torch +from torchvision.utils import make_grid +import math +from PIL import Image +from diffusion import create_diffusion +from diffusers.models import AutoencoderKL +import gradio as gr +from imagenet_class_data import IMAGENET_1K_CLASSES +from download import find_model +from models import DiT_XL_2 + + +def load_model(image_size=256): + assert image_size in [256, 512] + latent_size = image_size // 8 + model = DiT_XL_2(input_size=latent_size).to(device) + state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt") + model.load_state_dict(state_dict) + model.eval() + return model + + +torch.set_grad_enabled(False) +device = "cuda" if torch.cuda.is_available() else "cpu" +model = load_model(image_size=256) +vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device) +current_image_size = 256 +current_vae_model = "stabilityai/sd-vae-ft-mse" + + +def generate(image_size, vae_model, class_label, cfg_scale, num_sampling_steps, n, seed): + image_size = int(image_size.split("x")[0]) + global current_image_size + if image_size != current_image_size: + global model + del model + # if device == "cuda": + # torch.cuda.empty_cache() + model = load_model(image_size=image_size) + current_image_size = image_size + + global current_vae_model + if vae_model != current_vae_model: + global vae + if device == "cuda": + vae.to("cpu") + del vae + vae = AutoencoderKL.from_pretrained(vae_model).to(device) + + # Seed PyTorch: + torch.manual_seed(seed) + + # Setup diffusion + diffusion = create_diffusion(str(num_sampling_steps)) + + # Create sampling noise: + latent_size = image_size // 8 + z = torch.randn(n, 4, latent_size, latent_size, device=device) + y = torch.tensor([class_label] * n, device=device) + + # Setup classifier-free guidance: + z = torch.cat([z, z], 0) + y_null = torch.tensor([1000] * n, device=device) + y = torch.cat([y, y_null], 0) + model_kwargs = dict(y=y, cfg_scale=cfg_scale) + + # Sample images: + samples = diffusion.p_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device + ) + samples, _ = samples.chunk(2, dim=0) # Remove null class samples + samples = vae.decode(samples / 0.18215).sample + + # Convert to PIL.Image format: + samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy() + samples = [Image.fromarray(sample) for sample in samples] + return samples + + +description = '''This is a demo of our DiT image generation models. DiTs are a new class of diffusion models with +transformer backbones. They are class-conditional models trained on ImageNet-1K, and they outperform prior DDPMs.''' + +project_links = ''' +

+Project Page · +Colab · +Paper · +GitHub

''' + +examples = [ + ["512x512", "stabilityai/sd-vae-ft-mse", "golden retriever", 4.0, 200, 4, 1000], + ["512x512", "stabilityai/sd-vae-ft-mse", "macaw", 4.0, 200, 4, 1], + ["512x512", "stabilityai/sd-vae-ft-mse", "balloon", 4.0, 200, 4, 1], + ["512x512", "stabilityai/sd-vae-ft-mse", "cliff, drop, drop-off", 4.0, 200, 4, 7], + ["512x512", "stabilityai/sd-vae-ft-mse", "Pembroke, Pembroke Welsh corgi", 4.0, 200, 4, 0], + ["256x256", "stabilityai/sd-vae-ft-mse", "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", 4.0, 200, + 4, 1], + ["256x256", "stabilityai/sd-vae-ft-mse", "teddy, teddy bear", 4.0, 200, 4, 3], + ["256x256", "stabilityai/sd-vae-ft-mse", "cheeseburger", 4.0, 200, 4, 2], + +] + +with gr.Blocks() as demo: + gr.Markdown("

Scalable Diffusion Models with Transformers (DiT)

") + gr.Markdown(project_links) + gr.Markdown(description) + + with gr.Tabs(): + with gr.TabItem('Generate'): + with gr.Row(): + with gr.Column(): + with gr.Row(): + image_size = gr.inputs.Radio(choices=["256x256", "512x512"], default="256x256", label='DiT Model Resolution') + vae_model = gr.inputs.Radio(choices=["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"], + default="stabilityai/sd-vae-ft-mse", label='VAE Decoder') + with gr.Row(): + i1k_class = gr.inputs.Dropdown( + list(IMAGENET_1K_CLASSES.values()), + default='golden retriever', + type="index", label='ImageNet-1K Class' + ) + cfg_scale = gr.inputs.Slider(minimum=1, maximum=25, step=0.1, default=4.0, label='Classifier-free Guidance Scale') + steps = gr.inputs.Slider(minimum=4, maximum=1000, step=1, default=75, label='Sampling Steps') + n = gr.inputs.Slider(minimum=1, maximum=16, step=1, default=1, label='Number of Samples') + seed = gr.inputs.Number(default=0, label='Seed') + button = gr.Button("Generate", variant="primary") + with gr.Column(): + output = gr.Gallery(label='Generated Images').style(grid=[2], height="auto") + button.click(generate, inputs=[image_size, vae_model, i1k_class, cfg_scale, steps, n, seed], outputs=[output]) + with gr.Row(): + ex = gr.Examples(examples=examples, fn=generate, + inputs=[image_size, vae_model, i1k_class, cfg_scale, steps, n, seed], + outputs=[output], + cache_examples=True) + + demo.launch() diff --git a/diffusion/__init__.py b/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c536a98da92c4d051458803737661e5ecf974c2 --- /dev/null +++ b/diffusion/__init__.py @@ -0,0 +1,46 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion, space_timesteps + + +def create_diffusion( + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + rescale_learned_sigmas=False, + diffusion_steps=1000 +): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type + # rescale_timesteps=rescale_timesteps, + ) diff --git a/diffusion/__pycache__/__init__.cpython-39.pyc b/diffusion/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b5c8ed4caa8afc6c000283303604810c8f21509 Binary files /dev/null and b/diffusion/__pycache__/__init__.cpython-39.pyc differ diff --git a/diffusion/__pycache__/diffusion_utils.cpython-39.pyc b/diffusion/__pycache__/diffusion_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..568e49b4392362148d812797b29ff8625a22d9fd Binary files /dev/null and b/diffusion/__pycache__/diffusion_utils.cpython-39.pyc differ diff --git a/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc b/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5577bd197c50b380f1d4ddd3443e9ca6d87a924c Binary files /dev/null and b/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc differ diff --git a/diffusion/__pycache__/respace.cpython-39.pyc b/diffusion/__pycache__/respace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cfa72a7618953768035871077c702652937b3c9 Binary files /dev/null and b/diffusion/__pycache__/respace.cpython-39.pyc differ diff --git a/diffusion/diffusion_utils.py b/diffusion/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e493a6a3ecb91e553a53cc7eadee5cc0d1753060 --- /dev/null +++ b/diffusion/diffusion_utils.py @@ -0,0 +1,88 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import torch as th +import numpy as np + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = th.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/diffusion/gaussian_diffusion.py b/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..ccbcefeca4348e2e627d22723b50492c894e66ba --- /dev/null +++ b/diffusion/gaussian_diffusion.py @@ -0,0 +1,873 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + + +import math + +import numpy as np +import torch as th +import enum + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start ** 0.5, + beta_end ** 0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type + ): + + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) if len(self.posterior_variance) > 1 else np.array([]) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/diffusion/respace.py b/diffusion/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..0a2cc0435d1ace54466585db9043b284973d454e --- /dev/null +++ b/diffusion/respace.py @@ -0,0 +1,129 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/diffusion/timestep_sampler.py b/diffusion/timestep_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f369847677d8dbaaadb8297691b1be92cf189f --- /dev/null +++ b/diffusion/timestep_sampler.py @@ -0,0 +1,150 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/download.py b/download.py new file mode 100644 index 0000000000000000000000000000000000000000..cfdcf3b3e7481a84b7ff30a5f579195f271d21ef --- /dev/null +++ b/download.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Functions for downloading pre-trained DiT models +""" +from torchvision.datasets.utils import download_url +import torch +import os + + +pretrained_models = {'DiT-XL-2-512x512.pt', 'DiT-XL-2-256x256.pt'} + + +def find_model(model_name): + """ + Finds a pre-trained G.pt model, downloading it if necessary. Alternatively, loads a model from a local path. + """ + if model_name in pretrained_models: # Find/download our pre-trained G.pt checkpoints + return download_model(model_name) + else: # Load a custom DiT checkpoint: + assert os.path.isfile(model_name), f'Could not find DiT checkpoint at {model_name}' + return torch.load(model_name, map_location=lambda storage, loc: storage) + + +def download_model(model_name): + """ + Downloads a pre-trained DiT model from the web. + """ + assert model_name in pretrained_models + local_path = f'pretrained_models/{model_name}' + if not os.path.isfile(local_path): + os.makedirs('pretrained_models', exist_ok=True) + web_path = f'https://dl.fbaipublicfiles.com/DiT/models/{model_name}' + download_url(web_path, 'pretrained_models') + model = torch.load(local_path, map_location=lambda storage, loc: storage) + return model + + +if __name__ == "__main__": + # Download all DiT checkpoints + for model in pretrained_models: + download_model(model) + print('Done.') diff --git a/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/captions.json b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/captions.json new file mode 100644 index 0000000000000000000000000000000000000000..42a01d277fc24c6b3ff121343348579bd6e75479 --- /dev/null +++ b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/captions.json @@ -0,0 +1 @@ +{"/home/peebles/demo/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpe9n9k0az54l_f3gw.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpmoaadwkx_iwdl820.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpurgtzs1y9ie5h2s2.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpjtpzxdb26dcckrfs.png": null} \ No newline at end of file diff --git a/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpe9n9k0az54l_f3gw.png b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpe9n9k0az54l_f3gw.png new file mode 100644 index 0000000000000000000000000000000000000000..929ccf2b10b96d8651c23ec766d3866a99aeff46 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpe9n9k0az54l_f3gw.png differ diff --git a/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpjtpzxdb26dcckrfs.png b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpjtpzxdb26dcckrfs.png new file mode 100644 index 0000000000000000000000000000000000000000..d02d4de55875dc68e2af562798efc3ce9910efb4 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpjtpzxdb26dcckrfs.png differ diff --git a/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpmoaadwkx_iwdl820.png b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpmoaadwkx_iwdl820.png new file mode 100644 index 0000000000000000000000000000000000000000..1bccab369fb801a6b20e9e18cf82cd59eced3bb9 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpmoaadwkx_iwdl820.png differ diff --git a/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpurgtzs1y9ie5h2s2.png b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpurgtzs1y9ie5h2s2.png new file mode 100644 index 0000000000000000000000000000000000000000..dc29f02a7d6bd3a4d71f7c2d17121aff8692b577 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpurgtzs1y9ie5h2s2.png differ diff --git a/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/captions.json b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/captions.json new file mode 100644 index 0000000000000000000000000000000000000000..1893450dbb2cd025d2b1c10aaf47bd642d142151 --- /dev/null +++ b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/captions.json @@ -0,0 +1 @@ +{"/home/peebles/demo/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpfz82ezx8is9yjfdi.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpff9tyvtyjf15zmm7.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpz96a09w08a8kcl84.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmppl9m3bza1t5k15ta.png": null} \ No newline at end of file diff --git a/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpff9tyvtyjf15zmm7.png b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpff9tyvtyjf15zmm7.png new file mode 100644 index 0000000000000000000000000000000000000000..fe8e9dded1f9107d6c0ac527b18fa555a7dbf96d Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpff9tyvtyjf15zmm7.png differ diff --git a/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpfz82ezx8is9yjfdi.png b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpfz82ezx8is9yjfdi.png new file mode 100644 index 0000000000000000000000000000000000000000..ca65fb2db579c355ff026bc657f7ea0ec381b779 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpfz82ezx8is9yjfdi.png differ diff --git a/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmppl9m3bza1t5k15ta.png b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmppl9m3bza1t5k15ta.png new file mode 100644 index 0000000000000000000000000000000000000000..d6dac6bf8ff5203dd7229adcfbdfdcbb7f0a053b Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmppl9m3bza1t5k15ta.png differ diff --git a/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpz96a09w08a8kcl84.png b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpz96a09w08a8kcl84.png new file mode 100644 index 0000000000000000000000000000000000000000..7324c35c02e1a2d66951827533b8db49feb176b6 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpz96a09w08a8kcl84.png differ diff --git a/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/captions.json b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/captions.json new file mode 100644 index 0000000000000000000000000000000000000000..45a65e26742bc3cb66cc8eb085e895606ccb90c1 --- /dev/null +++ b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/captions.json @@ -0,0 +1 @@ +{"/home/peebles/demo/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpkhiwf5_97pu17qfx.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp8pyeu6a8l_en8xim.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpipmoifyv87sepx5o.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp63ria_p2997pqmtd.png": null} \ No newline at end of file diff --git a/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp63ria_p2997pqmtd.png b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp63ria_p2997pqmtd.png new file mode 100644 index 0000000000000000000000000000000000000000..5a30e3d619f7d3f8c9ae93a644fb2329819e8a3b Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp63ria_p2997pqmtd.png differ diff --git a/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp8pyeu6a8l_en8xim.png b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp8pyeu6a8l_en8xim.png new file mode 100644 index 0000000000000000000000000000000000000000..cc06ee1f0b48593fd8170164a51ce11b9a0ef212 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp8pyeu6a8l_en8xim.png differ diff --git a/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpipmoifyv87sepx5o.png b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpipmoifyv87sepx5o.png new file mode 100644 index 0000000000000000000000000000000000000000..e8ff9b2de5150afebcff7c48d1de1f82c77255bf Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpipmoifyv87sepx5o.png differ diff --git a/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpkhiwf5_97pu17qfx.png b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpkhiwf5_97pu17qfx.png new file mode 100644 index 0000000000000000000000000000000000000000..596665ee0e55e9822e5fc3d34c369aa8552d93f7 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpkhiwf5_97pu17qfx.png differ diff --git a/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/captions.json b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/captions.json new file mode 100644 index 0000000000000000000000000000000000000000..cc2544c2fe11d2f7368858348cb99772daf2eab3 --- /dev/null +++ b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/captions.json @@ -0,0 +1 @@ +{"/home/peebles/demo/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpgnbz7k1qi372rdol.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpv4smnu57thdhxlo5.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp3oqwhs41q0ygkzgo.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp1j0z6hrnm0vw0rai.png": null} \ No newline at end of file diff --git a/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp1j0z6hrnm0vw0rai.png b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp1j0z6hrnm0vw0rai.png new file mode 100644 index 0000000000000000000000000000000000000000..c7a7cdf3f95315f8cd87c4abd8d5b8c03e30273b Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp1j0z6hrnm0vw0rai.png differ diff --git a/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp3oqwhs41q0ygkzgo.png b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp3oqwhs41q0ygkzgo.png new file mode 100644 index 0000000000000000000000000000000000000000..0f0458f72d60b5622e1ebb73ac7b448a148bc531 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp3oqwhs41q0ygkzgo.png differ diff --git a/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpgnbz7k1qi372rdol.png b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpgnbz7k1qi372rdol.png new file mode 100644 index 0000000000000000000000000000000000000000..f2b7c1d78f55e9a31ce1cbcde57f3cbcc24df0f2 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpgnbz7k1qi372rdol.png differ diff --git a/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpv4smnu57thdhxlo5.png b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpv4smnu57thdhxlo5.png new file mode 100644 index 0000000000000000000000000000000000000000..d8813c0322c0a6744454eddc0cd2689f5a32d2d0 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpv4smnu57thdhxlo5.png differ diff --git a/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/captions.json b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/captions.json new file mode 100644 index 0000000000000000000000000000000000000000..cf5f8231d7fe826c59e62505a484fc49222d5f9b --- /dev/null +++ b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/captions.json @@ -0,0 +1 @@ +{"/home/peebles/demo/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpamh42w2dmr65v467.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp2xnh42_y5utj14_w.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp7th5inqczdoynzrl.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpnleysn66i6zycwi5.png": null} \ No newline at end of file diff --git a/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp2xnh42_y5utj14_w.png b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp2xnh42_y5utj14_w.png new file mode 100644 index 0000000000000000000000000000000000000000..04fafa07b652ddc3b21ed6c535a20fe983483ffb Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp2xnh42_y5utj14_w.png differ diff --git a/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp7th5inqczdoynzrl.png b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp7th5inqczdoynzrl.png new file mode 100644 index 0000000000000000000000000000000000000000..4a6a3d26ccfe7c65e69e43e0dc4cf55d48da09c3 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp7th5inqczdoynzrl.png differ diff --git a/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpamh42w2dmr65v467.png b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpamh42w2dmr65v467.png new file mode 100644 index 0000000000000000000000000000000000000000..ace17cf260b0726e968f6f44d3600ba17d57c79e Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpamh42w2dmr65v467.png differ diff --git a/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpnleysn66i6zycwi5.png b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpnleysn66i6zycwi5.png new file mode 100644 index 0000000000000000000000000000000000000000..fd043f10f30cb37029c0947e2c87b7cce58f5920 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpnleysn66i6zycwi5.png differ diff --git a/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/captions.json b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/captions.json new file mode 100644 index 0000000000000000000000000000000000000000..3c4754345c0b8531ad80977366f5a907fd94545c --- /dev/null +++ b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/captions.json @@ -0,0 +1 @@ +{"/home/peebles/demo/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmp_haqfqy2_0t9ntvn.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpwy2tstw8dlphar0w.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpe7h1fv6ojayij55x.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpvza140f8v_pikrxf.png": null} \ No newline at end of file diff --git a/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmp_haqfqy2_0t9ntvn.png b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmp_haqfqy2_0t9ntvn.png new file mode 100644 index 0000000000000000000000000000000000000000..f3ad3af3fc31c9570d2d92829ef2f9d6d2b46349 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmp_haqfqy2_0t9ntvn.png differ diff --git a/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpe7h1fv6ojayij55x.png b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpe7h1fv6ojayij55x.png new file mode 100644 index 0000000000000000000000000000000000000000..84ca66ecb7c6ae9f23e9c790037e37653a86bde0 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpe7h1fv6ojayij55x.png differ diff --git a/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpvza140f8v_pikrxf.png b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpvza140f8v_pikrxf.png new file mode 100644 index 0000000000000000000000000000000000000000..a246622dd0beb9632ae76fdf033e9fcaa3c9ad8c Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpvza140f8v_pikrxf.png differ diff --git a/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpwy2tstw8dlphar0w.png b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpwy2tstw8dlphar0w.png new file mode 100644 index 0000000000000000000000000000000000000000..5e1ab2903964c4f7b964a01ff8a28f1450b3117f Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpwy2tstw8dlphar0w.png differ diff --git a/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/captions.json b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/captions.json new file mode 100644 index 0000000000000000000000000000000000000000..225a3463fd33f2eb8726076b6347e21039a801de --- /dev/null +++ b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/captions.json @@ -0,0 +1 @@ +{"/home/peebles/demo/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp2u8kljaep52iold0.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmpbbapqj_eo3ty6ub2.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp3wbfuj7t3s4jrwxg.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp5a1jyli246vbchhw.png": null} \ No newline at end of file diff --git a/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp2u8kljaep52iold0.png b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp2u8kljaep52iold0.png new file mode 100644 index 0000000000000000000000000000000000000000..3e23bab873c45546d249fecb845b7f5c84661742 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp2u8kljaep52iold0.png differ diff --git a/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp3wbfuj7t3s4jrwxg.png b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp3wbfuj7t3s4jrwxg.png new file mode 100644 index 0000000000000000000000000000000000000000..4413095f76aa125986448f93aef1971a61fa8ae0 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp3wbfuj7t3s4jrwxg.png differ diff --git a/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp5a1jyli246vbchhw.png b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp5a1jyli246vbchhw.png new file mode 100644 index 0000000000000000000000000000000000000000..554e274a289b0df3f0a19d57b1d37e6dfc86cd6f Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp5a1jyli246vbchhw.png differ diff --git a/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmpbbapqj_eo3ty6ub2.png b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmpbbapqj_eo3ty6ub2.png new file mode 100644 index 0000000000000000000000000000000000000000..9c1bac4e9627743b5c9b2f6b620a18b244d84398 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmpbbapqj_eo3ty6ub2.png differ diff --git a/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/captions.json b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/captions.json new file mode 100644 index 0000000000000000000000000000000000000000..51926d3808e002ed81f38a16100fd05d59dc5658 --- /dev/null +++ b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/captions.json @@ -0,0 +1 @@ +{"/home/peebles/demo/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmprezpqwr5516qiohq.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpgt7c7p6l8xiv4p83.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpt_dcbfedfr0d8mnn.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpbrh6kbrfgw02swpp.png": null} \ No newline at end of file diff --git a/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpbrh6kbrfgw02swpp.png b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpbrh6kbrfgw02swpp.png new file mode 100644 index 0000000000000000000000000000000000000000..70dc335566745a187df069aa849668217a32182c Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpbrh6kbrfgw02swpp.png differ diff --git a/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpgt7c7p6l8xiv4p83.png b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpgt7c7p6l8xiv4p83.png new file mode 100644 index 0000000000000000000000000000000000000000..be9c1c9e4d29a996ccd9e1220463229385c1e8bd Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpgt7c7p6l8xiv4p83.png differ diff --git a/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmprezpqwr5516qiohq.png b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmprezpqwr5516qiohq.png new file mode 100644 index 0000000000000000000000000000000000000000..171e084f4fd9ccc91fd90bb64f861e6ef56c17a6 Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmprezpqwr5516qiohq.png differ diff --git a/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpt_dcbfedfr0d8mnn.png b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpt_dcbfedfr0d8mnn.png new file mode 100644 index 0000000000000000000000000000000000000000..665b1eef870aacdde6e3d7b55ec8e3da2d377bde Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpt_dcbfedfr0d8mnn.png differ diff --git a/gradio_cached_examples/24/log.csv b/gradio_cached_examples/24/log.csv new file mode 100644 index 0000000000000000000000000000000000000000..e42b6baeea589b126c02d031fa10b2fbb5c600ef --- /dev/null +++ b/gradio_cached_examples/24/log.csv @@ -0,0 +1,9 @@ +Generated Images,flag,username,timestamp +/home/peebles/demo/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec,,,2022-12-21 22:19:52.975486 +/home/peebles/demo/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd,,,2022-12-21 22:21:49.400813 +/home/peebles/demo/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63,,,2022-12-21 22:23:45.918671 +/home/peebles/demo/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e,,,2022-12-21 22:25:42.360174 +/home/peebles/demo/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998,,,2022-12-21 22:27:38.970601 +/home/peebles/demo/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61,,,2022-12-21 22:28:17.384826 +/home/peebles/demo/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8,,,2022-12-21 22:28:44.680872 +/home/peebles/demo/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf,,,2022-12-21 22:29:12.090413 diff --git a/imagenet_class_data.py b/imagenet_class_data.py new file mode 100644 index 0000000000000000000000000000000000000000..0e52b050f2fffaed156eae333a245cf660e8da0f --- /dev/null +++ b/imagenet_class_data.py @@ -0,0 +1,1003 @@ +# https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt + +IMAGENET_1K_CLASSES = \ +{0: 'tench, Tinca tinca', + 1: 'goldfish, Carassius auratus', + 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', + 3: 'tiger shark, Galeocerdo cuvieri', + 4: 'hammerhead, hammerhead shark', + 5: 'electric ray, crampfish, numbfish, torpedo', + 6: 'stingray', + 7: 'cock', + 8: 'hen', + 9: 'ostrich, Struthio camelus', + 10: 'brambling, Fringilla montifringilla', + 11: 'goldfinch, Carduelis carduelis', + 12: 'house finch, linnet, Carpodacus mexicanus', + 13: 'junco, snowbird', + 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea', + 15: 'robin, American robin, Turdus migratorius', + 16: 'bulbul', + 17: 'jay', + 18: 'magpie', + 19: 'chickadee', + 20: 'water ouzel, dipper', + 21: 'kite', + 22: 'bald eagle, American eagle, Haliaeetus leucocephalus', + 23: 'vulture', + 24: 'great grey owl, great gray owl, Strix nebulosa', + 25: 'European fire salamander, Salamandra salamandra', + 26: 'common newt, Triturus vulgaris', + 27: 'eft', + 28: 'spotted salamander, Ambystoma maculatum', + 29: 'axolotl, mud puppy, Ambystoma mexicanum', + 30: 'bullfrog, Rana catesbeiana', + 31: 'tree frog, tree-frog', + 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui', + 33: 'loggerhead, loggerhead turtle, Caretta caretta', + 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea', + 35: 'mud turtle', + 36: 'terrapin', + 37: 'box turtle, box tortoise', + 38: 'banded gecko', + 39: 'common iguana, iguana, Iguana iguana', + 40: 'American chameleon, anole, Anolis carolinensis', + 41: 'whiptail, whiptail lizard', + 42: 'agama', + 43: 'frilled lizard, Chlamydosaurus kingi', + 44: 'alligator lizard', + 45: 'Gila monster, Heloderma suspectum', + 46: 'green lizard, Lacerta viridis', + 47: 'African chameleon, Chamaeleo chamaeleon', + 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', + 49: 'African crocodile, Nile crocodile, Crocodylus niloticus', + 50: 'American alligator, Alligator mississipiensis', + 51: 'triceratops', + 52: 'thunder snake, worm snake, Carphophis amoenus', + 53: 'ringneck snake, ring-necked snake, ring snake', + 54: 'hognose snake, puff adder, sand viper', + 55: 'green snake, grass snake', + 56: 'king snake, kingsnake', + 57: 'garter snake, grass snake', + 58: 'water snake', + 59: 'vine snake', + 60: 'night snake, Hypsiglena torquata', + 61: 'boa constrictor, Constrictor constrictor', + 62: 'rock python, rock snake, Python sebae', + 63: 'Indian cobra, Naja naja', + 64: 'green mamba', + 65: 'sea snake', + 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus', + 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus', + 68: 'sidewinder, horned rattlesnake, Crotalus cerastes', + 69: 'trilobite', + 70: 'harvestman, daddy longlegs, Phalangium opilio', + 71: 'scorpion', + 72: 'black and gold garden spider, Argiope aurantia', + 73: 'barn spider, Araneus cavaticus', + 74: 'garden spider, Aranea diademata', + 75: 'black widow, Latrodectus mactans', + 76: 'tarantula', + 77: 'wolf spider, hunting spider', + 78: 'tick', + 79: 'centipede', + 80: 'black grouse', + 81: 'ptarmigan', + 82: 'ruffed grouse, partridge, Bonasa umbellus', + 83: 'prairie chicken, prairie grouse, prairie fowl', + 84: 'peacock', + 85: 'quail', + 86: 'partridge', + 87: 'African grey, African gray, Psittacus erithacus', + 88: 'macaw', + 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita', + 90: 'lorikeet', + 91: 'coucal', + 92: 'bee eater', + 93: 'hornbill', + 94: 'hummingbird', + 95: 'jacamar', + 96: 'toucan', + 97: 'drake', + 98: 'red-breasted merganser, Mergus serrator', + 99: 'goose', + 100: 'black swan, Cygnus atratus', + 101: 'tusker', + 102: 'echidna, spiny anteater, anteater', + 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus', + 104: 'wallaby, brush kangaroo', + 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus', + 106: 'wombat', + 107: 'jellyfish', + 108: 'sea anemone, anemone', + 109: 'brain coral', + 110: 'flatworm, platyhelminth', + 111: 'nematode, nematode worm, roundworm', + 112: 'conch', + 113: 'snail', + 114: 'slug', + 115: 'sea slug, nudibranch', + 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore', + 117: 'chambered nautilus, pearly nautilus, nautilus', + 118: 'Dungeness crab, Cancer magister', + 119: 'rock crab, Cancer irroratus', + 120: 'fiddler crab', + 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica', + 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus', + 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish', + 124: 'crayfish, crawfish, crawdad, crawdaddy', + 125: 'hermit crab', + 126: 'isopod', + 127: 'white stork, Ciconia ciconia', + 128: 'black stork, Ciconia nigra', + 129: 'spoonbill', + 130: 'flamingo', + 131: 'little blue heron, Egretta caerulea', + 132: 'American egret, great white heron, Egretta albus', + 133: 'bittern', + 134: 'crane', + 135: 'limpkin, Aramus pictus', + 136: 'European gallinule, Porphyrio porphyrio', + 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana', + 138: 'bustard', + 139: 'ruddy turnstone, Arenaria interpres', + 140: 'red-backed sandpiper, dunlin, Erolia alpina', + 141: 'redshank, Tringa totanus', + 142: 'dowitcher', + 143: 'oystercatcher, oyster catcher', + 144: 'pelican', + 145: 'king penguin, Aptenodytes patagonica', + 146: 'albatross, mollymawk', + 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus', + 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca', + 149: 'dugong, Dugong dugon', + 150: 'sea lion', + 151: 'Chihuahua', + 152: 'Japanese spaniel', + 153: 'Maltese dog, Maltese terrier, Maltese', + 154: 'Pekinese, Pekingese, Peke', + 155: 'Shih-Tzu', + 156: 'Blenheim spaniel', + 157: 'papillon', + 158: 'toy terrier', + 159: 'Rhodesian ridgeback', + 160: 'Afghan hound, Afghan', + 161: 'basset, basset hound', + 162: 'beagle', + 163: 'bloodhound, sleuthhound', + 164: 'bluetick', + 165: 'black-and-tan coonhound', + 166: 'Walker hound, Walker foxhound', + 167: 'English foxhound', + 168: 'redbone', + 169: 'borzoi, Russian wolfhound', + 170: 'Irish wolfhound', + 171: 'Italian greyhound', + 172: 'whippet', + 173: 'Ibizan hound, Ibizan Podenco', + 174: 'Norwegian elkhound, elkhound', + 175: 'otterhound, otter hound', + 176: 'Saluki, gazelle hound', + 177: 'Scottish deerhound, deerhound', + 178: 'Weimaraner', + 179: 'Staffordshire bullterrier, Staffordshire bull terrier', + 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier', + 181: 'Bedlington terrier', + 182: 'Border terrier', + 183: 'Kerry blue terrier', + 184: 'Irish terrier', + 185: 'Norfolk terrier', + 186: 'Norwich terrier', + 187: 'Yorkshire terrier', + 188: 'wire-haired fox terrier', + 189: 'Lakeland terrier', + 190: 'Sealyham terrier, Sealyham', + 191: 'Airedale, Airedale terrier', + 192: 'cairn, cairn terrier', + 193: 'Australian terrier', + 194: 'Dandie Dinmont, Dandie Dinmont terrier', + 195: 'Boston bull, Boston terrier', + 196: 'miniature schnauzer', + 197: 'giant schnauzer', + 198: 'standard schnauzer', + 199: 'Scotch terrier, Scottish terrier, Scottie', + 200: 'Tibetan terrier, chrysanthemum dog', + 201: 'silky terrier, Sydney silky', + 202: 'soft-coated wheaten terrier', + 203: 'West Highland white terrier', + 204: 'Lhasa, Lhasa apso', + 205: 'flat-coated retriever', + 206: 'curly-coated retriever', + 207: 'golden retriever', + 208: 'Labrador retriever', + 209: 'Chesapeake Bay retriever', + 210: 'German short-haired pointer', + 211: 'vizsla, Hungarian pointer', + 212: 'English setter', + 213: 'Irish setter, red setter', + 214: 'Gordon setter', + 215: 'Brittany spaniel', + 216: 'clumber, clumber spaniel', + 217: 'English springer, English springer spaniel', + 218: 'Welsh springer spaniel', + 219: 'cocker spaniel, English cocker spaniel, cocker', + 220: 'Sussex spaniel', + 221: 'Irish water spaniel', + 222: 'kuvasz', + 223: 'schipperke', + 224: 'groenendael', + 225: 'malinois', + 226: 'briard', + 227: 'kelpie', + 228: 'komondor', + 229: 'Old English sheepdog, bobtail', + 230: 'Shetland sheepdog, Shetland sheep dog, Shetland', + 231: 'collie', + 232: 'Border collie', + 233: 'Bouvier des Flandres, Bouviers des Flandres', + 234: 'Rottweiler', + 235: 'German shepherd, German shepherd dog, German police dog, alsatian', + 236: 'Doberman, Doberman pinscher', + 237: 'miniature pinscher', + 238: 'Greater Swiss Mountain dog', + 239: 'Bernese mountain dog', + 240: 'Appenzeller', + 241: 'EntleBucher', + 242: 'boxer', + 243: 'bull mastiff', + 244: 'Tibetan mastiff', + 245: 'French bulldog', + 246: 'Great Dane', + 247: 'Saint Bernard, St Bernard', + 248: 'Eskimo dog, husky', + 249: 'malamute, malemute, Alaskan malamute', + 250: 'Siberian husky', + 251: 'dalmatian, coach dog, carriage dog', + 252: 'affenpinscher, monkey pinscher, monkey dog', + 253: 'basenji', + 254: 'pug, pug-dog', + 255: 'Leonberg', + 256: 'Newfoundland, Newfoundland dog', + 257: 'Great Pyrenees', + 258: 'Samoyed, Samoyede', + 259: 'Pomeranian', + 260: 'chow, chow chow', + 261: 'keeshond', + 262: 'Brabancon griffon', + 263: 'Pembroke, Pembroke Welsh corgi', + 264: 'Cardigan, Cardigan Welsh corgi', + 265: 'toy poodle', + 266: 'miniature poodle', + 267: 'standard poodle', + 268: 'Mexican hairless', + 269: 'timber wolf, grey wolf, gray wolf, Canis lupus', + 270: 'white wolf, Arctic wolf, Canis lupus tundrarum', + 271: 'red wolf, maned wolf, Canis rufus, Canis niger', + 272: 'coyote, prairie wolf, brush wolf, Canis latrans', + 273: 'dingo, warrigal, warragal, Canis dingo', + 274: 'dhole, Cuon alpinus', + 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus', + 276: 'hyena, hyaena', + 277: 'red fox, Vulpes vulpes', + 278: 'kit fox, Vulpes macrotis', + 279: 'Arctic fox, white fox, Alopex lagopus', + 280: 'grey fox, gray fox, Urocyon cinereoargenteus', + 281: 'tabby, tabby cat', + 282: 'tiger cat', + 283: 'Persian cat', + 284: 'Siamese cat, Siamese', + 285: 'Egyptian cat', + 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor', + 287: 'lynx, catamount', + 288: 'leopard, Panthera pardus', + 289: 'snow leopard, ounce, Panthera uncia', + 290: 'jaguar, panther, Panthera onca, Felis onca', + 291: 'lion, king of beasts, Panthera leo', + 292: 'tiger, Panthera tigris', + 293: 'cheetah, chetah, Acinonyx jubatus', + 294: 'brown bear, bruin, Ursus arctos', + 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus', + 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus', + 297: 'sloth bear, Melursus ursinus, Ursus ursinus', + 298: 'mongoose', + 299: 'meerkat, mierkat', + 300: 'tiger beetle', + 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle', + 302: 'ground beetle, carabid beetle', + 303: 'long-horned beetle, longicorn, longicorn beetle', + 304: 'leaf beetle, chrysomelid', + 305: 'dung beetle', + 306: 'rhinoceros beetle', + 307: 'weevil', + 308: 'fly', + 309: 'bee', + 310: 'ant, emmet, pismire', + 311: 'grasshopper, hopper', + 312: 'cricket', + 313: 'walking stick, walkingstick, stick insect', + 314: 'cockroach, roach', + 315: 'mantis, mantid', + 316: 'cicada, cicala', + 317: 'leafhopper', + 318: 'lacewing, lacewing fly', + 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", + 320: 'damselfly', + 321: 'admiral', + 322: 'ringlet, ringlet butterfly', + 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus', + 324: 'cabbage butterfly', + 325: 'sulphur butterfly, sulfur butterfly', + 326: 'lycaenid, lycaenid butterfly', + 327: 'starfish, sea star', + 328: 'sea urchin', + 329: 'sea cucumber, holothurian', + 330: 'wood rabbit, cottontail, cottontail rabbit', + 331: 'hare', + 332: 'Angora, Angora rabbit', + 333: 'hamster', + 334: 'porcupine, hedgehog', + 335: 'fox squirrel, eastern fox squirrel, Sciurus niger', + 336: 'marmot', + 337: 'beaver', + 338: 'guinea pig, Cavia cobaya', + 339: 'sorrel', + 340: 'zebra', + 341: 'hog, pig, grunter, squealer, Sus scrofa', + 342: 'wild boar, boar, Sus scrofa', + 343: 'warthog', + 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius', + 345: 'ox', + 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis', + 347: 'bison', + 348: 'ram, tup', + 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis', + 350: 'ibex, Capra ibex', + 351: 'hartebeest', + 352: 'impala, Aepyceros melampus', + 353: 'gazelle', + 354: 'Arabian camel, dromedary, Camelus dromedarius', + 355: 'llama', + 356: 'weasel', + 357: 'mink', + 358: 'polecat, fitch, foulmart, foumart, Mustela putorius', + 359: 'black-footed ferret, ferret, Mustela nigripes', + 360: 'otter', + 361: 'skunk, polecat, wood pussy', + 362: 'badger', + 363: 'armadillo', + 364: 'three-toed sloth, ai, Bradypus tridactylus', + 365: 'orangutan, orang, orangutang, Pongo pygmaeus', + 366: 'gorilla, Gorilla gorilla', + 367: 'chimpanzee, chimp, Pan troglodytes', + 368: 'gibbon, Hylobates lar', + 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus', + 370: 'guenon, guenon monkey', + 371: 'patas, hussar monkey, Erythrocebus patas', + 372: 'baboon', + 373: 'macaque', + 374: 'langur', + 375: 'colobus, colobus monkey', + 376: 'proboscis monkey, Nasalis larvatus', + 377: 'marmoset', + 378: 'capuchin, ringtail, Cebus capucinus', + 379: 'howler monkey, howler', + 380: 'titi, titi monkey', + 381: 'spider monkey, Ateles geoffroyi', + 382: 'squirrel monkey, Saimiri sciureus', + 383: 'Madagascar cat, ring-tailed lemur, Lemur catta', + 384: 'indri, indris, Indri indri, Indri brevicaudatus', + 385: 'Indian elephant, Elephas maximus', + 386: 'African elephant, Loxodonta africana', + 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens', + 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca', + 389: 'barracouta, snoek', + 390: 'eel', + 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch', + 392: 'rock beauty, Holocanthus tricolor', + 393: 'anemone fish', + 394: 'sturgeon', + 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus', + 396: 'lionfish', + 397: 'puffer, pufferfish, blowfish, globefish', + 398: 'abacus', + 399: 'abaya', + 400: "academic gown, academic robe, judge's robe", + 401: 'accordion, piano accordion, squeeze box', + 402: 'acoustic guitar', + 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier', + 404: 'airliner', + 405: 'airship, dirigible', + 406: 'altar', + 407: 'ambulance', + 408: 'amphibian, amphibious vehicle', + 409: 'analog clock', + 410: 'apiary, bee house', + 411: 'apron', + 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin', + 413: 'assault rifle, assault gun', + 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack', + 415: 'bakery, bakeshop, bakehouse', + 416: 'balance beam, beam', + 417: 'balloon', + 418: 'ballpoint, ballpoint pen, ballpen, Biro', + 419: 'Band Aid', + 420: 'banjo', + 421: 'bannister, banister, balustrade, balusters, handrail', + 422: 'barbell', + 423: 'barber chair', + 424: 'barbershop', + 425: 'barn', + 426: 'barometer', + 427: 'barrel, cask', + 428: 'barrow, garden cart, lawn cart, wheelbarrow', + 429: 'baseball', + 430: 'basketball', + 431: 'bassinet', + 432: 'bassoon', + 433: 'bathing cap, swimming cap', + 434: 'bath towel', + 435: 'bathtub, bathing tub, bath, tub', + 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', + 437: 'beacon, lighthouse, beacon light, pharos', + 438: 'beaker', + 439: 'bearskin, busby, shako', + 440: 'beer bottle', + 441: 'beer glass', + 442: 'bell cote, bell cot', + 443: 'bib', + 444: 'bicycle-built-for-two, tandem bicycle, tandem', + 445: 'bikini, two-piece', + 446: 'binder, ring-binder', + 447: 'binoculars, field glasses, opera glasses', + 448: 'birdhouse', + 449: 'boathouse', + 450: 'bobsled, bobsleigh, bob', + 451: 'bolo tie, bolo, bola tie, bola', + 452: 'bonnet, poke bonnet', + 453: 'bookcase', + 454: 'bookshop, bookstore, bookstall', + 455: 'bottlecap', + 456: 'bow', + 457: 'bow tie, bow-tie, bowtie', + 458: 'brass, memorial tablet, plaque', + 459: 'brassiere, bra, bandeau', + 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty', + 461: 'breastplate, aegis, egis', + 462: 'broom', + 463: 'bucket, pail', + 464: 'buckle', + 465: 'bulletproof vest', + 466: 'bullet train, bullet', + 467: 'butcher shop, meat market', + 468: 'cab, hack, taxi, taxicab', + 469: 'caldron, cauldron', + 470: 'candle, taper, wax light', + 471: 'cannon', + 472: 'canoe', + 473: 'can opener, tin opener', + 474: 'cardigan', + 475: 'car mirror', + 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig', + 477: "carpenter's kit, tool kit", + 478: 'carton', + 479: 'car wheel', + 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM', + 481: 'cassette', + 482: 'cassette player', + 483: 'castle', + 484: 'catamaran', + 485: 'CD player', + 486: 'cello, violoncello', + 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone', + 488: 'chain', + 489: 'chainlink fence', + 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour', + 491: 'chain saw, chainsaw', + 492: 'chest', + 493: 'chiffonier, commode', + 494: 'chime, bell, gong', + 495: 'china cabinet, china closet', + 496: 'Christmas stocking', + 497: 'church, church building', + 498: 'cinema, movie theater, movie theatre, movie house, picture palace', + 499: 'cleaver, meat cleaver, chopper', + 500: 'cliff dwelling', + 501: 'cloak', + 502: 'clog, geta, patten, sabot', + 503: 'cocktail shaker', + 504: 'coffee mug', + 505: 'coffeepot', + 506: 'coil, spiral, volute, whorl, helix', + 507: 'combination lock', + 508: 'computer keyboard, keypad', + 509: 'confectionery, confectionary, candy store', + 510: 'container ship, containership, container vessel', + 511: 'convertible', + 512: 'corkscrew, bottle screw', + 513: 'cornet, horn, trumpet, trump', + 514: 'cowboy boot', + 515: 'cowboy hat, ten-gallon hat', + 516: 'cradle', + 517: 'crane', + 518: 'crash helmet', + 519: 'crate', + 520: 'crib, cot', + 521: 'Crock Pot', + 522: 'croquet ball', + 523: 'crutch', + 524: 'cuirass', + 525: 'dam, dike, dyke', + 526: 'desk', + 527: 'desktop computer', + 528: 'dial telephone, dial phone', + 529: 'diaper, nappy, napkin', + 530: 'digital clock', + 531: 'digital watch', + 532: 'dining table, board', + 533: 'dishrag, dishcloth', + 534: 'dishwasher, dish washer, dishwashing machine', + 535: 'disk brake, disc brake', + 536: 'dock, dockage, docking facility', + 537: 'dogsled, dog sled, dog sleigh', + 538: 'dome', + 539: 'doormat, welcome mat', + 540: 'drilling platform, offshore rig', + 541: 'drum, membranophone, tympan', + 542: 'drumstick', + 543: 'dumbbell', + 544: 'Dutch oven', + 545: 'electric fan, blower', + 546: 'electric guitar', + 547: 'electric locomotive', + 548: 'entertainment center', + 549: 'envelope', + 550: 'espresso maker', + 551: 'face powder', + 552: 'feather boa, boa', + 553: 'file, file cabinet, filing cabinet', + 554: 'fireboat', + 555: 'fire engine, fire truck', + 556: 'fire screen, fireguard', + 557: 'flagpole, flagstaff', + 558: 'flute, transverse flute', + 559: 'folding chair', + 560: 'football helmet', + 561: 'forklift', + 562: 'fountain', + 563: 'fountain pen', + 564: 'four-poster', + 565: 'freight car', + 566: 'French horn, horn', + 567: 'frying pan, frypan, skillet', + 568: 'fur coat', + 569: 'garbage truck, dustcart', + 570: 'gasmask, respirator, gas helmet', + 571: 'gas pump, gasoline pump, petrol pump, island dispenser', + 572: 'goblet', + 573: 'go-kart', + 574: 'golf ball', + 575: 'golfcart, golf cart', + 576: 'gondola', + 577: 'gong, tam-tam', + 578: 'gown', + 579: 'grand piano, grand', + 580: 'greenhouse, nursery, glasshouse', + 581: 'grille, radiator grille', + 582: 'grocery store, grocery, food market, market', + 583: 'guillotine', + 584: 'hair slide', + 585: 'hair spray', + 586: 'half track', + 587: 'hammer', + 588: 'hamper', + 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier', + 590: 'hand-held computer, hand-held microcomputer', + 591: 'handkerchief, hankie, hanky, hankey', + 592: 'hard disc, hard disk, fixed disk', + 593: 'harmonica, mouth organ, harp, mouth harp', + 594: 'harp', + 595: 'harvester, reaper', + 596: 'hatchet', + 597: 'holster', + 598: 'home theater, home theatre', + 599: 'honeycomb', + 600: 'hook, claw', + 601: 'hoopskirt, crinoline', + 602: 'horizontal bar, high bar', + 603: 'horse cart, horse-cart', + 604: 'hourglass', + 605: 'iPod', + 606: 'iron, smoothing iron', + 607: "jack-o'-lantern", + 608: 'jean, blue jean, denim', + 609: 'jeep, landrover', + 610: 'jersey, T-shirt, tee shirt', + 611: 'jigsaw puzzle', + 612: 'jinrikisha, ricksha, rickshaw', + 613: 'joystick', + 614: 'kimono', + 615: 'knee pad', + 616: 'knot', + 617: 'lab coat, laboratory coat', + 618: 'ladle', + 619: 'lampshade, lamp shade', + 620: 'laptop, laptop computer', + 621: 'lawn mower, mower', + 622: 'lens cap, lens cover', + 623: 'letter opener, paper knife, paperknife', + 624: 'library', + 625: 'lifeboat', + 626: 'lighter, light, igniter, ignitor', + 627: 'limousine, limo', + 628: 'liner, ocean liner', + 629: 'lipstick, lip rouge', + 630: 'Loafer', + 631: 'lotion', + 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system', + 633: "loupe, jeweler's loupe", + 634: 'lumbermill, sawmill', + 635: 'magnetic compass', + 636: 'mailbag, postbag', + 637: 'mailbox, letter box', + 638: 'maillot', + 639: 'maillot, tank suit', + 640: 'manhole cover', + 641: 'maraca', + 642: 'marimba, xylophone', + 643: 'mask', + 644: 'matchstick', + 645: 'maypole', + 646: 'maze, labyrinth', + 647: 'measuring cup', + 648: 'medicine chest, medicine cabinet', + 649: 'megalith, megalithic structure', + 650: 'microphone, mike', + 651: 'microwave, microwave oven', + 652: 'military uniform', + 653: 'milk can', + 654: 'minibus', + 655: 'miniskirt, mini', + 656: 'minivan', + 657: 'missile', + 658: 'mitten', + 659: 'mixing bowl', + 660: 'mobile home, manufactured home', + 661: 'Model T', + 662: 'modem', + 663: 'monastery', + 664: 'monitor', + 665: 'moped', + 666: 'mortar', + 667: 'mortarboard', + 668: 'mosque', + 669: 'mosquito net', + 670: 'motor scooter, scooter', + 671: 'mountain bike, all-terrain bike, off-roader', + 672: 'mountain tent', + 673: 'mouse, computer mouse', + 674: 'mousetrap', + 675: 'moving van', + 676: 'muzzle', + 677: 'nail', + 678: 'neck brace', + 679: 'necklace', + 680: 'nipple', + 681: 'notebook, notebook computer', + 682: 'obelisk', + 683: 'oboe, hautboy, hautbois', + 684: 'ocarina, sweet potato', + 685: 'odometer, hodometer, mileometer, milometer', + 686: 'oil filter', + 687: 'organ, pipe organ', + 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO', + 689: 'overskirt', + 690: 'oxcart', + 691: 'oxygen mask', + 692: 'packet', + 693: 'paddle, boat paddle', + 694: 'paddlewheel, paddle wheel', + 695: 'padlock', + 696: 'paintbrush', + 697: "pajama, pyjama, pj's, jammies", + 698: 'palace', + 699: 'panpipe, pandean pipe, syrinx', + 700: 'paper towel', + 701: 'parachute, chute', + 702: 'parallel bars, bars', + 703: 'park bench', + 704: 'parking meter', + 705: 'passenger car, coach, carriage', + 706: 'patio, terrace', + 707: 'pay-phone, pay-station', + 708: 'pedestal, plinth, footstall', + 709: 'pencil box, pencil case', + 710: 'pencil sharpener', + 711: 'perfume, essence', + 712: 'Petri dish', + 713: 'photocopier', + 714: 'pick, plectrum, plectron', + 715: 'pickelhaube', + 716: 'picket fence, paling', + 717: 'pickup, pickup truck', + 718: 'pier', + 719: 'piggy bank, penny bank', + 720: 'pill bottle', + 721: 'pillow', + 722: 'ping-pong ball', + 723: 'pinwheel', + 724: 'pirate, pirate ship', + 725: 'pitcher, ewer', + 726: "plane, carpenter's plane, woodworking plane", + 727: 'planetarium', + 728: 'plastic bag', + 729: 'plate rack', + 730: 'plow, plough', + 731: "plunger, plumber's helper", + 732: 'Polaroid camera, Polaroid Land camera', + 733: 'pole', + 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria', + 735: 'poncho', + 736: 'pool table, billiard table, snooker table', + 737: 'pop bottle, soda bottle', + 738: 'pot, flowerpot', + 739: "potter's wheel", + 740: 'power drill', + 741: 'prayer rug, prayer mat', + 742: 'printer', + 743: 'prison, prison house', + 744: 'projectile, missile', + 745: 'projector', + 746: 'puck, hockey puck', + 747: 'punching bag, punch bag, punching ball, punchball', + 748: 'purse', + 749: 'quill, quill pen', + 750: 'quilt, comforter, comfort, puff', + 751: 'racer, race car, racing car', + 752: 'racket, racquet', + 753: 'radiator', + 754: 'radio, wireless', + 755: 'radio telescope, radio reflector', + 756: 'rain barrel', + 757: 'recreational vehicle, RV, R.V.', + 758: 'reel', + 759: 'reflex camera', + 760: 'refrigerator, icebox', + 761: 'remote control, remote', + 762: 'restaurant, eating house, eating place, eatery', + 763: 'revolver, six-gun, six-shooter', + 764: 'rifle', + 765: 'rocking chair, rocker', + 766: 'rotisserie', + 767: 'rubber eraser, rubber, pencil eraser', + 768: 'rugby ball', + 769: 'rule, ruler', + 770: 'running shoe', + 771: 'safe', + 772: 'safety pin', + 773: 'saltshaker, salt shaker', + 774: 'sandal', + 775: 'sarong', + 776: 'sax, saxophone', + 777: 'scabbard', + 778: 'scale, weighing machine', + 779: 'school bus', + 780: 'schooner', + 781: 'scoreboard', + 782: 'screen, CRT screen', + 783: 'screw', + 784: 'screwdriver', + 785: 'seat belt, seatbelt', + 786: 'sewing machine', + 787: 'shield, buckler', + 788: 'shoe shop, shoe-shop, shoe store', + 789: 'shoji', + 790: 'shopping basket', + 791: 'shopping cart', + 792: 'shovel', + 793: 'shower cap', + 794: 'shower curtain', + 795: 'ski', + 796: 'ski mask', + 797: 'sleeping bag', + 798: 'slide rule, slipstick', + 799: 'sliding door', + 800: 'slot, one-armed bandit', + 801: 'snorkel', + 802: 'snowmobile', + 803: 'snowplow, snowplough', + 804: 'soap dispenser', + 805: 'soccer ball', + 806: 'sock', + 807: 'solar dish, solar collector, solar furnace', + 808: 'sombrero', + 809: 'soup bowl', + 810: 'space bar', + 811: 'space heater', + 812: 'space shuttle', + 813: 'spatula', + 814: 'speedboat', + 815: "spider web, spider's web", + 816: 'spindle', + 817: 'sports car, sport car', + 818: 'spotlight, spot', + 819: 'stage', + 820: 'steam locomotive', + 821: 'steel arch bridge', + 822: 'steel drum', + 823: 'stethoscope', + 824: 'stole', + 825: 'stone wall', + 826: 'stopwatch, stop watch', + 827: 'stove', + 828: 'strainer', + 829: 'streetcar, tram, tramcar, trolley, trolley car', + 830: 'stretcher', + 831: 'studio couch, day bed', + 832: 'stupa, tope', + 833: 'submarine, pigboat, sub, U-boat', + 834: 'suit, suit of clothes', + 835: 'sundial', + 836: 'sunglass', + 837: 'sunglasses, dark glasses, shades', + 838: 'sunscreen, sunblock, sun blocker', + 839: 'suspension bridge', + 840: 'swab, swob, mop', + 841: 'sweatshirt', + 842: 'swimming trunks, bathing trunks', + 843: 'swing', + 844: 'switch, electric switch, electrical switch', + 845: 'syringe', + 846: 'table lamp', + 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle', + 848: 'tape player', + 849: 'teapot', + 850: 'teddy, teddy bear', + 851: 'television, television system', + 852: 'tennis ball', + 853: 'thatch, thatched roof', + 854: 'theater curtain, theatre curtain', + 855: 'thimble', + 856: 'thresher, thrasher, threshing machine', + 857: 'throne', + 858: 'tile roof', + 859: 'toaster', + 860: 'tobacco shop, tobacconist shop, tobacconist', + 861: 'toilet seat', + 862: 'torch', + 863: 'totem pole', + 864: 'tow truck, tow car, wrecker', + 865: 'toyshop', + 866: 'tractor', + 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi', + 868: 'tray', + 869: 'trench coat', + 870: 'tricycle, trike, velocipede', + 871: 'trimaran', + 872: 'tripod', + 873: 'triumphal arch', + 874: 'trolleybus, trolley coach, trackless trolley', + 875: 'trombone', + 876: 'tub, vat', + 877: 'turnstile', + 878: 'typewriter keyboard', + 879: 'umbrella', + 880: 'unicycle, monocycle', + 881: 'upright, upright piano', + 882: 'vacuum, vacuum cleaner', + 883: 'vase', + 884: 'vault', + 885: 'velvet', + 886: 'vending machine', + 887: 'vestment', + 888: 'viaduct', + 889: 'violin, fiddle', + 890: 'volleyball', + 891: 'waffle iron', + 892: 'wall clock', + 893: 'wallet, billfold, notecase, pocketbook', + 894: 'wardrobe, closet, press', + 895: 'warplane, military plane', + 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin', + 897: 'washer, automatic washer, washing machine', + 898: 'water bottle', + 899: 'water jug', + 900: 'water tower', + 901: 'whiskey jug', + 902: 'whistle', + 903: 'wig', + 904: 'window screen', + 905: 'window shade', + 906: 'Windsor tie', + 907: 'wine bottle', + 908: 'wing', + 909: 'wok', + 910: 'wooden spoon', + 911: 'wool, woolen, woollen', + 912: 'worm fence, snake fence, snake-rail fence, Virginia fence', + 913: 'wreck', + 914: 'yawl', + 915: 'yurt', + 916: 'web site, website, internet site, site', + 917: 'comic book', + 918: 'crossword puzzle, crossword', + 919: 'street sign', + 920: 'traffic light, traffic signal, stoplight', + 921: 'book jacket, dust cover, dust jacket, dust wrapper', + 922: 'menu', + 923: 'plate', + 924: 'guacamole', + 925: 'consomme', + 926: 'hot pot, hotpot', + 927: 'trifle', + 928: 'ice cream, icecream', + 929: 'ice lolly, lolly, lollipop, popsicle', + 930: 'French loaf', + 931: 'bagel, beigel', + 932: 'pretzel', + 933: 'cheeseburger', + 934: 'hotdog, hot dog, red hot', + 935: 'mashed potato', + 936: 'head cabbage', + 937: 'broccoli', + 938: 'cauliflower', + 939: 'zucchini, courgette', + 940: 'spaghetti squash', + 941: 'acorn squash', + 942: 'butternut squash', + 943: 'cucumber, cuke', + 944: 'artichoke, globe artichoke', + 945: 'bell pepper', + 946: 'cardoon', + 947: 'mushroom', + 948: 'Granny Smith', + 949: 'strawberry', + 950: 'orange', + 951: 'lemon', + 952: 'fig', + 953: 'pineapple, ananas', + 954: 'banana', + 955: 'jackfruit, jak, jack', + 956: 'custard apple', + 957: 'pomegranate', + 958: 'hay', + 959: 'carbonara', + 960: 'chocolate sauce, chocolate syrup', + 961: 'dough', + 962: 'meat loaf, meatloaf', + 963: 'pizza, pizza pie', + 964: 'potpie', + 965: 'burrito', + 966: 'red wine', + 967: 'espresso', + 968: 'cup', + 969: 'eggnog', + 970: 'alp', + 971: 'bubble', + 972: 'cliff, drop, drop-off', + 973: 'coral reef', + 974: 'geyser', + 975: 'lakeside, lakeshore', + 976: 'promontory, headland, head, foreland', + 977: 'sandbar, sand bar', + 978: 'seashore, coast, seacoast, sea-coast', + 979: 'valley, vale', + 980: 'volcano', + 981: 'ballplayer, baseball player', + 982: 'groom, bridegroom', + 983: 'scuba diver', + 984: 'rapeseed', + 985: 'daisy', + 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", + 987: 'corn', + 988: 'acorn', + 989: 'hip, rose hip, rosehip', + 990: 'buckeye, horse chestnut, conker', + 991: 'coral fungus', + 992: 'agaric', + 993: 'gyromitra', + 994: 'stinkhorn, carrion fungus', + 995: 'earthstar', + 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa', + 997: 'bolete', + 998: 'ear, spike, capitulum', + 999: 'toilet tissue, toilet paper, bathroom tissue'} \ No newline at end of file diff --git a/models.py b/models.py new file mode 100644 index 0000000000000000000000000000000000000000..fbb9a01d9aab53ef1f5c5b8d9dcf231d1e49ce95 --- /dev/null +++ b/models.py @@ -0,0 +1,350 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import numpy as np +import math +from timm.models.vision_transformer import PatchEmbed, Attention, Mlp + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +################################################################################# +# Core DiT Model # +################################################################################# + +class DiTBlock(nn.Module): + """ + A DiT block with gated adaptive layer norm (adaLN) conditioning. + """ + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class DiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + + self.blocks = nn.ModuleList([ + DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) + ]) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + return imgs + + def forward(self, x, t, y): + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of class labels + """ + x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(t) # (N, D) + y = self.y_embedder(y, self.training) # (N, D) + c = t + y # (N, D) + for block in self.blocks: + x = block(x, c) # (N, T, D) + x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + return x + + def forward_with_cfg(self, x, t, y, cfg_scale): + """ + Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, t, y) + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +################################################################################# +# DiT Configs # +################################################################################# + +def DiT_XL_2(**kwargs): + return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) + +def DiT_XL_4(**kwargs): + return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) + +def DiT_XL_8(**kwargs): + return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) + +def DiT_L_2(**kwargs): + return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) + +def DiT_L_4(**kwargs): + return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) + +def DiT_L_8(**kwargs): + return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) + +def DiT_B_2(**kwargs): + return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) + +def DiT_B_4(**kwargs): + return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) + +def DiT_B_8(**kwargs): + return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) + +def DiT_S_2(**kwargs): + return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) + +def DiT_S_4(**kwargs): + return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) + +def DiT_S_8(**kwargs): + return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0ebd02af4bc54eb452dbaae36fca4c5621b1b99b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +torch +torchvision +numpy +tqdm +timm +pillow +diffusers +accelerate