torchgfn
torchgfn copied to clipboard
Add a test for Trajectories extend function
As was done in https://github.com/saleml/torchgfn/pull/79 for the master
branch:
def test_extend_trajectories_on_cuda():
import os
import sys
sys.path.insert(0, os.path.abspath("__file__" + "/../"))
from src.gfn.containers.trajectories import Trajectories as Traj
torch.manual_seed(0)
env = HyperGrid(ndim=4, height=8, R0=0.01, device_str="cuda")
sampler = TrajectoriesSampler(
env=env,
actions_sampler=DiscreteActionsSampler(
estimator=LogitPFEstimator(env=env, module_name="NeuralNet"),
),
)
trajectories_1 = sampler.sample(n_trajectories=10)
trajectories_2 = sampler.sample(n_trajectories=10)
trajectories_1 = Traj(
env=sampler.env,
states=trajectories_1.states,
actions=trajectories_1.actions,
when_is_done=trajectories_1.when_is_done,
is_backward=sampler.is_backward,
log_rewards=trajectories_1.log_rewards,
log_probs=trajectories_1.log_probs,
)
trajectories_2 = Traj(
env=sampler.env,
states=trajectories_2.states,
actions=trajectories_2.actions,
when_is_done=trajectories_2.when_is_done,
is_backward=sampler.is_backward,
log_rewards=trajectories_2.log_rewards,
log_probs=trajectories_2.log_probs,
)
trajectories_1.extend(trajectories_2)