Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] deepspeed overlap_comm data race #5545

Open
yangyihang-bytedance opened this issue May 18, 2024 · 1 comment
Open

[BUG] deepspeed overlap_comm data race #5545

yangyihang-bytedance opened this issue May 18, 2024 · 1 comment
Labels
bug Something isn't working training

Comments

@yangyihang-bytedance
Copy link

Describe the bug
As illustrated below,DeepSpeed's overlap buffer design presents potential data race.
I have write a patch for bugfix.

Could you kindly help diagnosing and fix this issue?

whiteboard_exported_image

To Reproduce

debug.py

import argparse
from this import d
import deepspeed.runtime.zero.stage_1_and_2
import torch
import torch.nn
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import set_seed
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.moe.utils import is_moe_param

def patch_deepspeed():
    def backward(self, loss, retain_graph=False):
        """
        :attr:`backward` performs the following steps:

        1. fp32_loss = loss.float()
        2. scaled_loss = fp32_loss*loss_scale
        3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
        """
        self.micro_step_id += 1

        if self.contiguous_gradients:
            self.ipg_buffer = []
            self.ipg_events = []  # 添加的代码,yihang: 缓冲区 event
            buf_0 = torch.empty(int(self.reduce_bucket_size),
                                dtype=self.dtype,
                                device=get_accelerator().current_device_name())
            self.ipg_buffer.append(buf_0)
            self.ipg_events.append(None)  # 添加的代码,yihang: 缓冲区 event

            # Use double buffers to avoid data access conflict when overlap_comm is enabled.
            if self.overlap_comm:
                buf_1 = torch.empty(int(self.reduce_bucket_size),
                                    dtype=self.dtype,
                                    device=get_accelerator().current_device_name())
                self.ipg_buffer.append(buf_1)
                self.ipg_events.append(None)   # 添加的代码,yihang: 缓冲区 event

            self.ipg_index = 0

        if self.custom_loss_scaler:
            scaled_loss = self.external_loss_scale * loss
            scaled_loss.backward()
        else:
            self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)

        # Only for Stage 1, Mode 2
        if self.use_grad_accum_attribute:
            self.fill_grad_accum_attribute()

    def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):

        grad_reduc = self.get_gradient_for_reduction(param)
        if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size:
            self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel())
            self.reduce_ipg_grads()

            # 添加的代码,yihang: 缓冲区 event
            if self.contiguous_gradients and self.overlap_comm and not get_accelerator().is_synchronized_device():
                with get_accelerator().stream(self.reduction_stream):
                    current_event = torch.cuda.Event() # yihang: 在 reduction_stream 上创建一个事件
                    current_event.record(self.reduction_stream)

                    self.ipg_events[self.ipg_index] = current_event

            if self.contiguous_gradients and self.overlap_comm:
                # Swap ipg_index between 0 and 1
                self.ipg_index = 1 - self.ipg_index

            # 添加的代码,yihang: 缓冲区 event
            prev_event = self.ipg_events[self.ipg_index]
            if prev_event is not None:
                prev_event.wait()

            self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", param.numel())

        param_id = self.get_param_id(param)
        assert self.params_already_reduced[param_id] == False, \
            f"The parameter {param_id} has already been reduced. \
            Gradient computed twice for this partition. \
            Multiple gradient reduction is currently not supported"

        if self.contiguous_gradients:
            if param.numel() > self.reduce_bucket_size:
                self.extra_large_param_to_reduce = param
            else:
                # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
                new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel())
                new_grad_tensor.copy_(grad_reduc.view(-1))
                grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc)

        self.elements_in_ipg_bucket += param.numel()

        assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient"

        self.grads_in_ipg_bucket.append(grad_reduc)
        self.params_in_ipg_bucket.append((i, param, param_id))

        #make sure the average tensor function knows how to average the gradients
        if is_moe_param(param):
            self.ipg_bucket_has_moe_params = True

        self.report_ipg_memory_usage("End ipg_remove_grads", 0)

    deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer.backward = backward
    deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer.reduce_independent_p_g_buckets_and_remove_grads = reduce_independent_p_g_buckets_and_remove_grads

class DummyModule(torch.nn.Module):
    def __init__(self, n_layer=16, hidden_size=1024, vocab_size=1024) -> None:
        super(DummyModule, self).__init__()

        self._vocab_size = vocab_size

        self.embs = torch.nn.ModuleList([
            torch.nn.Embedding(vocab_size, hidden_size) for _ in range(n_layer)
        ])

    def init_weights(self):
        for emb in self.embs:
            emb.weight.data.normal_(mean=0.0, std=0.0002)
            if emb.padding_idx is not None:
                emb.weight.data[emb.padding_idx].zero_()

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        idx = torch.range(0, self._vocab_size - 1, device=input.device, dtype=torch.long)

        hidden_states = input
        for emb in self.embs:
            hidden_states = hidden_states + emb(idx)

        return hidden_states.mean()

