Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Episode return is not recorded correctly in cleanRL's example #299

Open
3 tasks done
williamd4112 opened this issue Apr 4, 2024 · 0 comments
Open
3 tasks done
Assignees

Comments

@williamd4112
Copy link

Describe the bug

In cleanRL example, the episode returns in two independent episodes are added up since the episode return counter is not reset at time-out.

To Reproduce

The following is the minimal example to reproduce the bug. I tried the RecordEpisodeStatistics wrapper in envpool repo and clean RL repo, and both have this bug.

import gym
import numpy as np
import envpool

is_legacy_gym = True

# From: https://github.com/sail-sg/envpool/blob/main/examples/cleanrl_examples/ppo_atari_envpool.py
class RecordEpisodeStatistics(gym.Wrapper):

  def __init__(self, env, deque_size=100):
    super(RecordEpisodeStatistics, self).__init__(env)
    self.num_envs = getattr(env, "num_envs", 1)
    self.episode_returns = None
    self.episode_lengths = None
    # get if the env has lives
    self.has_lives = False
    env.reset()
    info = env.step(np.zeros(self.num_envs, dtype=int))[-1]
    if info["lives"].sum() > 0:
      self.has_lives = True
      print("env has lives")

  def reset(self, **kwargs):
    if is_legacy_gym:
      observations = super(RecordEpisodeStatistics, self).reset(**kwargs)
    else:
      observations, _ = super(RecordEpisodeStatistics, self).reset(**kwargs)
    self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
    self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
    self.lives = np.zeros(self.num_envs, dtype=np.int32)
    self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
    self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
    return observations

  def step(self, action):
    if is_legacy_gym:
      observations, rewards, dones, infos = super(
        RecordEpisodeStatistics, self
      ).step(action)
    else:
      observations, rewards, term, trunc, infos = super(
        RecordEpisodeStatistics, self
      ).step(action)
      dones = term + trunc
    self.episode_returns += infos["reward"]
    self.episode_lengths += 1
    self.returned_episode_returns[:] = self.episode_returns
    self.returned_episode_lengths[:] = self.episode_lengths
    all_lives_exhausted = infos["lives"] == 0
    if self.has_lives:
      self.episode_returns *= 1 - all_lives_exhausted
      self.episode_lengths *= 1 - all_lives_exhausted
    else:
      self.episode_returns *= 1 - dones
      self.episode_lengths *= 1 - dones
    infos["r"] = self.returned_episode_returns
    infos["l"] = self.returned_episode_lengths
    return (
      observations,
      rewards,
      dones,
      infos,
    )

# From: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_envpool.py
# class RecordEpisodeStatistics(gym.Wrapper):
#     def __init__(self, env, deque_size=100):
#         super().__init__(env)
#         self.num_envs = getattr(env, "num_envs", 1)
#         self.episode_returns = None
#         self.episode_lengths = None

#     def reset(self, **kwargs):
#         observations = super().reset(**kwargs)
#         self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
#         self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
#         self.lives = np.zeros(self.num_envs, dtype=np.int32)
#         self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
#         self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
#         return observations

#     def step(self, action):
#         observations, rewards, dones, infos = super().step(action)
#         self.episode_returns += infos["reward"]
#         self.episode_lengths += 1
#         self.returned_episode_returns[:] = self.episode_returns
#         self.returned_episode_lengths[:] = self.episode_lengths       
#         self.episode_returns *= 1 - infos["terminated"]
#         self.episode_lengths *= 1 - infos["terminated"]
#         infos["r"] = self.returned_episode_returns
#         infos["l"] = self.returned_episode_lengths
#         return (
#             observations,
#             rewards,
#             dones,
#             infos,
#         )


