Grokking-Deep-Learning icon indicating copy to clipboard operation
Grokking-Deep-Learning copied to clipboard

small mistake in Chapter 15

Open moFang222 opened this issue 5 years ago • 1 comments

In the Homomorphically encrypted federated learning section The providing code are as follows:


1. model = Embedding(vocab_size=len(vocab), dim=1)
2. model.weight.data *= 0
3. 
4. # note that in production the n_length should be at least 1024
5. public_key, private_key = phe.generate_paillier_keypair(n_length=128)
6. 
7. def train_and_encrypt(model, input, target, pubkey):
8.     new_model = train(copy.deepcopy(model), input, target, iterations=1)
9. 
10.     encrypted_weights = list()
11.     for val in new_model.weight.data[:,0]:
12.         encrypted_weights.append(public_key.encrypt(val))
13.     ew = np.array(encrypted_weights).reshape(new_model.weight.data.shape)
14. 
15.     return ew
16. 
17. for i in range(3):
18.     print("\nStarting Training Round...")
19.     print("\tStep 1: send the model to Bob")
20.     bob_encrypted_model = train_and_encrypt(copy.deepcopy(model),
21.                                             bob[0], bob[1], public_key)
22. 
23.     print("\n\tStep 2: send the model to Alice")
24.     alice_encrypted_model=train_and_encrypt(copy.deepcopy(model),
25.                                             alice[0],alice[1],public_key)
26. 
27.     print("\n\tStep 3: Send the model to Sue")
28.     sue_encrypted_model = train_and_encrypt(copy.deepcopy(model),
29.                                             sue[0], sue[1], public_key)
30. 
31.     print("\n\tStep 4: Bob, Alice, and Sue send their")
32.     print("\tencrypted models to each other.")
33.     aggregated_model = bob_encrypted_model + \
34.                        alice_encrypted_model + \
35.                        sue_encrypted_model
36. 
37.     print("\n\tStep 5: only the aggregated model")
38.     print("\tis sent back to the model owner who")
39.     print("\t can decrypt it.")
40.     raw_values = list()
41.     for val in sue_encrypted_model.flatten():
42.         raw_values.append(private_key.decrypt(val))
43.     new = np.array(raw_values).reshape(model.weight.data.shape)/3
44.     model.weight.data = new
45. 
46.     print("\t% Correct on Test Set: " + \
47.               str(test(model, test_data, test_target)*100))

And I think the sue_encrypted_model in 41 line should be aggregated_model ?

moFang222 avatar Feb 24 '20 07:02 moFang222

That's what I think as well.

t-kubrak avatar May 26 '20 22:05 t-kubrak