kiigii commited on
Commit
8f6ec56
·
verified ·
1 Parent(s): 5684c31

Update pipeline_imagedream.py

Browse files
Files changed (1) hide show
  1. pipeline_imagedream.py +154 -3
pipeline_imagedream.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import Any, Callable, Dict, List, Optional, Union
2
 
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
@@ -17,6 +18,8 @@ except:
17
 
18
  from diffusers.image_processor import PipelineImageInput
19
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
 
 
20
  from diffusers.pipelines.stable_diffusion.pipeline_output import (
21
  StableDiffusionPipelineOutput,
22
  )
@@ -37,9 +40,6 @@ from transformers import (
37
  CLIPVisionModel,
38
  )
39
 
40
- from attention_processor import add_imagedream_attn_processor
41
- from camera_utils import get_camera
42
-
43
 
44
  class ImageDreamPipeline(StableDiffusionPipeline):
45
  def __init__(
@@ -417,3 +417,154 @@ class ImageDreamPipeline(StableDiffusionPipeline):
417
  return StableDiffusionPipelineOutput(
418
  images=image, nsfw_content_detected=has_nsfw_concept
419
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Any, Callable, Dict, List, Optional, Union
2
 
3
+ import numpy as np
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
 
18
 
19
  from diffusers.image_processor import PipelineImageInput
20
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
21
+ from diffusers.models.attention import Attention
22
+ from diffusers.models.attention_processor import AttnProcessor2_0
23
  from diffusers.pipelines.stable_diffusion.pipeline_output import (
24
  StableDiffusionPipelineOutput,
25
  )
 
40
  CLIPVisionModel,
41
  )
42
 
 
 
 
43
 
44
  class ImageDreamPipeline(StableDiffusionPipeline):
45
  def __init__(
 
417
  return StableDiffusionPipelineOutput(
418
  images=image, nsfw_content_detected=has_nsfw_concept
419
  )
420
+
421
+
422
+ # fmt: off
423
+ # Copied from ImageDream
424
+ # https://github.com/bytedance/ImageDream/blob/main/extern/ImageDream/imagedream/camera_utils.py
425
+
426
+
427
+ def create_camera_to_world_matrix(elevation, azimuth):
428
+ elevation = np.radians(elevation)
429
+ azimuth = np.radians(azimuth)
430
+ # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere
431
+ x = np.cos(elevation) * np.sin(azimuth)
432
+ y = np.sin(elevation)
433
+ z = np.cos(elevation) * np.cos(azimuth)
434
+
435
+ # Calculate camera position, target, and up vectors
436
+ camera_pos = np.array([x, y, z])
437
+ target = np.array([0, 0, 0])
438
+ up = np.array([0, 1, 0])
439
+
440
+ # Construct view matrix
441
+ forward = target - camera_pos
442
+ forward /= np.linalg.norm(forward)
443
+ right = np.cross(forward, up)
444
+ right /= np.linalg.norm(right)
445
+ new_up = np.cross(right, forward)
446
+ new_up /= np.linalg.norm(new_up)
447
+ cam2world = np.eye(4)
448
+ cam2world[:3, :3] = np.array([right, new_up, -forward]).T
449
+ cam2world[:3, 3] = camera_pos
450
+ return cam2world
451
+
452
+
453
+ def convert_opengl_to_blender(camera_matrix):
454
+ if isinstance(camera_matrix, np.ndarray):
455
+ # Construct transformation matrix to convert from OpenGL space to Blender space
456
+ flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
457
+ camera_matrix_blender = np.dot(flip_yz, camera_matrix)
458
+ else:
459
+ # Construct transformation matrix to convert from OpenGL space to Blender space
460
+ flip_yz = torch.tensor(
461
+ [[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]
462
+ )
463
+ if camera_matrix.ndim == 3:
464
+ flip_yz = flip_yz.unsqueeze(0)
465
+ camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix)
466
+ return camera_matrix_blender
467
+
468
+
469
+ def normalize_camera(camera_matrix):
470
+ """normalize the camera location onto a unit-sphere"""
471
+ if isinstance(camera_matrix, np.ndarray):
472
+ camera_matrix = camera_matrix.reshape(-1, 4, 4)
473
+ translation = camera_matrix[:, :3, 3]
474
+ translation = translation / (
475
+ np.linalg.norm(translation, axis=1, keepdims=True) + 1e-8
476
+ )
477
+ camera_matrix[:, :3, 3] = translation
478
+ else:
479
+ camera_matrix = camera_matrix.reshape(-1, 4, 4)
480
+ translation = camera_matrix[:, :3, 3]
481
+ translation = translation / (
482
+ torch.norm(translation, dim=1, keepdim=True) + 1e-8
483
+ )
484
+ camera_matrix[:, :3, 3] = translation
485
+ return camera_matrix.reshape(-1, 16)
486
+
487
+
488
+ def get_camera(
489
+ num_frames,
490
+ elevation=15,
491
+ azimuth_start=0,
492
+ azimuth_span=360,
493
+ blender_coord=True,
494
+ extra_view=False,
495
+ ):
496
+ angle_gap = azimuth_span / num_frames
497
+ cameras = []
498
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
499
+ camera_matrix = create_camera_to_world_matrix(elevation, azimuth)
500
+ if blender_coord:
501
+ camera_matrix = convert_opengl_to_blender(camera_matrix)
502
+ cameras.append(camera_matrix.flatten())
503
+
504
+ if extra_view:
505
+ dim = len(cameras[0])
506
+ cameras.append(np.zeros(dim))
507
+ return torch.tensor(np.stack(cameras, 0)).float()
508
+ # fmt: on
509
+
510
+
511
+ def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> nn.Module:
512
+ attn_procs = {}
513
+ for key, attn_processor in unet.attn_processors.items():
514
+ if "attn1" in key:
515
+ attn_procs[key] = ImageDreamAttnProcessor2_0()
516
+ else:
517
+ attn_procs[key] = attn_processor
518
+ unet.set_attn_processor(attn_procs)
519
+ return unet
520
+
521
+
522
+ class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
523
+ def __call__(
524
+ self,
525
+ attn: Attention,
526
+ hidden_states: torch.Tensor,
527
+ encoder_hidden_states: Optional[torch.Tensor] = None,
528
+ attention_mask: Optional[torch.Tensor] = None,
529
+ temb: Optional[torch.Tensor] = None,
530
+ num_views: int = 1,
531
+ *args,
532
+ **kwargs,
533
+ ):
534
+ if num_views == 1:
535
+ return super().__call__(
536
+ attn=attn,
537
+ hidden_states=hidden_states,
538
+ encoder_hidden_states=encoder_hidden_states,
539
+ attention_mask=attention_mask,
540
+ temb=temb,
541
+ *args,
542
+ **kwargs,
543
+ )
544
+
545
+ input_ndim = hidden_states.ndim
546
+ B = hidden_states.size(0)
547
+ if B % num_views:
548
+ raise ValueError(
549
+ f"`batch_size`(got {B}) must be a multiple of `num_views`(got {num_views})."
550
+ )
551
+ real_B = B // num_views
552
+ if input_ndim == 4:
553
+ H, W = hidden_states.shape[2:]
554
+ hidden_states = hidden_states.reshape(real_B, -1, H, W).transpose(1, 2)
555
+ else:
556
+ hidden_states = hidden_states.reshape(real_B, -1, hidden_states.size(-1))
557
+ hidden_states = super().__call__(
558
+ attn=attn,
559
+ hidden_states=hidden_states,
560
+ encoder_hidden_states=encoder_hidden_states,
561
+ attention_mask=attention_mask,
562
+ temb=temb,
563
+ *args,
564
+ **kwargs,
565
+ )
566
+ if input_ndim == 4:
567
+ hidden_states = hidden_states.transpose(-1, -2).reshape(B, -1, H, W)
568
+ else:
569
+ hidden_states = hidden_states.reshape(B, -1, hidden_states.size(-1))
570
+ return hidden_states