Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .discriminator import Discriminator | |
| from .identity import Identity | |
| class MultiScaleDiscriminator(nn.Module): | |
| def __init__(self): | |
| super(MultiScaleDiscriminator, self).__init__() | |
| self.discriminators = nn.ModuleList( | |
| [Discriminator() for _ in range(3)] | |
| ) | |
| self.pooling = nn.ModuleList( | |
| [Identity()] + | |
| [nn.AvgPool1d(kernel_size=4, stride=2, padding=1, count_include_pad=False) for _ in range(1, 3)] | |
| ) | |
| def forward(self, x): | |
| ret = list() | |
| for pool, disc in zip(self.pooling, self.discriminators): | |
| x = pool(x) | |
| ret.append(disc(x)) | |
| return ret # [(feat, score), (feat, score), (feat, score)] | |