Update inference.py

This commit is contained in:
aufr33 2021-04-08 05:52:04 +03:00 committed by GitHub
parent d21f92860d
commit a718daba38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -98,7 +98,7 @@ def main():
p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth')
p.add_argument('--input', '-i', required=True)
p.add_argument('--model_params', '-m', type=str, default='')
p.add_argument('--window_size', '-w', type=int, default=352)
p.add_argument('--window_size', '-w', type=int, default=320)
p.add_argument('--output_image', '-I', action='store_true')
p.add_argument('--postprocess', '-p', action='store_true')
p.add_argument('--tta', '-t', action='store_true')
@ -106,10 +106,7 @@ def main():
p.add_argument('--aggressiveness', '-A', type=float, default=0.1)
args = p.parse_args()
if '' == args.model_params:
mp = ModelParameters(args.pretrained_model)
else:
mp = ModelParameters(args.model_params)
mp = ModelParameters(args.model_params)
start = time.time()
@ -133,7 +130,7 @@ def main():
for d in range(bands_n, 0, -1):
bp = mp.param['band'][d]
if d == bands_n: # high band
if d == bands_n: # high-end band
X_wave[d], _ = librosa.load(
args.input, bp['sr'], False, dtype=np.float32, res_type=bp['res_type'])
@ -168,9 +165,9 @@ def main():
print('done')
if 'is_vocal_model' in mp.param: # swap
stems = {'inst': 'Vocals', 'vocals': 'Instrumental'}
stems = {'inst': 'Vocals', 'vocals': 'Instruments'}
else:
stems = {'inst': 'Instrumental', 'vocals': 'Vocals'}
stems = {'inst': 'Instruments', 'vocals': 'Vocals'}
print('inverse stft of {}...'.format(stems['inst']), end=' ')
y_spec_m = pred * X_phase
@ -197,7 +194,6 @@ def main():
sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['inst'])), wave, mp.param['sr'])
print('inverse stft of {}...'.format(stems['vocals']), end=' ')
#v_spec_m = X_spec_m - y_spec_m
wave = spec_utils.cmb_spectrogram_to_wave(v_spec_m, mp)
print('done')
sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['vocals'])), wave, mp.param['sr'])