verl icon indicating copy to clipboard operation
verl copied to clipboard

[recipe, algo] feat: Representation-based Exploration (RepExp)

Open jens321 opened this issue 6 days ago • 1 comments

What does this PR do?

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

Add support for the training and evaluation of the RepExp method introduced in section 5 of the paper Representation-Based Exploration for Language Models: From Test-Time to Post-Training.

Checklist Before Starting

  • [x] Search for similar PRs. Paste at least one query link here: https://github.com/volcengine/verl/pull/1830
  • [x] Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

Please refer to Figure 2 in https://arxiv.org/abs/2510.11686 for evaluation results.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

The general format for training with our method is as follows

sh recipe/rep_exp/train_elliptical.sh $TASK $SPARSE_DIM $BETA $SEED

where $TASK is the task name, $SPARSE_DIM is the sparse dimension, $BETA is the beta parameter, and $SEED is the seed.

For example for training on MATH with the original parameters from the paper, one would do

sh recipe/rep_exp/train_elliptical.sh math 32 0.01 42

Once done training, one can evaluate the model on the test set by following two steps.

  1. Merge the model checkpoint.

This is necessary because the model checkpoint is saved in multiple shards (depending on the nubmer of GPUs), and we need to merge them into a single checkpoint.

sh recipe/rep_exp/model_merge.sh /path/to/global_step_X/actor # where X is the global step of the checkpoint with the best pass@1 on dev
  1. Evaluate the merged model.
sh recipe/rep_exp/eval.sh $TASK /path/to/global_step_X/actor/hf #where X is the global step of the checkpoint with the best pass@1 on dev

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

All changes are contained in the recipe/rep_exp folder and summarized below:

  • rep_exp/main_rep_exp.py: copy of verl/trainer/main_ppo.py but imports EllipticalRewardModelWorker instead of the standard RewardModelWorker
  • rep_exp/rep_exp_trainer.py: copy of verl/trainer/ppo/ray_trainer.py but adds computing of hidden states and elliptical reward scores as follows:
with marked_timer("reward", timing_raw, color="yellow"):
    # compute reward model score
    if self.use_rm and "rm_scores" not in batch.batch.keys():
        if self.config.reward_model.elliptical.enable:
            hidden_states = self.rm_wg.compute_hidden_states(batch)
            batch = batch.union(hidden_states)
            reward_tensor = self.rm_wg.compute_rm_score(batch)
        else:
            reward_tensor = self.rm_wg.compute_rm_score(batch)
        batch = batch.union(reward_tensor)

In addition, there are a few lines of code that help keep track of the best pass@1 seen so far on validation.

  • rep_exp/metric_utils.py: adds and extends some utility functions that provide additional metrics for our method that could be helpful for debugging.
  • rep_exp/workers/elliptical_reward_model_worker.py: adds the EllipticalRewardModelWorker class which provides functionality for (1) computing hidden states and (2) computing elliptical reward scores based on the hidden states.
