YulianSa commited on
Commit
829e08b
·
1 Parent(s): e7033f8
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. .gitignore +4 -0
  3. app.py +390 -0
  4. assets/teaser.jpg +3 -0
  5. configs/infer.yml +52 -0
  6. data/basic_shapes_norm/SM_GR_BS_CubeBevel_001.ply +0 -0
  7. data/basic_shapes_norm/SM_GR_BS_CylinderSharp_001.ply +0 -0
  8. data/basic_shapes_norm/SM_GR_BS_SphereSharp_001.ply +0 -0
  9. data/basic_shapes_norm/basic_shapes.json +89 -0
  10. data/basic_shapes_norm_pc10000/SM_GR_BS_CubeBevel_001.ply +3 -0
  11. data/basic_shapes_norm_pc10000/SM_GR_BS_CylinderSharp_001.ply +3 -0
  12. data/basic_shapes_norm_pc10000/SM_GR_BS_SphereSharp_001.ply +3 -0
  13. data/demo_glb/barbell.glb +3 -0
  14. data/demo_glb/book.glb +3 -0
  15. data/demo_glb/bunny.glb +3 -0
  16. data/demo_glb/desk.glb +3 -0
  17. data/demo_glb/man.glb +3 -0
  18. data/demo_glb/micky.glb +3 -0
  19. data/demo_glb/pac.glb +3 -0
  20. data/demo_glb/robot.glb +3 -0
  21. data/demo_glb/rocket.glb +3 -0
  22. data/demo_glb/sheep.glb +3 -0
  23. data/demo_glb/shelf.glb +3 -0
  24. data/demo_glb/table.glb +3 -0
  25. data/demo_glb/vent.glb +3 -0
  26. data/demo_glb/walkman.glb +3 -0
  27. pre-requirements.txt +36 -0
  28. primitive_anything/__init__.py +0 -0
  29. primitive_anything/michelangelo/__init__.py +51 -0
  30. primitive_anything/michelangelo/data/__init__.py +1 -0
  31. primitive_anything/michelangelo/data/templates.json +69 -0
  32. primitive_anything/michelangelo/data/transforms.py +407 -0
  33. primitive_anything/michelangelo/data/utils.py +59 -0
  34. primitive_anything/michelangelo/graphics/__init__.py +1 -0
  35. primitive_anything/michelangelo/graphics/primitives/__init__.py +9 -0
  36. primitive_anything/michelangelo/graphics/primitives/mesh.py +114 -0
  37. primitive_anything/michelangelo/graphics/primitives/volume.py +21 -0
  38. primitive_anything/michelangelo/models/__init__.py +1 -0
  39. primitive_anything/michelangelo/models/asl_diffusion/__init__.py +1 -0
  40. primitive_anything/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py +483 -0
  41. primitive_anything/michelangelo/models/asl_diffusion/asl_udt.py +104 -0
  42. primitive_anything/michelangelo/models/asl_diffusion/base.py +13 -0
  43. primitive_anything/michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py +393 -0
  44. primitive_anything/michelangelo/models/asl_diffusion/inference_utils.py +80 -0
  45. primitive_anything/michelangelo/models/conditional_encoders/__init__.py +3 -0
  46. primitive_anything/michelangelo/models/conditional_encoders/clip.py +89 -0
  47. primitive_anything/michelangelo/models/conditional_encoders/encoder_factory.py +562 -0
  48. primitive_anything/michelangelo/models/modules/__init__.py +3 -0
  49. primitive_anything/michelangelo/models/modules/checkpoint.py +69 -0
  50. primitive_anything/michelangelo/models/modules/diffusion_transformer.py +218 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/demo_glb/*.glb filter=lfs diff=lfs merge=lfs -text
37
+ assets/*.jpg filter=lfs diff=lfs merge=lfs -text
38
+ data/basic_shapes_norm_pc10000/*.ply filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ **/__pycache__/
2
+ ckpt
3
+ gradio_cached_examples
4
+ results
app.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import glob
4
+ import json
5
+ import yaml
6
+ import torch
7
+ import trimesh
8
+ import argparse
9
+ import mesh2sdf.core
10
+ import numpy as np
11
+ import skimage.measure
12
+ import seaborn as sns
13
+ from scipy.spatial.transform import Rotation
14
+ from mesh_to_sdf import get_surface_point_cloud
15
+ from accelerate.utils import set_seed
16
+ from accelerate import Accelerator
17
+ from huggingface_hub.file_download import hf_hub_download
18
+ from huggingface_hub import list_repo_files
19
+
20
+ from primitive_anything.utils import path_mkdir, count_parameters
21
+ from primitive_anything.utils.logger import print_log
22
+
23
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
24
+
25
+ import spaces
26
+
27
+ repo_id = "hyz317/PrimitiveAnything"
28
+ all_files = list_repo_files(repo_id, revision="main")
29
+ for file in all_files:
30
+ if os.path.exists(file):
31
+ continue
32
+ hf_hub_download(repo_id, file, local_dir="./ckpt")
33
+ hf_hub_download("Maikou/Michelangelo", "checkpoints/aligned_shape_latents/shapevae-256.ckpt", local_dir="./ckpt")
34
+
35
+ def parse_args():
36
+ parser = argparse.ArgumentParser(description='Process 3D model files')
37
+
38
+ parser.add_argument(
39
+ '--input',
40
+ type=str,
41
+ default='./data/demo_glb/',
42
+ help='Input file or directory path (default: ./data/demo_glb/)'
43
+ )
44
+
45
+ parser.add_argument(
46
+ '--log_path',
47
+ type=str,
48
+ default='./results/demo',
49
+ help='Output directory path (default: results/demo)'
50
+ )
51
+
52
+ return parser.parse_args()
53
+
54
+ def get_input_files(input_path):
55
+ if os.path.isfile(input_path):
56
+ return [input_path]
57
+ elif os.path.isdir(input_path):
58
+ return glob.glob(os.path.join(input_path, '*'))
59
+ else:
60
+ raise ValueError(f"Input path {input_path} is neither a file nor a directory")
61
+
62
+ args = parse_args()
63
+
64
+ # Create output directory (keeping your original variable name)
65
+ LOG_PATH = args.log_path
66
+ os.makedirs(LOG_PATH, exist_ok=True)
67
+
68
+ print(f"Output directory: {LOG_PATH}")
69
+
70
+ CODE_SHAPE = {
71
+ 0: 'SM_GR_BS_CubeBevel_001.ply',
72
+ 1: 'SM_GR_BS_SphereSharp_001.ply',
73
+ 2: 'SM_GR_BS_CylinderSharp_001.ply',
74
+ }
75
+
76
+ shapename_map = {
77
+ 'SM_GR_BS_CubeBevel_001.ply': 1101002001034001,
78
+ 'SM_GR_BS_SphereSharp_001.ply': 1101002001034010,
79
+ 'SM_GR_BS_CylinderSharp_001.ply': 1101002001034002,
80
+ }
81
+
82
+ #### config
83
+ bs_dir = 'data/basic_shapes_norm'
84
+ config_path = './configs/infer.yml'
85
+ AR_checkpoint_path = './ckpt/mesh-transformer.ckpt.60.pt'
86
+ temperature= 0.0
87
+ #### init model
88
+ mesh_bs = {}
89
+ for bs_path in glob.glob(os.path.join(bs_dir, '*.ply')):
90
+ bs_name = os.path.basename(bs_path)
91
+ bs = trimesh.load(bs_path)
92
+ bs.visual.uv = np.clip(bs.visual.uv, 0, 1)
93
+ bs.visual = bs.visual.to_color()
94
+ mesh_bs[bs_name] = bs
95
+
96
+ def create_model(cfg_model):
97
+ kwargs = cfg_model
98
+ name = kwargs.pop('name')
99
+ model = get_model(name)(**kwargs)
100
+ print_log("Model '{}' init: nb_params={:,}, kwargs={}".format(name, count_parameters(model), kwargs))
101
+ return model
102
+
103
+ from primitive_anything.primitive_transformer import PrimitiveTransformerDiscrete
104
+ def get_model(name):
105
+ return {
106
+ 'discrete': PrimitiveTransformerDiscrete,
107
+ }[name]
108
+
109
+ with open(config_path, mode='r') as fp:
110
+ AR_train_cfg = yaml.load(fp, Loader=yaml.FullLoader)
111
+
112
+ AR_checkpoint = torch.load(AR_checkpoint_path)
113
+
114
+ transformer = create_model(AR_train_cfg['model'])
115
+ transformer.load_state_dict(AR_checkpoint)
116
+
117
+ device = torch.device('cuda')
118
+ accelerator = Accelerator(
119
+ mixed_precision='fp16',
120
+ )
121
+ transformer = accelerator.prepare(transformer)
122
+ transformer.eval()
123
+ transformer.bs_pc = transformer.bs_pc.cuda()
124
+ transformer.rotation_matrix_align_coord = transformer.rotation_matrix_align_coord.cuda()
125
+ print('model loaded to device')
126
+
127
+
128
+ def sample_surface_points(mesh, number_of_points=500000, surface_point_method='scan', sign_method='normal',
129
+ scan_count=100, scan_resolution=400, sample_point_count=10000000, return_gradients=False,
130
+ return_surface_pc_normals=False, normalized=False):
131
+ sample_start = time.time()
132
+ if surface_point_method == 'sample' and sign_method == 'depth':
133
+ print("Incompatible methods for sampling points and determining sign, using sign_method='normal' instead.")
134
+ sign_method = 'normal'
135
+
136
+ surface_start = time.time()
137
+ bound_radius = 1 if normalized else None
138
+ surface_point_cloud = get_surface_point_cloud(mesh, surface_point_method, bound_radius, scan_count, scan_resolution,
139
+ sample_point_count,
140
+ calculate_normals=sign_method == 'normal' or return_gradients)
141
+
142
+ surface_end = time.time()
143
+ print('surface point cloud time cost :', surface_end - surface_start)
144
+
145
+ normal_start = time.time()
146
+ if return_surface_pc_normals:
147
+ rng = np.random.default_rng()
148
+ assert surface_point_cloud.points.shape[0] == surface_point_cloud.normals.shape[0]
149
+ indices = rng.choice(surface_point_cloud.points.shape[0], number_of_points, replace=True)
150
+ points = surface_point_cloud.points[indices]
151
+ normals = surface_point_cloud.normals[indices]
152
+ surface_points = np.concatenate([points, normals], axis=-1)
153
+ else:
154
+ surface_points = surface_point_cloud.get_random_surface_points(number_of_points, use_scans=True)
155
+ normal_end = time.time()
156
+ print('normal time cost :', normal_end - normal_start)
157
+ sample_end = time.time()
158
+ print('sample surface point time cost :', sample_end - sample_start)
159
+ return surface_points
160
+
161
+
162
+ def normalize_vertices(vertices, scale=0.9):
163
+ bbmin, bbmax = vertices.min(0), vertices.max(0)
164
+ center = (bbmin + bbmax) * 0.5
165
+ scale = 2.0 * scale / (bbmax - bbmin).max()
166
+ vertices = (vertices - center) * scale
167
+ return vertices, center, scale
168
+
169
+
170
+ def export_to_watertight(normalized_mesh, octree_depth: int = 7):
171
+ """
172
+ Convert the non-watertight mesh to watertight.
173
+
174
+ Args:
175
+ input_path (str): normalized path
176
+ octree_depth (int):
177
+
178
+ Returns:
179
+ mesh(trimesh.Trimesh): watertight mesh
180
+
181
+ """
182
+ size = 2 ** octree_depth
183
+ level = 2 / size
184
+
185
+ scaled_vertices, to_orig_center, to_orig_scale = normalize_vertices(normalized_mesh.vertices)
186
+ sdf = mesh2sdf.core.compute(scaled_vertices, normalized_mesh.faces, size=size)
187
+ vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), level)
188
+
189
+ # watertight mesh
190
+ vertices = vertices / size * 2 - 1 # -1 to 1
191
+ vertices = vertices / to_orig_scale + to_orig_center
192
+ mesh = trimesh.Trimesh(vertices, faces, normals=normals)
193
+
194
+ return mesh
195
+
196
+
197
+ def process_mesh_to_surface_pc(mesh_list, marching_cubes=False, dilated_offset=0.0, sample_num=10000):
198
+ # mesh_list : list of trimesh
199
+ pc_normal_list = []
200
+ return_mesh_list = []
201
+ for mesh in mesh_list:
202
+ if marching_cubes:
203
+ mesh = export_to_watertight(mesh)
204
+ print("MC over!")
205
+ if dilated_offset > 0:
206
+ new_vertices = mesh.vertices + mesh.vertex_normals * dilated_offset
207
+ mesh.vertices = new_vertices
208
+ print("dilate over!")
209
+
210
+ mesh.merge_vertices()
211
+ mesh.update_faces(mesh.unique_faces())
212
+ mesh.fix_normals()
213
+
214
+ return_mesh_list.append(mesh)
215
+
216
+ pc_normal = np.asarray(sample_surface_points(mesh, sample_num, return_surface_pc_normals=True))
217
+ pc_normal_list.append(pc_normal)
218
+ print("process mesh success")
219
+ return pc_normal_list, return_mesh_list
220
+
221
+
222
+ #### utils
223
+ def euler_to_quat(euler):
224
+ return Rotation.from_euler('XYZ', euler, degrees=True).as_quat()
225
+
226
+ def SRT_quat_to_matrix(scale, quat, translation):
227
+ rotation_matrix = Rotation.from_quat(quat).as_matrix()
228
+ transform_matrix = np.eye(4)
229
+ transform_matrix[:3, :3] = rotation_matrix * scale
230
+ transform_matrix[:3, 3] = translation
231
+ return transform_matrix
232
+
233
+
234
+ def write_output(primitives, name):
235
+ out_json = {}
236
+ out_json['operation'] = 0
237
+ out_json['type'] = 1
238
+ out_json['scene_id'] = None
239
+
240
+ new_group = []
241
+ model_scene = trimesh.Scene()
242
+ color_map = sns.color_palette("hls", primitives['type_code'].squeeze().shape[0])
243
+ color_map = (np.array(color_map) * 255).astype("uint8")
244
+ for idx, (scale, rotation, translation, type_code) in enumerate(zip(
245
+ primitives['scale'].squeeze().cpu().numpy(),
246
+ primitives['rotation'].squeeze().cpu().numpy(),
247
+ primitives['translation'].squeeze().cpu().numpy(),
248
+ primitives['type_code'].squeeze().cpu().numpy()
249
+ )):
250
+ if type_code == -1:
251
+ break
252
+ bs_name = CODE_SHAPE[type_code]
253
+ new_block = {}
254
+ new_block['type_id'] = shapename_map[bs_name]
255
+ new_block['data'] = {}
256
+ new_block['data']['location'] = translation.tolist()
257
+ new_block['data']['rotation'] = euler_to_quat(rotation).tolist()
258
+ new_block['data']['scale'] = scale.tolist()
259
+ new_block['data']['color'] = ['808080']
260
+ new_group.append(new_block)
261
+
262
+ trans = SRT_quat_to_matrix(scale, euler_to_quat(rotation), translation)
263
+ bs = mesh_bs[bs_name].copy().apply_transform(trans)
264
+ new_vertex_colors = np.repeat(color_map[idx:idx+1], bs.visual.vertex_colors.shape[0], axis=0)
265
+ bs.visual.vertex_colors[:, :3] = new_vertex_colors
266
+ vertices = bs.vertices.copy()
267
+ vertices[:, 1] = bs.vertices[:, 2]
268
+ vertices[:, 2] = -bs.vertices[:, 1]
269
+ bs.vertices = vertices
270
+ model_scene.add_geometry(bs)
271
+ out_json['group'] = new_group
272
+
273
+ json_path = os.path.join(LOG_PATH, f'output_{name}.json')
274
+ with open(json_path, 'w') as json_file:
275
+ json.dump(out_json, json_file, indent=4)
276
+
277
+ glb_path = os.path.join(LOG_PATH, f'output_{name}.glb')
278
+ model_scene.export(glb_path)
279
+
280
+ return glb_path, out_json
281
+
282
+
283
+ @torch.no_grad()
284
+ def do_inference(input_3d, dilated_offset=0.0, sample_seed=0, do_sampling=False, do_marching_cubes=False, postprocess='none'):
285
+ t1 = time.time()
286
+ set_seed(sample_seed)
287
+ input_mesh = trimesh.load(input_3d, force='mesh')
288
+
289
+ # scale mesh
290
+ vertices = input_mesh.vertices
291
+ bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)])
292
+ vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2
293
+ vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6
294
+ input_mesh.vertices = vertices
295
+
296
+ pc_list, mesh_list = process_mesh_to_surface_pc(
297
+ [input_mesh],
298
+ marching_cubes=do_marching_cubes,
299
+ dilated_offset=dilated_offset
300
+ )
301
+ pc_normal = pc_list[0] # 10000, 6
302
+ mesh = mesh_list[0]
303
+
304
+ pc_coor = pc_normal[:, :3]
305
+ normals = pc_normal[:, 3:]
306
+
307
+ if dilated_offset > 0:
308
+ # scale mesh and pc
309
+ vertices = mesh.vertices
310
+ bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)])
311
+ vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2
312
+ vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6
313
+ mesh.vertices = vertices
314
+ pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2
315
+ pc_coor = pc_coor / (bounds[1] - bounds[0]).max() * 1.6
316
+
317
+ input_save_name = os.path.join(LOG_PATH, f'processed_{os.path.basename(input_3d)}')
318
+ mesh.export(input_save_name)
319
+
320
+ assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), 'normals should be unit vectors, something wrong'
321
+ normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16)
322
+
323
+ input_pc = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None]
324
+
325
+ with accelerator.autocast():
326
+ if postprocess == 'postprocess1':
327
+ recon_primitives, mask = transformer.generate_w_recon_loss(pc=input_pc, temperature=temperature, single_directional=True)
328
+ else:
329
+ recon_primitives, mask = transformer.generate(pc=input_pc, temperature=temperature)
330
+
331
+ output_glb, output_json = write_output(recon_primitives, os.path.basename(input_3d)[:-4])
332
+
333
+ return input_save_name, output_glb, output_json
334
+
335
+
336
+ import gradio as gr
337
+
338
+ @spaces.GPU
339
+ def process_3d_model(input_3d, dilated_offset, do_marching_cubes, postprocess_method="postprocess1"):
340
+ print(f"processing: {input_3d}")
341
+ # try:
342
+ preprocess_model_obj, output_model_obj, output_model_json = do_inference(
343
+ input_3d,
344
+ dilated_offset=dilated_offset,
345
+ do_marching_cubes=do_marching_cubes,
346
+ postprocess=postprocess_method
347
+ )
348
+ return output_model_obj
349
+ # except Exception as e:
350
+ # return f"Error processing file: {str(e)}"
351
+
352
+ # Title and reminder placeholders
353
+ title = "3D Model Processing Demo"
354
+ reminder = "Please upload your 3D model file and adjust parameters as needed."
355
+
356
+ with gr.Blocks(title=title) as demo:
357
+ # Title section
358
+ gr.Markdown(f"# {title}")
359
+ gr.Markdown(reminder)
360
+
361
+ with gr.Row():
362
+ with gr.Column():
363
+ # Input components
364
+ input_3d = gr.Model3D(label="Upload 3D Model File")
365
+ dilated_offset = gr.Number(label="Dilated Offset", value=0.015)
366
+ do_marching_cubes = gr.Checkbox(label="Perform Marching Cubes", value=True)
367
+ submit_btn = gr.Button("Process Model")
368
+
369
+ with gr.Column():
370
+ # Output components
371
+ output = gr.Model3D(label="Primitive Assembly Predition")
372
+
373
+ submit_btn.click(
374
+ fn=process_3d_model,
375
+ inputs=[input_3d, dilated_offset, do_marching_cubes],
376
+ outputs=output
377
+ )
378
+
379
+
380
+ # Prepare examples properly
381
+ example_files = [ [f] for f in glob.glob('./data/demo_glb/*.glb') ] # Note: wrapped in list and filtered for GLB
382
+
383
+ example = gr.Examples(
384
+ examples=example_files,
385
+ inputs=[input_3d], # Only include the Model3D input
386
+ examples_per_page=14,
387
+ )
388
+
389
+ if __name__ == "__main__":
390
+ demo.launch()
assets/teaser.jpg ADDED

