Update inference.py

This commit is contained in:
aufr33 2021-04-08 04:40:02 +03:00 committed by GitHub
parent 331e03436f
commit d21f92860d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -97,9 +97,6 @@ def main():
p.add_argument('--gpu', '-g', type=int, default=-1)
p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth')
p.add_argument('--input', '-i', required=True)
#p.add_argument('--sr', '-r', type=int, default=44100)
#p.add_argument('--n_fft', type=int, default=2048) # combined
#p.add_argument('--hop_length', '-l', type=int, default=1024)
p.add_argument('--model_params', '-m', type=str, default='')
p.add_argument('--window_size', '-w', type=int, default=352)
p.add_argument('--output_image', '-I', action='store_true')
@ -126,8 +123,6 @@ def main():
model.to(device)
print('done')
#end = time.time()
#print(end - start)
print('loading & stft of wave source...', end=' ')
@ -183,7 +178,7 @@ def main():
if args.high_end_process == 'bypass':
wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp, input_high_end_h, input_high_end)
elif args.high_end_process == 'correlation':
elif args.high_end_process == 'correlation': # It needs improvement
for i in range(input_high_end.shape[2]):
for c in range(2):
X_mag_max = np.amax(input_high_end[c, :, i])
@ -201,12 +196,11 @@ def main():
model_name = os.path.splitext(os.path.basename(args.pretrained_model))[0]
sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['inst'])), wave, mp.param['sr'])
if True:
print('inverse stft of {}...'.format(stems['vocals']), end=' ')
#v_spec_m = X_spec_m - y_spec_m
wave = spec_utils.cmb_spectrogram_to_wave(v_spec_m, mp)
print('done')
sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['vocals'])), wave, mp.param['sr'])
print('inverse stft of {}...'.format(stems['vocals']), end=' ')
#v_spec_m = X_spec_m - y_spec_m
wave = spec_utils.cmb_spectrogram_to_wave(v_spec_m, mp)
print('done')
sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['vocals'])), wave, mp.param['sr'])
if args.output_image:
with open('{}_{}.jpg'.format(basename, stems['inst']), mode='wb') as f:
@ -218,7 +212,7 @@ def main():
_, bin_image = cv2.imencode('.jpg', image)
bin_image.tofile(f)
print(time.time() - start)
print('Runtime: {0:.{1}f}s'.format(time.time() - start, 1))
if __name__ == '__main__':
main()