Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IndexError: InlinedVector::at(size_type) const failed bounds check #957

Open
caixiiaoyang opened this issue Sep 24, 2023 · 0 comments
Open

Comments

@caixiiaoyang
Copy link

Please describe the bug
IndexError: InlinedVector::at(size_type) const failed bounds check

System information and environment

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04, docker):
  • Python version:3.8.10
  • CUDA version:11.3
  • NCCL version:2.9
  • cupy version:11.3
  • GPU model and memory:2*A100(80G)
  • Alpa version:0.2.3
  • TensorFlow version:2.8.0
  • JAX version:0.3.22

To Reproduce
Steps to reproduce the behavior:
1.Training an llama model implemented by flax produces the following error
2. See error
2023-09-24 12:29:49,782 INFO worker.py:1342 -- Connecting to existing Ray cluster at address: 10.233.115.148:6379...
2023-09-24 12:29:49,795 INFO worker.py:1528 -- Connected to Ray cluster.
Training/epoch 0: 0%| | 0/7473 [00:01<?, ?it/s]
Traceback (most recent call last):
File "./Trainer/train_ray_batch.py", line 149, in
main()
File "./Trainer/train_ray_batch.py", line 139, in main
state, loss = train_step(state, seq, seq_mask, labels, labels_mask)
File "/home/mpi/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/api.py", line 121, in call
self._decode_args_and_get_executable(*args))
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/api.py", line 191, in _decode_args_and_get_executable
executable = _compile_parallel_executable(f, in_tree, out_tree_hashable,
File "/home/mpi/.local/lib/python3.8/site-packages/jax/linear_util.py", line 309, in memoized_fun
ans = call(fun, *args)
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/api.py", line 223, in _compile_parallel_executable
return method.compile_executable(fun, in_tree, out_tree_thunk,
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/parallel_method.py", line 108, in compile_executable
return compile_shard_executable(fun, in_tree, out_tree_thunk,
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/compile_executable.py", line 78, in compile_shard_executable
return shard_parallel_internal(fun, in_tree, out_tree_thunk,
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/compile_executable.py", line 139, in shard_parallel_internal
hlo, stage_plan = run_auto_sharding_pass(hlo, logical_mesh_choices[0],
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/auto_sharding.py", line 345, in run_auto_sharding_pass
xe.run_auto_sharding(hlo.get_module(), compile_options)
jax._src.traceback_util.UnfilteredStackTrace: IndexError: InlinedVector::at(size_type) const failed bounds check

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "./Trainer/train_ray_batch.py", line 149, in
main()
File "./Trainer/train_ray_batch.py", line 139, in main
state, loss = train_step(state, seq, seq_mask, labels, labels_mask)
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/auto_sharding.py", line 345, in run_auto_sharding_pass
xe.run_auto_sharding(hlo.get_module(), compile_options)
IndexError: InlinedVector::at(size_type) const failed bounds check

Screenshots
image

Code snippet to reproduce the problem
@alpa.parallelize(batch_argnums=(1,2,3,4))
def train_step(state, seq, seq_mask, labels, labels_mask):

def train_forward(params):
# seq, seq_mask, labels, labels_mask = data_batch
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(seq).shape[-1]),seq.shape)
outputs = state.apply_fn(
params,
seq,
seq_mask,
position_ids,
deterministic = False,
return_dict = False,
)
logits = outputs[0]
loss = cross_entropy_loss(logits, labels, mask=labels_mask)
return loss
dynamic_scale = state.dynamic_scale
if dynamic_scale:
grad_fn = dynamic_scale.value_and_grad(train_forward)
dynamic_scale, is_fin, loss, grads = grad_fn(state.params)

new_state = state.apply_gradients(grads=grads)

if dynamic_scale:
new_state = new_state.replace(
opt_state=jax.tree_map(
functools.partial(jnp.where, is_fin),
new_state.opt_state, state.opt_state),
params=jax.tree_map(
functools.partial(jnp.where, is_fin),
new_state.params, state.params),
master_copy=jax.tree_map(
functools.partial(jnp.where, is_fin),
new_state.master_copy, state.master_copy),
dynamic_scale=dynamic_scale)

return new_state, loss

def main() -> None:
global llama_model
alpa.init(cluster="ray")
lr = 0.001
batch_size = 1
max_len = 640
n_epochs = 7

load_pretrained_model = False
ckpt_dir="./JAX_model/7B"

prepare dataset

tokenizer = LLaMATokenizer("./JAX_model/tokenizer.model")
dataset = GSMDataset(split='train')
collate_fn = partial(gsm_collate_fn_train, tokenizer=tokenizer, max_len=max_len)
dataloader = LlamaDataLoader(dataset, batch_size, collate_fn)

set config

if load_pretrained_model:
with open(Path(ckpt_dir)/"params.json", "r") as f:
config_params = json.loads(f.read())
config_params.update({'vocab_size': len(tokenizer), 'max_seq_len':max_len})
llama_config = LLaMAConfig(**config_params)
else:
llama_config = LLaMAConfig(num_hidden_layers=4)
llama_model = LLaMAForCausalLMModule(llama_config)

init model

input_ids = jnp.ones((batch_size,max_len))
attention_mask = jnp.ones_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),input_ids.shape)
params = llama_model.init(input_ids, attention_mask, position_ids, return_dict=False, init_cache=False)

if load_pretrained_model:
param = restore(Path(ckpt_dir)/"consolidated.nra", replace_keys=False)
params['param'] = param

n_steps = math.ceil(len(dataloader))

schedule = warmup_cosine_decay_schedule(
init_value=0.,
peak_value=lr,
warmup_steps=n_steps,
decay_steps=n_steps + 1,
end_value=lr,
)
optimizer = adamw(learning_rate=schedule)

use_master_copy = True
dynamic_scale = DynamicScale()
alpa.global_config.flax_always_use_fp16_embedding = True
state = TrainState.create(apply_fn=llama_model.run, params=params, tx=optimizer,dynamic_scale=dynamic_scale, use_master_copy=use_master_copy)

for epoch in range(n_epochs):
with tqdm(dataloader) as tepoch:
tepoch.set_description(f"Training/epoch {epoch}")
for batch in tepoch:
seq, seq_mask, labels, labels_mask = batch
state, loss = train_step(state, seq, seq_mask, labels, labels_mask)

if name == 'main':
main()

Additional information
Add any other context about the problem here or include any logs that would be helpful to diagnose the problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant