AiDotNet
AiDotNet copied to clipboard
Refactor: Move LossFunction from constructor parameter to NeuralNetworkArchitecture property
Problem Statement
Currently, neural network classes accept ILossFunction<T> as a separate constructor parameter, even though the default loss function is already determined by architecture.TaskType via NeuralNetworkHelper<T>.GetDefaultLossFunction().
Current pattern:
public SiameseNetwork(NeuralNetworkArchitecture<T> architecture, ILossFunction<T>? lossFunction = null) :
base(architecture, lossFunction ?? NeuralNetworkHelper<T>.GetDefaultLossFunction(architecture.TaskType))
This creates redundancy where:
- The architecture already knows what task type it is (classification, regression, etc.)
- The task type already determines the default loss function
- But the loss function is passed as a separate parameter
Proposed Solution
Move ILossFunction<T> into NeuralNetworkArchitecture<T> as a property, making the architecture a complete specification of the model.
Proposed pattern:
public class NeuralNetworkArchitecture<T>
{
public TaskType TaskType { get; set; }
public ILossFunction<T> LossFunction { get; set; }
// ... other properties
// Constructor sets default loss function based on task type
public NeuralNetworkArchitecture(TaskType taskType, ILossFunction<T>? lossFunction = null)
{
TaskType = taskType;
LossFunction = lossFunction ?? NeuralNetworkHelper<T>.GetDefaultLossFunction(taskType);
}
}
// Neural network constructors become simpler
public SiameseNetwork(NeuralNetworkArchitecture<T> architecture) :
base(architecture)
{
// Loss function comes from architecture.LossFunction
}
Benefits
- Single Source of Truth: Architecture fully describes the model configuration
- Simpler Constructors: Removes redundant loss function parameter from all neural network constructors
- Better Encapsulation: Architecture encapsulates both structure AND training configuration
- Consistency: Loss function always matches the architecture's task type by default
- Serialization: When saving/loading architectures, loss function is included automatically
Implementation Considerations
- Breaking Change: This will affect all neural network constructors and instantiation code
-
Migration Path: All calls like
new SomeNetwork(arch, lossFunc)need to change to either:-
new SomeNetwork(arch)(if using default loss) - Create architecture with custom loss:
new Architecture(..., customLoss)
-
- Backward Compatibility: Consider adding deprecated overloads that accept loss function for gradual migration
- Testing: All neural network tests need updating
Files Affected
Core Files
-
src/Models/NeuralNetworkArchitecture.cs- AddLossFunctionproperty -
src/NeuralNetworks/NeuralNetworkBase.cs- Update base constructor
Neural Network Classes
-
src/NeuralNetworks/ConvolutionalNeuralNetwork.cs -
src/NeuralNetworks/RecurrentNeuralNetwork.cs -
src/NeuralNetworks/SiameseNetwork.cs -
src/NeuralNetworks/ResidualNeuralNetwork.cs -
src/NeuralNetworks/Transformer.cs -
src/NeuralNetworks/VariationalAutoencoder.cs -
src/NeuralNetworks/GenerativeAdversarialNetwork.cs - All other neural network implementations
Tests
- All unit tests that instantiate neural networks
- Integration tests
Priority
Medium - This is a quality-of-life improvement that improves architecture consistency but doesn't block functionality.
Labels
refactoring, architecture, breaking-change