From 73d30c0a26123e851480c18c888ef55ae1d6f5dc Mon Sep 17 00:00:00 2001 From: Wernervanrun Date: Mon, 12 Aug 2024 14:58:36 +0200 Subject: [PATCH] refactor: Improve get_projects function to return sorted list and default project name, also moved default_project_name logic into the get_projects function. --- infer-web.py | 44 +++++++++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/infer-web.py b/infer-web.py index 8d87b41..ba94b83 100644 --- a/infer-web.py +++ b/infer-web.py @@ -810,11 +810,27 @@ def change_f0_method(f0method8): # start tab loss graph helper functions desired_tags = ["loss_d_total", "loss_g_total", "loss_g_fm", "loss_g_mel", "loss_g_kl"] + def get_projects(): """ - Get the list of projects. + Gets a list of project names from the index root directory. + + Returns: + list: A list of project names. + str: The default project name (first in the list). + dict: A dictionary of image paths keyed by desired_tags for the default project. """ - return [name for name in os.listdir(index_root) if os.path.isdir(os.path.join(index_root, name)) and name != 'mute' and os.path.isdir(os.path.join(index_root, name, 'loss_graphs'))] + projects = [name for name in os.listdir(index_root) if os.path.isdir(os.path.join(index_root, name)) and name != 'mute' and os.path.isdir(os.path.join(index_root, name, 'loss_graphs'))] + + # Check if there are any projects before accessing + if projects: + default_project_name = projects[0] + else: + print("No projects found.") + default_project_name = None + + return sorted(projects), default_project_name + def get_loss_graph_images(selection): """ @@ -839,6 +855,7 @@ def get_loss_graph_images(selection): return graphs + def get_loss_graph_tabs(project): """ Create Gradio Tabs and Image fields for the loss graphs. @@ -860,6 +877,7 @@ def get_loss_graph_tabs(project): loss_graph_image_fields[tag] = image_field return loss_graph_tabs, list(loss_graph_image_fields.values()) + def update_loss_graph_images(selection): """ Update the loss graph images for a given project. @@ -882,23 +900,14 @@ def update_loss_graph_images(selection): return updated_values + def update_projects(): """ Update the list of projects. """ - projects = get_projects() - return {"choices": sorted(projects), "__type__": "update"} + projects, default_project_name = get_projects() + return {"choices": projects, "__type__": "update"} -projects = get_projects() - -# Check if there are any projects before accessing -if projects: - default_project = projects[0] - default_loss_graph_images = get_loss_graph_images(projects[0]) -else: - print("No projects found.") - default_project = None - default_loss_graph_images = [] # gradio app with gr.Blocks(title="RVC WebUI") as app: @@ -1522,11 +1531,12 @@ with gr.Blocks(title="RVC WebUI") as app: ) ) with gr.Row(): + projects, default_project_name = get_projects() voice_list_dropdown = gr.Dropdown( label=i18n("选择语音"), - choices=sorted(projects), + choices=projects, interactive=True, - value=default_project + value=default_project_name ) with gr.Column(): update_voice_list_button = gr.Button( @@ -1544,7 +1554,7 @@ with gr.Blocks(title="RVC WebUI") as app: api_name="infer_refresh" ) with gr.Row(): - tabs, image_fields = get_loss_graph_tabs(default_project) + tabs, image_fields = get_loss_graph_tabs(default_project_name) voice_list_dropdown.change( fn=update_loss_graph_images,