diff --git a/UVR.py b/UVR.py index b792e8e..1be154a 100644 --- a/UVR.py +++ b/UVR.py @@ -599,7 +599,7 @@ class MainWindow(TkinterDnD.Tk): self.options_mdxnetModel_Optionmenu = ttk.OptionMenu(self.options_Frame, self.mdxnetModel_var, - None, 'UVR-MDX-NET 1', 'UVR-MDX-NET 2', 'UVR-MDX-NET 3', 'UVR-MDX-NET Karaoke') + None, 'UVR-MDX-NET 1', 'UVR-MDX-NET 2', 'UVR-MDX-NET 3', 'UVR-MDX-NET Karaoke')#, 'UVR-MDX-NET Full-B') # Ensemble Mode self.options_ensChoose_Label = tk.Label(master=self.options_Frame, text='Choose Ensemble', anchor=tk.CENTER, diff --git a/inference_MDX.py b/inference_MDX.py index ff3b327..ba3340f 100644 --- a/inference_MDX.py +++ b/inference_MDX.py @@ -65,7 +65,7 @@ class Predictor(): self.onnx_models = {} c = 0 - self.models = get_models('tdf_extra', load=False, device=cpu, stems='vocals') + self.models = get_models('tdf_extra', load=False, device=cpu, stems=modeltype) widget_text.write(base_text + 'Loading ONNX model... ') update_progress(**progress_kwargs, step=0.1) @@ -701,6 +701,7 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress global default_noisereduc_s global _basename global _mixture + global modeltype global progress_kwargs global base_text global model_set @@ -743,15 +744,24 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress if data['mdxnetModel'] == 'UVR-MDX-NET 1': model_set = 'UVR_MDXNET_9703.onnx' model_set_name = 'UVR_MDXNET_9703' + modeltype = 'vocals-one' if data['mdxnetModel'] == 'UVR-MDX-NET 2': model_set = 'UVR_MDXNET_9682.onnx' model_set_name = 'UVR_MDXNET_9682' + modeltype = 'vocals-one' if data['mdxnetModel'] == 'UVR-MDX-NET 3': model_set = 'UVR_MDXNET_9662.onnx' model_set_name = 'UVR_MDXNET_9662' + modeltype = 'vocals-one' if data['mdxnetModel'] == 'UVR-MDX-NET Karaoke': model_set = 'UVR_MDXNET_KARA.onnx' model_set_name = 'UVR_MDXNET_Karaoke' + modeltype = 'vocals-one' + if data['mdxnetModel'] == 'UVR-MDX-NET Full-B': + model_set = 'UVR_MDXNET_FULL.onnx' + model_set_name = 'UVR_MDXNET_FULL' + modeltype = 'vocals-two' + stime = time.perf_counter() progress_var.set(0) diff --git a/inference_v5_ensemble.py b/inference_v5_ensemble.py index af2c42c..477d474 100644 --- a/inference_v5_ensemble.py +++ b/inference_v5_ensemble.py @@ -71,7 +71,7 @@ class Predictor(): self.onnx_models = {} c = 0 - self.models = get_models('tdf_extra', load=False, device=cpu, stems='vocals') + self.models = get_models('tdf_extra', load=False, device=cpu, stems=modeltype) widget_text.write(base_text + 'Loading ONNX model... ') update_progress(**progress_kwargs, step=0.1) @@ -90,6 +90,7 @@ class Predictor(): print(run_type) print(str(device)) + print('model_set: ', model_set) self.onnx_models[c] = ort.InferenceSession(os.path.join('models/MDX_Net_Models', model_set), providers=run_type) widget_text.write('Done!\n') @@ -609,12 +610,10 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress global base_name global progress_kwargs global base_text + global modeltype global model_set global model_set_name global ModelName_2 - - model_set = 'UVR_MDXNET_9703.onnx' - model_set_name = 'UVR_MDXNET_9703' # Update default settings default_chunks = data['chunks'] @@ -1191,7 +1190,7 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress vr_ensem_mdx_c_name = 'pass' else: vr_ensem_mdx_c_name = data['vr_ensem_mdx_c'] - vr_ensem_mdx_c = f'models/Main_Models/{vr_ensem_mdx_c_name}.pth' + vr_ensem_mdx_c = f'models/Main_Models/{vr_ensem_mdx_c_name}.pth' #MDX-Net Model @@ -1202,7 +1201,8 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress if data['mdx_ensem'] == 'UVR-MDX-NET 3': mdx_ensem = 'UVR_MDXNET_9662' if data['mdx_ensem'] == 'UVR-MDX-NET Karaoke': - mdx_ensem = 'UVR_MDXNET_Karaoke' + mdx_ensem = 'UVR_MDXNET_KARA' + #MDX-Net Model 2 @@ -1925,6 +1925,23 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress else: text_widget.write('Ensemble Mode - Running Model - ' + mdx_name + '\n\n') + if mdx_name == 'UVR_MDXNET_9703': + mdx_ensem_b = 'UVR_MDXNET_9703' + model_set = 'UVR_MDXNET_9703.onnx' + model_set_name = 'UVR_MDXNET_9703' + modeltype = 'vocals-one' + if mdx_name == 'UVR_MDXNET_9682': + model_set = 'UVR_MDXNET_9682.onnx' + model_set_name = 'UVR_MDXNET_9682' + modeltype = 'vocals-one' + if mdx_name == 'UVR_MDXNET_9662': + model_set = 'UVR_MDXNET_9662.onnx' + model_set_name = 'UVR_MDXNET_9662' + modeltype = 'vocals-one' + if mdx_name == 'UVR_MDXNET_Karaoke': + model_set = 'UVR_MDXNET_KARA.onnx' + model_set_name = 'UVR_MDXNET_Karaoke' + modeltype = 'vocals-one' update_progress(**progress_kwargs, step=0) @@ -2116,7 +2133,7 @@ def main(window: tk.Wm, text_widget: tk.Text, button_widget: tk.Button, progress text_widget.write(base_text + 'FFmpeg might be missing or corrupted, please check error log.\n') text_widget.write(base_text + 'Moving on... ') else: - text_widget.write(base_text + 'Failed to save output(s) as Flac(s).\n') + text_widget.write('\n' + base_text + 'Failed to save output(s) as Flac(s).\n') text_widget.write(base_text + 'Please check error log.\n') text_widget.write(base_text + 'Moving on... ') try: diff --git a/models.py b/models.py index 15e5fb3..0e7a704 100644 --- a/models.py +++ b/models.py @@ -5,139 +5,27 @@ import numpy as np import librosa dim_c = 4 -k = 3 model_path = 'model' -n_fft_scale = {'bass': 8, 'drums':2, 'other':4, 'vocals':3, '*':2} - - -class Conv_TDF(nn.Module): - def __init__(self, c, l, f, k, bn, bias=True): - - super(Conv_TDF, self).__init__() - - self.use_tdf = bn is not None - - self.H = nn.ModuleList() - for i in range(l): - self.H.append( - nn.Sequential( - nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k//2), - nn.BatchNorm2d(c), - nn.ReLU(), - ) - ) - - if self.use_tdf: - if bn==0: - self.tdf = nn.Sequential( - nn.Linear(f,f, bias=bias), - nn.BatchNorm2d(c), - nn.ReLU() - ) - else: - self.tdf = nn.Sequential( - nn.Linear(f,f//bn, bias=bias), - nn.BatchNorm2d(c), - nn.ReLU(), - nn.Linear(f//bn,f, bias=bias), - nn.BatchNorm2d(c), - nn.ReLU() - ) - - def forward(self, x): - for h in self.H: - x = h(x) - - return x + self.tdf(x) if self.use_tdf else x +n_fft_scale = {'vocals-one':6144, 'vocals-two':7680,'*':2} class Conv_TDF_net_trim(nn.Module): - def __init__(self, device, load, model_name, target_name, lr, epoch, - L, l, g, dim_f, dim_t, k=3, hop=1024, bn=None, bias=True): + def __init__(self, device, load, model_name, target_name, + L, dim_f, dim_t, hop=1024): super(Conv_TDF_net_trim, self).__init__() - self.dim_f, self.dim_t = 2**dim_f, 2**dim_t - self.n_fft = self.dim_f * n_fft_scale[target_name] + self.dim_f, self.dim_t = dim_f, 2**dim_t + self.n_fft = n_fft_scale[target_name] self.hop = hop self.n_bins = self.n_fft//2+1 self.chunk_size = hop * (self.dim_t-1) self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device) self.target_name = target_name - self.blender = 'blender' in model_name - + print(target_name) out_c = dim_c*4 if target_name=='*' else dim_c - in_c = dim_c*2 if self.blender else dim_c - #out_c = dim_c*2 if self.blender else dim_c self.freq_pad = torch.zeros([1, out_c, self.n_bins-self.dim_f, self.dim_t]).to(device) - - self.n = L//2 - if load: - - self.first_conv = nn.Sequential( - nn.Conv2d(in_channels=in_c, out_channels=g, kernel_size=1, stride=1), - nn.BatchNorm2d(g), - nn.ReLU(), - ) - - f = self.dim_f - c = g - self.ds_dense = nn.ModuleList() - self.ds = nn.ModuleList() - for i in range(self.n): - self.ds_dense.append(Conv_TDF(c, l, f, k, bn, bias=bias)) - - scale = (2,2) - self.ds.append( - nn.Sequential( - nn.Conv2d(in_channels=c, out_channels=c+g, kernel_size=scale, stride=scale), - nn.BatchNorm2d(c+g), - nn.ReLU() - ) - ) - f = f//2 - c += g - - self.mid_dense = Conv_TDF(c, l, f, k, bn, bias=bias) - #if bn is None and mid_tdf: - # self.mid_dense = Conv_TDF(c, l, f, k, bn=0, bias=False) - - self.us_dense = nn.ModuleList() - self.us = nn.ModuleList() - for i in range(self.n): - scale = (2,2) - self.us.append( - nn.Sequential( - nn.ConvTranspose2d(in_channels=c, out_channels=c-g, kernel_size=scale, stride=scale), - nn.BatchNorm2d(c-g), - nn.ReLU() - ) - ) - f = f*2 - c -= g - - self.us_dense.append(Conv_TDF(c, l, f, k, bn, bias=bias)) - - - self.final_conv = nn.Sequential( - nn.Conv2d(in_channels=c, out_channels=out_c, kernel_size=1, stride=1), - ) - - - model_cfg = f'L{L}l{l}g{g}' - model_cfg += ', ' if (bn is None or bn==0) else f'bn{bn}, ' - - stft_cfg = f'f{dim_f}t{dim_t}, ' - - model_name = model_name[:model_name.index('(')+1] + model_cfg + stft_cfg + model_name[model_name.index('(')+1:] - try: - self.load_state_dict( - torch.load('{0}/{1}/{2}_lr{3}_e{4:05}.ckpt'.format(model_path, model_name, target_name, lr, epoch), map_location=device) - ) - print(f'Loading model ({target_name})') - except FileNotFoundError: - print(f'Random init ({target_name})') - + print(self.n_bins) def stft(self, x): x = x.reshape([-1, self.chunk_size]) @@ -154,32 +42,6 @@ class Conv_TDF_net_trim(nn.Module): x = x.permute([0,2,3,1]) x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) return x.reshape([-1,c,self.chunk_size]) - - - def forward(self, x): - - x = self.first_conv(x) - - x = x.transpose(-1,-2) - - ds_outputs = [] - for i in range(self.n): - x = self.ds_dense[i](x) - ds_outputs.append(x) - x = self.ds[i](x) - - x = self.mid_dense(x) - - for i in range(self.n): - x = self.us[i](x) - x *= ds_outputs[-i-1] - x = self.us_dense[i](x) - - x = x.transpose(-1,-2) - - x = self.final_conv(x) - - return x def stft(wave, nfft, hl): wave_left = np.asfortranarray(wave[0]) @@ -193,7 +55,6 @@ def stft(wave, nfft, hl): def istft(spec, hl): spec_left = np.asfortranarray(spec[0]) spec_right = np.asfortranarray(spec[1]) - wave_left = librosa.istft(spec_left, hop_length=hl) wave_right = librosa.istft(spec_right, hop_length=hl) wave = np.asfortranarray([wave_left, wave_right]) @@ -220,18 +81,24 @@ def spec_effects(wave, algorithm='default', value=None): return wave -def get_models(name, device, load=True, stems='vocals'): +def get_models(name, device, load=True, stems='vocals-onevocals-two'): if name=='tdf_extra': models = [] - if 'vocals' in stems: + if 'vocals-one' in stems: models.append( Conv_TDF_net_trim( device=device, load=load, - model_name='Conv-TDF', target_name='vocals', - lr=0.0001, epoch=0, - L=11, l=3, g=32, bn=8, bias=False, - dim_f=11, dim_t=8 + model_name='Conv-TDF', target_name='vocals-one', + L=11, dim_f=2048, dim_t=8 + ) + ) + if 'vocals-two' in stems: + models.append( + Conv_TDF_net_trim( + device=device, load=load, + model_name='Conv-TDF', target_name='vocals-two', + L=11, dim_f=3072, dim_t=8 ) ) return models