Update inference.py

This commit is contained in:
aufr33 2021-06-01 03:13:50 +03:00 committed by GitHub
parent 53b8aa3ff9
commit 721350b0df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,6 @@
import argparse
import os
import importlib
import cv2
import librosa
@ -96,7 +97,7 @@ def main():
p.add_argument('--gpu', '-g', type=int, default=-1)
p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth')
p.add_argument('--input', '-i', required=True)
p.add_argument('--nn_architecture', '-n', type=str, default='default')
p.add_argument('--nn_architecture', '-n', type=str, choices=['default', '33966KB', '123821KB', '129605KB'], default='default')
p.add_argument('--model_params', '-m', type=str, default='')
p.add_argument('--window_size', '-w', type=int, default=512)
p.add_argument('--output_image', '-I', action='store_true')
@ -108,14 +109,7 @@ def main():
p.add_argument('--aggressiveness', '-A', type=float, default=0.07)
args = p.parse_args()
if args.nn_architecture == 'default':
from lib import nets
if args.nn_architecture == '33966KB':
from lib import nets_33966KB as nets
if args.nn_architecture == '123821KB':
from lib import nets_123821KB as nets
if args.nn_architecture == '129605KB':
from lib import nets_129605KB as nets
nets = importlib.import_module('lib.nets' + f'_{args.nn_architecture}'.replace('_default', ''), package=None)
dir = 'ensembled/temp'
for file in os.scandir(dir):