2022-12-19 04:18:56 +01:00
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
|
|
# All rights reserved.
|
|
|
|
#
|
|
|
|
# This source code is licensed under the license found in the
|
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""Loading pretrained models.
|
|
|
|
"""
|
|
|
|
|
|
|
|
import logging
|
|
|
|
from pathlib import Path
|
|
|
|
import typing as tp
|
|
|
|
|
2022-12-27 11:36:29 +01:00
|
|
|
#from dora.log import fatal
|
2022-12-19 04:18:56 +01:00
|
|
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
from diffq import DiffQuantizer
|
|
|
|
import torch.hub
|
|
|
|
|
|
|
|
from .model import Demucs
|
|
|
|
from .tasnet_v2 import ConvTasNet
|
|
|
|
from .utils import set_state
|
|
|
|
|
|
|
|
from .hdemucs import HDemucs
|
|
|
|
from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/"
|
|
|
|
REMOTE_ROOT = Path(__file__).parent / 'remote'
|
|
|
|
|
|
|
|
SOURCES = ["drums", "bass", "other", "vocals"]
|
|
|
|
|
|
|
|
|
|
|
|
def demucs_unittest():
|
|
|
|
model = HDemucs(channels=4, sources=SOURCES)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def add_model_flags(parser):
|
|
|
|
group = parser.add_mutually_exclusive_group(required=False)
|
|
|
|
group.add_argument("-s", "--sig", help="Locally trained XP signature.")
|
|
|
|
group.add_argument("-n", "--name", default="mdx_extra_q",
|
|
|
|
help="Pretrained model name or signature. Default is mdx_extra_q.")
|
|
|
|
parser.add_argument("--repo", type=Path,
|
|
|
|
help="Folder containing all pre-trained models for use with -n.")
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
|
|
|
|
root: str = ''
|
|
|
|
models: tp.Dict[str, str] = {}
|
|
|
|
for line in remote_file_list.read_text().split('\n'):
|
|
|
|
line = line.strip()
|
|
|
|
if line.startswith('#'):
|
|
|
|
continue
|
|
|
|
elif line.startswith('root:'):
|
|
|
|
root = line.split(':', 1)[1].strip()
|
|
|
|
else:
|
|
|
|
sig = line.split('-', 1)[0]
|
|
|
|
assert sig not in models
|
|
|
|
models[sig] = ROOT_URL + root + line
|
|
|
|
return models
|
|
|
|
|
|
|
|
def get_model(name: str,
|
|
|
|
repo: tp.Optional[Path] = None):
|
|
|
|
"""`name` must be a bag of models name or a pretrained signature
|
|
|
|
from the remote AWS model repo or the specified local repo if `repo` is not None.
|
|
|
|
"""
|
|
|
|
if name == 'demucs_unittest':
|
|
|
|
return demucs_unittest()
|
|
|
|
model_repo: ModelOnlyRepo
|
|
|
|
if repo is None:
|
|
|
|
models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
|
|
|
|
model_repo = RemoteRepo(models)
|
|
|
|
bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
|
|
|
|
else:
|
|
|
|
if not repo.is_dir():
|
|
|
|
fatal(f"{repo} must exist and be a directory.")
|
|
|
|
model_repo = LocalRepo(repo)
|
|
|
|
bag_repo = BagOnlyRepo(repo, model_repo)
|
|
|
|
any_repo = AnyModelRepo(model_repo, bag_repo)
|
|
|
|
model = any_repo.get_model(name)
|
|
|
|
model.eval()
|
|
|
|
return model
|
|
|
|
|
|
|
|
def get_model_from_args(args):
|
|
|
|
"""
|
|
|
|
Load local model package or pre-trained model.
|
|
|
|
"""
|
|
|
|
return get_model(name=args.name, repo=args.repo)
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/"
|
|
|
|
|
|
|
|
PRETRAINED_MODELS = {
|
|
|
|
'demucs': 'e07c671f',
|
|
|
|
'demucs48_hq': '28a1282c',
|
|
|
|
'demucs_extra': '3646af93',
|
|
|
|
'demucs_quantized': '07afea75',
|
|
|
|
'tasnet': 'beb46fac',
|
|
|
|
'tasnet_extra': 'df3777b2',
|
|
|
|
'demucs_unittest': '09ebc15f',
|
|
|
|
}
|
|
|
|
|
|
|
|
SOURCES = ["drums", "bass", "other", "vocals"]
|
|
|
|
|
|
|
|
|
|
|
|
def get_url(name):
|
|
|
|
sig = PRETRAINED_MODELS[name]
|
|
|
|
return ROOT + name + "-" + sig[:8] + ".th"
|
|
|
|
|
|
|
|
def is_pretrained(name):
|
|
|
|
return name in PRETRAINED_MODELS
|
|
|
|
|
|
|
|
|
|
|
|
def load_pretrained(name):
|
|
|
|
if name == "demucs":
|
|
|
|
return demucs(pretrained=True)
|
|
|
|
elif name == "demucs48_hq":
|
|
|
|
return demucs(pretrained=True, hq=True, channels=48)
|
|
|
|
elif name == "demucs_extra":
|
|
|
|
return demucs(pretrained=True, extra=True)
|
|
|
|
elif name == "demucs_quantized":
|
|
|
|
return demucs(pretrained=True, quantized=True)
|
|
|
|
elif name == "demucs_unittest":
|
|
|
|
return demucs_unittest(pretrained=True)
|
|
|
|
elif name == "tasnet":
|
|
|
|
return tasnet(pretrained=True)
|
|
|
|
elif name == "tasnet_extra":
|
|
|
|
return tasnet(pretrained=True, extra=True)
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Invalid pretrained name {name}")
|
|
|
|
|
|
|
|
|
|
|
|
def _load_state(name, model, quantizer=None):
|
|
|
|
url = get_url(name)
|
|
|
|
state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
|
|
|
|
set_state(model, quantizer, state)
|
|
|
|
if quantizer:
|
|
|
|
quantizer.detach()
|
|
|
|
|
|
|
|
|
|
|
|
def demucs_unittest(pretrained=True):
|
|
|
|
model = Demucs(channels=4, sources=SOURCES)
|
|
|
|
if pretrained:
|
|
|
|
_load_state('demucs_unittest', model)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64):
|
|
|
|
if not pretrained and (extra or quantized or hq):
|
|
|
|
raise ValueError("if extra or quantized is True, pretrained must be True.")
|
|
|
|
model = Demucs(sources=SOURCES, channels=channels)
|
|
|
|
if pretrained:
|
|
|
|
name = 'demucs'
|
|
|
|
if channels != 64:
|
|
|
|
name += str(channels)
|
|
|
|
quantizer = None
|
|
|
|
if sum([extra, quantized, hq]) > 1:
|
|
|
|
raise ValueError("Only one of extra, quantized, hq, can be True.")
|
|
|
|
if quantized:
|
|
|
|
quantizer = DiffQuantizer(model, group_size=8, min_size=1)
|
|
|
|
name += '_quantized'
|
|
|
|
if extra:
|
|
|
|
name += '_extra'
|
|
|
|
if hq:
|
|
|
|
name += '_hq'
|
|
|
|
_load_state(name, model, quantizer)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def tasnet(pretrained=True, extra=False):
|
|
|
|
if not pretrained and extra:
|
|
|
|
raise ValueError("if extra is True, pretrained must be True.")
|
|
|
|
model = ConvTasNet(X=10, sources=SOURCES)
|
|
|
|
if pretrained:
|
|
|
|
name = 'tasnet'
|
|
|
|
if extra:
|
|
|
|
name = 'tasnet_extra'
|
|
|
|
_load_state(name, model)
|
|
|
|
return model
|