From 2fd6f15912b8913025e38ef2242dc2c3afb059db Mon Sep 17 00:00:00 2001 From: Anjok07 <68268275+Anjok07@users.noreply.github.com> Date: Sun, 17 Nov 2024 23:30:39 -0600 Subject: [PATCH] Update separate.py --- separate.py | 180 +++++++++++++++++++++++++++------------------------- 1 file changed, 94 insertions(+), 86 deletions(-) diff --git a/separate.py b/separate.py index 2236154..72dfabc 100644 --- a/separate.py +++ b/separate.py @@ -13,6 +13,7 @@ from lib_v5.vr_network import nets_new from lib_v5.vr_network.model_param_init import ModelParameters from lib_v5.mel_band_roformer import MelBandRoformer from lib_v5.bs_roformer import BSRoformer +#from lib_v5.scnet.scnet import SCNet from pathlib import Path from gui_data.constants import * from gui_data.error_handling import * @@ -25,11 +26,13 @@ import numpy as np import onnxruntime as ort import os import torch +import torch.nn as nn import warnings import pydub import soundfile as sf import lib_v5.mdxnet as MdxnetSet import math +import time #import random from onnx import load from onnx2pytorch import ConvertModel @@ -175,8 +178,8 @@ class SeperateAttributes: #Roformer self.roformer_config = model_data.mdx_c_configs self.is_roformer = model_data.is_roformer + self.overwrite_protect_stamp = round(time.time()) - if self.is_inst_only_voc_splitter or self.is_sec_bv_rebalance: self.is_primary_stem_only = False self.is_secondary_stem_only = False @@ -392,8 +395,12 @@ class SeperateAttributes: def write_audio(self, stem_path: str, stem_source, samplerate, stem_name=None): - def save_audio_file(path, source): + def save_audio_file(path: str, source): source = spec_utils.normalize(source, self.is_normalization) + + if os.path.isfile(path) and is_not_ensemble: + path = path.replace(".wav", f"_{self.overwrite_protect_stamp}.wav") + sf.write(path, source, samplerate, subtype=self.wav_type_set) if is_not_ensemble: @@ -740,113 +747,114 @@ class SeperateMDXC(SeperateAttributes): if self.is_secondary_model or self.is_pre_proc_model: return secondary_sources - def overlap_add(self, result, x, weights, start, length): + def overlap_add(self, result, x, l, j, start, window): if self.device == 'mps': x = x.to(self.device) - result[..., start:start+length] += x[..., :length] * weights[:length] + result[..., start:start + l] += x[j][..., :l] * window[..., :l] return result + def demix(self, mix): sr_pitched = 441000 org_mix = mix if self.is_pitch_change: mix, sr_pitched = spec_utils.change_pitch_semitones(mix, 44100, semitone_shift=-self.semitone_shift) - if self.is_roformer: - overlap = self.overlap_mdx23 - device = self.device + device = self.device - # Determine the model type based on the configuration and instantiate it + # Determine the model type based on the configuration and instantiate it + if self.is_roformer: if 'num_bands' in self.roformer_config.model: model = MelBandRoformer(**self.roformer_config.model) elif 'freqs_per_bands' in self.roformer_config.model: model = BSRoformer(**self.roformer_config.model) + # elif 'bandsplit_ratios' in self.roformer_config.model: + # #print("**self.roformer_config.model", f"{self.roformer_config.model}") + # model = SCNet(**self.roformer_config.model) else: raise ValueError('Unknown model type in the configuration.') - - # Load model checkpoint - checkpoint = torch.load(self.model_path, map_location='cpu') - model = model if not isinstance(model, torch.nn.DataParallel) else model.module - model.load_state_dict(checkpoint) - model.to(device).eval() - mix = torch.tensor(mix, dtype=torch.float32) - - segment_size = self.mdx_c_configs.inference.dim_t if self.is_mdx_c_seg_def else self.mdx_segment_size - - S = 1 if self.roformer_config.training.target_instrument else len(self.roformer_config.training.instruments) - C = self.roformer_config.audio.hop_length * (segment_size - 1) - step = int(overlap * self.roformer_config.audio.sample_rate) - - # Create a weighting table and convert it to a PyTorch tensor - window = torch.tensor(signal.hamming(C), dtype=torch.float32) - - device = next(model.parameters()).device - # Transfer to the weighting plate for the same device as the other tensors - window = window.to(device) - - batch_len = int(mix.shape[1]/step) - - #with torch.cuda.amp.autocast(): - with torch.no_grad(): - req_shape = (len(self.roformer_config.training.instruments), ) + tuple(mix.shape) - result = torch.zeros(req_shape, dtype=torch.float32).to(device) - counter = torch.zeros(req_shape, dtype=torch.float32).to(device) - for i in range(0, mix.shape[1], step): - self.running_inference_progress_bar(batch_len) - part = mix[:, i:i + C] - length = part.shape[-1] - if i + C > mix.shape[1]: - part = mix[:, -C:] - length = C - part = part.to(device) - x = model(part.unsqueeze(0))[0] - if i + C > mix.shape[1]: - # Corrigido para adicionar corretamente ao final do tensor - result = self.overlap_add(result, x, window, result.shape[-1] - C, length) - counter[..., result.shape[-1] - C:] += window[:length] - else: - result = self.overlap_add(result, x, window, i, length) - counter[..., i:i+length] += window[:length] - - estimated_sources = result / counter.clamp(min=1e-10) else: model = TFC_TDF_net(self.mdx_c_configs, device=self.device) - model.load_state_dict(torch.load(self.model_path, map_location=cpu)) - model.to(self.device).eval() - mix = torch.tensor(mix, dtype=torch.float32) - try: - S = model.num_target_instruments - except Exception as e: - S = model.module.num_target_instruments + # Load model checkpoint + checkpoint = torch.load(self.model_path, map_location='cpu') + model = model if not isinstance(model, torch.nn.DataParallel) else model.module + model.load_state_dict(checkpoint) + model.to(device).eval() + mix = torch.tensor(mix, dtype=torch.float32) - mdx_segment_size = self.mdx_c_configs.inference.dim_t if self.is_mdx_c_seg_def else self.mdx_segment_size - - batch_size = self.mdx_batch_size - chunk_size = self.mdx_c_configs.audio.hop_length * (mdx_segment_size - 1) - overlap = self.overlap_mdx23 + segment_size = self.mdx_c_configs.inference.dim_t if self.is_mdx_c_seg_def else self.mdx_segment_size + S = 1 if self.roformer_config.training.target_instrument else len(self.roformer_config.training.instruments) + C = self.roformer_config.audio.hop_length * (segment_size - 1) + N = self.overlap_mdx23 + step = int(C // N) + fade_size = C // 10 + batch_size = self.roformer_config.inference.batch_size + length_init = mix.shape[-1] - hop_size = chunk_size // overlap - mix_shape = mix.shape[1] - pad_size = hop_size - (mix_shape - chunk_size) % hop_size - mix = torch.cat([torch.zeros(2, chunk_size - hop_size), mix, torch.zeros(2, pad_size + chunk_size - hop_size)], 1) + # Padding the mix to account for border effects + if length_init > 2 * (C - step) and (C - step > 0): + mix = nn.functional.pad(mix, (C - step, C - step), mode='reflect') - chunks = mix.unfold(1, chunk_size, hop_size).transpose(0, 1) - batches = [chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)] - - X = torch.zeros(S, *mix.shape) if S > 1 else torch.zeros_like(mix) - X = X.to(self.device) - - with torch.no_grad(): - cnt = 0 - for batch in batches: - self.running_inference_progress_bar(len(batches)) - x = model(batch.to(self.device)) - - for w in x: - X[..., cnt * hop_size : cnt * hop_size + chunk_size] += w - cnt += 1 + # Set up windows for fade-in/out + fadein = torch.linspace(0, 1, fade_size).to(device) + fadeout = torch.linspace(1, 0, fade_size).to(device) + window_start = torch.ones(C).to(device) + window_middle = torch.ones(C).to(device) + window_finish = torch.ones(C).to(device) + window_start[-fade_size:] *= fadeout # No fade-in at start + window_finish[:fade_size] *= fadein # No fade-out at end + window_middle[:fade_size] *= fadein + window_middle[-fade_size:] *= fadeout - estimated_sources = X[..., chunk_size - hop_size:-(pad_size + chunk_size - hop_size)] / overlap + batch_len = int(mix.shape[1] / step) + + with torch.inference_mode(): + req_shape = (S, ) + tuple(mix.shape) + result = torch.zeros(req_shape, dtype=torch.float32, device=device) + counter = torch.zeros(req_shape, dtype=torch.float32, device=device) + batch_data = [] + batch_locations = [] + + i = 0 + + while i < mix.shape[1]: + part = mix[:, i:i + C].to(device) + length = part.shape[-1] + if length < C: + if length > C // 2 + 1: + part = nn.functional.pad(part, (0, C - length), mode='reflect') + else: + part = nn.functional.pad(part, (0, C - length, 0, 0), mode='constant', value=0) + + batch_data.append(part) + batch_locations.append((i, length)) + i += step + + # Process in batches + if len(batch_data) >= batch_size or (i >= mix.shape[1]): + arr = torch.stack(batch_data, dim=0) + x = model(arr) + + for j in range(len(batch_locations)): + self.running_inference_progress_bar(batch_len) + start, l = batch_locations[j] + window = window_middle + if start == 0: + window = window_start + elif i >= mix.shape[1]: + window = window_finish + + result = self.overlap_add(result, x, l, j, start, window) + counter[..., start:start + l] += window[..., :l] + + batch_data = [] + batch_locations = [] + + # Normalize by the overlap counter and remove padding + estimated_sources = result / counter.clamp(min=1e-10) + + if length_init > 2 * (C - step) and (C - step > 0): + estimated_sources = estimated_sources[..., (C - step):-(C - step)] pitch_fix = lambda s:self.pitch_fix(s, sr_pitched, org_mix)