From 0fc160c03e0864b4154cfa3994f09b2c495f6cbd Mon Sep 17 00:00:00 2001 From: yxlllc <33565655+yxlllc@users.noreply.github.com> Date: Sat, 2 Sep 2023 15:45:50 +0800 Subject: [PATCH] optimize the streaming f0 exatrators (#1168) --- gui_v1.py | 40 +++++++------ infer/lib/rmvpe.py | 2 +- tools/rvc_for_realtime.py | 120 ++++++++++++++++++++++---------------- 3 files changed, 91 insertions(+), 71 deletions(-) diff --git a/gui_v1.py b/gui_v1.py index f178633..93a60e4 100644 --- a/gui_v1.py +++ b/gui_v1.py @@ -261,9 +261,9 @@ if __name__ == "__main__": [ sg.Text(i18n("采样长度")), sg.Slider( - range=(0.09, 2.4), + range=(0.05, 2.4), key="block_time", - resolution=0.03, + resolution=0.01, orientation="h", default_value=data.get("block_time", ""), enable_events=True, @@ -455,18 +455,20 @@ if __name__ == "__main__": inp_q, opt_q, device, + self.rvc if hasattr(self, "rvc") else None ) self.config.samplerate = self.rvc.tgt_sr self.config.crossfade_time = min( self.config.crossfade_time, self.config.block_time ) - self.block_frame = int(self.config.block_time * self.config.samplerate) + self.zc = self.rvc.tgt_sr // 100 + self.block_frame = int(np.round(self.config.block_time * self.config.samplerate / self.zc)) * self.zc + self.block_frame_16k = 160 * self.block_frame // self.zc self.crossfade_frame = int( self.config.crossfade_time * self.config.samplerate ) self.sola_search_frame = int(0.01 * self.config.samplerate) self.extra_frame = int(self.config.extra_time * self.config.samplerate) - self.zc = self.rvc.tgt_sr // 100 self.input_wav: np.ndarray = np.zeros( int( np.ceil( @@ -482,6 +484,7 @@ if __name__ == "__main__": ), dtype="float32", ) + self.input_wav_res: torch.Tensor= torch.zeros(160 * len(self.input_wav) // self.zc) self.output_wav_cache: torch.Tensor = torch.zeros( int( np.ceil( @@ -573,18 +576,14 @@ if __name__ == "__main__": for i in range(db_threhold.shape[0]): if db_threhold[i]: indata[i * hop_length : (i + 1) * hop_length] = 0 - self.input_wav[:] = np.append(self.input_wav[self.block_frame :], indata) + self.input_wav[: -self.block_frame] = self.input_wav[self.block_frame :] + self.input_wav[-self.block_frame: ] = indata + # infer - inp = torch.from_numpy(self.input_wav).to(device) - res1 = self.resampler(inp) - ###55% - rate1 = self.block_frame / ( - self.extra_frame - + self.crossfade_frame - + self.sola_search_frame - + self.block_frame - ) - rate2 = ( + inp = torch.from_numpy(self.input_wav[-self.block_frame-2*self.zc :]).to(device) + self.input_wav_res[ : -self.block_frame_16k] = self.input_wav_res[self.block_frame_16k :].clone() + self.input_wav_res[-self.block_frame_16k-160 :] = self.resampler(inp)[160 :] + rate = ( self.crossfade_frame + self.sola_search_frame + self.block_frame ) / ( self.extra_frame @@ -592,11 +591,14 @@ if __name__ == "__main__": + self.sola_search_frame + self.block_frame ) + f0_extractor_frame = self.block_frame_16k + 800 + if self.config.f0method == 'rmvpe': + f0_extractor_frame = 5120 * ((f0_extractor_frame - 1) // 5120 + 1) res2 = self.rvc.infer( - res1, - res1[-self.block_frame :].cpu().numpy(), - rate1, - rate2, + self.input_wav_res, + self.input_wav_res[-f0_extractor_frame :].cpu().numpy(), + self.block_frame_16k, + rate, self.pitch, self.pitchf, self.config.f0method, diff --git a/infer/lib/rmvpe.py b/infer/lib/rmvpe.py index b498857..57e4de5 100644 --- a/infer/lib/rmvpe.py +++ b/infer/lib/rmvpe.py @@ -601,7 +601,7 @@ class RMVPE: with torch.no_grad(): n_frames = mel.shape[-1] mel = F.pad( - mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect" + mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="constant" ) if "privateuseone" in str(self.device): onnx_input_name = self.model.get_inputs()[0].name diff --git a/tools/rvc_for_realtime.py b/tools/rvc_for_realtime.py index 4d2802f..5a7a44f 100644 --- a/tools/rvc_for_realtime.py +++ b/tools/rvc_for_realtime.py @@ -2,7 +2,6 @@ import os import sys import traceback import logging - logger = logging.getLogger(__name__) from time import time as ttime @@ -48,7 +47,7 @@ if config.dml == True: # config.is_half=False########强制cpu测试 class RVC: def __init__( - self, key, pth_path, index_path, index_rate, n_cpu, inp_q, opt_q, device + self, key, pth_path, index_path, index_rate, n_cpu, inp_q, opt_q, device, last_rvc=None, ) -> None: """ 初始化 @@ -72,48 +71,64 @@ class RVC: self.index = faiss.read_index(index_path) self.big_npy = self.index.reconstruct_n(0, self.index.ntotal) logger.info("Index search enabled") + self.pth_path = pth_path self.index_path = index_path self.index_rate = index_rate - models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( - ["assets/hubert/hubert_base.pt"], - suffix="", - ) - hubert_model = models[0] - hubert_model = hubert_model.to(config.device) - if config.is_half: - hubert_model = hubert_model.half() - else: - hubert_model = hubert_model.float() - hubert_model.eval() - self.model = hubert_model - cpt = torch.load(pth_path, map_location="cpu") - self.tgt_sr = cpt["config"][-1] - cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] - self.if_f0 = cpt.get("f0", 1) - self.version = cpt.get("version", "v1") - if self.version == "v1": - if self.if_f0 == 1: - self.net_g = SynthesizerTrnMs256NSFsid( - *cpt["config"], is_half=config.is_half - ) + + if last_rvc is None: + models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( + ["assets/hubert/hubert_base.pt"], + suffix="", + ) + hubert_model = models[0] + hubert_model = hubert_model.to(config.device) + if config.is_half: + hubert_model = hubert_model.half() else: - self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) - elif self.version == "v2": - if self.if_f0 == 1: - self.net_g = SynthesizerTrnMs768NSFsid( - *cpt["config"], is_half=config.is_half - ) - else: - self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) - del self.net_g.enc_q - logger.debug(self.net_g.load_state_dict(cpt["weight"], strict=False)) - self.net_g.eval().to(device) - # print(2333333333,device,config.device,self.device)#net_g是device,hubert是config.device - if config.is_half: - self.net_g = self.net_g.half() + hubert_model = hubert_model.float() + hubert_model.eval() + self.model = hubert_model else: - self.net_g = self.net_g.float() - self.is_half = config.is_half + self.model = last_rvc.model + + if last_rvc is None or last_rvc.pth_path != self.pth_path: + cpt = torch.load(self.pth_path, map_location="cpu") + self.tgt_sr = cpt["config"][-1] + cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] + self.if_f0 = cpt.get("f0", 1) + self.version = cpt.get("version", "v1") + if self.version == "v1": + if self.if_f0 == 1: + self.net_g = SynthesizerTrnMs256NSFsid( + *cpt["config"], is_half=config.is_half + ) + else: + self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) + elif self.version == "v2": + if self.if_f0 == 1: + self.net_g = SynthesizerTrnMs768NSFsid( + *cpt["config"], is_half=config.is_half + ) + else: + self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) + del self.net_g.enc_q + logger.debug(self.net_g.load_state_dict(cpt["weight"], strict=False)) + self.net_g.eval().to(device) + # print(2333333333,device,config.device,self.device)#net_g是device,hubert是config.device + if config.is_half: + self.net_g = self.net_g.half() + else: + self.net_g = self.net_g.float() + self.is_half = config.is_half + else: + self.tgt_sr = last_rvc.tgt_sr + self.if_f0 = last_rvc.if_f0 + self.version = last_rvc.version + self.net_g = last_rvc.net_g + self.is_half = last_rvc.is_half + + if last_rvc is not None and hasattr(last_rvc, "model_rmvpe"): + self.model_rmvpe = last_rvc.model_rmvpe except: logger.warn(traceback.format_exc()) @@ -149,7 +164,7 @@ class RVC: if method == "rmvpe": return self.get_f0_rmvpe(x, f0_up_key) if method == "pm": - p_len = x.shape[0] // 160 + p_len = x.shape[0] // 160 + 1 f0 = ( parselmouth.Sound(x, 16000) .to_pitch_ac( @@ -181,9 +196,10 @@ class RVC: f0 = signal.medfilt(f0, 3) f0 *= pow(2, f0_up_key / 12) return self.get_f0_post(f0) - f0bak = np.zeros(x.shape[0] // 160, dtype=np.float64) + f0bak = np.zeros(x.shape[0] // 160 + 1, dtype=np.float64) length = len(x) - part_length = int(length / n_cpu / 160) * 160 + part_length = 160 * ((length // 160 - 1) // n_cpu + 1) + n_cpu = (length // 160 - 1) // (part_length // 160) + 1 ts = ttime() res_f0 = mm.dict() for idx in range(n_cpu): @@ -205,7 +221,7 @@ class RVC: elif idx != n_cpu - 1: f0 = f0[2:-3] else: - f0 = f0[2:-1] + f0 = f0[2:] f0bak[ part_length * idx // 160 : part_length * idx // 160 + f0.shape[0] ] = f0 @@ -259,8 +275,8 @@ class RVC: self, feats: torch.Tensor, indata: np.ndarray, - rate1, - rate2, + block_frame_16k, + rate, cache_pitch, cache_pitchf, f0method, @@ -286,7 +302,7 @@ class RVC: t2 = ttime() try: if hasattr(self, "index") and self.index_rate != 0: - leng_replace_head = int(rate1 * feats[0].shape[0]) + leng_replace_head = int(rate * feats[0].shape[0]) npy = feats[0][-leng_replace_head:].cpu().numpy().astype("float32") score, ix = self.index.search(npy, k=8) weight = np.square(1 / score) @@ -307,9 +323,11 @@ class RVC: t3 = ttime() if self.if_f0 == 1: pitch, pitchf = self.get_f0(indata, self.f0_up_key, self.n_cpu, f0method) - cache_pitch[:] = np.append(cache_pitch[pitch[:-1].shape[0] :], pitch[:-1]) + start_frame = block_frame_16k // 160 + end_frame = len(cache_pitch) - (pitch.shape[0] - 4) + start_frame + cache_pitch[:] = np.append(cache_pitch[start_frame : end_frame], pitch[3:-1]) cache_pitchf[:] = np.append( - cache_pitchf[pitchf[:-1].shape[0] :], pitchf[:-1] + cache_pitchf[start_frame : end_frame], pitchf[3:-1] ) p_len = min(feats.shape[1], 13000, cache_pitch.shape[0]) else: @@ -330,14 +348,14 @@ class RVC: # print(12222222222,feats.device,p_len.device,cache_pitch.device,cache_pitchf.device,sid.device,rate2) infered_audio = ( self.net_g.infer( - feats, p_len, cache_pitch, cache_pitchf, sid, rate2 + feats, p_len, cache_pitch, cache_pitchf, sid, rate )[0][0, 0] .data.cpu() .float() ) else: infered_audio = ( - self.net_g.infer(feats, p_len, sid, rate2)[0][0, 0] + self.net_g.infer(feats, p_len, sid, rate)[0][0, 0] .data.cpu() .float() )