头图

Distributed Operation Barrier of PyTorch

Original document: https://www.yuque.com/lart/ugkv9f/gy7sva

The concept of barrier

For the concept of barrier, please refer to the introduction in Wiki: A synchronization barrier (Barrier) is a synchronization method in parallel computing. For a group of processes or threads, a synchronization barrier in a program means that any thread/process executing after that must wait until all threads/processes have reached this point before continuing to execute the following.

It should be noted here that the barrier method is not unique to pytorch. This is a basic concept in parallel computing, and this concept and operation may also be involved in other parallel computing scenarios. This article mainly discusses the situation in pytorch.

torch.distributed.barrier(group=None, async_op=False, device_ids=None)

Synchronizes all processes.

This collective blocks processes until the whole group enters this function, if async_op is False, or if async work handle is called on wait().

Parameters
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
async_op (bool, optional) – Whether this op should be an async op
device_ids ([int], optional) – List of device/GPU ids. Valid only for NCCL backend.

Returns
Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

During multi-card training, because different GPUs are often set in different processes, sometimes in order to perform some tasks in a separate process, but at the same time want to limit the execution progress of other processes, there is a way to use barrier need.
A practical scenario is to prepare the data set: we only need to process it in process 0, other processes do not need to perform this task, but the subsequent work of other processes depends on the prepared data. Therefore, it is necessary to block other processes during the execution of process 0 to make it enter the waiting state. Wait until it's dealt with, and then let it go together.

Under this requirement, a typical construction based on the context manager form is as follows:

# https://github.com/ultralytics/yolov5/blob/7d56d451241e94cd9dbe4fcb9bfba0e92c6e0e23/utils/torch_utils.py#L29-L38

@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """
    Decorator to make all processes in distributed training
    wait for each local_master to do something.
    """
    if local_rank not in [-1, 0]:
        dist.barrier(device_ids=[local_rank])
    yield
    if local_rank == 0:
        dist.barrier(device_ids=[0])

Details about barriers

# -*- coding: utf-8 -*-

import os
import time

import torch.distributed as dist
import torch.multiprocessing as mp


def ddp_test_v0(local_rank, word_size):
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    dist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)

    print("first before barrier{}\n".format(local_rank))
    if local_rank != 0:
        dist.barrier()
    print("first after barrier{}\n".format(local_rank))

    print("inter {}".format(local_rank))

    print("second before barrier{}\n".format(local_rank))
    if local_rank == 0:
        dist.barrier()
    print("second after barrier{}\n".format(local_rank))

    print("{} exit".format(local_rank))


def ddp_test_v1(local_rank, word_size):
    # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
    dist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)

    if local_rank != 0:
        print("1 before barrier{}\n".format(local_rank))
        start = time.time()
        time.sleep(5)
        dist.barrier()
        print(time.time() - start)
        print("1 after barrier{}\n".format(local_rank))
        dist.barrier()
        print("1 after barrier{}\n".format(local_rank))
    else:
        print("0 before barrier{}\n".format(local_rank))
        start = time.time()
        dist.barrier()
        print(time.time() - start)
        print("0 after barrier{}\n".format(local_rank))
        print("0 after barrier{}\n".format(local_rank))
        dist.barrier()
        print("0 after barrier{}\n".format(local_rank))

    print("{} exit".format(local_rank))


def main():
    world_size = 2
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
    mp.spawn(ddp_test_v0, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main()

Two examples are shown here. In fact dist.barrier , that is, its operation actually requires the same number of corresponding executions within each process before the corresponding block is blocked. becomes normal operation.
Let's look at the first example:

def ddp_test(local_rank, word_size):
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    dist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)

    print("first before barrier{}\n".format(local_rank))
    if local_rank != 0:
        dist.barrier()
    print("first after barrier{}\n".format(local_rank))

    print("inter {}".format(local_rank))

    print("second before barrier{}\n".format(local_rank))
    if local_rank == 0:
        dist.barrier()
    print("second after barrier{}\n".format(local_rank))

    print("{} exit".format(local_rank))

Its output is:

first before barrier1
first before barrier0


first after barrier0

inter 0
second before barrier0

second after barrier0

0 exit
first after barrier1

inter 1
second before barrier1

second after barrier1

