cavargas10 commited on
Commit
4f8013a
·
verified ·
1 Parent(s): 04fa6ac

Delete trellis/utils

Browse files
trellis/utils/__init__.py DELETED
File without changes
trellis/utils/data_utils.py DELETED
@@ -1,226 +0,0 @@
1
- from typing import *
2
- import math
3
- import torch
4
- import numpy as np
5
- from torch.utils.data import Sampler, Dataset, DataLoader, DistributedSampler
6
- import torch.distributed as dist
7
-
8
-
9
- def recursive_to_device(
10
- data: Any,
11
- device: torch.device,
12
- non_blocking: bool = False,
13
- ) -> Any:
14
- """
15
- Recursively move all tensors in a data structure to a device.
16
- """
17
- if hasattr(data, "to"):
18
- return data.to(device, non_blocking=non_blocking)
19
- elif isinstance(data, (list, tuple)):
20
- return type(data)(recursive_to_device(d, device, non_blocking) for d in data)
21
- elif isinstance(data, dict):
22
- return {k: recursive_to_device(v, device, non_blocking) for k, v in data.items()}
23
- else:
24
- return data
25
-
26
-
27
- def load_balanced_group_indices(
28
- load: List[int],
29
- num_groups: int,
30
- equal_size: bool = False,
31
- ) -> List[List[int]]:
32
- """
33
- Split indices into groups with balanced load.
34
- """
35
- if equal_size:
36
- group_size = len(load) // num_groups
37
- indices = np.argsort(load)[::-1]
38
- groups = [[] for _ in range(num_groups)]
39
- group_load = np.zeros(num_groups)
40
- for idx in indices:
41
- min_group_idx = np.argmin(group_load)
42
- groups[min_group_idx].append(idx)
43
- if equal_size and len(groups[min_group_idx]) == group_size:
44
- group_load[min_group_idx] = float('inf')
45
- else:
46
- group_load[min_group_idx] += load[idx]
47
- return groups
48
-
49
-
50
- def cycle(data_loader: DataLoader) -> Iterator:
51
- while True:
52
- for data in data_loader:
53
- if isinstance(data_loader.sampler, ResumableSampler):
54
- data_loader.sampler.idx += data_loader.batch_size # type: ignore[attr-defined]
55
- yield data
56
- if isinstance(data_loader.sampler, DistributedSampler):
57
- data_loader.sampler.epoch += 1
58
- if isinstance(data_loader.sampler, ResumableSampler):
59
- data_loader.sampler.epoch += 1
60
- data_loader.sampler.idx = 0
61
-
62
-
63
- class ResumableSampler(Sampler):
64
- """
65
- Distributed sampler that is resumable.
66
-
67
- Args:
68
- dataset: Dataset used for sampling.
69
- rank (int, optional): Rank of the current process within :attr:`num_replicas`.
70
- By default, :attr:`rank` is retrieved from the current distributed
71
- group.
72
- shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
73
- indices.
74
- seed (int, optional): random seed used to shuffle the sampler if
75
- :attr:`shuffle=True`. This number should be identical across all
76
- processes in the distributed group. Default: ``0``.
77
- drop_last (bool, optional): if ``True``, then the sampler will drop the
78
- tail of the data to make it evenly divisible across the number of
79
- replicas. If ``False``, the sampler will add extra indices to make
80
- the data evenly divisible across the replicas. Default: ``False``.
81
- """
82
-
83
- def __init__(
84
- self,
85
- dataset: Dataset,
86
- shuffle: bool = True,
87
- seed: int = 0,
88
- drop_last: bool = False,
89
- ) -> None:
90
- self.dataset = dataset
91
- self.epoch = 0
92
- self.idx = 0
93
- self.drop_last = drop_last
94
- self.world_size = dist.get_world_size() if dist.is_initialized() else 1
95
- self.rank = dist.get_rank() if dist.is_initialized() else 0
96
- # If the dataset length is evenly divisible by # of replicas, then there
97
- # is no need to drop any data, since the dataset will be split equally.
98
- if self.drop_last and len(self.dataset) % self.world_size != 0: # type: ignore[arg-type]
99
- # Split to nearest available length that is evenly divisible.
100
- # This is to ensure each rank receives the same amount of data when
101
- # using this Sampler.
102
- self.num_samples = math.ceil(
103
- (len(self.dataset) - self.world_size) / self.world_size # type: ignore[arg-type]
104
- )
105
- else:
106
- self.num_samples = math.ceil(len(self.dataset) / self.world_size) # type: ignore[arg-type]
107
- self.total_size = self.num_samples * self.world_size
108
- self.shuffle = shuffle
109
- self.seed = seed
110
-
111
- def __iter__(self) -> Iterator:
112
- if self.shuffle:
113
- # deterministically shuffle based on epoch and seed
114
- g = torch.Generator()
115
- g.manual_seed(self.seed + self.epoch)
116
- indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
117
- else:
118
- indices = list(range(len(self.dataset))) # type: ignore[arg-type]
119
-
120
- if not self.drop_last:
121
- # add extra samples to make it evenly divisible
122
- padding_size = self.total_size - len(indices)
123
- if padding_size <= len(indices):
124
- indices += indices[:padding_size]
125
- else:
126
- indices += (indices * math.ceil(padding_size / len(indices)))[
127
- :padding_size
128
- ]
129
- else:
130
- # remove tail of data to make it evenly divisible.
131
- indices = indices[: self.total_size]
132
- assert len(indices) == self.total_size
133
-
134
- # subsample
135
- indices = indices[self.rank : self.total_size : self.world_size]
136
-
137
- # resume from previous state
138
- indices = indices[self.idx:]
139
-
140
- return iter(indices)
141
-
142
- def __len__(self) -> int:
143
- return self.num_samples
144
-
145
- def state_dict(self) -> dict[str, int]:
146
- return {
147
- 'epoch': self.epoch,
148
- 'idx': self.idx,
149
- }
150
-
151
- def load_state_dict(self, state_dict):
152
- self.epoch = state_dict['epoch']
153
- self.idx = state_dict['idx']
154
-
155
-
156
- class BalancedResumableSampler(ResumableSampler):
157
- """
158
- Distributed sampler that is resumable and balances the load among the processes.
159
-
160
- Args:
161
- dataset: Dataset used for sampling.
162
- rank (int, optional): Rank of the current process within :attr:`num_replicas`.
163
- By default, :attr:`rank` is retrieved from the current distributed
164
- group.
165
- shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
166
- indices.
167
- seed (int, optional): random seed used to shuffle the sampler if
168
- :attr:`shuffle=True`. This number should be identical across all
169
- processes in the distributed group. Default: ``0``.
170
- drop_last (bool, optional): if ``True``, then the sampler will drop the
171
- tail of the data to make it evenly divisible across the number of
172
- replicas. If ``False``, the sampler will add extra indices to make
173
- the data evenly divisible across the replicas. Default: ``False``.
174
- """
175
-
176
- def __init__(
177
- self,
178
- dataset: Dataset,
179
- shuffle: bool = True,
180
- seed: int = 0,
181
- drop_last: bool = False,
182
- batch_size: int = 1,
183
- ) -> None:
184
- assert hasattr(dataset, 'loads'), 'Dataset must have "loads" attribute to use BalancedResumableSampler'
185
- super().__init__(dataset, shuffle, seed, drop_last)
186
- self.batch_size = batch_size
187
- self.loads = dataset.loads
188
-
189
- def __iter__(self) -> Iterator:
190
- if self.shuffle:
191
- # deterministically shuffle based on epoch and seed
192
- g = torch.Generator()
193
- g.manual_seed(self.seed + self.epoch)
194
- indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
195
- else:
196
- indices = list(range(len(self.dataset))) # type: ignore[arg-type]
197
-
198
- if not self.drop_last:
199
- # add extra samples to make it evenly divisible
200
- padding_size = self.total_size - len(indices)
201
- if padding_size <= len(indices):
202
- indices += indices[:padding_size]
203
- else:
204
- indices += (indices * math.ceil(padding_size / len(indices)))[
205
- :padding_size
206
- ]
207
- else:
208
- # remove tail of data to make it evenly divisible.
209
- indices = indices[: self.total_size]
210
- assert len(indices) == self.total_size
211
-
212
- # balance load among processes
213
- num_batches = len(indices) // (self.batch_size * self.world_size)
214
- balanced_indices = []
215
- for i in range(num_batches):
216
- start_idx = i * self.batch_size * self.world_size
217
- end_idx = (i + 1) * self.batch_size * self.world_size
218
- batch_indices = indices[start_idx:end_idx]
219
- batch_loads = [self.loads[idx] for idx in batch_indices]
220
- groups = load_balanced_group_indices(batch_loads, self.world_size, equal_size=True)
221
- balanced_indices.extend([batch_indices[j] for j in groups[self.rank]])
222
-
223
- # resume from previous state
224
- indices = balanced_indices[self.idx:]
225
-
226
- return iter(indices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/utils/dist_utils.py DELETED
@@ -1,93 +0,0 @@
1
- import os
2
- import io
3
- from contextlib import contextmanager
4
- import torch
5
- import torch.distributed as dist
6
- from torch.nn.parallel import DistributedDataParallel as DDP
7
-
8
-
9
- def setup_dist(rank, local_rank, world_size, master_addr, master_port):
10
- os.environ['MASTER_ADDR'] = master_addr
11
- os.environ['MASTER_PORT'] = master_port
12
- os.environ['WORLD_SIZE'] = str(world_size)
13
- os.environ['RANK'] = str(rank)
14
- os.environ['LOCAL_RANK'] = str(local_rank)
15
- torch.cuda.set_device(local_rank)
16
- dist.init_process_group('nccl', rank=rank, world_size=world_size)
17
-
18
-
19
- def read_file_dist(path):
20
- """
21
- Read the binary file distributedly.
22
- File is only read once by the rank 0 process and broadcasted to other processes.
23
-
24
- Returns:
25
- data (io.BytesIO): The binary data read from the file.
26
- """
27
- if dist.is_initialized() and dist.get_world_size() > 1:
28
- # read file
29
- size = torch.LongTensor(1).cuda()
30
- if dist.get_rank() == 0:
31
- with open(path, 'rb') as f:
32
- data = f.read()
33
- data = torch.ByteTensor(
34
- torch.UntypedStorage.from_buffer(data, dtype=torch.uint8)
35
- ).cuda()
36
- size[0] = data.shape[0]
37
- # broadcast size
38
- dist.broadcast(size, src=0)
39
- if dist.get_rank() != 0:
40
- data = torch.ByteTensor(size[0].item()).cuda()
41
- # broadcast data
42
- dist.broadcast(data, src=0)
43
- # convert to io.BytesIO
44
- data = data.cpu().numpy().tobytes()
45
- data = io.BytesIO(data)
46
- return data
47
- else:
48
- with open(path, 'rb') as f:
49
- data = f.read()
50
- data = io.BytesIO(data)
51
- return data
52
-
53
-
54
- def unwrap_dist(model):
55
- """
56
- Unwrap the model from distributed training.
57
- """
58
- if isinstance(model, DDP):
59
- return model.module
60
- return model
61
-
62
-
63
- @contextmanager
64
- def master_first():
65
- """
66
- A context manager that ensures master process executes first.
67
- """
68
- if not dist.is_initialized():
69
- yield
70
- else:
71
- if dist.get_rank() == 0:
72
- yield
73
- dist.barrier()
74
- else:
75
- dist.barrier()
76
- yield
77
-
78
-
79
- @contextmanager
80
- def local_master_first():
81
- """
82
- A context manager that ensures local master process executes first.
83
- """
84
- if not dist.is_initialized():
85
- yield
86
- else:
87
- if dist.get_rank() % torch.cuda.device_count() == 0:
88
- yield
89
- dist.barrier()
90
- else:
91
- dist.barrier()
92
- yield
93
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/utils/elastic_utils.py DELETED
@@ -1,228 +0,0 @@
1
- from abc import abstractmethod
2
- from contextlib import contextmanager
3
- from typing import Tuple
4
- import torch
5
- import torch.nn as nn
6
- import numpy as np
7
-
8
-
9
- class MemoryController:
10
- """
11
- Base class for memory management during training.
12
- """
13
-
14
- _last_input_size = None
15
- _last_mem_ratio = []
16
-
17
- @contextmanager
18
- def record(self):
19
- pass
20
-
21
- def update_run_states(self, input_size=None, mem_ratio=None):
22
- if self._last_input_size is None:
23
- self._last_input_size = input_size
24
- elif self._last_input_size!= input_size:
25
- raise ValueError(f'Input size should not change for different ElasticModules.')
26
- self._last_mem_ratio.append(mem_ratio)
27
-
28
- @abstractmethod
29
- def get_mem_ratio(self, input_size):
30
- pass
31
-
32
- @abstractmethod
33
- def state_dict(self):
34
- pass
35
-
36
- @abstractmethod
37
- def log(self):
38
- pass
39
-
40
-
41
- class LinearMemoryController(MemoryController):
42
- """
43
- A simple controller for memory management during training.
44
- The memory usage is modeled as a linear function of:
45
- - the number of input parameters
46
- - the ratio of memory the model use compared to the maximum usage (with no checkpointing)
47
- memory_usage = k * input_size * mem_ratio + b
48
- The controller keeps track of the memory usage and gives the
49
- expected memory ratio to keep the memory usage under a target
50
- """
51
- def __init__(
52
- self,
53
- buffer_size=1000,
54
- update_every=500,
55
- target_ratio=0.8,
56
- available_memory=None,
57
- max_mem_ratio_start=0.1,
58
- params=None,
59
- device=None
60
- ):
61
- self.buffer_size = buffer_size
62
- self.update_every = update_every
63
- self.target_ratio = target_ratio
64
- self.device = device or torch.cuda.current_device()
65
- self.available_memory = available_memory or torch.cuda.get_device_properties(self.device).total_memory / 1024**3
66
-
67
- self._memory = np.zeros(buffer_size, dtype=np.float32)
68
- self._input_size = np.zeros(buffer_size, dtype=np.float32)
69
- self._mem_ratio = np.zeros(buffer_size, dtype=np.float32)
70
- self._buffer_ptr = 0
71
- self._buffer_length = 0
72
- self._params = tuple(params) if params is not None else (0.0, 0.0)
73
- self._max_mem_ratio = max_mem_ratio_start
74
- self.step = 0
75
-
76
- def __repr__(self):
77
- return f'LinearMemoryController(target_ratio={self.target_ratio}, available_memory={self.available_memory})'
78
-
79
- def _add_sample(self, memory, input_size, mem_ratio):
80
- self._memory[self._buffer_ptr] = memory
81
- self._input_size[self._buffer_ptr] = input_size
82
- self._mem_ratio[self._buffer_ptr] = mem_ratio
83
- self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
84
- self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
85
-
86
- @contextmanager
87
- def record(self):
88
- torch.cuda.reset_peak_memory_stats(self.device)
89
- self._last_input_size = None
90
- self._last_mem_ratio = []
91
- yield
92
- self._last_memory = torch.cuda.max_memory_allocated(self.device) / 1024**3
93
- self._last_mem_ratio = sum(self._last_mem_ratio) / len(self._last_mem_ratio)
94
- self._add_sample(self._last_memory, self._last_input_size, self._last_mem_ratio)
95
- self.step += 1
96
- if self.step % self.update_every == 0:
97
- self._max_mem_ratio = min(1.0, self._max_mem_ratio + 0.1)
98
- self._fit_params()
99
-
100
- def _fit_params(self):
101
- memory_usage = self._memory[:self._buffer_length]
102
- input_size = self._input_size[:self._buffer_length]
103
- mem_ratio = self._mem_ratio[:self._buffer_length]
104
-
105
- x = input_size * mem_ratio
106
- y = memory_usage
107
- k, b = np.polyfit(x, y, 1)
108
- self._params = (k, b)
109
- # self._visualize()
110
-
111
- def _visualize(self):
112
- import matplotlib.pyplot as plt
113
- memory_usage = self._memory[:self._buffer_length]
114
- input_size = self._input_size[:self._buffer_length]
115
- mem_ratio = self._mem_ratio[:self._buffer_length]
116
- k, b = self._params
117
-
118
- plt.scatter(input_size * mem_ratio, memory_usage, c=mem_ratio, cmap='viridis')
119
- x = np.array([0.0, 20000.0])
120
- plt.plot(x, k * x + b, c='r')
121
- plt.savefig(f'linear_memory_controller_{self.step}.png')
122
- plt.cla()
123
-
124
- def get_mem_ratio(self, input_size):
125
- k, b = self._params
126
- if k == 0: return np.random.rand() * self._max_mem_ratio
127
- pred = (self.available_memory * self.target_ratio - b) / (k * input_size)
128
- return min(self._max_mem_ratio, max(0.0, pred))
129
-
130
- def state_dict(self):
131
- return {
132
- 'params': self._params,
133
- }
134
-
135
- def load_state_dict(self, state_dict):
136
- self._params = tuple(state_dict['params'])
137
-
138
- def log(self):
139
- return {
140
- 'params/k': self._params[0],
141
- 'params/b': self._params[1],
142
- 'memory': self._last_memory,
143
- 'input_size': self._last_input_size,
144
- 'mem_ratio': self._last_mem_ratio,
145
- }
146
-
147
-
148
- class ElasticModule(nn.Module):
149
- """
150
- Module for training with elastic memory management.
151
- """
152
- def __init__(self):
153
- super().__init__()
154
- self._memory_controller: MemoryController = None
155
-
156
- @abstractmethod
157
- def _get_input_size(self, *args, **kwargs) -> int:
158
- """
159
- Get the size of the input data.
160
-
161
- Returns:
162
- int: The size of the input data.
163
- """
164
- pass
165
-
166
- @abstractmethod
167
- def _forward_with_mem_ratio(self, *args, mem_ratio=0.0, **kwargs) -> Tuple[float, Tuple]:
168
- """
169
- Forward with a given memory ratio.
170
- """
171
- pass
172
-
173
- def register_memory_controller(self, memory_controller: MemoryController):
174
- self._memory_controller = memory_controller
175
-
176
- def forward(self, *args, **kwargs):
177
- if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
178
- _, ret = self._forward_with_mem_ratio(*args, **kwargs)
179
- else:
180
- input_size = self._get_input_size(*args, **kwargs)
181
- mem_ratio = self._memory_controller.get_mem_ratio(input_size)
182
- mem_ratio, ret = self._forward_with_mem_ratio(*args, mem_ratio=mem_ratio, **kwargs)
183
- self._memory_controller.update_run_states(input_size, mem_ratio)
184
- return ret
185
-
186
-
187
- class ElasticModuleMixin:
188
- """
189
- Mixin for training with elastic memory management.
190
- """
191
- def __init__(self, *args, **kwargs):
192
- super().__init__(*args, **kwargs)
193
- self._memory_controller: MemoryController = None
194
-
195
- @abstractmethod
196
- def _get_input_size(self, *args, **kwargs) -> int:
197
- """
198
- Get the size of the input data.
199
-
200
- Returns:
201
- int: The size of the input data.
202
- """
203
- pass
204
-
205
- @abstractmethod
206
- @contextmanager
207
- def with_mem_ratio(self, mem_ratio=1.0) -> float:
208
- """
209
- Context manager for training with a reduced memory ratio compared to the full memory usage.
210
-
211
- Returns:
212
- float: The exact memory ratio used during the forward pass.
213
- """
214
- pass
215
-
216
- def register_memory_controller(self, memory_controller: MemoryController):
217
- self._memory_controller = memory_controller
218
-
219
- def forward(self, *args, **kwargs):
220
- if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
221
- ret = super().forward(*args, **kwargs)
222
- else:
223
- input_size = self._get_input_size(*args, **kwargs)
224
- mem_ratio = self._memory_controller.get_mem_ratio(input_size)
225
- with self.with_mem_ratio(mem_ratio) as exact_mem_ratio:
226
- ret = super().forward(*args, **kwargs)
227
- self._memory_controller.update_run_states(input_size, exact_mem_ratio)
228
- return ret
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/utils/general_utils.py DELETED
@@ -1,201 +0,0 @@
1
- import re
2
- import numpy as np
3
- import cv2
4
- import torch
5
- import contextlib
6
-
7
-
8
- # Dictionary utils
9
- def _dict_merge(dicta, dictb, prefix=''):
10
- """
11
- Merge two dictionaries.
12
- """
13
- assert isinstance(dicta, dict), 'input must be a dictionary'
14
- assert isinstance(dictb, dict), 'input must be a dictionary'
15
- dict_ = {}
16
- all_keys = set(dicta.keys()).union(set(dictb.keys()))
17
- for key in all_keys:
18
- if key in dicta.keys() and key in dictb.keys():
19
- if isinstance(dicta[key], dict) and isinstance(dictb[key], dict):
20
- dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}')
21
- else:
22
- raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}')
23
- elif key in dicta.keys():
24
- dict_[key] = dicta[key]
25
- else:
26
- dict_[key] = dictb[key]
27
- return dict_
28
-
29
-
30
- def dict_merge(dicta, dictb):
31
- """
32
- Merge two dictionaries.
33
- """
34
- return _dict_merge(dicta, dictb, prefix='')
35
-
36
-
37
- def dict_foreach(dic, func, special_func={}):
38
- """
39
- Recursively apply a function to all non-dictionary leaf values in a dictionary.
40
- """
41
- assert isinstance(dic, dict), 'input must be a dictionary'
42
- for key in dic.keys():
43
- if isinstance(dic[key], dict):
44
- dic[key] = dict_foreach(dic[key], func)
45
- else:
46
- if key in special_func.keys():
47
- dic[key] = special_func[key](dic[key])
48
- else:
49
- dic[key] = func(dic[key])
50
- return dic
51
-
52
-
53
- def dict_reduce(dicts, func, special_func={}):
54
- """
55
- Reduce a list of dictionaries. Leaf values must be scalars.
56
- """
57
- assert isinstance(dicts, list), 'input must be a list of dictionaries'
58
- assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries'
59
- assert len(dicts) > 0, 'input must be a non-empty list of dictionaries'
60
- all_keys = set([key for dict_ in dicts for key in dict_.keys()])
61
- reduced_dict = {}
62
- for key in all_keys:
63
- vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()]
64
- if isinstance(vlist[0], dict):
65
- reduced_dict[key] = dict_reduce(vlist, func, special_func)
66
- else:
67
- if key in special_func.keys():
68
- reduced_dict[key] = special_func[key](vlist)
69
- else:
70
- reduced_dict[key] = func(vlist)
71
- return reduced_dict
72
-
73
-
74
- def dict_any(dic, func):
75
- """
76
- Recursively apply a function to all non-dictionary leaf values in a dictionary.
77
- """
78
- assert isinstance(dic, dict), 'input must be a dictionary'
79
- for key in dic.keys():
80
- if isinstance(dic[key], dict):
81
- if dict_any(dic[key], func):
82
- return True
83
- else:
84
- if func(dic[key]):
85
- return True
86
- return False
87
-
88
-
89
- def dict_all(dic, func):
90
- """
91
- Recursively apply a function to all non-dictionary leaf values in a dictionary.
92
- """
93
- assert isinstance(dic, dict), 'input must be a dictionary'
94
- for key in dic.keys():
95
- if isinstance(dic[key], dict):
96
- if not dict_all(dic[key], func):
97
- return False
98
- else:
99
- if not func(dic[key]):
100
- return False
101
- return True
102
-
103
-
104
- def dict_flatten(dic, sep='.'):
105
- """
106
- Flatten a nested dictionary into a dictionary with no nested dictionaries.
107
- """
108
- assert isinstance(dic, dict), 'input must be a dictionary'
109
- flat_dict = {}
110
- for key in dic.keys():
111
- if isinstance(dic[key], dict):
112
- sub_dict = dict_flatten(dic[key], sep=sep)
113
- for sub_key in sub_dict.keys():
114
- flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key]
115
- else:
116
- flat_dict[key] = dic[key]
117
- return flat_dict
118
-
119
-
120
- # Context utils
121
- @contextlib.contextmanager
122
- def nested_contexts(*contexts):
123
- with contextlib.ExitStack() as stack:
124
- for ctx in contexts:
125
- stack.enter_context(ctx())
126
- yield
127
-
128
-
129
- # Image utils
130
- def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
131
- num_images = len(images)
132
- if nrow is None and ncol is None:
133
- if aspect_ratio is not None:
134
- nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
135
- else:
136
- nrow = int(np.sqrt(num_images))
137
- ncol = (num_images + nrow - 1) // nrow
138
- elif nrow is None and ncol is not None:
139
- nrow = (num_images + ncol - 1) // ncol
140
- elif nrow is not None and ncol is None:
141
- ncol = (num_images + nrow - 1) // nrow
142
- else:
143
- assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
144
-
145
- if images[0].ndim == 2:
146
- grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype)
147
- else:
148
- grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
149
- for i, img in enumerate(images):
150
- row = i // ncol
151
- col = i % ncol
152
- grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
153
- return grid
154
-
155
-
156
- def notes_on_image(img, notes=None):
157
- img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
158
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
159
- if notes is not None:
160
- img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
161
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
162
- return img
163
-
164
-
165
- def save_image_with_notes(img, path, notes=None):
166
- """
167
- Save an image with notes.
168
- """
169
- if isinstance(img, torch.Tensor):
170
- img = img.cpu().numpy().transpose(1, 2, 0)
171
- if img.dtype == np.float32 or img.dtype == np.float64:
172
- img = np.clip(img * 255, 0, 255).astype(np.uint8)
173
- img = notes_on_image(img, notes)
174
- cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
175
-
176
-
177
- # debug utils
178
-
179
- def atol(x, y):
180
- """
181
- Absolute tolerance.
182
- """
183
- return torch.abs(x - y)
184
-
185
-
186
- def rtol(x, y):
187
- """
188
- Relative tolerance.
189
- """
190
- return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12)
191
-
192
-
193
- # print utils
194
- def indent(s, n=4):
195
- """
196
- Indent a string.
197
- """
198
- lines = s.split('\n')
199
- for i in range(1, len(lines)):
200
- lines[i] = ' ' * n + lines[i]
201
- return '\n'.join(lines)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/utils/grad_clip_utils.py DELETED
@@ -1,81 +0,0 @@
1
- from typing import *
2
- import torch
3
- import numpy as np
4
- import torch.utils
5
-
6
-
7
- class AdaptiveGradClipper:
8
- """
9
- Adaptive gradient clipping for training.
10
- """
11
- def __init__(
12
- self,
13
- max_norm=None,
14
- clip_percentile=95.0,
15
- buffer_size=1000,
16
- ):
17
- self.max_norm = max_norm
18
- self.clip_percentile = clip_percentile
19
- self.buffer_size = buffer_size
20
-
21
- self._grad_norm = np.zeros(buffer_size, dtype=np.float32)
22
- self._max_norm = max_norm
23
- self._buffer_ptr = 0
24
- self._buffer_length = 0
25
-
26
- def __repr__(self):
27
- return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})'
28
-
29
- def state_dict(self):
30
- return {
31
- 'grad_norm': self._grad_norm,
32
- 'max_norm': self._max_norm,
33
- 'buffer_ptr': self._buffer_ptr,
34
- 'buffer_length': self._buffer_length,
35
- }
36
-
37
- def load_state_dict(self, state_dict):
38
- self._grad_norm = state_dict['grad_norm']
39
- self._max_norm = state_dict['max_norm']
40
- self._buffer_ptr = state_dict['buffer_ptr']
41
- self._buffer_length = state_dict['buffer_length']
42
-
43
- def log(self):
44
- return {
45
- 'max_norm': self._max_norm,
46
- }
47
-
48
- def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None):
49
- """Clip the gradient norm of an iterable of parameters.
50
-
51
- The norm is computed over all gradients together, as if they were
52
- concatenated into a single vector. Gradients are modified in-place.
53
-
54
- Args:
55
- parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
56
- single Tensor that will have gradients normalized
57
- norm_type (float): type of the used p-norm. Can be ``'inf'`` for
58
- infinity norm.
59
- error_if_nonfinite (bool): if True, an error is thrown if the total
60
- norm of the gradients from :attr:`parameters` is ``nan``,
61
- ``inf``, or ``-inf``. Default: False (will switch to True in the future)
62
- foreach (bool): use the faster foreach-based implementation.
63
- If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
64
- fall back to the slow implementation for other device types.
65
- Default: ``None``
66
-
67
- Returns:
68
- Total norm of the parameter gradients (viewed as a single vector).
69
- """
70
- max_norm = self._max_norm if self._max_norm is not None else float('inf')
71
- grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach)
72
-
73
- if torch.isfinite(grad_norm):
74
- self._grad_norm[self._buffer_ptr] = grad_norm
75
- self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
76
- self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
77
- if self._buffer_length == self.buffer_size:
78
- self._max_norm = np.percentile(self._grad_norm, self.clip_percentile)
79
- self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm
80
-
81
- return grad_norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/utils/loss_utils.py DELETED
@@ -1,92 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch.autograd import Variable
4
- from math import exp
5
- from lpips import LPIPS
6
-
7
-
8
- def smooth_l1_loss(pred, target, beta=1.0):
9
- diff = torch.abs(pred - target)
10
- loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta)
11
- return loss.mean()
12
-
13
-
14
- def l1_loss(network_output, gt):
15
- return torch.abs((network_output - gt)).mean()
16
-
17
-
18
- def l2_loss(network_output, gt):
19
- return ((network_output - gt) ** 2).mean()
20
-
21
-
22
- def gaussian(window_size, sigma):
23
- gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
24
- return gauss / gauss.sum()
25
-
26
-
27
- def create_window(window_size, channel):
28
- _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
29
- _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
30
- window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
31
- return window
32
-
33
-
34
- def psnr(img1, img2, max_val=1.0):
35
- mse = F.mse_loss(img1, img2)
36
- return 20 * torch.log10(max_val / torch.sqrt(mse))
37
-
38
-
39
- def ssim(img1, img2, window_size=11, size_average=True):
40
- channel = img1.size(-3)
41
- window = create_window(window_size, channel)
42
-
43
- if img1.is_cuda:
44
- window = window.cuda(img1.get_device())
45
- window = window.type_as(img1)
46
-
47
- return _ssim(img1, img2, window, window_size, channel, size_average)
48
-
49
- def _ssim(img1, img2, window, window_size, channel, size_average=True):
50
- mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
51
- mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
52
-
53
- mu1_sq = mu1.pow(2)
54
- mu2_sq = mu2.pow(2)
55
- mu1_mu2 = mu1 * mu2
56
-
57
- sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
58
- sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
59
- sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
60
-
61
- C1 = 0.01 ** 2
62
- C2 = 0.03 ** 2
63
-
64
- ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
65
-
66
- if size_average:
67
- return ssim_map.mean()
68
- else:
69
- return ssim_map.mean(1).mean(1).mean(1)
70
-
71
-
72
- loss_fn_vgg = None
73
- def lpips(img1, img2, value_range=(0, 1)):
74
- global loss_fn_vgg
75
- if loss_fn_vgg is None:
76
- loss_fn_vgg = LPIPS(net='vgg').cuda().eval()
77
- # normalize to [-1, 1]
78
- img1 = (img1 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
79
- img2 = (img2 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
80
- return loss_fn_vgg(img1, img2).mean()
81
-
82
-
83
- def normal_angle(pred, gt):
84
- pred = pred * 2.0 - 1.0
85
- gt = gt * 2.0 - 1.0
86
- norms = pred.norm(dim=-1) * gt.norm(dim=-1)
87
- cos_sim = (pred * gt).sum(-1) / (norms + 1e-9)
88
- cos_sim = torch.clamp(cos_sim, -1.0, 1.0)
89
- ang = torch.rad2deg(torch.acos(cos_sim[norms > 1e-9])).mean()
90
- if ang.isnan():
91
- return -1
92
- return ang
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/utils/postprocessing_utils.py DELETED
@@ -1,587 +0,0 @@
1
- from typing import *
2
- import numpy as np
3
- import torch
4
- import utils3d
5
- import nvdiffrast.torch as dr
6
- from tqdm import tqdm
7
- import trimesh
8
- import trimesh.visual
9
- import xatlas
10
- import pyvista as pv
11
- from pymeshfix import _meshfix
12
- import igraph
13
- import cv2
14
- from PIL import Image
15
- from .random_utils import sphere_hammersley_sequence
16
- from .render_utils import render_multiview
17
- from ..renderers import GaussianRenderer
18
- from ..representations import Strivec, Gaussian, MeshExtractResult
19
-
20
-
21
- @torch.no_grad()
22
- def _fill_holes(
23
- verts,
24
- faces,
25
- max_hole_size=0.04,
26
- max_hole_nbe=32,
27
- resolution=128,
28
- num_views=500,
29
- debug=False,
30
- verbose=False
31
- ):
32
- """
33
- Rasterize a mesh from multiple views and remove invisible faces.
34
- Also includes postprocessing to:
35
- 1. Remove connected components that are have low visibility.
36
- 2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole.
37
-
38
- Args:
39
- verts (torch.Tensor): Vertices of the mesh. Shape (V, 3).
40
- faces (torch.Tensor): Faces of the mesh. Shape (F, 3).
41
- max_hole_size (float): Maximum area of a hole to fill.
42
- resolution (int): Resolution of the rasterization.
43
- num_views (int): Number of views to rasterize the mesh.
44
- verbose (bool): Whether to print progress.
45
- """
46
- # Construct cameras
47
- yaws = []
48
- pitchs = []
49
- for i in range(num_views):
50
- y, p = sphere_hammersley_sequence(i, num_views)
51
- yaws.append(y)
52
- pitchs.append(p)
53
- yaws = torch.tensor(yaws).cuda()
54
- pitchs = torch.tensor(pitchs).cuda()
55
- radius = 2.0
56
- fov = torch.deg2rad(torch.tensor(40)).cuda()
57
- projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3)
58
- views = []
59
- for (yaw, pitch) in zip(yaws, pitchs):
60
- orig = torch.tensor([
61
- torch.sin(yaw) * torch.cos(pitch),
62
- torch.cos(yaw) * torch.cos(pitch),
63
- torch.sin(pitch),
64
- ]).cuda().float() * radius
65
- view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
66
- views.append(view)
67
- views = torch.stack(views, dim=0)
68
-
69
- # Rasterize
70
- visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device)
71
- rastctx = utils3d.torch.RastContext(backend='cuda')
72
- for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'):
73
- view = views[i]
74
- buffers = utils3d.torch.rasterize_triangle_faces(
75
- rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection
76
- )
77
- face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1
78
- face_id = torch.unique(face_id).long()
79
- visblity[face_id] += 1
80
- visblity = visblity.float() / num_views
81
-
82
- # Mincut
83
- ## construct outer faces
84
- edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces)
85
- boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1)
86
- connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge)
87
- outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device)
88
- for i in range(len(connected_components)):
89
- outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5)
90
- outer_face_indices = outer_face_indices.nonzero().reshape(-1)
91
-
92
- ## construct inner faces
93
- inner_face_indices = torch.nonzero(visblity == 0).reshape(-1)
94
- if verbose:
95
- tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces')
96
- if inner_face_indices.shape[0] == 0:
97
- return verts, faces
98
-
99
- ## Construct dual graph (faces as nodes, edges as edges)
100
- dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge)
101
- dual_edge2edge = edges[dual_edge2edge]
102
- dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1)
103
- if verbose:
104
- tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges')
105
-
106
- ## solve mincut problem
107
- ### construct main graph
108
- g = igraph.Graph()
109
- g.add_vertices(faces.shape[0])
110
- g.add_edges(dual_edges.cpu().numpy())
111
- g.es['weight'] = dual_edges_weights.cpu().numpy()
112
-
113
- ### source and target
114
- g.add_vertex('s')
115
- g.add_vertex('t')
116
-
117
- ### connect invisible faces to source
118
- g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
119
-
120
- ### connect outer faces to target
121
- g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
122
-
123
- ### solve mincut
124
- cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist())
125
- remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device)
126
- if verbose:
127
- tqdm.write(f'Mincut solved, start checking the cut')
128
-
129
- ### check if the cut is valid with each connected component
130
- to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices])
131
- if debug:
132
- tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}')
133
- valid_remove_cc = []
134
- cutting_edges = []
135
- for cc in to_remove_cc:
136
- #### check if the connected component has low visibility
137
- visblity_median = visblity[remove_face_indices[cc]].median()
138
- if debug:
139
- tqdm.write(f'visblity_median: {visblity_median}')
140
- if visblity_median > 0.25:
141
- continue
142
-
143
- #### check if the cuting loop is small enough
144
- cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True)
145
- cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1]
146
- cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)]
147
- if len(cc_new_boundary_edge_indices) > 0:
148
- cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices])
149
- cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc]
150
- cc_new_boundary_edges_cc_area = []
151
- for i, edge_cc in enumerate(cc_new_boundary_edge_cc):
152
- _e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i]
153
- _e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i]
154
- cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5)
155
- if debug:
156
- cutting_edges.append(cc_new_boundary_edge_indices)
157
- tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}')
158
- if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]):
159
- continue
160
-
161
- valid_remove_cc.append(cc)
162
-
163
- if debug:
164
- face_v = verts[faces].mean(dim=1).cpu().numpy()
165
- vis_dual_edges = dual_edges.cpu().numpy()
166
- vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8)
167
- vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255]
168
- vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0]
169
- vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255]
170
- if len(valid_remove_cc) > 0:
171
- vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0]
172
- utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors)
173
-
174
- vis_verts = verts.cpu().numpy()
175
- vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy()
176
- utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges)
177
-
178
-
179
- if len(valid_remove_cc) > 0:
180
- remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)]
181
- mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device)
182
- mask[remove_face_indices] = 0
183
- faces = faces[mask]
184
- faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts)
185
- if verbose:
186
- tqdm.write(f'Removed {(~mask).sum()} faces by mincut')
187
- else:
188
- if verbose:
189
- tqdm.write(f'Removed 0 faces by mincut')
190
-
191
- mesh = _meshfix.PyTMesh()
192
- mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy())
193
- mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True)
194
- verts, faces = mesh.return_arrays()
195
- verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32)
196
-
197
- return verts, faces
198
-
199
-
200
- def postprocess_mesh(
201
- vertices: np.array,
202
- faces: np.array,
203
- simplify: bool = True,
204
- simplify_ratio: float = 0.9,
205
- fill_holes: bool = True,
206
- fill_holes_max_hole_size: float = 0.04,
207
- fill_holes_max_hole_nbe: int = 32,
208
- fill_holes_resolution: int = 1024,
209
- fill_holes_num_views: int = 1000,
210
- debug: bool = False,
211
- verbose: bool = False,
212
- ):
213
- """
214
- Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces.
215
-
216
- Args:
217
- vertices (np.array): Vertices of the mesh. Shape (V, 3).
218
- faces (np.array): Faces of the mesh. Shape (F, 3).
219
- simplify (bool): Whether to simplify the mesh, using quadric edge collapse.
220
- simplify_ratio (float): Ratio of faces to keep after simplification.
221
- fill_holes (bool): Whether to fill holes in the mesh.
222
- fill_holes_max_hole_size (float): Maximum area of a hole to fill.
223
- fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill.
224
- fill_holes_resolution (int): Resolution of the rasterization.
225
- fill_holes_num_views (int): Number of views to rasterize the mesh.
226
- verbose (bool): Whether to print progress.
227
- """
228
-
229
- if verbose:
230
- tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
231
-
232
- # Simplify
233
- if simplify and simplify_ratio > 0:
234
- mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1))
235
- mesh = mesh.decimate(simplify_ratio, progress_bar=verbose)
236
- vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:]
237
- if verbose:
238
- tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
239
-
240
- # Remove invisible faces
241
- if fill_holes:
242
- vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda()
243
- vertices, faces = _fill_holes(
244
- vertices, faces,
245
- max_hole_size=fill_holes_max_hole_size,
246
- max_hole_nbe=fill_holes_max_hole_nbe,
247
- resolution=fill_holes_resolution,
248
- num_views=fill_holes_num_views,
249
- debug=debug,
250
- verbose=verbose,
251
- )
252
- vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy()
253
- if verbose:
254
- tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
255
-
256
- return vertices, faces
257
-
258
-
259
- def parametrize_mesh(vertices: np.array, faces: np.array):
260
- """
261
- Parametrize a mesh to a texture space, using xatlas.
262
-
263
- Args:
264
- vertices (np.array): Vertices of the mesh. Shape (V, 3).
265
- faces (np.array): Faces of the mesh. Shape (F, 3).
266
- """
267
-
268
- vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
269
-
270
- vertices = vertices[vmapping]
271
- faces = indices
272
-
273
- return vertices, faces, uvs
274
-
275
-
276
- def bake_texture(
277
- vertices: np.array,
278
- faces: np.array,
279
- uvs: np.array,
280
- observations: List[np.array],
281
- masks: List[np.array],
282
- extrinsics: List[np.array],
283
- intrinsics: List[np.array],
284
- texture_size: int = 2048,
285
- near: float = 0.1,
286
- far: float = 10.0,
287
- mode: Literal['fast', 'opt'] = 'opt',
288
- lambda_tv: float = 1e-2,
289
- verbose: bool = False,
290
- ):
291
- """
292
- Bake texture to a mesh from multiple observations.
293
-
294
- Args:
295
- vertices (np.array): Vertices of the mesh. Shape (V, 3).
296
- faces (np.array): Faces of the mesh. Shape (F, 3).
297
- uvs (np.array): UV coordinates of the mesh. Shape (V, 2).
298
- observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3).
299
- masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W).
300
- extrinsics (List[np.array]): List of extrinsics. Shape (4, 4).
301
- intrinsics (List[np.array]): List of intrinsics. Shape (3, 3).
302
- texture_size (int): Size of the texture.
303
- near (float): Near plane of the camera.
304
- far (float): Far plane of the camera.
305
- mode (Literal['fast', 'opt']): Mode of texture baking.
306
- lambda_tv (float): Weight of total variation loss in optimization.
307
- verbose (bool): Whether to print progress.
308
- """
309
- vertices = torch.tensor(vertices).cuda()
310
- faces = torch.tensor(faces.astype(np.int32)).cuda()
311
- uvs = torch.tensor(uvs).cuda()
312
- observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations]
313
- masks = [torch.tensor(m>0).bool().cuda() for m in masks]
314
- views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics]
315
- projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics]
316
-
317
- if mode == 'fast':
318
- texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda()
319
- texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda()
320
- rastctx = utils3d.torch.RastContext(backend='cuda')
321
- for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'):
322
- with torch.no_grad():
323
- rast = utils3d.torch.rasterize_triangle_faces(
324
- rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
325
- )
326
- uv_map = rast['uv'][0].detach().flip(0)
327
- mask = rast['mask'][0].detach().bool() & masks[0]
328
-
329
- # nearest neighbor interpolation
330
- uv_map = (uv_map * texture_size).floor().long()
331
- obs = observation[mask]
332
- uv_map = uv_map[mask]
333
- idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
334
- texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs)
335
- texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device))
336
-
337
- mask = texture_weights > 0
338
- texture[mask] /= texture_weights[mask][:, None]
339
- texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8)
340
-
341
- # inpaint
342
- mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size)
343
- texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
344
-
345
- elif mode == 'opt':
346
- rastctx = utils3d.torch.RastContext(backend='cuda')
347
- observations = [observations.flip(0) for observations in observations]
348
- masks = [m.flip(0) for m in masks]
349
- _uv = []
350
- _uv_dr = []
351
- for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'):
352
- with torch.no_grad():
353
- rast = utils3d.torch.rasterize_triangle_faces(
354
- rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
355
- )
356
- _uv.append(rast['uv'].detach())
357
- _uv_dr.append(rast['uv_dr'].detach())
358
-
359
- texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda())
360
- optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
361
-
362
- def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
363
- return start_lr * (end_lr / start_lr) ** (step / total_steps)
364
-
365
- def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
366
- return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
367
-
368
- def tv_loss(texture):
369
- return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \
370
- torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :])
371
-
372
- total_steps = 2500
373
- with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar:
374
- for step in range(total_steps):
375
- optimizer.zero_grad()
376
- selected = np.random.randint(0, len(views))
377
- uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected]
378
- render = dr.texture(texture, uv, uv_dr)[0]
379
- loss = torch.nn.functional.l1_loss(render[mask], observation[mask])
380
- if lambda_tv > 0:
381
- loss += lambda_tv * tv_loss(texture)
382
- loss.backward()
383
- optimizer.step()
384
- # annealing
385
- optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5)
386
- pbar.set_postfix({'loss': loss.item()})
387
- pbar.update()
388
- texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)
389
- mask = 1 - utils3d.torch.rasterize_triangle_faces(
390
- rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size
391
- )['mask'][0].detach().cpu().numpy().astype(np.uint8)
392
- texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
393
- else:
394
- raise ValueError(f'Unknown mode: {mode}')
395
-
396
- return texture
397
-
398
-
399
- def to_glb(
400
- app_rep: Union[Strivec, Gaussian],
401
- mesh: MeshExtractResult,
402
- simplify: float = 0.95,
403
- fill_holes: bool = True,
404
- fill_holes_max_size: float = 0.04,
405
- texture_size: int = 1024,
406
- debug: bool = False,
407
- verbose: bool = True,
408
- ) -> trimesh.Trimesh:
409
- """
410
- Convert a generated asset to a glb file.
411
-
412
- Args:
413
- app_rep (Union[Strivec, Gaussian]): Appearance representation.
414
- mesh (MeshExtractResult): Extracted mesh.
415
- simplify (float): Ratio of faces to remove in simplification.
416
- fill_holes (bool): Whether to fill holes in the mesh.
417
- fill_holes_max_size (float): Maximum area of a hole to fill.
418
- texture_size (int): Size of the texture.
419
- debug (bool): Whether to print debug information.
420
- verbose (bool): Whether to print progress.
421
- """
422
- vertices = mesh.vertices.cpu().numpy()
423
- faces = mesh.faces.cpu().numpy()
424
-
425
- # mesh postprocess
426
- vertices, faces = postprocess_mesh(
427
- vertices, faces,
428
- simplify=simplify > 0,
429
- simplify_ratio=simplify,
430
- fill_holes=fill_holes,
431
- fill_holes_max_hole_size=fill_holes_max_size,
432
- fill_holes_max_hole_nbe=int(250 * np.sqrt(1-simplify)),
433
- fill_holes_resolution=1024,
434
- fill_holes_num_views=1000,
435
- debug=debug,
436
- verbose=verbose,
437
- )
438
-
439
- # parametrize mesh
440
- vertices, faces, uvs = parametrize_mesh(vertices, faces)
441
-
442
- # bake texture
443
- observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100)
444
- masks = [np.any(observation > 0, axis=-1) for observation in observations]
445
- extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
446
- intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
447
- texture = bake_texture(
448
- vertices, faces, uvs,
449
- observations, masks, extrinsics, intrinsics,
450
- texture_size=texture_size, mode='opt',
451
- lambda_tv=0.01,
452
- verbose=verbose
453
- )
454
- texture = Image.fromarray(texture)
455
-
456
- # rotate mesh (from z-up to y-up)
457
- vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
458
- material = trimesh.visual.material.PBRMaterial(
459
- roughnessFactor=1.0,
460
- baseColorTexture=texture,
461
- baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8)
462
- )
463
- mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material))
464
- return mesh
465
-
466
-
467
- def simplify_gs(
468
- gs: Gaussian,
469
- simplify: float = 0.95,
470
- verbose: bool = True,
471
- ):
472
- """
473
- Simplify 3D Gaussians
474
- NOTE: this function is not used in the current implementation for the unsatisfactory performance.
475
-
476
- Args:
477
- gs (Gaussian): 3D Gaussian.
478
- simplify (float): Ratio of Gaussians to remove in simplification.
479
- """
480
- if simplify <= 0:
481
- return gs
482
-
483
- # simplify
484
- observations, extrinsics, intrinsics = render_multiview(gs, resolution=1024, nviews=100)
485
- observations = [torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) for obs in observations]
486
-
487
- # Following https://arxiv.org/pdf/2411.06019
488
- renderer = GaussianRenderer({
489
- "resolution": 1024,
490
- "near": 0.8,
491
- "far": 1.6,
492
- "ssaa": 1,
493
- "bg_color": (0,0,0),
494
- })
495
- new_gs = Gaussian(**gs.init_params)
496
- new_gs._features_dc = gs._features_dc.clone()
497
- new_gs._features_rest = gs._features_rest.clone() if gs._features_rest is not None else None
498
- new_gs._opacity = torch.nn.Parameter(gs._opacity.clone())
499
- new_gs._rotation = torch.nn.Parameter(gs._rotation.clone())
500
- new_gs._scaling = torch.nn.Parameter(gs._scaling.clone())
501
- new_gs._xyz = torch.nn.Parameter(gs._xyz.clone())
502
-
503
- start_lr = [1e-4, 1e-3, 5e-3, 0.025]
504
- end_lr = [1e-6, 1e-5, 5e-5, 0.00025]
505
- optimizer = torch.optim.Adam([
506
- {"params": new_gs._xyz, "lr": start_lr[0]},
507
- {"params": new_gs._rotation, "lr": start_lr[1]},
508
- {"params": new_gs._scaling, "lr": start_lr[2]},
509
- {"params": new_gs._opacity, "lr": start_lr[3]},
510
- ], lr=start_lr[0])
511
-
512
- def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
513
- return start_lr * (end_lr / start_lr) ** (step / total_steps)
514
-
515
- def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
516
- return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
517
-
518
- _zeta = new_gs.get_opacity.clone().detach().squeeze()
519
- _lambda = torch.zeros_like(_zeta)
520
- _delta = 1e-7
521
- _interval = 10
522
- num_target = int((1 - simplify) * _zeta.shape[0])
523
-
524
- with tqdm(total=2500, disable=not verbose, desc='Simplifying Gaussian') as pbar:
525
- for i in range(2500):
526
- # prune
527
- if i % 100 == 0:
528
- mask = new_gs.get_opacity.squeeze() > 0.05
529
- mask = torch.nonzero(mask).squeeze()
530
- new_gs._xyz = torch.nn.Parameter(new_gs._xyz[mask])
531
- new_gs._rotation = torch.nn.Parameter(new_gs._rotation[mask])
532
- new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask])
533
- new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask])
534
- new_gs._features_dc = new_gs._features_dc[mask]
535
- new_gs._features_rest = new_gs._features_rest[mask] if new_gs._features_rest is not None else None
536
- _zeta = _zeta[mask]
537
- _lambda = _lambda[mask]
538
- # update optimizer state
539
- for param_group, new_param in zip(optimizer.param_groups, [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity]):
540
- stored_state = optimizer.state[param_group['params'][0]]
541
- if 'exp_avg' in stored_state:
542
- stored_state['exp_avg'] = stored_state['exp_avg'][mask]
543
- stored_state['exp_avg_sq'] = stored_state['exp_avg_sq'][mask]
544
- del optimizer.state[param_group['params'][0]]
545
- param_group['params'][0] = new_param
546
- optimizer.state[param_group['params'][0]] = stored_state
547
-
548
- opacity = new_gs.get_opacity.squeeze()
549
-
550
- # sparisfy
551
- if i % _interval == 0:
552
- _zeta = _lambda + opacity.detach()
553
- if opacity.shape[0] > num_target:
554
- index = _zeta.topk(num_target)[1]
555
- _m = torch.ones_like(_zeta, dtype=torch.bool)
556
- _m[index] = 0
557
- _zeta[_m] = 0
558
- _lambda = _lambda + opacity.detach() - _zeta
559
-
560
- # sample a random view
561
- view_idx = np.random.randint(len(observations))
562
- observation = observations[view_idx]
563
- extrinsic = extrinsics[view_idx]
564
- intrinsic = intrinsics[view_idx]
565
-
566
- color = renderer.render(new_gs, extrinsic, intrinsic)['color']
567
- rgb_loss = torch.nn.functional.l1_loss(color, observation)
568
- loss = rgb_loss + \
569
- _delta * torch.sum(torch.pow(_lambda + opacity - _zeta, 2))
570
-
571
- optimizer.zero_grad()
572
- loss.backward()
573
- optimizer.step()
574
-
575
- # update lr
576
- for j in range(len(optimizer.param_groups)):
577
- optimizer.param_groups[j]['lr'] = cosine_anealing(optimizer, i, 2500, start_lr[j], end_lr[j])
578
-
579
- pbar.set_postfix({'loss': rgb_loss.item(), 'num': opacity.shape[0], 'lambda': _lambda.mean().item()})
580
- pbar.update()
581
-
582
- new_gs._xyz = new_gs._xyz.data
583
- new_gs._rotation = new_gs._rotation.data
584
- new_gs._scaling = new_gs._scaling.data
585
- new_gs._opacity = new_gs._opacity.data
586
-
587
- return new_gs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/utils/random_utils.py DELETED
@@ -1,30 +0,0 @@
1
- import numpy as np
2
-
3
- PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
4
-
5
- def radical_inverse(base, n):
6
- val = 0
7
- inv_base = 1.0 / base
8
- inv_base_n = inv_base
9
- while n > 0:
10
- digit = n % base
11
- val += digit * inv_base_n
12
- n //= base
13
- inv_base_n *= inv_base
14
- return val
15
-
16
- def halton_sequence(dim, n):
17
- return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
18
-
19
- def hammersley_sequence(dim, n, num_samples):
20
- return [n / num_samples] + halton_sequence(dim - 1, n)
21
-
22
- def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False):
23
- u, v = hammersley_sequence(2, n, num_samples)
24
- u += offset[0] / num_samples
25
- v += offset[1]
26
- if remap:
27
- u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
28
- theta = np.arccos(1 - 2 * u) - np.pi / 2
29
- phi = v * 2 * np.pi
30
- return [phi, theta]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/utils/render_utils.py DELETED
@@ -1,120 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from tqdm import tqdm
4
- import utils3d
5
- from PIL import Image
6
-
7
- from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer
8
- from ..representations import Octree, Gaussian, MeshExtractResult
9
- from ..modules import sparse as sp
10
- from .random_utils import sphere_hammersley_sequence
11
-
12
-
13
- def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs):
14
- is_list = isinstance(yaws, list)
15
- if not is_list:
16
- yaws = [yaws]
17
- pitchs = [pitchs]
18
- if not isinstance(rs, list):
19
- rs = [rs] * len(yaws)
20
- if not isinstance(fovs, list):
21
- fovs = [fovs] * len(yaws)
22
- extrinsics = []
23
- intrinsics = []
24
- for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs):
25
- fov = torch.deg2rad(torch.tensor(float(fov))).cuda()
26
- yaw = torch.tensor(float(yaw)).cuda()
27
- pitch = torch.tensor(float(pitch)).cuda()
28
- orig = torch.tensor([
29
- torch.sin(yaw) * torch.cos(pitch),
30
- torch.cos(yaw) * torch.cos(pitch),
31
- torch.sin(pitch),
32
- ]).cuda() * r
33
- extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
34
- intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
35
- extrinsics.append(extr)
36
- intrinsics.append(intr)
37
- if not is_list:
38
- extrinsics = extrinsics[0]
39
- intrinsics = intrinsics[0]
40
- return extrinsics, intrinsics
41
-
42
-
43
- def get_renderer(sample, **kwargs):
44
- if isinstance(sample, Octree):
45
- renderer = OctreeRenderer()
46
- renderer.rendering_options.resolution = kwargs.get('resolution', 512)
47
- renderer.rendering_options.near = kwargs.get('near', 0.8)
48
- renderer.rendering_options.far = kwargs.get('far', 1.6)
49
- renderer.rendering_options.bg_color = kwargs.get('bg_color', (0, 0, 0))
50
- renderer.rendering_options.ssaa = kwargs.get('ssaa', 4)
51
- renderer.pipe.primitive = sample.primitive
52
- elif isinstance(sample, Gaussian):
53
- renderer = GaussianRenderer()
54
- renderer.rendering_options.resolution = kwargs.get('resolution', 512)
55
- renderer.rendering_options.near = kwargs.get('near', 0.8)
56
- renderer.rendering_options.far = kwargs.get('far', 1.6)
57
- renderer.rendering_options.bg_color = kwargs.get('bg_color', (0, 0, 0))
58
- renderer.rendering_options.ssaa = kwargs.get('ssaa', 1)
59
- renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1)
60
- renderer.pipe.use_mip_gaussian = True
61
- elif isinstance(sample, MeshExtractResult):
62
- renderer = MeshRenderer()
63
- renderer.rendering_options.resolution = kwargs.get('resolution', 512)
64
- renderer.rendering_options.near = kwargs.get('near', 1)
65
- renderer.rendering_options.far = kwargs.get('far', 100)
66
- renderer.rendering_options.ssaa = kwargs.get('ssaa', 4)
67
- else:
68
- raise ValueError(f'Unsupported sample type: {type(sample)}')
69
- return renderer
70
-
71
-
72
- def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs):
73
- renderer = get_renderer(sample, **options)
74
- rets = {}
75
- for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose):
76
- if isinstance(sample, MeshExtractResult):
77
- res = renderer.render(sample, extr, intr)
78
- if 'normal' not in rets: rets['normal'] = []
79
- rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
80
- else:
81
- res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite)
82
- if 'color' not in rets: rets['color'] = []
83
- if 'depth' not in rets: rets['depth'] = []
84
- rets['color'].append(np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
85
- if 'percent_depth' in res:
86
- rets['depth'].append(res['percent_depth'].detach().cpu().numpy())
87
- elif 'depth' in res:
88
- rets['depth'].append(res['depth'].detach().cpu().numpy())
89
- else:
90
- rets['depth'].append(None)
91
- return rets
92
-
93
-
94
- def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs):
95
- yaws = torch.linspace(0, 2 * 3.1415, num_frames)
96
- pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames))
97
- yaws = yaws.tolist()
98
- pitch = pitch.tolist()
99
- extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov)
100
- return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
101
-
102
-
103
- def render_multiview(sample, resolution=512, nviews=30):
104
- r = 2
105
- fov = 40
106
- cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)]
107
- yaws = [cam[0] for cam in cams]
108
- pitchs = [cam[1] for cam in cams]
109
- extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov)
110
- res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)})
111
- return res['color'], extrinsics, intrinsics
112
-
113
-
114
- def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, **kwargs):
115
- yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
116
- yaw_offset = offset[0]
117
- yaw = [y + yaw_offset for y in yaw]
118
- pitch = [offset[1] for _ in range(4)]
119
- extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov)
120
- return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)