drop: overwrite of config jsons & feat: read index from assets/indices
This commit is contained in:
parent
3a80032e74
commit
6141253fba
1
.env
1
.env
@ -5,4 +5,5 @@ no_proxy = localhost, 127.0.0.1, ::1
|
|||||||
weight_root = assets/weights
|
weight_root = assets/weights
|
||||||
weight_uvr5_root = assets/uvr5_weights
|
weight_uvr5_root = assets/uvr5_weights
|
||||||
index_root = logs
|
index_root = logs
|
||||||
|
outside_index_root = assets/indices
|
||||||
rmvpe_root = assets/rmvpe
|
rmvpe_root = assets/rmvpe
|
||||||
|
2
assets/indices/.gitignore
vendored
Normal file
2
assets/indices/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
*
|
||||||
|
!.gitignore
|
@ -2,6 +2,7 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
|
import shutil
|
||||||
from multiprocessing import cpu_count
|
from multiprocessing import cpu_count
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -65,7 +66,10 @@ class Config:
|
|||||||
def load_config_json() -> dict:
|
def load_config_json() -> dict:
|
||||||
d = {}
|
d = {}
|
||||||
for config_file in version_config_list:
|
for config_file in version_config_list:
|
||||||
with open(f"configs/{config_file}", "r") as f:
|
p = f"configs/inuse/{config_file}"
|
||||||
|
if not os.path.exists(p):
|
||||||
|
shutil.copy(f"configs/{config_file}", p)
|
||||||
|
with open(f"configs/inuse/{config_file}", "r") as f:
|
||||||
d[config_file] = json.load(f)
|
d[config_file] = json.load(f)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@ -124,12 +128,13 @@ class Config:
|
|||||||
def use_fp32_config(self):
|
def use_fp32_config(self):
|
||||||
for config_file in version_config_list:
|
for config_file in version_config_list:
|
||||||
self.json_config[config_file]["train"]["fp16_run"] = False
|
self.json_config[config_file]["train"]["fp16_run"] = False
|
||||||
with open(f"configs/{config_file}", "r") as f:
|
with open(f"configs/inuse/{config_file}", "r") as f:
|
||||||
strr = f.read().replace("true", "false")
|
strr = f.read().replace("true", "false")
|
||||||
with open(f"configs/{config_file}", "w") as f:
|
with open(f"configs/inuse/{config_file}", "w") as f:
|
||||||
f.write(strr)
|
f.write(strr)
|
||||||
|
logger.info("overwrite "+config_file)
|
||||||
self.preprocess_per = 3.0
|
self.preprocess_per = 3.0
|
||||||
logger.info("overwrite configs.json")
|
logger.info("overwrite preprocess_per to %d" % (self.preprocess_per))
|
||||||
|
|
||||||
def device_config(self) -> tuple:
|
def device_config(self) -> tuple:
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
4
configs/inuse/.gitignore
vendored
Normal file
4
configs/inuse/.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
*
|
||||||
|
!.gitignore
|
||||||
|
!v1
|
||||||
|
!v2
|
2
configs/inuse/v1/.gitignore
vendored
Normal file
2
configs/inuse/v1/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
*
|
||||||
|
!.gitignore
|
2
configs/inuse/v2/.gitignore
vendored
Normal file
2
configs/inuse/v2/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
*
|
||||||
|
!.gitignore
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
import shutil
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
@ -147,7 +148,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
try:
|
try:
|
||||||
with open("configs/config.json", "r") as j:
|
if not os.path.exists("configs/inuse/config.json"):
|
||||||
|
shutil.copy("configs/config.json", "configs/inuse/config.json")
|
||||||
|
with open("configs/inuse/config.json", "r") as j:
|
||||||
data = json.load(j)
|
data = json.load(j)
|
||||||
data["sr_model"] = data["sr_type"] == "sr_model"
|
data["sr_model"] = data["sr_type"] == "sr_model"
|
||||||
data["sr_device"] = data["sr_type"] == "sr_device"
|
data["sr_device"] = data["sr_type"] == "sr_device"
|
||||||
@ -179,7 +182,7 @@ if __name__ == "__main__":
|
|||||||
self.output_devices_indices.index(sd.default.device[1])
|
self.output_devices_indices.index(sd.default.device[1])
|
||||||
]
|
]
|
||||||
except:
|
except:
|
||||||
with open("configs/config.json", "w") as j:
|
with open("configs/inuse/config.json", "w") as j:
|
||||||
data = {
|
data = {
|
||||||
"pth_path": "",
|
"pth_path": "",
|
||||||
"index_path": "",
|
"index_path": "",
|
||||||
@ -578,7 +581,7 @@ if __name__ == "__main__":
|
|||||||
].index(True)
|
].index(True)
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
with open("configs/config.json", "w") as j:
|
with open("configs/inuse/config.json", "w") as j:
|
||||||
json.dump(settings, j)
|
json.dump(settings, j)
|
||||||
if self.stream is not None:
|
if self.stream is not None:
|
||||||
self.delay_time = (
|
self.delay_time = (
|
||||||
|
24
infer-web.py
24
infer-web.py
@ -131,16 +131,21 @@ class ToolButton(gr.Button, gr.components.FormComponent):
|
|||||||
weight_root = os.getenv("weight_root")
|
weight_root = os.getenv("weight_root")
|
||||||
weight_uvr5_root = os.getenv("weight_uvr5_root")
|
weight_uvr5_root = os.getenv("weight_uvr5_root")
|
||||||
index_root = os.getenv("index_root")
|
index_root = os.getenv("index_root")
|
||||||
|
outside_index_root = os.getenv("outside_index_root")
|
||||||
|
|
||||||
names = []
|
names = []
|
||||||
for name in os.listdir(weight_root):
|
for name in os.listdir(weight_root):
|
||||||
if name.endswith(".pth"):
|
if name.endswith(".pth"):
|
||||||
names.append(name)
|
names.append(name)
|
||||||
index_paths = []
|
index_paths = []
|
||||||
|
def lookup_indices(index_root):
|
||||||
|
global index_paths
|
||||||
for root, dirs, files in os.walk(index_root, topdown=False):
|
for root, dirs, files in os.walk(index_root, topdown=False):
|
||||||
for name in files:
|
for name in files:
|
||||||
if name.endswith(".index") and "trained" not in name:
|
if name.endswith(".index") and "trained" not in name:
|
||||||
index_paths.append("%s/%s" % (root, name))
|
index_paths.append("%s/%s" % (root, name))
|
||||||
|
lookup_indices(index_root)
|
||||||
|
lookup_indices(outside_index_root)
|
||||||
uvr5_names = []
|
uvr5_names = []
|
||||||
for name in os.listdir(weight_uvr5_root):
|
for name in os.listdir(weight_uvr5_root):
|
||||||
if name.endswith(".pth") or "onnx" in name:
|
if name.endswith(".pth") or "onnx" in name:
|
||||||
@ -658,6 +663,23 @@ def train_index(exp_dir1, version19):
|
|||||||
"%s/trained_IVF%s_Flat_nprobe_%s_%s_%s.index"
|
"%s/trained_IVF%s_Flat_nprobe_%s_%s_%s.index"
|
||||||
% (exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19),
|
% (exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19),
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
|
os.link(
|
||||||
|
"%s/trained_IVF%s_Flat_nprobe_%s_%s_%s.index"
|
||||||
|
% (exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19),
|
||||||
|
"%s/%s_IVF%s_Flat_nprobe_%s_%s_%s.index"
|
||||||
|
% (outside_index_root, exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19),
|
||||||
|
)
|
||||||
|
infos.append(
|
||||||
|
"链接索引到%s"
|
||||||
|
% (outside_index_root)
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
infos.append(
|
||||||
|
"链接索引到%s失败"
|
||||||
|
% (outside_index_root)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
infos.append("adding")
|
infos.append("adding")
|
||||||
yield "\n".join(infos)
|
yield "\n".join(infos)
|
||||||
@ -670,7 +692,7 @@ def train_index(exp_dir1, version19):
|
|||||||
% (exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19),
|
% (exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19),
|
||||||
)
|
)
|
||||||
infos.append(
|
infos.append(
|
||||||
"成功构建索引,added_IVF%s_Flat_nprobe_%s_%s_%s.index"
|
"成功构建索引 added_IVF%s_Flat_nprobe_%s_%s_%s.index"
|
||||||
% (n_ivf, index_ivf.nprobe, exp_dir1, version19)
|
% (n_ivf, index_ivf.nprobe, exp_dir1, version19)
|
||||||
)
|
)
|
||||||
# faiss.write_index(index, '%s/added_IVF%s_Flat_FastScan_%s.index'%(exp_dir,n_ivf,version19))
|
# faiss.write_index(index, '%s/added_IVF%s_Flat_FastScan_%s.index'%(exp_dir,n_ivf,version19))
|
||||||
|
Loading…
Reference in New Issue
Block a user