if __name__ == "__main__":
  
  np.random.seed(1)
  
  envs = envpool.make(
    "UpNDown-v5",
    env_type="gym",
    num_envs=1,
    episodic_life=True,  # Espeholt et al., 2018, Tab. G.1
    repeat_action_probability=0,  # Hessel et al., 2022 (Muesli) Tab. 10
    full_action_space=False,  # Espeholt et al., 2018, Appendix G., "Following related work, experts use game-specific action sets."
    max_episode_steps=30, # Set as 50 to hit timelimit faster
    reward_clip=True,
    seed=1,
  )
  envs = RecordEpisodeStatistics(envs)
 
  num_episodes = 2

  episode_count = 0
  cur_episode_len = 0
  cur_episode_return = 0

  my_episode_returns = []
  my_episode_lens = []

  # Track episode returns here to compare with the ones recorded with `RecordEpisodeStatistics`
  recorded_episode_returns = []
  recorded_episode_lens = []
  
  obs = envs.reset()
  while episode_count < num_episodes:   
      action = np.random.randint(0, envs.action_space.n, 1)
      obs, reward, done, info = envs.step(action)
      cur_episode_return += info["reward"][0]
      cur_episode_len += 1
      print(f"Ep={episode_count}, EpStep={cur_episode_len}, Return={info['r']}, MyReturn={cur_episode_return}, Terminated={info['terminated']}, Timeout={info['TimeLimit.truncated']}, Lives={info['lives']}")
      
      # info["terminated"] = True: Game over.
      # info["TimeLimit.truncated"] = True: Timeout, the environment will be reset (so the episode return should be reset too)
      if info["terminated"][0] or info["TimeLimit.truncated"][0]:
        recorded_episode_returns.append(info["r"][0]) # Append the episode return recorded in `RecordEpisodeStatistics`
        recorded_episode_lens.append(info["l"][0]) # Append the episode length recorded in `RecordEpisodeStatistics`
        my_episode_returns.append(cur_episode_return)
        my_episode_lens.append(cur_episode_len)
        print(f"Episode {episode_count}'s length is {cur_episode_len} (terminated={info['terminated']}, timeout={info['TimeLimit.truncated']})")
        
        episode_count += 1
        cur_episode_return *= 1 - (info["terminated"][0] | info["TimeLimit.truncated"][0])
        cur_episode_len *= 1 - (info["terminated"][0] | info["TimeLimit.truncated"][0])

  for episode_idx in range(num_episodes):
    print(f"Episode {episode_idx}'s return is supposed to be {my_episode_returns[episode_idx]}, but the wrapper `RecordEpisodeStatistics` gives {recorded_episode_returns[episode_idx]}")
    print(f"Episode {episode_idx}'s len is supposed to be {my_episode_lens[episode_idx]}, but the wrapper `RecordEpisodeStatistics` gives {recorded_episode_lens[episode_idx]}")
      

Executing the above code snippet, you should see the following printout

env has lives
Ep=0, EpStep=1, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=2, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=3, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=4, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=5, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=6, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=7, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=8, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=9, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=10, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=11, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=12, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=13, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=14, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=15, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=16, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=17, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=18, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=19, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=20, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=21, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=22, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=23, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=24, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=25, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=26, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=27, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=28, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=29, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[ True], Lives=[5]
Episode 0's length is 29 (terminated=[0], timeout=[ True])
Ep=1, EpStep=1, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=2, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=3, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=4, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=5, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=6, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=7, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=8, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=9, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=10, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=11, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=12, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=13, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=14, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=15, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=16, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=17, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=18, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=19, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=20, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=21, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=22, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=23, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=24, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=25, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=26, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=27, Return=[20.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=28, Return=[20.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=29, Return=[20.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=30, Return=[20.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=31, Return=[20.], MyReturn=10.0, Terminated=[0], Timeout=[ True], Lives=[5]
Episode 1's length is 31 (terminated=[0], timeout=[ True])
Episode 0's return is supposed to be 10.0, but the wrapper `RecordEpisodeStatistics` gives 10.0
Episode 0's len is supposed to be 29, but the wrapper `RecordEpisodeStatistics` gives 29
Episode 1's return is supposed to be 10.0, but the wrapper `RecordEpisodeStatistics` gives 20.0
Episode 1's len is supposed to be 31, but the wrapper `RecordEpisodeStatistics` gives 60

Expected behavior

See the above example's output:

Ep=0, EpStep=29, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[ True], Lives=[5]
Episode 0's length is 29 (terminated=[0], timeout=[ True])
Ep=1, EpStep=1, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]

The return in the new episode (Ep=1) is not reset to zero but is carried from the return in the old episode. The expected behavior is to reset the return counter to zero upon timeout.

Screenshots

N/A

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
import envpool, numpy, sys
print(envpool.__version__, numpy.__version__, sys.version, sys.platform)
>>> print(envpool.__version__, numpy.__version__, sys.version, sys.platform)
0.8.4 1.21.6 3.8.12 (default, Oct 12 2021, 13:49:34) 
[GCC 7.5.0] linux

Additional context

N/A

Reason and Possible fixes

Change both

self.episode_returns *= 1 - all_lives_exhausted
self.episode_lengths *= 1 - all_lives_exhausted

to the following:

self.episode_returns *= 1 - (info["terminated"] | info["TimeLimit.truncated"])
self.episode_lengths *= 1 - (info["terminated"] | info["TimeLimit.truncated"])

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants