Update inference.py

This commit is contained in:
aufr33 2021-04-08 04:40:02 +03:00 committed by GitHub
parent 331e03436f
commit d21f92860d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -97,9 +97,6 @@ def main():
p.add_argument('--gpu', '-g', type=int, default=-1) p.add_argument('--gpu', '-g', type=int, default=-1)
p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth') p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth')
p.add_argument('--input', '-i', required=True) p.add_argument('--input', '-i', required=True)
#p.add_argument('--sr', '-r', type=int, default=44100)
#p.add_argument('--n_fft', type=int, default=2048) # combined
#p.add_argument('--hop_length', '-l', type=int, default=1024)
p.add_argument('--model_params', '-m', type=str, default='') p.add_argument('--model_params', '-m', type=str, default='')
p.add_argument('--window_size', '-w', type=int, default=352) p.add_argument('--window_size', '-w', type=int, default=352)
p.add_argument('--output_image', '-I', action='store_true') p.add_argument('--output_image', '-I', action='store_true')
@ -126,8 +123,6 @@ def main():
model.to(device) model.to(device)
print('done') print('done')
#end = time.time()
#print(end - start)
print('loading & stft of wave source...', end=' ') print('loading & stft of wave source...', end=' ')
@ -183,7 +178,7 @@ def main():
if args.high_end_process == 'bypass': if args.high_end_process == 'bypass':
wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp, input_high_end_h, input_high_end) wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp, input_high_end_h, input_high_end)
elif args.high_end_process == 'correlation': elif args.high_end_process == 'correlation': # It needs improvement
for i in range(input_high_end.shape[2]): for i in range(input_high_end.shape[2]):
for c in range(2): for c in range(2):
X_mag_max = np.amax(input_high_end[c, :, i]) X_mag_max = np.amax(input_high_end[c, :, i])
@ -201,12 +196,11 @@ def main():
model_name = os.path.splitext(os.path.basename(args.pretrained_model))[0] model_name = os.path.splitext(os.path.basename(args.pretrained_model))[0]
sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['inst'])), wave, mp.param['sr']) sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['inst'])), wave, mp.param['sr'])
if True: print('inverse stft of {}...'.format(stems['vocals']), end=' ')
print('inverse stft of {}...'.format(stems['vocals']), end=' ') #v_spec_m = X_spec_m - y_spec_m
#v_spec_m = X_spec_m - y_spec_m wave = spec_utils.cmb_spectrogram_to_wave(v_spec_m, mp)
wave = spec_utils.cmb_spectrogram_to_wave(v_spec_m, mp) print('done')
print('done') sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['vocals'])), wave, mp.param['sr'])
sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['vocals'])), wave, mp.param['sr'])
if args.output_image: if args.output_image:
with open('{}_{}.jpg'.format(basename, stems['inst']), mode='wb') as f: with open('{}_{}.jpg'.format(basename, stems['inst']), mode='wb') as f:
@ -218,7 +212,7 @@ def main():
_, bin_image = cv2.imencode('.jpg', image) _, bin_image = cv2.imencode('.jpg', image)
bin_image.tofile(f) bin_image.tofile(f)
print(time.time() - start) print('Runtime: {0:.{1}f}s'.format(time.time() - start, 1))
if __name__ == '__main__': if __name__ == '__main__':
main() main()