ImplementShiftNetwork: Reduce the number of conflicts when there are multiple sources that map to a target slot
Consider the following IR:
#layout = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and ct = i1 and (2i2 - i3 + slot) mod 4 = 0 and 0 <= i1 <= 1 and 0 <= i2 <= 1 and 0 <= i3 <= 1 and 0 <= slot <= 1023 }">
#layout1 = #tensor_ext.layout<"{ [i0, i1] -> [ct, slot] : exists (e0, e1, e2: i0 = 0 and ct = 0 and 16e2 = -i1 + slot + 16e0 and 0 <= i1 <= 1023 and 0 <= slot <= 1023 and 0 <= e1 <= 3 and -3 + i1 - 16e0 <= 4e1 <= i1 - 16e0) }">
#original_type = #tensor_ext.original_type<originalType = tensor<1x2x2x2xf32>, layout = #layout>
module {
func.func @main(%arg0: !secret.secret<tensor<1x1024xf32>> {tensor_ext.original_type = #tensor_ext.original_type<originalType = tensor<1x1x4x4xf32>, layout = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and i1 = 0 and ct = 0 and (-4i2 - i3 + slot) mod 16 = 0 and 0 <= i2 <= 3 and 0 <= i3 <= 3 and 0 <= slot <= 1023 }">>}, %arg1: tensor<2x1x3x3xf32>) -> (!secret.secret<tensor<1x1024xf32>> {tensor_ext.original_type = #original_type}) {
%0 = secret.generic(%arg0: !secret.secret<tensor<1x1024xf32>>) {
^body(%input0: tensor<1x1024xf32>):
%1 = tensor_ext.remap %input0 {permutation = #layout1} : tensor<1x1024xf32>
secret.yield %1 : tensor<1x1024xf32>
} -> !secret.secret<tensor<1x1024xf32>>
return %0 : !secret.secret<tensor<1x1024xf32>>
}
}
In this case we are remapping the input according to a permutation that was generated from slice extraction from tensor<1x1x4x4> to tensor<4x4>. Since the ciphertext size is 1024, the original 4x4 tensor was repeated (1024/16) times in the source. The slice extraction preserved the layout so that the result also had the same layout (and was also repeated the same way in the ciphertext). When the remap was lowered to a shift network, the Mapping generated a number of conflicts since the original 16 element tensor was repeated and the mapping could choose multiple of the same sources to map to a target. For example, it would generate the following IR if the mapping de-duped the sources that mapped to a target and choose them by the first that appear in the point pair collector:
#layout = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and ct = i1 and (2i2 - i3 + slot) mod 4 = 0 and 0 <= i1 <= 1 and 0 <= i2 <= 1 and 0 <= i3 <= 1 and 0 <= slot <= 1023 }">
#original_type = #tensor_ext.original_type<originalType = tensor<1x2x2x2xf32>, layout = #layout>
module {
func.func @main(%arg0: !secret.secret<tensor<1x1024xf32>> {tensor_ext.original_type = #tensor_ext.original_type<originalType = tensor<1x1x4x4xf32>, layout = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and i1 = 0 and ct = 0 and (-4i2 - i3 + slot) mod 16 = 0 and 0 <= i2 <= 3 and 0 <= i3 <= 3 and 0 <= slot <= 1023 }">>}, %arg1: tensor<2x1x3x3xf32>) -> (!secret.secret<tensor<1x1024xf32>> {tensor_ext.original_type = #original_type}) {
%c16 = arith.constant 16 : index
%cst = arith.constant dense_resource<__elided__> : tensor<1024xf32>
%c32 = arith.constant 32 : index
%cst_0 = arith.constant dense_resource<__elided__> : tensor<1024xf32>
%c64 = arith.constant 64 : index
%cst_1 = arith.constant dense_resource<__elided__> : tensor<1024xf32>
%c128 = arith.constant 128 : index
%cst_2 = arith.constant dense_resource<__elided__> : tensor<1024xf32>
%c256 = arith.constant 256 : index
%cst_3 = arith.constant dense_resource<__elided__> : tensor<1024xf32>
%c512 = arith.constant 512 : index
%cst_4 = arith.constant dense_resource<__elided__> : tensor<1024xf32>
%0 = tensor.empty() : tensor<1x1024xf32>
%1 = secret.generic(%arg0: !secret.secret<tensor<1x1024xf32>>) {
^body(%input0: tensor<1x1024xf32>):
%extracted_slice = tensor.extract_slice %input0[0, 0] [1, 1024] [1, 1] : tensor<1x1024xf32> to tensor<1024xf32>
%2 = arith.mulf %extracted_slice, %cst_4 : tensor<1024xf32>
%3 = tensor_ext.rotate %2, %c16 : tensor<1024xf32>, index
%4 = arith.addf %2, %3 : tensor<1024xf32>
%5 = arith.mulf %4, %cst : tensor<1024xf32>
%6 = tensor_ext.rotate %5, %c32 : tensor<1024xf32>, index
%7 = arith.addf %5, %6 : tensor<1024xf32>
%8 = arith.mulf %7, %cst_0 : tensor<1024xf32>
%9 = tensor_ext.rotate %8, %c64 : tensor<1024xf32>, index
%10 = arith.addf %8, %9 : tensor<1024xf32>
%11 = arith.mulf %10, %cst_1 : tensor<1024xf32>
%12 = tensor_ext.rotate %11, %c128 : tensor<1024xf32>, index
%13 = arith.addf %11, %12 : tensor<1024xf32>
%14 = arith.mulf %13, %cst_2 : tensor<1024xf32>
%15 = tensor_ext.rotate %14, %c256 : tensor<1024xf32>, index
%16 = arith.addf %14, %15 : tensor<1024xf32>
%17 = arith.mulf %16, %cst_3 : tensor<1024xf32>
%18 = tensor_ext.rotate %17, %c512 : tensor<1024xf32>, index
%19 = arith.addf %17, %18 : tensor<1024xf32>
%inserted_slice = tensor.insert_slice %19 into %0[0, 0] [1, 1024] [1, 1] : tensor<1024xf32> into tensor<1x1024xf32>
secret.yield %inserted_slice : tensor<1x1024xf32>
} -> !secret.secret<tensor<1x1024xf32>>
return %1 : !secret.secret<tensor<1x1024xf32>>
}
}
Where the constants are selecting groups of 16 floating point elements.
Instead of arbitrarily choosing the source that maps to a target, we can also select the one that has the closest distance to the target, which will result in choosing the source point in a repeated block closest to the target and de-dupe them in a way that might uniqueify which source maps to each target. Computing distance can be done with the "virtual distance": ciphertext_diff * ciphertextSize + slot_diff. If we do that, then the resulting IR would be
#layout = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and ct = i1 and (2i2 - i3 + slot) mod 4 = 0 and 0 <= i1 <= 1 and 0 <= i2 <= 1 and 0 <= i3 <= 1 and 0 <= slot <= 1023 }">
#original_type = #tensor_ext.original_type<originalType = tensor<1x2x2x2xf32>, layout = #layout>
module {
func.func @main(%arg0: !secret.secret<tensor<1x1024xf32>> {tensor_ext.original_type = #tensor_ext.original_type<originalType = tensor<1x1x4x4xf32>, layout = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and i1 = 0 and ct = 0 and (-4i2 - i3 + slot) mod 16 = 0 and 0 <= i2 <= 3 and 0 <= i3 <= 3 and 0 <= slot <= 1023 }">>}, %arg1: tensor<2x1x3x3xf32>) -> (!secret.secret<tensor<1x1024xf32>> {tensor_ext.original_type = #original_type}) {
%0 = tensor.empty() : tensor<1x1024xf32>
%1 = secret.generic(%arg0: !secret.secret<tensor<1x1024xf32>>) {
^body(%input0: tensor<1x1024xf32>):
%extracted_slice = tensor.extract_slice %input0[0, 0] [1, 1024] [1, 1] : tensor<1x1024xf32> to tensor<1024xf32>
%inserted_slice = tensor.insert_slice %extracted_slice into %0[0, 0] [1, 1024] [1, 1] : tensor<1024xf32> into tensor<1x1024xf32>
secret.yield %inserted_slice : tensor<1x1024xf32>
} -> !secret.secret<tensor<1x1024xf32>>
return %1 : !secret.secret<tensor<1x1024xf32>>
}
}
!!!
Obviously this is just one way to find a way to minimize conflicts in the graph. But it will only work for patterns like this where the source and target are repeated an equal amount of time. If we hold on to all the sources that may map to the target, then there is probably a better way to prune the conflict graph.
There is a slightly more general graph problem here, a solution to which would solve this problem:
Given a graph G you'd like to produce an induced subgraph G' for which $V(G') \subset V(G)$. G comes with a partition of the vertices $V_1, V_2, ..., V_k \subset V(G)$ with the property that, for each subset $V_i$, at least one vertex in $V_i$ must remain in the final graph. The goal is to make G' colorable with as few colors as possible.
Tying this back to shift networks: given a mapping [(v_i, s_i) -> (w_i, t_i)], let E_s = {(w, t) : t = s} be the set of sources that map to the same target slot s. Since each target slot can have at most one value, the choice of which source actually gets to land in the target slot is arbitrary (usually this happens with replication, where each source that lands in the same slot has the same value). However, the sources which do NOT end up landing in the target slot can be omitted from the shift network entirely, but they may add conflicts in the Vos-Vos-Erkin conflict graph, which adds an extra set of power-of-two rotations. Pruning those vertices from the conflict graph can reduce the number of colors needed, which can result in merging rotation groups.
@asraa the resource you shared by @j2kun was great context sharing here https://www.jeremykun.com/2024/09/02/shift-networks/
I suspect the general graph problem will be NP-hard, so some ideas are:
- Use the special structure of this graph (constructed by evaluating these power-of-two shifts in some order) to come up with a heuristic to prune it. Asra's initial heuristic was to select only the source from a group that is closest to its destination.
- Use a greedy heuristic like: Start with $S$ empty. Pick a random order of the $V_i$ and for each $i$ sort the vertices of $V_i$ by the primary key of # edges into S, and secondary key degree in G. Add the minimal vertex to S and repeat for all $i$, randomly breaking ties. Picking the best after a few tries with different random seeds, you should get a decently good outcome.
- Solve a quadratic program that has a binary variable $B_v$ for each vertex indicating whether it is chosen, constrained by $\sum_{v \in V_i} B_v = 1$ and minimizing $\sum_{(v, w) \in E(G)} B_v B_w$; here having the smallest number of edges is a proxy for low chromatic number. (Does or-tools have a solver that supports this? Probably will be slow...) The benefit here is that, if there is an independent set solution, it would be guaranteed to be found.
Ha, social media came through. This problem is called "selective graph coloring" and has been studied by a few authors.
On some applications of the selective graph coloring problem https://www.sciencedirect.com/science/article/abs/pii/S0377221714004184 (pdf link)
That paper has some references for constructive algorithms.
This issue has 1 outstanding TODOs:
- lib/Dialect/TensorExt/Transforms/ShiftScheme.h:100: Consider a better way to handle multiple valid sources.
This comment was autogenerated by todo-backlinks