optimize the streaming f0 exatrators (#1168)
This commit is contained in:
parent
ad85b02ed9
commit
0fc160c03e
40
gui_v1.py
40
gui_v1.py
@ -261,9 +261,9 @@ if __name__ == "__main__":
|
|||||||
[
|
[
|
||||||
sg.Text(i18n("采样长度")),
|
sg.Text(i18n("采样长度")),
|
||||||
sg.Slider(
|
sg.Slider(
|
||||||
range=(0.09, 2.4),
|
range=(0.05, 2.4),
|
||||||
key="block_time",
|
key="block_time",
|
||||||
resolution=0.03,
|
resolution=0.01,
|
||||||
orientation="h",
|
orientation="h",
|
||||||
default_value=data.get("block_time", ""),
|
default_value=data.get("block_time", ""),
|
||||||
enable_events=True,
|
enable_events=True,
|
||||||
@ -455,18 +455,20 @@ if __name__ == "__main__":
|
|||||||
inp_q,
|
inp_q,
|
||||||
opt_q,
|
opt_q,
|
||||||
device,
|
device,
|
||||||
|
self.rvc if hasattr(self, "rvc") else None
|
||||||
)
|
)
|
||||||
self.config.samplerate = self.rvc.tgt_sr
|
self.config.samplerate = self.rvc.tgt_sr
|
||||||
self.config.crossfade_time = min(
|
self.config.crossfade_time = min(
|
||||||
self.config.crossfade_time, self.config.block_time
|
self.config.crossfade_time, self.config.block_time
|
||||||
)
|
)
|
||||||
self.block_frame = int(self.config.block_time * self.config.samplerate)
|
self.zc = self.rvc.tgt_sr // 100
|
||||||
|
self.block_frame = int(np.round(self.config.block_time * self.config.samplerate / self.zc)) * self.zc
|
||||||
|
self.block_frame_16k = 160 * self.block_frame // self.zc
|
||||||
self.crossfade_frame = int(
|
self.crossfade_frame = int(
|
||||||
self.config.crossfade_time * self.config.samplerate
|
self.config.crossfade_time * self.config.samplerate
|
||||||
)
|
)
|
||||||
self.sola_search_frame = int(0.01 * self.config.samplerate)
|
self.sola_search_frame = int(0.01 * self.config.samplerate)
|
||||||
self.extra_frame = int(self.config.extra_time * self.config.samplerate)
|
self.extra_frame = int(self.config.extra_time * self.config.samplerate)
|
||||||
self.zc = self.rvc.tgt_sr // 100
|
|
||||||
self.input_wav: np.ndarray = np.zeros(
|
self.input_wav: np.ndarray = np.zeros(
|
||||||
int(
|
int(
|
||||||
np.ceil(
|
np.ceil(
|
||||||
@ -482,6 +484,7 @@ if __name__ == "__main__":
|
|||||||
),
|
),
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
)
|
)
|
||||||
|
self.input_wav_res: torch.Tensor= torch.zeros(160 * len(self.input_wav) // self.zc)
|
||||||
self.output_wav_cache: torch.Tensor = torch.zeros(
|
self.output_wav_cache: torch.Tensor = torch.zeros(
|
||||||
int(
|
int(
|
||||||
np.ceil(
|
np.ceil(
|
||||||
@ -573,18 +576,14 @@ if __name__ == "__main__":
|
|||||||
for i in range(db_threhold.shape[0]):
|
for i in range(db_threhold.shape[0]):
|
||||||
if db_threhold[i]:
|
if db_threhold[i]:
|
||||||
indata[i * hop_length : (i + 1) * hop_length] = 0
|
indata[i * hop_length : (i + 1) * hop_length] = 0
|
||||||
self.input_wav[:] = np.append(self.input_wav[self.block_frame :], indata)
|
self.input_wav[: -self.block_frame] = self.input_wav[self.block_frame :]
|
||||||
|
self.input_wav[-self.block_frame: ] = indata
|
||||||
|
|
||||||
# infer
|
# infer
|
||||||
inp = torch.from_numpy(self.input_wav).to(device)
|
inp = torch.from_numpy(self.input_wav[-self.block_frame-2*self.zc :]).to(device)
|
||||||
res1 = self.resampler(inp)
|
self.input_wav_res[ : -self.block_frame_16k] = self.input_wav_res[self.block_frame_16k :].clone()
|
||||||
###55%
|
self.input_wav_res[-self.block_frame_16k-160 :] = self.resampler(inp)[160 :]
|
||||||
rate1 = self.block_frame / (
|
rate = (
|
||||||
self.extra_frame
|
|
||||||
+ self.crossfade_frame
|
|
||||||
+ self.sola_search_frame
|
|
||||||
+ self.block_frame
|
|
||||||
)
|
|
||||||
rate2 = (
|
|
||||||
self.crossfade_frame + self.sola_search_frame + self.block_frame
|
self.crossfade_frame + self.sola_search_frame + self.block_frame
|
||||||
) / (
|
) / (
|
||||||
self.extra_frame
|
self.extra_frame
|
||||||
@ -592,11 +591,14 @@ if __name__ == "__main__":
|
|||||||
+ self.sola_search_frame
|
+ self.sola_search_frame
|
||||||
+ self.block_frame
|
+ self.block_frame
|
||||||
)
|
)
|
||||||
|
f0_extractor_frame = self.block_frame_16k + 800
|
||||||
|
if self.config.f0method == 'rmvpe':
|
||||||
|
f0_extractor_frame = 5120 * ((f0_extractor_frame - 1) // 5120 + 1)
|
||||||
res2 = self.rvc.infer(
|
res2 = self.rvc.infer(
|
||||||
res1,
|
self.input_wav_res,
|
||||||
res1[-self.block_frame :].cpu().numpy(),
|
self.input_wav_res[-f0_extractor_frame :].cpu().numpy(),
|
||||||
rate1,
|
self.block_frame_16k,
|
||||||
rate2,
|
rate,
|
||||||
self.pitch,
|
self.pitch,
|
||||||
self.pitchf,
|
self.pitchf,
|
||||||
self.config.f0method,
|
self.config.f0method,
|
||||||
|
@ -601,7 +601,7 @@ class RMVPE:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
n_frames = mel.shape[-1]
|
n_frames = mel.shape[-1]
|
||||||
mel = F.pad(
|
mel = F.pad(
|
||||||
mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect"
|
mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="constant"
|
||||||
)
|
)
|
||||||
if "privateuseone" in str(self.device):
|
if "privateuseone" in str(self.device):
|
||||||
onnx_input_name = self.model.get_inputs()[0].name
|
onnx_input_name = self.model.get_inputs()[0].name
|
||||||
|
@ -2,7 +2,6 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from time import time as ttime
|
from time import time as ttime
|
||||||
@ -48,7 +47,7 @@ if config.dml == True:
|
|||||||
# config.is_half=False########强制cpu测试
|
# config.is_half=False########强制cpu测试
|
||||||
class RVC:
|
class RVC:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, key, pth_path, index_path, index_rate, n_cpu, inp_q, opt_q, device
|
self, key, pth_path, index_path, index_rate, n_cpu, inp_q, opt_q, device, last_rvc=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
初始化
|
初始化
|
||||||
@ -72,48 +71,64 @@ class RVC:
|
|||||||
self.index = faiss.read_index(index_path)
|
self.index = faiss.read_index(index_path)
|
||||||
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
|
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
|
||||||
logger.info("Index search enabled")
|
logger.info("Index search enabled")
|
||||||
|
self.pth_path = pth_path
|
||||||
self.index_path = index_path
|
self.index_path = index_path
|
||||||
self.index_rate = index_rate
|
self.index_rate = index_rate
|
||||||
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
|
||||||
["assets/hubert/hubert_base.pt"],
|
if last_rvc is None:
|
||||||
suffix="",
|
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||||
)
|
["assets/hubert/hubert_base.pt"],
|
||||||
hubert_model = models[0]
|
suffix="",
|
||||||
hubert_model = hubert_model.to(config.device)
|
)
|
||||||
if config.is_half:
|
hubert_model = models[0]
|
||||||
hubert_model = hubert_model.half()
|
hubert_model = hubert_model.to(config.device)
|
||||||
else:
|
if config.is_half:
|
||||||
hubert_model = hubert_model.float()
|
hubert_model = hubert_model.half()
|
||||||
hubert_model.eval()
|
|
||||||
self.model = hubert_model
|
|
||||||
cpt = torch.load(pth_path, map_location="cpu")
|
|
||||||
self.tgt_sr = cpt["config"][-1]
|
|
||||||
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
|
||||||
self.if_f0 = cpt.get("f0", 1)
|
|
||||||
self.version = cpt.get("version", "v1")
|
|
||||||
if self.version == "v1":
|
|
||||||
if self.if_f0 == 1:
|
|
||||||
self.net_g = SynthesizerTrnMs256NSFsid(
|
|
||||||
*cpt["config"], is_half=config.is_half
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
|
hubert_model = hubert_model.float()
|
||||||
elif self.version == "v2":
|
hubert_model.eval()
|
||||||
if self.if_f0 == 1:
|
self.model = hubert_model
|
||||||
self.net_g = SynthesizerTrnMs768NSFsid(
|
|
||||||
*cpt["config"], is_half=config.is_half
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
|
|
||||||
del self.net_g.enc_q
|
|
||||||
logger.debug(self.net_g.load_state_dict(cpt["weight"], strict=False))
|
|
||||||
self.net_g.eval().to(device)
|
|
||||||
# print(2333333333,device,config.device,self.device)#net_g是device,hubert是config.device
|
|
||||||
if config.is_half:
|
|
||||||
self.net_g = self.net_g.half()
|
|
||||||
else:
|
else:
|
||||||
self.net_g = self.net_g.float()
|
self.model = last_rvc.model
|
||||||
self.is_half = config.is_half
|
|
||||||
|
if last_rvc is None or last_rvc.pth_path != self.pth_path:
|
||||||
|
cpt = torch.load(self.pth_path, map_location="cpu")
|
||||||
|
self.tgt_sr = cpt["config"][-1]
|
||||||
|
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
||||||
|
self.if_f0 = cpt.get("f0", 1)
|
||||||
|
self.version = cpt.get("version", "v1")
|
||||||
|
if self.version == "v1":
|
||||||
|
if self.if_f0 == 1:
|
||||||
|
self.net_g = SynthesizerTrnMs256NSFsid(
|
||||||
|
*cpt["config"], is_half=config.is_half
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
|
||||||
|
elif self.version == "v2":
|
||||||
|
if self.if_f0 == 1:
|
||||||
|
self.net_g = SynthesizerTrnMs768NSFsid(
|
||||||
|
*cpt["config"], is_half=config.is_half
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
|
||||||
|
del self.net_g.enc_q
|
||||||
|
logger.debug(self.net_g.load_state_dict(cpt["weight"], strict=False))
|
||||||
|
self.net_g.eval().to(device)
|
||||||
|
# print(2333333333,device,config.device,self.device)#net_g是device,hubert是config.device
|
||||||
|
if config.is_half:
|
||||||
|
self.net_g = self.net_g.half()
|
||||||
|
else:
|
||||||
|
self.net_g = self.net_g.float()
|
||||||
|
self.is_half = config.is_half
|
||||||
|
else:
|
||||||
|
self.tgt_sr = last_rvc.tgt_sr
|
||||||
|
self.if_f0 = last_rvc.if_f0
|
||||||
|
self.version = last_rvc.version
|
||||||
|
self.net_g = last_rvc.net_g
|
||||||
|
self.is_half = last_rvc.is_half
|
||||||
|
|
||||||
|
if last_rvc is not None and hasattr(last_rvc, "model_rmvpe"):
|
||||||
|
self.model_rmvpe = last_rvc.model_rmvpe
|
||||||
except:
|
except:
|
||||||
logger.warn(traceback.format_exc())
|
logger.warn(traceback.format_exc())
|
||||||
|
|
||||||
@ -149,7 +164,7 @@ class RVC:
|
|||||||
if method == "rmvpe":
|
if method == "rmvpe":
|
||||||
return self.get_f0_rmvpe(x, f0_up_key)
|
return self.get_f0_rmvpe(x, f0_up_key)
|
||||||
if method == "pm":
|
if method == "pm":
|
||||||
p_len = x.shape[0] // 160
|
p_len = x.shape[0] // 160 + 1
|
||||||
f0 = (
|
f0 = (
|
||||||
parselmouth.Sound(x, 16000)
|
parselmouth.Sound(x, 16000)
|
||||||
.to_pitch_ac(
|
.to_pitch_ac(
|
||||||
@ -181,9 +196,10 @@ class RVC:
|
|||||||
f0 = signal.medfilt(f0, 3)
|
f0 = signal.medfilt(f0, 3)
|
||||||
f0 *= pow(2, f0_up_key / 12)
|
f0 *= pow(2, f0_up_key / 12)
|
||||||
return self.get_f0_post(f0)
|
return self.get_f0_post(f0)
|
||||||
f0bak = np.zeros(x.shape[0] // 160, dtype=np.float64)
|
f0bak = np.zeros(x.shape[0] // 160 + 1, dtype=np.float64)
|
||||||
length = len(x)
|
length = len(x)
|
||||||
part_length = int(length / n_cpu / 160) * 160
|
part_length = 160 * ((length // 160 - 1) // n_cpu + 1)
|
||||||
|
n_cpu = (length // 160 - 1) // (part_length // 160) + 1
|
||||||
ts = ttime()
|
ts = ttime()
|
||||||
res_f0 = mm.dict()
|
res_f0 = mm.dict()
|
||||||
for idx in range(n_cpu):
|
for idx in range(n_cpu):
|
||||||
@ -205,7 +221,7 @@ class RVC:
|
|||||||
elif idx != n_cpu - 1:
|
elif idx != n_cpu - 1:
|
||||||
f0 = f0[2:-3]
|
f0 = f0[2:-3]
|
||||||
else:
|
else:
|
||||||
f0 = f0[2:-1]
|
f0 = f0[2:]
|
||||||
f0bak[
|
f0bak[
|
||||||
part_length * idx // 160 : part_length * idx // 160 + f0.shape[0]
|
part_length * idx // 160 : part_length * idx // 160 + f0.shape[0]
|
||||||
] = f0
|
] = f0
|
||||||
@ -259,8 +275,8 @@ class RVC:
|
|||||||
self,
|
self,
|
||||||
feats: torch.Tensor,
|
feats: torch.Tensor,
|
||||||
indata: np.ndarray,
|
indata: np.ndarray,
|
||||||
rate1,
|
block_frame_16k,
|
||||||
rate2,
|
rate,
|
||||||
cache_pitch,
|
cache_pitch,
|
||||||
cache_pitchf,
|
cache_pitchf,
|
||||||
f0method,
|
f0method,
|
||||||
@ -286,7 +302,7 @@ class RVC:
|
|||||||
t2 = ttime()
|
t2 = ttime()
|
||||||
try:
|
try:
|
||||||
if hasattr(self, "index") and self.index_rate != 0:
|
if hasattr(self, "index") and self.index_rate != 0:
|
||||||
leng_replace_head = int(rate1 * feats[0].shape[0])
|
leng_replace_head = int(rate * feats[0].shape[0])
|
||||||
npy = feats[0][-leng_replace_head:].cpu().numpy().astype("float32")
|
npy = feats[0][-leng_replace_head:].cpu().numpy().astype("float32")
|
||||||
score, ix = self.index.search(npy, k=8)
|
score, ix = self.index.search(npy, k=8)
|
||||||
weight = np.square(1 / score)
|
weight = np.square(1 / score)
|
||||||
@ -307,9 +323,11 @@ class RVC:
|
|||||||
t3 = ttime()
|
t3 = ttime()
|
||||||
if self.if_f0 == 1:
|
if self.if_f0 == 1:
|
||||||
pitch, pitchf = self.get_f0(indata, self.f0_up_key, self.n_cpu, f0method)
|
pitch, pitchf = self.get_f0(indata, self.f0_up_key, self.n_cpu, f0method)
|
||||||
cache_pitch[:] = np.append(cache_pitch[pitch[:-1].shape[0] :], pitch[:-1])
|
start_frame = block_frame_16k // 160
|
||||||
|
end_frame = len(cache_pitch) - (pitch.shape[0] - 4) + start_frame
|
||||||
|
cache_pitch[:] = np.append(cache_pitch[start_frame : end_frame], pitch[3:-1])
|
||||||
cache_pitchf[:] = np.append(
|
cache_pitchf[:] = np.append(
|
||||||
cache_pitchf[pitchf[:-1].shape[0] :], pitchf[:-1]
|
cache_pitchf[start_frame : end_frame], pitchf[3:-1]
|
||||||
)
|
)
|
||||||
p_len = min(feats.shape[1], 13000, cache_pitch.shape[0])
|
p_len = min(feats.shape[1], 13000, cache_pitch.shape[0])
|
||||||
else:
|
else:
|
||||||
@ -330,14 +348,14 @@ class RVC:
|
|||||||
# print(12222222222,feats.device,p_len.device,cache_pitch.device,cache_pitchf.device,sid.device,rate2)
|
# print(12222222222,feats.device,p_len.device,cache_pitch.device,cache_pitchf.device,sid.device,rate2)
|
||||||
infered_audio = (
|
infered_audio = (
|
||||||
self.net_g.infer(
|
self.net_g.infer(
|
||||||
feats, p_len, cache_pitch, cache_pitchf, sid, rate2
|
feats, p_len, cache_pitch, cache_pitchf, sid, rate
|
||||||
)[0][0, 0]
|
)[0][0, 0]
|
||||||
.data.cpu()
|
.data.cpu()
|
||||||
.float()
|
.float()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
infered_audio = (
|
infered_audio = (
|
||||||
self.net_g.infer(feats, p_len, sid, rate2)[0][0, 0]
|
self.net_g.infer(feats, p_len, sid, rate)[0][0, 0]
|
||||||
.data.cpu()
|
.data.cpu()
|
||||||
.float()
|
.float()
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user