AiDotNet icon indicating copy to clipboard operation
AiDotNet copied to clipboard

Refactor: Move LossFunction from constructor parameter to NeuralNetworkArchitecture property

Open ooples opened this issue 2 months ago • 0 comments

Problem Statement

Currently, neural network classes accept ILossFunction<T> as a separate constructor parameter, even though the default loss function is already determined by architecture.TaskType via NeuralNetworkHelper<T>.GetDefaultLossFunction().

Current pattern:

public SiameseNetwork(NeuralNetworkArchitecture<T> architecture, ILossFunction<T>? lossFunction = null) :
    base(architecture, lossFunction ?? NeuralNetworkHelper<T>.GetDefaultLossFunction(architecture.TaskType))

This creates redundancy where:

  1. The architecture already knows what task type it is (classification, regression, etc.)
  2. The task type already determines the default loss function
  3. But the loss function is passed as a separate parameter

Proposed Solution

Move ILossFunction<T> into NeuralNetworkArchitecture<T> as a property, making the architecture a complete specification of the model.

Proposed pattern:

public class NeuralNetworkArchitecture<T>
{
    public TaskType TaskType { get; set; }
    public ILossFunction<T> LossFunction { get; set; }
    // ... other properties
    
    // Constructor sets default loss function based on task type
    public NeuralNetworkArchitecture(TaskType taskType, ILossFunction<T>? lossFunction = null)
    {
        TaskType = taskType;
        LossFunction = lossFunction ?? NeuralNetworkHelper<T>.GetDefaultLossFunction(taskType);
    }
}

// Neural network constructors become simpler
public SiameseNetwork(NeuralNetworkArchitecture<T> architecture) :
    base(architecture)
{
    // Loss function comes from architecture.LossFunction
}

Benefits

  1. Single Source of Truth: Architecture fully describes the model configuration
  2. Simpler Constructors: Removes redundant loss function parameter from all neural network constructors
  3. Better Encapsulation: Architecture encapsulates both structure AND training configuration
  4. Consistency: Loss function always matches the architecture's task type by default
  5. Serialization: When saving/loading architectures, loss function is included automatically

Implementation Considerations

  1. Breaking Change: This will affect all neural network constructors and instantiation code
  2. Migration Path: All calls like new SomeNetwork(arch, lossFunc) need to change to either:
    • new SomeNetwork(arch) (if using default loss)
    • Create architecture with custom loss: new Architecture(..., customLoss)
  3. Backward Compatibility: Consider adding deprecated overloads that accept loss function for gradual migration
  4. Testing: All neural network tests need updating

Files Affected

Core Files

  • src/Models/NeuralNetworkArchitecture.cs - Add LossFunction property
  • src/NeuralNetworks/NeuralNetworkBase.cs - Update base constructor

Neural Network Classes

  • src/NeuralNetworks/ConvolutionalNeuralNetwork.cs
  • src/NeuralNetworks/RecurrentNeuralNetwork.cs
  • src/NeuralNetworks/SiameseNetwork.cs
  • src/NeuralNetworks/ResidualNeuralNetwork.cs
  • src/NeuralNetworks/Transformer.cs
  • src/NeuralNetworks/VariationalAutoencoder.cs
  • src/NeuralNetworks/GenerativeAdversarialNetwork.cs
  • All other neural network implementations

Tests

  • All unit tests that instantiate neural networks
  • Integration tests

Priority

Medium - This is a quality-of-life improvement that improves architecture consistency but doesn't block functionality.

Labels

refactoring, architecture, breaking-change

ooples avatar Nov 10 '25 17:11 ooples