mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-24 07:20:10 +01:00
67 lines
2.2 KiB
Python
67 lines
2.2 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
"""Loading pretrained models.
|
|
"""
|
|
|
|
import logging
|
|
from pathlib import Path
|
|
import typing as tp
|
|
|
|
from dora.log import fatal
|
|
|
|
from .hdemucs import HDemucs
|
|
from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa
|
|
|
|
logger = logging.getLogger(__name__)
|
|
ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/"
|
|
REMOTE_ROOT = Path(__file__).parent / 'remote'
|
|
|
|
SOURCES = ["drums", "bass", "other", "vocals"]
|
|
|
|
|
|
def demucs_unittest():
|
|
model = HDemucs(channels=4, sources=SOURCES)
|
|
return model
|
|
|
|
|
|
def add_model_flags(parser):
|
|
group = parser.add_mutually_exclusive_group(required=False)
|
|
group.add_argument("-s", "--sig", help="Locally trained XP signature.")
|
|
group.add_argument("-n", "--name", default="mdx_extra_q",
|
|
help="Pretrained model name or signature. Default is mdx_extra_q.")
|
|
parser.add_argument("--repo", type=Path,
|
|
help="Folder containing all pre-trained models for use with -n.")
|
|
|
|
|
|
def get_model(name: str,
|
|
repo: tp.Optional[Path] = None):
|
|
"""`name` must be a bag of models name or a pretrained signature
|
|
from the remote AWS model repo or the specified local repo if `repo` is not None.
|
|
"""
|
|
if name == 'demucs_unittest':
|
|
return demucs_unittest()
|
|
model_repo: ModelOnlyRepo
|
|
if repo is None:
|
|
remote_files = [line.strip()
|
|
for line in (REMOTE_ROOT / 'files.txt').read_text().split('\n')
|
|
if line.strip()]
|
|
model_repo = RemoteRepo(ROOT_URL, remote_files)
|
|
bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
|
|
else:
|
|
if not repo.is_dir():
|
|
fatal(f"{repo} must exist and be a directory.")
|
|
model_repo = LocalRepo(repo)
|
|
bag_repo = BagOnlyRepo(repo, model_repo)
|
|
any_repo = AnyModelRepo(model_repo, bag_repo)
|
|
return any_repo.get_model(name)
|
|
|
|
|
|
def get_model_from_args(args):
|
|
"""
|
|
Load local model package or pre-trained model.
|
|
"""
|
|
return get_model(name=args.name, repo=args.repo)
|