Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Version >0.14.0 leads to RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! #5538

Closed
pacman100 opened this issue May 15, 2024 · 15 comments
Assignees
Labels
bug Something isn't working training

Comments

@pacman100
Copy link
Contributor

pacman100 commented May 15, 2024

Describe the bug
[BUG] Version >0.14.0 leads to RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

To Reproduce
Steps to reproduce the behavior:

  1. Self hosted runner failures: https://github.com/huggingface/transformers/actions/runs/9089085470/job/24979681853
  2. A test from it for quick reproduction:
cd transformers
export CUDA_VISIBLE_DEVICES=0,1
export RUN_SLOW=true
pytest -sv tests/deepspeed/test_deepspeed.py::TrainerIntegrationDeepSpeed::test_can_resume_training_normal_zero3_fp16_ds_optim_ds_scheduler

output error trace:

_ TrainerIntegrationDeepSpeed.test_can_resume_training_normal_zero3_fp16_ds_optim_ds_scheduler _

a = (<test_deepspeed.TrainerIntegrationDeepSpeed testMethod=test_can_resume_training_normal_zero3_fp16_ds_optim_ds_scheduler>,)
kw = {}

    @wraps(func)
    def standalone_func(*a, **kw):
>       return func(*(a + p.args), **p.kwargs, **kw)

/usr/local/lib/python3.8/dist-packages/parameterized/parameterized.py:620: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/deepspeed/test_deepspeed.py:825: in test_can_resume_training_normal
    trainer.train()
