From ac1397f3f9845ca8028235dbe358d951173817ea Mon Sep 17 00:00:00 2001 From: yxlllc <33565655+yxlllc@users.noreply.github.com> Date: Thu, 21 Sep 2023 22:17:06 +0800 Subject: [PATCH] add input wav and delay time monitor (#1295) --- gui_v1.py | 85 ++++++++++++++++++++++++++---------- tools/torchgate/torchgate.py | 7 --- 2 files changed, 63 insertions(+), 29 deletions(-) diff --git a/gui_v1.py b/gui_v1.py index f86809d..b36a08a 100644 --- a/gui_v1.py +++ b/gui_v1.py @@ -14,7 +14,7 @@ sys.path.append(now_dir) import multiprocessing logger = logging.getLogger(__name__) - +stream_latency = -1 class Harvest(multiprocessing.Process): def __init__(self, inp_q, opt_q): @@ -100,7 +100,8 @@ if __name__ == "__main__": def __init__(self) -> None: self.config = GUIConfig() self.flag_vc = False - + self.function = 'vc' + self.delay_time = 0 self.launcher() def load(self): @@ -112,6 +113,10 @@ if __name__ == "__main__": data["harvest"] = data["f0method"] == "harvest" data["crepe"] = data["f0method"] == "crepe" data["rmvpe"] = data["f0method"] == "rmvpe" + if data["sg_input_device"] not in input_devices: + data["sg_input_device"] = input_devices[sd.default.device[0]] + if data["sg_output_device"] not in output_devices: + data["sg_output_device"] = output_devices[sd.default.device[1]] except: with open("configs/config.json", "w") as j: data = { @@ -342,6 +347,22 @@ if __name__ == "__main__": [ sg.Button(i18n("开始音频转换"), key="start_vc"), sg.Button(i18n("停止音频转换"), key="stop_vc"), + sg.Radio( + i18n("输入监听"), + "function", + key="im", + default=False, + enable_events=True, + ), + sg.Radio( + i18n("输出变声"), + "function", + key="vc", + default=True, + enable_events=True, + ), + sg.Text(i18n("算法延迟(ms):")), + sg.Text("0", key="delay_time"), sg.Text(i18n("推理时间(ms):")), sg.Text("0", key="infer_time"), ], @@ -403,9 +424,16 @@ if __name__ == "__main__": } with open("configs/config.json", "w") as j: json.dump(settings, j) + global stream_latency + while stream_latency < 0: + time.sleep(0.01) + self.delay_time = stream_latency + values["block_time"] + values["crossfade_length"] + 0.01 + if values["I_noise_reduce"]: + self.delay_time += values["crossfade_length"] + self.window["delay_time"].update(int(self.delay_time * 1000)) if event == "stop_vc" and self.flag_vc == True: self.flag_vc = False - + stream_latency = -1 # Parameter hot update if event == "threhold": self.config.threhold = values["threhold"] @@ -423,11 +451,17 @@ if __name__ == "__main__": self.config.f0method = event elif event == "I_noise_reduce": self.config.I_noise_reduce = values["I_noise_reduce"] + if stream_latency > 0: + self.delay_time += (1 if values["I_noise_reduce"] else -1) * values["crossfade_length"] + self.window["delay_time"].update(int(self.delay_time * 1000)) elif event == "O_noise_reduce": self.config.O_noise_reduce = values["O_noise_reduce"] + elif event in ["vc", "im"]: + self.function = event elif event != "start_vc" and self.flag_vc == True: # Other parameters do not support hot update self.flag_vc = False + stream_latency = -1 def set_values(self, values): if len(values["pth_path"].strip()) == 0: @@ -565,7 +599,9 @@ if __name__ == "__main__": blocksize=self.block_frame, samplerate=self.config.samplerate, dtype="float32", - ): + ) as stream: + global stream_latency + stream_latency = stream.latency[-1] while self.flag_vc: time.sleep(self.config.block_time) logger.debug("Audio block passed.") @@ -597,7 +633,7 @@ if __name__ == "__main__": self.block_frame_16k : ].clone() # input noise reduction and resampling - if self.config.I_noise_reduce: + if self.config.I_noise_reduce and self.function == 'vc': input_wav = self.input_wav[ -self.crossfade_frame - self.block_frame - 2 * self.zc : ] @@ -621,23 +657,28 @@ if __name__ == "__main__": self.input_wav[-self.block_frame - 2 * self.zc :] )[160:] # infer - f0_extractor_frame = self.block_frame_16k + 800 - if self.config.f0method == "rmvpe": - f0_extractor_frame = 5120 * ((f0_extractor_frame - 1) // 5120 + 1) - 160 - infer_wav = self.rvc.infer( - self.input_wav_res, - self.input_wav_res[-f0_extractor_frame:].cpu().numpy(), - self.block_frame_16k, - self.valid_rate, - self.pitch, - self.pitchf, - self.config.f0method, - ) - infer_wav = infer_wav[ - -self.crossfade_frame - self.sola_search_frame - self.block_frame : - ] + if self.function == 'vc': + f0_extractor_frame = self.block_frame_16k + 800 + if self.config.f0method == "rmvpe": + f0_extractor_frame = 5120 * ((f0_extractor_frame - 1) // 5120 + 1) - 160 + infer_wav = self.rvc.infer( + self.input_wav_res, + self.input_wav_res[-f0_extractor_frame:].cpu().numpy(), + self.block_frame_16k, + self.valid_rate, + self.pitch, + self.pitchf, + self.config.f0method, + ) + infer_wav = infer_wav[ + -self.crossfade_frame - self.sola_search_frame - self.block_frame : + ] + else: + infer_wav = self.input_wav[ + -self.crossfade_frame - self.sola_search_frame - self.block_frame : + ].clone() # output noise reduction - if self.config.O_noise_reduce: + if (self.config.O_noise_reduce and self.function == 'vc') or (self.config.I_noise_reduce and self.function == 'im'): self.output_buffer[: -self.block_frame] = self.output_buffer[ self.block_frame : ].clone() @@ -646,7 +687,7 @@ if __name__ == "__main__": infer_wav.unsqueeze(0), self.output_buffer.unsqueeze(0) ).squeeze(0) # volume envelop mixing - if self.config.rms_mix_rate < 1: + if self.config.rms_mix_rate < 1 and self.function == 'vc': rms1 = librosa.feature.rms( y=self.input_wav_res[-160 * infer_wav.shape[0] // self.zc :] .cpu() diff --git a/tools/torchgate/torchgate.py b/tools/torchgate/torchgate.py index 086f2ab..f95ffef 100644 --- a/tools/torchgate/torchgate.py +++ b/tools/torchgate/torchgate.py @@ -211,13 +211,6 @@ class TorchGate(torch.nn.Module): Returns: torch.Tensor: The denoised audio signal, with the same shape as the input signal. """ - assert x.ndim == 2 - if x.shape[-1] < self.win_length * 2: - raise Exception(f"x must be bigger than {self.win_length * 2}") - - assert xn is None or xn.ndim == 1 or xn.ndim == 2 - if xn is not None and xn.shape[-1] < self.win_length * 2: - raise Exception(f"xn must be bigger than {self.win_length * 2}") # Compute short-time Fourier transform (STFT) X = torch.stft(