mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2025-01-20 09:32:44 +01:00
245 lines
10 KiB
Python
245 lines
10 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import argparse
|
|
import os
|
|
from pathlib import Path
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser("demucs", description="Train and evaluate Demucs.")
|
|
default_raw = None
|
|
default_musdb = None
|
|
if 'DEMUCS_RAW' in os.environ:
|
|
default_raw = Path(os.environ['DEMUCS_RAW'])
|
|
if 'DEMUCS_MUSDB' in os.environ:
|
|
default_musdb = Path(os.environ['DEMUCS_MUSDB'])
|
|
parser.add_argument(
|
|
"--raw",
|
|
type=Path,
|
|
default=default_raw,
|
|
help="Path to raw audio, can be faster, see python3 -m demucs.raw to extract.")
|
|
parser.add_argument("--no_raw", action="store_const", const=None, dest="raw")
|
|
parser.add_argument("-m",
|
|
"--musdb",
|
|
type=Path,
|
|
default=default_musdb,
|
|
help="Path to musdb root")
|
|
parser.add_argument("--is_wav", action="store_true",
|
|
help="Indicate that the MusDB dataset is in wav format (i.e. MusDB-HQ).")
|
|
parser.add_argument("--metadata", type=Path, default=Path("metadata/"),
|
|
help="Folder where metadata information is stored.")
|
|
parser.add_argument("--wav", type=Path,
|
|
help="Path to a wav dataset. This should contain a 'train' and a 'valid' "
|
|
"subfolder.")
|
|
parser.add_argument("--samplerate", type=int, default=44100)
|
|
parser.add_argument("--audio_channels", type=int, default=2)
|
|
parser.add_argument("--samples",
|
|
default=44100 * 10,
|
|
type=int,
|
|
help="number of samples to feed in")
|
|
parser.add_argument("--data_stride",
|
|
default=44100,
|
|
type=int,
|
|
help="Stride for chunks, shorter = longer epochs")
|
|
parser.add_argument("-w", "--workers", default=10, type=int, help="Loader workers")
|
|
parser.add_argument("--eval_workers", default=2, type=int, help="Final evaluation workers")
|
|
parser.add_argument("-d",
|
|
"--device",
|
|
help="Device to train on, default is cuda if available else cpu")
|
|
parser.add_argument("--eval_cpu", action="store_true", help="Eval on test will be run on cpu.")
|
|
parser.add_argument("--dummy", help="Dummy parameter, useful to create a new checkpoint file")
|
|
parser.add_argument("--test", help="Just run the test pipeline + one validation. "
|
|
"This should be a filename relative to the models/ folder.")
|
|
parser.add_argument("--test_pretrained", help="Just run the test pipeline + one validation, "
|
|
"on a pretrained model. ")
|
|
|
|
parser.add_argument("--rank", default=0, type=int)
|
|
parser.add_argument("--world_size", default=1, type=int)
|
|
parser.add_argument("--master")
|
|
|
|
parser.add_argument("--checkpoints",
|
|
type=Path,
|
|
default=Path("checkpoints"),
|
|
help="Folder where to store checkpoints etc")
|
|
parser.add_argument("--evals",
|
|
type=Path,
|
|
default=Path("evals"),
|
|
help="Folder where to store evals and waveforms")
|
|
parser.add_argument("--save",
|
|
action="store_true",
|
|
help="Save estimated for the test set waveforms")
|
|
parser.add_argument("--logs",
|
|
type=Path,
|
|
default=Path("logs"),
|
|
help="Folder where to store logs")
|
|
parser.add_argument("--models",
|
|
type=Path,
|
|
default=Path("models"),
|
|
help="Folder where to store trained models")
|
|
parser.add_argument("-R",
|
|
"--restart",
|
|
action='store_true',
|
|
help='Restart training, ignoring previous run')
|
|
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
parser.add_argument("-e", "--epochs", type=int, default=180, help="Number of epochs")
|
|
parser.add_argument("-r",
|
|
"--repeat",
|
|
type=int,
|
|
default=2,
|
|
help="Repeat the train set, longer epochs")
|
|
parser.add_argument("-b", "--batch_size", type=int, default=64)
|
|
parser.add_argument("--lr", type=float, default=3e-4)
|
|
parser.add_argument("--mse", action="store_true", help="Use MSE instead of L1")
|
|
parser.add_argument("--init", help="Initialize from a pre-trained model.")
|
|
|
|
# Augmentation options
|
|
parser.add_argument("--no_augment",
|
|
action="store_false",
|
|
dest="augment",
|
|
default=True,
|
|
help="No basic data augmentation.")
|
|
parser.add_argument("--repitch", type=float, default=0.2,
|
|
help="Probability to do tempo/pitch change")
|
|
parser.add_argument("--max_tempo", type=float, default=12,
|
|
help="Maximum relative tempo change in %% when using repitch.")
|
|
|
|
parser.add_argument("--remix_group_size",
|
|
type=int,
|
|
default=4,
|
|
help="Shuffle sources using group of this size. Useful to somewhat "
|
|
"replicate multi-gpu training "
|
|
"on less GPUs.")
|
|
parser.add_argument("--shifts",
|
|
type=int,
|
|
default=10,
|
|
help="Number of random shifts used for the shift trick.")
|
|
parser.add_argument("--overlap",
|
|
type=float,
|
|
default=0.25,
|
|
help="Overlap when --split_valid is passed.")
|
|
|
|
# See model.py for doc
|
|
parser.add_argument("--growth",
|
|
type=float,
|
|
default=2.,
|
|
help="Number of channels between two layers will increase by this factor")
|
|
parser.add_argument("--depth",
|
|
type=int,
|
|
default=6,
|
|
help="Number of layers for the encoder and decoder")
|
|
parser.add_argument("--lstm_layers", type=int, default=2, help="Number of layers for the LSTM")
|
|
parser.add_argument("--channels",
|
|
type=int,
|
|
default=64,
|
|
help="Number of channels for the first encoder layer")
|
|
parser.add_argument("--kernel_size",
|
|
type=int,
|
|
default=8,
|
|
help="Kernel size for the (transposed) convolutions")
|
|
parser.add_argument("--conv_stride",
|
|
type=int,
|
|
default=4,
|
|
help="Stride for the (transposed) convolutions")
|
|
parser.add_argument("--context",
|
|
type=int,
|
|
default=3,
|
|
help="Context size for the decoder convolutions "
|
|
"before the transposed convolutions")
|
|
parser.add_argument("--rescale",
|
|
type=float,
|
|
default=0.1,
|
|
help="Initial weight rescale reference")
|
|
parser.add_argument("--no_resample", action="store_false",
|
|
default=True, dest="resample",
|
|
help="No Resampling of the input/output x2")
|
|
parser.add_argument("--no_glu",
|
|
action="store_false",
|
|
default=True,
|
|
dest="glu",
|
|
help="Replace all GLUs by ReLUs")
|
|
parser.add_argument("--no_rewrite",
|
|
action="store_false",
|
|
default=True,
|
|
dest="rewrite",
|
|
help="No 1x1 rewrite convolutions")
|
|
parser.add_argument("--normalize", action="store_true")
|
|
parser.add_argument("--no_norm_wav", action="store_false", dest='norm_wav', default=True)
|
|
|
|
# Tasnet options
|
|
parser.add_argument("--tasnet", action="store_true")
|
|
parser.add_argument("--split_valid",
|
|
action="store_true",
|
|
help="Predict chunks by chunks for valid and test. Required for tasnet")
|
|
parser.add_argument("--X", type=int, default=8)
|
|
|
|
# Other options
|
|
parser.add_argument("--show",
|
|
action="store_true",
|
|
help="Show model architecture, size and exit")
|
|
parser.add_argument("--save_model", action="store_true",
|
|
help="Skip traning, just save final model "
|
|
"for the current checkpoint value.")
|
|
parser.add_argument("--save_state",
|
|
help="Skip training, just save state "
|
|
"for the current checkpoint value. You should "
|
|
"provide a model name as argument.")
|
|
|
|
# Quantization options
|
|
parser.add_argument("--q-min-size", type=float, default=1,
|
|
help="Only quantize layers over this size (in MB)")
|
|
parser.add_argument(
|
|
"--qat", type=int, help="If provided, use QAT training with that many bits.")
|
|
|
|
parser.add_argument("--diffq", type=float, default=0)
|
|
parser.add_argument(
|
|
"--ms-target", type=float, default=162,
|
|
help="Model size target in MB, when using DiffQ. Best model will be kept "
|
|
"only if it is smaller than this target.")
|
|
|
|
return parser
|
|
|
|
|
|
def get_name(parser, args):
|
|
"""
|
|
Return the name of an experiment given the args. Some parameters are ignored,
|
|
for instance --workers, as they do not impact the final result.
|
|
"""
|
|
ignore_args = set([
|
|
"checkpoints",
|
|
"deterministic",
|
|
"eval",
|
|
"evals",
|
|
"eval_cpu",
|
|
"eval_workers",
|
|
"logs",
|
|
"master",
|
|
"rank",
|
|
"restart",
|
|
"save",
|
|
"save_model",
|
|
"save_state",
|
|
"show",
|
|
"workers",
|
|
"world_size",
|
|
])
|
|
parts = []
|
|
name_args = dict(args.__dict__)
|
|
for name, value in name_args.items():
|
|
if name in ignore_args:
|
|
continue
|
|
if value != parser.get_default(name):
|
|
if isinstance(value, Path):
|
|
parts.append(f"{name}={value.name}")
|
|
else:
|
|
parts.append(f"{name}={value}")
|
|
if parts:
|
|
name = " ".join(parts)
|
|
else:
|
|
name = "default"
|
|
return name
|