Add files via upload

This commit is contained in:
Anjok07 2023-04-12 02:13:30 -05:00 committed by GitHub
parent 18d32660db
commit 6ffd7a244e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 44 deletions

View File

@ -118,7 +118,7 @@ class CascadedASPPNet(nn.Module):
self.offset = 128
def forward(self, x, aggressiveness=None):
def forward(self, x):
mix = x.detach()
x = x.clone()
@ -155,17 +155,12 @@ class CascadedASPPNet(nn.Module):
mode='replicate')
return mask * mix, aux1 * mix, aux2 * mix
else:
if aggressiveness:
mask[:, :, :aggressiveness['split_bin']] = torch.pow(mask[:, :, :aggressiveness['split_bin']], 1 + aggressiveness['value'] / 3)
mask[:, :, aggressiveness['split_bin']:] = torch.pow(mask[:, :, aggressiveness['split_bin']:], 1 + aggressiveness['value'])
return mask# * mix
return mask * mix
def predict(self, x_mag, aggressiveness=None):
h = self.forward(x_mag, aggressiveness)
def predict_mask(self, x):
mask = self.forward(x)
if self.offset > 0:
h = h[:, :, :, self.offset:-self.offset]
assert h.size()[3] > 0
mask = mask[:, :, :, self.offset:-self.offset]
return h
return mask

View File

@ -40,50 +40,32 @@ class BaseNet(nn.Module):
class CascadedNet(nn.Module):
def __init__(self, n_fft, nn_architecture):
def __init__(self, n_fft, nn_arch_size, nout=32, nout_lstm=128):
super(CascadedNet, self).__init__()
self.max_bin = n_fft // 2
self.output_bin = n_fft // 2 + 1
self.nin_lstm = self.max_bin // 2
self.offset = 64
self.nn_architecture = nn_architecture
nout = 64 if nn_arch_size == 218409 else nout
print('ARC SIZE: ', nn_architecture)
if nn_architecture == 218409:
self.stg1_low_band_net = nn.Sequential(
BaseNet(2, 32, self.nin_lstm // 2, 128),
layers.Conv2DBNActiv(32, 16, 1, 1, 0)
self.stg1_low_band_net = nn.Sequential(
BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm),
layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0)
)
self.stg1_high_band_net = BaseNet(2, 16, self.nin_lstm // 2, 64)
self.stg1_high_band_net = BaseNet(2, nout // 4, self.nin_lstm // 2, nout_lstm // 2)
self.stg2_low_band_net = nn.Sequential(
BaseNet(18, 64, self.nin_lstm // 2, 128),
layers.Conv2DBNActiv(64, 32, 1, 1, 0)
self.stg2_low_band_net = nn.Sequential(
BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm),
layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0)
)
self.stg2_high_band_net = BaseNet(18, 32, self.nin_lstm // 2, 64)
self.stg2_high_band_net = BaseNet(nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2)
self.stg3_full_band_net = BaseNet(50, 64, self.nin_lstm, 128)
self.stg3_full_band_net = BaseNet(3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm)
self.out = nn.Conv2d(64, 2, 1, bias=False)
self.aux_out = nn.Conv2d(48, 2, 1, bias=False)
else:
self.stg1_low_band_net = nn.Sequential(
BaseNet(2, 16, self.nin_lstm // 2, 128),
layers.Conv2DBNActiv(16, 8, 1, 1, 0)
)
self.stg1_high_band_net = BaseNet(2, 8, self.nin_lstm // 2, 64)
self.stg2_low_band_net = nn.Sequential(
BaseNet(10, 32, self.nin_lstm // 2, 128),
layers.Conv2DBNActiv(32, 16, 1, 1, 0)
)
self.stg2_high_band_net = BaseNet(10, 16, self.nin_lstm // 2, 64)
self.stg3_full_band_net = BaseNet(26, 32, self.nin_lstm, 128)
self.out = nn.Conv2d(32, 2, 1, bias=False)
self.aux_out = nn.Conv2d(24, 2, 1, bias=False)
self.out = nn.Conv2d(nout, 2, 1, bias=False)
self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False)
def forward(self, x):
x = x[:, :, :self.max_bin]