MP-SPDZ icon indicating copy to clipboard operation
MP-SPDZ copied to clipboard

Problems with using multithread and Oram

Open HuiZ-W opened this issue 9 months ago • 5 comments

Hi, I am trying to implement bellman algorithm using MP-SPDZ. I tried it with and without ORAM and there seems to be some problems with both. Firstly the version without ORAM:

S2B_DIS = Array(party1_b_node + party2_b_node, sint)
Edges = MultiArray([party1_edge + party2_edge, 3], sint)
#Bellman
@for_range(Loop)
def _(i):
    #use party1_edge_to_relax
    @for_range_multithread(None, 10, party1_edge_num)
    #@for_range(0, party1_edge + party2_edge)
    def _(j):
        start = Edge[j][0].reveal()
        end = Edge[j][1].reveal()
        weight = Edge[j][2]
        flag = S2B_DIS[end] < (S2B_DIS[start] + weight)
        S2B_DIS[end] = ternary_operator(flag, S2B_DIS[end], S2B_DIS[start] + weight)

S2B_DIS holds the distances to be deflated, similar to the dis array in dijkstra, and Edges holds the edges to be deflated.ternary_operator is a ternary operation implemented as follows:

def ternary_operator(c, if_true, if_false):
    return c * (if_true - if_false) + if_false

I'm trying to speed things up with multithreading, but this seems to cause me to end up with an error in the final result, I know there are data conflicts with this implementation, but the result should logically be a more sensible value, and when I look at arrays I often see a lot of inexplicably large numbers. Is there some more reasonable way to implement multithreading, or can I use some thread locks?

Another ORAM version has some problems too:

S2B_DIS = OptimalORAM(party1_b_node + party2_b_node)
Edges = MultiArray([party1_edge + party2_edge, 3], sint)
#Bellman
@for_range(Loop)
def _(i):
    @for_range_multithread(None, 10, party1_edge + party2_edge)
    def _(j):
        start = Edge[j][0]
        end = Edge[j][1]
        weight = Edge[j][2]
        s_dis = S2B_DIS[start]
        e_dis = S2B_DIS[end]
        flag = s_dis < (e_dis + 1)
        S2B_DIS[start] = ternary_operator(flag, s_dis, e_dis + weight)

I'm trying to use ORAM to protect the topology information on both sides, and when I use multithreading I'm having problems with can't directly write memory in threads And, without multithreading, Bellman deflation is very slow, even taking 10s to process an edge, I understand that ORAM can lead to performance degradation, but this efficiency seems a bit too abnormal.

HuiZ-W avatar Apr 03 '25 07:04 HuiZ-W

Are you sure that the algorithm is suitable for parallelization? If end is the same for several edges, I would expect some race condition there.

ORAM is inherently unsuitable for multi-threading at the higher level. There is some internal multithreading by default, which you can influence by changing oram.n_threads and oram.n_threads_for_tree.

mkskeller avatar Apr 04 '25 01:04 mkskeller

Yes, you're right, if multiple edges are pointing to a destination at the same time it can lead to conflicts, my intention was to parallelise to minimise the communication time of the comparison as much as possible, and I think some standardisation can be followed up to ensure that the edges processed in parallel each time are from the same starting point and pointing to a different end point.

I tried to modify those parameters you mentioned, I found that OptimalORAM returns linearORAM when the number of ORAMs is small, then modifying n_thread seems to have some effect (with 300 points, n_thread=None needs to process 1400s, n_thread=20 needs to process 540s), and the n_thread in linearORAM doesn't seem to work for larger numbers of points (at 7000 points, it seems to take 6s to process an edge anyway, is that a reasonable amount of time?).

If I use TreeORAM, it asks me to make sure the number is a power of 2, which makes me have to add some meaningless extra points. and I can't get to the edge processing logic for a long time after the packed ORAM init, is this still building an ORAM?

HuiZ-W avatar Apr 04 '25 14:04 HuiZ-W

Can you post the complete code? ORAM shouldn't just work for power of two sizes. LinearORAM doesn't scale well but tree-based ORAM is indeed relatively expensive to initialize.

mkskeller avatar Apr 07 '25 00:04 mkskeller

Of course, here's the full code

party1_b_node = 5001
party1_edge = 300962
party2_b_node = 4898
party2_edge = 200962
boundary_edge = 22969
inf = sint(1000000)
inner_Min = sint(1000)
P1_Edge = MultiArray([party1_edge, 3], sint)
P2_Edge = MultiArray([party2_edge, 3], sint)
B_Edge = MultiArray([boundary_edge,2], cint)
#get party1_inner_edge
party1_edge_num = sint.get_input_from(0).reveal()
@for_range(0, party1_edge_num)
def _(i):
    P1_Edge[i][0] = sint.get_input_from(0)
    P1_Edge[i][1] = sint.get_input_from(0)
    P1_Edge[i][2] = sint.get_input_from(0)
