From d21f92860d5c8aae081ffd114880b102d9d592d2 Mon Sep 17 00:00:00 2001 From: aufr33 <65520685+aufr33@users.noreply.github.com> Date: Thu, 8 Apr 2021 04:40:02 +0300 Subject: [PATCH] Update inference.py --- inference.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/inference.py b/inference.py index 10cd895..9c16245 100644 --- a/inference.py +++ b/inference.py @@ -97,9 +97,6 @@ def main(): p.add_argument('--gpu', '-g', type=int, default=-1) p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth') p.add_argument('--input', '-i', required=True) - #p.add_argument('--sr', '-r', type=int, default=44100) - #p.add_argument('--n_fft', type=int, default=2048) # combined - #p.add_argument('--hop_length', '-l', type=int, default=1024) p.add_argument('--model_params', '-m', type=str, default='') p.add_argument('--window_size', '-w', type=int, default=352) p.add_argument('--output_image', '-I', action='store_true') @@ -126,8 +123,6 @@ def main(): model.to(device) print('done') - #end = time.time() - #print(end - start) print('loading & stft of wave source...', end=' ') @@ -183,7 +178,7 @@ def main(): if args.high_end_process == 'bypass': wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp, input_high_end_h, input_high_end) - elif args.high_end_process == 'correlation': + elif args.high_end_process == 'correlation': # It needs improvement for i in range(input_high_end.shape[2]): for c in range(2): X_mag_max = np.amax(input_high_end[c, :, i]) @@ -201,12 +196,11 @@ def main(): model_name = os.path.splitext(os.path.basename(args.pretrained_model))[0] sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['inst'])), wave, mp.param['sr']) - if True: - print('inverse stft of {}...'.format(stems['vocals']), end=' ') - #v_spec_m = X_spec_m - y_spec_m - wave = spec_utils.cmb_spectrogram_to_wave(v_spec_m, mp) - print('done') - sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['vocals'])), wave, mp.param['sr']) + print('inverse stft of {}...'.format(stems['vocals']), end=' ') + #v_spec_m = X_spec_m - y_spec_m + wave = spec_utils.cmb_spectrogram_to_wave(v_spec_m, mp) + print('done') + sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['vocals'])), wave, mp.param['sr']) if args.output_image: with open('{}_{}.jpg'.format(basename, stems['inst']), mode='wb') as f: @@ -218,7 +212,7 @@ def main(): _, bin_image = cv2.imencode('.jpg', image) bin_image.tofile(f) - print(time.time() - start) + print('Runtime: {0:.{1}f}s'.format(time.time() - start, 1)) if __name__ == '__main__': main()