def train():
    torch.use_deterministic_algorithms(True)

    accelerator = Accelerator(project_dir="./outputs")

    device = accelerator.device
    dtype = torch.float16

    set_seed(42 + accelerator.process_index)

    model = DummyModule()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    model, optimizer = accelerator.prepare(model, optimizer)

    for global_step in range(20):
        input = torch.rand(1, dtype=dtype, device=device)
        label = torch.rand(1, dtype=dtype, device=device)

        optimizer.zero_grad()
        output = model(input)
        
        loss = F.mse_loss(output.float(), label.float())

        accelerator.backward(loss)
        optimizer.step()

        global_grad_norm = -100.0
        if hasattr(optimizer.optimizer, '_global_grad_norm'):
            global_grad_norm = optimizer.optimizer._global_grad_norm

        if accelerator.process_index == 0:
            print(f'rank [{accelerator.process_index}] global_step [{global_step}] loss [{loss.item():.9}] grad_norm [{global_grad_norm:.9}]')

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--patch_deepspeed", action='store_true', help="patch deepspeed")
    args = parser.parse_args()

    if args.patch_deepspeed:
        patch_deepspeed()

    train()

if __name__ == '__main__':
    main()

Unexpected behavior

System info (please complete the following information):

  • GPU count and types [one machines with x2 A100s each]

Launcher context

dp_zero1_fp16.yaml

{
    "train_micro_batch_size_per_gpu": 6,
    "steps_per_print": 100,
    "prescale_gradients": false,
    "zero_allow_untested_optimizer": true,
    "gradient_accumulation_steps": "auto",
    "bf16": {
        "enabled": false
    },
    "fp16": {
        "enabled": true
    },
    "wall_clock_breakdown": false,
    "gradient_clipping": 1.0,
    "zero_optimization": {
        "stage": 1,
        "allgather_partitions": true,
        "reduce_scatter": true,
        "allgather_bucket_size": 1e8,
        "reduce_bucket_size": 1048576,
        "stage3_max_reuse_distance": 2e9,
        "overlap_comm": true,
        "contiguous_gradients": true
    }
}

debug.sh

#!/bin/bash

set -ex

num_processes=2

export TORCH_CUDA_SANITIZER=1

accelerate launch --main_process_ip $main_host --main_process_port $main_port \
    --num_machines 1 --machine_rank 0 --num_processes $num_processes \
    --use_deepspeed --deepspeed_config_file dp_zero1_fp16.yaml --deepspeed_multinode_launcher standard \
    debug.py

Additional context
A similar issue is present in stage3 as well, yet I have not prepared the patch for it.

@yangyihang-bytedance yangyihang-bytedance added bug Something isn't working training labels May 18, 2024
@yangyihang-bytedance
Copy link
Author

Additional context

