torchgfn icon indicating copy to clipboard operation
torchgfn copied to clipboard

Add a test for Trajectories extend function

Open saleml opened this issue 1 year ago • 0 comments

As was done in https://github.com/saleml/torchgfn/pull/79 for the master branch:

def test_extend_trajectories_on_cuda():
    import os
    import sys

    sys.path.insert(0, os.path.abspath("__file__" + "/../"))

    from src.gfn.containers.trajectories import Trajectories as Traj

    torch.manual_seed(0)

    env = HyperGrid(ndim=4, height=8, R0=0.01, device_str="cuda")
    sampler = TrajectoriesSampler(
        env=env,
        actions_sampler=DiscreteActionsSampler(
            estimator=LogitPFEstimator(env=env, module_name="NeuralNet"),
        ),
    )

    trajectories_1 = sampler.sample(n_trajectories=10)
    trajectories_2 = sampler.sample(n_trajectories=10)

    trajectories_1 = Traj(
        env=sampler.env,
        states=trajectories_1.states,
        actions=trajectories_1.actions,
        when_is_done=trajectories_1.when_is_done,
        is_backward=sampler.is_backward,
        log_rewards=trajectories_1.log_rewards,
        log_probs=trajectories_1.log_probs,
    )
    trajectories_2 = Traj(
        env=sampler.env,
        states=trajectories_2.states,
        actions=trajectories_2.actions,
        when_is_done=trajectories_2.when_is_done,
        is_backward=sampler.is_backward,
        log_rewards=trajectories_2.log_rewards,
        log_probs=trajectories_2.log_probs,
    )

    trajectories_1.extend(trajectories_2)

saleml avatar Aug 02 '23 01:08 saleml