Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

why PPO needs to store action_log_probs instead of using stop_gradient for better efficiency? #284

Open
Emerald01 opened this issue Oct 15, 2021 · 1 comment

Comments

@Emerald01
Copy link

Hi,
I am looking at the PPO implementation, and I am curious about this part (actually many other implementations are using this workflow as well, so I am also curious to see if I miss anything)

So the action_log_probs is created, removed gradient (by setting requires_gradient=False), and inserted into the storage buffer, this action_log_probs is generated by the following function and then will be referred as old_action_log_probs_batch in PPO

def act(self, inputs, rnn_hxs, masks, deterministic=False):
        ...
        action_log_probs = dist.log_probs(action)

        return value, action, action_log_probs, rnn_hxs

In PPO algorithm, the ratio is calculated by the following, the action_log_probs is from evaluate_actions()

values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions(
                    obs_batch, recurrent_hidden_states_batch, masks_batch,
                    actions_batch)
ratio = torch.exp(action_log_probs - old_action_log_probs_batch)

If I am not understanding wrong, evaluate_actions() and act() will output the same action_log_probs because they are using the same actor_critic and calling log_probs(action), the only difference is the old_action_log_probs_batch has the gradient removed, so backpropagation will not go through it.

So my question is, why we bother to save old_action_log_probs_batch in the storage, but instead, something like this can be created on the fly.

values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions(
                    obs_batch, recurrent_hidden_states_batch, masks_batch,
                    actions_batch)
old_action_log_probs_batch = action_log_probs.detach()
ratio = torch.exp(action_log_probs - old_action_log_probs_batch)

Thank you for your attention. Look forward to the discussion.

Regards,
Tian

@yueyang130
Copy link

In my understanding, the key point is that after sampling trajectories, the agent parameters would be updated several times (it's up to args.ppo_epoch). At the first updating time, the situation is as you said. However, since the second time, the old_action_log_probs in the PPO implementation is calculated based on the original paramenters, while old_action_log_probs in your implementation is calculated based on paramenters that have been updated once.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants