pyreft icon indicating copy to clipboard operation
pyreft copied to clipboard

[P1] Cannot load trained model anymore - "type must be tuple of ints,but got NoneType"

Open chris-aeviator opened this issue 7 months ago • 9 comments

After updating pyreft recently I'm encountering errors when loading a trained model. This applies to newly trained models as well as prev. trained models. I'm loading from disk.

the error happens due to seemingly my config not beeing read correctly. The error originates since kwargs['low_rank_dimension'] is None and if I set it to my correct value of e.g. 8 or 12 and the intervention type to my class, the model loads.

Name: pyvene
Version: 0.1.2

Name: pyreft
Version: 0.0.6 --> there seems to be only a version 0.0.5 ?! can't explain this, maybe due to direct code edits in site-packages?
File ~/micromamba/envs/trtf/lib/python3.9/site-packages/pyreft/interventions.py:40, in LoreftIntervention.__init__(self, **kwargs)
---> [40](site-packages/pyreft/interventions.py:40) rotate_layer = LowRankRotateLayer(
     [41](site-packages/pyreft/interventions.py:41)     self.embed_dim, kwargs["low_rank_dimension"], init_orth=True)
     [42](site-packages/pyreft/interventions.py:42) self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
     [43](site-packages/pyreft/interventions.py:43) self.learned_source = torch.nn.Linear(
     [44](site-packages/pyreft/interventions.py:44)     self.embed_dim, kwargs["low_rank_dimension"]).to(
     [45](site-packages/pyreft/interventions.py:45)     kwargs["dtype"] if "dtype" in kwargs else torch.bfloat16)

File ~/micromamba/envs/trtf/lib/python3.9/site-packages/pyreft/interventions.py:19, in LowRankRotateLayer.__init__(self, n, m, init_orth)
     [17](site-packages/pyreft/interventions.py:17) super().__init__()
     [18](site-packages/pyreft/interventions.py:18) # n > m
---> [19](site-packages/pyreft/interventions.py:19) self.weight = torch.nn.Parameter(torch.empty(n, m), requires_grad=True)

     [21](site-packages/pyreft/interventions.py:21)     torch.nn.init.orthogonal_(self.weight)

sample config file

{
  "intervention_constant_sources": [
    true
  ],
  "intervention_dimensions": [
    4096
  ],
  "intervention_types": [
    "<class 'transforms.autobrew.trft.raft.subspace.SubloreftIntervention'>"
  ],
  "mode": "parallel",
  "representations": [
    [
      16,
      "block_output",
      "pos",
      1,
      null,
      null,
      null,
      null,
      null,
      null,
      null,
      null,
      null
    ]
  ],
  "sorted_keys": [
    "layer.16.comp.block_output.unit.pos.nunit.1#0"
  ],
  "transformers_version": "4.43.3"
}

chris-aeviator avatar Aug 02 '24 06:08 chris-aeviator