Git LFS Details

  • SHA256: ae89c8078e3379126ff0f0723ee4598a99ac7515d8ac70d62f479b243d80b792
  • Pointer size: 132 Bytes
  • Size of remote file: 1.41 MB
configs/infer.yml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: base
3
+ pc_dir: ./data/test_pc
4
+ bs_dir: data/basic_shapes_norm
5
+ max_length: 144
6
+ range_scale: [0, 1]
7
+ range_rotation: [-180, 180]
8
+ range_translation: [-1, 1]
9
+ rotation_type: euler
10
+ pc_format: pn
11
+ model:
12
+ attn_depth: 6
13
+ attn_heads: 6
14
+ bin_smooth_blur_sigma: -1
15
+ bs_pc_dir: data/basic_shapes_norm_pc10000
16
+ coarse_pre_gateloop_depth: 3
17
+ continuous_range_rotation:
18
+ - -181
19
+ - 181
20
+ continuous_range_scale:
21
+ - 0
22
+ - 1
23
+ continuous_range_translation:
24
+ - -1
25
+ - 1
26
+ dim: 768
27
+ dim_rotation_embed: 16
28
+ dim_scale_embed: 16
29
+ dim_translation_embed: 16
30
+ dim_type_embed: 48
31
+ dropout: 0.0
32
+ embed_order: ctrs
33
+ gateloop_use_heinsen: false
34
+ loss_weight:
35
+ eos: 1.0
36
+ reconstruction: 1.0
37
+ rotation: 1.0
38
+ scale: 1.0
39
+ translation: 1.0
40
+ type: 1.0
41
+ max_primitive_len: 144
42
+ name: discrete
43
+ num_discrete_rotation: 181
44
+ num_discrete_scale: 128
45
+ num_discrete_translation: 128
46
+ num_type: 3
47
+ shape_cond_with_cat: true
48
+ shape_cond_with_cross_attn: false
49
+ shape_cond_with_film: false
50
+ shape_condition_dim: 768
51
+ shape_condition_len: 77
52
+ shape_condition_model_type: michelangelo
data/basic_shapes_norm/SM_GR_BS_CubeBevel_001.ply ADDED
Binary file (10.1 kB). View file
 
data/basic_shapes_norm/SM_GR_BS_CylinderSharp_001.ply ADDED
Binary file (4.96 kB). View file
 
data/basic_shapes_norm/SM_GR_BS_SphereSharp_001.ply ADDED
Binary file (27.5 kB). View file
 