print_ln("finish party1 edge")
#get party2_inner_edge
party2_edge_num = sint.get_input_from(1).reveal()
@for_range(0, party2_edge_num)
def _(i):
    P2_Edge[i][0] = sint.get_input_from(1) + party1_b_node
    P2_Edge[i][1] = sint.get_input_from(1) + party1_b_node
    P2_Edge[i][2] = sint.get_input_from(1)
print_ln("finish party2 edge")
#get boundary_edge
@for_range(0, boundary_edge)
def _(i):
    start = sint.get_input_from(0).reveal()
    end = sint.get_input_from(0).reveal() + party1_b_node
    weight = sint.get_input_from(0)
    B_Edge[i][0] = start
    B_Edge[i][1] = end
print_ln("finish boundary edge")
S2B = Array(party1_b_node + party2_b_node, sint)
S2B.assign_all(9999)
B2D = Array(party1_b_node, sint)
@for_range(0, party1_b_node)
def _(i):
    S2B[i] = sint.get_input_from(0)
@for_range(0, party1_b_node)
def _(i):
    B2D[i] = sint.get_input_from(0)
S2B_DIS = OptimalORAM(party1_b_node + party2_b_node)
S2B_DIS.batch_init(S2B)
B2D_DIS = OptimalORAM(party1_b_node)
B2D_DIS.batch_init(B2D)

#get inner shortest distance
Min = sint.get_input_from(0)
print_ln("finish input")
#Bellman
@for_range(Loop)
def _(i):
    #use boundary_edge_to_relax
    @for_range_multithread(None, 10, boundary_edge)
    #@for_range_opt_multithread(10, boundary_edge)
    #@for_range(0, boundary_edge)
    def _(j):
        start = B_Edge[j][0]
        end = B_Edge[j][1]
        s_dis = S2B_DIS[start]
        e_dis = S2B_DIS[end]
        weight = 1
        flag = s_dis < (e_dis + weight)
        S2B_DIS[start] = ternary_operator(flag, s_dis, e_dis + weight)
        flag = e_dis < (s_dis + weight)
        S2B_DIS[end] = ternary_operator(flag, e_dis, s_dis + weight)
        print_ln('finish %s',j)
    print_ln('finish boundary')
    #use party1_edge_to_relax
    @for_range_multithread(None, 10, party1_edge_num)
    #@for_range_opt_multithread(10, party1_edge_num)
    #@for_range(0, party1_edge_num)
    def _(j):
        start = P1_Edge[j][0]
        end = P1_Edge[j][1]
        s_dis = S2B_DIS[start]
        e_dis = S2B_DIS[end]
        weight = P1_Edge[j][2]
        flag = s_dis < (e_dis + weight)
        S2B_DIS[start] = ternary_operator(flag, s_dis, e_dis + weight)
    print_ln('finish party1')
    #use party2_edge_to_relax
    @for_range_multithread(None, 10, party2_edge_num)
    #@for_range_opt_multithread(10, party2_edge_num)
    #@for_range(0, 4)
    def _(j):
        start = P2_Edge[j][0]
        end = P2_Edge[j][1]
        s_dis = S2B_DIS[start]
        e_dis = S2B_DIS[end]
        weight = P2_Edge[j][2]
        flag = s_dis < (e_dis + weight)
        S2B_DIS[start] = ternary_operator(flag, s_dis, e_dis + weight)
    print_ln('finish party2')
    print_ln("Loop %s",i)

And I found out the problem is that TreeORAM requires the number to be a power of 2 when doing batch_init

File "/home/huizhong/MPSPDZ/MP-SPDZ/Programs/Source/mayi/third/oram.mpc", line 61, in <module>
    S2B_DIS.batch_init(S2B)
File "/home/huizhong/MPSPDZ/MP-SPDZ/Compiler/oram.py", line 1251, in batch_init
    raise CompilerError('Batch size must a power of 2.')
Compiler.exceptions.CompilerError: Batch size must a power of 2.

I'm trying to use batch_init because directly using for_range to store the initial to ORAM is S2B is really too long, in 5000 values need to deal with 3h, is there any better solution?

I previously thought that the processing time for an edge takes 6s mainly because of the need to de-comparison, but after I commented out the comparison it still takes so, it looks like accessing and processing the ORAM is the main overhead, is there any solution to this?

HuiZ-W avatar Apr 07 '25 06:04 HuiZ-W

Thank you for raising the restriction, which isn't necessary anymore. 85f2d094a9d fixes this and some issues with batch initialization.

I previously thought that the processing time for an edge takes 6s mainly because of the need to de-comparison, but after I commented out the comparison it still takes so, it looks like accessing and processing the ORAM is the main overhead, is there any solution to this?

ORAM is known for its heavy cost. You can experiment with the different variants (LinearORAM, TreeORAM, PathORAM, SqrtOram). However, I would expect ORAM to remain the biggest cost factor in any algorithm that requires it.

mkskeller avatar Apr 10 '25 06:04 mkskeller