optimize: cpt as #836 mentioned
This commit is contained in:
parent
76c18b547c
commit
23f64d3aa8
14
infer-web.py
14
infer-web.py
@ -1,9 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
|
||||||
sys.path.append(now_dir)
|
|
||||||
import traceback, pdb
|
import traceback, pdb
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -40,7 +36,7 @@ from sklearn.cluster import MiniBatchKMeans
|
|||||||
|
|
||||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
now_dir = os.getcwd()
|
||||||
tmp = os.path.join(now_dir, "TEMP")
|
tmp = os.path.join(now_dir, "TEMP")
|
||||||
shutil.rmtree(tmp, ignore_errors=True)
|
shutil.rmtree(tmp, ignore_errors=True)
|
||||||
shutil.rmtree(
|
shutil.rmtree(
|
||||||
@ -156,6 +152,7 @@ for name in os.listdir(weight_uvr5_root):
|
|||||||
if name.endswith(".pth") or "onnx" in name:
|
if name.endswith(".pth") or "onnx" in name:
|
||||||
uvr5_names.append(name.replace(".pth", ""))
|
uvr5_names.append(name.replace(".pth", ""))
|
||||||
|
|
||||||
|
cpt = None
|
||||||
|
|
||||||
def vc_single(
|
def vc_single(
|
||||||
sid,
|
sid,
|
||||||
@ -172,7 +169,7 @@ def vc_single(
|
|||||||
rms_mix_rate,
|
rms_mix_rate,
|
||||||
protect,
|
protect,
|
||||||
): # spk_item, input_audio0, vc_transform0,f0_file,f0method0
|
): # spk_item, input_audio0, vc_transform0,f0_file,f0method0
|
||||||
global tgt_sr, net_g, vc, hubert_model, version
|
global tgt_sr, net_g, vc, hubert_model, version, cpt
|
||||||
if input_audio_path is None:
|
if input_audio_path is None:
|
||||||
return "You need to upload an audio", None
|
return "You need to upload an audio", None
|
||||||
f0_up_key = int(f0_up_key)
|
f0_up_key = int(f0_up_key)
|
||||||
@ -429,9 +426,7 @@ def get_vc(sid, to_return_protect0, to_return_protect1):
|
|||||||
else:
|
else:
|
||||||
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
|
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
|
||||||
del net_g, cpt
|
del net_g, cpt
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
cpt = None
|
|
||||||
return {"visible": False, "__type__": "update"}
|
return {"visible": False, "__type__": "update"}
|
||||||
person = "%s/%s" % (weight_root, sid)
|
person = "%s/%s" % (weight_root, sid)
|
||||||
print("loading %s" % person)
|
print("loading %s" % person)
|
||||||
@ -1253,6 +1248,7 @@ def change_info_(ckpt_path):
|
|||||||
|
|
||||||
|
|
||||||
def export_onnx(ModelPath, ExportedPath):
|
def export_onnx(ModelPath, ExportedPath):
|
||||||
|
global cpt
|
||||||
cpt = torch.load(ModelPath, map_location="cpu")
|
cpt = torch.load(ModelPath, map_location="cpu")
|
||||||
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
||||||
vec_channels = 256 if cpt.get("version", "v1") == "v1" else 768
|
vec_channels = 256 if cpt.get("version", "v1") == "v1" else 768
|
||||||
|
Loading…
x
Reference in New Issue
Block a user