From 7fc6642c040066d074fed6f78772227c4a3a9407 Mon Sep 17 00:00:00 2001 From: Rice Cake Date: Wed, 28 Jun 2023 13:48:06 +0800 Subject: [PATCH] update index training script v2 (#643) * update index training script v2 * Apply Code Formatter Change --------- Co-authored-by: gak123 --- tools/infer/train-index-v2.py | 36 +++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/tools/infer/train-index-v2.py b/tools/infer/train-index-v2.py index 656a5a6..77dfa0b 100644 --- a/tools/infer/train-index-v2.py +++ b/tools/infer/train-index-v2.py @@ -2,9 +2,15 @@ 格式:直接cid为自带的index位;aid放不下了,通过字典来查,反正就5w个 """ import faiss, numpy as np, os +from sklearn.cluster import MiniBatchKMeans +import traceback +from multiprocessing import cpu_count # ###########如果是原始特征要先写save -inp_root = r"./logs/nene/3_feature768" +n_cpu = 0 +if n_cpu == 0: + n_cpu = cpu_count() +inp_root = r"./logs/anz/3_feature768" npys = [] listdir_res = list(os.listdir(inp_root)) for name in sorted(listdir_res): @@ -15,7 +21,27 @@ big_npy_idx = np.arange(big_npy.shape[0]) np.random.shuffle(big_npy_idx) big_npy = big_npy[big_npy_idx] print(big_npy.shape) # (6196072, 192)#fp32#4.43G -np.save("infer/big_src_feature_mi.npy", big_npy) +if big_npy.shape[0] > 2e5: + # if(1): + info = "Trying doing kmeans %s shape to 10k centers." % big_npy.shape[0] + print(info) + try: + big_npy = ( + MiniBatchKMeans( + n_clusters=10000, + verbose=True, + batch_size=256 * n_cpu, + compute_labels=False, + init="random", + ) + .fit(big_npy) + .cluster_centers_ + ) + except: + info = traceback.format_exc() + print(info) + +np.save("tools/infer/big_src_feature_mi.npy", big_npy) ##################train+add # big_npy=np.load("/bili-coeus/jupyter/jupyterhub-liujing04/vits_ch/inference_f0/big_src_feature_mi.npy") @@ -26,13 +52,15 @@ index_ivf = faiss.extract_index_ivf(index) # index_ivf.nprobe = 1 index.train(big_npy) faiss.write_index( - index, "infer/trained_IVF%s_Flat_baseline_src_feat_v2.index" % (n_ivf) + index, "tools/infer/trained_IVF%s_Flat_baseline_src_feat_v2.index" % (n_ivf) ) print("adding") batch_size_add = 8192 for i in range(0, big_npy.shape[0], batch_size_add): index.add(big_npy[i : i + batch_size_add]) -faiss.write_index(index, "infer/added_IVF%s_Flat_mi_baseline_src_feat.index" % (n_ivf)) +faiss.write_index( + index, "tools/infer/added_IVF%s_Flat_mi_baseline_src_feat.index" % (n_ivf) +) """ 大小(都是FP32) big_src_feature 2.95G