1
0
mirror of synced 2024-11-30 18:24:32 +01:00

refactor: Improve get_projects function to return sorted list and default project name, also moved default_project_name logic into the get_projects function.

This commit is contained in:
Wernervanrun 2024-08-12 14:58:36 +02:00
parent 62c6142250
commit 73d30c0a26

View File

@ -810,11 +810,27 @@ def change_f0_method(f0method8):
# start tab loss graph helper functions # start tab loss graph helper functions
desired_tags = ["loss_d_total", "loss_g_total", "loss_g_fm", "loss_g_mel", "loss_g_kl"] desired_tags = ["loss_d_total", "loss_g_total", "loss_g_fm", "loss_g_mel", "loss_g_kl"]
def get_projects(): def get_projects():
""" """
Get the list of projects. Gets a list of project names from the index root directory.
Returns:
list: A list of project names.
str: The default project name (first in the list).
dict: A dictionary of image paths keyed by desired_tags for the default project.
""" """
return [name for name in os.listdir(index_root) if os.path.isdir(os.path.join(index_root, name)) and name != 'mute' and os.path.isdir(os.path.join(index_root, name, 'loss_graphs'))] projects = [name for name in os.listdir(index_root) if os.path.isdir(os.path.join(index_root, name)) and name != 'mute' and os.path.isdir(os.path.join(index_root, name, 'loss_graphs'))]
# Check if there are any projects before accessing
if projects:
default_project_name = projects[0]
else:
print("No projects found.")
default_project_name = None
return sorted(projects), default_project_name
def get_loss_graph_images(selection): def get_loss_graph_images(selection):
""" """
@ -839,6 +855,7 @@ def get_loss_graph_images(selection):
return graphs return graphs
def get_loss_graph_tabs(project): def get_loss_graph_tabs(project):
""" """
Create Gradio Tabs and Image fields for the loss graphs. Create Gradio Tabs and Image fields for the loss graphs.
@ -860,6 +877,7 @@ def get_loss_graph_tabs(project):
loss_graph_image_fields[tag] = image_field loss_graph_image_fields[tag] = image_field
return loss_graph_tabs, list(loss_graph_image_fields.values()) return loss_graph_tabs, list(loss_graph_image_fields.values())
def update_loss_graph_images(selection): def update_loss_graph_images(selection):
""" """
Update the loss graph images for a given project. Update the loss graph images for a given project.
@ -882,23 +900,14 @@ def update_loss_graph_images(selection):
return updated_values return updated_values
def update_projects(): def update_projects():
""" """
Update the list of projects. Update the list of projects.
""" """
projects = get_projects() projects, default_project_name = get_projects()
return {"choices": sorted(projects), "__type__": "update"} return {"choices": projects, "__type__": "update"}
projects = get_projects()
# Check if there are any projects before accessing
if projects:
default_project = projects[0]
default_loss_graph_images = get_loss_graph_images(projects[0])
else:
print("No projects found.")
default_project = None
default_loss_graph_images = []
# gradio app # gradio app
with gr.Blocks(title="RVC WebUI") as app: with gr.Blocks(title="RVC WebUI") as app:
@ -1522,11 +1531,12 @@ with gr.Blocks(title="RVC WebUI") as app:
) )
) )
with gr.Row(): with gr.Row():
projects, default_project_name = get_projects()
voice_list_dropdown = gr.Dropdown( voice_list_dropdown = gr.Dropdown(
label=i18n("选择语音"), label=i18n("选择语音"),
choices=sorted(projects), choices=projects,
interactive=True, interactive=True,
value=default_project value=default_project_name
) )
with gr.Column(): with gr.Column():
update_voice_list_button = gr.Button( update_voice_list_button = gr.Button(
@ -1544,7 +1554,7 @@ with gr.Blocks(title="RVC WebUI") as app:
api_name="infer_refresh" api_name="infer_refresh"
) )
with gr.Row(): with gr.Row():
tabs, image_fields = get_loss_graph_tabs(default_project) tabs, image_fields = get_loss_graph_tabs(default_project_name)
voice_list_dropdown.change( voice_list_dropdown.change(
fn=update_loss_graph_images, fn=update_loss_graph_images,