Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can I make action sampling within the range specified by my environment when using onpolicy_trainer? #1142

Open
lidaken opened this issue May 9, 2024 · 6 comments
Labels
question Further information is requested

Comments

@lidaken
Copy link

lidaken commented May 9, 2024

Hi, I am new to tianshou and RL. I created a env and used ppo in tianshou to run. But I found the action sampling is out of range. So I searched for, and I found map_action. But it seem not used in trainer
So, how can I solve this problem. Thanks a lot
# continuous actions: orn_low = np.array([-30, -30, -30]) * np.pi / 180 orn_high = np.array([30, 30, 30]) * np.pi / 180 v_low = np.array([0.001]) v_high = np.array([0.1]) distance_low = np.array([0.01]) distance_high = np.array([0.5]) act_low = np.concatenate((orn_low,v_low,distance_low)) act_high = np.concatenate((orn_high, v_high,distance_high)) bias = () self.action_space = spaces.Box(low = act_low, high = act_high, dtype = np.float64) self.action = np.zeros(self.action_space.shape, dtype = self.action_space)
`#model
net_a = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
activation=nn.Tanh,
device=args.device,
)
actor = ActorProb(
net_a,
args.action_shape,
unbounded=True,
device=args.device,
).to(args.device)
net_c = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
activation=nn.Tanh,
device=args.device,
)
critic = Critic(net_c, device=args.device).to(args.device)
actor_critic = ActorCritic(actor, critic)

torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in actor_critic.modules():
    if isinstance(m, torch.nn.Linear):
        # orthogonal initialization
        torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
        torch.nn.init.zeros_(m.bias)
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.zeros_(m.bias)
        m.weight.data.copy_(0.01 * m.weight.data)

optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)

lr_scheduler = None
if args.lr_decay:
    # decay learning rate to 0 linearly
    max_update_num = np.ceil(
        args.step_per_epoch / args.step_per_collect
    ) * args.epoch

    lr_scheduler = LambdaLR(
        optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num
    )

def dist(*logits):
    return Independent(Normal(*logits), 1)

policy = PPOPolicy(
    actor,
    critic,
    optim,
    dist,
    discount_factor=args.gamma,
    gae_lambda=args.gae_lambda,
    max_grad_norm=args.max_grad_norm,
    vf_coef=args.vf_coef,
    ent_coef=args.ent_coef,
    reward_normalization=args.rew_norm,
    action_scaling=True,
    action_bound_method=args.bound_action_method,
    lr_scheduler=lr_scheduler,
    action_space=env.action_space,
    eps_clip=args.eps_clip,
    value_clip=args.value_clip,
    dual_clip=args.dual_clip,
    advantage_normalization=args.norm_adv,
    recompute_advantage=args.recompute_adv,
)

if not args.watch:
    # trainer
    #train_envs.render(args.watch)

    result = onpolicy_trainer(
        policy,
        train_collector,
        test_collector,
        args.epoch,
        args.step_per_epoch,
        args.repeat_per_collect,
        args.test_num,
        args.batch_size,
        step_per_collect=args.step_per_collect,
        save_best_fn=save_best_fn,
        logger=logger,
        test_in_train=False,
    )`
@lidaken
Copy link
Author

lidaken commented May 9, 2024

