From 31b78e38f9b5b32c18746e822e4d47f7b1563c95 Mon Sep 17 00:00:00 2001 From: aufr33 <65520685+aufr33@users.noreply.github.com> Date: Sun, 26 Sep 2021 12:36:53 +0300 Subject: [PATCH] Delete nets_129605KB.py --- lib/nets_129605KB.py | 116 ------------------------------------------- 1 file changed, 116 deletions(-) delete mode 100644 lib/nets_129605KB.py diff --git a/lib/nets_129605KB.py b/lib/nets_129605KB.py deleted file mode 100644 index f90c19b..0000000 --- a/lib/nets_129605KB.py +++ /dev/null @@ -1,116 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F - -from lib import layers_129605KB as layers - - -class BaseASPPNet(nn.Module): - - def __init__(self, nin, ch, dilations=(4, 8, 16, 32)): - super(BaseASPPNet, self).__init__() - self.enc1 = layers.Encoder(nin, ch, 3, 2, 1) - self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1) - self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1) - self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1) - self.enc5 = layers.Encoder(ch * 8, ch * 16, 3, 2, 1) - - self.aspp = layers.ASPPModule(ch * 16, ch * 32, dilations) - - self.dec5 = layers.Decoder(ch * (16 + 32), ch * 16, 3, 1, 1) - self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1) - self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1) - self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1) - self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1) - - def __call__(self, x): - h, e1 = self.enc1(x) - h, e2 = self.enc2(h) - h, e3 = self.enc3(h) - h, e4 = self.enc4(h) - h, e5 = self.enc5(h) - - h = self.aspp(h) - - h = self.dec5(h, e5) - h = self.dec4(h, e4) - h = self.dec3(h, e3) - h = self.dec2(h, e2) - h = self.dec1(h, e1) - - return h - - -class CascadedASPPNet(nn.Module): - - def __init__(self, n_fft): - super(CascadedASPPNet, self).__init__() - self.stg1_low_band_net = BaseASPPNet(2, 16) - self.stg1_high_band_net = BaseASPPNet(2, 16) - - self.stg2_bridge = layers.Conv2DBNActiv(18, 8, 1, 1, 0) - self.stg2_full_band_net = BaseASPPNet(8, 16) - - self.stg3_bridge = layers.Conv2DBNActiv(34, 16, 1, 1, 0) - self.stg3_full_band_net = BaseASPPNet(16, 32) - - self.out = nn.Conv2d(32, 2, 1, bias=False) - self.aux1_out = nn.Conv2d(16, 2, 1, bias=False) - self.aux2_out = nn.Conv2d(16, 2, 1, bias=False) - - self.max_bin = n_fft // 2 - self.output_bin = n_fft // 2 + 1 - - self.offset = 128 - - def forward(self, x, aggressiveness=None): - mix = x.detach() - x = x.clone() - - x = x[:, :, :self.max_bin] - - bandw = x.size()[2] // 2 - aux1 = torch.cat([ - self.stg1_low_band_net(x[:, :, :bandw]), - self.stg1_high_band_net(x[:, :, bandw:]) - ], dim=2) - - h = torch.cat([x, aux1], dim=1) - aux2 = self.stg2_full_band_net(self.stg2_bridge(h)) - - h = torch.cat([x, aux1, aux2], dim=1) - h = self.stg3_full_band_net(self.stg3_bridge(h)) - - mask = torch.sigmoid(self.out(h)) - mask = F.pad( - input=mask, - pad=(0, 0, 0, self.output_bin - mask.size()[2]), - mode='replicate') - - if self.training: - aux1 = torch.sigmoid(self.aux1_out(aux1)) - aux1 = F.pad( - input=aux1, - pad=(0, 0, 0, self.output_bin - aux1.size()[2]), - mode='replicate') - aux2 = torch.sigmoid(self.aux2_out(aux2)) - aux2 = F.pad( - input=aux2, - pad=(0, 0, 0, self.output_bin - aux2.size()[2]), - mode='replicate') - return mask * mix, aux1 * mix, aux2 * mix - else: - if aggressiveness: - mask[:, :, :aggressiveness['split_bin']] = torch.pow(mask[:, :, :aggressiveness['split_bin']], 1 + aggressiveness['value'] / 3) - mask[:, :, aggressiveness['split_bin']:] = torch.pow(mask[:, :, aggressiveness['split_bin']:], 1 + aggressiveness['value']) - - return mask * mix - - def predict(self, x_mag, aggressiveness=None): - h = self.forward(x_mag, aggressiveness) - - if self.offset > 0: - h = h[:, :, :, self.offset:-self.offset] - assert h.size()[3] > 0 - - return h