verl
verl copied to clipboard
[recipe, algo] feat: Representation-based Exploration (RepExp)
What does this PR do?
Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.
Add support for the training and evaluation of the RepExp method introduced in section 5 of the paper Representation-Based Exploration for Language Models: From Test-Time to Post-Training.
Checklist Before Starting
- [x] Search for similar PRs. Paste at least one query link here: https://github.com/volcengine/verl/pull/1830
- [x] Format the PR title as
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data- If this PR involves multiple modules, separate them with
,like[megatron, fsdp, doc] {type}is infeat,fix,refactor,chore,test- If this PR breaks any API (CLI arguments, config, function signature, etc.), add
[BREAKING]to the beginning of the title. - Example:
[BREAKING][fsdp, megatron] feat: dynamic batching
Test
For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.
Please refer to Figure 2 in https://arxiv.org/abs/2510.11686 for evaluation results.
API and Usage Example
Demonstrate how the API changes if any, and provide usage example(s) if possible.
The general format for training with our method is as follows
sh recipe/rep_exp/train_elliptical.sh $TASK $SPARSE_DIM $BETA $SEED
where $TASK is the task name, $SPARSE_DIM is the sparse dimension, $BETA is the beta parameter, and $SEED is the seed.
For example for training on MATH with the original parameters from the paper, one would do
sh recipe/rep_exp/train_elliptical.sh math 32 0.01 42
Once done training, one can evaluate the model on the test set by following two steps.
- Merge the model checkpoint.
This is necessary because the model checkpoint is saved in multiple shards (depending on the nubmer of GPUs), and we need to merge them into a single checkpoint.
sh recipe/rep_exp/model_merge.sh /path/to/global_step_X/actor # where X is the global step of the checkpoint with the best pass@1 on dev
- Evaluate the merged model.
sh recipe/rep_exp/eval.sh $TASK /path/to/global_step_X/actor/hf #where X is the global step of the checkpoint with the best pass@1 on dev
Design & Code Changes
Demonstrate the high-level design if this PR is complex, and list the specific changes.
All changes are contained in the recipe/rep_exp folder and summarized below:
rep_exp/main_rep_exp.py: copy ofverl/trainer/main_ppo.pybut importsEllipticalRewardModelWorkerinstead of the standardRewardModelWorkerrep_exp/rep_exp_trainer.py: copy ofverl/trainer/ppo/ray_trainer.pybut adds computing of hidden states and elliptical reward scores as follows:
with marked_timer("reward", timing_raw, color="yellow"):
# compute reward model score
if self.use_rm and "rm_scores" not in batch.batch.keys():
if self.config.reward_model.elliptical.enable:
hidden_states = self.rm_wg.compute_hidden_states(batch)
batch = batch.union(hidden_states)
reward_tensor = self.rm_wg.compute_rm_score(batch)
else:
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)
In addition, there are a few lines of code that help keep track of the best pass@1 seen so far on validation.
rep_exp/metric_utils.py: adds and extends some utility functions that provide additional metrics for our method that could be helpful for debugging.rep_exp/workers/elliptical_reward_model_worker.py: adds the EllipticalRewardModelWorker class which provides functionality for (1) computing hidden states and (2) computing elliptical reward scores based on the hidden states.
class EllipticalRewardModelWorker(RewardModelWorker):
def __init__(self, config):
super().__init__(config)
self.lamb = config.elliptical.lamb
self.normalization = config.elliptical.normalization
self.sparse_dim = config.elliptical.sparse_dim
self.sparse_matrix = None
self.randomize_sparse_matrix = config.elliptical.randomize_sparse_matrix
self.persist_covariance = config.elliptical.persist_covariance
self.cov_inv_dict = {}
self.mean_hidden_states_mu_dict = {}
self.hidden_mean_counter_dict = {}
@staticmethod
def _construct_sparse_matrix(features: torch.Tensor, sparse_dim: int) -> torch.Tensor:
from sklearn.random_projection import SparseRandomProjection
sparse_proj = SparseRandomProjection(sparse_dim, density="auto")
sparse_proj.fit(features)
sparse_matrix = sparse_proj.components_
sparse_matrix_coo = sparse_matrix.tocoo()
# Convert the row and col lists to numpy arrays and then to a LongTensor (speed up)
indices = torch.LongTensor(np.array([sparse_matrix_coo.row, sparse_matrix_coo.col]))
values = torch.FloatTensor(sparse_matrix_coo.data)
sparse_mat = torch.sparse_coo_tensor(indices, values, [sparse_dim, features.shape[1]]).t()
return sparse_mat
def _build_model(self, config):
# the following line is necessary
from torch.distributed.fsdp import CPUOffload
from transformers import AutoConfig, AutoModel
use_shm = config.model.get("use_shm", False)
# download the checkpoint from hdfs
local_path = copy_to_local(config.model.path, use_shm=use_shm)
if self.config.model.input_tokenizer is None:
self._do_switch_chat_template = False
else:
self._do_switch_chat_template = True
input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer, use_shm=use_shm)
self.input_tokenizer = hf_tokenizer(
input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False)
)
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False))
trust_remote_code = config.model.get("trust_remote_code", False)
model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
model_config.num_labels = 1
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
init_context = get_init_weight_context_manager(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh
)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
model_config.classifier_dropout = 0.0
reward_module = AutoModel.from_pretrained(
pretrained_model_name_or_path=local_path,
config=model_config,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=trust_remote_code,
)
apply_monkey_patch(
model=reward_module,
use_remove_padding=config.model.get("use_remove_padding", False),
ulysses_sp_size=self.ulysses_sequence_parallel_size,
)
reward_module.to(torch.bfloat16)
auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config)
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
if config.strategy == "fsdp":
reward_module = FSDP(
reward_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=get_device_id(),
sharding_strategy=sharding_strategy, # zero3
sync_module_states=True,
cpu_offload=CPUOffload(offload_params=True),
forward_prefetch=self.config.model.fsdp_config.forward_prefetch,
device_mesh=self.device_mesh,
)
elif config.strategy == "fsdp2":
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
cpu_offload = CPUOffloadPolicy(pin_memory=True)
fsdp_kwargs = {
"mesh": fsdp_mesh,
"offload_policy": cpu_offload,
"reshard_after_forward": config.model.fsdp_config.reshard_after_forward,
"shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]),
}
full_state = reward_module.state_dict()
apply_fsdp2(reward_module, fsdp_kwargs, config.model.fsdp_config)
fsdp2_load_full_state_dict(reward_module, full_state, fsdp_mesh, cpu_offload)
else:
raise NotImplementedError(f"Unknown strategy: {config.strategy}")
return reward_module
def _forward_micro_batch(self, micro_batch, start_of_response: int):
with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16):
input_ids = micro_batch["input_ids"]
batch_size, seqlen = input_ids.shape
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]
if position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
if self.use_remove_padding:
raise NotImplementedError("Remove padding is not implemented for elliptical reward model")
else:
output = self.reward_module(
input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False
)
sequence_lengths = attention_mask[:, start_of_response:].sum(dim=1)
mean_hidden_states = []
for i, seq_len in enumerate(sequence_lengths):
mean_hidden_states.append(
output.last_hidden_state[i, start_of_response : start_of_response + seq_len].mean(dim=0)
)
mean_hidden_states = torch.stack(mean_hidden_states)
return mean_hidden_states
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
@DistProfiler.annotate(color="brown")
def compute_hidden_states(self, data: DataProto):
import itertools
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
# Support all hardwares
data = data.to(get_device_id())
if self._do_switch_chat_template:
rm_data = self._switch_chat_template(data)
else:
rm_input_ids = data.batch["input_ids"]
rm_attention_mask = data.batch["attention_mask"]
rm_position_ids = data.batch["position_ids"]
rm_inputs = {
"input_ids": rm_input_ids,
"attention_mask": rm_attention_mask,
"position_ids": rm_position_ids,
}
rm_data = DataProto.from_dict(rm_inputs)
# Support all hardwares
rm_data = rm_data.to(get_device_id())
# perform forward computation
with self.ulysses_sharding_manager:
use_dynamic_bsz = self.config.use_dynamic_bsz
if use_dynamic_bsz:
max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len)
else:
micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu)
output = []
for micro_batch in micro_batches:
mean_hidden_states = self._forward_micro_batch(
micro_batch, start_of_response=data.batch["prompts"].shape[-1]
)
output.append(mean_hidden_states)
mean_hidden_states = torch.cat(output, dim=0) # (batch_size)
# NOTE(Jens): this has not been thoroughly checked
if use_dynamic_bsz:
indices = list(itertools.chain.from_iterable(indices))
assert len(indices) == mean_hidden_states.size(0), f"{len(indices)} vs. {mean_hidden_states.size()}"
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
mean_hidden_states = mean_hidden_states[revert_indices]
# Note that this is only the scores, may not be the final rewards used to train RL
output = DataProto.from_dict(tensors={"mean_hidden_states": mean_hidden_states})
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1 and fsdp_version(self.reward_module) == 1:
self.reward_module._handle.reshard(True)
output = output.to("cpu")
return output
def _compute_bonuses(self, hidden_states, cov_inv, prompt_index: int):
if self.config.elliptical.reward_type == "leave_one_out":
if self.persist_covariance:
raise NotImplementedError("Leave-one-out with persistence is not implemented")
else:
bonuses = []
for i, hidden_state in enumerate(hidden_states):
chosen_samp = hidden_state.unsqueeze(1)
middle_part = torch.inverse(1 - chosen_samp.t() @ cov_inv @ chosen_samp)
leave_one_out_cov_inv = cov_inv + cov_inv @ chosen_samp @ middle_part @ chosen_samp.t() @ cov_inv
bonus = (chosen_samp.t() @ leave_one_out_cov_inv @ chosen_samp).flatten().float()
bonuses.append(bonus)
bonuses = torch.concat(bonuses)
elif self.config.elliptical.reward_type == "leverage":
if self.persist_covariance:
hidden_mean = self.mean_hidden_states_mu_dict[prompt_index]
hidden_mean_counter = self.hidden_mean_counter_dict[prompt_index]
hidden_states = hidden_states - hidden_mean
numerator = cov_inv @ hidden_mean.unsqueeze(1) @ hidden_mean.unsqueeze(0) @ cov_inv
denominator = -1 / hidden_mean_counter + hidden_mean.t() @ cov_inv @ hidden_mean
cov_inv_mean_adjusted = cov_inv - numerator / denominator
batch_cov_inv = cov_inv_mean_adjusted.unsqueeze(0).expand(hidden_states.shape[0], -1, -1)
else:
batch_cov_inv = cov_inv.unsqueeze(0).expand(hidden_states.shape[0], -1, -1)
bonuses = (hidden_states.unsqueeze(1) @ batch_cov_inv @ hidden_states.unsqueeze(2)).flatten().float()
return bonuses
def _normalize_bonuses(self, bonuses):
if self.normalization == "none":
pass
elif self.normalization == "rnd":
std = torch.std(bonuses)
if std > 0:
bonuses = bonuses / std
elif self.normalization == "z_score":
mean = torch.mean(bonuses)
std = torch.std(bonuses)
if std > 0:
bonuses = (bonuses - mean) / std
else:
bonuses = bonuses - mean
else:
raise ValueError(f"Unknown normalization: {self.normalization}")
return bonuses
@register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
@DistProfiler.annotate(color="brown")
def compute_rm_score(self, data: DataProto):
if self.sparse_matrix is None:
d = data.batch["mean_hidden_states"].shape[-1]
sparse_matrix = self._construct_sparse_matrix(torch.randn(1, d), self.sparse_dim)
if not self.randomize_sparse_matrix:
self.sparse_matrix = sparse_matrix
else:
sparse_matrix = self.sparse_matrix
mean_hidden_states = data.batch["mean_hidden_states"].cuda().float()
# sparse project
mean_hidden_states = mean_hidden_states @ sparse_matrix.cuda()
# upgrade to float64
mean_hidden_states = mean_hidden_states.to(torch.float64)
seen_uids = set()
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32).cuda()
raw_bonuses_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32).cuda()
for i in range(len(data)):
data_item = data[i]
uid = data_item.non_tensor_batch["uid"]
if uid in seen_uids:
continue
seen_uids.add(uid)
mask = data.non_tensor_batch["uid"] == uid
filtered_mean_hidden_states = mean_hidden_states[mask]
prompt_index = data_item.non_tensor_batch["extra_info"]["index"]
if self.persist_covariance:
# first update the mean hidden states mu
if prompt_index not in self.mean_hidden_states_mu_dict:
self.mean_hidden_states_mu_dict[prompt_index] = filtered_mean_hidden_states.mean(dim=0)
self.hidden_mean_counter_dict[prompt_index] = mask.sum()
else:
total_count = self.hidden_mean_counter_dict[prompt_index] + mask.sum()
old_mu = self.mean_hidden_states_mu_dict[prompt_index]
new_mu = (
old_mu * self.hidden_mean_counter_dict[prompt_index]
+ filtered_mean_hidden_states.mean(dim=0) * mask.sum()
) / total_count
self.mean_hidden_states_mu_dict[prompt_index] = new_mu
self.hidden_mean_counter_dict[prompt_index] = total_count
# NOTE: we don't center here since otherwise the covariance will accumulate stale means
final_mean_hidden_states = filtered_mean_hidden_states
if prompt_index not in self.cov_inv_dict:
d = final_mean_hidden_states.shape[-1]
self.cov_inv_dict[prompt_index] = torch.eye(d, dtype=torch.float64).cuda() * self.lamb**-1
cov_inv = self.cov_inv_dict[prompt_index]
else:
centered_mean_hidden_states = filtered_mean_hidden_states - filtered_mean_hidden_states.mean(dim=0)
final_mean_hidden_states = centered_mean_hidden_states
d = final_mean_hidden_states.shape[-1]
cov_inv = torch.eye(d, dtype=torch.float64).cuda() * self.lamb**-1
# update inverse covariance matrix with rank-1 updates
for hidden_state in final_mean_hidden_states:
chosen_samp = hidden_state.unsqueeze(1)
middle_part = torch.inverse(1 + chosen_samp.t() @ cov_inv @ chosen_samp)
cov_inv = cov_inv - cov_inv @ chosen_samp @ middle_part @ chosen_samp.t() @ cov_inv
if self.persist_covariance:
self.cov_inv_dict[prompt_index] = cov_inv
raw_bonuses = self._compute_bonuses(final_mean_hidden_states, cov_inv, prompt_index)
normalized_bonuses = self._normalize_bonuses(raw_bonuses)
prompt_ids = data.batch["prompts"][mask]
prompt_length = prompt_ids.shape[-1]
valid_response_lengths = data.batch["attention_mask"][mask, prompt_length:].sum(-1)
raw_bonuses_tensor[mask, valid_response_lengths - 1] = raw_bonuses
reward_tensor[mask, valid_response_lengths - 1] = normalized_bonuses
output = DataProto.from_dict(
tensors={"rm_scores": reward_tensor}, non_tensors={"raw_bonuses": raw_bonuses_tensor.cpu().numpy()}
)
return output.to("cpu")
rep_exp/reward_manager: adds the EllipticalRewardManager class that handles combining the external reward (from a verifier) and the elliptical reward.
@register("elliptical")
class EllipticalRewardManager(NaiveRewardManager):
"""The reward manager."""
def __init__(
self,
tokenizer,
num_examine,
compute_score=None,
reward_fn_key="data_source",
beta: int = 1.0,
turn_off_elliptical_if_none_correct: bool = False,
turn_off_elliptical_if_some_correct: bool = False,
turn_off_elliptical_if_all_correct: bool = False,
turn_off_elliptical_if_rollout_incorrect: bool = False,
alpha: float = 1.0,
) -> None:
"""
Initialize the NaiveRewardManager instance.
Args:
tokenizer: The tokenizer used to decode token IDs into text.
num_examine: The number of batches of decoded responses to print to the console for debugging purpose.
compute_score: A function to compute the reward score. If None, `default_compute_score` will be used.
reward_fn_key: The key used to access the data source in the non-tensor batch data. Defaults to
"data_source".
"""
super().__init__(tokenizer, num_examine, default_compute_score, reward_fn_key)
self.beta = beta
self.turn_off_elliptical_if_none_correct = turn_off_elliptical_if_none_correct
self.turn_off_elliptical_if_some_correct = turn_off_elliptical_if_some_correct
self.turn_off_elliptical_if_all_correct = turn_off_elliptical_if_all_correct
self.turn_off_elliptical_if_rollout_incorrect = turn_off_elliptical_if_rollout_incorrect
self.alpha = alpha
def __call__(self, data: DataProto, return_dict=False):
if "rm_scores" not in data.batch:
# this means we're doing validation, so we don't need to compute the elliptical reward
return super().__call__(data, return_dict=return_dict)
reward_extra_info = defaultdict(list)
intrinsic_reward_tensor = data.batch["rm_scores"]
data.pop(batch_keys=["rm_scores"])
extrinsic_reward_result = super().__call__(data, return_dict=True)
extrinsic_reward_tensor = extrinsic_reward_result["reward_tensor"]
extrinsic_reward_extra_info = extrinsic_reward_result["reward_extra_info"]
self._maybe_turn_off_elliptical(data, extrinsic_reward_tensor, intrinsic_reward_tensor)
reward_tensor = self.alpha * extrinsic_reward_tensor + self.beta * intrinsic_reward_tensor
# Intrinsic reward extra info
reward_extra_info["intrinsic_reward"] = intrinsic_reward_tensor.numpy()
reward_extra_info["beta_scaled_intrinsic_reward"] = self.beta * intrinsic_reward_tensor.numpy()
reward_extra_info["extrinsic_reward"] = extrinsic_reward_tensor.numpy()
reward_extra_info["alpha_scaled_extrinsic_reward"] = self.alpha * extrinsic_reward_tensor.numpy()
reward_extra_info["total_reward"] = reward_tensor.numpy()
# Update with extrinsic reward extra info
reward_extra_info.update(extrinsic_reward_extra_info)
if return_dict:
return {
"reward_tensor": reward_tensor,
"reward_extra_info": reward_extra_info,
}
else:
return reward_tensor
def _maybe_turn_off_elliptical(
self, data: DataProto, extrinsic_reward_tensor: torch.Tensor, intrinsic_reward_tensor: torch.Tensor
) -> None:
"""
Potentially turn off the elliptical reward for samples that have one of the following properties:
(1) any of the rollouts have the correct answer
(2) all of the rollouts have the correct answer
Args:
data (DataProto): The data proto containing the batch data.
extrinsic_reward_tensor (torch.Tensor): The extrinsic reward tensor.
intrinsic_reward_tensor (torch.Tensor): The intrinsic reward tensor.
Returns:
None
"""
if self.turn_off_elliptical_if_rollout_incorrect:
mask = extrinsic_reward_tensor.sum(dim=-1) == 0
intrinsic_reward_tensor[mask] = 0.0
visited_uids = set()
for uid in data.non_tensor_batch["uid"]:
if uid in visited_uids:
continue
visited_uids.add(uid)
mask = torch.from_numpy(data.non_tensor_batch["uid"] == uid)
# Potentially turn off elliptical if **no** rollout has the correct answer
if self.turn_off_elliptical_if_none_correct and extrinsic_reward_tensor[mask].sum() == 0:
intrinsic_reward_tensor[mask] = 0.0
# Potentially turn off elliptical if **some** rollouts have the correct answer
if (
self.turn_off_elliptical_if_some_correct
and extrinsic_reward_tensor[mask].sum() > 0
and extrinsic_reward_tensor[mask].sum() < mask.sum()
):
intrinsic_reward_tensor[mask] = 0.0
# Potentially turn off elliptical if **all** rollouts have the correct answer
if self.turn_off_elliptical_if_all_correct and extrinsic_reward_tensor[mask].sum() == mask.sum():
intrinsic_reward_tensor[mask] = 0.0
rep_exp/utils/tracking.pyandrep_exp/utils/aggregate_logger.py: adds the JsonEvalLogger which can be used to log final evaluation results to a json file.
class JsonEvalLogger:
"""
A logger that logs to a json file.
Args:
save_path: The path to the checkpoint to resume from.
task: The task name, used to name the experiment.
"""
def __init__(self, save_path: str, task: str):
self.root = "eval"
if save_path is not None and save_path != "":
self.experiment_name = save_path.split("/")[-2]
self.checkpoint_type = save_path.split("/")[-1]
else:
self.experiment_name = f"{task}_untrained"
self.checkpoint_type = ""
def flush(self):
pass
def log(self, data, step):
# Create eval folder
save_folder = os.path.join(self.root, self.experiment_name, self.checkpoint_type)
os.makedirs(save_folder, exist_ok=True)
# Save to json
with open(os.path.join(save_folder, "eval.json"), "w") as f:
json.dump(data, f)
rep_exp/data_preprocess.py: contains a script for each of MATH, GSM8K, and AIME 2024 that provide the logic for getting the dataset splits (train, dev, test).rep_exp/reward_score/__init__.py: copy ofverl/utils/reward_score/__init__.pybut usesmath_verifyfor dapo training
elif data_source in ["math_dapo", "math", "math_dapo_reasoning"] or data_source.startswith("aime"):
# res = math_dapo.compute_score(solution_str, ground_truth)
from verl.utils.reward_score import math_verify
res = math_verify.compute_score(solution_str, ground_truth)
-
rep_exp/plot_pass_at_k.py: sample script that provides basic plotting code that plots a pass@k curve based on the logged json files that are saved after running the evaluation script. -
rep_exp/config/rep_exp_trainer.yaml: overwrites and adds any RepExp specific configuration parameters. -
rep_exp/train_elliptical.sh: training script -
rep_exp/model_merge.sh: script to merge model checkpoints -
rep_exp/eval.sh: evaluation script
Checklist Before Submitting
[!IMPORTANT] Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
- [x] Read the Contribute Guide.
- [x] Apply pre-commit checks:
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always - [x] Add / Update the documentation.
- [x] Add unit or end-to-end test(s) to the CI workflow to cover all the code. If not feasible, explain why: algorithm support
- [x] Once your PR is ready for CI, send a message in the
ci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.
@jens321 Hi, thanks for your contribution. We're moving recipe to a separate project verl-project/verl-recipe, could you submit a PR to this project? https://github.com/volcengine/verl/pull/4283
@wuxibin89 Thanks for the quick reply! Sounds good, will go ahead and close this one then.