1
0
mirror of synced 2024-12-18 02:16:08 +01:00

feat(gui): trained_xxxx.index hint (fix #1070)

This commit is contained in:
源文雨 2024-01-25 21:53:29 +09:00
parent 918ab7b62f
commit 3b2a1d081c
2 changed files with 12 additions and 10 deletions

View File

@ -530,7 +530,7 @@ if __name__ == "__main__":
self.gui_config.sg_hostapi = self.hostapis[0] self.gui_config.sg_hostapi = self.hostapis[0]
self.window["sg_hostapi"].Update(values=self.hostapis) self.window["sg_hostapi"].Update(values=self.hostapis)
self.window["sg_hostapi"].Update(value=self.gui_config.sg_hostapi) self.window["sg_hostapi"].Update(value=self.gui_config.sg_hostapi)
if self.gui_config.sg_input_device not in self.input_devices: if self.gui_config.sg_input_device not in self.input_devices and len(self.input_devices) > 0:
self.gui_config.sg_input_device = self.input_devices[0] self.gui_config.sg_input_device = self.input_devices[0]
self.window["sg_input_device"].Update(values=self.input_devices) self.window["sg_input_device"].Update(values=self.input_devices)
self.window["sg_input_device"].Update( self.window["sg_input_device"].Update(

View File

@ -372,6 +372,7 @@ class RVC:
if hasattr(self, "index") and self.index_rate != 0: if hasattr(self, "index") and self.index_rate != 0:
npy = feats[0][skip_head // 2 :].cpu().numpy().astype("float32") npy = feats[0][skip_head // 2 :].cpu().numpy().astype("float32")
score, ix = self.index.search(npy, k=8) score, ix = self.index.search(npy, k=8)
if (ix >= 0).all():
weight = np.square(1 / score) weight = np.square(1 / score)
weight /= weight.sum(axis=1, keepdims=True) weight /= weight.sum(axis=1, keepdims=True)
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
@ -381,6 +382,7 @@ class RVC:
torch.from_numpy(npy).unsqueeze(0).to(self.device) * self.index_rate torch.from_numpy(npy).unsqueeze(0).to(self.device) * self.index_rate
+ (1 - self.index_rate) * feats[0][skip_head // 2 :] + (1 - self.index_rate) * feats[0][skip_head // 2 :]
) )
else: printt("Invalid index. You MUST use added_xxxx.index but not trained_xxxx.index!")
else: else:
printt("Index search FAILED or disabled") printt("Index search FAILED or disabled")
except: except: