Update inference.py

This commit is contained in:
aufr33 2021-09-26 12:43:43 +03:00 committed by GitHub
parent 33b7900ae2
commit d9468a1edd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,6 +8,7 @@ import numpy as np
import soundfile as sf import soundfile as sf
import torch import torch
import time import time
import math
from tqdm import tqdm from tqdm import tqdm
from lib import dataset from lib import dataset
@ -23,7 +24,7 @@ class VocalRemover(object):
self.device = device self.device = device
self.window_size = window_size self.window_size = window_size
def _execute(self, X_mag_pad, roi_size, n_window, aggressiveness): def _execute(self, X_mag_pad, roi_size, n_window, params):
self.model.eval() self.model.eval()
with torch.no_grad(): with torch.no_grad():
preds = [] preds = []
@ -32,7 +33,7 @@ class VocalRemover(object):
X_mag_window = X_mag_pad[None, :, :, start:start + self.window_size] X_mag_window = X_mag_pad[None, :, :, start:start + self.window_size]
X_mag_window = torch.from_numpy(X_mag_window).to(self.device) X_mag_window = torch.from_numpy(X_mag_window).to(self.device)
pred = self.model.predict(X_mag_window, aggressiveness) pred = self.model.predict(X_mag_window, params)
pred = pred.detach().cpu().numpy() pred = pred.detach().cpu().numpy()
preds.append(pred[0]) preds.append(pred[0])
@ -47,7 +48,7 @@ class VocalRemover(object):
return X_mag, X_phase return X_mag, X_phase
def inference(self, X_spec, aggressiveness): def inference(self, X_spec, params):
X_mag, X_phase = self.preprocess(X_spec) X_mag, X_phase = self.preprocess(X_spec)
coef = X_mag.max() coef = X_mag.max()
@ -59,12 +60,12 @@ class VocalRemover(object):
X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
pred = self._execute(X_mag_pad, roi_size, n_window, aggressiveness) pred = self._execute(X_mag_pad, roi_size, n_window, params)
pred = pred[:, :, :n_frame] pred = pred[:, :, :n_frame]
return pred * coef, X_mag, np.exp(1.j * X_phase) return pred * coef, X_mag, np.exp(1.j * X_phase)
def inference_tta(self, X_spec, aggressiveness): def inference_tta(self, X_spec, params):
X_mag, X_phase = self.preprocess(X_spec) X_mag, X_phase = self.preprocess(X_spec)
coef = X_mag.max() coef = X_mag.max()
@ -76,7 +77,7 @@ class VocalRemover(object):
X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
pred = self._execute(X_mag_pad, roi_size, n_window, aggressiveness) pred = self._execute(X_mag_pad, roi_size, n_window, params)
pred = pred[:, :, :n_frame] pred = pred[:, :, :n_frame]
pad_l += roi_size // 2 pad_l += roi_size // 2
@ -85,39 +86,51 @@ class VocalRemover(object):
X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
pred_tta = self._execute(X_mag_pad, roi_size, n_window, aggressiveness) pred_tta = self._execute(X_mag_pad, roi_size, n_window, params)
pred_tta = pred_tta[:, :, roi_size // 2:] pred_tta = pred_tta[:, :, roi_size // 2:]
pred_tta = pred_tta[:, :, :n_frame] pred_tta = pred_tta[:, :, :n_frame]
return (pred + pred_tta) * 0.5 * coef, X_mag, np.exp(1.j * X_phase) return (pred + pred_tta) * 0.5 * coef, X_mag, np.exp(1.j * X_phase)
def main(): def main():
nn_arch_sizes = [
31191, # default
33966, 123821, 537238 # custom
]
p = argparse.ArgumentParser() p = argparse.ArgumentParser()
p.add_argument('--gpu', '-g', type=int, default=-1) p.add_argument('--gpu', '-g', type=int, default=-1)
p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth') p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth')
p.add_argument('--input', '-i', required=True) p.add_argument('--input', '-i', required=True)
p.add_argument('--nn_architecture', '-n', type=str, choices=['default', '33966KB', '123821KB', '129605KB', '537238KB'], default='default') p.add_argument('--nn_architecture', '-n', type=str, choices= ['auto'] + list('{}KB'.format(s) for s in nn_arch_sizes), default='auto')
p.add_argument('--model_params', '-m', type=str, default='') p.add_argument('--model_params', '-m', type=str, default='')
p.add_argument('--window_size', '-w', type=int, default=512) p.add_argument('--window_size', '-w', type=int, default=512)
p.add_argument('--output_image', '-I', action='store_true') p.add_argument('--output_image', '-I', action='store_true')
p.add_argument('--deepextraction', '-D', action='store_true') p.add_argument('--deepextraction', '-D', action='store_true')
p.add_argument('--postprocess', '-p', action='store_true') p.add_argument('--postprocess', '-p', action='store_true')
p.add_argument('--is_vocal_model', '-vm', action='store_true') p.add_argument('--is_vocal_model', '-vm', action='store_true')
p.add_argument('--tta', '-t', action='store_true') p.add_argument('--tta', '-t', action='store_true', help='Test-Time-Augmentation')
p.add_argument('--high_end_process', '-H', type=str, choices=['none', 'bypass', 'correlation', 'mirroring', 'mirroring2'], default='mirroring') p.add_argument('--high_end_process', '-H', type=str, choices=['mirroring', 'mirroring2', 'bypass', 'none'], default='mirroring')
p.add_argument('--aggressiveness', '-A', type=float, default=0.07) p.add_argument('--aggressiveness', '-A', type=float, default=0.07, help='The strength of the vocal isolation. From 0.0 to 1.0.')
p.add_argument('--no_vocals', '-nv', action='store_true') p.add_argument('--no_vocals', '-nv', action='store_true', help='Don\'t create Vocals stem.')
p.add_argument('--chunks', '-c', type=int, default=1, help='Split the input file into chunks to reduce RAM consumption.')
p.add_argument('--model_test_mode', '-mt', action='store_true', help='Include the model name in the output file name.')
p.add_argument('--normalize', action='store_true')
args = p.parse_args() args = p.parse_args()
nets = importlib.import_module('lib.nets' + f'_{args.nn_architecture}'.replace('_default', ''), package=None) separated_dir = 'separated'
ensembled_dir = 'ensembled/temp'
dir = 'ensembled/temp' for file in os.scandir(ensembled_dir):
for file in os.scandir(dir):
os.remove(file.path) os.remove(file.path)
mp = ModelParameters(args.model_params) if 'auto' == args.nn_architecture:
model_size = math.ceil(os.stat(args.pretrained_model).st_size / 1024)
args.nn_architecture = '{}KB'.format(min(nn_arch_sizes, key=lambda x:abs(x-model_size)))
nets = importlib.import_module('lib.nets' + f'_{args.nn_architecture}'.replace('_{}KB'.format(nn_arch_sizes[0]), ''), package=None)
mp = ModelParameters(args.model_params)
start_time = time.time() start_time = time.time()
print('loading model...', end=' ') print('loading model...', end=' ')
@ -133,105 +146,177 @@ def main():
print('loading & stft of wave source...', end=' ') print('loading & stft of wave source...', end=' ')
X_wave, y_wave, X_spec_s, y_spec_s = {}, {}, {}, {} X_spec = {}
input_is_mono = False
basename = os.path.splitext(os.path.basename(args.input))[0] basename = os.path.splitext(os.path.basename(args.input))[0]
basenameb = '"{}"'.format(os.path.splitext(os.path.basename(args.input))[0]) basenameb = '"{}"'.format(os.path.splitext(os.path.basename(args.input))[0])
bands_n = len(mp.param['band']) bands_n = len(mp.param['band'])
# high-end band
bp = mp.param['band'][bands_n]
wave, _ = librosa.load(args.input, bp['sr'], False, dtype=np.float32, res_type=bp['res_type'])
if wave.ndim == 1:
input_is_mono = True
wave = np.asarray([wave, wave])
if args.normalize:
wave /= max(np.max(wave), abs(np.min(wave)))
X_spec[bands_n] = spec_utils.wave_to_spectrogram(wave, bp['hl'], bp['n_fft'], mp, True)
X_spec[bands_n] = spec_utils.convert_channels(X_spec[bands_n], mp, bands_n)
if np.max(wave[0]) == 0.0:
print('Empty audio file!')
raise ValueError('Empty audio file')
if args.high_end_process != 'none':
input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (mp.param['pre_filter_stop'] - mp.param['pre_filter_start'])
input_high_end = X_spec[bands_n][:, bp['n_fft']//2-input_high_end_h:bp['n_fft']//2, :]
for d in range(bands_n, 0, -1): # lower bands
for d in range(bands_n - 1, 0, -1):
bp = mp.param['band'][d] bp = mp.param['band'][d]
if d == bands_n: # high-end band
X_wave[d], _ = librosa.load(
args.input, bp['sr'], False, dtype=np.float32, res_type=bp['res_type'])
if X_wave[d].ndim == 1:
X_wave[d] = np.asarray([X_wave[d], X_wave[d]])
else: # lower bands
X_wave[d] = librosa.resample(X_wave[d+1], mp.param['band'][d+1]['sr'], bp['sr'], res_type=bp['res_type'])
X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp['hl'], bp['n_fft'], mp, True)
if d == bands_n and args.high_end_process != 'none': wave = librosa.resample(wave, mp.param['band'][d+1]['sr'], bp['sr'], res_type=bp['res_type'])
input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (mp.param['pre_filter_stop'] - mp.param['pre_filter_start']) X_spec[d] = spec_utils.wave_to_spectrogram(wave, bp['hl'], bp['n_fft'], mp, True)
input_high_end = X_spec_s[d][:, bp['n_fft']//2-input_high_end_h:bp['n_fft']//2, :] X_spec[d] = spec_utils.convert_channels(X_spec[d], mp, d)
X_spec_m = spec_utils.combine_spectrograms(X_spec_s, mp)
del X_wave, X_spec_s X_spec = spec_utils.combine_spectrograms(X_spec, mp)
print('done') print('done')
vr = VocalRemover(model, device, args.window_size) vr = VocalRemover(model, device, args.window_size)
if args.tta:
pred, X_mag, X_phase = vr.inference_tta(X_spec_m, {'value': args.aggressiveness, 'split_bin': mp.param['band'][1]['crop_stop']})
else:
pred, X_mag, X_phase = vr.inference(X_spec_m, {'value': args.aggressiveness, 'split_bin': mp.param['band'][1]['crop_stop']})
if args.postprocess:
print('post processing...', end=' ')
pred_inv = np.clip(X_mag - pred, 0, np.inf)
pred = spec_utils.mask_silence(pred, pred_inv)
print('done')
if 'is_vocal_model' in mp.param or args.is_vocal_model: # swap
stems = {'inst': 'Vocals', 'vocals': 'Instruments'}
else:
stems = {'inst': 'Instruments', 'vocals': 'Vocals'}
print('inverse stft of {}...'.format(stems['inst']), end=' ')
y_spec_m = pred * X_phase
v_spec_m = X_spec_m - y_spec_m
if args.high_end_process == 'bypass':
wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp, input_high_end_h, input_high_end)
elif args.high_end_process == 'correlation':
print('Deprecated: correlation will be removed in the final release. Please use the mirroring instead.')
for i in range(input_high_end.shape[2]):
for c in range(2):
X_mag_max = np.amax(input_high_end[c, :, i])
b1 = mp.param['pre_filter_start']-input_high_end_h//2
b2 = mp.param['pre_filter_start']-1
if X_mag_max > 0 and np.sum(np.abs(v_spec_m[c, b1:b2, i])) / (b2 - b1) > 0.07:
y_mag = np.median(y_spec_m[c, b1:b2, i])
input_high_end[c, :, i] = np.true_divide(input_high_end[c, :, i], abs(X_mag_max) / min(abs(y_mag * 4), abs(X_mag_max)))
wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp, input_high_end_h, input_high_end)
elif args.high_end_process.startswith('mirroring'):
input_high_end_ = spec_utils.mirroring(args.high_end_process, y_spec_m, input_high_end, mp)
wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp, input_high_end_h, input_high_end_)
else:
wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp)
print('done') chunk_pfx = ''
model_name = os.path.splitext(os.path.basename(args.pretrained_model))[0] chunk_size = X_spec.shape[2] // args.chunks
sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['inst'])), wave, mp.param['sr']) chunks_filelist = {'vocals': {}, 'inst': {}}
for chunk in range(0, args.chunks):
chunk_margin_r = 0
if chunk == 0:
chunk_offset_m, chunk_offset, chunk_margin = 0, 0, 0
else:
chunk_margin = chunk_size // 100 - 1
chunk_offset_m = chunk * chunk_size - chunk_margin - 1
chunk_offset = chunk * chunk_size - 1
if args.chunks > 1:
chunk_pfx = f'_chunk{chunk}'
print(f'Chunk {chunk}')
if chunk < args.chunks - 1:
chunk_margin_r = chunk_size // 100 - 1
pd = {
'aggr_value': args.aggressiveness,
'aggr_split_bin': mp.param['band'][1]['crop_stop'],
'aggr_correction': mp.param.get('aggr_correction'),
'is_vocal_model': args.is_vocal_model
}
if not args.no_vocals: if args.tta:
print('inverse stft of {}...'.format(stems['vocals']), end=' ') pred, X_mag, X_phase = vr.inference_tta(X_spec[:, :, chunk_offset_m:(chunk+1)*chunk_size+chunk_margin_r], pd)
else:
pred, X_mag, X_phase = vr.inference(X_spec[:, :, chunk_offset_m:(chunk+1)*chunk_size+chunk_margin_r], pd)
if args.high_end_process.startswith('mirroring'): if args.postprocess:
input_high_end_ = spec_utils.mirroring(args.high_end_process, v_spec_m, input_high_end, mp) print('post processing...', end=' ')
pred_inv = np.clip(X_mag - pred, 0, np.inf)
pred = spec_utils.mask_silence(pred, pred_inv)
print('done')
wave = spec_utils.cmb_spectrogram_to_wave(v_spec_m, mp, input_high_end_h, input_high_end_) stems = {'inst': 'Instruments', 'vocals': 'Vocals'}
else: basename_enc = basename
wave = spec_utils.cmb_spectrogram_to_wave(v_spec_m, mp)
print('inverse stft of {}...'.format(stems['inst']), end=' ')
y_spec_m = (pred * X_phase)[:, :, chunk_margin:pred.shape[2]-chunk_margin_r]
if args.chunks > 1:
import hashlib
basename_enc = hashlib.sha1(basename.encode('utf-8')).hexdigest()
if chunk > 0: # smoothing
y_spec_m[:, :, 0] = 0.5 * (y_spec_m[:, :, 0] + prev_chunk_edge)
prev_chunk_edge = y_spec_m[:, :, -1]
ffmpeg_tmp_fn = '{}_{}_inst'.format(basename_enc, time.time())
if args.high_end_process == 'bypass':
wave = spec_utils.cmb_spectrogram_to_wave_ffmpeg(y_spec_m, mp, ffmpeg_tmp_fn, input_high_end_h, input_high_end)
elif args.high_end_process.startswith('mirroring'):
input_high_end_ = spec_utils.mirroring(args.high_end_process, y_spec_m, input_high_end[:, :, chunk_offset:(chunk+1)*chunk_size], mp)
wave = spec_utils.cmb_spectrogram_to_wave_ffmpeg(y_spec_m, mp, ffmpeg_tmp_fn, input_high_end_h, input_high_end_)
else:
wave = spec_utils.cmb_spectrogram_to_wave_ffmpeg(y_spec_m, mp, ffmpeg_tmp_fn)
print('done') print('done')
sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['vocals'])), wave, mp.param['sr'])
model_name = ''
if args.model_test_mode:
model_name = '_' + os.path.splitext(os.path.basename(args.pretrained_model))[0]
if input_is_mono:
wave = wave.mean(axis=1, keepdims=True)
fn = os.path.join(separated_dir, '{}{}_{}{}.wav'.format(basename_enc, model_name, stems['inst'], chunk_pfx))
sf.write(fn, wave, mp.param['sr'])
chunks_filelist['inst'][chunk] = fn
if not args.no_vocals:
print('inverse stft of {}...'.format(stems['vocals']), end=' ')
ffmpeg_tmp_fn = '{}_{}_vocals'.format(basename_enc, time.time())
v_spec_m = X_spec[:, :, chunk_offset:(chunk+1)*chunk_size] - y_spec_m
if args.high_end_process.startswith('mirroring'):
input_high_end_ = spec_utils.mirroring(args.high_end_process, v_spec_m, input_high_end[:, :, chunk_offset:(chunk+1)*chunk_size], mp)
wave = spec_utils.cmb_spectrogram_to_wave_ffmpeg(v_spec_m, mp, ffmpeg_tmp_fn, input_high_end_h, input_high_end_)
else:
wave = spec_utils.cmb_spectrogram_to_wave_ffmpeg(v_spec_m, mp, ffmpeg_tmp_fn)
print('done')
if input_is_mono:
wave = wave.mean(axis=1, keepdims=True)
fn = os.path.join(separated_dir, '{}{}_{}{}.wav'.format(basename_enc, model_name, stems['vocals'], chunk_pfx))
sf.write(fn, wave, mp.param['sr'])
chunks_filelist['vocals'][chunk] = fn
for stem in stems:
if len(chunks_filelist[stem]) > 0 and args.chunks > 1:
import subprocess
fn = os.path.join(separated_dir, '{}{}_{}.wav'.format(basename_enc, model_name, stems[stem]))
fn2 = os.path.join(separated_dir, '{}{}_{}.wav'.format(basename, model_name, stems[stem]))
#os.system('sox "' + '" "'.join([f for f in chunks_filelist[stem].values()]) + f'" "{fn}"')
subprocess.run(['sox'] + [f for f in chunks_filelist[stem].values()] + [fn])
if not os.path.isfile(fn):
print('Error: failed to create output file. Make sure that you have installed sox.')
os.rename(fn, fn2)
for rf in chunks_filelist[stem].values():
os.remove(rf)
if args.output_image: if args.output_image:
with open('{}_{}.jpg'.format(basename, stems['inst']), mode='wb') as f: with open('{}_{}.jpg'.format(basename, stems['inst']), mode='wb') as f:
image = spec_utils.spectrogram_to_image(y_spec_m) image = spec_utils.spectrogram_to_image(y_spec_m)
_, bin_image = cv2.imencode('.jpg', image) _, bin_image = cv2.imencode('.jpg', image)
bin_image.tofile(f) bin_image.tofile(f)
with open('{}_{}.jpg'.format(basename, stems['vocals']), mode='wb') as f:
image = spec_utils.spectrogram_to_image(v_spec_m) if not args.no_vocals:
_, bin_image = cv2.imencode('.jpg', image) with open('{}_{}.jpg'.format(basename, stems['vocals']), mode='wb') as f:
bin_image.tofile(f) image = spec_utils.spectrogram_to_image(v_spec_m)
_, bin_image = cv2.imencode('.jpg', image)
bin_image.tofile(f)
if args.deepextraction: if args.deepextraction:
@ -239,17 +324,16 @@ def main():
{ {
'algorithm':'deep', 'algorithm':'deep',
'model_params':'modelparams/1band_sr44100_hl512.json', 'model_params':'modelparams/1band_sr44100_hl512.json',
'file1':"separated/{}_{}_{}.wav".format(basenameb, model_name, stems['vocals'], mp.param['sr']), 'file1':"{}/{}{}_{}.wav".format(separated_dir, basenameb, model_name, stems['vocals'], mp.param['sr']),
'file2':"separated/{}_{}_{}.wav".format(basenameb, model_name, stems['inst'], mp.param['sr']), 'file2':"{}/{}{}_{}.wav".format(separated_dir, basenameb, model_name, stems['inst'], mp.param['sr']),
'output':'separated/{}_{}_{}_Deep_Extraction'.format(basenameb, model_name, stems['inst'], mp.param['sr']) 'output':'{}/{}{}_{}_Deep_Extraction'.format(separated_dir, basenameb, model_name, stems['inst'], mp.param['sr'])
} }
] ]
for i,e in tqdm(enumerate(deepext), desc="Performing Deep Extraction..."): for i,e in tqdm(enumerate(deepext), desc="Performing Deep Extraction..."):
os.system(f"python lib/spec_utils.py -a {e['algorithm']} -m {e['model_params']} {e['file1']} {e['file2']} -o {e['output']}") os.system(f"python lib/spec_utils.py -a {e['algorithm']} -m {e['model_params']} {e['file1']} {e['file2']} -o {e['output']}")
dir = 'ensembled/temp' for file in os.scandir(ensembled_dir):
for file in os.scandir(dir):
os.remove(file.path) os.remove(file.path)
print('Complete!') print('Complete!')