AiDotNet icon indicating copy to clipboard operation
AiDotNet copied to clipboard

Refactor: Abstract Knowledge Distillation operations to support generic TOutput types

Open ooples opened this issue 2 months ago • 0 comments

Problem Statement

The Knowledge Distillation (KD) framework's interfaces are correctly generic with IKnowledgeDistillationTrainer<T, TInput, TOutput> and IDistillationStrategy<T, TOutput>, but concrete implementations are hardcoded to Vector<T>, forcing conversions and defeating the purpose of the generic architecture.

Current limitations:

// Interfaces are generic ✅
public interface IDistillationStrategy<T, TOutput>
{
    T ComputeLoss(TOutput studentOutput, TOutput teacherOutput, TOutput? trueLabels = default);
    TOutput ComputeGradient(TOutput studentOutput, TOutput teacherOutput, TOutput? trueLabels = default);
}

// But implementations are hardcoded to Vector<T> ❌
public class DistillationLoss<T> : DistillationStrategyBase<T, Vector<T>>
{
    // Softmax, KLDivergence, CrossEntropy all work only on Vector<T>
    private Vector<T> Softmax(Vector<T> logits, double temperature) { ... }
    private T KLDivergence(Vector<T> p, Vector<T> q) { ... }
}

Consequences:

  1. PredictionModelBuilder must convert Matrix<T> → Vector<T> → Matrix<T> on every forward/backward pass
  2. Conversion overhead defeats performance benefits of generic architecture
  3. Cannot leverage batch processing with Matrix<T> or Tensor<T> directly
  4. Not aligned with industry standard (PyTorch, TensorFlow work with tensors of any shape)

Current workaround in PR #437: Adapter functions bridge PredictionModelBuilder's TInput/TOutput (dataset types like Matrix<T>) with KD trainer's Vector<T> (sample types). This works but adds unnecessary conversion overhead.

Proposed Solution

Introduce IOutputOperations<T, TOutput> interface to abstract all output-specific operations, following the same pattern as INumericOperations<T> for numeric type abstraction.

1. Define the abstraction

public interface IOutputOperations<T, TOutput>
{
    TOutput Softmax(TOutput logits, double temperature);
    T KLDivergence(TOutput p, TOutput q);
    T CrossEntropy(TOutput predictions, TOutput trueLabels);
    int GetDimension(TOutput output);
    int ArgMax(TOutput output);
    TOutput Subtract(TOutput a, TOutput b);
    TOutput Multiply(TOutput output, T scalar);
    TOutput CreateOutput(int dimension);
}

2. Implement for Vector<T>

public class VectorOutputOperations<T> : IOutputOperations<T, Vector<T>>
{
    private readonly INumericOperations<T> _numOps;

    public VectorOutputOperations(INumericOperations<T> numOps)
    {
        _numOps = numOps ?? throw new ArgumentNullException(nameof(numOps));
    }

    public Vector<T> Softmax(Vector<T> logits, double temperature) { /* Move existing logic here */ }
    public T KLDivergence(Vector<T> p, Vector<T> q) { /* Move existing logic here */ }
    public T CrossEntropy(Vector<T> predictions, Vector<T> trueLabels) { /* Move existing logic here */ }
    public int GetDimension(Vector<T> output) => output.Length;
    public int ArgMax(Vector<T> output) { /* Move existing logic here */ }
    public Vector<T> Subtract(Vector<T> a, Vector<T> b) { /* Element-wise subtraction */ }
    public Vector<T> Multiply(Vector<T> output, T scalar) { /* Scalar multiplication */ }
    public Vector<T> CreateOutput(int dimension) => new Vector<T>(dimension);
}

3. Refactor DistillationStrategyBase

public abstract class DistillationStrategyBase<T, TOutput> : IDistillationStrategy<T, TOutput>
{
    protected readonly INumericOperations<T> NumOps;
    protected readonly IOutputOperations<T, TOutput> OutputOps;

    protected DistillationStrategyBase(
        IOutputOperations<T, TOutput> outputOps,
        double temperature = 3.0,
        double alpha = 0.3)
    {
        OutputOps = outputOps ?? throw new ArgumentNullException(nameof(outputOps));
        NumOps = MathHelper.GetNumericOperations<T>();
        Temperature = temperature;
        Alpha = alpha;
    }

    protected void ValidateOutputDimensions(TOutput studentOutput, TOutput teacherOutput)
    {
        int studentDim = OutputOps.GetDimension(studentOutput);
        int teacherDim = OutputOps.GetDimension(teacherOutput);
        if (studentDim != teacherDim)
            throw new ArgumentException($"Output dimensions must match. Student: {studentDim}, Teacher: {teacherDim}");
    }
}

4. Update concrete strategies

public class DistillationLoss<T> : DistillationStrategyBase<T, Vector<T>>
{
    public DistillationLoss(IOutputOperations<T, Vector<T>> outputOps, double temperature = 3.0, double alpha = 0.3)
        : base(outputOps, temperature, alpha) { }

    public override T ComputeLoss(Vector<T> studentLogits, Vector<T> teacherLogits, Vector<T>? trueLabels = null)
    {
        ValidateOutputDimensions(studentLogits, teacherLogits);
        var studentSoft = OutputOps.Softmax(studentLogits, Temperature);
        var teacherSoft = OutputOps.Softmax(teacherLogits, Temperature);
        var softLoss = OutputOps.KLDivergence(teacherSoft, studentSoft);
        // ... rest of logic using OutputOps
    }
}

Benefits

  1. True generic support: No conversion overhead, works natively with Matrix<T>, Tensor<T>, etc.
  2. Industry-standard architecture: Matches PyTorch/TensorFlow pattern
  3. Performance: Eliminates Vector<T> conversion bottleneck
  4. Extensibility: Easy to add support for new output types
  5. Consistency: Follows existing INumericOperations<T> pattern

Implementation Checklist

  • [ ] Create IOutputOperations<T, TOutput> interface
  • [ ] Implement VectorOutputOperations<T>
  • [ ] Refactor DistillationStrategyBase<T, TOutput> to use IOutputOperations
  • [ ] Update all concrete strategies
  • [ ] Update KnowledgeDistillationTrainerBase<T, TInput, TOutput>
  • [ ] Update TeacherModelBase<TInput, TOutput, T>
  • [ ] Update factories to inject operations
  • [ ] Remove adapter functions from PredictionModelBuilder
  • [ ] Add unit tests
  • [ ] Update documentation

Related Issues

  • PR #437: Knowledge Distillation features (currently uses adapter workaround)
  • This refactor should be done in a separate PR after #437 is merged

References

  • PyTorch KD: Works with torch.Tensor of any shape
  • TensorFlow/Keras KD: Works with tf.Tensor of any shape
  • Gemini analysis: Recommended IOutputOperations pattern for true generic support

ooples avatar Nov 13 '25 04:11 ooples