File size: 327 Bytes
1a030c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import torch.nn as nn
import torch
def model_device(m: nn.Module):
return next(iter(m.parameters())).device
def model_numel(m: nn.Module, requires_grad=False):
if requires_grad:
return sum(p.numel() for p in m.parameters() if p.requires_grad)
else:
return sum(p.numel() for p in m.parameters())
|