最近我在做PyTorch的Dataloader相关的开发,有一个问题让我比较在意:PyTorch的Dataloader在启动多个进程读取样本的时候,这些数据是怎么在进程之间进行传输的?会不会引入多余的内存拷贝?
Dataloader使用multiprocess.Queue
来传输数据
先简单介绍一下Dataloader的多进程模式,Dataloader在构造的时候,若num_workers
不为0,就会启动num_workers
个worker进程,然后主进程会向worker进程分发读取任务,worker进程读到数据之后,再把数据放到队列中供主进程取用。Worker进程所执行的代码片段如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
|
while watchdog.is_alive():
try:
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
if isinstance(r, _ResumeIteration):
# Acknowledge the main process
data_queue.put((r, None))
iteration_end = False
# Recreate the fetcher for worker-reuse policy
fetcher = _DatasetKind.create_fetcher(
dataset_kind, dataset, auto_collation, collate_fn, drop_last)
continue
elif r is None:
# Received the final signal
assert done_event.is_set() or iteration_end
break
elif done_event.is_set() or iteration_end:
# `done_event` is set. But I haven't received the final signal
# (None) yet. I will keep continuing until get it, and skip the
# processing steps.
continue
idx, index = r
data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
if init_exception is not None:
data = init_exception
init_exception = None
else:
try:
data = fetcher.fetch(index)
except Exception as e:
if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
data = _IterableDatasetStopIteration(worker_id)
# Set `iteration_end`
# (1) to save future `next(...)` calls, and
# (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
iteration_end = True
else:
# It is important that we don't store exc_info in a variable.
# `ExceptionWrapper` does the correct thing.
# See NOTE [ Python Traceback Reference Cycle Problem ]
data = ExceptionWrapper(
where="in DataLoader worker process {}".format(worker_id))
data_queue.put((idx, data))
del data, idx, index, r # save memory
|
其中,data_queue
通常是一个torch.multiprocessing.Queue
的实例。
看到这里,似乎Dataloader只是平平无奇地使用了torch.multiprocessing.Queue
的接口,难道torch.multiprocessing.Queue
有一些高级的技巧?
继续看torch.multiprocessing.Queue
的代码,发现它只是简单地把multiprocessing.Queue
包了一下。众所周知,
multiprocessing.Queue
在Linux使用socket来实现,那难道读上来的数据需要在socket之间传来传去吗,效率也太低了吧?!不对,PyTorch一定还有其他骚操作。
Tensor
(CPU Tensor)在multiprocessing.Queue
中的序列化和反序列化
在通常的用法里,Dataloader从Dataset里读出来的数据都会被collate_fn
转成CPU Tensor,那我们就继续看看,Tensor是怎么在队列中序列化和反序列化的。
可以看到,torch.Tensor
重载了__reduce_ex__()
函数,序列化的时候只会用到 Tensor.storage
, Tensor.storage_offset
, size
, stride
, requires_grad
和backward_hooks
;
而反序列化的torch._utils._rebuild_tensor_v2()
也只会用到以上信息。
multiprocessing.Queue
是使用pickle
来做序列化和反序列化的,而重载__reduce_ex__()
正是自定义序列化反序列化方式的方法之一。
那看起来,CPU Tensor在进程中传输时,是在接收进程中把Tensor重新构建了一遍,而构建Tensor时候用到的信息, Tensor.storage_offset
, size
, stride
, requires_grad
和backward_hooks
,都只是用于描述Tensor的meta信息,实际和数据相关的,就只有Tensor.storage
了。
Tensor.Storage
的序列化与反序列化
Tensor.Storage
同样重载了pickle的序列化与反序列化过程,在torch/multiprocessing/reduction.py
中,给Tensor.Storage
注册了reduce函数reduce_storage
.
这里为什么使用copyreg
库而不是重载__reduce__()
, 在copyreg
的注释里说copyreg是专用于C extension的.按照这个说法,在reductions.py里为Tensor
注册的reduce function应该是没有起效的。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
def reduce_storage(storage):
from . import get_sharing_strategy
if storage.is_cuda:
raise RuntimeError("Cannot pickle CUDA storage; try pickling a CUDA tensor instead")
elif get_sharing_strategy() == 'file_system':
metadata = storage._share_filename_()
cache_key = metadata[1]
rebuild = rebuild_storage_filename
storage._shared_incref()
elif storage.size() == 0:
# This is special cased because Empty tensors
# (with size 0) cannot be mmapped.
return (rebuild_storage_empty, (type(storage),))
else:
fd, size = storage._share_fd_()
df = multiprocessing.reduction.DupFd(fd)
cache_key = fd_id(fd)
metadata = (df, size)
rebuild = rebuild_storage_fd # type: ignore[assignment]
shared_cache[cache_key] = StorageWeakRef(storage)
return (rebuild, (type(storage),) + metadata)
|
这个函数首先根据环境中的sharing strategy来决定共享内存的使用方式,然后若storage原本不在共享内存中的话,就把它拷到共享内存中去,比如_share_fd_()
的实现。
小结
看到这里,我们可以大概得出一个结论了。
Worker进程从Dataset中读出来的Tensor本身是普通的CPU Tensor,但当把它放到multiprocessing.Queue
中去的时候,这个Tensor的数据会被拷到共享内存中,Queue只会发送这个Tensor所具有的meta信息,主进程接到这些meta信息之后,就可以从共享内存中的数据重新构建Tensor。
这里有一点值得注意,如果你想要验证这个过程,在发送进程调用multiprocessing.Queue.put()
之后,立即调用Tensor.is_shared()
并不会返回True
,因为put()
是非阻塞的,只有当Tensor
被QueueFeedThread
序列化完成之后再调用is_shared()
,才会得到预期中的结果。
Dataloader的小心机
在default_collate
中,有这样一段小代码:
1
2
3
4
5
6
7
8
9
|
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
|
含义是,如果batch中的数据已经是Tensor了,那么,如果这是一个Worker进程,就开一段共享内存把这个batch放进去。因为collate的时候无论如何都会有一次内存拷贝(除非底层的Dataset有其他保证),那么这个操作就省掉了之后放进队列中的那一次内存拷贝。
不过我随便找了几个自定义了collate_fn
的模型看了一下他们写的collate过程,是没有把这一点考虑进去的。这也算是Dataloader的一个小心机吧,有缘人就用得上。