Update separate.py

This commit is contained in:
Anjok07 2024-11-17 23:30:39 -06:00 committed by GitHub
parent ac0435bcfc
commit 2fd6f15912
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,6 +13,7 @@ from lib_v5.vr_network import nets_new
from lib_v5.vr_network.model_param_init import ModelParameters
from lib_v5.mel_band_roformer import MelBandRoformer
from lib_v5.bs_roformer import BSRoformer
#from lib_v5.scnet.scnet import SCNet
from pathlib import Path
from gui_data.constants import *
from gui_data.error_handling import *
@ -25,11 +26,13 @@ import numpy as np
import onnxruntime as ort
import os
import torch
import torch.nn as nn
import warnings
import pydub
import soundfile as sf
import lib_v5.mdxnet as MdxnetSet
import math
import time
#import random
from onnx import load
from onnx2pytorch import ConvertModel
@ -175,8 +178,8 @@ class SeperateAttributes:
#Roformer
self.roformer_config = model_data.mdx_c_configs
self.is_roformer = model_data.is_roformer
self.overwrite_protect_stamp = round(time.time())
if self.is_inst_only_voc_splitter or self.is_sec_bv_rebalance:
self.is_primary_stem_only = False
self.is_secondary_stem_only = False
@ -392,8 +395,12 @@ class SeperateAttributes:
def write_audio(self, stem_path: str, stem_source, samplerate, stem_name=None):
def save_audio_file(path, source):
def save_audio_file(path: str, source):
source = spec_utils.normalize(source, self.is_normalization)
if os.path.isfile(path) and is_not_ensemble:
path = path.replace(".wav", f"_{self.overwrite_protect_stamp}.wav")
sf.write(path, source, samplerate, subtype=self.wav_type_set)
if is_not_ensemble:
@ -740,113 +747,114 @@ class SeperateMDXC(SeperateAttributes):
if self.is_secondary_model or self.is_pre_proc_model:
return secondary_sources
def overlap_add(self, result, x, weights, start, length):
def overlap_add(self, result, x, l, j, start, window):
if self.device == 'mps':
x = x.to(self.device)
result[..., start:start+length] += x[..., :length] * weights[:length]
result[..., start:start + l] += x[j][..., :l] * window[..., :l]
return result
def demix(self, mix):
sr_pitched = 441000
org_mix = mix
if self.is_pitch_change:
mix, sr_pitched = spec_utils.change_pitch_semitones(mix, 44100, semitone_shift=-self.semitone_shift)
if self.is_roformer:
overlap = self.overlap_mdx23
device = self.device
device = self.device
# Determine the model type based on the configuration and instantiate it
# Determine the model type based on the configuration and instantiate it
if self.is_roformer:
if 'num_bands' in self.roformer_config.model:
model = MelBandRoformer(**self.roformer_config.model)
elif 'freqs_per_bands' in self.roformer_config.model:
model = BSRoformer(**self.roformer_config.model)
# elif 'bandsplit_ratios' in self.roformer_config.model:
# #print("**self.roformer_config.model", f"{self.roformer_config.model}")
# model = SCNet(**self.roformer_config.model)
else:
raise ValueError('Unknown model type in the configuration.')
# Load model checkpoint
checkpoint = torch.load(self.model_path, map_location='cpu')
model = model if not isinstance(model, torch.nn.DataParallel) else model.module
model.load_state_dict(checkpoint)
model.to(device).eval()
mix = torch.tensor(mix, dtype=torch.float32)
segment_size = self.mdx_c_configs.inference.dim_t if self.is_mdx_c_seg_def else self.mdx_segment_size
S = 1 if self.roformer_config.training.target_instrument else len(self.roformer_config.training.instruments)
C = self.roformer_config.audio.hop_length * (segment_size - 1)
step = int(overlap * self.roformer_config.audio.sample_rate)
# Create a weighting table and convert it to a PyTorch tensor
window = torch.tensor(signal.hamming(C), dtype=torch.float32)
device = next(model.parameters()).device
# Transfer to the weighting plate for the same device as the other tensors
window = window.to(device)
batch_len = int(mix.shape[1]/step)
#with torch.cuda.amp.autocast():
with torch.no_grad():
req_shape = (len(self.roformer_config.training.instruments), ) + tuple(mix.shape)
result = torch.zeros(req_shape, dtype=torch.float32).to(device)
counter = torch.zeros(req_shape, dtype=torch.float32).to(device)
for i in range(0, mix.shape[1], step):
self.running_inference_progress_bar(batch_len)
part = mix[:, i:i + C]
length = part.shape[-1]
if i + C > mix.shape[1]:
part = mix[:, -C:]
length = C
part = part.to(device)
x = model(part.unsqueeze(0))[0]
if i + C > mix.shape[1]:
# Corrigido para adicionar corretamente ao final do tensor
result = self.overlap_add(result, x, window, result.shape[-1] - C, length)
counter[..., result.shape[-1] - C:] += window[:length]
else:
result = self.overlap_add(result, x, window, i, length)
counter[..., i:i+length] += window[:length]
estimated_sources = result / counter.clamp(min=1e-10)
else:
model = TFC_TDF_net(self.mdx_c_configs, device=self.device)
model.load_state_dict(torch.load(self.model_path, map_location=cpu))
model.to(self.device).eval()
mix = torch.tensor(mix, dtype=torch.float32)
try:
S = model.num_target_instruments
except Exception as e:
S = model.module.num_target_instruments
# Load model checkpoint
checkpoint = torch.load(self.model_path, map_location='cpu')
model = model if not isinstance(model, torch.nn.DataParallel) else model.module
model.load_state_dict(checkpoint)
model.to(device).eval()
mix = torch.tensor(mix, dtype=torch.float32)
mdx_segment_size = self.mdx_c_configs.inference.dim_t if self.is_mdx_c_seg_def else self.mdx_segment_size
batch_size = self.mdx_batch_size
chunk_size = self.mdx_c_configs.audio.hop_length * (mdx_segment_size - 1)
overlap = self.overlap_mdx23
segment_size = self.mdx_c_configs.inference.dim_t if self.is_mdx_c_seg_def else self.mdx_segment_size
S = 1 if self.roformer_config.training.target_instrument else len(self.roformer_config.training.instruments)
C = self.roformer_config.audio.hop_length * (segment_size - 1)
N = self.overlap_mdx23
step = int(C // N)
fade_size = C // 10
batch_size = self.roformer_config.inference.batch_size
length_init = mix.shape[-1]
hop_size = chunk_size // overlap
mix_shape = mix.shape[1]
pad_size = hop_size - (mix_shape - chunk_size) % hop_size
mix = torch.cat([torch.zeros(2, chunk_size - hop_size), mix, torch.zeros(2, pad_size + chunk_size - hop_size)], 1)
# Padding the mix to account for border effects
if length_init > 2 * (C - step) and (C - step > 0):
mix = nn.functional.pad(mix, (C - step, C - step), mode='reflect')
chunks = mix.unfold(1, chunk_size, hop_size).transpose(0, 1)
batches = [chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)]
X = torch.zeros(S, *mix.shape) if S > 1 else torch.zeros_like(mix)
X = X.to(self.device)
with torch.no_grad():
cnt = 0
for batch in batches:
self.running_inference_progress_bar(len(batches))
x = model(batch.to(self.device))
for w in x:
X[..., cnt * hop_size : cnt * hop_size + chunk_size] += w
cnt += 1
# Set up windows for fade-in/out
fadein = torch.linspace(0, 1, fade_size).to(device)
fadeout = torch.linspace(1, 0, fade_size).to(device)
window_start = torch.ones(C).to(device)
window_middle = torch.ones(C).to(device)
window_finish = torch.ones(C).to(device)
window_start[-fade_size:] *= fadeout # No fade-in at start
window_finish[:fade_size] *= fadein # No fade-out at end
window_middle[:fade_size] *= fadein
window_middle[-fade_size:] *= fadeout
estimated_sources = X[..., chunk_size - hop_size:-(pad_size + chunk_size - hop_size)] / overlap
batch_len = int(mix.shape[1] / step)
with torch.inference_mode():
req_shape = (S, ) + tuple(mix.shape)
result = torch.zeros(req_shape, dtype=torch.float32, device=device)
counter = torch.zeros(req_shape, dtype=torch.float32, device=device)
batch_data = []
batch_locations = []
i = 0
while i < mix.shape[1]:
part = mix[:, i:i + C].to(device)
length = part.shape[-1]
if length < C:
if length > C // 2 + 1:
part = nn.functional.pad(part, (0, C - length), mode='reflect')
else:
part = nn.functional.pad(part, (0, C - length, 0, 0), mode='constant', value=0)
batch_data.append(part)
batch_locations.append((i, length))
i += step
# Process in batches
if len(batch_data) >= batch_size or (i >= mix.shape[1]):
arr = torch.stack(batch_data, dim=0)
x = model(arr)
for j in range(len(batch_locations)):
self.running_inference_progress_bar(batch_len)
start, l = batch_locations[j]
window = window_middle
if start == 0:
window = window_start
elif i >= mix.shape[1]:
window = window_finish
result = self.overlap_add(result, x, l, j, start, window)
counter[..., start:start + l] += window[..., :l]
batch_data = []
batch_locations = []
# Normalize by the overlap counter and remove padding
estimated_sources = result / counter.clamp(min=1e-10)
if length_init > 2 * (C - step) and (C - step > 0):
estimated_sources = estimated_sources[..., (C - step):-(C - step)]
pitch_fix = lambda s:self.pitch_fix(s, sr_pitched, org_mix)