mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-24 15:30:11 +01:00
101 lines
3.2 KiB
Python
101 lines
3.2 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
"""Distributed training utilities.
|
|
"""
|
|
import logging
|
|
import pickle
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from torch.utils.data import DataLoader, Subset
|
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
|
|
|
from dora import distrib as dora_distrib
|
|
|
|
logger = logging.getLogger(__name__)
|
|
rank = 0
|
|
world_size = 1
|
|
|
|
|
|
def init():
|
|
global rank, world_size
|
|
if not torch.distributed.is_initialized():
|
|
dora_distrib.init()
|
|
rank = dora_distrib.rank()
|
|
world_size = dora_distrib.world_size()
|
|
|
|
|
|
def average(metrics, count=1.):
|
|
if isinstance(metrics, dict):
|
|
keys, values = zip(*sorted(metrics.items()))
|
|
values = average(values, count)
|
|
return dict(zip(keys, values))
|
|
if world_size == 1:
|
|
return metrics
|
|
tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32)
|
|
tensor *= count
|
|
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
|
|
return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist()
|
|
|
|
|
|
def wrap(model):
|
|
if world_size == 1:
|
|
return model
|
|
else:
|
|
return DistributedDataParallel(
|
|
model,
|
|
# find_unused_parameters=True,
|
|
device_ids=[torch.cuda.current_device()],
|
|
output_device=torch.cuda.current_device())
|
|
|
|
|
|
def barrier():
|
|
if world_size > 1:
|
|
torch.distributed.barrier()
|
|
|
|
|
|
def share(obj=None, src=0):
|
|
if world_size == 1:
|
|
return obj
|
|
size = torch.empty(1, device='cuda', dtype=torch.long)
|
|
if rank == src:
|
|
dump = pickle.dumps(obj)
|
|
size[0] = len(dump)
|
|
torch.distributed.broadcast(size, src=src)
|
|
# size variable is now set to the length of pickled obj in all processes
|
|
|
|
if rank == src:
|
|
buffer = torch.from_numpy(np.frombuffer(dump, dtype=np.uint8).copy()).cuda()
|
|
else:
|
|
buffer = torch.empty(size[0].item(), device='cuda', dtype=torch.uint8)
|
|
torch.distributed.broadcast(buffer, src=src)
|
|
# buffer variable is now set to pickled obj in all processes
|
|
|
|
if rank != src:
|
|
obj = pickle.loads(buffer.cpu().numpy().tobytes())
|
|
logger.debug(f"Shared object of size {len(buffer)}")
|
|
return obj
|
|
|
|
|
|
def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
|
|
"""
|
|
Create a dataloader properly in case of distributed training.
|
|
If a gradient is going to be computed you must set `shuffle=True`.
|
|
"""
|
|
if world_size == 1:
|
|
return klass(dataset, *args, shuffle=shuffle, **kwargs)
|
|
|
|
if shuffle:
|
|
# train means we will compute backward, we use DistributedSampler
|
|
sampler = DistributedSampler(dataset)
|
|
# We ignore shuffle, DistributedSampler already shuffles
|
|
return klass(dataset, *args, **kwargs, sampler=sampler)
|
|
else:
|
|
# We make a manual shard, as DistributedSampler otherwise replicate some examples
|
|
dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
|
|
return klass(dataset, *args, shuffle=shuffle, **kwargs)
|