Update inference.py

This commit is contained in:
aufr33 2021-04-30 08:43:44 +03:00 committed by GitHub
parent 47ef8cf83f
commit fb26b9c108
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,7 +10,6 @@ import time
from tqdm import tqdm
from lib import dataset
from lib import nets
from lib import spec_utils
from lib.model_param_init import ModelParameters
@ -97,18 +96,31 @@ def main():
p.add_argument('--gpu', '-g', type=int, default=-1)
p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth')
p.add_argument('--input', '-i', required=True)
p.add_argument('--nn_architecture', '-n', type=str, default='default')
p.add_argument('--model_params', '-m', type=str, default='')
p.add_argument('--window_size', '-w', type=int, default=320)
p.add_argument('--window_size', '-w', type=int, default=512)
p.add_argument('--output_image', '-I', action='store_true')
p.add_argument('--postprocess', '-p', action='store_true')
p.add_argument('--tta', '-t', action='store_true')
p.add_argument('--high_end_process', '-H', type=str, choices=['none', 'bypass', 'correlation'], default='none')
p.add_argument('--aggressiveness', '-A', type=float, default=0.1)
p.add_argument('--aggressiveness', '-A', type=float, default=0.09)
args = p.parse_args()
if args.nn_architecture == 'default':
from lib import nets
if args.nn_architecture == '33966KB':
from lib import nets_33966KB as nets
if args.nn_architecture == '123821KB':
from lib import nets_123821KB as nets
if args.nn_architecture == '129605KB':
from lib import nets_129605KB as nets
#if '' == args.model_params:
# mp = ModelParameters(args.pretrained_model)
#else:
mp = ModelParameters(args.model_params)
start = time.time()
start_time = time.time()
print('loading model...', end=' ')
@ -139,7 +151,7 @@ def main():
else: # lower bands
X_wave[d] = librosa.resample(X_wave[d+1], mp.param['band'][d+1]['sr'], bp['sr'], res_type=bp['res_type'])
X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(X_wave[d], bp['hl'], bp['n_fft'], mp.param['mid_side'])
X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(X_wave[d], bp['hl'], bp['n_fft'], mp.param['mid_side'], mp.param['reverse'])
if d == bands_n and args.high_end_process in ['bypass', 'correlation']:
input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (mp.param['pre_filter_stop'] - mp.param['pre_filter_start'])
@ -175,7 +187,7 @@ def main():
if args.high_end_process == 'bypass':
wave = spec_utils.cmb_spectrogram_to_wave(y_spec_m, mp, input_high_end_h, input_high_end)
elif args.high_end_process == 'correlation': # It needs improvement
elif args.high_end_process == 'correlation':
for i in range(input_high_end.shape[2]):
for c in range(2):
X_mag_max = np.amax(input_high_end[c, :, i])
@ -193,7 +205,9 @@ def main():
model_name = os.path.splitext(os.path.basename(args.pretrained_model))[0]
sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['inst'])), wave, mp.param['sr'])
if True:
print('inverse stft of {}...'.format(stems['vocals']), end=' ')
#v_spec_m = X_spec_m - y_spec_m
wave = spec_utils.cmb_spectrogram_to_wave(v_spec_m, mp)
print('done')
sf.write(os.path.join('separated', '{}_{}_{}.wav'.format(basename, model_name, stems['vocals'])), wave, mp.param['sr'])
@ -208,7 +222,7 @@ def main():
_, bin_image = cv2.imencode('.jpg', image)
bin_image.tofile(f)
print('Runtime: {0:.{1}f}s'.format(time.time() - start, 1))
print('Total time: {0:.{1}f}s'.format(time.time() - start_time, 1))
if __name__ == '__main__':
main()