38 lines
1.2 KiB
Python
38 lines
1.2 KiB
Python
|
import torch
|
||
|
|
||
|
|
||
|
def get_synthesizer(pth_path, device=torch.device("cpu")):
|
||
|
from infer.lib.infer_pack.models import (
|
||
|
SynthesizerTrnMs256NSFsid,
|
||
|
SynthesizerTrnMs256NSFsid_nono,
|
||
|
SynthesizerTrnMs768NSFsid,
|
||
|
SynthesizerTrnMs768NSFsid_nono,
|
||
|
)
|
||
|
|
||
|
cpt = torch.load(pth_path, map_location=torch.device("cpu"))
|
||
|
# tgt_sr = cpt["config"][-1]
|
||
|
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
||
|
if_f0 = cpt.get("f0", 1)
|
||
|
version = cpt.get("version", "v1")
|
||
|
if version == "v1":
|
||
|
if if_f0 == 1:
|
||
|
net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=False)
|
||
|
else:
|
||
|
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
|
||
|
elif version == "v2":
|
||
|
if if_f0 == 1:
|
||
|
net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=False)
|
||
|
else:
|
||
|
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
|
||
|
del net_g.enc_q
|
||
|
# net_g.forward = net_g.infer
|
||
|
# ckpt = {}
|
||
|
# ckpt["config"] = cpt["config"]
|
||
|
# ckpt["f0"] = if_f0
|
||
|
# ckpt["version"] = version
|
||
|
# ckpt["info"] = cpt.get("info", "0epoch")
|
||
|
net_g.load_state_dict(cpt["weight"], strict=False)
|
||
|
net_g = net_g.float()
|
||
|
net_g.eval().to(device)
|
||
|
return net_g, cpt
|