1
0
mirror of synced 2024-11-28 01:10:56 +01:00

bugfix: leaked semaphore error (#309)

* use config for n_cpu

* rm import

* fix process loop

* unuse mp.spawn

ref. https://discuss.pytorch.org/t/how-to-fix-a-sigsegv-in-pytorch-when-using-distributed-training-e-g-ddp/113518/10

* fix commentout
This commit is contained in:
N. Hiroto 2023-05-19 18:56:06 +09:00 committed by GitHub
parent 563c64ded9
commit 080b7cdc31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 17 deletions

View File

@ -154,7 +154,7 @@ if __name__ == "__main__":
f0method,
),
)
p.start()
ps.append(p)
for p in ps:
p.join()
p.start()
for i in range(n_p):
ps[i].join()

View File

@ -1,7 +1,6 @@
import torch, os, traceback, sys, warnings, shutil, numpy as np
os.environ["no_proxy"] = "localhost, 127.0.0.1, ::1"
from multiprocessing import cpu_count
import threading
from time import sleep
from subprocess import Popen
@ -25,7 +24,6 @@ import ffmpeg
i18n = I18nAuto()
# 判断是否有能用来训练和加速推理的N卡
ncpu = cpu_count()
ngpu = torch.cuda.device_count()
gpu_infos = []
mem = []
@ -1436,10 +1434,10 @@ with gr.Blocks() as app:
)
np7 = gr.Slider(
minimum=0,
maximum=ncpu,
maximum=config.n_cpu,
step=1,
label=i18n("提取音高和处理数据使用的CPU进程数"),
value=ncpu,
value=config.n_cpu,
interactive=True,
)
with gr.Group(): # 暂时单人的, 后面支持最多4人的#数据处理

View File

@ -66,18 +66,22 @@ class EpochRecorder:
def main():
# n_gpus = torch.cuda.device_count()
n_gpus = torch.cuda.device_count()
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "51545"
mp.spawn(
run,
nprocs=n_gpus,
args=(
children = []
for i in range(n_gpus):
subproc = mp.Process(target=run, args=(
i,
n_gpus,
hps,
),
)
))
children.append(subproc)
subproc.start()
for i in range(n_gpus):
children[i].join()
def run(rank, n_gpus, hps):

View File

@ -115,10 +115,10 @@ class PreProcess:
p = multiprocessing.Process(
target=self.pipeline_mp, args=(infos[i::n_p],)
)
p.start()
ps.append(p)
for p in ps:
p.join()
p.start()
for i in range(n_p):
ps[i].join()
except:
println("Fail. %s" % traceback.format_exc())