From 3b2a1d081cab9b31afd3846bf1b0dcd4ec1d8e49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 25 Jan 2024 21:53:29 +0900 Subject: [PATCH] feat(gui): trained_xxxx.index hint (fix #1070) --- gui_v1.py | 2 +- tools/rvc_for_realtime.py | 20 +++++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/gui_v1.py b/gui_v1.py index fa57881..4331336 100644 --- a/gui_v1.py +++ b/gui_v1.py @@ -530,7 +530,7 @@ if __name__ == "__main__": self.gui_config.sg_hostapi = self.hostapis[0] self.window["sg_hostapi"].Update(values=self.hostapis) self.window["sg_hostapi"].Update(value=self.gui_config.sg_hostapi) - if self.gui_config.sg_input_device not in self.input_devices: + if self.gui_config.sg_input_device not in self.input_devices and len(self.input_devices) > 0: self.gui_config.sg_input_device = self.input_devices[0] self.window["sg_input_device"].Update(values=self.input_devices) self.window["sg_input_device"].Update( diff --git a/tools/rvc_for_realtime.py b/tools/rvc_for_realtime.py index 1f89cc6..f0c6bb9 100644 --- a/tools/rvc_for_realtime.py +++ b/tools/rvc_for_realtime.py @@ -372,15 +372,17 @@ class RVC: if hasattr(self, "index") and self.index_rate != 0: npy = feats[0][skip_head // 2 :].cpu().numpy().astype("float32") score, ix = self.index.search(npy, k=8) - weight = np.square(1 / score) - weight /= weight.sum(axis=1, keepdims=True) - npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) - if self.config.is_half: - npy = npy.astype("float16") - feats[0][skip_head // 2 :] = ( - torch.from_numpy(npy).unsqueeze(0).to(self.device) * self.index_rate - + (1 - self.index_rate) * feats[0][skip_head // 2 :] - ) + if (ix >= 0).all(): + weight = np.square(1 / score) + weight /= weight.sum(axis=1, keepdims=True) + npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) + if self.config.is_half: + npy = npy.astype("float16") + feats[0][skip_head // 2 :] = ( + torch.from_numpy(npy).unsqueeze(0).to(self.device) * self.index_rate + + (1 - self.index_rate) * feats[0][skip_head // 2 :] + ) + else: printt("Invalid index. You MUST use added_xxxx.index but not trained_xxxx.index!") else: printt("Index search FAILED or disabled") except: