多进程最佳实践

优质
小牛编辑
134浏览
2023-12-01

译者:cvley

torch.multiprocessing 是 Python 的 multiprocessing 的直接替代模块。它支持完全相同的操作,但进行了扩展,这样所有的张量就可以通过一个 multiprocessing.Queue 进行传递,将数据移动到共享内存并只将句柄传递到另一个进程。

注意

当一个 Tensor 传递到另一个进程时,Tensor 的数据是共享的。如果 torch.Tensor.grad 不是 None, 也会被共享。在一个没有 torch.Tensor.grad 域的 Tensor 被送到其他进程时,一个标准的进程专用的 .grad Tensor 会被创建,而它在所有的进程中不会自动被共享,与 Tensor 数据的共享方式不同。

这就允许实现各种训练方法,比如 Hogwild、A3C,或者其他那些需要异步操作的方法。

共享 CUDA 张量

进程间共享 CUDA 张量仅支持 Python 3,使用的是 spawn 或者 forkserver 启动方法。Python 2 中的 multiprocessing 仅使用 fork 来创建子进程,而 CUDA 运行时不支持该方法。

警告

CUDA API 需要分配给其他进程的显存在它们还在使用的情况下一直有效。你需要仔细确保共享的 CUDA 张量若非必须,不会超出使用范围。这对于共享模型参数不会是一个问题,但传递其他类型的数据时需要谨慎。注意该限制并不适用于共享 CPU 内存。

也可以参考:使用 nn.DataParallel 替代 multiprocessing

最佳实践和提示

避免和处理死锁

当创建一个新进程时,很多情况会发生,最常见的就是后台线程间的死锁。如果任何一个线程有锁的状态或者引入了一个模块,然后调用了fork,子进程很有可能处于中断状态,并以另外的方式死锁或者失败。注意即使你没这么做,Python 内建的库也有可能这么做——无需舍近求远,multiprocessing即是如此。multiprocessing.Queue 实际上是一个非常复杂的类,可以创建多个线程用于串行、发送和接收对象,它们也会出现前面提到的问题。如果你发现自己遇到了这种情况,尝试使用 multiprocessing.queues.SimpleQueue,它不会使用额外的线程。

我们在尽最大努力为你化繁为简,确保不会发生死锁的情况,但有时也会出现失控的情况。如果你遇到任何暂时无法解决的问题,可以在论坛上求助,我们将会研究是否可以修复。

通过 Queue 传递重用缓存

记住每次将一个 Tensor 放进一个 multiprocessing.Queue 时,它就会被移动到共享内存中。如果它已经被共享,那将不会有操作,否则将会触发一次额外的内存拷贝,而这将会拖慢整个进程。即使你有一个进程池把数据发送到一个进程,并把缓存送回来——这近乎于无操作,在发送下一个批次的数据时避免拷贝。

异步多进程训练(如Hogwild)

使用 torch.multiprocessing,可以异步训练一个模型,参数要么一直共享,要么周期性同步。在第一个情况下,我们建议传递整个模型的对象,而对于后一种情况,我们将以仅传递 state_dict()

我们建议使用 multiprocessing.Queue在进程间传递 PyTorch 对象。当使用fork命令时,可以进行诸如继承共享内存中的张量和存储的操作,然而这个操作容易产生问题,应该小心使用,仅建议高级用户使用。Queue,尽管有时不是一个那么优雅的解决方案,但在所有的情况下都可以合理使用。

警告

你应该注意那些不在if __name__ == '__main__'中的全局声明。如果使用了一个不是fork的系统调用,它们将会在所有子进程中执行。

Hogwild

在示例仓库中可以找到一个具体的Hogwild实现,但除了完整的代码结构之外,下面也有一个简化的例子:

import torch.multiprocessing as mp
from model import MyModel

def train(model):
    # Construct data_loader, optimizer, etc.
    for data, labels in data_loader:
        optimizer.zero_grad()
        loss_fn(model(data), labels).backward()
        optimizer.step()  # This will update the shared parameters

if __name__ == '__main__':
    num_processes = 4
    model = MyModel()
    # NOTE: this is required for the ``fork`` method to work
    model.share_memory()
    processes = []
    for rank in range(num_processes):
        p = mp.Process(target=train, args=(model,))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()