diff --git a/lib_v5/mdxnet.py b/lib_v5/mdxnet.py index c0a61fe..3293c89 100644 --- a/lib_v5/mdxnet.py +++ b/lib_v5/mdxnet.py @@ -1,15 +1,11 @@ -from abc import ABCMeta - import torch import torch.nn as nn -from pytorch_lightning import LightningModule from .modules import TFC_TDF +from pytorch_lightning import LightningModule dim_s = 4 class AbstractMDXNet(LightningModule): - __metaclass__ = ABCMeta - def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap): super().__init__() self.target_name = target_name @@ -24,7 +20,7 @@ class AbstractMDXNet(LightningModule): self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False) self.freq_pad = nn.Parameter(torch.zeros([1, dim_c, self.n_bins - self.dim_f, self.dim_t]), requires_grad=False) - def configure_optimizers(self): + def get_optimizer(self): if self.optimizer == 'rmsprop': return torch.optim.RMSprop(self.parameters(), self.lr) @@ -37,7 +33,7 @@ class ConvTDFNet(AbstractMDXNet): super(ConvTDFNet, self).__init__( target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap) - self.save_hyperparameters() + #self.save_hyperparameters() self.num_blocks = num_blocks self.l = l diff --git a/lib_v5/spec_utils.py b/lib_v5/spec_utils.py index 0b12a92..f35a05d 100644 --- a/lib_v5/spec_utils.py +++ b/lib_v5/spec_utils.py @@ -1,18 +1,32 @@ +import audioread import librosa import numpy as np import soundfile as sf import math -import random -import math import platform import traceback from . import pyrb -#cur +from scipy.signal import correlate, hilbert +import io + OPERATING_SYSTEM = platform.system() SYSTEM_ARCH = platform.platform() SYSTEM_PROC = platform.processor() ARM = 'arm' +AUTO_PHASE = "Automatic" +POSITIVE_PHASE = "Positive Phase" +NEGATIVE_PHASE = "Negative Phase" +NONE_P = "None", +LOW_P = "Shifts: Low", +MED_P = "Shifts: Medium", +HIGH_P = "Shifts: High", +VHIGH_P = "Shifts: Very High" +MAXIMUM_P = "Shifts: Maximum" + +progress_value = 0 +last_update_time = 0 + if OPERATING_SYSTEM == 'Windows': from pyrubberband import pyrb else: @@ -25,6 +39,11 @@ else: MAX_SPEC = 'Max Spec' MIN_SPEC = 'Min Spec' +LIN_ENSE = 'Linear Ensemble' + +MAX_WAV = MAX_SPEC +MIN_WAV = MIN_SPEC + AVERAGE = 'Average' def crop_center(h1, h2): @@ -57,122 +76,42 @@ def make_padding(width, cropsize, offset): return left, right, roi_size -def wave_to_spectrogram(wave, hop_length, n_fft, mid_side=False, mid_side_b2=False, reverse=False): - if reverse: - wave_left = np.flip(np.asfortranarray(wave[0])) - wave_right = np.flip(np.asfortranarray(wave[1])) - elif mid_side: - wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2) - wave_right = np.asfortranarray(np.subtract(wave[0], wave[1])) - elif mid_side_b2: - wave_left = np.asfortranarray(np.add(wave[1], wave[0] * .5)) - wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * .5)) - else: - wave_left = np.asfortranarray(wave[0]) - wave_right = np.asfortranarray(wave[1]) - - spec_left = librosa.stft(wave_left, n_fft, hop_length=hop_length) - spec_right = librosa.stft(wave_right, n_fft, hop_length=hop_length) - - spec = np.asfortranarray([spec_left, spec_right]) - - return spec - -def wave_to_spectrogram_mt(wave, hop_length, n_fft, mid_side=False, mid_side_b2=False, reverse=False): - import threading - - if reverse: - wave_left = np.flip(np.asfortranarray(wave[0])) - wave_right = np.flip(np.asfortranarray(wave[1])) - elif mid_side: - wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2) - wave_right = np.asfortranarray(np.subtract(wave[0], wave[1])) - elif mid_side_b2: - wave_left = np.asfortranarray(np.add(wave[1], wave[0] * .5)) - wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * .5)) - else: - wave_left = np.asfortranarray(wave[0]) - wave_right = np.asfortranarray(wave[1]) - - def run_thread(**kwargs): - global spec_left - spec_left = librosa.stft(**kwargs) - - thread = threading.Thread(target=run_thread, kwargs={'y': wave_left, 'n_fft': n_fft, 'hop_length': hop_length}) - thread.start() - spec_right = librosa.stft(wave_right, n_fft, hop_length=hop_length) - thread.join() - - spec = np.asfortranarray([spec_left, spec_right]) - - return spec - def normalize(wave, is_normalize=False): - """Save output music files""" + """Normalize audio""" + maxv = np.abs(wave).max() if maxv > 1.0: - print(f"\nNormalization Set {is_normalize}: Input above threshold for clipping. Max:{maxv}") if is_normalize: - print(f"The result was normalized.") + print("Above clipping threshold.") wave /= maxv - else: - print(f"The result was not normalized.") - else: - print(f"\nNormalization Set {is_normalize}: Input not above threshold for clipping. Max:{maxv}") return wave -def normalize_two_stem(wave, mix, is_normalize=False): - """Save output music files""" - - maxv = np.abs(wave).max() - max_mix = np.abs(mix).max() - - if maxv > 1.0: - print(f"\nNormalization Set {is_normalize}: Primary source above threshold for clipping. Max:{maxv}") - print(f"\nNormalization Set {is_normalize}: Mixture above threshold for clipping. Max:{max_mix}") - if is_normalize: - print(f"The result was normalized.") - wave /= maxv - mix /= maxv - else: - print(f"The result was not normalized.") - else: - print(f"\nNormalization Set {is_normalize}: Input not above threshold for clipping. Max:{maxv}") - - - print(f"\nNormalization Set {is_normalize}: Primary source - Max:{np.abs(wave).max()}") - print(f"\nNormalization Set {is_normalize}: Mixture - Max:{np.abs(mix).max()}") - - return wave, mix +def auto_transpose(audio_array:np.ndarray): + """ + Ensure that the audio array is in the (channels, samples) format. -def combine_spectrograms(specs, mp): - l = min([specs[i].shape[2] for i in specs]) - spec_c = np.zeros(shape=(2, mp.param['bins'] + 1, l), dtype=np.complex64) - offset = 0 - bands_n = len(mp.param['band']) - - for d in range(1, bands_n + 1): - h = mp.param['band'][d]['crop_stop'] - mp.param['band'][d]['crop_start'] - spec_c[:, offset:offset+h, :l] = specs[d][:, mp.param['band'][d]['crop_start']:mp.param['band'][d]['crop_stop'], :l] - offset += h - - if offset > mp.param['bins']: - raise ValueError('Too much bins') - - # lowpass fiter - if mp.param['pre_filter_start'] > 0: # and mp.param['band'][bands_n]['res_type'] in ['scipy', 'polyphase']: - if bands_n == 1: - spec_c = fft_lp_filter(spec_c, mp.param['pre_filter_start'], mp.param['pre_filter_stop']) - else: - gp = 1 - for b in range(mp.param['pre_filter_start'] + 1, mp.param['pre_filter_stop']): - g = math.pow(10, -(b - mp.param['pre_filter_start']) * (3.5 - gp) / 20.0) - gp = g - spec_c[:, b, :] *= g - - return np.asfortranarray(spec_c) + Parameters: + audio_array (ndarray): Input audio array. + + Returns: + ndarray: Transposed audio array if necessary. + """ + # If the second dimension is 2 (indicating stereo channels), transpose the array + if audio_array.shape[1] == 2: + return audio_array.T + return audio_array + +def write_array_to_mem(audio_data, subtype): + if isinstance(audio_data, np.ndarray): + audio_buffer = io.BytesIO() + sf.write(audio_buffer, audio_data, 44100, subtype=subtype, format='WAV') + audio_buffer.seek(0) + return audio_buffer + else: + return audio_data + def spectrogram_to_image(spec, mode='magnitude'): if mode == 'magnitude': if np.iscomplexobj(spec): @@ -251,7 +190,6 @@ def merge_artifacts(y_mask, thres=0.01, min_range=64, fade_size=32): message = f'{error_name}: "{e}"\n{traceback_text}"' print('Post Process Failed: ', message) - return mask def align_wave_head_and_tail(a, b): @@ -259,47 +197,112 @@ def align_wave_head_and_tail(a, b): return a[:l,:l], b[:l,:l] -def spectrogram_to_wave(spec, hop_length, mid_side, mid_side_b2, reverse, clamp=False): +def convert_channels(spec, mp, band): + cc = mp.param['band'][band].get('convert_channels') + + if 'mid_side_c' == cc: + spec_left = np.add(spec[0], spec[1] * .25) + spec_right = np.subtract(spec[1], spec[0] * .25) + elif 'mid_side' == cc: + spec_left = np.add(spec[0], spec[1]) / 2 + spec_right = np.subtract(spec[0], spec[1]) + elif 'stereo_n' == cc: + spec_left = np.add(spec[0], spec[1] * .25) / 0.9375 + spec_right = np.add(spec[1], spec[0] * .25) / 0.9375 + else: + return spec + + return np.asfortranarray([spec_left, spec_right]) + +def combine_spectrograms(specs, mp, is_v51_model=False): + l = min([specs[i].shape[2] for i in specs]) + spec_c = np.zeros(shape=(2, mp.param['bins'] + 1, l), dtype=np.complex64) + offset = 0 + bands_n = len(mp.param['band']) + + for d in range(1, bands_n + 1): + h = mp.param['band'][d]['crop_stop'] - mp.param['band'][d]['crop_start'] + spec_c[:, offset:offset+h, :l] = specs[d][:, mp.param['band'][d]['crop_start']:mp.param['band'][d]['crop_stop'], :l] + offset += h + + if offset > mp.param['bins']: + raise ValueError('Too much bins') + + # lowpass fiter + + if mp.param['pre_filter_start'] > 0: + if is_v51_model: + spec_c *= get_lp_filter_mask(spec_c.shape[1], mp.param['pre_filter_start'], mp.param['pre_filter_stop']) + else: + if bands_n == 1: + spec_c = fft_lp_filter(spec_c, mp.param['pre_filter_start'], mp.param['pre_filter_stop']) + else: + gp = 1 + for b in range(mp.param['pre_filter_start'] + 1, mp.param['pre_filter_stop']): + g = math.pow(10, -(b - mp.param['pre_filter_start']) * (3.5 - gp) / 20.0) + gp = g + spec_c[:, b, :] *= g + + return np.asfortranarray(spec_c) + +def wave_to_spectrogram(wave, hop_length, n_fft, mp, band, is_v51_model=False): + + if wave.ndim == 1: + wave = np.asfortranarray([wave,wave]) + + if not is_v51_model: + if mp.param['reverse']: + wave_left = np.flip(np.asfortranarray(wave[0])) + wave_right = np.flip(np.asfortranarray(wave[1])) + elif mp.param['mid_side']: + wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2) + wave_right = np.asfortranarray(np.subtract(wave[0], wave[1])) + elif mp.param['mid_side_b2']: + wave_left = np.asfortranarray(np.add(wave[1], wave[0] * .5)) + wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * .5)) + else: + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + else: + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + + spec_left = librosa.stft(wave_left, n_fft, hop_length=hop_length) + spec_right = librosa.stft(wave_right, n_fft, hop_length=hop_length) + + spec = np.asfortranarray([spec_left, spec_right]) + + if is_v51_model: + spec = convert_channels(spec, mp, band) + + return spec + +def spectrogram_to_wave(spec, hop_length=1024, mp={}, band=0, is_v51_model=True): spec_left = np.asfortranarray(spec[0]) spec_right = np.asfortranarray(spec[1]) - + wave_left = librosa.istft(spec_left, hop_length=hop_length) wave_right = librosa.istft(spec_right, hop_length=hop_length) - - if reverse: - return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)]) - elif mid_side: - return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)]) - elif mid_side_b2: - return np.asfortranarray([np.add(wave_right / 1.25, .4 * wave_left), np.subtract(wave_left / 1.25, .4 * wave_right)]) + + if is_v51_model: + cc = mp.param['band'][band].get('convert_channels') + if 'mid_side_c' == cc: + return np.asfortranarray([np.subtract(wave_left / 1.0625, wave_right / 4.25), np.add(wave_right / 1.0625, wave_left / 4.25)]) + elif 'mid_side' == cc: + return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)]) + elif 'stereo_n' == cc: + return np.asfortranarray([np.subtract(wave_left, wave_right * .25), np.subtract(wave_right, wave_left * .25)]) else: - return np.asfortranarray([wave_left, wave_right]) + if mp.param['reverse']: + return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)]) + elif mp.param['mid_side']: + return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)]) + elif mp.param['mid_side_b2']: + return np.asfortranarray([np.add(wave_right / 1.25, .4 * wave_left), np.subtract(wave_left / 1.25, .4 * wave_right)]) -def spectrogram_to_wave_mt(spec, hop_length, mid_side, reverse, mid_side_b2): - import threading - - spec_left = np.asfortranarray(spec[0]) - spec_right = np.asfortranarray(spec[1]) + return np.asfortranarray([wave_left, wave_right]) - def run_thread(**kwargs): - global wave_left - wave_left = librosa.istft(**kwargs) - - thread = threading.Thread(target=run_thread, kwargs={'stft_matrix': spec_left, 'hop_length': hop_length}) - thread.start() - wave_right = librosa.istft(spec_right, hop_length=hop_length) - thread.join() - - if reverse: - return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)]) - elif mid_side: - return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)]) - elif mid_side_b2: - return np.asfortranarray([np.add(wave_right / 1.25, .4 * wave_left), np.subtract(wave_left / 1.25, .4 * wave_right)]) - else: - return np.asfortranarray([wave_left, wave_right]) - -def cmb_spectrogram_to_wave(spec_m, mp, extra_bins_h=None, extra_bins=None): +def cmb_spectrogram_to_wave(spec_m, mp, extra_bins_h=None, extra_bins=None, is_v51_model=False): bands_n = len(mp.param['band']) offset = 0 @@ -308,31 +311,60 @@ def cmb_spectrogram_to_wave(spec_m, mp, extra_bins_h=None, extra_bins=None): spec_s = np.ndarray(shape=(2, bp['n_fft'] // 2 + 1, spec_m.shape[2]), dtype=complex) h = bp['crop_stop'] - bp['crop_start'] spec_s[:, bp['crop_start']:bp['crop_stop'], :] = spec_m[:, offset:offset+h, :] - + offset += h if d == bands_n: # higher if extra_bins_h: # if --high_end_process bypass max_bin = bp['n_fft'] // 2 spec_s[:, max_bin-extra_bins_h:max_bin, :] = extra_bins[:, :extra_bins_h, :] if bp['hpf_start'] > 0: - spec_s = fft_hp_filter(spec_s, bp['hpf_start'], bp['hpf_stop'] - 1) + if is_v51_model: + spec_s *= get_hp_filter_mask(spec_s.shape[1], bp['hpf_start'], bp['hpf_stop'] - 1) + else: + spec_s = fft_hp_filter(spec_s, bp['hpf_start'], bp['hpf_stop'] - 1) if bands_n == 1: - wave = spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse']) + wave = spectrogram_to_wave(spec_s, bp['hl'], mp, d, is_v51_model) else: - wave = np.add(wave, spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse'])) + wave = np.add(wave, spectrogram_to_wave(spec_s, bp['hl'], mp, d, is_v51_model)) else: sr = mp.param['band'][d+1]['sr'] if d == 1: # lower - spec_s = fft_lp_filter(spec_s, bp['lpf_start'], bp['lpf_stop']) - wave = librosa.resample(spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse']), bp['sr'], sr, res_type=wav_resolution) + if is_v51_model: + spec_s *= get_lp_filter_mask(spec_s.shape[1], bp['lpf_start'], bp['lpf_stop']) + else: + spec_s = fft_lp_filter(spec_s, bp['lpf_start'], bp['lpf_stop']) + wave = librosa.resample(spectrogram_to_wave(spec_s, bp['hl'], mp, d, is_v51_model), bp['sr'], sr, res_type=wav_resolution) else: # mid - spec_s = fft_hp_filter(spec_s, bp['hpf_start'], bp['hpf_stop'] - 1) - spec_s = fft_lp_filter(spec_s, bp['lpf_start'], bp['lpf_stop']) - wave2 = np.add(wave, spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse'])) + if is_v51_model: + spec_s *= get_hp_filter_mask(spec_s.shape[1], bp['hpf_start'], bp['hpf_stop'] - 1) + spec_s *= get_lp_filter_mask(spec_s.shape[1], bp['lpf_start'], bp['lpf_stop']) + else: + spec_s = fft_hp_filter(spec_s, bp['hpf_start'], bp['hpf_stop'] - 1) + spec_s = fft_lp_filter(spec_s, bp['lpf_start'], bp['lpf_stop']) + + wave2 = np.add(wave, spectrogram_to_wave(spec_s, bp['hl'], mp, d, is_v51_model)) wave = librosa.resample(wave2, bp['sr'], sr, res_type=wav_resolution) return wave +def get_lp_filter_mask(n_bins, bin_start, bin_stop): + mask = np.concatenate([ + np.ones((bin_start - 1, 1)), + np.linspace(1, 0, bin_stop - bin_start + 1)[:, None], + np.zeros((n_bins - bin_stop, 1)) + ], axis=0) + + return mask + +def get_hp_filter_mask(n_bins, bin_start, bin_stop): + mask = np.concatenate([ + np.zeros((bin_stop + 1, 1)), + np.linspace(0, 1, 1 + bin_start - bin_stop)[:, None], + np.ones((n_bins - bin_start - 2, 1)) + ], axis=0) + + return mask + def fft_lp_filter(spec, bin_start, bin_stop): g = 1.0 for b in range(bin_start, bin_stop): @@ -353,6 +385,30 @@ def fft_hp_filter(spec, bin_start, bin_stop): return spec +def spectrogram_to_wave_old(spec, hop_length=1024): + if spec.ndim == 2: + wave = librosa.istft(spec, hop_length=hop_length) + elif spec.ndim == 3: + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + + wave_left = librosa.istft(spec_left, hop_length=hop_length) + wave_right = librosa.istft(spec_right, hop_length=hop_length) + wave = np.asfortranarray([wave_left, wave_right]) + + return wave + +def wave_to_spectrogram_old(wave, hop_length, n_fft): + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + + spec_left = librosa.stft(wave_left, n_fft, hop_length=hop_length) + spec_right = librosa.stft(wave_right, n_fft, hop_length=hop_length) + + spec = np.asfortranarray([spec_left, spec_right]) + + return spec + def mirroring(a, spec_m, input_high_end, mp): if 'mirroring' == a: mirror = np.flip(np.abs(spec_m[:, mp.param['pre_filter_start']-10-input_high_end.shape[1]:mp.param['pre_filter_start']-10, :]), 1) @@ -367,7 +423,7 @@ def mirroring(a, spec_m, input_high_end, mp): return np.where(np.abs(input_high_end) <= np.abs(mi), input_high_end, mi) def adjust_aggr(mask, is_non_accom_stem, aggressiveness): - aggr = aggressiveness['value'] + aggr = aggressiveness['value'] * 2 if aggr != 0: if is_non_accom_stem: @@ -383,9 +439,6 @@ def adjust_aggr(mask, is_non_accom_stem, aggressiveness): mask[ch, :aggressiveness['split_bin']] = np.power(mask[ch, :aggressiveness['split_bin']], 1 + aggr[ch] / 3) mask[ch, aggressiveness['split_bin']:] = np.power(mask[ch, aggressiveness['split_bin']:], 1 + aggr[ch]) - # if is_non_accom_stem: - # mask = (1.0 - mask) - return mask def stft(wave, nfft, hl): @@ -460,32 +513,51 @@ def invert_audio(specs, invert_p=True): return v_spec def invert_stem(mixture, stem): - mixture = wave_to_spectrogram_no_mp(mixture) stem = wave_to_spectrogram_no_mp(stem) output = spectrogram_to_wave_no_mp(invert_audio([mixture, stem])) return -output.T -def ensembling(a, specs): - for i in range(1, len(specs)): - if i == 1: - spec = specs[0] +def ensembling(a, inputs, is_wavs=False): - ln = min([spec.shape[2], specs[i].shape[2]]) - spec = spec[:,:,:ln] - specs[i] = specs[i][:,:,:ln] + for i in range(1, len(inputs)): + if i == 1: + input = inputs[0] + + if is_wavs: + ln = min([input.shape[1], inputs[i].shape[1]]) + input = input[:,:ln] + inputs[i] = inputs[i][:,:ln] + else: + ln = min([input.shape[2], inputs[i].shape[2]]) + input = input[:,:,:ln] + inputs[i] = inputs[i][:,:,:ln] if MIN_SPEC == a: - spec = np.where(np.abs(specs[i]) <= np.abs(spec), specs[i], spec) + input = np.where(np.abs(inputs[i]) <= np.abs(input), inputs[i], input) if MAX_SPEC == a: - spec = np.where(np.abs(specs[i]) >= np.abs(spec), specs[i], spec) - if AVERAGE == a: - spec = np.where(np.abs(specs[i]) == np.abs(spec), specs[i], spec) + input = np.where(np.abs(inputs[i]) >= np.abs(input), inputs[i], input) - return spec + #linear_ensemble + #input = ensemble_wav(inputs, split_size=1) -def ensemble_inputs(audio_input, algorithm, is_normalization, wav_type_set, save_path): + return input + +def ensemble_for_align(waves): + + specs = [] + + for wav in waves: + spec = wave_to_spectrogram_no_mp(wav.T) + specs.append(spec) + + wav_aligned = spectrogram_to_wave_no_mp(ensembling(MIN_SPEC, specs)).T + wav_aligned = match_array_shapes(wav_aligned, waves[1], is_swap=True) + + return wav_aligned + +def ensemble_inputs(audio_input, algorithm, is_normalization, wav_type_set, save_path, is_wave=False, is_array=False): wavs_ = [] @@ -498,13 +570,17 @@ def ensemble_inputs(audio_input, algorithm, is_normalization, wav_type_set, save for i in range(len(audio_input)): wave, samplerate = librosa.load(audio_input[i], mono=False, sr=44100) wavs_.append(wave) - spec = wave_to_spectrogram_no_mp(wave) + spec = wave if is_wave else wave_to_spectrogram_no_mp(wave) specs.append(spec) wave_shapes = [w.shape[1] for w in wavs_] target_shape = wavs_[wave_shapes.index(max(wave_shapes))] - output = spectrogram_to_wave_no_mp(ensembling(algorithm, specs)) + if is_wave: + output = ensembling(algorithm, specs, is_wavs=True) + else: + output = spectrogram_to_wave_no_mp(ensembling(algorithm, specs)) + output = to_shape(output, target_shape.shape) sf.write(save_path, normalize(output.T, is_normalization), samplerate, subtype=wav_type_set) @@ -528,26 +604,142 @@ def to_shape_minimize(x: np.ndarray, target_shape): return np.pad(x, tuple(padding_list), mode='constant') -def augment_audio(export_path, audio_file, rate, is_normalization, wav_type_set, save_format=None, is_pitch=False): +def detect_leading_silence(audio, sr, silence_threshold=0.007, frame_length=1024): + """ + Detect silence at the beginning of an audio signal. + + :param audio: np.array, audio signal + :param sr: int, sample rate + :param silence_threshold: float, magnitude threshold below which is considered silence + :param frame_length: int, the number of samples to consider for each check + + :return: float, duration of the leading silence in milliseconds + """ + + if len(audio.shape) == 2: + # If stereo, pick the channel with more energy to determine the silence + channel = np.argmax(np.sum(np.abs(audio), axis=1)) + audio = audio[channel] + + for i in range(0, len(audio), frame_length): + if np.max(np.abs(audio[i:i+frame_length])) > silence_threshold: + return (i / sr) * 1000 + + return (len(audio) / sr) * 1000 + +def adjust_leading_silence(target_audio, reference_audio, silence_threshold=0.01, frame_length=1024): + """ + Adjust the leading silence of the target_audio to match the leading silence of the reference_audio. + + :param target_audio: np.array, audio signal that will have its silence adjusted + :param reference_audio: np.array, audio signal used as a reference + :param sr: int, sample rate + :param silence_threshold: float, magnitude threshold below which is considered silence + :param frame_length: int, the number of samples to consider for each check + + :return: np.array, target_audio adjusted to have the same leading silence as reference_audio + """ + + def find_silence_end(audio): + if len(audio.shape) == 2: + # If stereo, pick the channel with more energy to determine the silence + channel = np.argmax(np.sum(np.abs(audio), axis=1)) + audio_mono = audio[channel] + else: + audio_mono = audio + + for i in range(0, len(audio_mono), frame_length): + if np.max(np.abs(audio_mono[i:i+frame_length])) > silence_threshold: + return i + return len(audio_mono) + + ref_silence_end = find_silence_end(reference_audio) + target_silence_end = find_silence_end(target_audio) + silence_difference = ref_silence_end - target_silence_end + + try: + ref_silence_end_p = (ref_silence_end / 44100) * 1000 + target_silence_end_p = (target_silence_end / 44100) * 1000 + silence_difference_p = ref_silence_end_p - target_silence_end_p + print("silence_difference: ", silence_difference_p) + except Exception as e: + pass + + if silence_difference > 0: # Add silence to target_audio + if len(target_audio.shape) == 2: # stereo + silence_to_add = np.zeros((target_audio.shape[0], silence_difference)) + else: # mono + silence_to_add = np.zeros(silence_difference) + return np.hstack((silence_to_add, target_audio)) + elif silence_difference < 0: # Remove silence from target_audio + if len(target_audio.shape) == 2: # stereo + return target_audio[:, -silence_difference:] + else: # mono + return target_audio[-silence_difference:] + else: # No adjustment needed + return target_audio + +def match_array_shapes(array_1:np.ndarray, array_2:np.ndarray, is_swap=False): + + if is_swap: + array_1, array_2 = array_1.T, array_2.T + + #print("before", array_1.shape, array_2.shape) + if array_1.shape[1] > array_2.shape[1]: + array_1 = array_1[:,:array_2.shape[1]] + elif array_1.shape[1] < array_2.shape[1]: + padding = array_2.shape[1] - array_1.shape[1] + array_1 = np.pad(array_1, ((0,0), (0,padding)), 'constant', constant_values=0) + + #print("after", array_1.shape, array_2.shape) + + if is_swap: + array_1, array_2 = array_1.T, array_2.T + + return array_1 + +def match_mono_array_shapes(array_1: np.ndarray, array_2: np.ndarray): + + if len(array_1) > len(array_2): + array_1 = array_1[:len(array_2)] + elif len(array_1) < len(array_2): + padding = len(array_2) - len(array_1) + array_1 = np.pad(array_1, (0, padding), 'constant', constant_values=0) + + return array_1 + +def change_pitch_semitones(y, sr, semitone_shift): + factor = 2 ** (semitone_shift / 12) # Convert semitone shift to factor for resampling + y_pitch_tuned = [] + for y_channel in y: + y_pitch_tuned.append(librosa.resample(y_channel, sr, sr*factor, res_type=wav_resolution)) + y_pitch_tuned = np.array(y_pitch_tuned) + new_sr = sr * factor + return y_pitch_tuned, new_sr + +def augment_audio(export_path, audio_file, rate, is_normalization, wav_type_set, save_format=None, is_pitch=False, is_time_correction=True): wav, sr = librosa.load(audio_file, sr=44100, mono=False) if wav.ndim == 1: wav = np.asfortranarray([wav,wav]) - if is_pitch: - wav_1 = pyrb.pitch_shift(wav[0], sr, rate, rbargs=None) - wav_2 = pyrb.pitch_shift(wav[1], sr, rate, rbargs=None) + if not is_time_correction: + wav_mix = change_pitch_semitones(wav, 44100, semitone_shift=-rate)[0] else: - wav_1 = pyrb.time_stretch(wav[0], sr, rate, rbargs=None) - wav_2 = pyrb.time_stretch(wav[1], sr, rate, rbargs=None) + if is_pitch: + wav_1 = pyrb.pitch_shift(wav[0], sr, rate, rbargs=None) + wav_2 = pyrb.pitch_shift(wav[1], sr, rate, rbargs=None) + else: + wav_1 = pyrb.time_stretch(wav[0], sr, rate, rbargs=None) + wav_2 = pyrb.time_stretch(wav[1], sr, rate, rbargs=None) - if wav_1.shape > wav_2.shape: - wav_2 = to_shape(wav_2, wav_1.shape) - if wav_1.shape < wav_2.shape: - wav_1 = to_shape(wav_1, wav_2.shape) - - wav_mix = np.asfortranarray([wav_1, wav_2]) + if wav_1.shape > wav_2.shape: + wav_2 = to_shape(wav_2, wav_1.shape) + if wav_1.shape < wav_2.shape: + wav_1 = to_shape(wav_1, wav_2.shape) + + wav_mix = np.asfortranarray([wav_1, wav_2]) sf.write(export_path, normalize(wav_mix.T, is_normalization), sr, subtype=wav_type_set) save_format(export_path) @@ -587,7 +779,7 @@ def average_dual_sources(wav_1, wav_2, value): wave = (wav_1 * value) + (wav_2 * (1-value)) return wave - + def reshape_sources(wav_1: np.ndarray, wav_2: np.ndarray): if wav_1.shape > wav_2.shape: @@ -602,91 +794,434 @@ def reshape_sources(wav_1: np.ndarray, wav_2: np.ndarray): return wav_2 -def align_audio(file1, file2, file2_aligned, file_subtracted, wav_type_set, is_normalization, command_Text, progress_bar_main_var, save_format): +def reshape_sources_ref(wav_1_shape, wav_2: np.ndarray): + + if wav_1_shape > wav_2.shape: + wav_2 = to_shape(wav_2, wav_1_shape) + + return wav_2 + +def combine_arrarys(audio_sources, is_swap=False): + source = np.zeros_like(max(audio_sources, key=np.size)) + + for v in audio_sources: + v = match_array_shapes(v, source, is_swap=is_swap) + source += v + + return source + +def combine_audio(paths: list, audio_file_base=None, wav_type_set='FLOAT', save_format=None): + + source = combine_arrarys([load_audio(i) for i in paths]) + save_path = f"{audio_file_base}_combined.wav" + sf.write(save_path, source.T, 44100, subtype=wav_type_set) + save_format(save_path) + +def reduce_mix_bv(inst_source, voc_source, reduction_rate=0.9): + # Reduce the volume + inst_source = inst_source * (1 - reduction_rate) + + mix_reduced = combine_arrarys([inst_source, voc_source], is_swap=True) + + return mix_reduced + +def organize_inputs(inputs): + input_list = { + "target":None, + "reference":None, + "reverb":None, + "inst":None + } + + for i in inputs: + if i.endswith("_(Vocals).wav"): + input_list["reference"] = i + elif "_RVC_" in i: + input_list["target"] = i + elif i.endswith("reverbed_stem.wav"): + input_list["reverb"] = i + elif i.endswith("_(Instrumental).wav"): + input_list["inst"] = i + + return input_list + +def check_if_phase_inverted(wav1, wav2, is_mono=False): + # Load the audio files + if not is_mono: + wav1 = np.mean(wav1, axis=0) + wav2 = np.mean(wav2, axis=0) + + # Compute the correlation + correlation = np.corrcoef(wav1[:1000], wav2[:1000]) + + return correlation[0,1] < 0 + +def align_audio(file1, + file2, + file2_aligned, + file_subtracted, + wav_type_set, + is_save_aligned, + command_Text, + save_format, + align_window:list, + align_intro_val:list, + db_analysis:tuple, + set_progress_bar, + phase_option, + phase_shifts, + is_match_silence, + is_spec_match): + + global progress_value + progress_value = 0 + is_mono = False + + def get_diff(a, b): + corr = np.correlate(a, b, "full") + diff = corr.argmax() - (b.shape[0] - 1) + + return diff + + def progress_bar(length): + global progress_value + progress_value += 1 + + if (0.90/length*progress_value) >= 0.9: + length = progress_value + 1 + + set_progress_bar(0.1, (0.9/length*progress_value)) + + # read tracks + wav1, sr1 = librosa.load(file1, sr=44100, mono=False) + wav2, sr2 = librosa.load(file2, sr=44100, mono=False) + + if wav1.ndim == 1 and wav2.ndim == 1: + is_mono = True + elif wav1.ndim == 1: + wav1 = np.asfortranarray([wav1,wav1]) + elif wav2.ndim == 1: + wav2 = np.asfortranarray([wav2,wav2]) + + # Check if phase is inverted + if phase_option == AUTO_PHASE: + if check_if_phase_inverted(wav1, wav2, is_mono=is_mono): + wav2 = -wav2 + elif phase_option == POSITIVE_PHASE: + wav2 = +wav2 + elif phase_option == NEGATIVE_PHASE: + wav2 = -wav2 + + if is_match_silence: + wav2 = adjust_leading_silence(wav2, wav1) + + wav1_length = int(librosa.get_duration(y=wav1, sr=44100)) + wav2_length = int(librosa.get_duration(y=wav2, sr=44100)) + + if not is_mono: + wav1 = wav1.transpose() + wav2 = wav2.transpose() + + wav2_org = wav2.copy() + + command_Text("Processing files... \n") + seconds_length = min(wav1_length, wav2_length) + + wav2_aligned_sources = [] + + for sec_len in align_intro_val: + # pick a position at 1 second in and get diff + sec_seg = 1 if sec_len == 1 else int(seconds_length // sec_len) + index = sr1*sec_seg # 1 second in, assuming sr1 = sr2 = 44100 + + if is_mono: + samp1, samp2 = wav1[index : index + sr1], wav2[index : index + sr1] + diff = get_diff(samp1, samp2) + #print(f"Estimated difference: {diff}\n") + else: + index = sr1*sec_seg # 1 second in, assuming sr1 = sr2 = 44100 + samp1, samp2 = wav1[index : index + sr1, 0], wav2[index : index + sr1, 0] + samp1_r, samp2_r = wav1[index : index + sr1, 1], wav2[index : index + sr1, 1] + diff, diff_r = get_diff(samp1, samp2), get_diff(samp1_r, samp2_r) + #print(f"Estimated difference Left Channel: {diff}\nEstimated difference Right Channel: {diff_r}\n") + + # make aligned track 2 + if diff > 0: + zeros_to_append = np.zeros(diff) if is_mono else np.zeros((diff, 2)) + wav2_aligned = np.append(zeros_to_append, wav2_org, axis=0) + elif diff < 0: + wav2_aligned = wav2_org[-diff:] + else: + wav2_aligned = wav2_org + #command_Text(f"Audio files already aligned.\n") + + if not any(np.array_equal(wav2_aligned, source) for source in wav2_aligned_sources): + wav2_aligned_sources.append(wav2_aligned) + + #print("Unique Sources: ", len(wav2_aligned_sources)) + + unique_sources = len(wav2_aligned_sources) + + sub_mapper_big_mapper = {} + + for s in wav2_aligned_sources: + wav2_aligned = match_mono_array_shapes(s, wav1) if is_mono else match_array_shapes(s, wav1, is_swap=True) + + if align_window: + wav_sub = time_correction(wav1, wav2_aligned, seconds_length, align_window=align_window, db_analysis=db_analysis, progress_bar=progress_bar, unique_sources=unique_sources, phase_shifts=phase_shifts) + wav_sub_size = np.abs(wav_sub).mean() + sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{wav_sub_size:wav_sub}} + else: + wav2_aligned = wav2_aligned * np.power(10, db_analysis[0] / 20) + db_range = db_analysis[1] + + for db_adjustment in db_range: + # Adjust the dB of track2 + s_adjusted = wav2_aligned * (10 ** (db_adjustment / 20)) + wav_sub = wav1 - s_adjusted + wav_sub_size = np.abs(wav_sub).mean() + sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{wav_sub_size:wav_sub}} + + #print(sub_mapper_big_mapper.keys(), min(sub_mapper_big_mapper.keys())) + + sub_mapper_value_list = list(sub_mapper_big_mapper.values()) + + if is_spec_match and len(sub_mapper_value_list) >= 2: + #print("using spec ensemble with align") + wav_sub = ensemble_for_align(list(sub_mapper_big_mapper.values())) + else: + #print("using linear ensemble with align") + wav_sub = ensemble_wav(list(sub_mapper_big_mapper.values())) + + #print(f"Mix Mean: {np.abs(wav1).mean()}\nInst Mean: {np.abs(wav2).mean()}") + #print('Final: ', np.abs(wav_sub).mean()) + wav_sub = np.clip(wav_sub, -1, +1) + + command_Text(f"Saving inverted track... ") + + if is_save_aligned or is_spec_match: + wav1 = match_mono_array_shapes(wav1, wav_sub) if is_mono else match_array_shapes(wav1, wav_sub, is_swap=True) + wav2_aligned = wav1 - wav_sub + + if is_spec_match: + if wav1.ndim == 1 and wav2.ndim == 1: + wav2_aligned = np.asfortranarray([wav2_aligned, wav2_aligned]).T + wav1 = np.asfortranarray([wav1, wav1]).T + + wav2_aligned = ensemble_for_align([wav2_aligned, wav1]) + wav_sub = wav1 - wav2_aligned + + if is_save_aligned: + sf.write(file2_aligned, wav2_aligned, sr1, subtype=wav_type_set) + save_format(file2_aligned) + + sf.write(file_subtracted, wav_sub, sr1, subtype=wav_type_set) + save_format(file_subtracted) + +def phase_shift_hilbert(signal, degree): + analytic_signal = hilbert(signal) + return np.cos(np.radians(degree)) * analytic_signal.real - np.sin(np.radians(degree)) * analytic_signal.imag + +def get_phase_shifted_tracks(track, phase_shift): + if phase_shift == 180: + return [track, -track] + + step = phase_shift + end = 180 - (180 % step) if 180 % step == 0 else 181 + phase_range = range(step, end, step) + + flipped_list = [track, -track] + for i in phase_range: + flipped_list.extend([phase_shift_hilbert(track, i), phase_shift_hilbert(track, -i)]) + + return flipped_list + +def time_correction(mix:np.ndarray, instrumental:np.ndarray, seconds_length, align_window, db_analysis, sr=44100, progress_bar=None, unique_sources=None, phase_shifts=NONE_P): + # Function to align two tracks using cross-correlation + + def align_tracks(track1, track2): + # A dictionary to store each version of track2_shifted and its mean absolute value + shifted_tracks = {} + + # Loop to adjust dB of track2 + track2 = track2 * np.power(10, db_analysis[0] / 20) + db_range = db_analysis[1] + + if phase_shifts == 190: + track2_flipped = [track2] + else: + track2_flipped = get_phase_shifted_tracks(track2, phase_shifts) + + for db_adjustment in db_range: + for t in track2_flipped: + # Adjust the dB of track2 + track2_adjusted = t * (10 ** (db_adjustment / 20)) + corr = correlate(track1, track2_adjusted) + delay = np.argmax(np.abs(corr)) - (len(track1) - 1) + track2_shifted = np.roll(track2_adjusted, shift=delay) + + # Compute the mean absolute value of track2_shifted + track2_shifted_sub = track1 - track2_shifted + mean_abs_value = np.abs(track2_shifted_sub).mean() + + # Store track2_shifted and its mean absolute value in the dictionary + shifted_tracks[mean_abs_value] = track2_shifted + + # Return the version of track2_shifted with the smallest mean absolute value + + return shifted_tracks[min(shifted_tracks.keys())] + + # Make sure the audio files have the same shape + + assert mix.shape == instrumental.shape, f"Audio files must have the same shape - Mix: {mix.shape}, Inst: {instrumental.shape}" + + seconds_length = seconds_length // 2 + + sub_mapper = {} + + progress_update_interval = 120 + total_iterations = 0 + + if len(align_window) > 2: + progress_update_interval = 320 + + for secs in align_window: + step = secs / 2 + window_size = int(sr * secs) + step_size = int(sr * step) + + if len(mix.shape) == 1: + total_mono = (len(range(0, len(mix) - window_size, step_size))//progress_update_interval)*unique_sources + total_iterations += total_mono + else: + total_stereo_ = len(range(0, len(mix[:, 0]) - window_size, step_size))*2 + total_stereo = (total_stereo_//progress_update_interval) * unique_sources + total_iterations += total_stereo + + #print(total_iterations) + + for secs in align_window: + sub = np.zeros_like(mix) + divider = np.zeros_like(mix) + step = secs / 2 + window_size = int(sr * secs) + step_size = int(sr * step) + window = np.hanning(window_size) + + # For the mono case: + if len(mix.shape) == 1: + # The files are mono + counter = 0 + for i in range(0, len(mix) - window_size, step_size): + counter += 1 + if counter % progress_update_interval == 0: + progress_bar(total_iterations) + window_mix = mix[i:i+window_size] * window + window_instrumental = instrumental[i:i+window_size] * window + window_instrumental_aligned = align_tracks(window_mix, window_instrumental) + sub[i:i+window_size] += window_mix - window_instrumental_aligned + divider[i:i+window_size] += window + else: + # The files are stereo + counter = 0 + for ch in range(mix.shape[1]): + for i in range(0, len(mix[:, ch]) - window_size, step_size): + counter += 1 + if counter % progress_update_interval == 0: + progress_bar(total_iterations) + window_mix = mix[i:i+window_size, ch] * window + window_instrumental = instrumental[i:i+window_size, ch] * window + window_instrumental_aligned = align_tracks(window_mix, window_instrumental) + sub[i:i+window_size, ch] += window_mix - window_instrumental_aligned + divider[i:i+window_size, ch] += window + + # Normalize the result by the overlap count + sub = np.where(divider > 1e-6, sub / divider, sub) + sub_size = np.abs(sub).mean() + sub_mapper = {**sub_mapper, **{sub_size: sub}} + + #print("SUB_LEN", len(list(sub_mapper.values()))) + + sub = ensemble_wav(list(sub_mapper.values()), split_size=12) + + return sub + +def ensemble_wav(waveforms, split_size=240): + # Create a dictionary to hold the thirds of each waveform and their mean absolute values + waveform_thirds = {i: np.array_split(waveform, split_size) for i, waveform in enumerate(waveforms)} + + # Initialize the final waveform + final_waveform = [] + + # For chunk + for third_idx in range(split_size): + # Compute the mean absolute value of each third from each waveform + means = [np.abs(waveform_thirds[i][third_idx]).mean() for i in range(len(waveforms))] + + # Find the index of the waveform with the lowest mean absolute value for this third + min_index = np.argmin(means) + + # Add the least noisy third to the final waveform + final_waveform.append(waveform_thirds[min_index][third_idx]) + + # Concatenate all the thirds to create the final waveform + final_waveform = np.concatenate(final_waveform) + + return final_waveform + +def ensemble_wav_min(waveforms): + for i in range(1, len(waveforms)): + if i == 1: + wave = waveforms[0] + + ln = min(len(wave), len(waveforms[i])) + wave = wave[:ln] + waveforms[i] = waveforms[i][:ln] + + wave = np.where(np.abs(waveforms[i]) <= np.abs(wave), waveforms[i], wave) + + return wave + +def align_audio_test(wav1, wav2, sr1=44100): def get_diff(a, b): corr = np.correlate(a, b, "full") diff = corr.argmax() - (b.shape[0] - 1) return diff - progress_bar_main_var.set(10) - # read tracks - wav1, sr1 = librosa.load(file1, sr=44100, mono=False) - wav2, sr2 = librosa.load(file2, sr=44100, mono=False) wav1 = wav1.transpose() wav2 = wav2.transpose() - command_Text(f"Audio file shapes: {wav1.shape} / {wav2.shape}\n") + #print(f"Audio file shapes: {wav1.shape} / {wav2.shape}\n") wav2_org = wav2.copy() - progress_bar_main_var.set(20) - command_Text("Processing files... \n") - - # pick random position and get diff - - counts = {} # counting up for each diff value - progress = 20 - - check_range = 64 - - base = (64 / check_range) - - for i in range(check_range): - index = int(random.uniform(44100 * 2, min(wav1.shape[0], wav2.shape[0]) - 44100 * 2)) - shift = int(random.uniform(-22050,+22050)) - samp1 = wav1[index :index +44100, 0] # currently use left channel - samp2 = wav2[index+shift:index+shift+44100, 0] - progress += 1 * base - progress_bar_main_var.set(progress) - diff = get_diff(samp1, samp2) - diff -= shift - - if abs(diff) < 22050: - if not diff in counts: - counts[diff] = 0 - counts[diff] += 1 - - # use max counted diff value - max_count = 0 - est_diff = 0 - for diff in counts.keys(): - if counts[diff] > max_count: - max_count = counts[diff] - est_diff = diff - - command_Text(f"Estimated difference is {est_diff} (count: {max_count})\n") - - progress_bar_main_var.set(90) - - audio_files = [] - - def save_aligned_audio(wav2_aligned): - command_Text(f"Aligned File 2 with File 1.\n") - command_Text(f"Saving files... ") - sf.write(file2_aligned, normalize(wav2_aligned, is_normalization), sr2, subtype=wav_type_set) - save_format(file2_aligned) - min_len = min(wav1.shape[0], wav2_aligned.shape[0]) - wav_sub = wav1[:min_len] - wav2_aligned[:min_len] - audio_files.append(file2_aligned) - return min_len, wav_sub + # pick a position at 1 second in and get diff + index = sr1#*seconds_length # 1 second in, assuming sr1 = sr2 = 44100 + samp1 = wav1[index : index + sr1, 0] # currently use left channel + samp2 = wav2[index : index + sr1, 0] + diff = get_diff(samp1, samp2) # make aligned track 2 - if est_diff > 0: - wav2_aligned = np.append(np.zeros((est_diff, 2)), wav2_org, axis=0) - min_len, wav_sub = save_aligned_audio(wav2_aligned) - elif est_diff < 0: - wav2_aligned = wav2_org[-est_diff:] - min_len, wav_sub = save_aligned_audio(wav2_aligned) + if diff > 0: + wav2_aligned = np.append(np.zeros((diff, 1)), wav2_org, axis=0) + elif diff < 0: + wav2_aligned = wav2_org[-diff:] else: - command_Text(f"Audio files already aligned.\n") - command_Text(f"Saving inverted track... ") - min_len = min(wav1.shape[0], wav2.shape[0]) - wav_sub = wav1[:min_len] - wav2[:min_len] + wav2_aligned = wav2_org + + return wav2_aligned - wav_sub = np.clip(wav_sub, -1, +1) - - sf.write(file_subtracted, normalize(wav_sub, is_normalization), sr1, subtype=wav_type_set) - save_format(file_subtracted) - - progress_bar_main_var.set(95) \ No newline at end of file +def load_audio(audio_file): + wav, sr = librosa.load(audio_file, sr=44100, mono=False) + + if wav.ndim == 1: + wav = np.asfortranarray([wav,wav]) + + return wav + +def rerun_mp3(audio_file, sample_rate=44100): + with audioread.audio_open(audio_file) as f: + track_length = int(f.duration) + + return librosa.load(audio_file, duration=track_length, mono=False, sr=sample_rate)[0] diff --git a/lib_v5/tfc_tdf_v3.py b/lib_v5/tfc_tdf_v3.py new file mode 100644 index 0000000..8f72595 --- /dev/null +++ b/lib_v5/tfc_tdf_v3.py @@ -0,0 +1,234 @@ +import torch +import torch.nn as nn +from functools import partial + +class STFT: + def __init__(self, n_fft, hop_length, dim_f): + self.n_fft = n_fft + self.hop_length = hop_length + self.window = torch.hann_window(window_length=self.n_fft, periodic=True) + self.dim_f = dim_f + + def __call__(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-2] + c, t = x.shape[-2:] + x = x.reshape([-1, t]) + x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True,return_complex=False) + x = x.permute([0, 3, 1, 2]) + x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) + return x[..., :self.dim_f, :] + + def inverse(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-3] + c, f, t = x.shape[-3:] + n = self.n_fft // 2 + 1 + f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device) + x = torch.cat([x, f_pad], -2) + x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) + x = x.permute([0, 2, 3, 1]) + x = x[..., 0] + x[..., 1] * 1.j + x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True) + x = x.reshape([*batch_dims, 2, -1]) + return x + + +def get_norm(norm_type): + def norm(c, norm_type): + if norm_type == 'BatchNorm': + return nn.BatchNorm2d(c) + elif norm_type == 'InstanceNorm': + return nn.InstanceNorm2d(c, affine=True) + elif 'GroupNorm' in norm_type: + g = int(norm_type.replace('GroupNorm', '')) + return nn.GroupNorm(num_groups=g, num_channels=c) + else: + return nn.Identity() + + return partial(norm, norm_type=norm_type) + + +def get_act(act_type): + if act_type == 'gelu': + return nn.GELU() + elif act_type == 'relu': + return nn.ReLU() + elif act_type[:3] == 'elu': + alpha = float(act_type.replace('elu', '')) + return nn.ELU(alpha) + else: + raise Exception + + +class Upscale(nn.Module): + def __init__(self, in_c, out_c, scale, norm, act): + super().__init__() + self.conv = nn.Sequential( + norm(in_c), + act, + nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) + ) + + def forward(self, x): + return self.conv(x) + + +class Downscale(nn.Module): + def __init__(self, in_c, out_c, scale, norm, act): + super().__init__() + self.conv = nn.Sequential( + norm(in_c), + act, + nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) + ) + + def forward(self, x): + return self.conv(x) + + +class TFC_TDF(nn.Module): + def __init__(self, in_c, c, l, f, bn, norm, act): + super().__init__() + + self.blocks = nn.ModuleList() + for i in range(l): + block = nn.Module() + + block.tfc1 = nn.Sequential( + norm(in_c), + act, + nn.Conv2d(in_c, c, 3, 1, 1, bias=False), + ) + block.tdf = nn.Sequential( + norm(c), + act, + nn.Linear(f, f // bn, bias=False), + norm(c), + act, + nn.Linear(f // bn, f, bias=False), + ) + block.tfc2 = nn.Sequential( + norm(c), + act, + nn.Conv2d(c, c, 3, 1, 1, bias=False), + ) + block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False) + + self.blocks.append(block) + in_c = c + + def forward(self, x): + for block in self.blocks: + s = block.shortcut(x) + x = block.tfc1(x) + x = x + block.tdf(x) + x = block.tfc2(x) + x = x + s + return x + + +class TFC_TDF_net(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + norm = get_norm(norm_type=config.model.norm) + act = get_act(act_type=config.model.act) + + self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments) + self.num_subbands = config.model.num_subbands + + dim_c = self.num_subbands * config.audio.num_channels * 2 + n = config.model.num_scales + scale = config.model.scale + l = config.model.num_blocks_per_scale + c = config.model.num_channels + g = config.model.growth + bn = config.model.bottleneck_factor + f = config.audio.dim_f // self.num_subbands + + self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False) + + self.encoder_blocks = nn.ModuleList() + for i in range(n): + block = nn.Module() + block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act) + block.downscale = Downscale(c, c + g, scale, norm, act) + f = f // scale[1] + c += g + self.encoder_blocks.append(block) + + self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act) + + self.decoder_blocks = nn.ModuleList() + for i in range(n): + block = nn.Module() + block.upscale = Upscale(c, c - g, scale, norm, act) + f = f * scale[1] + c -= g + block.tfc_tdf = TFC_TDF(2 * c, c, l, f, bn, norm, act) + self.decoder_blocks.append(block) + + self.final_conv = nn.Sequential( + nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), + act, + nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) + ) + + self.stft = STFT(config.audio.n_fft, config.audio.hop_length, config.audio.dim_f) + + def cac2cws(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c, k, f // k, t) + x = x.reshape(b, c * k, f // k, t) + return x + + def cws2cac(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c // k, k, f, t) + x = x.reshape(b, c // k, f * k, t) + return x + + def forward(self, x): + + x = self.stft(x) + + mix = x = self.cac2cws(x) + + first_conv_out = x = self.first_conv(x) + + x = x.transpose(-1, -2) + + encoder_outputs = [] + for block in self.encoder_blocks: + x = block.tfc_tdf(x) + encoder_outputs.append(x) + x = block.downscale(x) + + x = self.bottleneck_block(x) + + for block in self.decoder_blocks: + x = block.upscale(x) + x = torch.cat([x, encoder_outputs.pop()], 1) + x = block.tfc_tdf(x) + + x = x.transpose(-1, -2) + + x = x * first_conv_out # reduce artifacts + + x = self.final_conv(torch.cat([mix, x], 1)) + + x = self.cws2cac(x) + + if self.num_target_instruments > 1: + b, c, f, t = x.shape + x = x.reshape(b, self.num_target_instruments, -1, f, t) + + x = self.stft.inverse(x) + + return x + +