fix Config, GUIConfig and self (#340)
Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com>
This commit is contained in:
parent
0729c9d6f2
commit
a2ef4cca76
19
gui.py
19
gui.py
@ -16,7 +16,7 @@ now_dir = os.getcwd()
|
|||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
from config import Config as MyConfig
|
from config import Config as MyConfig
|
||||||
|
|
||||||
is_half = MyConfig().is_half
|
Config = Config()
|
||||||
import PySimpleGUI as sg
|
import PySimpleGUI as sg
|
||||||
import sounddevice as sd
|
import sounddevice as sd
|
||||||
import noisereduce as nr
|
import noisereduce as nr
|
||||||
@ -71,7 +71,7 @@ class RVC:
|
|||||||
)
|
)
|
||||||
self.model = models[0]
|
self.model = models[0]
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(device)
|
||||||
if is_half == True:
|
if Config.is_half:
|
||||||
self.model = self.model.half()
|
self.model = self.model.half()
|
||||||
else:
|
else:
|
||||||
self.model = self.model.float()
|
self.model = self.model.float()
|
||||||
@ -81,25 +81,24 @@ class RVC:
|
|||||||
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
|
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
|
||||||
self.if_f0 = cpt.get("f0", 1)
|
self.if_f0 = cpt.get("f0", 1)
|
||||||
self.version = cpt.get("version", "v1")
|
self.version = cpt.get("version", "v1")
|
||||||
|
|
||||||
if self.version == "v1":
|
if self.version == "v1":
|
||||||
if self.if_f0 == 1:
|
if self.if_f0 == 1:
|
||||||
self.net_g = SynthesizerTrnMs256NSFsid(
|
self.net_g = SynthesizerTrnMs256NSFsid(
|
||||||
*cpt["config"], is_half=self.config.is_half
|
*cpt["config"], is_half=Config.is_half
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
|
self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
|
||||||
elif self.version == "v2":
|
elif self.version == "v2":
|
||||||
if self.if_f0 == 1:
|
if self.if_f0 == 1:
|
||||||
self.net_g = SynthesizerTrnMs768NSFsid(
|
self.net_g = SynthesizerTrnMs768NSFsid(
|
||||||
*cpt["config"], is_half=self.config.is_half
|
*cpt["config"], is_half=Config.is_half
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
|
self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
|
||||||
del self.net_g.enc_q
|
del self.net_g.enc_q
|
||||||
print(self.net_g.load_state_dict(cpt["weight"], strict=False))
|
print(self.net_g.load_state_dict(cpt["weight"], strict=False))
|
||||||
self.net_g.eval().to(device)
|
self.net_g.eval().to(device)
|
||||||
if is_half == True:
|
if Config.is_half:
|
||||||
self.net_g = self.net_g.half()
|
self.net_g = self.net_g.half()
|
||||||
else:
|
else:
|
||||||
self.net_g = self.net_g.float()
|
self.net_g = self.net_g.float()
|
||||||
@ -160,7 +159,7 @@ class RVC:
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = self.model.extract_features(**inputs)
|
logits = self.model.extract_features(**inputs)
|
||||||
feats = model.final_proj(logits[0]) if self.version == "v1" else logits[0]
|
feats = self.model.final_proj(logits[0]) if self.version == "v1" else logits[0]
|
||||||
|
|
||||||
####索引优化
|
####索引优化
|
||||||
try:
|
try:
|
||||||
@ -174,7 +173,7 @@ class RVC:
|
|||||||
weight = np.square(1 / score)
|
weight = np.square(1 / score)
|
||||||
weight /= weight.sum(axis=1, keepdims=True)
|
weight /= weight.sum(axis=1, keepdims=True)
|
||||||
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
|
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
|
||||||
if is_half == True:
|
if Config.is_half:
|
||||||
npy = npy.astype("float16")
|
npy = npy.astype("float16")
|
||||||
feats = (
|
feats = (
|
||||||
torch.from_numpy(npy).unsqueeze(0).to(device) * self.index_rate
|
torch.from_numpy(npy).unsqueeze(0).to(device) * self.index_rate
|
||||||
@ -220,7 +219,7 @@ class RVC:
|
|||||||
return infered_audio
|
return infered_audio
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class GUIConfig:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.hubert_path: str = ""
|
self.hubert_path: str = ""
|
||||||
self.pth_path: str = ""
|
self.pth_path: str = ""
|
||||||
@ -240,7 +239,7 @@ class Config:
|
|||||||
|
|
||||||
class GUI:
|
class GUI:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.config = Config()
|
self.config = GUIConfig()
|
||||||
self.flag_vc = False
|
self.flag_vc = False
|
||||||
|
|
||||||
self.launcher()
|
self.launcher()
|
||||||
|
Loading…
Reference in New Issue
Block a user