class EllipticalRewardModelWorker(RewardModelWorker):
    def __init__(self, config):
        super().__init__(config)
        self.lamb = config.elliptical.lamb
        self.normalization = config.elliptical.normalization
        self.sparse_dim = config.elliptical.sparse_dim
        self.sparse_matrix = None
        self.randomize_sparse_matrix = config.elliptical.randomize_sparse_matrix
        self.persist_covariance = config.elliptical.persist_covariance
        self.cov_inv_dict = {}
        self.mean_hidden_states_mu_dict = {}
        self.hidden_mean_counter_dict = {}

    @staticmethod
    def _construct_sparse_matrix(features: torch.Tensor, sparse_dim: int) -> torch.Tensor:
        from sklearn.random_projection import SparseRandomProjection

        sparse_proj = SparseRandomProjection(sparse_dim, density="auto")
        sparse_proj.fit(features)
        sparse_matrix = sparse_proj.components_
        sparse_matrix_coo = sparse_matrix.tocoo()

        # Convert the row and col lists to numpy arrays and then to a LongTensor (speed up)
        indices = torch.LongTensor(np.array([sparse_matrix_coo.row, sparse_matrix_coo.col]))
        values = torch.FloatTensor(sparse_matrix_coo.data)

        sparse_mat = torch.sparse_coo_tensor(indices, values, [sparse_dim, features.shape[1]]).t()

        return sparse_mat

    def _build_model(self, config):
        # the following line is necessary
        from torch.distributed.fsdp import CPUOffload
        from transformers import AutoConfig, AutoModel

        use_shm = config.model.get("use_shm", False)
        # download the checkpoint from hdfs
        local_path = copy_to_local(config.model.path, use_shm=use_shm)

        if self.config.model.input_tokenizer is None:
            self._do_switch_chat_template = False
        else:
            self._do_switch_chat_template = True
            input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer, use_shm=use_shm)
            self.input_tokenizer = hf_tokenizer(
                input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False)
            )
            self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False))

        trust_remote_code = config.model.get("trust_remote_code", False)
        model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
        model_config.num_labels = 1

        # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
        init_context = get_init_weight_context_manager(
            use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh
        )

        with init_context(), warnings.catch_warnings():
            warnings.simplefilter("ignore")
            model_config.classifier_dropout = 0.0
            reward_module = AutoModel.from_pretrained(
                pretrained_model_name_or_path=local_path,
                config=model_config,
                torch_dtype=torch.bfloat16,
                attn_implementation="flash_attention_2",
                trust_remote_code=trust_remote_code,
            )

            apply_monkey_patch(
                model=reward_module,
                use_remove_padding=config.model.get("use_remove_padding", False),
                ulysses_sp_size=self.ulysses_sequence_parallel_size,
            )

            reward_module.to(torch.bfloat16)

        auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config)

        fsdp_mesh = self.device_mesh
        sharding_strategy = get_sharding_strategy(fsdp_mesh)

        if config.strategy == "fsdp":
            reward_module = FSDP(
                reward_module,
                param_init_fn=init_fn,
                use_orig_params=False,
                auto_wrap_policy=auto_wrap_policy,
                device_id=get_device_id(),
                sharding_strategy=sharding_strategy,  # zero3
                sync_module_states=True,
                cpu_offload=CPUOffload(offload_params=True),
                forward_prefetch=self.config.model.fsdp_config.forward_prefetch,
                device_mesh=self.device_mesh,
            )
        elif config.strategy == "fsdp2":
            assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
            cpu_offload = CPUOffloadPolicy(pin_memory=True)
            fsdp_kwargs = {
                "mesh": fsdp_mesh,
                "offload_policy": cpu_offload,
                "reshard_after_forward": config.model.fsdp_config.reshard_after_forward,
                "shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]),
            }
            full_state = reward_module.state_dict()
            apply_fsdp2(reward_module, fsdp_kwargs, config.model.fsdp_config)
            fsdp2_load_full_state_dict(reward_module, full_state, fsdp_mesh, cpu_offload)
        else:
            raise NotImplementedError(f"Unknown strategy: {config.strategy}")
        return reward_module

    def _forward_micro_batch(self, micro_batch, start_of_response: int):
        with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16):
            input_ids = micro_batch["input_ids"]
            batch_size, seqlen = input_ids.shape
            attention_mask = micro_batch["attention_mask"]
            position_ids = micro_batch["position_ids"]
            if position_ids.dim() == 3:  # qwen2vl mrope
                position_ids = position_ids.transpose(0, 1)  # (bsz, 3, seqlen) -> (3, bsz, seqlen)

            if self.use_remove_padding:
                raise NotImplementedError("Remove padding is not implemented for elliptical reward model")
            else:
                output = self.reward_module(
                    input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False
                )

                sequence_lengths = attention_mask[:, start_of_response:].sum(dim=1)
                mean_hidden_states = []
                for i, seq_len in enumerate(sequence_lengths):
                    mean_hidden_states.append(
                        output.last_hidden_state[i, start_of_response : start_of_response + seq_len].mean(dim=0)
                    )
                mean_hidden_states = torch.stack(mean_hidden_states)

            return mean_hidden_states

    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
    @DistProfiler.annotate(color="brown")
    def compute_hidden_states(self, data: DataProto):
        import itertools

        from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches

        # Support all hardwares
        data = data.to(get_device_id())
        if self._do_switch_chat_template:
            rm_data = self._switch_chat_template(data)
        else:
            rm_input_ids = data.batch["input_ids"]
            rm_attention_mask = data.batch["attention_mask"]
            rm_position_ids = data.batch["position_ids"]
            rm_inputs = {
                "input_ids": rm_input_ids,
                "attention_mask": rm_attention_mask,
                "position_ids": rm_position_ids,
            }
            rm_data = DataProto.from_dict(rm_inputs)

        # Support all hardwares
        rm_data = rm_data.to(get_device_id())

        # perform forward computation
        with self.ulysses_sharding_manager:
            use_dynamic_bsz = self.config.use_dynamic_bsz
            if use_dynamic_bsz:
                max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
                micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len)
            else:
                micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu)
            output = []
            for micro_batch in micro_batches:
                mean_hidden_states = self._forward_micro_batch(
                    micro_batch, start_of_response=data.batch["prompts"].shape[-1]
                )
                output.append(mean_hidden_states)
            mean_hidden_states = torch.cat(output, dim=0)  # (batch_size)

            # NOTE(Jens): this has not been thoroughly checked
            if use_dynamic_bsz:
                indices = list(itertools.chain.from_iterable(indices))
                assert len(indices) == mean_hidden_states.size(0), f"{len(indices)} vs. {mean_hidden_states.size()}"
                revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
                mean_hidden_states = mean_hidden_states[revert_indices]

            # Note that this is only the scores, may not be the final rewards used to train RL
            output = DataProto.from_dict(tensors={"mean_hidden_states": mean_hidden_states})

        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
        # unshard the root FSDP module
        if self.world_size > 1 and fsdp_version(self.reward_module) == 1:
            self.reward_module._handle.reshard(True)

        output = output.to("cpu")
        return output

    def _compute_bonuses(self, hidden_states, cov_inv, prompt_index: int):
        if self.config.elliptical.reward_type == "leave_one_out":
            if self.persist_covariance:
                raise NotImplementedError("Leave-one-out with persistence is not implemented")
            else:
                bonuses = []
                for i, hidden_state in enumerate(hidden_states):
                    chosen_samp = hidden_state.unsqueeze(1)
                    middle_part = torch.inverse(1 - chosen_samp.t() @ cov_inv @ chosen_samp)
                    leave_one_out_cov_inv = cov_inv + cov_inv @ chosen_samp @ middle_part @ chosen_samp.t() @ cov_inv
                    bonus = (chosen_samp.t() @ leave_one_out_cov_inv @ chosen_samp).flatten().float()
                    bonuses.append(bonus)

                bonuses = torch.concat(bonuses)

        elif self.config.elliptical.reward_type == "leverage":
            if self.persist_covariance:
                hidden_mean = self.mean_hidden_states_mu_dict[prompt_index]
                hidden_mean_counter = self.hidden_mean_counter_dict[prompt_index]

                hidden_states = hidden_states - hidden_mean

                numerator = cov_inv @ hidden_mean.unsqueeze(1) @ hidden_mean.unsqueeze(0) @ cov_inv
                denominator = -1 / hidden_mean_counter + hidden_mean.t() @ cov_inv @ hidden_mean
                cov_inv_mean_adjusted = cov_inv - numerator / denominator
                batch_cov_inv = cov_inv_mean_adjusted.unsqueeze(0).expand(hidden_states.shape[0], -1, -1)
            else:
                batch_cov_inv = cov_inv.unsqueeze(0).expand(hidden_states.shape[0], -1, -1)

            bonuses = (hidden_states.unsqueeze(1) @ batch_cov_inv @ hidden_states.unsqueeze(2)).flatten().float()

        return bonuses

    def _normalize_bonuses(self, bonuses):
        if self.normalization == "none":
            pass
        elif self.normalization == "rnd":
            std = torch.std(bonuses)
            if std > 0:
                bonuses = bonuses / std
        elif self.normalization == "z_score":
            mean = torch.mean(bonuses)
            std = torch.std(bonuses)
            if std > 0:
                bonuses = (bonuses - mean) / std
            else:
                bonuses = bonuses - mean
        else:
            raise ValueError(f"Unknown normalization: {self.normalization}")

        return bonuses

    @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
    @DistProfiler.annotate(color="brown")
    def compute_rm_score(self, data: DataProto):
        if self.sparse_matrix is None:
            d = data.batch["mean_hidden_states"].shape[-1]
            sparse_matrix = self._construct_sparse_matrix(torch.randn(1, d), self.sparse_dim)
            if not self.randomize_sparse_matrix:
                self.sparse_matrix = sparse_matrix
        else:
            sparse_matrix = self.sparse_matrix

        mean_hidden_states = data.batch["mean_hidden_states"].cuda().float()

        # sparse project
        mean_hidden_states = mean_hidden_states @ sparse_matrix.cuda()

        # upgrade to float64
        mean_hidden_states = mean_hidden_states.to(torch.float64)

        seen_uids = set()
        reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32).cuda()
        raw_bonuses_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32).cuda()
        for i in range(len(data)):
            data_item = data[i]
            uid = data_item.non_tensor_batch["uid"]
            if uid in seen_uids:
                continue

            seen_uids.add(uid)
            mask = data.non_tensor_batch["uid"] == uid
            filtered_mean_hidden_states = mean_hidden_states[mask]

            prompt_index = data_item.non_tensor_batch["extra_info"]["index"]

            if self.persist_covariance:
                # first update the mean hidden states mu
                if prompt_index not in self.mean_hidden_states_mu_dict:
                    self.mean_hidden_states_mu_dict[prompt_index] = filtered_mean_hidden_states.mean(dim=0)
                    self.hidden_mean_counter_dict[prompt_index] = mask.sum()
                else:
                    total_count = self.hidden_mean_counter_dict[prompt_index] + mask.sum()
                    old_mu = self.mean_hidden_states_mu_dict[prompt_index]
                    new_mu = (
                        old_mu * self.hidden_mean_counter_dict[prompt_index]
                        + filtered_mean_hidden_states.mean(dim=0) * mask.sum()
                    ) / total_count
                    self.mean_hidden_states_mu_dict[prompt_index] = new_mu
                    self.hidden_mean_counter_dict[prompt_index] = total_count

                # NOTE: we don't center here since otherwise the covariance will accumulate stale means
                final_mean_hidden_states = filtered_mean_hidden_states

                if prompt_index not in self.cov_inv_dict:
                    d = final_mean_hidden_states.shape[-1]
                    self.cov_inv_dict[prompt_index] = torch.eye(d, dtype=torch.float64).cuda() * self.lamb**-1
                cov_inv = self.cov_inv_dict[prompt_index]
            else:
                centered_mean_hidden_states = filtered_mean_hidden_states - filtered_mean_hidden_states.mean(dim=0)
                final_mean_hidden_states = centered_mean_hidden_states

                d = final_mean_hidden_states.shape[-1]
                cov_inv = torch.eye(d, dtype=torch.float64).cuda() * self.lamb**-1

            # update inverse covariance matrix with rank-1 updates
            for hidden_state in final_mean_hidden_states:
                chosen_samp = hidden_state.unsqueeze(1)
                middle_part = torch.inverse(1 + chosen_samp.t() @ cov_inv @ chosen_samp)
                cov_inv = cov_inv - cov_inv @ chosen_samp @ middle_part @ chosen_samp.t() @ cov_inv

            if self.persist_covariance:
                self.cov_inv_dict[prompt_index] = cov_inv

            raw_bonuses = self._compute_bonuses(final_mean_hidden_states, cov_inv, prompt_index)
            normalized_bonuses = self._normalize_bonuses(raw_bonuses)

            prompt_ids = data.batch["prompts"][mask]
            prompt_length = prompt_ids.shape[-1]
            valid_response_lengths = data.batch["attention_mask"][mask, prompt_length:].sum(-1)

            raw_bonuses_tensor[mask, valid_response_lengths - 1] = raw_bonuses
            reward_tensor[mask, valid_response_lengths - 1] = normalized_bonuses

        output = DataProto.from_dict(
            tensors={"rm_scores": reward_tensor}, non_tensors={"raw_bonuses": raw_bonuses_tensor.cpu().numpy()}
        )
        return output.to("cpu")
  • rep_exp/reward_manager: adds the EllipticalRewardManager class that handles combining the external reward (from a verifier) and the elliptical reward.
