mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-24 15:30:11 +01:00
273 lines
9.4 KiB
Python
273 lines
9.4 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from fractions import Fraction
|
|
|
|
import torch as th
|
|
from torch import distributed, nn
|
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
|
|
|
from .augment import FlipChannels, FlipSign, Remix, Shift
|
|
from .compressed import StemsSet, build_musdb_metadata, get_musdb_tracks
|
|
from .model import Demucs
|
|
from .parser import get_name, get_parser
|
|
from .raw import Rawset
|
|
from .tasnet import ConvTasNet
|
|
from .test import evaluate
|
|
from .train import train_model, validate_model
|
|
from .utils import human_seconds, load_model, save_model, sizeof_fmt
|
|
|
|
|
|
@dataclass
|
|
class SavedState:
|
|
metrics: list = field(default_factory=list)
|
|
last_state: dict = None
|
|
best_state: dict = None
|
|
optimizer: dict = None
|
|
|
|
|
|
def main():
|
|
parser = get_parser()
|
|
args = parser.parse_args()
|
|
name = get_name(parser, args)
|
|
print(f"Experiment {name}")
|
|
|
|
if args.musdb is None and args.rank == 0:
|
|
print(
|
|
"You must provide the path to the MusDB dataset with the --musdb flag. "
|
|
"To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.",
|
|
file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
eval_folder = args.evals / name
|
|
eval_folder.mkdir(exist_ok=True, parents=True)
|
|
args.logs.mkdir(exist_ok=True)
|
|
metrics_path = args.logs / f"{name}.json"
|
|
eval_folder.mkdir(exist_ok=True, parents=True)
|
|
args.checkpoints.mkdir(exist_ok=True, parents=True)
|
|
args.models.mkdir(exist_ok=True, parents=True)
|
|
|
|
if args.device is None:
|
|
device = "cpu"
|
|
if th.cuda.is_available():
|
|
device = "cuda"
|
|
else:
|
|
device = args.device
|
|
|
|
th.manual_seed(args.seed)
|
|
# Prevents too many threads to be started when running `museval` as it can be quite
|
|
# inefficient on NUMA architectures.
|
|
os.environ["OMP_NUM_THREADS"] = "1"
|
|
|
|
if args.world_size > 1:
|
|
if device != "cuda" and args.rank == 0:
|
|
print("Error: distributed training is only available with cuda device", file=sys.stderr)
|
|
sys.exit(1)
|
|
th.cuda.set_device(args.rank % th.cuda.device_count())
|
|
distributed.init_process_group(backend="nccl",
|
|
init_method="tcp://" + args.master,
|
|
rank=args.rank,
|
|
world_size=args.world_size)
|
|
|
|
checkpoint = args.checkpoints / f"{name}.th"
|
|
checkpoint_tmp = args.checkpoints / f"{name}.th.tmp"
|
|
if args.restart and checkpoint.exists():
|
|
checkpoint.unlink()
|
|
|
|
if args.test:
|
|
args.epochs = 1
|
|
args.repeat = 0
|
|
model = load_model(args.models / args.test)
|
|
elif args.tasnet:
|
|
model = ConvTasNet(audio_channels=args.audio_channels, samplerate=args.samplerate, X=args.X)
|
|
else:
|
|
model = Demucs(
|
|
audio_channels=args.audio_channels,
|
|
channels=args.channels,
|
|
context=args.context,
|
|
depth=args.depth,
|
|
glu=args.glu,
|
|
growth=args.growth,
|
|
kernel_size=args.kernel_size,
|
|
lstm_layers=args.lstm_layers,
|
|
rescale=args.rescale,
|
|
rewrite=args.rewrite,
|
|
sources=4,
|
|
stride=args.conv_stride,
|
|
upsample=args.upsample,
|
|
samplerate=args.samplerate
|
|
)
|
|
model.to(device)
|
|
if args.show:
|
|
print(model)
|
|
size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
|
|
print(f"Model size {size}")
|
|
return
|
|
|
|
optimizer = th.optim.Adam(model.parameters(), lr=args.lr)
|
|
|
|
try:
|
|
saved = th.load(checkpoint, map_location='cpu')
|
|
except IOError:
|
|
saved = SavedState()
|
|
else:
|
|
model.load_state_dict(saved.last_state)
|
|
optimizer.load_state_dict(saved.optimizer)
|
|
|
|
if args.save_model:
|
|
if args.rank == 0:
|
|
model.to("cpu")
|
|
model.load_state_dict(saved.best_state)
|
|
save_model(model, args.models / f"{name}.th")
|
|
return
|
|
|
|
if args.rank == 0:
|
|
done = args.logs / f"{name}.done"
|
|
if done.exists():
|
|
done.unlink()
|
|
|
|
if args.augment:
|
|
augment = nn.Sequential(FlipSign(), FlipChannels(), Shift(args.data_stride),
|
|
Remix(group_size=args.remix_group_size)).to(device)
|
|
else:
|
|
augment = Shift(args.data_stride)
|
|
|
|
if args.mse:
|
|
criterion = nn.MSELoss()
|
|
else:
|
|
criterion = nn.L1Loss()
|
|
|
|
# Setting number of samples so that all convolution windows are full.
|
|
# Prevents hard to debug mistake with the prediction being shifted compared
|
|
# to the input mixture.
|
|
samples = model.valid_length(args.samples)
|
|
print(f"Number of training samples adjusted to {samples}")
|
|
|
|
if args.raw:
|
|
train_set = Rawset(args.raw / "train",
|
|
samples=samples + args.data_stride,
|
|
channels=args.audio_channels,
|
|
streams=[0, 1, 2, 3, 4],
|
|
stride=args.data_stride)
|
|
|
|
valid_set = Rawset(args.raw / "valid", channels=args.audio_channels)
|
|
else:
|
|
if not args.metadata.is_file() and args.rank == 0:
|
|
build_musdb_metadata(args.metadata, args.musdb, args.workers)
|
|
if args.world_size > 1:
|
|
distributed.barrier()
|
|
metadata = json.load(open(args.metadata))
|
|
duration = Fraction(samples + args.data_stride, args.samplerate)
|
|
stride = Fraction(args.data_stride, args.samplerate)
|
|
train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"),
|
|
metadata,
|
|
duration=duration,
|
|
stride=stride,
|
|
samplerate=args.samplerate,
|
|
channels=args.audio_channels)
|
|
valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"),
|
|
metadata,
|
|
samplerate=args.samplerate,
|
|
channels=args.audio_channels)
|
|
|
|
best_loss = float("inf")
|
|
for epoch, metrics in enumerate(saved.metrics):
|
|
print(f"Epoch {epoch:03d}: "
|
|
f"train={metrics['train']:.8f} "
|
|
f"valid={metrics['valid']:.8f} "
|
|
f"best={metrics['best']:.4f} "
|
|
f"duration={human_seconds(metrics['duration'])}")
|
|
best_loss = metrics['best']
|
|
|
|
if args.world_size > 1:
|
|
dmodel = DistributedDataParallel(model,
|
|
device_ids=[th.cuda.current_device()],
|
|
output_device=th.cuda.current_device())
|
|
else:
|
|
dmodel = model
|
|
|
|
for epoch in range(len(saved.metrics), args.epochs):
|
|
begin = time.time()
|
|
model.train()
|
|
train_loss = train_model(epoch,
|
|
train_set,
|
|
dmodel,
|
|
criterion,
|
|
optimizer,
|
|
augment,
|
|
batch_size=args.batch_size,
|
|
device=device,
|
|
repeat=args.repeat,
|
|
seed=args.seed,
|
|
workers=args.workers,
|
|
world_size=args.world_size)
|
|
model.eval()
|
|
valid_loss = validate_model(epoch,
|
|
valid_set,
|
|
model,
|
|
criterion,
|
|
device=device,
|
|
rank=args.rank,
|
|
split=args.split_valid,
|
|
world_size=args.world_size)
|
|
|
|
duration = time.time() - begin
|
|
if valid_loss < best_loss:
|
|
best_loss = valid_loss
|
|
saved.best_state = {
|
|
key: value.to("cpu").clone()
|
|
for key, value in model.state_dict().items()
|
|
}
|
|
saved.metrics.append({
|
|
"train": train_loss,
|
|
"valid": valid_loss,
|
|
"best": best_loss,
|
|
"duration": duration
|
|
})
|
|
if args.rank == 0:
|
|
json.dump(saved.metrics, open(metrics_path, "w"))
|
|
|
|
saved.last_state = model.state_dict()
|
|
saved.optimizer = optimizer.state_dict()
|
|
if args.rank == 0 and not args.test:
|
|
th.save(saved, checkpoint_tmp)
|
|
checkpoint_tmp.rename(checkpoint)
|
|
|
|
print(f"Epoch {epoch:03d}: "
|
|
f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} "
|
|
f"duration={human_seconds(duration)}")
|
|
|
|
del dmodel
|
|
model.load_state_dict(saved.best_state)
|
|
if args.eval_cpu:
|
|
device = "cpu"
|
|
model.to(device)
|
|
model.eval()
|
|
evaluate(model,
|
|
args.musdb,
|
|
eval_folder,
|
|
rank=args.rank,
|
|
world_size=args.world_size,
|
|
device=device,
|
|
save=args.save,
|
|
split=args.split_valid,
|
|
shifts=args.shifts,
|
|
workers=args.eval_workers)
|
|
model.to("cpu")
|
|
save_model(model, args.models / f"{name}.th")
|
|
if args.rank == 0:
|
|
print("done")
|
|
done.write_text("done")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|