dfdx icon indicating copy to clipboard operation
dfdx copied to clipboard

Add `Tensor::top_k`

Open coreylowman opened this issue 1 year ago • 0 comments

Often times when neural networks output labels, we want to get the top K labels. This will be useful for validation & deployment.

UX should look something like:

let t: Tensor<Rank2<32, 10>, f32, _> = ...;

// with dynamic `k`
let _: Tensor<(Const<32>, usize), usize, _> = t.top_k(5);

// with const `k`
let _: Tensor<Rank2<32, 5>, usize, _> = t.top_k(Const::<5>);

Notice that both the shape & dtype are changing - the dtype will be the indices of the top k.

coreylowman avatar Apr 10 '23 12:04 coreylowman