9a855e0457
fix: 融合后的模型无法加载&优化colab笔记本
102 lines
4.9 KiB
Python
102 lines
4.9 KiB
Python
import torch,traceback,os,pdb
|
|
from collections import OrderedDict
|
|
|
|
def savee(ckpt,sr,if_f0,name,epoch):
|
|
try:
|
|
opt = OrderedDict()
|
|
opt["weight"] = {}
|
|
for key in ckpt.keys():
|
|
if ("enc_q" in key): continue
|
|
opt["weight"][key] = ckpt[key].half()
|
|
if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4], 109, 256, 40000]
|
|
elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4,4], 109, 256, 48000]
|
|
elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000]
|
|
opt["info"] = "%sepoch"%epoch
|
|
opt["sr"] = sr
|
|
opt["f0"] =if_f0
|
|
torch.save(opt, "weights/%s.pth"%name)
|
|
return "Success."
|
|
except:
|
|
return traceback.format_exc()
|
|
|
|
def show_info(path):
|
|
try:
|
|
a = torch.load(path, map_location="cpu")
|
|
return "模型信息:%s\n采样率:%s\n模型是否输入音高引导:%s"%(a.get("info","None"),a.get("sr","None"),a.get("f0","None"),)
|
|
except:
|
|
return traceback.format_exc()
|
|
|
|
def extract_small_model(path,name,sr,if_f0,info):
|
|
try:
|
|
ckpt = torch.load(path, map_location="cpu")
|
|
if("model"in ckpt):ckpt=ckpt["model"]
|
|
opt = OrderedDict()
|
|
opt["weight"] = {}
|
|
for key in ckpt.keys():
|
|
if ("enc_q" in key): continue
|
|
opt["weight"][key] = ckpt[key].half()
|
|
if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4], 109, 256, 40000]
|
|
elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4,4], 109, 256, 48000]
|
|
elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000]
|
|
if(info==""):info="Extracted model."
|
|
opt["info"] = info
|
|
opt["sr"] = sr
|
|
opt["f0"] =int(if_f0)
|
|
torch.save(opt, "weights/%s.pth"%name)
|
|
return "Success."
|
|
except:
|
|
return traceback.format_exc()
|
|
|
|
def change_info(path,info,name):
|
|
try:
|
|
ckpt = torch.load(path, map_location="cpu")
|
|
ckpt["info"]=info
|
|
if(name==""):name=os.path.basename(path)
|
|
torch.save(ckpt, "weights/%s"%name)
|
|
return "Success."
|
|
except:
|
|
return traceback.format_exc()
|
|
|
|
def merge(path1,path2,alpha1,sr,f0,info,name):
|
|
try:
|
|
def extract(ckpt):
|
|
a = ckpt["model"]
|
|
opt = OrderedDict()
|
|
opt["weight"] = {}
|
|
for key in a.keys():
|
|
if ("enc_q" in key): continue
|
|
opt["weight"][key] = a[key]
|
|
return opt
|
|
ckpt1 = torch.load(path1, map_location="cpu")
|
|
ckpt2 = torch.load(path2, map_location="cpu")
|
|
cfg = ckpt1["config"]
|
|
if("model"in ckpt1): ckpt1=extract(ckpt1)
|
|
else: ckpt1=ckpt1["weight"]
|
|
if("model"in ckpt2): ckpt2=extract(ckpt2)
|
|
else: ckpt2=ckpt2["weight"]
|
|
if(sorted(list(ckpt1.keys()))!=sorted(list(ckpt2.keys()))):return "Fail to merge the models. The model architectures are not the same."
|
|
opt = OrderedDict()
|
|
opt["weight"] = {}
|
|
for key in ckpt1.keys():
|
|
# try:
|
|
if(key=="emb_g.weight"and ckpt1[key].shape!=ckpt2[key].shape):
|
|
min_shape0=min(ckpt1[key].shape[0],ckpt2[key].shape[0])
|
|
opt["weight"][key] = (alpha1 * (ckpt1[key][:min_shape0].float()) + (1 - alpha1) * (ckpt2[key][:min_shape0].float())).half()
|
|
else:
|
|
opt["weight"][key] = (alpha1*(ckpt1[key].float())+(1-alpha1)*(ckpt2[key].float())).half()
|
|
# except:
|
|
# pdb.set_trace()
|
|
opt["config"] = cfg
|
|
'''
|
|
if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000]
|
|
elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000]
|
|
elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000]
|
|
'''
|
|
opt["sr"]=sr
|
|
opt["f0"]=1 if f0=="是"else 0
|
|
opt["info"]=info
|
|
torch.save(opt, "weights/%s.pth"%name)
|
|
return "Success."
|
|
except:
|
|
return traceback.format_exc()
|