add is_half argument to extract_feature_print.py (#1683)
* for GTX1650 * testing torch-profiler * no more profiler & change default audio * longer slice * fluid container * cache rmvpe and ui tweaks * get my changes back after merge * format code * only load rmvpe when necessary * fix rmvpe & config bug * fix is_half again * manual sync with upstream * revert other changes for pull request
This commit is contained in:
parent
f3281d1e37
commit
c2ab680e18
2
.github/workflows/unitest.yml
vendored
2
.github/workflows/unitest.yml
vendored
@ -33,4 +33,4 @@ jobs:
|
|||||||
python infer/modules/train/preprocess.py logs/mute/0_gt_wavs 48000 8 logs/mi-test True 3.7
|
python infer/modules/train/preprocess.py logs/mute/0_gt_wavs 48000 8 logs/mi-test True 3.7
|
||||||
touch logs/mi-test/extract_f0_feature.log
|
touch logs/mi-test/extract_f0_feature.log
|
||||||
python infer/modules/train/extract/extract_f0_print.py logs/mi-test $(nproc) pm
|
python infer/modules/train/extract/extract_f0_print.py logs/mi-test $(nproc) pm
|
||||||
python infer/modules/train/extract_feature_print.py cpu 1 0 0 logs/mi-test v1
|
python infer/modules/train/extract_feature_print.py cpu 1 0 0 logs/mi-test v1 True
|
||||||
|
@ -290,7 +290,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"!python3 extract_f0_print.py logs/{MODELNAME} {THREADCOUNT} {ALGO}\n",
|
"!python3 extract_f0_print.py logs/{MODELNAME} {THREADCOUNT} {ALGO}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"!python3 extract_feature_print.py cpu 1 0 0 logs/{MODELNAME}"
|
"!python3 extract_feature_print.py cpu 1 0 0 logs/{MODELNAME} True"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -309,7 +309,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"!python3 extract_f0_print.py logs/{MODELNAME} {THREADCOUNT} {ALGO}\n",
|
"!python3 extract_f0_print.py logs/{MODELNAME} {THREADCOUNT} {ALGO}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"!python3 extract_feature_print.py cpu 1 0 0 logs/{MODELNAME}"
|
"!python3 extract_feature_print.py cpu 1 0 0 logs/{MODELNAME} True"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -344,7 +344,7 @@ def extract_f0_feature(gpus, n_p, f0method, if_f0, exp_dir, version19, gpus_rmvp
|
|||||||
ps = []
|
ps = []
|
||||||
for idx, n_g in enumerate(gpus):
|
for idx, n_g in enumerate(gpus):
|
||||||
cmd = (
|
cmd = (
|
||||||
'"%s" infer/modules/train/extract_feature_print.py %s %s %s %s "%s/logs/%s" %s'
|
'"%s" infer/modules/train/extract_feature_print.py %s %s %s %s "%s/logs/%s" %s %s'
|
||||||
% (
|
% (
|
||||||
config.python_cmd,
|
config.python_cmd,
|
||||||
config.device,
|
config.device,
|
||||||
@ -354,6 +354,7 @@ def extract_f0_feature(gpus, n_p, f0method, if_f0, exp_dir, version19, gpus_rmvp
|
|||||||
now_dir,
|
now_dir,
|
||||||
exp_dir,
|
exp_dir,
|
||||||
version19,
|
version19,
|
||||||
|
config.is_half,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.info(cmd)
|
logger.info(cmd)
|
||||||
|
@ -8,14 +8,16 @@ os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
|
|||||||
device = sys.argv[1]
|
device = sys.argv[1]
|
||||||
n_part = int(sys.argv[2])
|
n_part = int(sys.argv[2])
|
||||||
i_part = int(sys.argv[3])
|
i_part = int(sys.argv[3])
|
||||||
if len(sys.argv) == 6:
|
if len(sys.argv) == 7:
|
||||||
exp_dir = sys.argv[4]
|
exp_dir = sys.argv[4]
|
||||||
version = sys.argv[5]
|
version = sys.argv[5]
|
||||||
|
is_half = bool(sys.argv[6])
|
||||||
else:
|
else:
|
||||||
i_gpu = sys.argv[4]
|
i_gpu = sys.argv[4]
|
||||||
exp_dir = sys.argv[5]
|
exp_dir = sys.argv[5]
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(i_gpu)
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(i_gpu)
|
||||||
version = sys.argv[6]
|
version = sys.argv[6]
|
||||||
|
is_half = bool(sys.argv[7])
|
||||||
import fairseq
|
import fairseq
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
@ -91,6 +93,7 @@ models, saved_cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
|||||||
model = models[0]
|
model = models[0]
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
printt("move model to %s" % device)
|
printt("move model to %s" % device)
|
||||||
|
if is_half:
|
||||||
if device not in ["mps", "cpu"]:
|
if device not in ["mps", "cpu"]:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -113,9 +116,7 @@ else:
|
|||||||
feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
|
feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
|
||||||
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
|
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
|
||||||
inputs = {
|
inputs = {
|
||||||
"source": feats.half().to(device)
|
"source": feats.half().to(device) if is_half and device not in ["mps", "cpu"] else feats.to(device),
|
||||||
if device not in ["mps", "cpu"]
|
|
||||||
else feats.to(device),
|
|
||||||
"padding_mask": padding_mask.to(device),
|
"padding_mask": padding_mask.to(device),
|
||||||
"output_layer": 9 if version == "v1" else 12, # layer 9
|
"output_layer": 9 if version == "v1" else 12, # layer 9
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user