Spaces:
Runtime error
Runtime error
Commit
·
710e5f8
1
Parent(s):
f65b8d3
fix inversion
Browse files- gradio_app.py +1 -2
- main.py +6 -4
- src/diffusion_model_wrapper.py +5 -5
gradio_app.py
CHANGED
@@ -18,7 +18,6 @@ This demo supports both generated images and real images. To modify a real image
|
|
18 |
'''
|
19 |
|
20 |
stable, stable_config = setup(LPMConfig())
|
21 |
-
stable_for_inversion, _ = setup(LPMConfig())
|
22 |
|
23 |
def main_pipeline(
|
24 |
prompt: str,
|
@@ -48,7 +47,7 @@ def main_pipeline(
|
|
48 |
real_image_path="" if input_image is None else input_image
|
49 |
)
|
50 |
|
51 |
-
result_images, result_proxy_words = main(stable, stable_config,
|
52 |
result_images = [im.permute(1, 2, 0).cpu().numpy() for im in result_images]
|
53 |
result_images = [(im * 255).astype(np.uint8) for im in result_images]
|
54 |
result_images = [Image.fromarray(im) for im in result_images]
|
|
|
18 |
'''
|
19 |
|
20 |
stable, stable_config = setup(LPMConfig())
|
|
|
21 |
|
22 |
def main_pipeline(
|
23 |
prompt: str,
|
|
|
47 |
real_image_path="" if input_image is None else input_image
|
48 |
)
|
49 |
|
50 |
+
result_images, result_proxy_words = main(stable, stable_config, args)
|
51 |
result_images = [im.permute(1, 2, 0).cpu().numpy() for im in result_images]
|
52 |
result_images = [(im * 255).astype(np.uint8) for im in result_images]
|
53 |
result_images = [Image.fromarray(im) for im in result_images]
|
main.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
import json
|
2 |
import os
|
|
|
|
|
|
|
3 |
import pyrallis
|
4 |
import torch
|
5 |
-
from dataclasses import dataclass, field
|
6 |
from torch.utils.data import DataLoader
|
7 |
from torchvision.transforms import ToTensor
|
8 |
from torchvision.utils import save_image
|
9 |
from tqdm import tqdm
|
10 |
-
from typing import List
|
11 |
|
12 |
from src.diffusion_model_wrapper import DiffusionModelWrapper, get_stable_diffusion_model, get_stable_diffusion_config, \
|
13 |
generate_original_image
|
@@ -34,7 +35,7 @@ def setup(args):
|
|
34 |
return ldm_stable, ldm_stable_config
|
35 |
|
36 |
|
37 |
-
def main(ldm_stable, ldm_stable_config,
|
38 |
|
39 |
similar_words, prompts, another_prompts = get_proxy_prompts(args, ldm_stable)
|
40 |
exp_path = save_args_dict(args, similar_words)
|
@@ -44,7 +45,8 @@ def main(ldm_stable, ldm_stable_config, ldm_stable_inversion, args):
|
|
44 |
uncond_embeddings = None
|
45 |
|
46 |
if args.real_image_path != "":
|
47 |
-
|
|
|
48 |
|
49 |
image, x_t, orig_all_latents, orig_mask, average_attention = generate_original_image(args, ldm_stable, ldm_stable_config, prompts, x_t, uncond_embeddings)
|
50 |
save_image(ToTensor()(image[0]), f"{exp_path}/{similar_words[0]}.jpg")
|
|
|
1 |
import json
|
2 |
import os
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import List
|
5 |
+
|
6 |
import pyrallis
|
7 |
import torch
|
|
|
8 |
from torch.utils.data import DataLoader
|
9 |
from torchvision.transforms import ToTensor
|
10 |
from torchvision.utils import save_image
|
11 |
from tqdm import tqdm
|
|
|
12 |
|
13 |
from src.diffusion_model_wrapper import DiffusionModelWrapper, get_stable_diffusion_model, get_stable_diffusion_config, \
|
14 |
generate_original_image
|
|
|
35 |
return ldm_stable, ldm_stable_config
|
36 |
|
37 |
|
38 |
+
def main(ldm_stable, ldm_stable_config, args):
|
39 |
|
40 |
similar_words, prompts, another_prompts = get_proxy_prompts(args, ldm_stable)
|
41 |
exp_path = save_args_dict(args, similar_words)
|
|
|
45 |
uncond_embeddings = None
|
46 |
|
47 |
if args.real_image_path != "":
|
48 |
+
ldm_stable, ldm_stable_config = setup(args)
|
49 |
+
x_t, uncond_embeddings = invert_image(args, ldm_stable, ldm_stable_config, prompts, exp_path)
|
50 |
|
51 |
image, x_t, orig_all_latents, orig_mask, average_attention = generate_original_image(args, ldm_stable, ldm_stable_config, prompts, x_t, uncond_embeddings)
|
52 |
save_image(ToTensor()(image[0]), f"{exp_path}/{similar_words[0]}.jpg")
|
src/diffusion_model_wrapper.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
import torch
|
2 |
-
import numpy as np
|
3 |
from typing import Optional, List
|
4 |
|
|
|
|
|
|
|
5 |
from diffusers import DDIMScheduler, StableDiffusionPipeline
|
6 |
from tqdm import tqdm
|
7 |
-
from cv2 import dilate
|
8 |
|
9 |
-
from src.attention_utils import show_cross_attention
|
10 |
from src.attention_based_segmentation import Segmentor
|
|
|
11 |
from src.prompt_to_prompt_controllers import DummyController, AttentionStore
|
12 |
|
13 |
|
@@ -136,7 +136,7 @@ class DiffusionModelWrapper:
|
|
136 |
if self.enbale_attn_controller_changes:
|
137 |
attn = self.controller(attn, is_cross, place_in_unet)
|
138 |
|
139 |
-
if is_cross and
|
140 |
attn = self.prompt_mixing.get_cross_attn(self, self.diff_step, attn, place_in_unet, batch_size)
|
141 |
|
142 |
if not is_cross and (not self.model_config["low_resource"] or not self.uncond_pred) and self.prompt_mixing is not None:
|
|
|
|
|
|
|
1 |
from typing import Optional, List
|
2 |
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from cv2 import dilate
|
6 |
from diffusers import DDIMScheduler, StableDiffusionPipeline
|
7 |
from tqdm import tqdm
|
|
|
8 |
|
|
|
9 |
from src.attention_based_segmentation import Segmentor
|
10 |
+
from src.attention_utils import show_cross_attention
|
11 |
from src.prompt_to_prompt_controllers import DummyController, AttentionStore
|
12 |
|
13 |
|
|
|
136 |
if self.enbale_attn_controller_changes:
|
137 |
attn = self.controller(attn, is_cross, place_in_unet)
|
138 |
|
139 |
+
if is_cross and self.prompt_mixing is not None and context[1] is not None:
|
140 |
attn = self.prompt_mixing.get_cross_attn(self, self.diff_step, attn, place_in_unet, batch_size)
|
141 |
|
142 |
if not is_cross and (not self.model_config["low_resource"] or not self.uncond_pred) and self.prompt_mixing is not None:
|