diff --git a/inference.py b/inference.py index e520f72..62211b2 100644 --- a/inference.py +++ b/inference.py @@ -10,7 +10,6 @@ import time from tqdm import tqdm from lib import dataset -from lib import nets from lib import spec_utils from lib.model_param_init import ModelParameters @@ -97,18 +96,31 @@ def main(): p.add_argument('--gpu', '-g', type=int, default=-1) p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth') p.add_argument('--input', '-i', required=True) + p.add_argument('--nn_architecture', '-n', type=str, default='default') p.add_argument('--model_params', '-m', type=str, default='') - p.add_argument('--window_size', '-w', type=int, default=320) + p.add_argument('--window_size', '-w', type=int, default=512) p.add_argument('--output_image', '-I', action='store_true') p.add_argument('--postprocess', '-p', action='store_true') p.add_argument('--tta', '-t', action='store_true') p.add_argument('--high_end_process', '-H', type=str, choices=['none', 'bypass', 'correlation'], default='none') - p.add_argument('--aggressiveness', '-A', type=float, default=0.1) + p.add_argument('--aggressiveness', '-A', type=float, default=0.09) args = p.parse_args() + if args.nn_architecture == 'default': + from lib import nets + if args.nn_architecture == '33966KB': + from lib import nets_33966KB as nets + if args.nn_architecture == '123821KB': + from lib import nets_123821KB as nets + if args.nn_architecture == '129605KB': + from lib import nets_129605KB as nets + + #if '' == args.model_params: + # mp = ModelParameters(args.pretrained_model) + #else: mp = ModelParameters(args.model_params) - start = time.time() + start_time = time.time() print('loading model...', end=' ') @@ -139,7 +151,7 @@ def main(): else: # lower bands X_wave[d] = librosa.resample(X_wave[d+1], mp.param['band'][d+1]['sr'], bp['sr'], res_type=bp['res_type']) - X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(X_wave[d], bp['hl'], bp['n_fft'], mp.param['mid_side']) + X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(X_wave[d], bp['hl'], bp['n_fft'], mp.param['mid_side'], mp.param['reverse']) if d == bands_n and args.high_end_process in ['bypass', 'correlation']: input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (mp.param['pre_filter_stop'] - mp.param['pre_filter_start']) @@ -175,7 +187,7 @@ def main(): if args.high_end_process == 'bypass': wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp, input_high_end_h, input_high_end) - elif args.high_end_process == 'correlation': # It needs improvement + elif args.high_end_process == 'correlation': for i in range(input_high_end.shape[2]): for c in range(2): X_mag_max = np.amax(input_high_end[c, :, i]) @@ -193,10 +205,12 @@ def main(): model_name = os.path.splitext(os.path.basename(args.pretrained_model))[0] sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['inst'])), wave, mp.param['sr']) - print('inverse stft of {}...'.format(stems['vocals']), end=' ') - wave = spec_utils.cmb_spectrogram_to_wave(v_spec_m, mp) - print('done') - sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['vocals'])), wave, mp.param['sr']) + if True: + print('inverse stft of {}...'.format(stems['vocals']), end=' ') + #v_spec_m = X_spec_m - y_spec_m + wave = spec_utils.cmb_spectrogram_to_wave(v_spec_m, mp) + print('done') + sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['vocals'])), wave, mp.param['sr']) if args.output_image: with open('{}_{}.jpg'.format(basename, stems['inst']), mode='wb') as f: @@ -208,7 +222,7 @@ def main(): _, bin_image = cv2.imencode('.jpg', image) bin_image.tofile(f) - print('Runtime: {0:.{1}f}s'.format(time.time() - start, 1)) + print('Total time: {0:.{1}f}s'.format(time.time() - start_time, 1)) if __name__ == '__main__': main()