RL4LMs
RL4LMs copied to clipboard
Problem with BLEURT reward function
BLEURT reward function fails with TypeError: cannot pickle '_thread.RLock' object
in multiprocessing environments.
Probably because it can't pickle Tensorflow model to send to environment subprocess.
Tested on both local and colab environment.
Here is the full stacktrace:
│ /home/eublefar/RL4LMs/rl4lms/envs/text_generation/training_utils.py:149 in __init__ │
│ │
│ 146 │ │ self._train_eval_config = train_eval_config │
│ 147 │ │ self._tracker = tracker │
│ 148 │ │ self._experiment_name = experiment_name │
│ ❱ 149 │ │ self._setup() │
│ 150 │ │
│ 151 │ def _setup(self): │
│ 152 │ │ # load trainer state from available previous checkpoint if available │
│ │
│ /home/eublefar/RL4LMs/rl4lms/envs/text_generation/training_utils.py:162 in _setup │
│ │
│ 159 │ │ │ self._train_eval_config.get("metrics", [])) │
│ 160 │ │ self._samples_by_split = build_datapool( │
│ 161 │ │ │ self._datapool_config) │
│ ❱ 162 │ │ self._env = build_env(self._env_config, self._reward_fn, │
│ 163 │ │ │ │ │ │ │ self._tokenizer, self._samples_by_split["train"]) │
│ 164 │ │ self._alg = build_alg(self._on_policy_alg_config, │
│ 165 │ │ │ │ │ │ │ self._env, self._tracker, │
│ │
│ /home/eublefar/RL4LMs/rl4lms/envs/text_generation/training_utils.py:90 in build_env │
│ │
│ 87 │ │ "samples": train_samples, │
│ 88 │ } │
│ 89 │ env_kwargs = {**env_kwargs, **env_config.get("args", {})} │
│ ❱ 90 │ env = make_vec_env(TextGenEnv, │
│ 91 │ │ │ │ │ n_envs=env_config.get( │
│ 92 │ │ │ │ │ │ "n_envs", 1), │
│ 93 │ │ │ │ │ vec_env_cls=SubprocVecEnv, │
│ │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/site-packages/stable_baselines3/common/e │
│ nv_util.py:105 in make_vec_env │
│ │
│ 102 │ │ # Default: use a DummyVecEnv │
│ 103 │ │ vec_env_cls = DummyVecEnv │
│ 104 │ │
│ ❱ 105 │ return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_en │
│ 106 │
│ 107 │
│ 108 def make_atari_env( │
│ │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/site-packages/stable_baselines3/common/v │
│ ec_env/subproc_vec_env.py:106 in __init__ │
│ │
│ 103 │ │ │ args = (work_remote, remote, CloudpickleWrapper(env_fn)) │
│ 104 │ │ │ # daemon=True: if the main process crashes, we should not cause things │
│ 105 │ │ │ process = ctx.Process(target=_worker, args=args, daemon=True) # pytype │
│ ❱ 106 │ │ │ process.start() │
│ 107 │ │ │ self.processes.append(process) │
│ 108 │ │ │ work_remote.close() │
│ 109 │
│ │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/process.py:121 in start │
│ │
│ 118 │ │ assert not _current_process._config.get('daemon'), \ │
│ 119 │ │ │ 'daemonic processes are not allowed to have children' │
│ 120 │ │ _cleanup() │
│ ❱ 121 │ │ self._popen = self._Popen(self) │
│ 122 │ │ self._sentinel = self._popen.sentinel │
│ 123 │ │ # Avoid a refcycle if the target function holds an indirect │
│ 124 │ │ # reference to the process object (see bpo-30775) │
│ │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/context.py:291 in _Popen │
│ │
│ 288 │ │ @staticmethod │
│ 289 │ │ def _Popen(process_obj): │
│ 290 │ │ │ from .popen_forkserver import Popen │
│ ❱ 291 │ │ │ return Popen(process_obj) │
│ 292 │ │
│ 293 │ class ForkContext(BaseContext): │
│ 294 │ │ _name = 'fork' │
│ │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/popen_forkserver.py:35 │
│ in __init__ │
│ │
│ 32 │ │
│ 33 │ def __init__(self, process_obj): │
│ 34 │ │ self._fds = [] │
│ ❱ 35 │ │ super().__init__(process_obj) │
│ 36 │ │
│ 37 │ def duplicate_for_child(self, fd): │
│ 38 │ │ self._fds.append(fd) │
│ │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/popen_fork.py:19 in │
│ __init__ │
│ │
│ 16 │ │ util._flush_std_streams() │
│ 17 │ │ self.returncode = None │
│ 18 │ │ self.finalizer = None │
│ ❱ 19 │ │ self._launch(process_obj) │
│ 20 │ │
│ 21 │ def duplicate_for_child(self, fd): │
│ 22 │ │ return fd │
│ │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/popen_forkserver.py:47 │
│ in _launch │
│ │
│ 44 │ │ set_spawning_popen(self) │
│ 45 │ │ try: │
│ 46 │ │ │ reduction.dump(prep_data, buf) │
│ ❱ 47 │ │ │ reduction.dump(process_obj, buf) │
│ 48 │ │ finally: │
│ 49 │ │ │ set_spawning_popen(None) │
│ 50 │
│ │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/multiprocessing/reduction.py:60 in dump │
│ │
│ 57 │
│ 58 def dump(obj, file, protocol=None): │
│ 59 │ '''Replacement for pickle.dump() using ForkingPickler.''' │
│ ❱ 60 │ ForkingPickler(file, protocol).dump(obj) │
│ 61 │
│ 62 # │
│ 63 # Platform specific definitions │
│ │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/site-packages/stable_baselines3/common/v │
│ ec_env/base_vec_env.py:371 in __getstate__ │
│ │
│ 368 │ │ self.var = var │
│ 369 │ │
│ 370 │ def __getstate__(self) -> Any: │
│ ❱ 371 │ │ return cloudpickle.dumps(self.var) │
│ 372 │ │
│ 373 │ def __setstate__(self, var: Any) -> None: │
│ 374 │ │ self.var = cloudpickle.loads(var) │
│ │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/site-packages/cloudpickle/cloudpickle_fa │
│ st.py:73 in dumps │
│ │
│ 70 │ │ │ cp = CloudPickler( │
│ 71 │ │ │ │ file, protocol=protocol, buffer_callback=buffer_callback │
│ 72 │ │ │ ) │
│ ❱ 73 │ │ │ cp.dump(obj) │
│ 74 │ │ │ return file.getvalue() │
│ 75 │
│ 76 else: │
│ │
│ /home/eublefar/miniconda3/envs/gpt/lib/python3.9/site-packages/cloudpickle/cloudpickle_fa │
│ st.py:632 in dump │
│ │
│ 629 │ │
│ 630 │ def dump(self, obj): │
│ 631 │ │ try: │
│ ❱ 632 │ │ │ return Pickler.dump(self, obj) │
│ 633 │ │ except RuntimeError as e: │
│ 634 │ │ │ if "recursion" in e.args[0]: │
│ 635 │ │ │ │ msg = ( │
╰───────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: cannot pickle '_thread.RLock' object