Keras.NET icon indicating copy to clipboard operation
Keras.NET copied to clipboard

BaseModel Compile Multiple Loss Confusion

Open adstep opened this issue 5 years ago • 4 comments

I'm confused how to call the Compile method with multiple losses.

The method takes a string as a parameter for the loss.

From the method documentation I can see:

String (name of objective function) or objective function. See losses. If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model will then be the sum of all individual losses.

How are we expected to pass a dictionary or a list? Is there a custom format for the string?

adstep avatar Oct 06 '20 22:10 adstep

This is a significant difference in the API for .net, there does not appear to be a way to provide multiple loss functions.

mnelsonwhite avatar Dec 29 '20 16:12 mnelsonwhite

Should be easy to add, the framework is all ready just needs to be extended with more methods. If you got you can follow how the existing compile method id implemented... create PR and merge

deepakkumar1984 avatar Dec 29 '20 22:12 deepakkumar1984

Will do. In the meantime, this extension method should work.

public static class KerasBaseModelExtensions
{
    public static void Compile(
        this BaseModel model,
        StringOrInstance optimizer,
        string[] loss,
        string[] metrics = null,
        float[] loss_weights = null,
        string sample_weight_mode = null,
        string[] weighted_metrics = null,
        NDarray[] target_tensors = null)
    {
        var args = new Dictionary<string, object>();
        args["optimizer"] = optimizer;
        args["loss"] = loss;
        args["metrics"] = metrics;
        args["loss_weights"] = loss_weights;
        args["sample_weight_mode"] = sample_weight_mode;
        args["weighted_metrics"] = weighted_metrics;
        args["target_tensors"] = target_tensors;

        model.InvokeMethod("compile", args);
    }
}

mnelsonwhite avatar Jan 08 '21 06:01 mnelsonwhite

#183 Add support for loss function array

mnelsonwhite avatar Jan 08 '21 06:01 mnelsonwhite

Stale issue message

github-actions[bot] avatar Oct 03 '23 00:10 github-actions[bot]