data/basic_shapes_norm/basic_shapes.json ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "SM_GR_BS_CubeBevel_001.ply": {
3
+ "name": "SM_GR_BS_CubeBevel_001.ply",
4
+ "tform_bs_to_normalized": [
5
+ [
6
+ 0.02,
7
+ 0.0,
8
+ 0.0,
9
+ 0.0
10
+ ],
11
+ [
12
+ 0.0,
13
+ 0.02,
14
+ 0.0,
15
+ 9.701276818911235e-18
16
+ ],
17
+ [
18
+ 0.0,
19
+ 0.0,
20
+ 0.019999999999999997,
21
+ -0.9999999999999999
22
+ ],
23
+ [
24
+ 0.0,
25
+ 0.0,
26
+ 0.0,
27
+ 1.0
28
+ ]
29
+ ]
30
+ },
31
+ "SM_GR_BS_CylinderSharp_001.ply": {
32
+ "name": "SM_GR_BS_CylinderSharp_001.ply",
33
+ "tform_bs_to_normalized": [
34
+ [
35
+ 0.006666668023003748,
36
+ 0.0,
37
+ 0.0,
38
+ -2.0345056221459462e-07
39
+ ],
40
+ [
41
+ 0.0,
42
+ 0.006666667683919426,
43
+ 0.0,
44
+ -5.086263794939386e-08
45
+ ],
46
+ [
47
+ 0.0,
48
+ 0.0,
49
+ 0.006666665445429783,
50
+ -0.9999998370794186
51
+ ],
52
+ [
53
+ 0.0,
54
+ 0.0,
55
+ 0.0,
56
+ 1.0
57
+ ]
58
+ ]
59
+ },
60
+ "SM_GR_BS_SphereSharp_001.ply": {
61
+ "name": "SM_GR_BS_SphereSharp_001.ply",
62
+ "tform_bs_to_normalized": [
63
+ [
64
+ 0.006666666666666667,
65
+ 0.0,
66
+ 0.0,
67
+ 0.0
68
+ ],
69
+ [
70
+ 0.0,
71
+ 0.006666666666666667,
72
+ 0.0,
73
+ 0.0
74
+ ],
75
+ [
76
+ 0.0,
77
+ 0.0,
78
+ 0.006666666666666667,
79
+ -1.0
80
+ ],
81
+ [
82
+ 0.0,
83
+ 0.0,
84
+ 0.0,
85
+ 1.0
86
+ ]
87
+ ]
88
+ }
89
+ }
data/basic_shapes_norm_pc10000/SM_GR_BS_CubeBevel_001.ply ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba980c1fb389e30783f09b07d35e788e08a97776d933b8bfd346147c9a7e86a0
3
+ size 510265
data/basic_shapes_norm_pc10000/SM_GR_BS_CylinderSharp_001.ply ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab8fb7aa7ec39237474d0a6e77da1d7070742f61af21e6b44dc9998fac1913cc
3
+ size 510265
data/basic_shapes_norm_pc10000/SM_GR_BS_SphereSharp_001.ply ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8765da7294292422d077267c1b71b9ea055f831aab3840d869656632ee6e8569
3
+ size 510265
data/demo_glb/barbell.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a9b9c124c321d6d18342b12407fc7327bdd56c8720d317e7b8694c10c851936
3
+ size 769528
data/demo_glb/book.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be526e9bca2ce3a74387f2dde7f6a25c9502a7c6d4f9fc671b244d09c18a9d94
3
+ size 5369916
data/demo_glb/bunny.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f62b2169b7cda3662de660d0b2e8ce2a1ccfe3bd186f243d890440e8cf7a0766
3
+ size 27518016
data/demo_glb/desk.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c8fd8041f1e870ba285572f3fb3e129107678ab9a311524a8376cc404cc332e
3
+ size 33679548
data/demo_glb/man.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:063a56d0a56d3866cf36170bbafad93924fd350882ac0f69e727cb43dc203351
3
+ size 31784
data/demo_glb/micky.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc6410c2f2a588c5b064f9255a1c1657a9dff061ae6f7342df693c80eef0c69d
3
+ size 294576
data/demo_glb/pac.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc39816cc71440fbc31d99c24f24ed42c25daf7319ba07b8dd3e34c1ea083578
3
+ size 274004
data/demo_glb/robot.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf201fbe21d73428f88e4a7e428849148ebd500cc4ce6ac3929638a53c5376ae
3
+ size 28116940
data/demo_glb/rocket.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29601d218e0d51a4dced2f1ed2a898e80f0d223b1e04212d45b0dda4ad670d1c
3
+ size 1426588
data/demo_glb/sheep.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:232b1303e56ec1682536c72bc9409585930985492dcbdfa101cdfb96d0b4fbf2
3
+ size 28732
data/demo_glb/shelf.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d63916c1ef1d5b2fc4d56e20c76316bd977f93819323f73e7f5e1c59df21e284
3
+ size 3091336
data/demo_glb/table.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:743c96d7aa1bef88576c5f28d5e06144f93e27a2e6ea5ef8bd85669d1213af9f
3
+ size 20093692
data/demo_glb/vent.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98d8d0d2c8d164fc75361d7d1408e9102f9148482c0343e0cc87d21950e20ab1
3
+ size 1785468
data/demo_glb/walkman.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:992355cf45609881223561d1081a05483c2bad488ed7148f87243259aa36be1b
3
+ size 158156
pre-requirements.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+ --extra-index-url https://data.dgl.ai/wheels/torch-2.1/cu121/repo.html
3
+ torch==2.2.0
4
+ torchvision==0.17.0
5
+ dgl
6
+ accelerate
7
+ beartype
8
+ einops
9
+ gateloop_transformer
10
+ matplotlib
11
+ scikit-learn
12
+ pandas
13
+ pytorch_custom_utils
14
+ gradio
15
+ pydantic==2.10.6
16
+ x_transformers
17
+ torch_redstone
18
+ torchdata==0.9.0
19
+ toolz
20
+ environs
21
+ jaxtyping
22
+ omegaconf
23
+ ema_pytorch
24
+ local_attention==1.9.15
25
+ taylor_series_linear_attention
26
+ transformers
27
+ vector_quantize_pytorch
28
+ open3d
29
+ trimesh
30
+ pytorch_lightning
31
+ scikit-image
32
+ opencv-python
33
+ mesh2sdf
34
+ seaborn
35
+ mesh_to_sdf
36
+ point_cloud_utils
primitive_anything/__init__.py ADDED
File without changes
primitive_anything/michelangelo/__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from omegaconf import OmegaConf
4
+ import torch
5
+ from torch import nn
6
+
7
+ from .utils.misc import instantiate_from_config
8
+ from ..utils import default, exists
9
+
10
+
11
+ def load_model():
12
+ model_config = OmegaConf.load(os.path.join(os.path.dirname(__file__), "shapevae-256.yaml"))
13
+ # print(model_config)
14
+ if hasattr(model_config, "model"):
15
+ model_config = model_config.model
16
+ ckpt_path = "./ckpt/shapevae-256.ckpt"
17
+
18
+ model = instantiate_from_config(model_config, ckpt_path=ckpt_path)
19
+ # model = model.cuda()
20
+ model = model.eval()
21
+
22
+ return model
23
+
24
+
25
+ class ShapeConditioner(nn.Module):
26
+ def __init__(
27
+ self,
28
+ *,
29
+ dim_latent = None
30
+ ):
31
+ super().__init__()
32
+ self.model = load_model()
33
+
34
+ self.dim_model_out = 768
35
+ dim_latent = default(dim_latent, self.dim_model_out)
36
+ self.dim_latent = dim_latent
37
+
38
+ def forward(
39
+ self,
40
+ shape = None,
41
+ shape_embed = None,
42
+ ):
43
+ assert exists(shape) ^ exists(shape_embed)
44
+
45
+ if not exists(shape_embed):
46
+ point_feature = self.model.encode_latents(shape)
47
+ shape_latents = self.model.to_shape_latents(point_feature[:, 1:])
48
+ shape_head = point_feature[:, 0:1]
49
+ shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-1)
50
+ # shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-2) # cat tmp
51
+ return shape_head, shape_embed
primitive_anything/michelangelo/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
primitive_anything/michelangelo/data/templates.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "shape": [
3
+ "a point cloud model of {}.",
4
+ "There is a {} in the scene.",
5
+ "There is the {} in the scene.",
6
+ "a photo of a {} in the scene.",
7
+ "a photo of the {} in the scene.",
8
+ "a photo of one {} in the scene.",
9
+ "itap of a {}.",
10
+ "itap of my {}.",
11
+ "itap of the {}.",
12
+ "a photo of a {}.",
13
+ "a photo of my {}.",
14
+ "a photo of the {}.",
15
+ "a photo of one {}.",
16
+ "a photo of many {}.",
17
+ "a good photo of a {}.",
18
+ "a good photo of the {}.",
19
+ "a bad photo of a {}.",
20
+ "a bad photo of the {}.",
21
+ "a photo of a nice {}.",
22
+ "a photo of the nice {}.",
23
+ "a photo of a cool {}.",
24
+ "a photo of the cool {}.",
25
+ "a photo of a weird {}.",
26
+ "a photo of the weird {}.",
27
+ "a photo of a small {}.",
28
+ "a photo of the small {}.",
29
+ "a photo of a large {}.",
30
+ "a photo of the large {}.",
31
+ "a photo of a clean {}.",
32
+ "a photo of the clean {}.",
33
+ "a photo of a dirty {}.",
34
+ "a photo of the dirty {}.",
35
+ "a bright photo of a {}.",
36
+ "a bright photo of the {}.",
37
+ "a dark photo of a {}.",
38
+ "a dark photo of the {}.",
39
+ "a photo of a hard to see {}.",
40
+ "a photo of the hard to see {}.",
41
+ "a low resolution photo of a {}.",
42
+ "a low resolution photo of the {}.",
43
+ "a cropped photo of a {}.",
44
+ "a cropped photo of the {}.",
45
+ "a close-up photo of a {}.",
46
+ "a close-up photo of the {}.",
47
+ "a jpeg corrupted photo of a {}.",
48
+ "a jpeg corrupted photo of the {}.",
49
+ "a blurry photo of a {}.",
50
+ "a blurry photo of the {}.",
51
+ "a pixelated photo of a {}.",
52
+ "a pixelated photo of the {}.",
53
+ "a black and white photo of the {}.",
54
+ "a black and white photo of a {}",
55
+ "a plastic {}.",
56
+ "the plastic {}.",
57
+ "a toy {}.",
58
+ "the toy {}.",
59
+ "a plushie {}.",
60
+ "the plushie {}.",
61
+ "a cartoon {}.",
62
+ "the cartoon {}.",
63
+ "an embroidered {}.",
64
+ "the embroidered {}.",
65
+ "a painting of the {}.",
66
+ "a painting of a {}."
67
+ ]
68
+
69
+ }
primitive_anything/michelangelo/data/transforms.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import time
4
+ import numpy as np
5
+ import warnings
6
+ import random
7
+ from omegaconf.listconfig import ListConfig
8
+ from webdataset import pipelinefilter
9
+ import torch
10
+ import torchvision.transforms.functional as TVF
11
+ from torchvision.transforms import InterpolationMode
12
+ from torchvision.transforms.transforms import _interpolation_modes_from_int
13
+ from typing import Sequence
14
+
15
+ from ..utils import instantiate_from_config
16
+
17
+
18
+ def _uid_buffer_pick(buf_dict, rng):
19
+ uid_keys = list(buf_dict.keys())
20
+ selected_uid = rng.choice(uid_keys)
21
+ buf = buf_dict[selected_uid]
22
+
23
+ k = rng.randint(0, len(buf) - 1)
24
+ sample = buf[k]
25
+ buf[k] = buf[-1]
26
+ buf.pop()
27
+
28
+ if len(buf) == 0:
29
+ del buf_dict[selected_uid]
30
+
31
+ return sample
32
+
33
+
34
+ def _add_to_buf_dict(buf_dict, sample):
35
+ key = sample["__key__"]
36
+ uid, uid_sample_id = key.split("_")
37
+ if uid not in buf_dict:
38
+ buf_dict[uid] = []
39
+ buf_dict[uid].append(sample)
40
+
41
+ return buf_dict
42
+
43
+
44
+ def _uid_shuffle(data, bufsize=1000, initial=100, rng=None, handler=None):
45
+ """Shuffle the data in the stream.
46
+
47
+ This uses a buffer of size `bufsize`. Shuffling at
48
+ startup is less random; this is traded off against
49
+ yielding samples quickly.
50
+
51
+ data: iterator
52
+ bufsize: buffer size for shuffling
53
+ returns: iterator
54
+ rng: either random module or random.Random instance
55
+
56
+ """
57
+ if rng is None:
58
+ rng = random.Random(int((os.getpid() + time.time()) * 1e9))
59
+ initial = min(initial, bufsize)
60
+ buf_dict = dict()
61
+ current_samples = 0
62
+ for sample in data:
63
+ _add_to_buf_dict(buf_dict, sample)
64
+ current_samples += 1
65
+
66
+ if current_samples < bufsize:
67
+ try:
68
+ _add_to_buf_dict(buf_dict, next(data)) # skipcq: PYL-R1708
69
+ current_samples += 1
70
+ except StopIteration:
71
+ pass
72
+
73
+ if current_samples >= initial:
74
+ current_samples -= 1
75
+ yield _uid_buffer_pick(buf_dict, rng)
76
+
77
+ while current_samples > 0:
78
+ current_samples -= 1
79
+ yield _uid_buffer_pick(buf_dict, rng)
80
+
81
+
82
+ uid_shuffle = pipelinefilter(_uid_shuffle)
83
+
84
+
85
+ class RandomSample(object):
86
+ def __init__(self,
87
+ num_volume_samples: int = 1024,
88
+ num_near_samples: int = 1024):
89
+
90
+ super().__init__()
91
+
92
+ self.num_volume_samples = num_volume_samples
93
+ self.num_near_samples = num_near_samples
94
+
95
+ def __call__(self, sample):
96
+ rng = np.random.default_rng()
97
+
98
+ # 1. sample surface input
99
+ total_surface = sample["surface"]
100
+ ind = rng.choice(total_surface.shape[0], replace=False)
101
+ surface = total_surface[ind]
102
+
103
+ # 2. sample volume/near geometric points
104
+ vol_points = sample["vol_points"]
105
+ vol_label = sample["vol_label"]
106
+ near_points = sample["near_points"]
107
+ near_label = sample["near_label"]
108
+
109
+ ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False)
110
+ vol_points = vol_points[ind]
111
+ vol_label = vol_label[ind]
112
+ vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1)
113
+
114
+ ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False)
115
+ near_points = near_points[ind]
116
+ near_label = near_label[ind]
117
+ near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1)
118
+
119
+ # concat sampled volume and near points
120
+ geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0)
121
+
122
+ sample = {
123
+ "surface": surface,
124
+ "geo_points": geo_points
125
+ }
126
+
127
+ return sample
128
+
129
+
130
+ class SplitRandomSample(object):
131
+ def __init__(self,
132
+ use_surface_sample: bool = False,
133
+ num_surface_samples: int = 4096,
134
+ num_volume_samples: int = 1024,
135
+ num_near_samples: int = 1024):
136
+
137
+ super().__init__()
138
+
139
+ self.use_surface_sample = use_surface_sample
140
+ self.num_surface_samples = num_surface_samples
141
+ self.num_volume_samples = num_volume_samples
142
+ self.num_near_samples = num_near_samples
143
+
144
+ def __call__(self, sample):
145
+
146
+ rng = np.random.default_rng()
147
+
148
+ # 1. sample surface input
149
+ surface = sample["surface"]
150
+
151
+ if self.use_surface_sample:
152
+ replace = surface.shape[0] < self.num_surface_samples
153
+ ind = rng.choice(surface.shape[0], self.num_surface_samples, replace=replace)
154
+ surface = surface[ind]
155
+
156
+ # 2. sample volume/near geometric points
157
+ vol_points = sample["vol_points"]
158
+ vol_label = sample["vol_label"]
159
+ near_points = sample["near_points"]
160
+ near_label = sample["near_label"]
161
+
162
+ ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False)
163
+ vol_points = vol_points[ind]
164
+ vol_label = vol_label[ind]
165
+ vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1)
166
+
167
+ ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False)
168
+ near_points = near_points[ind]
169
+ near_label = near_label[ind]
170
+ near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1)
171
+
172
+ # concat sampled volume and near points
173
+ geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0)
174
+
175
+ sample = {
176
+ "surface": surface,
177
+ "geo_points": geo_points
178
+ }
179
+
180
+ return sample
181
+
182
+
183
+ class FeatureSelection(object):
184
+
185
+ VALID_SURFACE_FEATURE_DIMS = {
186
+ "none": [0, 1, 2], # xyz
187
+ "watertight_normal": [0, 1, 2, 3, 4, 5], # xyz, normal
188
+ "normal": [0, 1, 2, 6, 7, 8]
189
+ }
190
+
191
+ def __init__(self, surface_feature_type: str):
192
+
193
+ self.surface_feature_type = surface_feature_type
194
+ self.surface_dims = self.VALID_SURFACE_FEATURE_DIMS[surface_feature_type]
195
+
196
+ def __call__(self, sample):
197
+ sample["surface"] = sample["surface"][:, self.surface_dims]
198
+ return sample
199
+
200
+
201
+ class AxisScaleTransform(object):
202
+ def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005):
203
+ assert isinstance(interval, (tuple, list, ListConfig))
204
+ self.interval = interval
205
+ self.min_val = interval[0]
206
+ self.max_val = interval[1]
207
+ self.inter_size = interval[1] - interval[0]
208
+ self.jitter = jitter
209
+ self.jitter_scale = jitter_scale
210
+
211
+ def __call__(self, sample):
212
+
213
+ surface = sample["surface"][..., 0:3]
214
+ geo_points = sample["geo_points"][..., 0:3]
215
+
216
+ scaling = torch.rand(1, 3) * self.inter_size + self.min_val
217
+ # print(scaling)
218
+ surface = surface * scaling
219
+ geo_points = geo_points * scaling
220
+
221
+ scale = (1 / torch.abs(surface).max().item()) * 0.999999
222
+ surface *= scale
223
+ geo_points *= scale
224
+
225
+ if self.jitter:
226
+ surface += self.jitter_scale * torch.randn_like(surface)
227
+ surface.clamp_(min=-1.015, max=1.015)
228
+
229
+ sample["surface"][..., 0:3] = surface
230
+ sample["geo_points"][..., 0:3] = geo_points
231
+
232
+ return sample
233
+
234
+
235
+ class ToTensor(object):
236
+
237
+ def __init__(self, tensor_keys=("surface", "geo_points", "tex_points")):
238
+ self.tensor_keys = tensor_keys
239
+
240
+ def __call__(self, sample):
241
+ for key in self.tensor_keys:
242
+ if key not in sample:
243
+ continue
244
+
245
+ sample[key] = torch.tensor(sample[key], dtype=torch.float32)
246
+
247
+ return sample
248
+
249
+
250
+ class AxisScale(object):
251
+ def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005):
252
+ assert isinstance(interval, (tuple, list, ListConfig))
253
+ self.interval = interval
254
+ self.jitter = jitter
255
+ self.jitter_scale = jitter_scale
256
+
257
+ def __call__(self, surface, *args):
258
+ scaling = torch.rand(1, 3) * 0.5 + 0.75
259
+ # print(scaling)
260
+ surface = surface * scaling
261
+ scale = (1 / torch.abs(surface).max().item()) * 0.999999
262
+ surface *= scale
263
+
264
+ args_outputs = []
265
+ for _arg in args:
266
+ _arg = _arg * scaling * scale
267
+ args_outputs.append(_arg)
268
+
269
+ if self.jitter:
270
+ surface += self.jitter_scale * torch.randn_like(surface)
271
+ surface.clamp_(min=-1, max=1)
272
+
273
+ if len(args) == 0:
274
+ return surface
275
+ else:
276
+ return surface, *args_outputs
277
+
278
+
279
+ class RandomResize(torch.nn.Module):
280
+ """Apply randomly Resize with a given probability."""
281
+
282
+ def __init__(
283
+ self,
284
+ size,
285
+ resize_radio=(0.5, 1),
286
+ allow_resize_interpolations=(InterpolationMode.BICUBIC, InterpolationMode.BILINEAR, InterpolationMode.BILINEAR),
287
+ interpolation=InterpolationMode.BICUBIC,
288
+ max_size=None,
289
+ antialias=None,
290
+ ):
291
+ super().__init__()
292
+ if not isinstance(size, (int, Sequence)):
293
+ raise TypeError(f"Size should be int or sequence. Got {type(size)}")
294
+ if isinstance(size, Sequence) and len(size) not in (1, 2):
295
+ raise ValueError("If size is a sequence, it should have 1 or 2 values")
296
+
297
+ self.size = size
298
+ self.max_size = max_size
299
+ # Backward compatibility with integer value
300
+ if isinstance(interpolation, int):
301
+ warnings.warn(
302
+ "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
303
+ "Please use InterpolationMode enum."
304
+ )
305
+ interpolation = _interpolation_modes_from_int(interpolation)
306
+
307
+ self.interpolation = interpolation
308
+ self.antialias = antialias
309
+
310
+ self.resize_radio = resize_radio
311
+ self.allow_resize_interpolations = allow_resize_interpolations
312
+
313
+ def random_resize_params(self):
314
+ radio = torch.rand(1) * (self.resize_radio[1] - self.resize_radio[0]) + self.resize_radio[0]
315
+
316
+ if isinstance(self.size, int):
317
+ size = int(self.size * radio)
318
+ elif isinstance(self.size, Sequence):
319
+ size = list(self.size)
320
+ size = (int(size[0] * radio), int(size[1] * radio))
321
+ else:
322
+ raise RuntimeError()
323
+
324
+ interpolation = self.allow_resize_interpolations[
325
+ torch.randint(low=0, high=len(self.allow_resize_interpolations), size=(1,))
326
+ ]
327
+ return size, interpolation
328
+
329
+ def forward(self, img):
330
+ size, interpolation = self.random_resize_params()
331
+ img = TVF.resize(img, size, interpolation, self.max_size, self.antialias)
332
+ img = TVF.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
333
+ return img
334
+
335
+ def __repr__(self) -> str:
336
+ detail = f"(size={self.size}, interpolation={self.interpolation.value},"
337
+ detail += f"max_size={self.max_size}, antialias={self.antialias}), resize_radio={self.resize_radio}"
338
+ return f"{self.__class__.__name__}{detail}"
339
+
340
+
341
+ class Compose(object):
342
+ """Composes several transforms together. This transform does not support torchscript.
343
+ Please, see the note below.
344
+
345
+ Args:
346
+ transforms (list of ``Transform`` objects): list of transforms to compose.
347
+
348
+ Example:
349
+ >>> transforms.Compose([
350
+ >>> transforms.CenterCrop(10),
351
+ >>> transforms.ToTensor(),
352
+ >>> ])
353
+
354
+ .. note::
355
+ In order to script the transformations, please use ``torch.nn.Sequential`` as below.
356
+
357
+ >>> transforms = torch.nn.Sequential(
358
+ >>> transforms.CenterCrop(10),
359
+ >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
360
+ >>> )
361
+ >>> scripted_transforms = torch.jit.script(transforms)
362
+
363
+ Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
364
+ `lambda` functions or ``PIL.Image``.
365
+
366
+ """
367
+
368
+ def __init__(self, transforms):
369
+ self.transforms = transforms
370
+
371
+ def __call__(self, *args):
372
+ for t in self.transforms:
373
+ args = t(*args)
374
+ return args
375
+
376
+ def __repr__(self):
377
+ format_string = self.__class__.__name__ + '('
378
+ for t in self.transforms:
379
+ format_string += '\n'
380
+ format_string += ' {0}'.format(t)
381
+ format_string += '\n)'
382
+ return format_string
383
+
384
+
385
+ def identity(*args, **kwargs):
386
+ if len(args) == 1:
387
+ return args[0]
388
+ else:
389
+ return args
390
+
391
+
392
+ def build_transforms(cfg):
393
+
394
+ if cfg is None:
395
+ return identity
396
+
397
+ transforms = []
398
+
399
+ for transform_name, cfg_instance in cfg.items():
400
+ transform_instance = instantiate_from_config(cfg_instance)
401
+ transforms.append(transform_instance)
402
+ print(f"Build transform: {transform_instance}")
403
+
404
+ transforms = Compose(transforms)
405
+
406
+ return transforms
407
+
primitive_anything/michelangelo/data/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+
7
+ def worker_init_fn(_):
8
+ worker_info = torch.utils.data.get_worker_info()
9
+ worker_id = worker_info.id
10
+
11
+ # dataset = worker_info.dataset
12
+ # split_size = dataset.num_records // worker_info.num_workers
13
+ # # reset num_records to the true number to retain reliable length information
14
+ # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
15
+ # current_id = np.random.choice(len(np.random.get_state()[1]), 1)
16
+ # return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
17
+
18
+ return np.random.seed(np.random.get_state()[1][0] + worker_id)
19
+
20
+
21
+ def collation_fn(samples, combine_tensors=True, combine_scalars=True):
22
+ """
23
+
24
+ Args:
25
+ samples (list[dict]):
26
+ combine_tensors:
27
+ combine_scalars:
28
+
29
+ Returns:
30
+
31
+ """
32
+
33
+ result = {}
34
+
35
+ keys = samples[0].keys()
36
+
37
+ for key in keys:
38
+ result[key] = []
39
+
40
+ for sample in samples:
41
+ for key in keys:
42
+ val = sample[key]
43
+ result[key].append(val)
44
+
45
+ for key in keys:
46
+ val_list = result[key]
47
+ if isinstance(val_list[0], (int, float)):
48
+ if combine_scalars:
49
+ result[key] = np.array(result[key])
50
+
51
+ elif isinstance(val_list[0], torch.Tensor):
52
+ if combine_tensors:
53
+ result[key] = torch.stack(val_list)
54
+
55
+ elif isinstance(val_list[0], np.ndarray):
56
+ if combine_tensors:
57
+ result[key] = np.stack(val_list)
58
+
59
+ return result
primitive_anything/michelangelo/graphics/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
primitive_anything/michelangelo/graphics/primitives/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .volume import generate_dense_grid_points
4
+
5
+ from .mesh import (
6
+ MeshOutput,
7
+ save_obj,
8
+ savemeshtes2
9
+ )
primitive_anything/michelangelo/graphics/primitives/mesh.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ import PIL.Image
7
+ from typing import Optional
8
+
9
+ import trimesh
10
+
11
+
12
+ def save_obj(pointnp_px3, facenp_fx3, fname):
13
+ fid = open(fname, "w")
14
+ write_str = ""
15
+ for pidx, p in enumerate(pointnp_px3):
16
+ pp = p
17
+ write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2])
18
+
19
+ for i, f in enumerate(facenp_fx3):
20
+ f1 = f + 1
21
+ write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2])
22
+ fid.write(write_str)
23
+ fid.close()
24
+ return
25
+
26
+
27
+ def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname):
28
+ fol, na = os.path.split(fname)
29
+ na, _ = os.path.splitext(na)
30
+
31
+ matname = "%s/%s.mtl" % (fol, na)
32
+ fid = open(matname, "w")
33
+ fid.write("newmtl material_0\n")
34
+ fid.write("Kd 1 1 1\n")
35
+ fid.write("Ka 0 0 0\n")
36
+ fid.write("Ks 0.4 0.4 0.4\n")
37
+ fid.write("Ns 10\n")
38
+ fid.write("illum 2\n")
39
+ fid.write("map_Kd %s.png\n" % na)
40
+ fid.close()
41
+ ####
42
+
43
+ fid = open(fname, "w")
44
+ fid.write("mtllib %s.mtl\n" % na)
45
+
46
+ for pidx, p in enumerate(pointnp_px3):
47
+ pp = p
48
+ fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
49
+
50
+ for pidx, p in enumerate(tcoords_px2):
51
+ pp = p
52
+ fid.write("vt %f %f\n" % (pp[0], pp[1]))
53
+
54
+ fid.write("usemtl material_0\n")
55
+ for i, f in enumerate(facenp_fx3):
56
+ f1 = f + 1
57
+ f2 = facetex_fx3[i] + 1
58
+ fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
59
+ fid.close()
60
+
61
+ PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save(
62
+ os.path.join(fol, "%s.png" % na))
63
+
64
+ return
65
+
66
+
67
+ class MeshOutput(object):
68
+
69
+ def __init__(self,
70
+ mesh_v: np.ndarray,
71
+ mesh_f: np.ndarray,
72
+ vertex_colors: Optional[np.ndarray] = None,
73
+ uvs: Optional[np.ndarray] = None,
74
+ mesh_tex_idx: Optional[np.ndarray] = None,
75
+ tex_map: Optional[np.ndarray] = None):
76
+
77
+ self.mesh_v = mesh_v
78
+ self.mesh_f = mesh_f
79
+ self.vertex_colors = vertex_colors
80
+ self.uvs = uvs
81
+ self.mesh_tex_idx = mesh_tex_idx
82
+ self.tex_map = tex_map
83
+
84
+ def contain_uv_texture(self):
85
+ return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None)
86
+
87
+ def contain_vertex_colors(self):
88
+ return self.vertex_colors is not None
89
+
90
+ def export(self, fname):
91
+
92
+ if self.contain_uv_texture():
93
+ savemeshtes2(
94
+ self.mesh_v,
95
+ self.uvs,
96
+ self.mesh_f,
97
+ self.mesh_tex_idx,
98
+ self.tex_map,
99
+ fname
100
+ )
101
+
102
+ elif self.contain_vertex_colors():
103
+ mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors)
104
+ mesh_obj.export(fname)
105
+
106
+ else:
107
+ save_obj(
108
+ self.mesh_v,
109
+ self.mesh_f,
110
+ fname
111
+ )
112
+
113
+
114
+
primitive_anything/michelangelo/graphics/primitives/volume.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import numpy as np
4
+
5
+
6
+ def generate_dense_grid_points(bbox_min: np.ndarray,
7
+ bbox_max: np.ndarray,
8
+ octree_depth: int,
9
+ indexing: str = "ij"):
10
+ length = bbox_max - bbox_min
11
+ num_cells = np.exp2(octree_depth)
12
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
13
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
14
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
15
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
16
+ xyz = np.stack((xs, ys, zs), axis=-1)
17
+ xyz = xyz.reshape(-1, 3)
18
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
19
+
20
+ return xyz, grid_size, length
21
+
primitive_anything/michelangelo/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
primitive_anything/michelangelo/models/asl_diffusion/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
primitive_anything/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from omegaconf import DictConfig
4
+ from typing import List, Tuple, Dict, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.optim import lr_scheduler
10
+ import pytorch_lightning as pl
11
+ from pytorch_lightning.utilities import rank_zero_only
12
+
13
+ from einops import rearrange
14
+
15
+ from diffusers.schedulers import (
16
+ DDPMScheduler,
17
+ DDIMScheduler,
18
+ KarrasVeScheduler,
19
+ DPMSolverMultistepScheduler
20
+ )
21
+
22
+ from ...utils import instantiate_from_config
23
+ # from ..tsal.tsal_base import ShapeAsLatentPLModule
24
+ from ..tsal.tsal_base import AlignedShapeAsLatentPLModule
25
+ from .inference_utils import ddim_sample
26
+
27
+ SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
28
+
29
+
30
+ def disabled_train(self, mode=True):
31
+ """Overwrite model.train with this function to make sure train/eval mode
32
+ does not change anymore."""
33
+ return self
34
+
35
+
36
+ class ASLDiffuser(pl.LightningModule):
37
+ first_stage_model: Optional[AlignedShapeAsLatentPLModule]
38
+ # cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
39
+ model: nn.Module
40
+
41
+ def __init__(self, *,
42
+ first_stage_config,
43
+ denoiser_cfg,
44
+ scheduler_cfg,
45
+ optimizer_cfg,
46
+ loss_cfg,
47
+ first_stage_key: str = "surface",
48
+ cond_stage_key: str = "image",
49
+ cond_stage_trainable: bool = True,
50
+ scale_by_std: bool = False,
51
+ z_scale_factor: float = 1.0,
52
+ ckpt_path: Optional[str] = None,
53
+ ignore_keys: Union[Tuple[str], List[str]] = ()):
54
+
55
+ super().__init__()
56
+
57
+ self.first_stage_key = first_stage_key
58
+ self.cond_stage_key = cond_stage_key
59
+ self.cond_stage_trainable = cond_stage_trainable
60
+
61
+ # 1. initialize first stage.
62
+ # Note: the condition model contained in the first stage model.
63
+ self.first_stage_config = first_stage_config
64
+ self.first_stage_model = None
65
+ # self.instantiate_first_stage(first_stage_config)
66
+
67
+ # 2. initialize conditional stage
68
+ # self.instantiate_cond_stage(cond_stage_config)
69
+ self.cond_stage_model = {
70
+ "image": self.encode_image,
71
+ "image_unconditional_embedding": self.empty_img_cond,
72
+ "text": self.encode_text,
73
+ "text_unconditional_embedding": self.empty_text_cond,
74
+ "surface": self.encode_surface,
75
+ "surface_unconditional_embedding": self.empty_surface_cond,
76
+ }
77
+
78
+ # 3. diffusion model
79
+ self.model = instantiate_from_config(
80
+ denoiser_cfg, device=None, dtype=None
81
+ )
82
+
83
+ self.optimizer_cfg = optimizer_cfg
84
+
85
+ # 4. scheduling strategy
86
+ self.scheduler_cfg = scheduler_cfg
87
+
88
+ self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
89
+ self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
90
+
91
+ # 5. loss configures
92
+ self.loss_cfg = loss_cfg
93
+
94
+ self.scale_by_std = scale_by_std
95
+ if scale_by_std:
96
+ self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
97
+ else:
98
+ self.z_scale_factor = z_scale_factor
99
+
100
+ self.ckpt_path = ckpt_path
101
+ if ckpt_path is not None:
102
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
103
+
104
+ def instantiate_first_stage(self, config):
105
+ model = instantiate_from_config(config)
106
+ self.first_stage_model = model.eval()
107
+ self.first_stage_model.train = disabled_train
108
+ for param in self.first_stage_model.parameters():
109
+ param.requires_grad = False
110
+
111
+ self.first_stage_model = self.first_stage_model.to(self.device)
112
+
113
+ # def instantiate_cond_stage(self, config):
114
+ # if not self.cond_stage_trainable:
115
+ # if config == "__is_first_stage__":
116
+ # print("Using first stage also as cond stage.")
117
+ # self.cond_stage_model = self.first_stage_model
118
+ # elif config == "__is_unconditional__":
119
+ # print(f"Training {self.__class__.__name__} as an unconditional model.")
120
+ # self.cond_stage_model = None
121
+ # # self.be_unconditional = True
122
+ # else:
123
+ # model = instantiate_from_config(config)
124
+ # self.cond_stage_model = model.eval()
125
+ # self.cond_stage_model.train = disabled_train
126
+ # for param in self.cond_stage_model.parameters():
127
+ # param.requires_grad = False
128
+ # else:
129
+ # assert config != "__is_first_stage__"
130
+ # assert config != "__is_unconditional__"
131
+ # model = instantiate_from_config(config)
132
+ # self.cond_stage_model = model
133
+
134
+ def init_from_ckpt(self, path, ignore_keys=()):
135
+ state_dict = torch.load(path, map_location="cpu")["state_dict"]
136
+
137
+ keys = list(state_dict.keys())
138
+ for k in keys:
139
+ for ik in ignore_keys:
140
+ if k.startswith(ik):
141
+ print("Deleting key {} from state_dict.".format(k))
142
+ del state_dict[k]
143
+
144
+ missing, unexpected = self.load_state_dict(state_dict, strict=False)
145
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
146
+ if len(missing) > 0:
147
+ print(f"Missing Keys: {missing}")
148
+ print(f"Unexpected Keys: {unexpected}")
149
+
150
+ @property
151
+ def zero_rank(self):
152
+ if self._trainer:
153
+ zero_rank = self.trainer.local_rank == 0
154
+ else:
155
+ zero_rank = True
156
+
157
+ return zero_rank
158
+
159
+ def configure_optimizers(self) -> Tuple[List, List]:
160
+
161
+ lr = self.learning_rate
162
+
163
+ trainable_parameters = list(self.model.parameters())
164
+ # if the conditional encoder is trainable
165
+
166
+ # if self.cond_stage_trainable:
167
+ # conditioner_params = [p for p in self.cond_stage_model.parameters() if p.requires_grad]
168
+ # trainable_parameters += conditioner_params
169
+ # print(f"number of trainable conditional parameters: {len(conditioner_params)}.")
170
+
171
+ if self.optimizer_cfg is None:
172
+ optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
173
+ schedulers = []
174
+ else:
175
+ optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
176
+ scheduler_func = instantiate_from_config(
177
+ self.optimizer_cfg.scheduler,
178
+ max_decay_steps=self.trainer.max_steps,
179
+ lr_max=lr
180
+ )
181
+ scheduler = {
182
+ "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
183
+ "interval": "step",
184
+ "frequency": 1
185
+ }
186
+ optimizers = [optimizer]
187
+ schedulers = [scheduler]
188
+
189
+ return optimizers, schedulers
190
+
191
+ @torch.no_grad()
192
+ def encode_text(self, text):
193
+
194
+ b = text.shape[0]
195
+ text_tokens = rearrange(text, "b t l -> (b t) l")
196
+ text_embed = self.first_stage_model.model.encode_text_embed(text_tokens)
197
+ text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
198
+ text_embed = text_embed.mean(dim=1)
199
+ text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
200
+
201
+ return text_embed
202
+
203
+ @torch.no_grad()
204
+ def encode_image(self, img):
205
+
206
+ return self.first_stage_model.model.encode_image_embed(img)
207
+
208
+ @torch.no_grad()
209
+ def encode_surface(self, surface):
210
+
211
+ return self.first_stage_model.model.encode_shape_embed(surface, return_latents=False)
212
+
213
+ @torch.no_grad()
214
+ def empty_text_cond(self, cond):
215
+
216
+ return torch.zeros_like(cond, device=cond.device)
217
+
218
+ @torch.no_grad()
219
+ def empty_img_cond(self, cond):
220
+
221
+ return torch.zeros_like(cond, device=cond.device)
222
+
223
+ @torch.no_grad()
224
+ def empty_surface_cond(self, cond):
225
+
226
+ return torch.zeros_like(cond, device=cond.device)
227
+
228
+ @torch.no_grad()
229
+ def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
230
+
231
+ z_q = self.first_stage_model.encode(surface, sample_posterior)
232
+ z_q = self.z_scale_factor * z_q
233
+
234
+ return z_q
235
+
236
+ @torch.no_grad()
237
+ def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
238
+
239
+ z_q = 1. / self.z_scale_factor * z_q
240
+ latents = self.first_stage_model.decode(z_q, **kwargs)
241
+ return latents
242
+
243
+ @rank_zero_only
244
+ @torch.no_grad()
245
+ def on_train_batch_start(self, batch, batch_idx):
246
+ # only for very first batch
247
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
248
+ and batch_idx == 0 and self.ckpt_path is None:
249
+ # set rescale weight to 1./std of encodings
250
+ print("### USING STD-RESCALING ###")
251
+
252
+ z_q = self.encode_first_stage(batch[self.first_stage_key])
253
+ z = z_q.detach()
254
+
255
+ del self.z_scale_factor
256
+ self.register_buffer("z_scale_factor", 1. / z.flatten().std())
257
+ print(f"setting self.z_scale_factor to {self.z_scale_factor}")
258
+
259
+ print("### USING STD-RESCALING ###")
260
+
261
+ def compute_loss(self, model_outputs, split):
262
+ """
263
+
264
+ Args:
265
+ model_outputs (dict):
266
+ - x_0:
267
+ - noise:
268
+ - noise_prior:
269
+ - noise_pred:
270
+ - noise_pred_prior:
271
+
272
+ split (str):
273
+
274
+ Returns:
275
+
276
+ """
277
+
278
+ pred = model_outputs["pred"]
279
+
280
+ if self.noise_scheduler.prediction_type == "epsilon":
281
+ target = model_outputs["noise"]
282
+ elif self.noise_scheduler.prediction_type == "sample":
283
+ target = model_outputs["x_0"]
284
+ else:
285
+ raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
286
+
287
+ if self.loss_cfg.loss_type == "l1":
288
+ simple = F.l1_loss(pred, target, reduction="mean")
289
+ elif self.loss_cfg.loss_type in ["mse", "l2"]:
290
+ simple = F.mse_loss(pred, target, reduction="mean")
291
+ else:
292
+ raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
293
+
294
+ total_loss = simple
295
+
296
+ loss_dict = {
297
+ f"{split}/total_loss": total_loss.clone().detach(),
298
+ f"{split}/simple": simple.detach(),
299
+ }
300
+
301
+ return total_loss, loss_dict
302
+
303
+ def forward(self, batch):
304
+ """
305
+
306
+ Args:
307
+ batch:
308
+
309
+ Returns:
310
+
311
+ """
312
+
313
+ if self.first_stage_model is None:
314
+ self.instantiate_first_stage(self.first_stage_config)
315
+
316
+ latents = self.encode_first_stage(batch[self.first_stage_key])
317
+
318
+ # conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
319
+
320
+ conditions = self.cond_stage_model[self.cond_stage_key](batch[self.cond_stage_key]).unsqueeze(1)
321
+
322
+ mask = torch.rand((len(conditions), 1, 1), device=conditions.device, dtype=conditions.dtype) >= 0.1
323
+ conditions = conditions * mask.to(conditions)
324
+
325
+ # Sample noise that we"ll add to the latents
326
+ # [batch_size, n_token, latent_dim]
327
+ noise = torch.randn_like(latents)
328
+ bs = latents.shape[0]
329
+ # Sample a random timestep for each motion
330
+ timesteps = torch.randint(
331
+ 0,
332
+ self.noise_scheduler.config.num_train_timesteps,
333
+ (bs,),
334
+ device=latents.device,
335
+ )
336
+ timesteps = timesteps.long()
337
+ # Add noise to the latents according to the noise magnitude at each timestep
338
+ noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
339
+
340
+ # diffusion model forward
341
+ noise_pred = self.model(noisy_z, timesteps, conditions)
342
+
343
+ diffusion_outputs = {
344
+ "x_0": noisy_z,
345
+ "noise": noise,
346
+ "pred": noise_pred
347
+ }
348
+
349
+ return diffusion_outputs
350
+
351
+ def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
352
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
353
+ """
354
+
355
+ Args:
356
+ batch (dict): the batch sample, and it contains:
357
+ - surface (torch.FloatTensor):
358
+ - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
359
+ - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
360
+ - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
361
+ - text (list of str):
362
+
363
+ batch_idx (int):
364
+
365
+ optimizer_idx (int):
366
+
367
+ Returns:
368
+ loss (torch.FloatTensor):
369
+
370
+ """
371
+
372
+ diffusion_outputs = self(batch)
373
+
374
+ loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
375
+ self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
376
+
377
+ return loss
378
+
379
+ def validation_step(self, batch: Dict[str, torch.FloatTensor],
380
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
381
+ """
382
+
383
+ Args:
384
+ batch (dict): the batch sample, and it contains:
385
+ - surface_pc (torch.FloatTensor): [n_pts, 4]
386
+ - surface_feats (torch.FloatTensor): [n_pts, c]
387
+ - text (list of str):
388
+
389
+ batch_idx (int):
390
+
391
+ optimizer_idx (int):
392
+
393
+ Returns:
394
+ loss (torch.FloatTensor):
395
+
396
+ """
397
+
398
+ diffusion_outputs = self(batch)
399
+
400
+ loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
401
+ self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
402
+
403
+ return loss
404
+
405
+ @torch.no_grad()
406
+ def sample(self,
407
+ batch: Dict[str, Union[torch.FloatTensor, List[str]]],
408
+ sample_times: int = 1,
409
+ steps: Optional[int] = None,
410
+ guidance_scale: Optional[float] = None,
411
+ eta: float = 0.0,
412
+ return_intermediates: bool = False, **kwargs):
413
+
414
+ if self.first_stage_model is None:
415
+ self.instantiate_first_stage(self.first_stage_config)
416
+
417
+ if steps is None:
418
+ steps = self.scheduler_cfg.num_inference_steps
419
+
420
+ if guidance_scale is None:
421
+ guidance_scale = self.scheduler_cfg.guidance_scale
422
+ do_classifier_free_guidance = guidance_scale > 0
423
+
424
+ # conditional encode
425
+ xc = batch[self.cond_stage_key]
426
+ # cond = self.cond_stage_model[self.cond_stage_key](xc)
427
+ cond = self.cond_stage_model[self.cond_stage_key](xc).unsqueeze(1)
428
+
429
+ if do_classifier_free_guidance:
430
+ """
431
+ Note: There are two kinds of uncond for text.
432
+ 1: using "" as uncond text; (in SAL diffusion)
433
+ 2: zeros_like(cond) as uncond text; (in MDM)
434
+ """
435
+ # un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
436
+ un_cond = self.cond_stage_model[f"{self.cond_stage_key}_unconditional_embedding"](cond)
437
+ # un_cond = torch.zeros_like(cond, device=cond.device)
438
+ cond = torch.cat([un_cond, cond], dim=0)
439
+
440
+ outputs = []
441
+ latents = None
442
+
443
+ if not return_intermediates:
444
+ for _ in range(sample_times):
445
+ sample_loop = ddim_sample(
446
+ self.denoise_scheduler,
447
+ self.model,
448
+ shape=self.first_stage_model.latent_shape,
449
+ cond=cond,
450
+ steps=steps,
451
+ guidance_scale=guidance_scale,
452
+ do_classifier_free_guidance=do_classifier_free_guidance,
453
+ device=self.device,
454
+ eta=eta,
455
+ disable_prog=not self.zero_rank
456
+ )
457
+ for sample, t in sample_loop:
458
+ latents = sample
459
+ outputs.append(self.decode_first_stage(latents, **kwargs))
460
+ else:
461
+
462
+ sample_loop = ddim_sample(
463
+ self.denoise_scheduler,
464
+ self.model,
465
+ shape=self.first_stage_model.latent_shape,
466
+ cond=cond,
467
+ steps=steps,
468
+ guidance_scale=guidance_scale,
469
+ do_classifier_free_guidance=do_classifier_free_guidance,
470
+ device=self.device,
471
+ eta=eta,
472
+ disable_prog=not self.zero_rank
473
+ )
474
+
475
+ iter_size = steps // sample_times
476
+ i = 0
477
+ for sample, t in sample_loop:
478
+ latents = sample
479
+ if i % iter_size == 0 or i == steps - 1:
480
+ outputs.append(self.decode_first_stage(latents, **kwargs))
481
+ i += 1
482
+
483
+ return outputs
primitive_anything/michelangelo/models/asl_diffusion/asl_udt.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Optional
6
+ from diffusers.models.embeddings import Timesteps
7
+ import math
8
+
9
+ from ..modules.transformer_blocks import MLP
10
+ from ..modules.diffusion_transformer import UNetDiffusionTransformer
11
+
12
+
13
+ class ConditionalASLUDTDenoiser(nn.Module):
14
+
15
+ def __init__(self, *,
16
+ device: Optional[torch.device],
17
+ dtype: Optional[torch.dtype],
18
+ input_channels: int,
19
+ output_channels: int,
20
+ n_ctx: int,
21
+ width: int,
22
+ layers: int,
23
+ heads: int,
24
+ context_dim: int,
25
+ context_ln: bool = True,
26
+ skip_ln: bool = False,
27
+ init_scale: float = 0.25,
28
+ flip_sin_to_cos: bool = False,
29
+ use_checkpoint: bool = False):
30
+ super().__init__()
31
+
32
+ self.use_checkpoint = use_checkpoint
33
+
34
+ init_scale = init_scale * math.sqrt(1.0 / width)
35
+
36
+ self.backbone = UNetDiffusionTransformer(
37
+ device=device,
38
+ dtype=dtype,
39
+ n_ctx=n_ctx,
40
+ width=width,
41
+ layers=layers,
42
+ heads=heads,
43
+ skip_ln=skip_ln,
44
+ init_scale=init_scale,
45
+ use_checkpoint=use_checkpoint
46
+ )
47
+ self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
48
+ self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
49
+ self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
50
+
51
+ # timestep embedding
52
+ self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0)
53
+ self.time_proj = MLP(
54
+ device=device, dtype=dtype, width=width, init_scale=init_scale
55
+ )
56
+
57
+ self.context_embed = nn.Sequential(
58
+ nn.LayerNorm(context_dim, device=device, dtype=dtype),
59
+ nn.Linear(context_dim, width, device=device, dtype=dtype),
60
+ )
61
+
62
+ if context_ln:
63
+ self.context_embed = nn.Sequential(
64
+ nn.LayerNorm(context_dim, device=device, dtype=dtype),
65
+ nn.Linear(context_dim, width, device=device, dtype=dtype),
66
+ )
67
+ else:
68
+ self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype)
69
+
70
+ def forward(self,
71
+ model_input: torch.FloatTensor,
72
+ timestep: torch.LongTensor,
73
+ context: torch.FloatTensor):
74
+
75
+ r"""
76
+ Args:
77
+ model_input (torch.FloatTensor): [bs, n_data, c]
78
+ timestep (torch.LongTensor): [bs,]
79
+ context (torch.FloatTensor): [bs, context_tokens, c]
80
+
81
+ Returns:
82
+ sample (torch.FloatTensor): [bs, n_data, c]
83
+
84
+ """
85
+
86
+ _, n_data, _ = model_input.shape
87
+
88
+ # 1. time
89
+ t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1)
90
+
91
+ # 2. conditions projector
92
+ context = self.context_embed(context)
93
+
94
+ # 3. denoiser
95
+ x = self.input_proj(model_input)
96
+ x = torch.cat([t_emb, context, x], dim=1)
97
+ x = self.backbone(x)
98
+ x = self.ln_post(x)
99
+ x = x[:, -n_data:]
100
+ sample = self.output_proj(x)
101
+
102
+ return sample
103
+
104
+
primitive_anything/michelangelo/models/asl_diffusion/base.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class BaseDenoiser(nn.Module):
8
+
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def forward(self, x, t, context):
13
+ raise NotImplementedError
primitive_anything/michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from omegaconf import DictConfig
4
+ from typing import List, Tuple, Dict, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.optim import lr_scheduler
10
+ import pytorch_lightning as pl
11
+ from pytorch_lightning.utilities import rank_zero_only
12
+
13
+ from diffusers.schedulers import (
14
+ DDPMScheduler,
15
+ DDIMScheduler,
16
+ KarrasVeScheduler,
17
+ DPMSolverMultistepScheduler
18
+ )
19
+
20
+ from ...utils import instantiate_from_config
21
+ from ..tsal.tsal_base import AlignedShapeAsLatentPLModule
22
+ from .inference_utils import ddim_sample
23
+
24
+ SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
25
+
26
+
27
+ def disabled_train(self, mode=True):
28
+ """Overwrite model.train with this function to make sure train/eval mode
29
+ does not change anymore."""
30
+ return self
31
+
32
+
33
+ class ClipASLDiffuser(pl.LightningModule):
34
+ first_stage_model: Optional[AlignedShapeAsLatentPLModule]
35
+ cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
36
+ model: nn.Module
37
+
38
+ def __init__(self, *,
39
+ first_stage_config,
40
+ cond_stage_config,
41
+ denoiser_cfg,
42
+ scheduler_cfg,
43
+ optimizer_cfg,
44
+ loss_cfg,
45
+ first_stage_key: str = "surface",
46
+ cond_stage_key: str = "image",
47
+ scale_by_std: bool = False,
48
+ z_scale_factor: float = 1.0,
49
+ ckpt_path: Optional[str] = None,
50
+ ignore_keys: Union[Tuple[str], List[str]] = ()):
51
+
52
+ super().__init__()
53
+
54
+ self.first_stage_key = first_stage_key
55
+ self.cond_stage_key = cond_stage_key
56
+
57
+ # 1. lazy initialize first stage
58
+ self.instantiate_first_stage(first_stage_config)
59
+
60
+ # 2. initialize conditional stage
61
+ self.instantiate_cond_stage(cond_stage_config)
62
+
63
+ # 3. diffusion model
64
+ self.model = instantiate_from_config(
65
+ denoiser_cfg, device=None, dtype=None
66
+ )
67
+
68
+ self.optimizer_cfg = optimizer_cfg
69
+
70
+ # 4. scheduling strategy
71
+ self.scheduler_cfg = scheduler_cfg
72
+
73
+ self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
74
+ self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
75
+
76
+ # 5. loss configures
77
+ self.loss_cfg = loss_cfg
78
+
79
+ self.scale_by_std = scale_by_std
80
+ if scale_by_std:
81
+ self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
82
+ else:
83
+ self.z_scale_factor = z_scale_factor
84
+
85
+ self.ckpt_path = ckpt_path
86
+ if ckpt_path is not None:
87
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
88
+
89
+ def instantiate_non_trainable_model(self, config):
90
+ model = instantiate_from_config(config)
91
+ model = model.eval()
92
+ model.train = disabled_train
93
+ for param in model.parameters():
94
+ param.requires_grad = False
95
+
96
+ return model
97
+
98
+ def instantiate_first_stage(self, first_stage_config):
99
+ self.first_stage_model = self.instantiate_non_trainable_model(first_stage_config)
100
+ self.first_stage_model.set_shape_model_only()
101
+
102
+ def instantiate_cond_stage(self, cond_stage_config):
103
+ self.cond_stage_model = self.instantiate_non_trainable_model(cond_stage_config)
104
+
105
+ def init_from_ckpt(self, path, ignore_keys=()):
106
+ state_dict = torch.load(path, map_location="cpu")["state_dict"]
107
+
108
+ keys = list(state_dict.keys())
109
+ for k in keys:
110
+ for ik in ignore_keys:
111
+ if k.startswith(ik):
112
+ print("Deleting key {} from state_dict.".format(k))
113
+ del state_dict[k]
114
+
115
+ missing, unexpected = self.load_state_dict(state_dict, strict=False)
116
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
117
+ if len(missing) > 0:
118
+ print(f"Missing Keys: {missing}")
119
+ print(f"Unexpected Keys: {unexpected}")
120
+
121
+ @property
122
+ def zero_rank(self):
123
+ if self._trainer:
124
+ zero_rank = self.trainer.local_rank == 0
125
+ else:
126
+ zero_rank = True
127
+
128
+ return zero_rank
129
+
130
+ def configure_optimizers(self) -> Tuple[List, List]:
131
+
132
+ lr = self.learning_rate
133
+
134
+ trainable_parameters = list(self.model.parameters())
135
+ if self.optimizer_cfg is None:
136
+ optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
137
+ schedulers = []
138
+ else:
139
+ optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
140
+ scheduler_func = instantiate_from_config(
141
+ self.optimizer_cfg.scheduler,
142
+ max_decay_steps=self.trainer.max_steps,
143
+ lr_max=lr
144
+ )
145
+ scheduler = {
146
+ "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
147
+ "interval": "step",
148
+ "frequency": 1
149
+ }
150
+ optimizers = [optimizer]
151
+ schedulers = [scheduler]
152
+
153
+ return optimizers, schedulers
154
+
155
+ @torch.no_grad()
156
+ def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
157
+
158
+ z_q = self.first_stage_model.encode(surface, sample_posterior)
159
+ z_q = self.z_scale_factor * z_q
160
+
161
+ return z_q
162
+
163
+ @torch.no_grad()
164
+ def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
165
+
166
+ z_q = 1. / self.z_scale_factor * z_q
167
+ latents = self.first_stage_model.decode(z_q, **kwargs)
168
+ return latents
169
+
170
+ @rank_zero_only
171
+ @torch.no_grad()
172
+ def on_train_batch_start(self, batch, batch_idx):
173
+ # only for very first batch
174
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
175
+ and batch_idx == 0 and self.ckpt_path is None:
176
+ # set rescale weight to 1./std of encodings
177
+ print("### USING STD-RESCALING ###")
178
+
179
+ z_q = self.encode_first_stage(batch[self.first_stage_key])
180
+ z = z_q.detach()
181
+
182
+ del self.z_scale_factor
183
+ self.register_buffer("z_scale_factor", 1. / z.flatten().std())
184
+ print(f"setting self.z_scale_factor to {self.z_scale_factor}")
185
+
186
+ print("### USING STD-RESCALING ###")
187
+
188
+ def compute_loss(self, model_outputs, split):
189
+ """
190
+
191
+ Args:
192
+ model_outputs (dict):
193
+ - x_0:
194
+ - noise:
195
+ - noise_prior:
196
+ - noise_pred:
197
+ - noise_pred_prior:
198
+
199
+ split (str):
200
+
201
+ Returns:
202
+
203
+ """
204
+
205
+ pred = model_outputs["pred"]
206
+
207
+ if self.noise_scheduler.prediction_type == "epsilon":
208
+ target = model_outputs["noise"]
209
+ elif self.noise_scheduler.prediction_type == "sample":
210
+ target = model_outputs["x_0"]
211
+ else:
212
+ raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
213
+
214
+ if self.loss_cfg.loss_type == "l1":
215
+ simple = F.l1_loss(pred, target, reduction="mean")
216
+ elif self.loss_cfg.loss_type in ["mse", "l2"]:
217
+ simple = F.mse_loss(pred, target, reduction="mean")
218
+ else:
219
+ raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
220
+
221
+ total_loss = simple
222
+
223
+ loss_dict = {
224
+ f"{split}/total_loss": total_loss.clone().detach(),
225
+ f"{split}/simple": simple.detach(),
226
+ }
227
+
228
+ return total_loss, loss_dict
229
+
230
+ def forward(self, batch):
231
+ """
232
+
233
+ Args:
234
+ batch:
235
+
236
+ Returns:
237
+
238
+ """
239
+
240
+ latents = self.encode_first_stage(batch[self.first_stage_key])
241
+ conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
242
+
243
+ # Sample noise that we"ll add to the latents
244
+ # [batch_size, n_token, latent_dim]
245
+ noise = torch.randn_like(latents)
246
+ bs = latents.shape[0]
247
+ # Sample a random timestep for each motion
248
+ timesteps = torch.randint(
249
+ 0,
250
+ self.noise_scheduler.config.num_train_timesteps,
251
+ (bs,),
252
+ device=latents.device,
253
+ )
254
+ timesteps = timesteps.long()
255
+ # Add noise to the latents according to the noise magnitude at each timestep
256
+ noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
257
+
258
+ # diffusion model forward
259
+ noise_pred = self.model(noisy_z, timesteps, conditions)
260
+
261
+ diffusion_outputs = {
262
+ "x_0": noisy_z,
263
+ "noise": noise,
264
+ "pred": noise_pred
265
+ }
266
+
267
+ return diffusion_outputs
268
+
269
+ def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
270
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
271
+ """
272
+
273
+ Args:
274
+ batch (dict): the batch sample, and it contains:
275
+ - surface (torch.FloatTensor):
276
+ - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
277
+ - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
278
+ - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
279
+ - text (list of str):
280
+
281
+ batch_idx (int):
282
+
283
+ optimizer_idx (int):
284
+
285
+ Returns:
286
+ loss (torch.FloatTensor):
287
+
288
+ """
289
+
290
+ diffusion_outputs = self(batch)
291
+
292
+ loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
293
+ self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
294
+
295
+ return loss
296
+
297
+ def validation_step(self, batch: Dict[str, torch.FloatTensor],
298
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
299
+ """
300
+
301
+ Args:
302
+ batch (dict): the batch sample, and it contains:
303
+ - surface_pc (torch.FloatTensor): [n_pts, 4]
304
+ - surface_feats (torch.FloatTensor): [n_pts, c]
305
+ - text (list of str):
306
+
307
+ batch_idx (int):
308
+
309
+ optimizer_idx (int):
310
+
311
+ Returns:
312
+ loss (torch.FloatTensor):
313
+
314
+ """
315
+
316
+ diffusion_outputs = self(batch)
317
+
318
+ loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
319
+ self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
320
+
321
+ return loss
322
+
323
+ @torch.no_grad()
324
+ def sample(self,
325
+ batch: Dict[str, Union[torch.FloatTensor, List[str]]],
326
+ sample_times: int = 1,
327
+ steps: Optional[int] = None,
328
+ guidance_scale: Optional[float] = None,
329
+ eta: float = 0.0,
330
+ return_intermediates: bool = False, **kwargs):
331
+
332
+ if steps is None:
333
+ steps = self.scheduler_cfg.num_inference_steps
334
+
335
+ if guidance_scale is None:
336
+ guidance_scale = self.scheduler_cfg.guidance_scale
337
+ do_classifier_free_guidance = guidance_scale > 0
338
+
339
+ # conditional encode
340
+ xc = batch[self.cond_stage_key]
341
+
342
+ # print(self.first_stage_model.device, self.cond_stage_model.device, self.device)
343
+
344
+ cond = self.cond_stage_model(xc)
345
+
346
+ if do_classifier_free_guidance:
347
+ un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
348
+ cond = torch.cat([un_cond, cond], dim=0)
349
+
350
+ outputs = []
351
+ latents = None
352
+
353
+ if not return_intermediates:
354
+ for _ in range(sample_times):
355
+ sample_loop = ddim_sample(
356
+ self.denoise_scheduler,
357
+ self.model,
358
+ shape=self.first_stage_model.latent_shape,
359
+ cond=cond,
360
+ steps=steps,
361
+ guidance_scale=guidance_scale,
362
+ do_classifier_free_guidance=do_classifier_free_guidance,
363
+ device=self.device,
364
+ eta=eta,
365
+ disable_prog=not self.zero_rank
366
+ )
367
+ for sample, t in sample_loop:
368
+ latents = sample
369
+ outputs.append(self.decode_first_stage(latents, **kwargs))
370
+ else:
371
+
372
+ sample_loop = ddim_sample(
373
+ self.denoise_scheduler,
374
+ self.model,
375
+ shape=self.first_stage_model.latent_shape,
376
+ cond=cond,
377
+ steps=steps,
378
+ guidance_scale=guidance_scale,
379
+ do_classifier_free_guidance=do_classifier_free_guidance,
380
+ device=self.device,
381
+ eta=eta,
382
+ disable_prog=not self.zero_rank
383
+ )
384
+
385
+ iter_size = steps // sample_times
386
+ i = 0
387
+ for sample, t in sample_loop:
388
+ latents = sample
389
+ if i % iter_size == 0 or i == steps - 1:
390
+ outputs.append(self.decode_first_stage(latents, **kwargs))
391
+ i += 1
392
+
393
+ return outputs
primitive_anything/michelangelo/models/asl_diffusion/inference_utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from tqdm import tqdm
5
+ from typing import Tuple, List, Union, Optional
6
+ from diffusers.schedulers import DDIMScheduler
7
+
8
+
9
+ __all__ = ["ddim_sample"]
10
+
11
+
12
+ def ddim_sample(ddim_scheduler: DDIMScheduler,
13
+ diffusion_model: torch.nn.Module,
14
+ shape: Union[List[int], Tuple[int]],
15
+ cond: torch.FloatTensor,
16
+ steps: int,
17
+ eta: float = 0.0,
18
+ guidance_scale: float = 3.0,
19
+ do_classifier_free_guidance: bool = True,
20
+ generator: Optional[torch.Generator] = None,
21
+ device: torch.device = "cuda:0",
22
+ disable_prog: bool = True):
23
+
24
+ assert steps > 0, f"{steps} must > 0."
25
+
26
+ # init latents
27
+ bsz = cond.shape[0]
28
+ if do_classifier_free_guidance:
29
+ bsz = bsz // 2
30
+
31
+ latents = torch.randn(
32
+ (bsz, *shape),
33
+ generator=generator,
34
+ device=cond.device,
35
+ dtype=cond.dtype,
36
+ )
37
+ # scale the initial noise by the standard deviation required by the scheduler
38
+ latents = latents * ddim_scheduler.init_noise_sigma
39
+ # set timesteps
40
+ ddim_scheduler.set_timesteps(steps)
41
+ timesteps = ddim_scheduler.timesteps.to(device)
42
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
43
+ # eta (η) is only used with the DDIMScheduler, and between [0, 1]
44
+ extra_step_kwargs = {
45
+ "eta": eta,
46
+ "generator": generator
47
+ }
48
+
49
+ # reverse
50
+ for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)):
51
+ # expand the latents if we are doing classifier free guidance
52
+ latent_model_input = (
53
+ torch.cat([latents] * 2)
54
+ if do_classifier_free_guidance
55
+ else latents
56
+ )
57
+ # latent_model_input = scheduler.scale_model_input(latent_model_input, t)
58
+ # predict the noise residual
59
+ timestep_tensor = torch.tensor([t], dtype=torch.long, device=device)
60
+ timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
61
+ noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond)
62
+
63
+ # perform guidance
64
+ if do_classifier_free_guidance:
65
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
66
+ noise_pred = noise_pred_uncond + guidance_scale * (
67
+ noise_pred_text - noise_pred_uncond
68
+ )
69
+ # text_embeddings_for_guidance = encoder_hidden_states.chunk(
70
+ # 2)[1] if do_classifier_free_guidance else encoder_hidden_states
71
+ # compute the previous noisy sample x_t -> x_t-1
72
+ latents = ddim_scheduler.step(
73
+ noise_pred, t, latents, **extra_step_kwargs
74
+ ).prev_sample
75
+
76
+ yield latents, t
77
+
78
+
79
+ def karra_sample():
80
+ pass
primitive_anything/michelangelo/models/conditional_encoders/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .clip import CLIPEncoder
primitive_anything/michelangelo/models/conditional_encoders/clip.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ from dataclasses import dataclass
7
+ from torchvision.transforms import Normalize
8
+ from transformers import CLIPModel, CLIPTokenizer
9
+ from transformers.utils import ModelOutput
10
+ from typing import Iterable, Optional, Union, List
11
+
12
+
13
+ ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
14
+
15
+
16
+ @dataclass
17
+ class CLIPEmbedOutput(ModelOutput):
18
+ last_hidden_state: torch.FloatTensor = None
19
+ pooler_output: torch.FloatTensor = None
20
+ embeds: torch.FloatTensor = None
21
+
22
+
23
+ class CLIPEncoder(torch.nn.Module):
24
+
25
+ def __init__(self, model_path="openai/clip-vit-base-patch32"):
26
+
27
+ super().__init__()
28
+
29
+ # Load the CLIP model and processor
30
+ self.model: CLIPModel = CLIPModel.from_pretrained(model_path)
31
+ self.tokenizer = CLIPTokenizer.from_pretrained(model_path)
32
+ self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
33
+
34
+ self.model.training = False
35
+ for p in self.model.parameters():
36
+ p.requires_grad = False
37
+
38
+ @torch.no_grad()
39
+ def encode_image(self, images: Iterable[Optional[ImageType]]):
40
+ pixel_values = self.image_preprocess(images)
41
+
42
+ vision_outputs = self.model.vision_model(pixel_values=pixel_values)
43
+
44
+ pooler_output = vision_outputs[1] # pooled_output
45
+ image_features = self.model.visual_projection(pooler_output)
46
+
47
+ visual_embeds = CLIPEmbedOutput(
48
+ last_hidden_state=vision_outputs.last_hidden_state,
49
+ pooler_output=pooler_output,
50
+ embeds=image_features
51
+ )
52
+
53
+ return visual_embeds
54
+
55
+ @torch.no_grad()
56
+ def encode_text(self, texts: List[str]):
57
+ text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt")
58
+
59
+ text_outputs = self.model.text_model(input_ids=text_inputs)
60
+
61
+ pooler_output = text_outputs[1] # pooled_output
62
+ text_features = self.model.text_projection(pooler_output)
63
+
64
+ text_embeds = CLIPEmbedOutput(
65
+ last_hidden_state=text_outputs.last_hidden_state,
66
+ pooler_output=pooler_output,
67
+ embeds=text_features
68
+ )
69
+
70
+ return text_embeds
71
+
72
+ def forward(self,
73
+ images: Iterable[Optional[ImageType]],
74
+ texts: List[str]):
75
+
76
+ visual_embeds = self.encode_image(images)
77
+ text_embeds = self.encode_text(texts)
78
+
79
+ return visual_embeds, text_embeds
80
+
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+
89
+
primitive_anything/michelangelo/models/conditional_encoders/encoder_factory.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision import transforms
7
+ from transformers import CLIPModel, CLIPTokenizer
8
+ from collections import OrderedDict
9
+
10
+ from ...data.transforms import RandomResize
11
+
12
+
13
+ class AbstractEncoder(nn.Module):
14
+ embedding_dim: int
15
+
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ def encode(self, *args, **kwargs):
20
+ raise NotImplementedError
21
+
22
+
23
+ class ClassEmbedder(nn.Module):
24
+ def __init__(self, embed_dim, n_classes=1000, key="class"):
25
+ super().__init__()
26
+ self.key = key
27
+ self.embedding = nn.Embedding(n_classes, embed_dim)
28
+
29
+ def forward(self, batch, key=None):
30
+ if key is None:
31
+ key = self.key
32
+ # this is for use in crossattn
33
+ c = batch[key][:, None]
34
+ c = self.embedding(c)
35
+ return c
36
+
37
+
38
+ class FrozenCLIPTextEmbedder(AbstractEncoder):
39
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
40
+
41
+ def __init__(
42
+ self,
43
+ version="openai/clip-vit-large-patch14",
44
+ tokenizer_version=None,
45
+ device="cuda",
46
+ max_length=77,
47
+ zero_embedding_radio: float = 0.1,
48
+ ):
49
+ super().__init__()
50
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
51
+
52
+ self.device = device
53
+ self.max_length = max_length
54
+ self.zero_embedding_radio = zero_embedding_radio
55
+
56
+ self.clip_dict = OrderedDict()
57
+ self.clip_name = os.path.split(version)[-1]
58
+
59
+ transformer = CLIPModel.from_pretrained(version).text_model
60
+
61
+ for param in transformer.parameters():
62
+ param.requires_grad = False
63
+ self.clip_dict[self.clip_name] = transformer
64
+
65
+ self._move_flag = False
66
+
67
+ @property
68
+ def clip(self):
69
+ return self.clip_dict[self.clip_name]
70
+
71
+ def move(self):
72
+ if self._move_flag:
73
+ return
74
+
75
+ self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
76
+ self._move_flag = True
77
+
78
+ def unconditional_embedding(self, batch_size):
79
+ empty_text = [""] * batch_size
80
+ empty_z = self.forward(empty_text)
81
+ return empty_z
82
+
83
+ def forward(self, text):
84
+ self.move()
85
+
86
+ batch_encoding = self.tokenizer(
87
+ text,
88
+ truncation=True,
89
+ max_length=self.max_length,
90
+ return_length=True,
91
+ return_overflowing_tokens=False,
92
+ padding="max_length",
93
+ return_tensors="pt",
94
+ )
95
+
96
+ tokens = batch_encoding["input_ids"].to(self.device)
97
+ outputs = self.clip(input_ids=tokens)
98
+
99
+ z = outputs.last_hidden_state
100
+ return z
101
+
102
+ def encode(self, text):
103
+ batch_size = len(text)
104
+ batch_mask = torch.rand((batch_size,))
105
+ for i in range(batch_size):
106
+ if batch_mask[i] < self.zero_embedding_radio:
107
+ text[i] = ""
108
+
109
+ return self(text)
110
+
111
+ class FrozenAlignedCLIPTextEmbedder(AbstractEncoder):
112
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
113
+
114
+ def __init__(
115
+ self,
116
+ version="openai/clip-vit-large-patch14",
117
+ tokenizer_version=None,
118
+ device="cuda",
119
+ max_length=77,
120
+ zero_embedding_radio: float = 0.1,
121
+ ):
122
+ super().__init__()
123
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
124
+
125
+ self.device = device
126
+ self.max_length = max_length
127
+ self.zero_embedding_radio = zero_embedding_radio
128
+
129
+ self.clip_dict = OrderedDict()
130
+ self.clip_name = os.path.split(version)[-1]
131
+
132
+ transformer = CLIPModel.from_pretrained(version).text_model
133
+
134
+ for param in transformer.parameters():
135
+ param.requires_grad = False
136
+ self.clip_dict[self.clip_name] = transformer
137
+
138
+ self._move_flag = False
139
+
140
+ @property
141
+ def clip(self):
142
+ return self.clip_dict[self.clip_name]
143
+
144
+ def move(self):
145
+ if self._move_flag:
146
+ return
147
+
148
+ self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
149
+ self._move_flag = True
150
+
151
+ def unconditional_embedding(self, batch_size):
152
+ empty_text = [""] * batch_size
153
+ empty_z = self.forward(empty_text)
154
+ return empty_z
155
+
156
+ def forward(self, text):
157
+ self.move()
158
+
159
+ batch_encoding = self.tokenizer(
160
+ text,
161
+ truncation=True,
162
+ max_length=self.max_length,
163
+ return_length=True,
164
+ return_overflowing_tokens=False,
165
+ padding="max_length",
166
+ return_tensors="pt",
167
+ )
168
+
169
+ tokens = batch_encoding["input_ids"].to(self.device)
170
+ outputs = self.clip(input_ids=tokens)
171
+
172
+ z = outputs.last_hidden_state
173
+ return z
174
+
175
+ def encode(self, text):
176
+ batch_size = len(text)
177
+ batch_mask = torch.rand((batch_size,))
178
+ for i in range(batch_size):
179
+ if batch_mask[i] < self.zero_embedding_radio:
180
+ text[i] = ""
181
+
182
+ return self(text)
183
+
184
+
185
+ class FrozenCLIPImageEmbedder(AbstractEncoder):
186
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
187
+
188
+ def __init__(
189
+ self,
190
+ version="openai/clip-vit-large-patch14",
191
+ device="cuda",
192
+ zero_embedding_radio=0.1,
193
+ normalize_embedding=True,
194
+ num_projection_vector=0,
195
+ linear_mapping_bias=True,
196
+ reverse_visual_projection=False,
197
+ ):
198
+ super().__init__()
199
+
200
+ self.device = device
201
+
202
+ self.clip_dict = OrderedDict()
203
+ self.clip_name = os.path.split(version)[-1]
204
+
205
+ clip_model = CLIPModel.from_pretrained(version)
206
+ clip_model.text_model = None
207
+ clip_model.text_projection = None
208
+ clip_model = clip_model.eval()
209
+ for param in self.parameters():
210
+ param.requires_grad = False
211
+ self.clip_dict[self.clip_name] = clip_model
212
+
213
+ self.transform = transforms.Compose(
214
+ [
215
+ transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
216
+ transforms.CenterCrop(224), # crop a (224, 224) square
217
+ transforms.Normalize(
218
+ mean=[0.48145466, 0.4578275, 0.40821073],
219
+ std=[0.26862954, 0.26130258, 0.27577711],
220
+ ),
221
+ ]
222
+ )
223
+ self.zero_embedding_radio = zero_embedding_radio
224
+
225
+ self.num_projection_vector = num_projection_vector
226
+ self.reverse_visual_projection = reverse_visual_projection
227
+ self.normalize_embedding = normalize_embedding
228
+
229
+ embedding_dim = (
230
+ clip_model.visual_projection.in_features
231
+ if reverse_visual_projection
232
+ else clip_model.visual_projection.out_features
233
+ )
234
+ self.embedding_dim = embedding_dim
235
+ if self.num_projection_vector > 0:
236
+ self.projection = nn.Linear(
237
+ embedding_dim,
238
+ clip_model.visual_projection.out_features * num_projection_vector,
239
+ bias=linear_mapping_bias,
240
+ )
241
+ nn.init.normal_(self.projection.weight, std=embedding_dim ** -0.5)
242
+
243
+ self._move_flag = False
244
+
245
+ @property
246
+ def clip(self):
247
+ return self.clip_dict[self.clip_name]
248
+
249
+ def unconditional_embedding(self, batch_size):
250
+ zero = torch.zeros(
251
+ batch_size,
252
+ 1,
253
+ self.embedding_dim,
254
+ device=self.device,
255
+ dtype=self.clip.visual_projection.weight.dtype,
256
+ )
257
+ if self.num_projection_vector > 0:
258
+ zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
259
+ return zero
260
+
261
+ def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
262
+ if value_range is not None:
263
+ low, high = value_range
264
+ image = (image - low) / (high - low)
265
+
266
+ image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
267
+
268
+ if self.reverse_visual_projection:
269
+ z = self.clip.vision_model(self.transform(image))[1]
270
+ else:
271
+ z = self.clip.get_image_features(self.transform(image))
272
+
273
+ if self.normalize_embedding:
274
+ z = z / z.norm(dim=-1, keepdim=True)
275
+ if z.ndim == 2:
276
+ z = z.unsqueeze(dim=-2)
277
+
278
+ if zero_embedding_radio > 0:
279
+ mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) < zero_embedding_radio
280
+ z = z * mask.to(z)
281
+
282
+ if self.num_projection_vector > 0:
283
+ z = self.projection(z).view(len(image), self.num_projection_vector, -1)
284
+
285
+ return z
286
+
287
+ def move(self):
288
+ if self._move_flag:
289
+ return
290
+
291
+ self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
292
+ self._move_flag = True
293
+
294
+ def encode(self, image):
295
+ self.move()
296
+ return self(image, zero_embedding_radio=self.zero_embedding_radio)
297
+
298
+
299
+ class FrozenCLIPImageGridEmbedder(AbstractEncoder):
300
+
301
+ def __init__(
302
+ self,
303
+ version="openai/clip-vit-large-patch14",
304
+ device="cuda",
305
+ zero_embedding_radio=0.1,
306
+ ):
307
+ super().__init__()
308
+
309
+ self.device = device
310
+
311
+ self.clip_dict = OrderedDict()
312
+ self.clip_name = os.path.split(version)[-1]
313
+
314
+ clip_model: CLIPModel = CLIPModel.from_pretrained(version)
315
+ clip_model.text_model = None
316
+ clip_model.text_projection = None
317
+ clip_model = clip_model.eval()
318
+ for param in self.parameters():
319
+ param.requires_grad = False
320
+ self.clip_dict[self.clip_name] = clip_model
321
+
322
+ self.transform = transforms.Compose(
323
+ [
324
+ transforms.Resize(224, transforms.InterpolationMode.BILINEAR, antialias=True),
325
+ transforms.CenterCrop(224), # crop a (224, 224) square
326
+ transforms.Normalize(
327
+ mean=[0.48145466, 0.4578275, 0.40821073],
328
+ std=[0.26862954, 0.26130258, 0.27577711],
329
+ ),
330
+ ]
331
+ )
332
+ self.zero_embedding_radio = zero_embedding_radio
333
+ self.embedding_dim = clip_model.vision_embed_dim
334
+
335
+ self._move_flag = False
336
+
337
+ @property
338
+ def clip(self):
339
+ return self.clip_dict[self.clip_name]
340
+
341
+ def move(self):
342
+ if self._move_flag:
343
+ return
344
+
345
+ self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
346
+ self._move_flag = True
347
+
348
+ def unconditional_embedding(self, batch_size):
349
+ zero = torch.zeros(
350
+ batch_size,
351
+ self.clip.vision_model.embeddings.num_positions,
352
+ self.embedding_dim,
353
+ device=self.device,
354
+ dtype=self.clip.visual_projection.weight.dtype,
355
+ )
356
+ return zero
357
+
358
+ def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
359
+ self.move()
360
+
361
+ if value_range is not None:
362
+ low, high = value_range
363
+ image = (image - low) / (high - low)
364
+
365
+ image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
366
+
367
+ z = self.clip.vision_model(self.transform(image)).last_hidden_state
368
+
369
+ if zero_embedding_radio > 0:
370
+ mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
371
+ z = z * mask.to(z)
372
+
373
+ return z
374
+
375
+ def encode(self, image):
376
+ return self(image, zero_embedding_radio=self.zero_embedding_radio)
377
+
378
+
379
+ class MoECLIPImageEncoder(nn.Module):
380
+ def __init__(
381
+ self,
382
+ versions,
383
+ hidden_state_dim,
384
+ num_projection_vector=8,
385
+ zero_embedding_radio=0.1,
386
+ device="cuda",
387
+ precision="fp16",
388
+ normalize=False,
389
+ clip_max=0,
390
+ transform_type="base",
391
+ argument_p=0.2,
392
+ ):
393
+ super().__init__()
394
+
395
+ self.device = torch.device(device)
396
+ self.hidden_state_dim = hidden_state_dim
397
+ self.zero_embedding_radio = zero_embedding_radio
398
+ self.num_projection_vector = num_projection_vector
399
+ self.dtype = dict(fp16=torch.float16, fp32=torch.float32, bf16=torch.bfloat16)[precision]
400
+ self.normalize = normalize
401
+ self.clip_max = clip_max
402
+
403
+ if transform_type == "base":
404
+ self.transform = transforms.Compose(
405
+ [
406
+ transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
407
+ transforms.CenterCrop(224), # crop a (224, 224) square
408
+ transforms.Normalize(
409
+ mean=[0.48145466, 0.4578275, 0.40821073],
410
+ std=[0.26862954, 0.26130258, 0.27577711],
411
+ ),
412
+ ]
413
+ )
414
+ elif transform_type == "crop_blur_resize":
415
+ self.transform = transforms.Compose(
416
+ [
417
+ transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
418
+ transforms.CenterCrop(224), # crop a (224, 224) square
419
+ transforms.RandomApply(
420
+ transforms=[
421
+ transforms.RandomResizedCrop(
422
+ size=224,
423
+ scale=(0.8, 1.0),
424
+ ratio=(0.99, 1.01),
425
+ interpolation=transforms.InterpolationMode.BICUBIC,
426
+ ),
427
+ ],
428
+ p=argument_p,
429
+ ),
430
+ transforms.RandomApply(
431
+ transforms=[
432
+ transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 5)),
433
+ ],
434
+ p=argument_p,
435
+ ),
436
+ transforms.RandomApply(
437
+ transforms=[
438
+ RandomResize(size=224, resize_radio=(0.2, 1)),
439
+ ],
440
+ p=argument_p,
441
+ ),
442
+ transforms.Normalize(
443
+ mean=[0.48145466, 0.4578275, 0.40821073],
444
+ std=[0.26862954, 0.26130258, 0.27577711],
445
+ ),
446
+ ]
447
+ )
448
+ else:
449
+ raise ValueError(f"invalid {transform_type=}")
450
+
451
+ if isinstance(versions, str):
452
+ versions = (versions,)
453
+
454
+ # 如果直接把clips定位为当前类的子module,1. 会在保存ckp时存无用的多个权重。 2. pl会调用to,导致layer_norm的权重也被转换成fp16
455
+ clips = OrderedDict()
456
+
457
+ for v in versions:
458
+ # 因为clips不是子module,直接指定device="cuda"会错误地导致clip模型权重都被放到cuda:0上。
459
+ clips[v], _ = clip.load(name=v, device="cpu", jit=False, download_root=None)
460
+ delattr(clips[v], "transformer")
461
+ clips[v].eval()
462
+ clips[v].requires_grad_(False)
463
+
464
+ self.clips_hidden_dim = sum(clips[v].ln_final.weight.size(0) for v in clips)
465
+
466
+ if self.num_projection_vector == 0:
467
+ self.projection = nn.Identity()
468
+ else:
469
+ self.projection = nn.Linear(self.clips_hidden_dim, hidden_state_dim * self.num_projection_vector, bias=True)
470
+ self.projection.to(dtype=self.dtype)
471
+ nn.init.normal_(self.projection.weight, std=self.clips_hidden_dim ** -0.5)
472
+
473
+ self.clips = clips
474
+
475
+ self._move_flag = False
476
+
477
+ def move(self):
478
+ if self._move_flag:
479
+ return
480
+
481
+ def convert_weights(model: nn.Module):
482
+ """Convert applicable model parameters to fp16"""
483
+
484
+ def _convert_weights_to_fp16(l):
485
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
486
+ l.weight.data = l.weight.data.type(self.dtype)
487
+ if l.bias is not None:
488
+ l.bias.data = l.bias.data.type(self.dtype)
489
+
490
+ if isinstance(l, nn.MultiheadAttention):
491
+ for attr in [
492
+ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
493
+ "in_proj_bias",
494
+ "bias_k",
495
+ "bias_v",
496
+ ]:
497
+ tensor = getattr(l, attr)
498
+ if tensor is not None:
499
+ tensor.data = tensor.data.type(self.dtype)
500
+
501
+ for name in ["text_projection", "proj"]:
502
+ if hasattr(l, name):
503
+ attr = getattr(l, name)
504
+ if attr is not None:
505
+ attr.data = attr.data.type(self.dtype)
506
+
507
+ model.apply(_convert_weights_to_fp16)
508
+
509
+ for k in self.clips:
510
+ self.clips[k].to(self.device)
511
+ convert_weights(self.clips[k]) # fp32 -> self.dtype
512
+ self._move_flag = True
513
+
514
+ def unconditional_embedding(self, batch_size=None):
515
+ zero = torch.zeros(
516
+ batch_size,
517
+ self.clips_hidden_dim,
518
+ device=self.device,
519
+ dtype=self.dtype,
520
+ )
521
+ if self.num_projection_vector > 0:
522
+ zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
523
+ return zero
524
+
525
+ def convert_embedding(self, z):
526
+ if self.num_projection_vector > 0:
527
+ z = self.projection(z.type(self.projection.weight.dtype)).view(len(z), self.num_projection_vector, -1)
528
+ return z
529
+
530
+ def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
531
+ if value_range is not None:
532
+ low, high = value_range
533
+ image = (image - low) / (high - low)
534
+
535
+ image = self.transform(image)
536
+
537
+ with torch.no_grad():
538
+ embs = []
539
+ for v in self.clips:
540
+ x = self.clips[v].encode_image(image)
541
+ if self.normalize:
542
+ x = x / x.norm(p=2, dim=-1, keepdim=True) * (x.size(-1) ** 0.5)
543
+ # clip_max only works with normalization
544
+ if self.clip_max > 0:
545
+ x = x.clamp(-self.clip_max, self.clip_max)
546
+ embs.append(x)
547
+
548
+ z = torch.cat(embs, dim=-1)
549
+ if self.normalize:
550
+ z /= z.size(-1) ** 0.5
551
+
552
+ if zero_embedding_radio > 0:
553
+ mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
554
+ z = z + mask.to(z)
555
+
556
+ if self.num_projection_vector > 0:
557
+ z = self.projection(z).view(len(image), self.num_projection_vector, -1)
558
+ return z
559
+
560
+ def encode(self, image):
561
+ self.move()
562
+ return self(image, zero_embedding_radio=self.zero_embedding_radio)
primitive_anything/michelangelo/models/modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .checkpoint import checkpoint
primitive_anything/michelangelo/models/modules/checkpoint.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124
4
+ """
5
+
6
+ import torch
7
+ from typing import Callable, Iterable, Sequence, Union
8
+
9
+
10
+ def checkpoint(
11
+ func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
12
+ inputs: Sequence[torch.Tensor],
13
+ params: Iterable[torch.Tensor],
14
+ flag: bool,
15
+ use_deepspeed: bool = False
16
+ ):
17
+ """
18
+ Evaluate a function without caching intermediate activations, allowing for
19
+ reduced memory at the expense of extra compute in the backward pass.
20
+ :param func: the function to evaluate.
21
+ :param inputs: the argument sequence to pass to `func`.
22
+ :param params: a sequence of parameters `func` depends on but does not
23
+ explicitly take as arguments.
24
+ :param flag: if False, disable gradient checkpointing.
25
+ :param use_deepspeed: if True, use deepspeed
26
+ """
27
+ if flag:
28
+ if use_deepspeed:
29
+ import deepspeed
30
+ return deepspeed.checkpointing.checkpoint(func, *inputs)
31
+
32
+ args = tuple(inputs) + tuple(params)
33
+ return CheckpointFunction.apply(func, len(inputs), *args)
34
+ else:
35
+ return func(*inputs)
36
+
37
+
38
+ class CheckpointFunction(torch.autograd.Function):
39
+ @staticmethod
40
+ @torch.cuda.amp.custom_fwd
41
+ def forward(ctx, run_function, length, *args):
42
+ ctx.run_function = run_function
43
+ ctx.input_tensors = list(args[:length])
44
+ ctx.input_params = list(args[length:])
45
+
46
+ with torch.no_grad():
47
+ output_tensors = ctx.run_function(*ctx.input_tensors)
48
+ return output_tensors
49
+
50
+ @staticmethod
51
+ @torch.cuda.amp.custom_bwd
52
+ def backward(ctx, *output_grads):
53
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
54
+ with torch.enable_grad():
55
+ # Fixes a bug where the first op in run_function modifies the
56
+ # Tensor storage in place, which is not allowed for detach()'d
57
+ # Tensors.
58
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
59
+ output_tensors = ctx.run_function(*shallow_copies)
60
+ input_grads = torch.autograd.grad(
61
+ output_tensors,
62
+ ctx.input_tensors + ctx.input_params,
63
+ output_grads,
64
+ allow_unused=True,
65
+ )
66
+ del ctx.input_tensors
67
+ del ctx.input_params
68
+ del output_tensors
69
+ return (None, None) + input_grads
primitive_anything/michelangelo/models/modules/diffusion_transformer.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Optional
7
+
8
+ from .checkpoint import checkpoint
9
+ from .transformer_blocks import (
10
+ init_linear,
11
+ MLP,
12
+ MultiheadCrossAttention,
13
+ MultiheadAttention,
14
+ ResidualAttentionBlock
15
+ )
16
+
17
+
18
+ class AdaLayerNorm(nn.Module):
19
+ def __init__(self,
20
+ device: torch.device,
21
+ dtype: torch.dtype,
22
+ width: int):
23
+
24
+ super().__init__()
25
+
26
+ self.silu = nn.SiLU(inplace=True)
27
+ self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype)
28
+ self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype)
29
+
30
+ def forward(self, x, timestep):
31
+ emb = self.linear(timestep)
32
+ scale, shift = torch.chunk(emb, 2, dim=2)
33
+ x = self.layernorm(x) * (1 + scale) + shift
34
+ return x
35
+
36
+
37
+ class DitBlock(nn.Module):
38
+ def __init__(
39
+ self,
40
+ *,
41
+ device: torch.device,
42
+ dtype: torch.dtype,
43
+ n_ctx: int,
44
+ width: int,
45
+ heads: int,
46
+ context_dim: int,
47
+ qkv_bias: bool = False,
48
+ init_scale: float = 1.0,
49
+ use_checkpoint: bool = False
50
+ ):
51
+ super().__init__()
52
+
53
+ self.use_checkpoint = use_checkpoint
54
+
55
+ self.attn = MultiheadAttention(
56
+ device=device,
57
+ dtype=dtype,
58
+ n_ctx=n_ctx,
59
+ width=width,
60
+ heads=heads,
61
+ init_scale=init_scale,
62
+ qkv_bias=qkv_bias
63
+ )
64
+ self.ln_1 = AdaLayerNorm(device, dtype, width)
65
+
66
+ if context_dim is not None:
67
+ self.ln_2 = AdaLayerNorm(device, dtype, width)
68
+ self.cross_attn = MultiheadCrossAttention(
69
+ device=device,
70
+ dtype=dtype,
71
+ width=width,
72
+ heads=heads,
73
+ data_width=context_dim,
74
+ init_scale=init_scale,
75
+ qkv_bias=qkv_bias
76
+ )
77
+
78
+ self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
79
+ self.ln_3 = AdaLayerNorm(device, dtype, width)
80
+
81
+ def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
82
+ return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint)
83
+
84
+ def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
85
+ x = x + self.attn(self.ln_1(x, t))
86
+ if context is not None:
87
+ x = x + self.cross_attn(self.ln_2(x, t), context)
88
+ x = x + self.mlp(self.ln_3(x, t))
89
+ return x
90
+
91
+
92
+ class DiT(nn.Module):
93
+ def __init__(
94
+ self,
95
+ *,
96
+ device: Optional[torch.device],
97
+ dtype: Optional[torch.dtype],
98
+ n_ctx: int,
99
+ width: int,
100
+ layers: int,
101
+ heads: int,
102
+ context_dim: int,
103
+ init_scale: float = 0.25,
104
+ qkv_bias: bool = False,
105
+ use_checkpoint: bool = False
106
+ ):
107
+ super().__init__()
108
+ self.n_ctx = n_ctx
109
+ self.width = width
110
+ self.layers = layers
111
+
112
+ self.resblocks = nn.ModuleList(
113
+ [
114
+ DitBlock(
115
+ device=device,
116
+ dtype=dtype,
117
+ n_ctx=n_ctx,
118
+ width=width,
119
+ heads=heads,
120
+ context_dim=context_dim,
121
+ qkv_bias=qkv_bias,
122
+ init_scale=init_scale,
123
+ use_checkpoint=use_checkpoint
124
+ )
125
+ for _ in range(layers)
126
+ ]
127
+ )
128
+
129
+ def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
130
+ for block in self.resblocks:
131
+ x = block(x, t, context)
132
+ return x
133
+
134
+
135
+ class UNetDiffusionTransformer(nn.Module):
136
+ def __init__(
137
+ self,
138
+ *,
139
+ device: Optional[torch.device],
140
+ dtype: Optional[torch.dtype],
141
+ n_ctx: int,
142
+ width: int,
143
+ layers: int,
144
+ heads: int,
145
+ init_scale: float = 0.25,
146
+ qkv_bias: bool = False,
147
+ skip_ln: bool = False,
148
+ use_checkpoint: bool = False
149
+ ):
150
+ super().__init__()
151
+
152
+ self.n_ctx = n_ctx
153
+ self.width = width
154
+ self.layers = layers
155
+
156
+ self.encoder = nn.ModuleList()
157
+ for _ in range(layers):
158
+ resblock = ResidualAttentionBlock(
159
+ device=device,
160
+ dtype=dtype,
161
+ n_ctx=n_ctx,
162
+ width=width,
163
+ heads=heads,
164
+ init_scale=init_scale,
165
+ qkv_bias=qkv_bias,
166
+ use_checkpoint=use_checkpoint
167
+ )
168
+ self.encoder.append(resblock)
169
+
170
+ self.middle_block = ResidualAttentionBlock(
171
+ device=device,
172
+ dtype=dtype,
173
+ n_ctx=n_ctx,
174
+ width=width,
175
+ heads=heads,
176
+ init_scale=init_scale,
177
+ qkv_bias=qkv_bias,
178
+ use_checkpoint=use_checkpoint
179
+ )
180
+
181
+ self.decoder = nn.ModuleList()
182
+ for _ in range(layers):
183
+ resblock = ResidualAttentionBlock(
184
+ device=device,
185
+ dtype=dtype,
186
+ n_ctx=n_ctx,
187
+ width=width,
188
+ heads=heads,
189
+ init_scale=init_scale,
190
+ qkv_bias=qkv_bias,
191
+ use_checkpoint=use_checkpoint
192
+ )
193
+ linear = nn.Linear(width * 2, width, device=device, dtype=dtype)
194
+ init_linear(linear, init_scale)
195
+
196
+ layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None
197
+
198
+ self.decoder.append(nn.ModuleList([resblock, linear, layer_norm]))
199
+
200
+ def forward(self, x: torch.Tensor):
201
+
202
+ enc_outputs = []
203
+ for block in self.encoder:
204
+ x = block(x)
205
+ enc_outputs.append(x)
206
+
207
+ x = self.middle_block(x)
208
+
209
+ for i, (resblock, linear, layer_norm) in enumerate(self.decoder):
210
+ x = torch.cat([enc_outputs.pop(), x], dim=-1)
211
+ x = linear(x)
212
+
213
+ if layer_norm is not None:
214
+ x = layer_norm(x)
215
+
216
+ x = resblock(x)
217
+
218
+ return x