`
#model
net_a = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
activation=nn.Tanh,
device=args.device,
)
actor = ActorProb(
net_a,
args.action_shape,
unbounded=True,
device=args.device,
).to(args.device)
net_c = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
activation=nn.Tanh,
device=args.device,
)
critic = Critic(net_c, device=args.device).to(args.device)
actor_critic = ActorCritic(actor, critic)

torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in actor_critic.modules():
    if isinstance(m, torch.nn.Linear):
        # orthogonal initialization
        torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
        torch.nn.init.zeros_(m.bias)
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.zeros_(m.bias)
        m.weight.data.copy_(0.01 * m.weight.data)

optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)

lr_scheduler = None
if args.lr_decay:
    # decay learning rate to 0 linearly
    max_update_num = np.ceil(
        args.step_per_epoch / args.step_per_collect
    ) * args.epoch

    lr_scheduler = LambdaLR(
        optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num
    )

def dist(*logits):
    return Independent(Normal(*logits), 1)

policy = PPOPolicy(
    actor,
    critic,
    optim,
    dist,
    discount_factor=args.gamma,
    gae_lambda=args.gae_lambda,
    max_grad_norm=args.max_grad_norm,
    vf_coef=args.vf_coef,
    ent_coef=args.ent_coef,
    reward_normalization=args.rew_norm,
    action_scaling=True,
    action_bound_method=args.bound_action_method,
    lr_scheduler=lr_scheduler,
    action_space=env.action_space,
    eps_clip=args.eps_clip,
    value_clip=args.value_clip,
    dual_clip=args.dual_clip,
    advantage_normalization=args.norm_adv,
    recompute_advantage=args.recompute_adv,
)

# # load a previous policy
# if args.resume_path:
#     ckpt = torch.load(args.resume_path, map_location=args.device)
#     policy.load_state_dict(ckpt["model"])
#     train_envs.set_obs_rms(ckpt["obs_rms"])
#     test_envs.set_obs_rms(ckpt["obs_rms"])
#     print("Loaded agent from: ", args.resume_path)

# collector
# if args.training_num > 1:
#     buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
# else:
buffer = VectorReplayBuffer(args.buffer_size, buffer_num = len(train_envs))
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)


#Collector will transfer env.reset() function, and lead to pybullet error
test_collector = Collector(policy, test_envs)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "ppo"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)

# logger
if args.logger == "wandb":
    logger = WandbLogger(
        save_interval=1,
        name=log_name.replace(os.path.sep, "__"),
        run_id=args.resume_id,
        config=args,
        project=args.wandb_project,
    )
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
    logger = TensorboardLogger(writer)
else:  # wandb
    logger.load(writer)

def save_best_fn(policy):
    state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()}
    torch.save(state, os.path.join(log_path, "policy.pth"))

if not args.watch:
    # trainer
    #train_envs.render(args.watch)

    result = onpolicy_trainer(
        policy,
        train_collector,
        test_collector,
        args.epoch,
        args.step_per_epoch,
        args.repeat_per_collect,
        args.test_num,
        args.batch_size,
        step_per_collect=args.step_per_collect,
        save_best_fn=save_best_fn,
        logger=logger,
        test_in_train=False,
    )

`

@lidaken
Copy link
Author

lidaken commented May 9, 2024

`# continuous actions:

    orn_low = np.array([-30, -30, -30]) * np.pi / 180

    orn_high = np.array([30, 30, 30]) * np.pi / 180

    v_low = np.array([0.001])

    v_high = np.array([0.1])

    distance_low = np.array([0.01])

    distance_high = np.array([0.5])

    act_low = np.concatenate((orn_low,v_low,distance_low))

    act_high = np.concatenate((orn_high, v_high,distance_high))

    self.action_space = spaces.Box(low = act_low, high = act_high, dtype = np.float64)

    self.action = np.zeros(self.action_space.shape, dtype = self.action_space)`

@lidaken
Copy link
Author

lidaken commented May 9, 2024

sorry, I don't know what's wrong with the code T_T

@lidaken
Copy link
Author

lidaken commented May 9, 2024

in my debug, I found that map_action will be called in collector.py. And i found I can't enter
if isinstance(self.action_space, gym.spaces.Box) and \ isinstance(act, np.ndarray):
so I set Judgment, and found my action_space is gym.spaces.box.Box. Not gym.spaces.Box, can someone tell me how to solve it
微信图片_20240509224453

微信图片_20240509224446

@MischaPanch
Copy link
Collaborator

I can take a look soon. Could you pls

  1. Format your posts above to make them a bit more readable
  2. Give some info on how you defined your environment. Tianshou only supports gymnasium envs, not gym - maybe that is the problem? Is the environment code available?
  3. Check whether you have the right versions installed. How did you install tianshou, from master or from pypi? As mentioned above, you should not build your env from gym, so gym should not be installed, and gymnasium should be of the version that automatically comes with tianshou

@MischaPanch MischaPanch added the question Further information is requested label May 9, 2024
@lidaken
Copy link
Author

lidaken commented May 10, 2024

I can take a look soon. Could you pls

  1. Format your posts above to make them a bit more readable
  2. Give some info on how you defined your environment. Tianshou only supports gymnasium envs, not gym - maybe that is the problem? Is the environment code available?
  3. Check whether you have the right versions installed. How did you install tianshou, from master or from pypi? As mentioned above, you should not build your env from gym, so gym should not be installed, and gymnasium should be of the version that automatically comes with tianshou

ok, thanks, I will re-upload code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants