You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We are seeing numerical differences between shardings in random number initialization on GPUs. For example, if I have a mesh of DP, FSDP, TP , based on what no of devices I allocate to each of these axes the numerical output of my initialization changes drastically. As a result of this when we are using TP we are seeing divergences in the network.
System info (python version, jaxlib version, accelerator, etc.)
IIUC this is a bug (unintended behavior) even with jax_threefry_partitionable=False, and also we don't yet know what's causing this bug. Good to know that setting jax_threefry_partitionable=True fixes it though!
Description
We are seeing numerical differences between shardings in random number initialization on GPUs. For example, if I have a mesh of DP, FSDP, TP , based on what no of devices I allocate to each of these axes the numerical output of my initialization changes drastically. As a result of this when we are using TP we are seeing divergences in the network.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: