EasyEdit icon indicating copy to clipboard operation
EasyEdit copied to clipboard

Suggestions for PMET Method on Llama

Open Lut-hub opened this issue 1 year ago • 1 comments

Hello!

I came across a tiny issue while editing Llama2 using the PMET method, and I thought it might be worth mentioning: At line 44 of the file easyeditor/models/pmet/compute_zs.py:

target_ids = tok(request["target_new"], return_tensors="pt").to("cuda")["input_ids"][0]

The LlamaTokenizer adds an additional "<s>" token to request["target_new"]. This will result in appending an additional "<s>" token to the last of query during subsequent processes, making it difficult for PMET to optimize zs effectively.

For example, when we edit the object of "What is the native language of Christiane Cohendy?" to "German", the result is:

Lookup index found: 12 | Sentence: What is the native language of Christiane Cohendy?<s> | Token: y
Rewrite layer is 8
Tying optimization objective to 31
Recording initial value of v* in attn
Recording initial value of v* in mlp
loss 11.094 = 11.094 + 0.0 + 0.0 avg prob of [ German] 2.2592015739064664e-05
loss 10.079 = 10.078 + 0.0 + 0.001 avg prob of [ German] 6.582784408237785e-05
loss 9.424 = 9.423 + 0.0 + 0.001 avg prob of [ German] 0.0001342837349511683
loss 7.538 = 7.538 + 0.0 + 0.001 avg prob of [ German] 0.0012688480783253908
...

However, for MEMIT, it should be:

Lookup index found: 12 | Sentence: What is the native language of Christiane Cohendy? | Token: y
Rewrite layer is 8
Tying optimization objective to 31
Recording initial value of v*
loss 6.46 = 6.46 + 0.0 + 0.0 avg prob of [ German] 0.002235208638012409
loss 4.531 = 4.399 + 0.099 + 0.033 avg prob of [ German] 0.013420548290014267
loss 3.512 = 3.249 + 0.231 + 0.033 avg prob of [ German] 0.04390619695186615
loss 2.939 = 2.666 + 0.24 + 0.033 avg prob of [ German] 0.07955510914325714
...

For the MEMIT method, there's the following code at line 43 of the file easyeditor/models/memit/compute_z.py:

if target_ids[0] == tok.bos_token_id or target_ids[0] == tok.unk_token_id:
        target_ids = target_ids[1:]

So, it would be better to add the aforementioned code at line 47 of the file easyeditor/models/pmet/compute_zs.py. 😊

Lut-hub avatar May 02 '24 09:05 Lut-hub

Thank you very much for your interest in EasyEdit. We apologize for our limited availability as we are currently busy with the nips submission deadline. We will focus on optimization after the deadline is over.

XeeKee avatar May 04 '24 08:05 XeeKee

Thank you for your suggestion. I will modify the entire code to use tok.encode(xx, add_special_tokens=False) to avoid adding unnecessary tokens.

pengzju avatar May 26 '24 07:05 pengzju

Thanks for your reply 😊

Lut-hub avatar May 28 '24 03:05 Lut-hub