From a718daba382c68a551ae2dd394db953227e7113c Mon Sep 17 00:00:00 2001 From: aufr33 <65520685+aufr33@users.noreply.github.com> Date: Thu, 8 Apr 2021 05:52:04 +0300 Subject: [PATCH] Update inference.py --- inference.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/inference.py b/inference.py index 9c16245..e520f72 100644 --- a/inference.py +++ b/inference.py @@ -98,7 +98,7 @@ def main(): p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth') p.add_argument('--input', '-i', required=True) p.add_argument('--model_params', '-m', type=str, default='') - p.add_argument('--window_size', '-w', type=int, default=352) + p.add_argument('--window_size', '-w', type=int, default=320) p.add_argument('--output_image', '-I', action='store_true') p.add_argument('--postprocess', '-p', action='store_true') p.add_argument('--tta', '-t', action='store_true') @@ -106,10 +106,7 @@ def main(): p.add_argument('--aggressiveness', '-A', type=float, default=0.1) args = p.parse_args() - if '' == args.model_params: - mp = ModelParameters(args.pretrained_model) - else: - mp = ModelParameters(args.model_params) + mp = ModelParameters(args.model_params) start = time.time() @@ -133,7 +130,7 @@ def main(): for d in range(bands_n, 0, -1): bp = mp.param['band'][d] - if d == bands_n: # high band + if d == bands_n: # high-end band X_wave[d], _ = librosa.load( args.input, bp['sr'], False, dtype=np.float32, res_type=bp['res_type']) @@ -168,9 +165,9 @@ def main(): print('done') if 'is_vocal_model' in mp.param: # swap - stems = {'inst': 'Vocals', 'vocals': 'Instrumental'} + stems = {'inst': 'Vocals', 'vocals': 'Instruments'} else: - stems = {'inst': 'Instrumental', 'vocals': 'Vocals'} + stems = {'inst': 'Instruments', 'vocals': 'Vocals'} print('inverse stft of {}...'.format(stems['inst']), end=' ') y_spec_m = pred * X_phase @@ -197,7 +194,6 @@ def main(): sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['inst'])), wave, mp.param['sr']) print('inverse stft of {}...'.format(stems['vocals']), end=' ') - #v_spec_m = X_spec_m - y_spec_m wave = spec_utils.cmb_spectrogram_to_wave(v_spec_m, mp) print('done') sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['vocals'])), wave, mp.param['sr'])