stable-baselines3-contrib icon indicating copy to clipboard operation
stable-baselines3-contrib copied to clipboard

[Bug]: in "RecurrentPPO" not work "model.policy.evaluate_actions()"

Open drulye opened this issue 11 months ago • 2 comments

🐛 Bug

in "RecurrentPPO" not work "model.policy.evaluate_actions()"

drulye avatar Jan 15 '25 19:01 drulye

# Importing libraries
from collections import namedtuple
from sb3_contrib import RecurrentPPO
import numpy as np
import torch as th

# !!! Set to True to SHOW ERROR !!!
SHOW_ERROR = False  # True, False

# Initialize environment and model
model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1").learn(200)
env = model.get_env()

# Reset environment
observation_lst = env.reset()
observation_lst = observation_lst[0]
state = None
done = "True"
reward_sum = 0

# Run simulation for 100 steps
for step in range(0, 100, 1):

    # Predict actions and update LSTM state
    actions, state = model.predict(
        observation=np.array(object=observation_lst, dtype=np.float32),
        state=state,
        episode_start=np.array(object=[done == "True"], dtype=np.bool_),
        deterministic=True,
    )
    actions = actions.item()

    # Prepare tensors for policy evaluation
    obs = th.tensor(data=np.array(object=[observation_lst], dtype=np.float32), dtype=th.float32, device=model.policy.device)
    act = th.tensor(data=np.array(object=[actions], dtype=np.int64), dtype=th.int64, device=model.policy.device)
    net_arch = namedtuple("net_arch", ["pi", "vf"])
    pi = th.tensor(data=state[0], dtype=th.float32, device=model.policy.device)
    vf = th.tensor(data=state[1], dtype=th.float32, device=model.policy.device)
    lstm_states = net_arch(pi=pi, vf=vf)
    episode_starts = th.tensor(data=np.array(object=[done == "True"], dtype=np.bool_), dtype=th.int64, device=model.policy.device)

    # Evaluate policy outputs
    action_probability_distribution = model.policy.get_distribution(obs=obs, lstm_states=lstm_states, episode_starts=episode_starts)
    state_value_estimates_a = model.policy.predict_values(obs=obs, lstm_states=lstm_states, episode_starts=episode_starts)
    if SHOW_ERROR == True:
        state_value_estimates_b, log_probabilities_of_actions, action_distribution_entropy = model.policy.evaluate_actions(obs=obs, actions=act, lstm_states=lstm_states, episode_starts=episode_starts)

    # Extract key metrics
    _action_probability_distribution_0 = action_probability_distribution[0].distribution.probs.detach().cpu().numpy()[0][0]
    _action_probability_distribution_1 = action_probability_distribution[0].distribution.probs.detach().cpu().numpy()[0][1]
    _state_value_estimates_a = state_value_estimates_a.detach().cpu().numpy()[0][0]
    if SHOW_ERROR == True:
        _state_value_estimates_b = state_value_estimates_b.detach().cpu().numpy()[0][0]
        _log_probabilities_of_actions = log_probabilities_of_actions.detach().cpu().numpy()[0]
        _action_distribution_entropy = action_distribution_entropy.detach().cpu().numpy()[0]

    # Print information for current step
    print("----------------------------------------------------------------------------------------------------------")
    print(f"step: {step} | actions: {actions} | done: {done}")
    print("----------------------------------")
    print("ACTION_PROBABILITY_DISTRIBUTION  :", action_probability_distribution[0].distribution.probs)
    print("ACTION_PROBABILITY_DISTRIBUTION_0:", _action_probability_distribution_0)
    print("ACTION_PROBABILITY_DISTRIBUTION_1:", _action_probability_distribution_1)
    print("----------------------------------")
    print("STATE_VALUE_ESTIMATES_A          :", state_value_estimates_a)
    print("STATE_VALUE_ESTIMATES_A          :", _state_value_estimates_a)
    if SHOW_ERROR == True:
        print("----------------------------------")
        print("STATE_VALUE_ESTIMATES_B          :", state_value_estimates_b)
        print("STATE_VALUE_ESTIMATES_B          :", _state_value_estimates_b)
        print("LOG_PROBABILITIES_OF_ACTIONS     :", log_probabilities_of_actions)
        print("LOG_PROBABILITIES_OF_ACTIONS     :", _log_probabilities_of_actions)
        print("ACTION_DISTRIBUTION_ENTROPY      :", action_distribution_entropy)
        print("ACTION_DISTRIBUTION_ENTROPY      :", _action_distribution_entropy)
    print("----------------------------------------------------------------------------------------------------------")

    # Perform action in environment and retrieve results
    observation_lst, reward, done, info = env.step(actions=np.array(object=[actions], dtype=np.int64))
    observation_lst = observation_lst[0]
    reward_sum += reward[0]
    done = str(done[0])

# Print final step and total reward
print(f"final step: {step} | total reward: {reward_sum}")

# Traceback (most recent call last):
#   File "/home/drulye/vscode/project/error_rppo_evaluate_actions.py", line 48, in <module>
#     state_value_estimates_b, log_probabilities_of_actions, action_distribution_entropy = model.policy.evaluate_actions(obs=obs, actions=act, lstm_states=lstm_states, episode_starts=episode_starts)
#                                                                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#   File "/home/drulye/anaconda3/lib/python3.12/site-packages/sb3_contrib/common/recurrent/policies.py", line 331, in evaluate_actions
#     latent_pi, _ = self._process_sequence(pi_features, lstm_states.pi, episode_starts, self.lstm_actor)
#                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#   File "/home/drulye/anaconda3/lib/python3.12/site-packages/sb3_contrib/common/recurrent/policies.py", line 186, in _process_sequence
#     features_sequence = features.reshape((n_seq, -1, lstm.input_size)).swapaxes(0, 1)
#                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# RuntimeError: shape '[256, -1, 4]' is invalid for input of size 4

drulye avatar Jan 16 '25 19:01 drulye

Hello, I'm not sure what you are trying to achieve, evaluate_actions() is used here: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/c070fc2faedaf28bf62299cc36c8be5ac68d15fd/sb3_contrib/ppo_recurrent/ppo_recurrent.py#L345-L350

So I would recommend to set a debugger and analyze the shape of the tensors and compare it to your script.

araffin avatar Feb 03 '25 18:02 araffin