1 exit

Process finished with exit code 0

As you can see, there are several details:

  • barrier , all operations were output by each GPU process.

    • Since local_rank=0 executes to its own visible barrier , multiple outputs will be output, while local_rank=1 has only one first before barrier1 .
  • second before barrier0 , No. 0 executes to its own barrier , which makes other processes no longer block and starts to run normally. Due to the time of the intermediate operation, first No. 0 outputs its own second after barrier0 and then exits, and then No. 1 also starts to output its own results.

It is worth noting here that the barrier different processes actually correspond to each other, and all processes must execute barrier once before they can be released again and proceed normally.
For the second piece of code:

def ddp_test_v1(local_rank, word_size):
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    dist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)

    if local_rank != 0:
        print("1 before barrier{}\n".format(local_rank))
        start = time.time()
        time.sleep(5)
        dist.barrier()
        print(time.time() - start)
        print("1 after barrier{}\n".format(local_rank))
        dist.barrier()
        print("1 after barrier{}\n".format(local_rank))
    else:
        print("0 before barrier{}\n".format(local_rank))
        start = time.time()
        dist.barrier()
        print(time.time() - start)
        print("0 after barrier{}\n".format(local_rank))
        print("0 after barrier{}\n".format(local_rank))
        dist.barrier()
        print("0 after barrier{}\n".format(local_rank))

    print("{} exit".format(local_rank))

Then there is output:

1 before barrier1
0 before barrier0


5.002117395401001
5.0021262168884281 after barrier1


0 after barrier0

0 after barrier0

0 after barrier0

0 exit
1 after barrier1

1 exit

Process finished with exit code 0

It can be seen that an important point is that print(time.time() - start) are basically the same. No matter how much the previous delay is, barrier is based on the longest time interval between barrier This further reflects the mutual restriction relationship between barrier And after 0 reaches its second barrier , it will make 1 run again. But this time 0 is the first to end.
In addition, it can be verified that if one of the two barrier in the code corresponding to a certain number, then the other will be caught in an infinite wait.
For example:


def ddp_test_v1(local_rank, word_size):
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    dist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)

    if local_rank != 0:
        print("1 before barrier{}\n".format(local_rank))
        start = time.time()
        time.sleep(5)
        dist.barrier()
        print(time.time() - start)
        print("1 after barrier{}\n".format(local_rank))
        # dist.barrier()
        print("1 after barrier{}\n".format(local_rank))
    else:
        print("0 before barrier{}\n".format(local_rank))
        start = time.time()
        time.sleep(3)
        dist.barrier()
        print(time.time() - start)
        print("0 after barrier{}\n".format(local_rank))
        print("0 after barrier{}\n".format(local_rank))
        dist.barrier()
        print("0 after barrier{}\n".format(local_rank))

    print("{} exit".format(local_rank))

output:

0 before barrier0
1 before barrier1


5.002458572387695
1 after barrier1

1 after barrier1

1 exit
5.002473831176758
0 after barrier0

0 after barrier0

Traceback (most recent call last):
  File "/home/lart/Coding/SODBetterProj/tools/dist_experiment_test.py", line 67, in <module>
    main()
  File "/home/lart/Coding/SODBetterProj/tools/dist_experiment_test.py", line 63, in main
    mp.spawn(ddp_test_v1, args=(world_size,), nprocs=world_size, join=True)
  File "/home/lart/miniconda3/envs/pt17/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 199, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/lart/miniconda3/envs/pt17/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 157, in start_processes
    while not context.join():
  File "/home/lart/miniconda3/envs/pt17/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 75, in join
    ready = multiprocessing.connection.wait(
  File "/home/lart/miniconda3/envs/pt17/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/home/lart/miniconda3/envs/pt17/lib/python3.8/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt

Process finished with exit code 137 (interrupted by signal 9: SIGKILL)

It will wait indefinitely at the second barrier
This feature is also mentioned in this answer:

when a process encounters a barrier it will block the position of the barrier is not important (not all processes have to enter the same if-statement, for instance) a process is blocked by a barrier until all processes have encountered a barrier, upon which the barrier is lifted for all processes

https://stackoverflow.com/a/59766443

Important References


lart
126 声望6 粉丝

生活就是肩膀痛和折腾