mpyc
mpyc copied to clipboard
`mpc.xor` producing incorrect value
import sys
from mpyc import mpctools
from mpyc.runtime import mpc
async def main():
my_input = int(sys.argv[1], 16)
secint = mpc.SecInt(256)
await mpc.start()
all_inputs = mpc.input(secint(my_input))
combined_inputs = mpctools.reduce(mpc.xor, all_inputs)
entropy = await mpc.output(combined_inputs)
await mpc.shutdown()
key = hex(entropy)[2:]
print(key)
mpc.run(main())
Run with:
➜ python src/compute.py -M2 -I0 --no-log 6bf558aeeb81970a5d82e63f8c785f6a03b0f5355f276ef986b8b8b1cff1c6cb
➜ python src/compute.py -M2 -I1 --no-log 96dc2474ff71bab6001c1b81c0d1ff13a8e64264e68c031666cc521a662dbe14
102d17d23eaf351c05d9f01c14d4a5e7dac97379a45b3720fed850acc361f84df
The first issue is there's an overflow, 257 bit instead of 256. Second issue is that this value is incorrect. Correct result:
hex1 = "6bf558aeeb81970a5d82e63f8c785f6a03b0f5355f276ef986b8b8b1cff1c6cb"
hex2 = "96dc2474ff71bab6001c1b81c0d1ff13a8e64264e68c031666cc521a662dbe14"
num1 = int(hex1, 16)
num2 = int(hex2, 16)
final = num1 ^ num2
key = hex(final)[2:]
print(key)
Running this produces fd297cda14f02dbc5d9efdbe4ca9a079ab56b751b9ab6defe074eaaba9dc78df, whereas running via mpc.xor produces 102d17d23eaf351c05d9f01c14d4a5e7dac97379a45b3720fed850acc361f84df
I think I found the fix. The xor function is implemented incorrectly. By doing a loop and the xor manually it produces the correct result:
import sys
from mpyc import mpctools
from mpyc.runtime import mpc
async def main():
my_input = int(sys.argv[1], 16)
secint = mpc.SecInt(256)
await mpc.start()
all_inputs = mpc.input(secint(my_input))
combined_inputs = all_inputs[0]
for input_value in all_inputs[1:]:
a_bits = mpc.to_bits(combined_inputs)
b_bits = mpc.to_bits(input_value)
and_bits = mpc.schur_prod(a_bits, b_bits)
xor_bits = mpc.vector_sub(
mpc.vector_add(a_bits, b_bits), mpc.vector_add(and_bits, and_bits)
)
combined_inputs = mpc.from_bits(xor_bits)
entropy = await mpc.output(combined_inputs)
await mpc.shutdown()
key = hex(entropy)[2:]
print(key)
mpc.run(main())
Instead of:
def xor(self, a, b):
"""Secure bitwise xor of a and b."""
return a + b
I think the implementation should be something like:
def xor(self, a, b):
"""Secure bitwise xor of a and b."""
a_bits = self.to_bits(a)
b_bits = self.to_bits(b)
and_bits = self.schur_prod(a_bits, b_bits)
xor_bits = self.vector_sub(
self.vector_add(a_bits, b_bits),
self.vector_add(and_bits, and_bits),
)
return self.from_bits(xor_bits)
Please see issue #36. The implementation of mpc.xor() is not meant to cover the general case because that's very costly.
It'd be great if this was documented, and also a code sample added to show how to do it