src/transformers/trainer.py:1885: in train
    return inner_training_loop(
src/transformers/trainer.py:2216: in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
src/transformers/trainer.py:3250: in training_step
    self.accelerator.backward(loss)
/usr/local/lib/python3.8/dist-packages/accelerate/accelerator.py:2117: in backward
    self.deepspeed_engine_wrapped.backward(loss, **kwargs)
/usr/local/lib/python3.8/dist-packages/accelerate/utils/deepspeed.py:175: in backward
    self.engine.step()
/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py:2[169](https://github.com/huggingface/transformers/actions/runs/9089085470/job/24979681853#step:11:170): in step
    self._take_model_step(lr_kwargs)
/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py:2075: in _take_model_step
    self.optimizer.step()
/usr/local/lib/python3.8/dist-packages/deepspeed/utils/nvtx.py:15: in wrapped_fn
    ret_val = func(*args, **kwargs)
/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/zero/stage3.py:2047: in step
    self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
/usr/local/lib/python3.8/dist-packages/deepspeed/utils/nvtx.py:15: in wrapped_fn
    ret_val = func(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <deepspeed.runtime.zero.stage3.DeepSpeedZeroOptimizer_Stage3 object at 0x7fed3aa26490>
sub_group_id = 0, total_norm = tensor(12.0575, device='cuda:0')

    @instrument_w_nvtx
    def unscale_and_clip_grads(self, sub_group_id, total_norm):
        # compute combined scale factor for this group
        combined_scale = self.loss_scale
        if self.clip_grad > 0.:
            # norm is in fact norm*scale
            clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
            if clip > 1:
                combined_scale = clip * self.loss_scale
    
>       self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale)
E       RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/zero/stage3.py:2117: RuntimeError
----------------------------- Captured stdout call -----------------------------
[2024-05-15 02:35:40,119] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed info: version=0.14.2, git-hash=unknown, git-branch=unknown
[2024-05-15 02:35:40,121] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False
[2024-05-15 02:35:41,676] [INFO] [logging.py:96:log_dist] [Rank 0] Using DeepSpeed Optimizer param name adamw as basic optimizer
[2024-05-15 02:35:41,676] [INFO] [logging.py:96:log_dist] [Rank 0] Removing param_group that has no 'params' in the basic Optimizer
[2024-05-15 02:35:41,676] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Basic Optimizer = DeepSpeedCPUAdam
[2024-05-15 02:35:41,677] [INFO] [utils.py:56:is_zero_supported_optimizer] Checking ZeRO support for optimizer=DeepSpeedCPUAdam type=<class 'deepspeed.ops.adam.cpu_adam.DeepSpeedCPUAdam'>
[2024-05-15 02:35:41,677] [INFO] [logging.py:96:log_dist] [Rank 0] Creating fp16 ZeRO stage 3 optimizer, MiCS is enabled False, Hierarchical params gather False
[2024-05-15 02:35:41,677] [INFO] [logging.py:96:log_dist] [Rank 0] Creating torch.float16 ZeRO stage 3 optimizer
Adam Optimizer #9 is created with AVX512 arithmetic capability.
Config: alpha=0.100000, betas=(0.900000, 0.999000), weight_decay=0.000000, adam_w=1
[2024-05-15 02:35:41,860] [INFO] [utils.py:779:see_memory_usage] Stage 3 initialize beginning
[2024-05-15 02:35:41,861] [INFO] [utils.py:780:see_memory_usage] MA 0.0 GB         Max_MA 0.75 GB         CA 0.0 GB         Max_CA 1 GB 
[2024-05-15 02:35:41,861] [INFO] [utils.py:787:see_memory_usage] CPU Virtual Memory:  used = 13.56 GB, percent = 21.9%
[2024-05-15 02:35:41,862] [INFO] [stage3.py:130:__init__] Reduce bucket size 1
[2024-05-15 02:35:41,862] [INFO] [stage3.py:131:__init__] Prefetch bucket size 0
[2024-05-15 02:35:41,999] [INFO] [utils.py:779:see_memory_usage] DeepSpeedZeRoOffload initialize [begin]
[2024-05-15 02:35:42,000] [INFO] [utils.py:780:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.0 GB         Max_CA 0 GB 
[2024-05-15 02:35:42,000] [INFO] [utils.py:787:see_memory_usage] CPU Virtual Memory:  used = 13.56 GB, percent = 21.9%
Parameter Offload: Total persistent parameters: 2 in 2 params
[2024-05-15 02:35:42,131] [INFO] [utils.py:779:see_memory_usage] DeepSpeedZeRoOffload initialize [end]
[2024-05-15 02:35:42,131] [INFO] [utils.py:780:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.0 GB         Max_CA 0 GB 
[2024-05-15 02:35:42,132] [INFO] [utils.py:787:see_memory_usage] CPU Virtual Memory:  used = 13.56 GB, percent = 21.9%
[2024-05-15 02:35:42,270] [INFO] [utils.py:779:see_memory_usage] Before creating fp16 partitions
[2024-05-15 02:35:42,270] [INFO] [utils.py:780:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.0 GB         Max_CA 0 GB 
[2024-05-15 02:35:42,271] [INFO] [utils.py:787:see_memory_usage] CPU Virtual Memory:  used = 13.56 GB, percent = 21.9%
[2024-05-15 02:35:42,403] [INFO] [utils.py:779:see_memory_usage] After creating fp16 partitions: 1
[2024-05-15 02:35:42,404] [INFO] [utils.py:780:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.0 GB         Max_CA 0 GB 
[2024-05-15 02:35:42,404] [INFO] [utils.py:787:see_memory_usage] CPU Virtual Memory:  used = 13.56 GB, percent = 21.9%
[2024-05-15 02:35:42,543] [INFO] [utils.py:779:see_memory_usage] Before creating fp32 partitions
[2024-05-15 02:35:42,544] [INFO] [utils.py:780:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.0 GB         Max_CA 0 GB 
[2024-05-15 02:35:42,544] [INFO] [utils.py:787:see_memory_usage] CPU Virtual Memory:  used = 13.56 GB, percent = 21.9%
[2024-05-15 02:35:42,675] [INFO] [utils.py:779:see_memory_usage] After creating fp32 partitions
[2024-05-15 02:35:42,676] [INFO] [utils.py:780:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.0 GB         Max_CA 0 GB 
[2024-05-15 02:35:42,676] [INFO] [utils.py:787:see_memory_usage] CPU Virtual Memory:  used = 13.56 GB, percent = 21.9%
[2024-05-15 02:35:42,805] [INFO] [utils.py:779:see_memory_usage] Before initializing optimizer states
[2024-05-15 02:35:42,806] [INFO] [utils.py:780:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.0 GB         Max_CA 0 GB 
[2024-05-15 02:35:42,806] [INFO] [utils.py:787:see_memory_usage] CPU Virtual Memory:  used = 13.56 GB, percent = 21.9%
[2024-05-15 02:35:42,941] [INFO] [utils.py:779:see_memory_usage] After initializing optimizer states
[2024-05-15 02:35:42,942] [INFO] [utils.py:780:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.0 GB         Max_CA 0 GB 
[2024-05-15 02:35:42,942] [INFO] [utils.py:787:see_memory_usage] CPU Virtual Memory:  used = 13.56 GB, percent = 21.9%
[2024-05-15 02:35:42,942] [INFO] [stage3.py:486:_setup_for_real_optimizer] optimizer state initialized
[2024-05-15 02:35:43,073] [INFO] [utils.py:779:see_memory_usage] After initializing ZeRO optimizer
[2024-05-15 02:35:43,074] [INFO] [utils.py:780:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.0 GB         Max_CA 0 GB 
[2024-05-15 02:35:43,074] [INFO] [utils.py:787:see_memory_usage] CPU Virtual Memory:  used = 13.56 GB, percent = 21.9%
[2024-05-15 02:35:43,074] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Final Optimizer = adamw
[2024-05-15 02:35:43,074] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed using configured LR scheduler = WarmupLR
[2024-05-15 02:35:43,074] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed LR Scheduler = <deepspeed.runtime.lr_schedules.WarmupLR object at 0x7fed3c460220>
[2024-05-15 02:35:43,074] [INFO] [logging.py:96:log_dist] [Rank 0] step=0, skipped=0, lr=[0.1], mom=[[0.9, 0.999]]
[2024-05-15 02:35:43,074] [INFO] [config.py:996:print] DeepSpeedEngine configuration:
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   activation_checkpointing_config  {
    "partition_activations": false, 
    "contiguous_memory_optimization": false, 
    "cpu_checkpointing": false, 
    "number_checkpoints": null, 
    "synchronize_checkpoint_boundary": false, 
    "profile": false
}
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   aio_config ................... {'block_size': 1048576, 'queue_depth': 8, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   amp_enabled .................. False
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   amp_params ................... False
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   autotuning_config ............ {
    "enabled": false, 
    "start_step": null, 
    "end_step": null, 
    "metric_path": null, 
    "arg_mappings": null, 
    "metric": "throughput", 
    "model_info": null, 
    "results_dir": "autotuning_results", 
    "exps_dir": "autotuning_exps", 
    "overwrite": true, 
    "fast": true, 
    "start_profile_step": 3, 
    "end_profile_step": 5, 
    "tuner_type": "gridsearch", 
    "tuner_early_stopping": 5, 
    "tuner_num_trials": 50, 
    "model_info_path": null, 
    "mp_size": 1, 
    "max_train_batch_size": null, 
    "min_train_batch_size": 1, 
    "max_train_micro_batch_size_per_gpu": 1.024000e+03, 
    "min_train_micro_batch_size_per_gpu": 1, 
    "num_tuning_micro_batch_sizes": 3
}
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   bfloat16_enabled ............. False
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   bfloat16_immediate_grad_update  False
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   checkpoint_parallel_write_pipeline  False
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   checkpoint_tag_validation_enabled  True
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   checkpoint_tag_validation_fail  False
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   comms_config ................. <deepspeed.comm.config.DeepSpeedCommsConfig object at 0x7fed3a9bfd60>
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   communication_data_type ...... None
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   compile_config ............... enabled=False backend='inductor' kwargs={}
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   compression_config ........... {'weight_quantization': {'shared_parameters': {'enabled': False, 'quantizer_kernel': False, 'schedule_offset': 0, 'quantize_groups': 1, 'quantize_verbose': False, 'quantization_type': 'symmetric', 'quantize_weight_in_forward': False, 'rounding': 'nearest', 'fp16_mixed_quantize': False, 'quantize_change_ratio': 0.001}, 'different_groups': {}}, 'activation_quantization': {'shared_parameters': {'enabled': False, 'quantization_type': 'symmetric', 'range_calibration': 'dynamic', 'schedule_offset': 1000}, 'different_groups': {}}, 'sparse_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'row_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'head_pruning': {'shared_parameters': {'enabled': False, 'method': 'topk', 'schedule_offset': 1000}, 'different_groups': {}}, 'channel_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'layer_reduction': {'enabled': False}}
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   curriculum_enabled_legacy .... False
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   curriculum_params_legacy ..... False
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   data_efficiency_config ....... {'enabled': False, 'seed': 1234, 'data_sampling': {'enabled': False, 'num_epochs': 1000, 'num_workers': 0, 'curriculum_learning': {'enabled': False}}, 'data_routing': {'enabled': False, 'random_ltd': {'enabled': False, 'layer_token_lr_schedule': {'enabled': False}}}}
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   data_efficiency_enabled ...... False
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   dataloader_drop_last ......... False
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   disable_allgather ............ False
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   dump_state ................... False
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   dynamic_loss_scale_args ...... {'init_scale': 2, 'scale_window': 1000, 'delayed_shift': 2, 'consecutive_hysteresis': False, 'min_scale': 1}
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   eigenvalue_enabled ........... False
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   eigenvalue_gas_boundary_resolution  1
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   eigenvalue_layer_name ........ bert.encoder.layer
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   eigenvalue_layer_num ......... 0
[2024-05-15 02:35:43,075] [INFO] [config.py:1000:print]   eigenvalue_max_iter .......... 100
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   eigenvalue_stability ......... 1e-06
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   eigenvalue_tol ............... 0.01
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   eigenvalue_verbose ........... False
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   elasticity_enabled ........... False
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   flops_profiler_config ........ {
    "enabled": false, 
    "recompute_fwd_factor": 0.0, 
    "profile_step": 1, 
    "module_depth": -1, 
    "top_modules": 1, 
    "detailed": true, 
    "output_file": null
}
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   fp16_auto_cast ............... False
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   fp16_enabled ................. True
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   fp16_master_weights_and_gradients  False
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   global_rank .................. 0
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   grad_accum_dtype ............. None
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   gradient_accumulation_steps .. 1
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   gradient_clipping ............ 1.0
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   gradient_predivide_factor .... 1.0
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   graph_harvesting ............. False
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   hybrid_engine ................ enabled=False max_out_tokens=512 inference_tp_size=1 release_inference_cache=False pin_parameters=True tp_gather_partition_size=8
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   initial_dynamic_scale ........ 2
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   load_universal_checkpoint .... False
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   loss_scale ................... 0
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   memory_breakdown ............. False
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   mics_hierarchial_params_gather  False
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   mics_shard_size .............. -1
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   monitor_config ............... tensorboard=TensorBoardConfig(enabled=False, output_path='', job_name='DeepSpeedJobName') wandb=WandbConfig(enabled=False, group=None, team=None, project='deepspeed') csv_monitor=CSVConfig(enabled=False, output_path='', job_name='DeepSpeedJobName') enabled=False
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   nebula_config ................ {
    "enabled": false, 
    "persistent_storage_path": null, 
    "persistent_time_interval": 100, 
    "num_of_version_in_retention": 2, 
    "enable_nebula_load": true, 
    "load_path": null
}
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   optimizer_legacy_fusion ...... False
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   optimizer_name ............... adamw
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   optimizer_params ............. {'lr': 0.1, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0.0}
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   pipeline ..................... {'stages': 'auto', 'partition': 'best', 'seed_layers': False, 'activation_checkpoint_interval': 0, 'pipe_partitioned': True, 'grad_partitioned': True}
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   pld_enabled .................. False
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   pld_params ................... False
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   prescale_gradients ........... False
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   scheduler_name ............... WarmupLR
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   scheduler_params ............. {'warmup_min_lr': 0, 'warmup_max_lr': 0.1, 'warmup_num_steps': 0}
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   seq_parallel_communication_data_type  torch.float32
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   sparse_attention ............. None
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   sparse_gradients_enabled ..... False
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   steps_per_print .............. inf
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   train_batch_size ............. 8
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   train_micro_batch_size_per_gpu  8
[2024-05-15 02:35:43,076] [INFO] [config.py:1000:print]   use_data_before_expert_parallel_  False
[2024-05-15 02:35:43,077] [INFO] [config.py:1000:print]   use_node_local_storage ....... False
[2024-05-15 02:35:43,077] [INFO] [config.py:1000:print]   wall_clock_breakdown ......... False
[2024-05-15 02:35:43,077] [INFO] [config.py:1000:print]   weight_quantization_config ... None
[2024-05-15 02:35:43,077] [INFO] [config.py:1000:print]   world_size ................... 1
[2024-05-15 02:35:43,077] [INFO] [config.py:1000:print]   zero_allow_untested_optimizer  False
[2024-05-15 02:35:43,077] [INFO] [config.py:1000:print]   zero_config .................. stage=3 contiguous_gradients=True reduce_scatter=True reduce_bucket_size=1 use_multi_rank_bucket_allreduce=True allgather_partitions=True allgather_bucket_size=500,000,000 overlap_comm=True load_from_fp32_weights=True elastic_checkpoint=False offload_param=DeepSpeedZeroOffloadParamConfig(device='cpu', nvme_path=None, buffer_count=5, buffer_size=100,000,000, max_in_cpu=1,000,000,000, pin_memory=True) offload_optimizer=DeepSpeedZeroOffloadOptimizerConfig(device='cpu', nvme_path=None, buffer_count=4, pin_memory=True, pipeline=False, pipeline_read=False, pipeline_write=False, fast_init=False, ratio=1.0) sub_group_size=1000000000 cpu_offload_param=None cpu_offload_use_pin_memory=None cpu_offload=None prefetch_bucket_size=0 param_persistence_threshold=10 model_persistence_threshold=sys.maxsize max_live_parameters=1000000000 max_reuse_distance=1000000000 gather_16bit_weights_on_model_save=True stage3_gather_fp16_weights_on_model_save=False ignore_unused_parameters=True legacy_stage1=False round_robin_gradients=False zero_hpz_partition_size=1 zero_quantized_weights=False zero_quantized_nontrainable_weights=False zero_quantized_gradients=False mics_shard_size=-1 mics_hierarchical_params_gather=False memory_efficient_linear=True pipeline_loading_checkpoint=False override_module_apply=True
[2024-05-15 02:35:43,077] [INFO] [config.py:1000:print]   zero_enabled ................. True
[2024-05-15 02:35:43,077] [INFO] [config.py:1000:print]   zero_force_ds_cpu_optimizer .. True
[2024-05-15 02:35:43,077] [INFO] [config.py:1000:print]   zero_optimization_stage ...... 3
[2024-05-15 02:35:43,077] [INFO] [config.py:986:print_user_config]   json = {
    "fp16": {
        "enabled": true, 
        "loss_scale": 0, 
        "loss_scale_window": 1000, 
        "initial_scale_power": 1, 
        "hysteresis": 2, 
        "min_loss_scale": 1
    }, 
    "bf16": {
        "enabled": false
    }, 
    "optimizer": {
        "type": "AdamW", 
        "params": {
            "lr": 0.1, 
            "betas": [0.9, 0.999], 
            "eps": 1e-08, 
            "weight_decay": 0.0
        }
    }, 
    "scheduler": {
        "type": "WarmupLR", 
        "params": {
            "warmup_min_lr": 0, 
            "warmup_max_lr": 0.1, 
            "warmup_num_steps": 0
        }
    }, 
    "zero_optimization": {
        "stage": 3, 
        "offload_optimizer": {
            "device": "cpu", 
            "pin_memory": true
        }, 
        "offload_param": {
            "device": "cpu", 
            "pin_memory": true
        }, 
        "overlap_comm": true, 
        "contiguous_gradients": true, 
        "sub_group_size": 1.000000e+09, 
        "reduce_bucket_size": 1, 
        "stage3_prefetch_bucket_size": 0.9, 
        "stage3_param_persistence_threshold": 10, 
        "stage3_max_live_parameters": 1.000000e+09, 
        "stage3_max_reuse_distance": 1.000000e+09, 
        "stage3_gather_16bit_weights_on_model_save": true
    }, 
    "gradient_accumulation_steps": 1, 
    "gradient_clipping": 1.0, 
    "steps_per_print": inf, 
    "train_batch_size": 8, 
    "train_micro_batch_size_per_gpu": 8, 
    "wall_clock_breakdown": false
}
----------------------------- Captured stderr call -----------------------------
PyTorch: setting up devices
torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. In order to use Torch DDP, launch your script with `python -m torch.distributed.launch
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
Using auto half precision backend
***** Running training *****
  Num examples = 128
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 48
  Number of trainable parameters = 2

  0%|          | 0/48 [00:00<?, ?it/s]
_ TrainerIntegrationDeepSpeed.test_can_resume_training_normal_zero3_fp16_ds_optim_hf_scheduler _

a = (<test_deepspeed.TrainerIntegrationDeepSpeed testMethod=test_can_resume_training_normal_zero3_fp16_ds_optim_hf_scheduler>,)
kw = {}

    @wraps(func)
    def standalone_func(*a, **kw):
>       return func(*(a + p.args), **p.kwargs, **kw)

/usr/local/lib/python3.8/dist-packages/parameterized/parameterized.py:620: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/deepspeed/test_deepspeed.py:825: in test_can_resume_training_normal
    trainer.train()
src/transformers/trainer.py:1885: in train
    return inner_training_loop(
src/transformers/trainer.py:2216: in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
src/transformers/trainer.py:3250: in training_step
    self.accelerator.backward(loss)
/usr/local/lib/python3.8/dist-packages/accelerate/accelerator.py:2117: in backward
    self.deepspeed_engine_wrapped.backward(loss, **kwargs)
/usr/local/lib/python3.8/dist-packages/accelerate/utils/deepspeed.py:[175](https://github.com/huggingface/transformers/actions/runs/9089085470/job/24979681853#step:11:176): in backward
    self.engine.step()
/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py:2169: in step
    self._take_model_step(lr_kwargs)
/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py:2075: in _take_model_step
    self.optimizer.step()
/usr/local/lib/python3.8/dist-packages/deepspeed/utils/nvtx.py:15: in wrapped_fn
    ret_val = func(*args, **kwargs)
/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/zero/stage3.py:2047: in step
    self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
/usr/local/lib/python3.8/dist-packages/deepspeed/utils/nvtx.py:15: in wrapped_fn
    ret_val = func(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <deepspeed.runtime.zero.stage3.DeepSpeedZeroOptimizer_Stage3 object at 0x7fed3a9bfd60>
sub_group_id = 0, total_norm = tensor(12.0575, device='cuda:0')

    @instrument_w_nvtx
    def unscale_and_clip_grads(self, sub_group_id, total_norm):
        # compute combined scale factor for this group
        combined_scale = self.loss_scale
        if self.clip_grad > 0.:
            # norm is in fact norm*scale
            clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
            if clip > 1:
                combined_scale = clip * self.loss_scale
    
>       self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale)
E       RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
@loadams
Copy link
Contributor

loadams commented May 15, 2024

Hi @pacman100 - thanks for making this issue here to better track it. Does this also happen with the latest changes in the master branch?

@pacman100
Copy link
Contributor Author

Hi @pacman100 - thanks for making this issue here to better track it. Does this also happen with the latest changes in the master branch?

I confirm that test passes when using the master branch.It would be great to have a patch release if possible.

@bug-fixed
Copy link

The problem will exist in Zero3-offload. It seems the problem lies in the partition parameter's part in Zero3 if the model has multiple parallel modules or frozen parameters, the offload procedure cannot load correct parameter mapping. Please take a look and fix it. Thanks.

@tjruwase
Copy link
Contributor

@bug-fixed, are you able to share repro for zero3-offload case? Thanks!

@bug-fixed
Copy link

@tjruwase please try this example (https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/finetune.sh) with zero3-offload. Thanks.

@jomayeri
Copy link
Contributor

jomayeri commented May 20, 2024

@bug-fixed that repro does not work. Please provide a more precise single script reproduction.

@bug-fixed
Copy link

@jomayeri , thanks for the response. The file needed in the script can be downloaded in here: https://huggingface.co/liuhaotian/llava-v1.5-mlp2x-336px-pretrain-vicuna-13b-v1.5/tree/main.

Unfortunately, I think it's difficult for me to prepare a more concise script, apologize for this. I checked the model with only Llama-3, the Zero3-offload works fine. But when I tested it using the script above, i.e., equipped with a vision transformer and another simple linear module, the problem occurred. I guess many factors may lead to the problem. Please note that the conclusion in my previous comment might be wrong because of my very limited knowledge in DeepSpeed. I have the following partial error information for your check:

**deepspeed_aio:   fstat for read failed on /lscratch/26730337/offload/param/zero_stage_3/bfloat16params/rank0/291_param.tensor.swp error = 2**

[cn1112:0]:  File "/vf/users/Panaji/anaconda3/envs/th21_ds/mypkgs/DeepSpeed_0142/deepspeed/runtime/engine.py", line 1582, in _configure_zero_optimizer                                                    [25/1960]
[cn1112:0]:    optimizer = DeepSpeedZeroOptimizer_Stage3(                                                                                                                                                           
[cn1112:0]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[cn1112:0]:  File "/vf/users/Panaji/anaconda3/envs/th21_ds/mypkgs/DeepSpeed_0142/deepspeed/runtime/zero/stage3.py", line 362, in __init__
[cn1112:0]:    self._setup_for_real_optimizer()
[cn1112:0]:  File "/vf/users/Panaji/anaconda3/envs/th21_ds/mypkgs/DeepSpeed_0142/deepspeed/runtime/zero/stage3.py", line 472, in _setup_for_real_optimizer
[cn1112:0]:    self._create_fp32_partitions()                                                             
[cn1112:0]:  File "/vf/users/Panaji/anaconda3/envs/th21_ds/mypkgs/DeepSpeed_0142/deepspeed/runtime/zero/stage3.py", line 845, in _create_fp32_partitions
[cn1112:0]:    self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i)                                                                                                                                      
[cn1112:0]:  File "/vf/users/Panaji/anaconda3/envs/th21_ds/mypkgs/DeepSpeed_0142/deepspeed/runtime/zero/stage3.py", line 762, in _swap_in_sub_group_to_flat_buffer
[cn1112:0]:    param.nvme_swapper.swap_in([param], async_op=False) 
[cn1112:0]:  File "/vf/users/Panaji/anaconda3/envs/th21_ds/mypkgs/DeepSpeed_0142/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py", line 306, in swap_in
[cn1112:0]:    swap_in_tensors(self.aio_read_handle, swap_in_buffers, swap_in_paths)
[cn1112:0]:  File "/vf/users/Panaji/anaconda3/envs/th21_ds/mypkgs/DeepSpeed_0142/deepspeed/runtime/swap_tensor/utils.py", line 21, in swap_in_tensors
[cn1112:0]:    assert (swap_handle.async_pread(buffer, path) == 0)

It shows that some parameter file was not saved in the storage. I guess one possible reason is that it failed to build the correct parameter mapping.

My Zero3-offload is:

"zero_optimization": {
  "stage": 3,
  "offload_optimizer": {
    "device": "nvme",
    "nvme_path": "/lscratch/26730337/offload/optimizer",
    "pin_memory": true,
    "ratio": 0.2,
    "buffer_count": 4,
    "fast_init": false
  },
  "offload_param": {
    "device": "nvme",
    "nvme_path": "/lscratch/26730337/offload/param",
    "pin_memory": true,
    "buffer_count": 5,
    "buffer_size": 1e9,
    "max_in_cpu": 1e9
  },
  "overlap_comm": true,
  "contiguous_gradients": true,
  "sub_group_size": 1e9,
  "reduce_bucket_size": "auto",
  "stage3_prefetch_bucket_size": 0,
  "stage3_param_persistence_threshold": "auto",
  "stage3_max_live_parameters": 1e9,
  "stage3_max_reuse_distance": 0,
  "gather_16bit_weights_on_model_save": true
},

@bug-fixed
Copy link

@tjruwase I have updated my comment, please kindly check it. Thanks.

@jomayeri
Copy link
Contributor

@bug-fixed Does the same thing happen when you offload to CPU?

@lihe07
Copy link

lihe07 commented May 21, 2024

@jomayeri Just encountered this problem. I use CPU offloading, and here is my deepspeed config:

    "zero_optimization": {
        "stage": 3,
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    }

The specific traceback is

    ret_val = func(*args, **kwargs)  File ".../site-packages/deepspeed/runtime/zero/stage3.py", line 2117, in unscale_and_clip_grads

              ^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/deepspeed/runtime/zero/stage3.py", line 2047, in step
    self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:2 and cpu!
    self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
  File ".../site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/deepspeed/runtime/zero/stage3.py", line 2117, in unscale_and_clip_grads
    self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

@lihe07
Copy link

lihe07 commented May 21, 2024

I found a workaround. Just manually patching your runtime/zero/stage3.py according to PR 5461 will fix everything.

@loadams
Copy link
Contributor

loadams commented May 21, 2024

I found a workaround. Just manually patching your runtime/zero/stage3.py according to PR 5461 will fix everything.

@lihe07 - so using the latest deepspeed built from source works? You don't hit any issues with Zero stage 3?

@lihe07
Copy link

lihe07 commented May 22, 2024

so using the latest deepspeed built from source works? You don't hit any issues with Zero stage 3?

@loadams I directly modified the source in my deepspeed 0.14.2 installation, and ZeRO stage 3 is working fluently now.

The status of the latest code should depend on Pull 5493, as it re-introduced the buggy optimization.

@bug-fixed
Copy link

@bug-fixed Does the same thing happen when you offload to CPU?

@jomayeri The machine I'm working on has very limited memory and is shared with others. it is difficult for me to test the "device": "cpu", option.

@jomayeri
Copy link
Contributor

Both this issue and #5422 are referring to this line in Zero3 which appears to be reverted to the correct state in master. If any user is having similar issues (@bug-fixed) please open a separate thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

6 participants