Update inference.py

This commit is contained in:
aufr33 2021-05-27 05:43:02 +03:00 committed by GitHub
parent 9822fba54b
commit 1f00ddc7a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -102,8 +102,9 @@ def main():
p.add_argument('--output_image', '-I', action='store_true')
p.add_argument('--deepextraction', '-D', action='store_true')
p.add_argument('--postprocess', '-p', action='store_true')
p.add_argument('--is_vocal_model', '-vm', action='store_true')
p.add_argument('--tta', '-t', action='store_true')
p.add_argument('--high_end_process', '-H', type=str, choices=['none', 'bypass', 'correlation', 'mirroring'], default='none')
p.add_argument('--high_end_process', '-H', type=str, choices=['none', 'bypass', 'correlation', 'mirroring', 'mirroring2'], default='none')
p.add_argument('--aggressiveness', '-A', type=float, default=0.07)
args = p.parse_args()
@ -158,7 +159,7 @@ def main():
X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(X_wave[d], bp['hl'], bp['n_fft'], mp.param['mid_side'], mp.param['reverse'])
if d == bands_n and args.high_end_process in ['bypass', 'correlation', 'mirroring']:
if d == bands_n and args.high_end_process != 'none':
input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (mp.param['pre_filter_stop'] - mp.param['pre_filter_start'])
input_high_end = X_spec_s[d][:, bp['n_fft']//2-input_high_end_h:bp['n_fft']//2, :]
@ -181,7 +182,7 @@ def main():
pred = spec_utils.mask_silence(pred, pred_inv)
print('done')
if 'is_vocal_model' in mp.param: # swap
if 'is_vocal_model' in mp.param or args.is_vocal_model: # swap
stems = {'inst': 'Vocals', 'vocals': 'Instruments'}
else:
stems = {'inst': 'Instruments', 'vocals': 'Vocals'}
@ -203,11 +204,10 @@ def main():
input_high_end[c, :, i] = np.true_divide(input_high_end[c, :, i], abs(X_mag_max) / min(abs(y_mag * 4), abs(X_mag_max)))
wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp, input_high_end_h, input_high_end)
elif args.high_end_process == 'mirroring':
mirror = np.multiply(np.flip(y_spec_m[:, mp.param['pre_filter_start']-10-input_high_end.shape[1]:mp.param['pre_filter_start']-10, :], 1), 1.7)
input_high_end_ = np.where(np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror)
wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp, input_high_end_h, input_high_end_)
elif args.high_end_process.startswith('mirroring'):
input_high_end_ = spec_utils.mirroring(args.high_end_process, y_spec_m, input_high_end, mp)
wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp, input_high_end_h, input_high_end_)
else:
wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp)
@ -253,10 +253,10 @@ def main():
if True:
print('inverse stft of {}...'.format(stems['vocals']), end=' ')
if args.high_end_process == 'mirroring':
mirror = np.multiply(np.flip(v_spec_m[:, mp.param['pre_filter_start']-10-input_high_end.shape[1]:mp.param['pre_filter_start']-10, :], 1), 1.7)
input_high_end_ = np.where(np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror)
wave = spec_utils.cmb_spectrogram_to_wave(v_spec_m, mp, input_high_end_h, input_high_end_)
if args.high_end_process.startswith('mirroring'):
input_high_end_ = spec_utils.mirroring(args.high_end_process, v_spec_m, input_high_end, mp)
wave = spec_utils.cmb_spectrogram_to_wave(v_spec_m, mp, input_high_end_h, input_high_end_)
else:
wave = spec_utils.cmb_spectrogram_to_wave(v_spec_m, mp)
print('done')