mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-28 01:10:56 +01:00
Add files via upload
This commit is contained in:
parent
ff919b039c
commit
f777f7aa69
@ -1,15 +1,11 @@
|
|||||||
from abc import ABCMeta
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from pytorch_lightning import LightningModule
|
|
||||||
from .modules import TFC_TDF
|
from .modules import TFC_TDF
|
||||||
|
from pytorch_lightning import LightningModule
|
||||||
|
|
||||||
dim_s = 4
|
dim_s = 4
|
||||||
|
|
||||||
class AbstractMDXNet(LightningModule):
|
class AbstractMDXNet(LightningModule):
|
||||||
__metaclass__ = ABCMeta
|
|
||||||
|
|
||||||
def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap):
|
def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.target_name = target_name
|
self.target_name = target_name
|
||||||
@ -24,7 +20,7 @@ class AbstractMDXNet(LightningModule):
|
|||||||
self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False)
|
self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False)
|
||||||
self.freq_pad = nn.Parameter(torch.zeros([1, dim_c, self.n_bins - self.dim_f, self.dim_t]), requires_grad=False)
|
self.freq_pad = nn.Parameter(torch.zeros([1, dim_c, self.n_bins - self.dim_f, self.dim_t]), requires_grad=False)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def get_optimizer(self):
|
||||||
if self.optimizer == 'rmsprop':
|
if self.optimizer == 'rmsprop':
|
||||||
return torch.optim.RMSprop(self.parameters(), self.lr)
|
return torch.optim.RMSprop(self.parameters(), self.lr)
|
||||||
|
|
||||||
@ -37,7 +33,7 @@ class ConvTDFNet(AbstractMDXNet):
|
|||||||
|
|
||||||
super(ConvTDFNet, self).__init__(
|
super(ConvTDFNet, self).__init__(
|
||||||
target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap)
|
target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap)
|
||||||
self.save_hyperparameters()
|
#self.save_hyperparameters()
|
||||||
|
|
||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
self.l = l
|
self.l = l
|
||||||
|
1033
lib_v5/spec_utils.py
1033
lib_v5/spec_utils.py
File diff suppressed because it is too large
Load Diff
234
lib_v5/tfc_tdf_v3.py
Normal file
234
lib_v5/tfc_tdf_v3.py
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
class STFT:
|
||||||
|
def __init__(self, n_fft, hop_length, dim_f):
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
|
||||||
|
self.dim_f = dim_f
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
window = self.window.to(x.device)
|
||||||
|
batch_dims = x.shape[:-2]
|
||||||
|
c, t = x.shape[-2:]
|
||||||
|
x = x.reshape([-1, t])
|
||||||
|
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True,return_complex=False)
|
||||||
|
x = x.permute([0, 3, 1, 2])
|
||||||
|
x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
|
||||||
|
return x[..., :self.dim_f, :]
|
||||||
|
|
||||||
|
def inverse(self, x):
|
||||||
|
window = self.window.to(x.device)
|
||||||
|
batch_dims = x.shape[:-3]
|
||||||
|
c, f, t = x.shape[-3:]
|
||||||
|
n = self.n_fft // 2 + 1
|
||||||
|
f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device)
|
||||||
|
x = torch.cat([x, f_pad], -2)
|
||||||
|
x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
|
||||||
|
x = x.permute([0, 2, 3, 1])
|
||||||
|
x = x[..., 0] + x[..., 1] * 1.j
|
||||||
|
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True)
|
||||||
|
x = x.reshape([*batch_dims, 2, -1])
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def get_norm(norm_type):
|
||||||
|
def norm(c, norm_type):
|
||||||
|
if norm_type == 'BatchNorm':
|
||||||
|
return nn.BatchNorm2d(c)
|
||||||
|
elif norm_type == 'InstanceNorm':
|
||||||
|
return nn.InstanceNorm2d(c, affine=True)
|
||||||
|
elif 'GroupNorm' in norm_type:
|
||||||
|
g = int(norm_type.replace('GroupNorm', ''))
|
||||||
|
return nn.GroupNorm(num_groups=g, num_channels=c)
|
||||||
|
else:
|
||||||
|
return nn.Identity()
|
||||||
|
|
||||||
|
return partial(norm, norm_type=norm_type)
|
||||||
|
|
||||||
|
|
||||||
|
def get_act(act_type):
|
||||||
|
if act_type == 'gelu':
|
||||||
|
return nn.GELU()
|
||||||
|
elif act_type == 'relu':
|
||||||
|
return nn.ReLU()
|
||||||
|
elif act_type[:3] == 'elu':
|
||||||
|
alpha = float(act_type.replace('elu', ''))
|
||||||
|
return nn.ELU(alpha)
|
||||||
|
else:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
|
||||||
|
class Upscale(nn.Module):
|
||||||
|
def __init__(self, in_c, out_c, scale, norm, act):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
norm(in_c),
|
||||||
|
act,
|
||||||
|
nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Downscale(nn.Module):
|
||||||
|
def __init__(self, in_c, out_c, scale, norm, act):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
norm(in_c),
|
||||||
|
act,
|
||||||
|
nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class TFC_TDF(nn.Module):
|
||||||
|
def __init__(self, in_c, c, l, f, bn, norm, act):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList()
|
||||||
|
for i in range(l):
|
||||||
|
block = nn.Module()
|
||||||
|
|
||||||
|
block.tfc1 = nn.Sequential(
|
||||||
|
norm(in_c),
|
||||||
|
act,
|
||||||
|
nn.Conv2d(in_c, c, 3, 1, 1, bias=False),
|
||||||
|
)
|
||||||
|
block.tdf = nn.Sequential(
|
||||||
|
norm(c),
|
||||||
|
act,
|
||||||
|
nn.Linear(f, f // bn, bias=False),
|
||||||
|
norm(c),
|
||||||
|
act,
|
||||||
|
nn.Linear(f // bn, f, bias=False),
|
||||||
|
)
|
||||||
|
block.tfc2 = nn.Sequential(
|
||||||
|
norm(c),
|
||||||
|
act,
|
||||||
|
nn.Conv2d(c, c, 3, 1, 1, bias=False),
|
||||||
|
)
|
||||||
|
block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False)
|
||||||
|
|
||||||
|
self.blocks.append(block)
|
||||||
|
in_c = c
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for block in self.blocks:
|
||||||
|
s = block.shortcut(x)
|
||||||
|
x = block.tfc1(x)
|
||||||
|
x = x + block.tdf(x)
|
||||||
|
x = block.tfc2(x)
|
||||||
|
x = x + s
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TFC_TDF_net(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
norm = get_norm(norm_type=config.model.norm)
|
||||||
|
act = get_act(act_type=config.model.act)
|
||||||
|
|
||||||
|
self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments)
|
||||||
|
self.num_subbands = config.model.num_subbands
|
||||||
|
|
||||||
|
dim_c = self.num_subbands * config.audio.num_channels * 2
|
||||||
|
n = config.model.num_scales
|
||||||
|
scale = config.model.scale
|
||||||
|
l = config.model.num_blocks_per_scale
|
||||||
|
c = config.model.num_channels
|
||||||
|
g = config.model.growth
|
||||||
|
bn = config.model.bottleneck_factor
|
||||||
|
f = config.audio.dim_f // self.num_subbands
|
||||||
|
|
||||||
|
self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
|
||||||
|
|
||||||
|
self.encoder_blocks = nn.ModuleList()
|
||||||
|
for i in range(n):
|
||||||
|
block = nn.Module()
|
||||||
|
block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act)
|
||||||
|
block.downscale = Downscale(c, c + g, scale, norm, act)
|
||||||
|
f = f // scale[1]
|
||||||
|
c += g
|
||||||
|
self.encoder_blocks.append(block)
|
||||||
|
|
||||||
|
self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act)
|
||||||
|
|
||||||
|
self.decoder_blocks = nn.ModuleList()
|
||||||
|
for i in range(n):
|
||||||
|
block = nn.Module()
|
||||||
|
block.upscale = Upscale(c, c - g, scale, norm, act)
|
||||||
|
f = f * scale[1]
|
||||||
|
c -= g
|
||||||
|
block.tfc_tdf = TFC_TDF(2 * c, c, l, f, bn, norm, act)
|
||||||
|
self.decoder_blocks.append(block)
|
||||||
|
|
||||||
|
self.final_conv = nn.Sequential(
|
||||||
|
nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
|
||||||
|
act,
|
||||||
|
nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.stft = STFT(config.audio.n_fft, config.audio.hop_length, config.audio.dim_f)
|
||||||
|
|
||||||
|
def cac2cws(self, x):
|
||||||
|
k = self.num_subbands
|
||||||
|
b, c, f, t = x.shape
|
||||||
|
x = x.reshape(b, c, k, f // k, t)
|
||||||
|
x = x.reshape(b, c * k, f // k, t)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def cws2cac(self, x):
|
||||||
|
k = self.num_subbands
|
||||||
|
b, c, f, t = x.shape
|
||||||
|
x = x.reshape(b, c // k, k, f, t)
|
||||||
|
x = x.reshape(b, c // k, f * k, t)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
x = self.stft(x)
|
||||||
|
|
||||||
|
mix = x = self.cac2cws(x)
|
||||||
|
|
||||||
|
first_conv_out = x = self.first_conv(x)
|
||||||
|
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
|
||||||
|
encoder_outputs = []
|
||||||
|
for block in self.encoder_blocks:
|
||||||
|
x = block.tfc_tdf(x)
|
||||||
|
encoder_outputs.append(x)
|
||||||
|
x = block.downscale(x)
|
||||||
|
|
||||||
|
x = self.bottleneck_block(x)
|
||||||
|
|
||||||
|
for block in self.decoder_blocks:
|
||||||
|
x = block.upscale(x)
|
||||||
|
x = torch.cat([x, encoder_outputs.pop()], 1)
|
||||||
|
x = block.tfc_tdf(x)
|
||||||
|
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
|
||||||
|
x = x * first_conv_out # reduce artifacts
|
||||||
|
|
||||||
|
x = self.final_conv(torch.cat([mix, x], 1))
|
||||||
|
|
||||||
|
x = self.cws2cac(x)
|
||||||
|
|
||||||
|
if self.num_target_instruments > 1:
|
||||||
|
b, c, f, t = x.shape
|
||||||
|
x = x.reshape(b, self.num_target_instruments, -1, f, t)
|
||||||
|
|
||||||
|
x = self.stft.inverse(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user