mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-30 18:24:28 +01:00
Update separate.py
This commit is contained in:
parent
ac0435bcfc
commit
2fd6f15912
180
separate.py
180
separate.py
@ -13,6 +13,7 @@ from lib_v5.vr_network import nets_new
|
||||
from lib_v5.vr_network.model_param_init import ModelParameters
|
||||
from lib_v5.mel_band_roformer import MelBandRoformer
|
||||
from lib_v5.bs_roformer import BSRoformer
|
||||
#from lib_v5.scnet.scnet import SCNet
|
||||
from pathlib import Path
|
||||
from gui_data.constants import *
|
||||
from gui_data.error_handling import *
|
||||
@ -25,11 +26,13 @@ import numpy as np
|
||||
import onnxruntime as ort
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import warnings
|
||||
import pydub
|
||||
import soundfile as sf
|
||||
import lib_v5.mdxnet as MdxnetSet
|
||||
import math
|
||||
import time
|
||||
#import random
|
||||
from onnx import load
|
||||
from onnx2pytorch import ConvertModel
|
||||
@ -175,8 +178,8 @@ class SeperateAttributes:
|
||||
#Roformer
|
||||
self.roformer_config = model_data.mdx_c_configs
|
||||
self.is_roformer = model_data.is_roformer
|
||||
self.overwrite_protect_stamp = round(time.time())
|
||||
|
||||
|
||||
if self.is_inst_only_voc_splitter or self.is_sec_bv_rebalance:
|
||||
self.is_primary_stem_only = False
|
||||
self.is_secondary_stem_only = False
|
||||
@ -392,8 +395,12 @@ class SeperateAttributes:
|
||||
|
||||
def write_audio(self, stem_path: str, stem_source, samplerate, stem_name=None):
|
||||
|
||||
def save_audio_file(path, source):
|
||||
def save_audio_file(path: str, source):
|
||||
source = spec_utils.normalize(source, self.is_normalization)
|
||||
|
||||
if os.path.isfile(path) and is_not_ensemble:
|
||||
path = path.replace(".wav", f"_{self.overwrite_protect_stamp}.wav")
|
||||
|
||||
sf.write(path, source, samplerate, subtype=self.wav_type_set)
|
||||
|
||||
if is_not_ensemble:
|
||||
@ -740,113 +747,114 @@ class SeperateMDXC(SeperateAttributes):
|
||||
if self.is_secondary_model or self.is_pre_proc_model:
|
||||
return secondary_sources
|
||||
|
||||
def overlap_add(self, result, x, weights, start, length):
|
||||
def overlap_add(self, result, x, l, j, start, window):
|
||||
if self.device == 'mps':
|
||||
x = x.to(self.device)
|
||||
result[..., start:start+length] += x[..., :length] * weights[:length]
|
||||
result[..., start:start + l] += x[j][..., :l] * window[..., :l]
|
||||
return result
|
||||
|
||||
def demix(self, mix):
|
||||
sr_pitched = 441000
|
||||
org_mix = mix
|
||||
if self.is_pitch_change:
|
||||
mix, sr_pitched = spec_utils.change_pitch_semitones(mix, 44100, semitone_shift=-self.semitone_shift)
|
||||
|
||||
if self.is_roformer:
|
||||
overlap = self.overlap_mdx23
|
||||
device = self.device
|
||||
device = self.device
|
||||
|
||||
# Determine the model type based on the configuration and instantiate it
|
||||
# Determine the model type based on the configuration and instantiate it
|
||||
if self.is_roformer:
|
||||
if 'num_bands' in self.roformer_config.model:
|
||||
model = MelBandRoformer(**self.roformer_config.model)
|
||||
elif 'freqs_per_bands' in self.roformer_config.model:
|
||||
model = BSRoformer(**self.roformer_config.model)
|
||||
# elif 'bandsplit_ratios' in self.roformer_config.model:
|
||||
# #print("**self.roformer_config.model", f"{self.roformer_config.model}")
|
||||
# model = SCNet(**self.roformer_config.model)
|
||||
else:
|
||||
raise ValueError('Unknown model type in the configuration.')
|
||||
|
||||
# Load model checkpoint
|
||||
checkpoint = torch.load(self.model_path, map_location='cpu')
|
||||
model = model if not isinstance(model, torch.nn.DataParallel) else model.module
|
||||
model.load_state_dict(checkpoint)
|
||||
model.to(device).eval()
|
||||
mix = torch.tensor(mix, dtype=torch.float32)
|
||||
|
||||
segment_size = self.mdx_c_configs.inference.dim_t if self.is_mdx_c_seg_def else self.mdx_segment_size
|
||||
|
||||
S = 1 if self.roformer_config.training.target_instrument else len(self.roformer_config.training.instruments)
|
||||
C = self.roformer_config.audio.hop_length * (segment_size - 1)
|
||||
step = int(overlap * self.roformer_config.audio.sample_rate)
|
||||
|
||||
# Create a weighting table and convert it to a PyTorch tensor
|
||||
window = torch.tensor(signal.hamming(C), dtype=torch.float32)
|
||||
|
||||
device = next(model.parameters()).device
|
||||
# Transfer to the weighting plate for the same device as the other tensors
|
||||
window = window.to(device)
|
||||
|
||||
batch_len = int(mix.shape[1]/step)
|
||||
|
||||
#with torch.cuda.amp.autocast():
|
||||
with torch.no_grad():
|
||||
req_shape = (len(self.roformer_config.training.instruments), ) + tuple(mix.shape)
|
||||
result = torch.zeros(req_shape, dtype=torch.float32).to(device)
|
||||
counter = torch.zeros(req_shape, dtype=torch.float32).to(device)
|
||||
for i in range(0, mix.shape[1], step):
|
||||
self.running_inference_progress_bar(batch_len)
|
||||
part = mix[:, i:i + C]
|
||||
length = part.shape[-1]
|
||||
if i + C > mix.shape[1]:
|
||||
part = mix[:, -C:]
|
||||
length = C
|
||||
part = part.to(device)
|
||||
x = model(part.unsqueeze(0))[0]
|
||||
if i + C > mix.shape[1]:
|
||||
# Corrigido para adicionar corretamente ao final do tensor
|
||||
result = self.overlap_add(result, x, window, result.shape[-1] - C, length)
|
||||
counter[..., result.shape[-1] - C:] += window[:length]
|
||||
else:
|
||||
result = self.overlap_add(result, x, window, i, length)
|
||||
counter[..., i:i+length] += window[:length]
|
||||
|
||||
estimated_sources = result / counter.clamp(min=1e-10)
|
||||
else:
|
||||
model = TFC_TDF_net(self.mdx_c_configs, device=self.device)
|
||||
model.load_state_dict(torch.load(self.model_path, map_location=cpu))
|
||||
model.to(self.device).eval()
|
||||
mix = torch.tensor(mix, dtype=torch.float32)
|
||||
|
||||
try:
|
||||
S = model.num_target_instruments
|
||||
except Exception as e:
|
||||
S = model.module.num_target_instruments
|
||||
# Load model checkpoint
|
||||
checkpoint = torch.load(self.model_path, map_location='cpu')
|
||||
model = model if not isinstance(model, torch.nn.DataParallel) else model.module
|
||||
model.load_state_dict(checkpoint)
|
||||
model.to(device).eval()
|
||||
mix = torch.tensor(mix, dtype=torch.float32)
|
||||
|
||||
mdx_segment_size = self.mdx_c_configs.inference.dim_t if self.is_mdx_c_seg_def else self.mdx_segment_size
|
||||
|
||||
batch_size = self.mdx_batch_size
|
||||
chunk_size = self.mdx_c_configs.audio.hop_length * (mdx_segment_size - 1)
|
||||
overlap = self.overlap_mdx23
|
||||
segment_size = self.mdx_c_configs.inference.dim_t if self.is_mdx_c_seg_def else self.mdx_segment_size
|
||||
S = 1 if self.roformer_config.training.target_instrument else len(self.roformer_config.training.instruments)
|
||||
C = self.roformer_config.audio.hop_length * (segment_size - 1)
|
||||
N = self.overlap_mdx23
|
||||
step = int(C // N)
|
||||
fade_size = C // 10
|
||||
batch_size = self.roformer_config.inference.batch_size
|
||||
length_init = mix.shape[-1]
|
||||
|
||||
hop_size = chunk_size // overlap
|
||||
mix_shape = mix.shape[1]
|
||||
pad_size = hop_size - (mix_shape - chunk_size) % hop_size
|
||||
mix = torch.cat([torch.zeros(2, chunk_size - hop_size), mix, torch.zeros(2, pad_size + chunk_size - hop_size)], 1)
|
||||
# Padding the mix to account for border effects
|
||||
if length_init > 2 * (C - step) and (C - step > 0):
|
||||
mix = nn.functional.pad(mix, (C - step, C - step), mode='reflect')
|
||||
|
||||
chunks = mix.unfold(1, chunk_size, hop_size).transpose(0, 1)
|
||||
batches = [chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)]
|
||||
|
||||
X = torch.zeros(S, *mix.shape) if S > 1 else torch.zeros_like(mix)
|
||||
X = X.to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
cnt = 0
|
||||
for batch in batches:
|
||||
self.running_inference_progress_bar(len(batches))
|
||||
x = model(batch.to(self.device))
|
||||
|
||||
for w in x:
|
||||
X[..., cnt * hop_size : cnt * hop_size + chunk_size] += w
|
||||
cnt += 1
|
||||
# Set up windows for fade-in/out
|
||||
fadein = torch.linspace(0, 1, fade_size).to(device)
|
||||
fadeout = torch.linspace(1, 0, fade_size).to(device)
|
||||
window_start = torch.ones(C).to(device)
|
||||
window_middle = torch.ones(C).to(device)
|
||||
window_finish = torch.ones(C).to(device)
|
||||
window_start[-fade_size:] *= fadeout # No fade-in at start
|
||||
window_finish[:fade_size] *= fadein # No fade-out at end
|
||||
window_middle[:fade_size] *= fadein
|
||||
window_middle[-fade_size:] *= fadeout
|
||||
|
||||
estimated_sources = X[..., chunk_size - hop_size:-(pad_size + chunk_size - hop_size)] / overlap
|
||||
batch_len = int(mix.shape[1] / step)
|
||||
|
||||
with torch.inference_mode():
|
||||
req_shape = (S, ) + tuple(mix.shape)
|
||||
result = torch.zeros(req_shape, dtype=torch.float32, device=device)
|
||||
counter = torch.zeros(req_shape, dtype=torch.float32, device=device)
|
||||
batch_data = []
|
||||
batch_locations = []
|
||||
|
||||
i = 0
|
||||
|
||||
while i < mix.shape[1]:
|
||||
part = mix[:, i:i + C].to(device)
|
||||
length = part.shape[-1]
|
||||
if length < C:
|
||||
if length > C // 2 + 1:
|
||||
part = nn.functional.pad(part, (0, C - length), mode='reflect')
|
||||
else:
|
||||
part = nn.functional.pad(part, (0, C - length, 0, 0), mode='constant', value=0)
|
||||
|
||||
batch_data.append(part)
|
||||
batch_locations.append((i, length))
|
||||
i += step
|
||||
|
||||
# Process in batches
|
||||
if len(batch_data) >= batch_size or (i >= mix.shape[1]):
|
||||
arr = torch.stack(batch_data, dim=0)
|
||||
x = model(arr)
|
||||
|
||||
for j in range(len(batch_locations)):
|
||||
self.running_inference_progress_bar(batch_len)
|
||||
start, l = batch_locations[j]
|
||||
window = window_middle
|
||||
if start == 0:
|
||||
window = window_start
|
||||
elif i >= mix.shape[1]:
|
||||
window = window_finish
|
||||
|
||||
result = self.overlap_add(result, x, l, j, start, window)
|
||||
counter[..., start:start + l] += window[..., :l]
|
||||
|
||||
batch_data = []
|
||||
batch_locations = []
|
||||
|
||||
# Normalize by the overlap counter and remove padding
|
||||
estimated_sources = result / counter.clamp(min=1e-10)
|
||||
|
||||
if length_init > 2 * (C - step) and (C - step > 0):
|
||||
estimated_sources = estimated_sources[..., (C - step):-(C - step)]
|
||||
|
||||
pitch_fix = lambda s:self.pitch_fix(s, sr_pitched, org_mix)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user