@register("elliptical")
class EllipticalRewardManager(NaiveRewardManager):
    """The reward manager."""

    def __init__(
        self,
        tokenizer,
        num_examine,
        compute_score=None,
        reward_fn_key="data_source",
        beta: int = 1.0,
        turn_off_elliptical_if_none_correct: bool = False,
        turn_off_elliptical_if_some_correct: bool = False,
        turn_off_elliptical_if_all_correct: bool = False,
        turn_off_elliptical_if_rollout_incorrect: bool = False,
        alpha: float = 1.0,
    ) -> None:
        """
        Initialize the NaiveRewardManager instance.

        Args:
            tokenizer: The tokenizer used to decode token IDs into text.
            num_examine: The number of batches of decoded responses to print to the console for debugging purpose.
            compute_score: A function to compute the reward score. If None, `default_compute_score` will be used.
            reward_fn_key: The key used to access the data source in the non-tensor batch data. Defaults to
                "data_source".
        """
        super().__init__(tokenizer, num_examine, default_compute_score, reward_fn_key)
        self.beta = beta
        self.turn_off_elliptical_if_none_correct = turn_off_elliptical_if_none_correct
        self.turn_off_elliptical_if_some_correct = turn_off_elliptical_if_some_correct
        self.turn_off_elliptical_if_all_correct = turn_off_elliptical_if_all_correct
        self.turn_off_elliptical_if_rollout_incorrect = turn_off_elliptical_if_rollout_incorrect
        self.alpha = alpha

    def __call__(self, data: DataProto, return_dict=False):
        if "rm_scores" not in data.batch:
            # this means we're doing validation, so we don't need to compute the elliptical reward
            return super().__call__(data, return_dict=return_dict)

        reward_extra_info = defaultdict(list)

        intrinsic_reward_tensor = data.batch["rm_scores"]
        data.pop(batch_keys=["rm_scores"])

        extrinsic_reward_result = super().__call__(data, return_dict=True)
        extrinsic_reward_tensor = extrinsic_reward_result["reward_tensor"]
        extrinsic_reward_extra_info = extrinsic_reward_result["reward_extra_info"]

        self._maybe_turn_off_elliptical(data, extrinsic_reward_tensor, intrinsic_reward_tensor)

        reward_tensor = self.alpha * extrinsic_reward_tensor + self.beta * intrinsic_reward_tensor

        # Intrinsic reward extra info
        reward_extra_info["intrinsic_reward"] = intrinsic_reward_tensor.numpy()
        reward_extra_info["beta_scaled_intrinsic_reward"] = self.beta * intrinsic_reward_tensor.numpy()
        reward_extra_info["extrinsic_reward"] = extrinsic_reward_tensor.numpy()
        reward_extra_info["alpha_scaled_extrinsic_reward"] = self.alpha * extrinsic_reward_tensor.numpy()
        reward_extra_info["total_reward"] = reward_tensor.numpy()

        # Update with extrinsic reward extra info
        reward_extra_info.update(extrinsic_reward_extra_info)

        if return_dict:
            return {
                "reward_tensor": reward_tensor,
                "reward_extra_info": reward_extra_info,
            }
        else:
            return reward_tensor

    def _maybe_turn_off_elliptical(
        self, data: DataProto, extrinsic_reward_tensor: torch.Tensor, intrinsic_reward_tensor: torch.Tensor
    ) -> None:
        """
        Potentially turn off the elliptical reward for samples that have one of the following properties:
            (1) any of the rollouts have the correct answer
            (2) all of the rollouts have the correct answer

        Args:
            data (DataProto): The data proto containing the batch data.
            extrinsic_reward_tensor (torch.Tensor): The extrinsic reward tensor.
            intrinsic_reward_tensor (torch.Tensor): The intrinsic reward tensor.

        Returns:
            None
        """
        if self.turn_off_elliptical_if_rollout_incorrect:
            mask = extrinsic_reward_tensor.sum(dim=-1) == 0
            intrinsic_reward_tensor[mask] = 0.0

        visited_uids = set()
        for uid in data.non_tensor_batch["uid"]:
            if uid in visited_uids:
                continue

            visited_uids.add(uid)
            mask = torch.from_numpy(data.non_tensor_batch["uid"] == uid)

            # Potentially turn off elliptical if **no** rollout has the correct answer
            if self.turn_off_elliptical_if_none_correct and extrinsic_reward_tensor[mask].sum() == 0:
                intrinsic_reward_tensor[mask] = 0.0

            # Potentially turn off elliptical if **some** rollouts have the correct answer
            if (
                self.turn_off_elliptical_if_some_correct
                and extrinsic_reward_tensor[mask].sum() > 0
                and extrinsic_reward_tensor[mask].sum() < mask.sum()
            ):
                intrinsic_reward_tensor[mask] = 0.0

            # Potentially turn off elliptical if **all** rollouts have the correct answer
            if self.turn_off_elliptical_if_all_correct and extrinsic_reward_tensor[mask].sum() == mask.sum():
                intrinsic_reward_tensor[mask] = 0.0
  • rep_exp/utils/tracking.py and rep_exp/utils/aggregate_logger.py: adds the JsonEvalLogger which can be used to log final evaluation results to a json file.
