dfdx
dfdx copied to clipboard
Add `Tensor::top_k`
Often times when neural networks output labels, we want to get the top K labels. This will be useful for validation & deployment.
UX should look something like:
let t: Tensor<Rank2<32, 10>, f32, _> = ...;
// with dynamic `k`
let _: Tensor<(Const<32>, usize), usize, _> = t.top_k(5);
// with const `k`
let _: Tensor<Rank2<32, 5>, usize, _> = t.top_k(Const::<5>);
Notice that both the shape & dtype are changing - the dtype will be the indices of the top k.