pyreft icon indicating copy to clipboard operation
pyreft copied to clipboard

[P1] Saving/loading issues

Open aryamanarora opened this issue 6 months ago • 0 comments

From @Jemoka when trying to save/load a bert.

File ~/Documents/Projects/dropval/playground/dropval/trainers/reft.py:213, in ReFTrainer.load(self, path)
    210 del model.config.__dict__["use_cache"]
    211 model = model.train()
--> 213 self.model = pyreft.ReftModel.load(
    214     str(Path(path)/"intervention"),
    215     model = model
    216 )

File /opt/homebrew/Caskroom/miniforge/base/envs/generic/lib/python3.11/site-packages/pyreft/reft_model.py:26, in ReftModel.load(*args, **kwargs)
     24 @staticmethod
     25 def load(*args, **kwargs):
---> 26     model = pv.IntervenableModel.load(*args, **kwargs)
     27     return ReftModel._convert_to_reft_model(model)

File /opt/homebrew/Caskroom/miniforge/base/envs/generic/lib/python3.11/site-packages/pyvene/models/intervenable_base.py:547, in IntervenableModel.load(load_directory, model, local_directory, from_huggingface_hub)
    543     casted_representations += [
    544         RepresentationConfig(*representation_opts)
    545     ]
    546 saving_config.representations = casted_representations
--> 547 intervenable = IntervenableModel(saving_config, model)
    549 # load binary files
    550 for i, (k, v) in enumerate(intervenable.interventions.items()):

File /opt/homebrew/Caskroom/miniforge/base/envs/generic/lib/python3.11/site-packages/pyvene/models/intervenable_base.py:116, in IntervenableModel.__init__(self, config, model, **kwargs)
    110 intervention_function = (
    111     intervention_type
    112     if type(intervention_type) != list
    113     else intervention_type[i]
    114 )
    115 all_metadata = representation._asdict()
--> 116 component_dim = get_dimension_by_component(
    117     get_internal_model_type(model), model.config, 
    118     representation.component
    119 )
    120 if component_dim is not None:
    121     component_dim *= int(representation.max_number_of_units)

File /opt/homebrew/Caskroom/miniforge/base/envs/generic/lib/python3.11/site-packages/pyvene/models/modeling_utils.py:100, in get_dimension_by_component(model_type, model_config, component)
     97 def get_dimension_by_component(model_type, model_config, component) -> int:
     98     """Based on the representation, get the aligning dimension size."""
--> 100     if component not in type_to_dimension_mapping[model_type]:
    101         return None
    103     dimension_proposals = type_to_dimension_mapping[model_type][component]

KeyError: <class 'transformers.models.bert.modeling_bert.BertForMaskedLM'>

aryamanarora avatar Aug 05 '24 20:08 aryamanarora