From 34de3492f369bbe9ec881bb6c1f01eb9ddabf844 Mon Sep 17 00:00:00 2001 From: Ftps <63702646+Tps-F@users.noreply.github.com> Date: Mon, 28 Aug 2023 01:14:36 +0900 Subject: [PATCH] rewrite infer_cli --- infer_cli.py | 272 --------------------------------------- tools/infer_batch_rvc.py | 14 +- tools/infer_cli.py | 65 ++++++++++ 3 files changed, 73 insertions(+), 278 deletions(-) delete mode 100644 infer_cli.py create mode 100644 tools/infer_cli.py diff --git a/infer_cli.py b/infer_cli.py deleted file mode 100644 index 59f246b..0000000 --- a/infer_cli.py +++ /dev/null @@ -1,272 +0,0 @@ -from scipy.io import wavfile -from fairseq import checkpoint_utils -from lib.audio import load_audio -from lib.infer_pack.models import ( - SynthesizerTrnMs256NSFsid, - SynthesizerTrnMs256NSFsid_nono, - SynthesizerTrnMs768NSFsid, - SynthesizerTrnMs768NSFsid_nono, -) -from lib.train.vc_infer_pipeline import VC -from multiprocessing import cpu_count -import numpy as np -import torch -import sys -import glob -import argparse -import os -import sys -import pdb -import torch - -now_dir = os.getcwd() -sys.path.append(now_dir) - -#### -# USAGE -# -# In your Terminal or CMD or whatever -# python infer_cli.py [TRANSPOSE_VALUE] "[INPUT_PATH]" "[OUTPUT_PATH]" "[MODEL_PATH]" "[INDEX_FILE_PATH]" "[INFERENCE_DEVICE]" "[METHOD]" - -using_cli = False -device = "cuda:0" -is_half = False - -if len(sys.argv) > 0: - f0_up_key = int(sys.argv[1]) # transpose value - input_path = sys.argv[2] - output_path = sys.argv[3] - model_path = sys.argv[4] - file_index = sys.argv[5] # .index file - device = sys.argv[6] - f0_method = sys.argv[7] # pm or harvest or crepe - - using_cli = True - - # file_index2=sys.argv[8] - # index_rate=float(sys.argv[10]) #search feature ratio - # filter_radius=float(sys.argv[11]) #median filter - # resample_sr=float(sys.argv[12]) #resample audio in post processing - # rms_mix_rate=float(sys.argv[13]) #search feature - print(sys.argv) - - -class Config: - def __init__(self, device, is_half): - self.device = device - self.is_half = is_half - self.n_cpu = 0 - self.gpu_name = None - self.gpu_mem = None - self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config() - - def device_config(self) -> tuple: - if torch.cuda.is_available() and device != "cpu": - i_device = int(self.device.split(":")[-1]) - self.gpu_name = torch.cuda.get_device_name(i_device) - if ( - ("16" in self.gpu_name and "V100" not in self.gpu_name.upper()) - or "P40" in self.gpu_name.upper() - or "1060" in self.gpu_name - or "1070" in self.gpu_name - or "1080" in self.gpu_name - ): - print("16系/10系显卡和P40强制单精度") - self.is_half = False - for config_file in ["32k.json", "40k.json", "48k.json"]: - with open(f"configs/{config_file}", "r") as f: - strr = f.read().replace("true", "false") - with open(f"configs/{config_file}", "w") as f: - f.write(strr) - with open("trainset_preprocess_pipeline_print.py", "r") as f: - strr = f.read().replace("3.7", "3.0") - with open("trainset_preprocess_pipeline_print.py", "w") as f: - f.write(strr) - else: - self.gpu_name = None - self.gpu_mem = int( - torch.cuda.get_device_properties(i_device).total_memory - / 1024 - / 1024 - / 1024 - + 0.4 - ) - if self.gpu_mem <= 4: - with open("trainset_preprocess_pipeline_print.py", "r") as f: - strr = f.read().replace("3.7", "3.0") - with open("trainset_preprocess_pipeline_print.py", "w") as f: - f.write(strr) - elif torch.backends.mps.is_available(): - print("没有发现支持的N卡, 使用MPS进行推理") - self.device = "mps" - else: - print("没有发现支持的N卡, 使用CPU进行推理") - self.device = "cpu" - self.is_half = False - - if self.n_cpu == 0: - self.n_cpu = cpu_count() - - if self.is_half: - # 6G显存配置 - x_pad = 3 - x_query = 10 - x_center = 60 - x_max = 65 - else: - # 5G显存配置 - x_pad = 1 - x_query = 6 - x_center = 38 - x_max = 41 - - if self.gpu_mem != None and self.gpu_mem <= 4: - x_pad = 1 - x_query = 5 - x_center = 30 - x_max = 32 - - return x_pad, x_query, x_center, x_max - - -config = Config(device, is_half) -now_dir = os.getcwd() -sys.path.append(now_dir) - -hubert_model = None - - -def load_hubert(): - global hubert_model - models, _, _ = checkpoint_utils.load_model_ensemble_and_task( - ["hubert_base.pt"], - suffix="", - ) - hubert_model = models[0] - hubert_model = hubert_model.to(config.device) - if config.is_half: - hubert_model = hubert_model.half() - else: - hubert_model = hubert_model.float() - hubert_model.eval() - - -def vc_single( - sid=0, - input_audio_path=None, - f0_up_key=0, - f0_file=None, - f0_method="pm", - file_index="", # .index file - file_index2="", - # file_big_npy, - index_rate=1.0, - filter_radius=3, - resample_sr=0, - rms_mix_rate=1.0, - model_path="", - output_path="", - protect=0.33, -): - global tgt_sr, net_g, vc, hubert_model, version - get_vc(model_path) - if input_audio_path is None: - return "You need to upload an audio file", None - - f0_up_key = int(f0_up_key) - audio = load_audio(input_audio_path, 16000) - audio_max = np.abs(audio).max() / 0.95 - - if audio_max > 1: - audio /= audio_max - times = [0, 0, 0] - - if hubert_model == None: - load_hubert() - - if_f0 = cpt.get("f0", 1) - - file_index = ( - ( - file_index.strip(" ") - .strip('"') - .strip("\n") - .strip('"') - .strip(" ") - .replace("trained", "added") - ) - if file_index != "" - else file_index2 - ) - - audio_opt = vc.pipeline( - hubert_model, - net_g, - sid, - audio, - input_audio_path, - times, - f0_up_key, - f0_method, - file_index, - # file_big_npy, - index_rate, - if_f0, - filter_radius, - tgt_sr, - resample_sr, - rms_mix_rate, - version, - f0_file=f0_file, - protect=protect, - ) - wavfile.write(output_path, tgt_sr, audio_opt) - return "processed" - - -def get_vc(model_path): - global n_spk, tgt_sr, net_g, vc, cpt, device, is_half, version - print("loading pth %s" % model_path) - cpt = torch.load(model_path, map_location="cpu") - tgt_sr = cpt["config"][-1] - cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk - if_f0 = cpt.get("f0", 1) - version = cpt.get("version", "v1") - if version == "v1": - if if_f0 == 1: - net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=is_half) - else: - net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) - elif version == "v2": - if if_f0 == 1: - net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=is_half) - else: - net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) - del net_g.enc_q - print(net_g.load_state_dict(cpt["weight"], strict=False)) - net_g.eval().to(device) - if is_half: - net_g = net_g.half() - else: - net_g = net_g.float() - vc = VC(tgt_sr, config) - n_spk = cpt["config"][-3] - # return {"visible": True,"maximum": n_spk, "__type__": "update"} - - -if using_cli: - vc_single( - sid=0, - input_audio_path=input_path, - f0_up_key=f0_up_key, - f0_file=None, - f0_method=f0_method, - file_index=file_index, - file_index2="", - index_rate=1, - filter_radius=3, - resample_sr=0, - rms_mix_rate=0, - model_path=model_path, - output_path=output_path, - ) diff --git a/tools/infer_batch_rvc.py b/tools/infer_batch_rvc.py index b8d73e7..763d17f 100644 --- a/tools/infer_batch_rvc.py +++ b/tools/infer_batch_rvc.py @@ -16,7 +16,7 @@ from configs.config import Config from infer.modules.vc.modules import VC -def args() -> tuple: +def arg_parse() -> tuple: parser = argparse.ArgumentParser() parser.add_argument("--f0up_key", type=int, default=0) parser.add_argument("--input_path", type=str, help="input path") @@ -39,6 +39,12 @@ def args() -> tuple: def main(): + load_dotenv() + args = arg_parse() + config = Config() + config.device = args.device if args.device else config.device + config.is_half = args.is_half if args.is_half else config.is_half + vc = VC(config) vc.get_vc(args.model_name) audios = os.listdir(args.input_path) for file in tq.tqdm(audios): @@ -59,12 +65,8 @@ def main(): args.protect, ) out_path = os.path.join(args.opt_path, file) - wavfile.write(out_path, vc.tgt_sr, wav_opt[1]) + wavfile.write(out_path, wav_opt[0], wav_opt[1]) if __name__ == "__main__": - load_dotenv() - args = args() - config = Config() - vc = VC(config) main() diff --git a/tools/infer_cli.py b/tools/infer_cli.py new file mode 100644 index 0000000..4a7dca5 --- /dev/null +++ b/tools/infer_cli.py @@ -0,0 +1,65 @@ +import argparse +import os +import sys +now_dir = os.getcwd() +sys.path.append(now_dir) +from scipy.io import wavfile + +from configs.config import Config +from infer.modules.vc.modules import VC +from dotenv import load_dotenv +#### +# USAGE +# +# In your Terminal or CMD or whatever + + +def arg_parse() -> tuple: + parser = argparse.ArgumentParser() + parser.add_argument("--f0up_key", type=int, default=0) + parser.add_argument("--input_path", type=str, help="input path") + parser.add_argument("--index_path", type=str, help="index path") + parser.add_argument("--f0method", type=str, default="harvest", help="harvest or pm") + parser.add_argument("--opt_path", type=str, help="opt path") + parser.add_argument("--model_name", type=str, help="store in assets/weight_root") + parser.add_argument("--index_rate", type=float, default=0.66, help="index rate") + parser.add_argument("--device", type=str, help="device") + parser.add_argument("--is_half", type=bool, help="use half -> True") + parser.add_argument("--filter_radius", type=int, default=3, help="filter radius") + parser.add_argument("--resample_sr", type=int, default=0, help="resample sr") + parser.add_argument("--rms_mix_rate", type=float, default=1, help="rms mix rate") + parser.add_argument("--protect", type=float, default=0.33, help="protect") + + args = parser.parse_args() + sys.argv = sys.argv[:1] + + return args + + +def main(): + load_dotenv() + args = arg_parse() + config = Config() + config.device = args.device if args.device else config.device + config.is_half = args.is_half if args.is_half else config.is_half + vc = VC(config) + vc.get_vc(args.model_name) + _, wav_opt = vc.vc_single( + 0, + args.input_path, + args.f0up_key, + None, + args.f0method, + args.index_path, + None, + args.index_rate, + args.filter_radius, + args.resample_sr, + args.rms_mix_rate, + args.protect, + ) + wavfile.write(args.opt_path, wav_opt[0], wav_opt[1]) + + +if __name__ == "__main__": + main()