Spaces:
Running
on
Zero
Running
on
Zero
from concurrent.futures import ThreadPoolExecutor | |
import time | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
from library.device_utils import clean_memory_on_device | |
def synchronize_device(device: torch.device): | |
if device.type == "cuda": | |
torch.cuda.synchronize() | |
elif device.type == "xpu": | |
torch.xpu.synchronize() | |
elif device.type == "mps": | |
torch.mps.synchronize() | |
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): | |
assert layer_to_cpu.__class__ == layer_to_cuda.__class__ | |
weight_swap_jobs = [] | |
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules | |
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): | |
# print(module_to_cpu.__class__, module_to_cuda.__class__) | |
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: | |
# weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) | |
modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()} | |
for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules(): | |
if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None: | |
module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None) | |
if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape: | |
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) | |
else: | |
if module_to_cuda.weight.data.device.type != device.type: | |
# print( | |
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device" | |
# ) | |
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device) | |
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value | |
stream = torch.cuda.Stream() | |
with torch.cuda.stream(stream): | |
# cuda to cpu | |
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: | |
cuda_data_view.record_stream(stream) | |
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) | |
stream.synchronize() | |
# cpu to cuda | |
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: | |
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) | |
module_to_cuda.weight.data = cuda_data_view | |
stream.synchronize() | |
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value | |
def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): | |
""" | |
not tested | |
""" | |
assert layer_to_cpu.__class__ == layer_to_cuda.__class__ | |
weight_swap_jobs = [] | |
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): | |
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: | |
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) | |
# device to cpu | |
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: | |
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) | |
synchronize_device() | |
# cpu to device | |
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: | |
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) | |
module_to_cuda.weight.data = cuda_data_view | |
synchronize_device() | |
def weighs_to_device(layer: nn.Module, device: torch.device): | |
for module in layer.modules(): | |
if hasattr(module, "weight") and module.weight is not None: | |
module.weight.data = module.weight.data.to(device, non_blocking=True) | |
class Offloader: | |
""" | |
common offloading class | |
""" | |
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): | |
self.num_blocks = num_blocks | |
self.blocks_to_swap = blocks_to_swap | |
self.device = device | |
self.debug = debug | |
self.thread_pool = ThreadPoolExecutor(max_workers=1) | |
self.futures = {} | |
self.cuda_available = device.type == "cuda" | |
def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module): | |
if self.cuda_available: | |
swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda) | |
else: | |
swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda) | |
def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda): | |
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): | |
if self.debug: | |
start_time = time.perf_counter() | |
print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}") | |
self.swap_weight_devices(block_to_cpu, block_to_cuda) | |
if self.debug: | |
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") | |
return bidx_to_cpu, bidx_to_cuda # , event | |
block_to_cpu = blocks[block_idx_to_cpu] | |
block_to_cuda = blocks[block_idx_to_cuda] | |
self.futures[block_idx_to_cuda] = self.thread_pool.submit( | |
move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda | |
) | |
def _wait_blocks_move(self, block_idx): | |
if block_idx not in self.futures: | |
return | |
if self.debug: | |
print(f"Wait for block {block_idx}") | |
start_time = time.perf_counter() | |
future = self.futures.pop(block_idx) | |
_, bidx_to_cuda = future.result() | |
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}" | |
if self.debug: | |
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") | |
class ModelOffloader(Offloader): | |
""" | |
supports forward offloading | |
""" | |
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): | |
super().__init__(num_blocks, blocks_to_swap, device, debug) | |
# register backward hooks | |
self.remove_handles = [] | |
for i, block in enumerate(blocks): | |
hook = self.create_backward_hook(blocks, i) | |
if hook is not None: | |
handle = block.register_full_backward_hook(hook) | |
self.remove_handles.append(handle) | |
def __del__(self): | |
for handle in self.remove_handles: | |
handle.remove() | |
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: | |
# -1 for 0-based index | |
num_blocks_propagated = self.num_blocks - block_index - 1 | |
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap | |
waiting = block_index > 0 and block_index <= self.blocks_to_swap | |
if not swapping and not waiting: | |
return None | |
# create hook | |
block_idx_to_cpu = self.num_blocks - num_blocks_propagated | |
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated | |
block_idx_to_wait = block_index - 1 | |
def backward_hook(module, grad_input, grad_output): | |
if self.debug: | |
print(f"Backward hook for block {block_index}") | |
if swapping: | |
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) | |
if waiting: | |
self._wait_blocks_move(block_idx_to_wait) | |
return None | |
return backward_hook | |
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): | |
if self.blocks_to_swap is None or self.blocks_to_swap == 0: | |
return | |
if self.debug: | |
print("Prepare block devices before forward") | |
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]: | |
b.to(self.device) | |
weighs_to_device(b, self.device) # make sure weights are on device | |
for b in blocks[self.num_blocks - self.blocks_to_swap :]: | |
b.to(self.device) # move block to device first | |
weighs_to_device(b, "cpu") # make sure weights are on cpu | |
synchronize_device(self.device) | |
clean_memory_on_device(self.device) | |
def wait_for_block(self, block_idx: int): | |
if self.blocks_to_swap is None or self.blocks_to_swap == 0: | |
return | |
self._wait_blocks_move(block_idx) | |
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int): | |
if self.blocks_to_swap is None or self.blocks_to_swap == 0: | |
return | |
if block_idx >= self.blocks_to_swap: | |
return | |
block_idx_to_cpu = block_idx | |
block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx | |
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) | |