1
0
mirror of synced 2024-11-27 17:00:54 +01:00

refactor: Improve get_projects function to return sorted list and default project name, also moved default_project_name logic into the get_projects function.

This commit is contained in:
Wernervanrun 2024-08-12 14:58:36 +02:00
parent 62c6142250
commit 73d30c0a26

View File

@ -810,11 +810,27 @@ def change_f0_method(f0method8):
# start tab loss graph helper functions
desired_tags = ["loss_d_total", "loss_g_total", "loss_g_fm", "loss_g_mel", "loss_g_kl"]
def get_projects():
"""
Get the list of projects.
Gets a list of project names from the index root directory.
Returns:
list: A list of project names.
str: The default project name (first in the list).
dict: A dictionary of image paths keyed by desired_tags for the default project.
"""
return [name for name in os.listdir(index_root) if os.path.isdir(os.path.join(index_root, name)) and name != 'mute' and os.path.isdir(os.path.join(index_root, name, 'loss_graphs'))]
projects = [name for name in os.listdir(index_root) if os.path.isdir(os.path.join(index_root, name)) and name != 'mute' and os.path.isdir(os.path.join(index_root, name, 'loss_graphs'))]
# Check if there are any projects before accessing
if projects:
default_project_name = projects[0]
else:
print("No projects found.")
default_project_name = None
return sorted(projects), default_project_name
def get_loss_graph_images(selection):
"""
@ -839,6 +855,7 @@ def get_loss_graph_images(selection):
return graphs
def get_loss_graph_tabs(project):
"""
Create Gradio Tabs and Image fields for the loss graphs.
@ -860,6 +877,7 @@ def get_loss_graph_tabs(project):
loss_graph_image_fields[tag] = image_field
return loss_graph_tabs, list(loss_graph_image_fields.values())
def update_loss_graph_images(selection):
"""
Update the loss graph images for a given project.
@ -882,23 +900,14 @@ def update_loss_graph_images(selection):
return updated_values
def update_projects():
"""
Update the list of projects.
"""
projects = get_projects()
return {"choices": sorted(projects), "__type__": "update"}
projects, default_project_name = get_projects()
return {"choices": projects, "__type__": "update"}
projects = get_projects()
# Check if there are any projects before accessing
if projects:
default_project = projects[0]
default_loss_graph_images = get_loss_graph_images(projects[0])
else:
print("No projects found.")
default_project = None
default_loss_graph_images = []
# gradio app
with gr.Blocks(title="RVC WebUI") as app:
@ -1522,11 +1531,12 @@ with gr.Blocks(title="RVC WebUI") as app:
)
)
with gr.Row():
projects, default_project_name = get_projects()
voice_list_dropdown = gr.Dropdown(
label=i18n("选择语音"),
choices=sorted(projects),
choices=projects,
interactive=True,
value=default_project
value=default_project_name
)
with gr.Column():
update_voice_list_button = gr.Button(
@ -1544,7 +1554,7 @@ with gr.Blocks(title="RVC WebUI") as app:
api_name="infer_refresh"
)
with gr.Row():
tabs, image_fields = get_loss_graph_tabs(default_project)
tabs, image_fields = get_loss_graph_tabs(default_project_name)
voice_list_dropdown.change(
fn=update_loss_graph_images,