class JsonEvalLogger:
    """
    A logger that logs to a json file.
    Args:
        save_path: The path to the checkpoint to resume from.
        task: The task name, used to name the experiment.
    """

    def __init__(self, save_path: str, task: str):
        self.root = "eval"
        if save_path is not None and save_path != "":
            self.experiment_name = save_path.split("/")[-2]
            self.checkpoint_type = save_path.split("/")[-1]
        else:
            self.experiment_name = f"{task}_untrained"
            self.checkpoint_type = ""

    def flush(self):
        pass

    def log(self, data, step):
        # Create eval folder
        save_folder = os.path.join(self.root, self.experiment_name, self.checkpoint_type)
        os.makedirs(save_folder, exist_ok=True)

        # Save to json
        with open(os.path.join(save_folder, "eval.json"), "w") as f:
            json.dump(data, f)
  • rep_exp/data_preprocess.py: contains a script for each of MATH, GSM8K, and AIME 2024 that provide the logic for getting the dataset splits (train, dev, test).
  • rep_exp/reward_score/__init__.py: copy of verl/utils/reward_score/__init__.py but uses math_verify for dapo training
elif data_source in ["math_dapo", "math", "math_dapo_reasoning"] or data_source.startswith("aime"):
    # res = math_dapo.compute_score(solution_str, ground_truth)
    from verl.utils.reward_score import math_verify

    res = math_verify.compute_score(solution_str, ground_truth)
  • rep_exp/plot_pass_at_k.py: sample script that provides basic plotting code that plots a pass@k curve based on the logged json files that are saved after running the evaluation script.

  • rep_exp/config/rep_exp_trainer.yaml: overwrites and adds any RepExp specific configuration parameters.

  • rep_exp/train_elliptical.sh: training script

  • rep_exp/model_merge.sh: script to merge model checkpoints

  • rep_exp/eval.sh: evaluation script

Checklist Before Submitting

[!IMPORTANT] Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

jens321 avatar Nov 24 '25 20:11 jens321

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

CLAassistant avatar Nov 24 '25 20:11 CLAassistant

@jens321 Hi, thanks for your contribution. We're moving recipe to a separate project verl-project/verl-recipe, could you submit a PR to this project? https://github.com/volcengine/verl/pull/4283

wuxibin89 avatar Nov 25 '25 05:11 wuxibin89

@wuxibin89 Thanks for the quick reply! Sounds good, will go ahead and close this one then.

jens321 avatar Nov 25 '25 11:11 jens321