diff --git a/train/process_ckpt.py b/train/process_ckpt.py index 3e4fc7e..3840345 100644 --- a/train/process_ckpt.py +++ b/train/process_ckpt.py @@ -69,10 +69,16 @@ def merge(path1,path2,alpha1,sr,f0,info,name): return opt ckpt1 = torch.load(path1, map_location="cpu") ckpt2 = torch.load(path2, map_location="cpu") - if("model"in ckpt1):ckpt1=extract(ckpt1) - else:ckpt1=ckpt1["weight"] - if("model"in ckpt2):ckpt2=extract(ckpt2) - else:ckpt2=ckpt2["weight"] + opt["config"] = ckpt1["config"] + ''' + if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000] + elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000] + elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000] + ''' + if("model"in ckpt1): ckpt1=extract(ckpt1) + else: ckpt1=ckpt1["weight"] + if("model"in ckpt2): ckpt2=extract(ckpt2) + else: ckpt2=ckpt2["weight"] if(sorted(list(ckpt1.keys()))!=sorted(list(ckpt2.keys()))):return "Fail to merge the models. The model architectures are not the same." opt = OrderedDict() opt["weight"] = {} @@ -85,12 +91,6 @@ def merge(path1,path2,alpha1,sr,f0,info,name): opt["weight"][key] = (alpha1*(ckpt1[key].float())+(1-alpha1)*(ckpt2[key].float())).half() # except: # pdb.set_trace() - opt["config"] = ckpt1["config"] - ''' - if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000] - elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000] - elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000] - ''' opt["sr"]=sr opt["f0"]=1 if f0=="是"else 0 opt["info"]=info