diff --git a/.env b/.env index 4dbb109..3e17125 100644 --- a/.env +++ b/.env @@ -5,4 +5,5 @@ no_proxy = localhost, 127.0.0.1, ::1 weight_root = assets/weights weight_uvr5_root = assets/uvr5_weights index_root = logs +outside_index_root = assets/indices rmvpe_root = assets/rmvpe diff --git a/assets/indices/.gitignore b/assets/indices/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/assets/indices/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/configs/config.py b/configs/config.py index 8410c58..fe18769 100644 --- a/configs/config.py +++ b/configs/config.py @@ -2,6 +2,7 @@ import argparse import os import sys import json +import shutil from multiprocessing import cpu_count import torch @@ -65,7 +66,10 @@ class Config: def load_config_json() -> dict: d = {} for config_file in version_config_list: - with open(f"configs/{config_file}", "r") as f: + p = f"configs/inuse/{config_file}" + if not os.path.exists(p): + shutil.copy(f"configs/{config_file}", p) + with open(f"configs/inuse/{config_file}", "r") as f: d[config_file] = json.load(f) return d @@ -124,12 +128,13 @@ class Config: def use_fp32_config(self): for config_file in version_config_list: self.json_config[config_file]["train"]["fp16_run"] = False - with open(f"configs/{config_file}", "r") as f: + with open(f"configs/inuse/{config_file}", "r") as f: strr = f.read().replace("true", "false") - with open(f"configs/{config_file}", "w") as f: + with open(f"configs/inuse/{config_file}", "w") as f: f.write(strr) + logger.info("overwrite "+config_file) self.preprocess_per = 3.0 - logger.info("overwrite configs.json") + logger.info("overwrite preprocess_per to %d" % (self.preprocess_per)) def device_config(self) -> tuple: if torch.cuda.is_available(): diff --git a/configs/inuse/.gitignore b/configs/inuse/.gitignore new file mode 100644 index 0000000..4194236 --- /dev/null +++ b/configs/inuse/.gitignore @@ -0,0 +1,4 @@ +* +!.gitignore +!v1 +!v2 diff --git a/configs/inuse/v1/.gitignore b/configs/inuse/v1/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/configs/inuse/v1/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/configs/inuse/v2/.gitignore b/configs/inuse/v2/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/configs/inuse/v2/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/gui_v1.py b/gui_v1.py index a43d79f..fa57881 100644 --- a/gui_v1.py +++ b/gui_v1.py @@ -1,6 +1,7 @@ import os import sys from dotenv import load_dotenv +import shutil load_dotenv() @@ -147,7 +148,9 @@ if __name__ == "__main__": def load(self): try: - with open("configs/config.json", "r") as j: + if not os.path.exists("configs/inuse/config.json"): + shutil.copy("configs/config.json", "configs/inuse/config.json") + with open("configs/inuse/config.json", "r") as j: data = json.load(j) data["sr_model"] = data["sr_type"] == "sr_model" data["sr_device"] = data["sr_type"] == "sr_device" @@ -179,7 +182,7 @@ if __name__ == "__main__": self.output_devices_indices.index(sd.default.device[1]) ] except: - with open("configs/config.json", "w") as j: + with open("configs/inuse/config.json", "w") as j: data = { "pth_path": "", "index_path": "", @@ -578,7 +581,7 @@ if __name__ == "__main__": ].index(True) ], } - with open("configs/config.json", "w") as j: + with open("configs/inuse/config.json", "w") as j: json.dump(settings, j) if self.stream is not None: self.delay_time = ( diff --git a/infer-web.py b/infer-web.py index 6918a92..99b5494 100644 --- a/infer-web.py +++ b/infer-web.py @@ -131,16 +131,21 @@ class ToolButton(gr.Button, gr.components.FormComponent): weight_root = os.getenv("weight_root") weight_uvr5_root = os.getenv("weight_uvr5_root") index_root = os.getenv("index_root") +outside_index_root = os.getenv("outside_index_root") names = [] for name in os.listdir(weight_root): if name.endswith(".pth"): names.append(name) index_paths = [] -for root, dirs, files in os.walk(index_root, topdown=False): - for name in files: - if name.endswith(".index") and "trained" not in name: - index_paths.append("%s/%s" % (root, name)) +def lookup_indices(index_root): + global index_paths + for root, dirs, files in os.walk(index_root, topdown=False): + for name in files: + if name.endswith(".index") and "trained" not in name: + index_paths.append("%s/%s" % (root, name)) +lookup_indices(index_root) +lookup_indices(outside_index_root) uvr5_names = [] for name in os.listdir(weight_uvr5_root): if name.endswith(".pth") or "onnx" in name: @@ -658,6 +663,23 @@ def train_index(exp_dir1, version19): "%s/trained_IVF%s_Flat_nprobe_%s_%s_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19), ) + try: + os.link( + "%s/trained_IVF%s_Flat_nprobe_%s_%s_%s.index" + % (exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19), + "%s/%s_IVF%s_Flat_nprobe_%s_%s_%s.index" + % (outside_index_root, exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19), + ) + infos.append( + "链接索引到%s" + % (outside_index_root) + ) + except: + infos.append( + "链接索引到%s失败" + % (outside_index_root) + ) + infos.append("adding") yield "\n".join(infos) @@ -670,7 +692,7 @@ def train_index(exp_dir1, version19): % (exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19), ) infos.append( - "成功构建索引,added_IVF%s_Flat_nprobe_%s_%s_%s.index" + "成功构建索引 added_IVF%s_Flat_nprobe_%s_%s_%s.index" % (n_ivf, index_ivf.nprobe, exp_dir1, version19) ) # faiss.write_index(index, '%s/added_IVF%s_Flat_FastScan_%s.index'%(exp_dir,n_ivf,version19))