plumetracknets
plumetracknets copied to clipboard
Agent can't succuessfully acheive the goal using parameter downloaded from figshare.
Code (Five VRNN model)
Plume_env is the same as source code.
import torch
import warnings
import numpy as np
import matplotlib.pyplot as plt
from torch.serialization import SourceChangeWarning
warnings.filterwarnings("ignore", category=SourceChangeWarning)
from plume_env import PlumeEnvironment
if __name__ == "__main__":
env = PlumeEnvironment()
model_path = "plume-data/net1.pt"
device = torch.device('cpu')
actor_critic, ob_rms = torch.load(model_path, map_location=device)
recurrent_hidden_states = torch.zeros(1,
actor_critic.recurrent_hidden_state_size, device=device)
masks = torch.zeros(1, 1, device=device)
obs = env.reset()
obs = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
traj = []
for _ in range(300):
with torch.no_grad():
value, action, _, recurrent_hidden_states, activity = actor_critic.act(
obs,
recurrent_hidden_states,
masks,
deterministic=True)
obs, reward, done, info = env.step(action.squeeze(0).detach().cpu().numpy())
print("obs :", obs)
obs = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
masks.fill_(0.0 if done else 1.0)
traj.append(info['location'])
print("location:", info['location'])
traj = np.array(traj)
plt.figure(figsize=(6,6))
plt.plot(traj[:, 0], traj[:, 1], '-o', markersize=3, linewidth=1)
plt.scatter(traj[0, 0], traj[0, 1], s=60, marker='o', label='start', zorder=5)
plt.scatter(traj[-1, 0], traj[-1, 1], s=60, marker='X', label='end', zorder=5)
plt.xlabel('x')
plt.ylabel('y')
plt.title('Agent trajectory')
plt.legend()
plt.axis('equal')
plt.grid(True)
plt.tight_layout()
plt.savefig("trajectory1.png", dpi=200)
print("Saved as trajectory5.png")
plt.show()