pyreft
pyreft copied to clipboard
[P1] Saving/loading issues
From @Jemoka when trying to save/load a bert.
File ~/Documents/Projects/dropval/playground/dropval/trainers/reft.py:213, in ReFTrainer.load(self, path)
210 del model.config.__dict__["use_cache"]
211 model = model.train()
--> 213 self.model = pyreft.ReftModel.load(
214 str(Path(path)/"intervention"),
215 model = model
216 )
File /opt/homebrew/Caskroom/miniforge/base/envs/generic/lib/python3.11/site-packages/pyreft/reft_model.py:26, in ReftModel.load(*args, **kwargs)
24 @staticmethod
25 def load(*args, **kwargs):
---> 26 model = pv.IntervenableModel.load(*args, **kwargs)
27 return ReftModel._convert_to_reft_model(model)
File /opt/homebrew/Caskroom/miniforge/base/envs/generic/lib/python3.11/site-packages/pyvene/models/intervenable_base.py:547, in IntervenableModel.load(load_directory, model, local_directory, from_huggingface_hub)
543 casted_representations += [
544 RepresentationConfig(*representation_opts)
545 ]
546 saving_config.representations = casted_representations
--> 547 intervenable = IntervenableModel(saving_config, model)
549 # load binary files
550 for i, (k, v) in enumerate(intervenable.interventions.items()):
File /opt/homebrew/Caskroom/miniforge/base/envs/generic/lib/python3.11/site-packages/pyvene/models/intervenable_base.py:116, in IntervenableModel.__init__(self, config, model, **kwargs)
110 intervention_function = (
111 intervention_type
112 if type(intervention_type) != list
113 else intervention_type[i]
114 )
115 all_metadata = representation._asdict()
--> 116 component_dim = get_dimension_by_component(
117 get_internal_model_type(model), model.config,
118 representation.component
119 )
120 if component_dim is not None:
121 component_dim *= int(representation.max_number_of_units)
File /opt/homebrew/Caskroom/miniforge/base/envs/generic/lib/python3.11/site-packages/pyvene/models/modeling_utils.py:100, in get_dimension_by_component(model_type, model_config, component)
97 def get_dimension_by_component(model_type, model_config, component) -> int:
98 """Based on the representation, get the aligning dimension size."""
--> 100 if component not in type_to_dimension_mapping[model_type]:
101 return None
103 dimension_proposals = type_to_dimension_mapping[model_type][component]
KeyError: <class 'transformers.models.bert.modeling_bert.BertForMaskedLM'>