Refactor: Abstract Knowledge Distillation operations to support generic TOutput types
Problem Statement
The Knowledge Distillation (KD) framework's interfaces are correctly generic with IKnowledgeDistillationTrainer<T, TInput, TOutput> and IDistillationStrategy<T, TOutput>, but concrete implementations are hardcoded to Vector<T>, forcing conversions and defeating the purpose of the generic architecture.
Current limitations:
// Interfaces are generic ✅
public interface IDistillationStrategy<T, TOutput>
{
T ComputeLoss(TOutput studentOutput, TOutput teacherOutput, TOutput? trueLabels = default);
TOutput ComputeGradient(TOutput studentOutput, TOutput teacherOutput, TOutput? trueLabels = default);
}
// But implementations are hardcoded to Vector<T> ❌
public class DistillationLoss<T> : DistillationStrategyBase<T, Vector<T>>
{
// Softmax, KLDivergence, CrossEntropy all work only on Vector<T>
private Vector<T> Softmax(Vector<T> logits, double temperature) { ... }
private T KLDivergence(Vector<T> p, Vector<T> q) { ... }
}
Consequences:
- PredictionModelBuilder must convert Matrix<T> → Vector<T> → Matrix<T> on every forward/backward pass
- Conversion overhead defeats performance benefits of generic architecture
- Cannot leverage batch processing with Matrix<T> or Tensor<T> directly
- Not aligned with industry standard (PyTorch, TensorFlow work with tensors of any shape)
Current workaround in PR #437: Adapter functions bridge PredictionModelBuilder's TInput/TOutput (dataset types like Matrix<T>) with KD trainer's Vector<T> (sample types). This works but adds unnecessary conversion overhead.
Proposed Solution
Introduce IOutputOperations<T, TOutput> interface to abstract all output-specific operations, following the same pattern as INumericOperations<T> for numeric type abstraction.
1. Define the abstraction
public interface IOutputOperations<T, TOutput>
{
TOutput Softmax(TOutput logits, double temperature);
T KLDivergence(TOutput p, TOutput q);
T CrossEntropy(TOutput predictions, TOutput trueLabels);
int GetDimension(TOutput output);
int ArgMax(TOutput output);
TOutput Subtract(TOutput a, TOutput b);
TOutput Multiply(TOutput output, T scalar);
TOutput CreateOutput(int dimension);
}
2. Implement for Vector<T>
public class VectorOutputOperations<T> : IOutputOperations<T, Vector<T>>
{
private readonly INumericOperations<T> _numOps;
public VectorOutputOperations(INumericOperations<T> numOps)
{
_numOps = numOps ?? throw new ArgumentNullException(nameof(numOps));
}
public Vector<T> Softmax(Vector<T> logits, double temperature) { /* Move existing logic here */ }
public T KLDivergence(Vector<T> p, Vector<T> q) { /* Move existing logic here */ }
public T CrossEntropy(Vector<T> predictions, Vector<T> trueLabels) { /* Move existing logic here */ }
public int GetDimension(Vector<T> output) => output.Length;
public int ArgMax(Vector<T> output) { /* Move existing logic here */ }
public Vector<T> Subtract(Vector<T> a, Vector<T> b) { /* Element-wise subtraction */ }
public Vector<T> Multiply(Vector<T> output, T scalar) { /* Scalar multiplication */ }
public Vector<T> CreateOutput(int dimension) => new Vector<T>(dimension);
}
3. Refactor DistillationStrategyBase
public abstract class DistillationStrategyBase<T, TOutput> : IDistillationStrategy<T, TOutput>
{
protected readonly INumericOperations<T> NumOps;
protected readonly IOutputOperations<T, TOutput> OutputOps;
protected DistillationStrategyBase(
IOutputOperations<T, TOutput> outputOps,
double temperature = 3.0,
double alpha = 0.3)
{
OutputOps = outputOps ?? throw new ArgumentNullException(nameof(outputOps));
NumOps = MathHelper.GetNumericOperations<T>();
Temperature = temperature;
Alpha = alpha;
}
protected void ValidateOutputDimensions(TOutput studentOutput, TOutput teacherOutput)
{
int studentDim = OutputOps.GetDimension(studentOutput);
int teacherDim = OutputOps.GetDimension(teacherOutput);
if (studentDim != teacherDim)
throw new ArgumentException($"Output dimensions must match. Student: {studentDim}, Teacher: {teacherDim}");
}
}
4. Update concrete strategies
public class DistillationLoss<T> : DistillationStrategyBase<T, Vector<T>>
{
public DistillationLoss(IOutputOperations<T, Vector<T>> outputOps, double temperature = 3.0, double alpha = 0.3)
: base(outputOps, temperature, alpha) { }
public override T ComputeLoss(Vector<T> studentLogits, Vector<T> teacherLogits, Vector<T>? trueLabels = null)
{
ValidateOutputDimensions(studentLogits, teacherLogits);
var studentSoft = OutputOps.Softmax(studentLogits, Temperature);
var teacherSoft = OutputOps.Softmax(teacherLogits, Temperature);
var softLoss = OutputOps.KLDivergence(teacherSoft, studentSoft);
// ... rest of logic using OutputOps
}
}
Benefits
- True generic support: No conversion overhead, works natively with Matrix<T>, Tensor<T>, etc.
- Industry-standard architecture: Matches PyTorch/TensorFlow pattern
- Performance: Eliminates Vector<T> conversion bottleneck
- Extensibility: Easy to add support for new output types
- Consistency: Follows existing INumericOperations<T> pattern
Implementation Checklist
- [ ] Create
IOutputOperations<T, TOutput>interface - [ ] Implement
VectorOutputOperations<T> - [ ] Refactor
DistillationStrategyBase<T, TOutput>to useIOutputOperations - [ ] Update all concrete strategies
- [ ] Update
KnowledgeDistillationTrainerBase<T, TInput, TOutput> - [ ] Update
TeacherModelBase<TInput, TOutput, T> - [ ] Update factories to inject operations
- [ ] Remove adapter functions from PredictionModelBuilder
- [ ] Add unit tests
- [ ] Update documentation
Related Issues
- PR #437: Knowledge Distillation features (currently uses adapter workaround)
- This refactor should be done in a separate PR after #437 is merged
References
- PyTorch KD: Works with
torch.Tensorof any shape - TensorFlow/Keras KD: Works with
tf.Tensorof any shape - Gemini analysis: Recommended
IOutputOperationspattern for true generic support