deepspeed==0.14.2
CSAN detected a possible data race on tensor with data pointer 140108946735104
Access by stream 0 during kernel:
aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
writing to argument(s) self, and to the output
With stack trace:
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 903, in reduce_partition_and_remove_grads
    self.reduce_ready_partitions_and_remove_grads(param, i)
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1416, in reduce_ready_partitions_and_remove_grads
    self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 949, in reduce_independent_p_g_buckets_and_remove_grads
    new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel())
  File "/usr/local/lib/python3.9/dist-packages/torch/cuda/_sanitizer.py", line 570, in __torch_dispatch__
    errors = self.event_handler._handle_kernel_launch(
  File "/usr/local/lib/python3.9/dist-packages/torch/cuda/_sanitizer.py", line 371, in _handle_kernel_launch
    stack_trace = traceback.StackSummary.extract(

Previous access by stream 152815408 during kernel:
aten::view(Tensor(a) self, SymInt[] size) -> Tensor(a)
writing to argument(s) self, and to the output
With stack trace:
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 903, in reduce_partition_and_remove_grads
    self.reduce_ready_partitions_and_remove_grads(param, i)
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1416, in reduce_ready_partitions_and_remove_grads
    self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 932, in reduce_independent_p_g_buckets_and_remove_grads
    self.reduce_ipg_grads()
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1367, in reduce_ipg_grads
    self.average_tensor(self.ipg_buffer[self.ipg_index].narrow(0, 0, self.elements_in_ipg_bucket))
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1127, in average_tensor
    self.allreduce_and_scatter(buckets[bucket_key],
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1031, in allreduce_and_scatter
    self.allreduce_and_copy_with_multiple_ranks(small_bucket,
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1005, in allreduce_and_copy_with_multiple_ranks
    for buf, synced, bucket_rank in zip(small_bucket, self.unflatten(allreduced, small_bucket), bucket_ranks):
  File "/usr/local/lib/python3.9/dist-packages/torch/_utils.py", line 534, in _unflatten_dense_tensors
    return torch._C._nn.unflatten_dense_tensors(flat, tensors)
  File "/usr/local/lib/python3.9/dist-packages/torch/cuda/_sanitizer.py", line 570, in __torch_dispatch__
    errors = self.event_handler._handle_kernel_launch(
  File "/usr/local/lib/python3.9/dist-packages/torch/cuda/_sanitizer.py", line 371, in _handle_kernel_launch
    stack_trace = traceback.StackSummary.extract(

github-merge-queue bot pushed a commit that referenced this issue Jun 10, 2024
`deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer.average_tensor`
only sets reduction stream waiting for default stream. This is ok in
cases where the computation time is longer than the communication time,
but when the communication time is longer, it may result in a rewrite of
the ipg_buffer when the communication is not completed.



![image](https://github.com/microsoft/DeepSpeed/assets/35059704/950cbf8a-f439-4cf9-a364-dcdfd47f46a0)



To fix this bug, the easiest way is just add default stream to wait for
reduction stream at the **same point**. For example, in point 1, the
`reduction stream` needs to wait for '2', so we add a wait_stream to
`reduction stream` waiting for `default stream`. Also, the `default
stream` needs to wait for 'A', so we need to add a wait_stream to
`default stream` waiting for `reduction stream` before the 'B'.


![image](https://github.com/microsoft/DeepSpeed/assets/35059704/588a9469-d3f9-4c39-976d-3ae0502cf1d1)



Compared with the modification of
#5523, wait_stream does not
cause host synchronization.

Compared with the modification of
#5545, the modification is
more simple and the logic is the same, just waiting for what needs to
wait.

---

With this modification, losses of Qwen-1.5 with and without overlap_comm
are totally identical.


![image](https://github.com/microsoft/DeepSpeed/assets/35059704/4d48d54e-e55b-4230-8b99-93549910a43f)

---

On the contrary, there is an obvious gap with a small sequence length,
which means a short computation time.


![image](https://github.com/microsoft/DeepSpeed/assets/35059704/c80af498-3358-4e36-9b13-8f266551d51d)

Co-authored-by: gp513 <guopeng34@huawei.com>
Co-authored-by: CurryRice233 <nmeia@qq.com>
Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
sfc-gh-reyazda pushed a commit to Snowflake-Labs/DeepSpeed that referenced this issue Jun 10, 2024
`deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer.average_tensor`
only sets reduction stream waiting for default stream. This is ok in
cases where the computation time is longer than the communication time,
but when the communication time is longer, it may result in a rewrite of
the ipg_buffer when the communication is not completed.



![image](https://github.com/microsoft/DeepSpeed/assets/35059704/950cbf8a-f439-4cf9-a364-dcdfd47f46a0)



To fix this bug, the easiest way is just add default stream to wait for
reduction stream at the **same point**. For example, in point 1, the
`reduction stream` needs to wait for '2', so we add a wait_stream to
`reduction stream` waiting for `default stream`. Also, the `default
stream` needs to wait for 'A', so we need to add a wait_stream to
`default stream` waiting for `reduction stream` before the 'B'.


![image](https://github.com/microsoft/DeepSpeed/assets/35059704/588a9469-d3f9-4c39-976d-3ae0502cf1d1)



Compared with the modification of
microsoft#5523, wait_stream does not
cause host synchronization.

Compared with the modification of
microsoft#5545, the modification is
more simple and the logic is the same, just waiting for what needs to
wait.

---

With this modification, losses of Qwen-1.5 with and without overlap_comm
are totally identical.


![image](https://github.com/microsoft/DeepSpeed/assets/35059704/4d48d54e-e55b-4230-8b99-93549910a43f)

---

On the contrary, there is an obvious gap with a small sequence length,
which means a short computation time.


![image](https://github.com/microsoft/DeepSpeed/assets/35059704/c80af498-3358-4e36-9b13-8f266551d51d)

Co-authored-by: gp513 <guopeng34@huawei.com>
Co-authored-by: CurryRice233 <nmeia@qq.com>
Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

1 participant