Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Universal checkpoint for zero stage 3 #5475

Open
wants to merge 18 commits into
base: master
Choose a base branch
from

Conversation

xylian86
Copy link

This PR enables the universal checkpoint for zero stage 3.

Notes:

  • The current implementation supports Data parallelism.
  • Development is ongoing for universal checkpoint Stage 3 with tensor-slicing model parallelism.
  • Pipeline parallelism is not supported by ZeRO Stage 3, and hence is not included in this universal checkpoint implementation.

In this PR:

  • I've updated deepspeed/checkpoint/ds_to_universal.py to support converting Zero checkpoints into Universal checkpoints.
  • I've updated deepspeed/runtime/zero/stage3.py to enable loading Universal checkpoints using the Stage 3 optimizer.

@xylian86
Copy link
Author

xylian86 commented Apr 29, 2024 via email

@tjruwase tjruwase requested review from samadejacobs, tohtana and lekurile and removed request for mrwyattii May 2, 2024 23:05
Copy link
Contributor

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a test for universal checkpointing. It currently supports DP scaling only but it would be good to test ZeRO3 feature using this test. You can just add "3" to the test argument.

deepspeed/runtime/zero/stage3.py Show resolved Hide resolved
deepspeed/checkpoint/ds_to_universal.py Show resolved Hide resolved
return int(text) if text.isdigit() else text


def natural_keys(text):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for introducing this interesting approach.
We have a similar sorting in _merge_zero_shards but are using a different approach but it is not good to have two different sorting implementations for the same purpose. Can you replace this one with natural_keys?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions!
For this natural_keys function, I actually reuse it from zero_to_fp32.py.

You’re right; it’s not ideal to have two different implementations for the same function. How about I replace the one in _merge_zero_shards with this natural_keys?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds good, thank you

@xylian86
Copy link
Author

xylian86 commented Jun 3, 2024

image

Convergence curve for ZeRO 3 using the current implementation

@xylian86 xylian86 requested a review from loadams as a code owner June 5, 2024 13:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants