Update pipeline_imagedream.py
Browse files- pipeline_imagedream.py +154 -3
pipeline_imagedream.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from typing import Any, Callable, Dict, List, Optional, Union
|
2 |
|
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
@@ -17,6 +18,8 @@ except:
|
|
17 |
|
18 |
from diffusers.image_processor import PipelineImageInput
|
19 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
|
|
|
|
20 |
from diffusers.pipelines.stable_diffusion.pipeline_output import (
|
21 |
StableDiffusionPipelineOutput,
|
22 |
)
|
@@ -37,9 +40,6 @@ from transformers import (
|
|
37 |
CLIPVisionModel,
|
38 |
)
|
39 |
|
40 |
-
from attention_processor import add_imagedream_attn_processor
|
41 |
-
from camera_utils import get_camera
|
42 |
-
|
43 |
|
44 |
class ImageDreamPipeline(StableDiffusionPipeline):
|
45 |
def __init__(
|
@@ -417,3 +417,154 @@ class ImageDreamPipeline(StableDiffusionPipeline):
|
|
417 |
return StableDiffusionPipelineOutput(
|
418 |
images=image, nsfw_content_detected=has_nsfw_concept
|
419 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from typing import Any, Callable, Dict, List, Optional, Union
|
2 |
|
3 |
+
import numpy as np
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
|
|
18 |
|
19 |
from diffusers.image_processor import PipelineImageInput
|
20 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
21 |
+
from diffusers.models.attention import Attention
|
22 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
23 |
from diffusers.pipelines.stable_diffusion.pipeline_output import (
|
24 |
StableDiffusionPipelineOutput,
|
25 |
)
|
|
|
40 |
CLIPVisionModel,
|
41 |
)
|
42 |
|
|
|
|
|
|
|
43 |
|
44 |
class ImageDreamPipeline(StableDiffusionPipeline):
|
45 |
def __init__(
|
|
|
417 |
return StableDiffusionPipelineOutput(
|
418 |
images=image, nsfw_content_detected=has_nsfw_concept
|
419 |
)
|
420 |
+
|
421 |
+
|
422 |
+
# fmt: off
|
423 |
+
# Copied from ImageDream
|
424 |
+
# https://github.com/bytedance/ImageDream/blob/main/extern/ImageDream/imagedream/camera_utils.py
|
425 |
+
|
426 |
+
|
427 |
+
def create_camera_to_world_matrix(elevation, azimuth):
|
428 |
+
elevation = np.radians(elevation)
|
429 |
+
azimuth = np.radians(azimuth)
|
430 |
+
# Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere
|
431 |
+
x = np.cos(elevation) * np.sin(azimuth)
|
432 |
+
y = np.sin(elevation)
|
433 |
+
z = np.cos(elevation) * np.cos(azimuth)
|
434 |
+
|
435 |
+
# Calculate camera position, target, and up vectors
|
436 |
+
camera_pos = np.array([x, y, z])
|
437 |
+
target = np.array([0, 0, 0])
|
438 |
+
up = np.array([0, 1, 0])
|
439 |
+
|
440 |
+
# Construct view matrix
|
441 |
+
forward = target - camera_pos
|
442 |
+
forward /= np.linalg.norm(forward)
|
443 |
+
right = np.cross(forward, up)
|
444 |
+
right /= np.linalg.norm(right)
|
445 |
+
new_up = np.cross(right, forward)
|
446 |
+
new_up /= np.linalg.norm(new_up)
|
447 |
+
cam2world = np.eye(4)
|
448 |
+
cam2world[:3, :3] = np.array([right, new_up, -forward]).T
|
449 |
+
cam2world[:3, 3] = camera_pos
|
450 |
+
return cam2world
|
451 |
+
|
452 |
+
|
453 |
+
def convert_opengl_to_blender(camera_matrix):
|
454 |
+
if isinstance(camera_matrix, np.ndarray):
|
455 |
+
# Construct transformation matrix to convert from OpenGL space to Blender space
|
456 |
+
flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
|
457 |
+
camera_matrix_blender = np.dot(flip_yz, camera_matrix)
|
458 |
+
else:
|
459 |
+
# Construct transformation matrix to convert from OpenGL space to Blender space
|
460 |
+
flip_yz = torch.tensor(
|
461 |
+
[[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]
|
462 |
+
)
|
463 |
+
if camera_matrix.ndim == 3:
|
464 |
+
flip_yz = flip_yz.unsqueeze(0)
|
465 |
+
camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix)
|
466 |
+
return camera_matrix_blender
|
467 |
+
|
468 |
+
|
469 |
+
def normalize_camera(camera_matrix):
|
470 |
+
"""normalize the camera location onto a unit-sphere"""
|
471 |
+
if isinstance(camera_matrix, np.ndarray):
|
472 |
+
camera_matrix = camera_matrix.reshape(-1, 4, 4)
|
473 |
+
translation = camera_matrix[:, :3, 3]
|
474 |
+
translation = translation / (
|
475 |
+
np.linalg.norm(translation, axis=1, keepdims=True) + 1e-8
|
476 |
+
)
|
477 |
+
camera_matrix[:, :3, 3] = translation
|
478 |
+
else:
|
479 |
+
camera_matrix = camera_matrix.reshape(-1, 4, 4)
|
480 |
+
translation = camera_matrix[:, :3, 3]
|
481 |
+
translation = translation / (
|
482 |
+
torch.norm(translation, dim=1, keepdim=True) + 1e-8
|
483 |
+
)
|
484 |
+
camera_matrix[:, :3, 3] = translation
|
485 |
+
return camera_matrix.reshape(-1, 16)
|
486 |
+
|
487 |
+
|
488 |
+
def get_camera(
|
489 |
+
num_frames,
|
490 |
+
elevation=15,
|
491 |
+
azimuth_start=0,
|
492 |
+
azimuth_span=360,
|
493 |
+
blender_coord=True,
|
494 |
+
extra_view=False,
|
495 |
+
):
|
496 |
+
angle_gap = azimuth_span / num_frames
|
497 |
+
cameras = []
|
498 |
+
for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
|
499 |
+
camera_matrix = create_camera_to_world_matrix(elevation, azimuth)
|
500 |
+
if blender_coord:
|
501 |
+
camera_matrix = convert_opengl_to_blender(camera_matrix)
|
502 |
+
cameras.append(camera_matrix.flatten())
|
503 |
+
|
504 |
+
if extra_view:
|
505 |
+
dim = len(cameras[0])
|
506 |
+
cameras.append(np.zeros(dim))
|
507 |
+
return torch.tensor(np.stack(cameras, 0)).float()
|
508 |
+
# fmt: on
|
509 |
+
|
510 |
+
|
511 |
+
def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> nn.Module:
|
512 |
+
attn_procs = {}
|
513 |
+
for key, attn_processor in unet.attn_processors.items():
|
514 |
+
if "attn1" in key:
|
515 |
+
attn_procs[key] = ImageDreamAttnProcessor2_0()
|
516 |
+
else:
|
517 |
+
attn_procs[key] = attn_processor
|
518 |
+
unet.set_attn_processor(attn_procs)
|
519 |
+
return unet
|
520 |
+
|
521 |
+
|
522 |
+
class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
|
523 |
+
def __call__(
|
524 |
+
self,
|
525 |
+
attn: Attention,
|
526 |
+
hidden_states: torch.Tensor,
|
527 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
528 |
+
attention_mask: Optional[torch.Tensor] = None,
|
529 |
+
temb: Optional[torch.Tensor] = None,
|
530 |
+
num_views: int = 1,
|
531 |
+
*args,
|
532 |
+
**kwargs,
|
533 |
+
):
|
534 |
+
if num_views == 1:
|
535 |
+
return super().__call__(
|
536 |
+
attn=attn,
|
537 |
+
hidden_states=hidden_states,
|
538 |
+
encoder_hidden_states=encoder_hidden_states,
|
539 |
+
attention_mask=attention_mask,
|
540 |
+
temb=temb,
|
541 |
+
*args,
|
542 |
+
**kwargs,
|
543 |
+
)
|
544 |
+
|
545 |
+
input_ndim = hidden_states.ndim
|
546 |
+
B = hidden_states.size(0)
|
547 |
+
if B % num_views:
|
548 |
+
raise ValueError(
|
549 |
+
f"`batch_size`(got {B}) must be a multiple of `num_views`(got {num_views})."
|
550 |
+
)
|
551 |
+
real_B = B // num_views
|
552 |
+
if input_ndim == 4:
|
553 |
+
H, W = hidden_states.shape[2:]
|
554 |
+
hidden_states = hidden_states.reshape(real_B, -1, H, W).transpose(1, 2)
|
555 |
+
else:
|
556 |
+
hidden_states = hidden_states.reshape(real_B, -1, hidden_states.size(-1))
|
557 |
+
hidden_states = super().__call__(
|
558 |
+
attn=attn,
|
559 |
+
hidden_states=hidden_states,
|
560 |
+
encoder_hidden_states=encoder_hidden_states,
|
561 |
+
attention_mask=attention_mask,
|
562 |
+
temb=temb,
|
563 |
+
*args,
|
564 |
+
**kwargs,
|
565 |
+
)
|
566 |
+
if input_ndim == 4:
|
567 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(B, -1, H, W)
|
568 |
+
else:
|
569 |
+
hidden_states = hidden_states.reshape(B, -1, hidden_states.size(-1))
|
570 |
+
return hidden_states
|