plumetracknets icon indicating copy to clipboard operation
plumetracknets copied to clipboard

Agent can't succuessfully acheive the goal using parameter downloaded from figshare.

Open FanboZhao opened this issue 1 month ago • 0 comments

Code (Five VRNN model)

Plume_env is the same as source code.

import torch
import warnings
import numpy as np
import matplotlib.pyplot as plt

from torch.serialization import SourceChangeWarning
warnings.filterwarnings("ignore", category=SourceChangeWarning)

from plume_env import PlumeEnvironment
if __name__ == "__main__":
    env = PlumeEnvironment()
    model_path = "plume-data/net1.pt"
    device = torch.device('cpu')
    actor_critic, ob_rms = torch.load(model_path, map_location=device)

    recurrent_hidden_states = torch.zeros(1, 
                    actor_critic.recurrent_hidden_state_size, device=device)
    masks = torch.zeros(1, 1, device=device)
    obs = env.reset()
    obs = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
    
    traj = []

    for _ in range(300):
        with torch.no_grad():
            value, action, _, recurrent_hidden_states, activity = actor_critic.act(
                obs,
                recurrent_hidden_states,
                masks,
                deterministic=True)

        obs, reward, done, info = env.step(action.squeeze(0).detach().cpu().numpy())

        print("obs :", obs)
        obs = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
        masks.fill_(0.0 if done else 1.0)

        traj.append(info['location'])
        print("location:", info['location'])

    traj = np.array(traj)
    plt.figure(figsize=(6,6))
    plt.plot(traj[:, 0], traj[:, 1], '-o', markersize=3, linewidth=1)
    plt.scatter(traj[0, 0], traj[0, 1], s=60, marker='o', label='start', zorder=5)
    plt.scatter(traj[-1, 0], traj[-1, 1], s=60, marker='X', label='end', zorder=5)
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Agent trajectory')
    plt.legend()
    plt.axis('equal')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("trajectory1.png", dpi=200)
    print("Saved as trajectory5.png")
    plt.show()

Trajectory (Five VRNN model)

Image Image Image Image Image

FanboZhao avatar Nov 18 '25 00:11 FanboZhao