Update inference.py

This commit is contained in:
aufr33 2021-05-09 11:23:25 +03:00 committed by GitHub
parent 50384e1179
commit 8244ad1617
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -103,7 +103,7 @@ def main():
p.add_argument('--deepextraction', '-D', action='store_true') p.add_argument('--deepextraction', '-D', action='store_true')
p.add_argument('--postprocess', '-p', action='store_true') p.add_argument('--postprocess', '-p', action='store_true')
p.add_argument('--tta', '-t', action='store_true') p.add_argument('--tta', '-t', action='store_true')
p.add_argument('--high_end_process', '-H', type=str, choices=['none', 'bypass', 'correlation'], default='none') p.add_argument('--high_end_process', '-H', type=str, choices=['none', 'bypass', 'correlation', 'mirroring'], default='none')
p.add_argument('--aggressiveness', '-A', type=float, default=0.07) p.add_argument('--aggressiveness', '-A', type=float, default=0.07)
args = p.parse_args() args = p.parse_args()
@ -158,7 +158,7 @@ def main():
X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(X_wave[d], bp['hl'], bp['n_fft'], mp.param['mid_side'], mp.param['reverse']) X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(X_wave[d], bp['hl'], bp['n_fft'], mp.param['mid_side'], mp.param['reverse'])
if d == bands_n and args.high_end_process in ['bypass', 'correlation']: if d == bands_n and args.high_end_process in ['bypass', 'correlation', 'mirroring']:
input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (mp.param['pre_filter_stop'] - mp.param['pre_filter_start']) input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (mp.param['pre_filter_stop'] - mp.param['pre_filter_start'])
input_high_end = X_spec_s[d][:, bp['n_fft']//2-input_high_end_h:bp['n_fft']//2, :] input_high_end = X_spec_s[d][:, bp['n_fft']//2-input_high_end_h:bp['n_fft']//2, :]
@ -202,6 +202,11 @@ def main():
y_mag = np.median(y_spec_m[c, b1:b2, i]) y_mag = np.median(y_spec_m[c, b1:b2, i])
input_high_end[c, :, i] = np.true_divide(input_high_end[c, :, i], abs(X_mag_max) / min(abs(y_mag * 4), abs(X_mag_max))) input_high_end[c, :, i] = np.true_divide(input_high_end[c, :, i], abs(X_mag_max) / min(abs(y_mag * 4), abs(X_mag_max)))
wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp, input_high_end_h, input_high_end)
elif args.high_end_process == 'mirroring':
mirror = np.multiply(np.flip(y_spec_m[:, mp.param['pre_filter_start']-10-input_high_end.shape[1]:mp.param['pre_filter_start']-10, :], 1), 1.7)
input_high_end = np.where(np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror)
wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp, input_high_end_h, input_high_end) wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp, input_high_end_h, input_high_end)
else: else:
wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp) wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp)