2022-12-18 21:18:56 -06:00

180 lines
5.5 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Loading pretrained models.
"""
import logging
from pathlib import Path
import typing as tp
from dora.log import fatal
import logging
from diffq import DiffQuantizer
import torch.hub
from .model import Demucs
from .tasnet_v2 import ConvTasNet
from .utils import set_state
from .hdemucs import HDemucs
from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa
logger = logging.getLogger(__name__)
ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/"
REMOTE_ROOT = Path(__file__).parent / 'remote'
SOURCES = ["drums", "bass", "other", "vocals"]
def demucs_unittest():
model = HDemucs(channels=4, sources=SOURCES)
return model
def add_model_flags(parser):
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("-s", "--sig", help="Locally trained XP signature.")
group.add_argument("-n", "--name", default="mdx_extra_q",
help="Pretrained model name or signature. Default is mdx_extra_q.")
parser.add_argument("--repo", type=Path,
help="Folder containing all pre-trained models for use with -n.")
def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
root: str = ''
models: tp.Dict[str, str] = {}
for line in remote_file_list.read_text().split('\n'):
line = line.strip()
if line.startswith('#'):
continue
elif line.startswith('root:'):
root = line.split(':', 1)[1].strip()
else:
sig = line.split('-', 1)[0]
assert sig not in models
models[sig] = ROOT_URL + root + line
return models
def get_model(name: str,
repo: tp.Optional[Path] = None):
"""`name` must be a bag of models name or a pretrained signature
from the remote AWS model repo or the specified local repo if `repo` is not None.
"""
if name == 'demucs_unittest':
return demucs_unittest()
model_repo: ModelOnlyRepo
if repo is None:
models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
model_repo = RemoteRepo(models)
bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
else:
if not repo.is_dir():
fatal(f"{repo} must exist and be a directory.")
model_repo = LocalRepo(repo)
bag_repo = BagOnlyRepo(repo, model_repo)
any_repo = AnyModelRepo(model_repo, bag_repo)
model = any_repo.get_model(name)
model.eval()
return model
def get_model_from_args(args):
"""
Load local model package or pre-trained model.
"""
return get_model(name=args.name, repo=args.repo)
logger = logging.getLogger(__name__)
ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/"
PRETRAINED_MODELS = {
'demucs': 'e07c671f',
'demucs48_hq': '28a1282c',
'demucs_extra': '3646af93',
'demucs_quantized': '07afea75',
'tasnet': 'beb46fac',
'tasnet_extra': 'df3777b2',
'demucs_unittest': '09ebc15f',
}
SOURCES = ["drums", "bass", "other", "vocals"]
def get_url(name):
sig = PRETRAINED_MODELS[name]
return ROOT + name + "-" + sig[:8] + ".th"
def is_pretrained(name):
return name in PRETRAINED_MODELS
def load_pretrained(name):
if name == "demucs":
return demucs(pretrained=True)
elif name == "demucs48_hq":
return demucs(pretrained=True, hq=True, channels=48)
elif name == "demucs_extra":
return demucs(pretrained=True, extra=True)
elif name == "demucs_quantized":
return demucs(pretrained=True, quantized=True)
elif name == "demucs_unittest":
return demucs_unittest(pretrained=True)
elif name == "tasnet":
return tasnet(pretrained=True)
elif name == "tasnet_extra":
return tasnet(pretrained=True, extra=True)
else:
raise ValueError(f"Invalid pretrained name {name}")
def _load_state(name, model, quantizer=None):
url = get_url(name)
state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
set_state(model, quantizer, state)
if quantizer:
quantizer.detach()
def demucs_unittest(pretrained=True):
model = Demucs(channels=4, sources=SOURCES)
if pretrained:
_load_state('demucs_unittest', model)
return model
def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64):
if not pretrained and (extra or quantized or hq):
raise ValueError("if extra or quantized is True, pretrained must be True.")
model = Demucs(sources=SOURCES, channels=channels)
if pretrained:
name = 'demucs'
if channels != 64:
name += str(channels)
quantizer = None
if sum([extra, quantized, hq]) > 1:
raise ValueError("Only one of extra, quantized, hq, can be True.")
if quantized:
quantizer = DiffQuantizer(model, group_size=8, min_size=1)
name += '_quantized'
if extra:
name += '_extra'
if hq:
name += '_hq'
_load_state(name, model, quantizer)
return model
def tasnet(pretrained=True, extra=False):
if not pretrained and extra:
raise ValueError("if extra is True, pretrained must be True.")
model = ConvTasNet(X=10, sources=SOURCES)
if pretrained:
name = 'tasnet'
if extra:
name = 'tasnet_extra'
_load_state(name, model)
return model