refactor: Improve get_projects function to return sorted list and default project name, also moved default_project_name logic into the get_projects function.
This commit is contained in:
parent
62c6142250
commit
73d30c0a26
44
infer-web.py
44
infer-web.py
@ -810,11 +810,27 @@ def change_f0_method(f0method8):
|
|||||||
# start tab loss graph helper functions
|
# start tab loss graph helper functions
|
||||||
desired_tags = ["loss_d_total", "loss_g_total", "loss_g_fm", "loss_g_mel", "loss_g_kl"]
|
desired_tags = ["loss_d_total", "loss_g_total", "loss_g_fm", "loss_g_mel", "loss_g_kl"]
|
||||||
|
|
||||||
|
|
||||||
def get_projects():
|
def get_projects():
|
||||||
"""
|
"""
|
||||||
Get the list of projects.
|
Gets a list of project names from the index root directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of project names.
|
||||||
|
str: The default project name (first in the list).
|
||||||
|
dict: A dictionary of image paths keyed by desired_tags for the default project.
|
||||||
"""
|
"""
|
||||||
return [name for name in os.listdir(index_root) if os.path.isdir(os.path.join(index_root, name)) and name != 'mute' and os.path.isdir(os.path.join(index_root, name, 'loss_graphs'))]
|
projects = [name for name in os.listdir(index_root) if os.path.isdir(os.path.join(index_root, name)) and name != 'mute' and os.path.isdir(os.path.join(index_root, name, 'loss_graphs'))]
|
||||||
|
|
||||||
|
# Check if there are any projects before accessing
|
||||||
|
if projects:
|
||||||
|
default_project_name = projects[0]
|
||||||
|
else:
|
||||||
|
print("No projects found.")
|
||||||
|
default_project_name = None
|
||||||
|
|
||||||
|
return sorted(projects), default_project_name
|
||||||
|
|
||||||
|
|
||||||
def get_loss_graph_images(selection):
|
def get_loss_graph_images(selection):
|
||||||
"""
|
"""
|
||||||
@ -839,6 +855,7 @@ def get_loss_graph_images(selection):
|
|||||||
|
|
||||||
return graphs
|
return graphs
|
||||||
|
|
||||||
|
|
||||||
def get_loss_graph_tabs(project):
|
def get_loss_graph_tabs(project):
|
||||||
"""
|
"""
|
||||||
Create Gradio Tabs and Image fields for the loss graphs.
|
Create Gradio Tabs and Image fields for the loss graphs.
|
||||||
@ -860,6 +877,7 @@ def get_loss_graph_tabs(project):
|
|||||||
loss_graph_image_fields[tag] = image_field
|
loss_graph_image_fields[tag] = image_field
|
||||||
return loss_graph_tabs, list(loss_graph_image_fields.values())
|
return loss_graph_tabs, list(loss_graph_image_fields.values())
|
||||||
|
|
||||||
|
|
||||||
def update_loss_graph_images(selection):
|
def update_loss_graph_images(selection):
|
||||||
"""
|
"""
|
||||||
Update the loss graph images for a given project.
|
Update the loss graph images for a given project.
|
||||||
@ -882,23 +900,14 @@ def update_loss_graph_images(selection):
|
|||||||
|
|
||||||
return updated_values
|
return updated_values
|
||||||
|
|
||||||
|
|
||||||
def update_projects():
|
def update_projects():
|
||||||
"""
|
"""
|
||||||
Update the list of projects.
|
Update the list of projects.
|
||||||
"""
|
"""
|
||||||
projects = get_projects()
|
projects, default_project_name = get_projects()
|
||||||
return {"choices": sorted(projects), "__type__": "update"}
|
return {"choices": projects, "__type__": "update"}
|
||||||
|
|
||||||
projects = get_projects()
|
|
||||||
|
|
||||||
# Check if there are any projects before accessing
|
|
||||||
if projects:
|
|
||||||
default_project = projects[0]
|
|
||||||
default_loss_graph_images = get_loss_graph_images(projects[0])
|
|
||||||
else:
|
|
||||||
print("No projects found.")
|
|
||||||
default_project = None
|
|
||||||
default_loss_graph_images = []
|
|
||||||
|
|
||||||
# gradio app
|
# gradio app
|
||||||
with gr.Blocks(title="RVC WebUI") as app:
|
with gr.Blocks(title="RVC WebUI") as app:
|
||||||
@ -1522,11 +1531,12 @@ with gr.Blocks(title="RVC WebUI") as app:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
projects, default_project_name = get_projects()
|
||||||
voice_list_dropdown = gr.Dropdown(
|
voice_list_dropdown = gr.Dropdown(
|
||||||
label=i18n("选择语音"),
|
label=i18n("选择语音"),
|
||||||
choices=sorted(projects),
|
choices=projects,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
value=default_project
|
value=default_project_name
|
||||||
)
|
)
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
update_voice_list_button = gr.Button(
|
update_voice_list_button = gr.Button(
|
||||||
@ -1544,7 +1554,7 @@ with gr.Blocks(title="RVC WebUI") as app:
|
|||||||
api_name="infer_refresh"
|
api_name="infer_refresh"
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
tabs, image_fields = get_loss_graph_tabs(default_project)
|
tabs, image_fields = get_loss_graph_tabs(default_project_name)
|
||||||
|
|
||||||
voice_list_dropdown.change(
|
voice_list_dropdown.change(
|
||||||
fn=update_loss_graph_images,
|
fn=update_loss_graph_images,
|
||||||
|
Loading…
Reference in New Issue
Block a user