From 9a20c3b28f8bec5f4e064ae8bae8083b2e7e5a38 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 3 Aug 2023 10:25:05 +0800 Subject: [PATCH] Format code (#932) Co-authored-by: github-actions[bot] --- infer_cli.py | 8 ++---- tools/calc_rvc_model_similarity.py | 40 ++++++++++++++++++++---------- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/infer_cli.py b/infer_cli.py index e382eff..63170ae 100644 --- a/infer_cli.py +++ b/infer_cli.py @@ -234,16 +234,12 @@ def get_vc(model_path): version = cpt.get("version", "v1") if version == "v1": if if_f0 == 1: - net_g = SynthesizerTrnMs256NSFsid( - *cpt["config"], is_half=is_half - ) + net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=is_half) else: net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) elif version == "v2": if if_f0 == 1: - net_g = SynthesizerTrnMs768NSFsid( - *cpt["config"], is_half=is_half - ) + net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=is_half) else: net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) del net_g.enc_q diff --git a/tools/calc_rvc_model_similarity.py b/tools/calc_rvc_model_similarity.py index 174a00d..edc1cf8 100644 --- a/tools/calc_rvc_model_similarity.py +++ b/tools/calc_rvc_model_similarity.py @@ -1,10 +1,11 @@ # This code references https://huggingface.co/JosephusCheung/ASimilarityCalculatior/blob/main/qwerty.py # Fill in the path of the model to be queried and the root directory of the reference models, and this script will return the similarity between the model to be queried and all reference models. -import sys,os +import sys, os import torch import torch.nn as nn import torch.nn.functional as F + def cal_cross_attn(to_q, to_k, to_v, rand_input): hidden_dim, embed_dim = to_q.shape attn_to_q = nn.Linear(hidden_dim, embed_dim, bias=False) @@ -16,41 +17,50 @@ def cal_cross_attn(to_q, to_k, to_v, rand_input): return torch.einsum( "ik, jk -> ik", - F.softmax(torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)), dim=-1), - attn_to_v(rand_input) + F.softmax( + torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)), + dim=-1, + ), + attn_to_v(rand_input), ) + def model_hash(filename): try: with open(filename, "rb") as file: import hashlib + m = hashlib.sha256() file.seek(0x100000) m.update(file.read(0x10000)) return m.hexdigest()[0:8] except FileNotFoundError: - return 'NOFILE' + return "NOFILE" + def eval(model, n, input): qk = f"enc_p.encoder.attn_layers.{n}.conv_q.weight" uk = f"enc_p.encoder.attn_layers.{n}.conv_k.weight" vk = f"enc_p.encoder.attn_layers.{n}.conv_v.weight" - atoq, atok, atov = model[qk][:,:,0], model[uk][:,:,0], model[vk][:,:,0] + atoq, atok, atov = model[qk][:, :, 0], model[uk][:, :, 0], model[vk][:, :, 0] attn = cal_cross_attn(atoq, atok, atov, input) return attn -def main(path,root): + +def main(path, root): torch.manual_seed(114514) model_a = torch.load(path, map_location="cpu")["weight"] - print("query:\t\t%s\t%s"%(path,model_hash(path))) + print("query:\t\t%s\t%s" % (path, model_hash(path))) map_attn_a = {} map_rand_input = {} for n in range(6): - hidden_dim, embed_dim,_ = model_a[f"enc_p.encoder.attn_layers.{n}.conv_v.weight"].shape + hidden_dim, embed_dim, _ = model_a[ + f"enc_p.encoder.attn_layers.{n}.conv_v.weight" + ].shape rand_input = torch.randn([embed_dim, hidden_dim]) map_attn_a[n] = eval(model_a, n, rand_input) @@ -59,7 +69,7 @@ def main(path,root): del model_a for name in sorted(list(os.listdir(root))): - path="%s/%s"%(root,name) + path = "%s/%s" % (root, name) model_b = torch.load(path, map_location="cpu")["weight"] sims = [] @@ -70,9 +80,13 @@ def main(path,root): sim = torch.mean(torch.cosine_similarity(attn_a, attn_b)) sims.append(sim) - print("reference:\t%s\t%s\t%s"%(path,model_hash(path),f"{torch.mean(torch.stack(sims)) * 1e2:.2f}%")) + print( + "reference:\t%s\t%s\t%s" + % (path, model_hash(path), f"{torch.mean(torch.stack(sims)) * 1e2:.2f}%") + ) + if __name__ == "__main__": - query_path=r"weights\mi v3.pth" - reference_root=r"weights" - main(query_path,reference_root) + query_path = r"weights\mi v3.pth" + reference_root = r"weights" + main(query_path, reference_root)