machinelearning-samples icon indicating copy to clipboard operation
machinelearning-samples copied to clipboard

How to get feature importance from a regression model?

Open sam-wheat opened this issue 3 years ago • 0 comments

References:

https://devblogs.microsoft.com/premier-developer/permutation-implementation-with-ml-net/

https://docs.microsoft.com/en-us/dotnet/machine-learning/how-to-guides/explain-machine-learning-model-permutation-feature-importance-ml-net

The code below is largely copied from the Taxi sample. I am trying to get relative importance of my inputs. I referred to the pages shown above but I am obviously doing something wrong. I have ten inputs to the model but the PermutationFeatureImportance method returns 16 elements. Also - my second attempt - my own PrintFeatureImportance method winds up with the same problem which results in an "index out of bounds" error.

public class RegressionModelService
{
    private readonly DataService dataService;

    public RegressionModelService(DataService dataService) => this.dataService = dataService;

    public ITransformer BuildTrainEvaluateAndSaveModel()
    {
        // STEP 1: Common data loading configuration
        MLContext mlContext = new MLContext(seed: 0);
        Console.WriteLine($"Start data load at {DateTime.Now.ToLongTimeString()}.");
        var sw = Stopwatch.StartNew();
        IDataView trainingDataView = mlContext.Data.LoadFromEnumerable(dataService.GetChunk(0, Program.SetSize));
        IDataView testDataView = mlContext.Data.LoadFromEnumerable(dataService.GetChunk(Program.SetSize, Program.SetSize));
        sw.Stop();
        string elapsed = sw.Elapsed.TotalMinutes.ToString();
        Console.WriteLine($"Data load completed in {elapsed} seconds at {DateTime.Now.ToLongTimeString()}.");

        string[] featureColumnNames = new[] {
            nameof(BatteryHistory.LastReadVoltage),
            nameof(BatteryHistory.HistoricVoltage1),
            nameof(BatteryHistory.HistoricVoltage2),
            nameof(BatteryHistory.CurrentTemperature),
            nameof(BatteryHistory.LastReadTemperature),
            nameof(BatteryHistory.HistoricTemperature1),
            nameof(BatteryHistory.HistoricTemperature2),
            "CD_COMM_FW_VER_Encoded",
            "CD_COMM_HW_VER_Encoded",
            nameof(BatteryHistory.ModuleAgeInDays)
        };

        // STEP 2: Common data process configuration with pipeline data transformations
        var pipeline = mlContext.Transforms
            .CopyColumns("Label", nameof(BatteryHistory.CurrentVoltage))
            .Append(mlContext.Transforms.Categorical.OneHotEncoding(outputColumnName: "CD_COMM_FW_VER_Encoded", inputColumnName: nameof(BatteryHistory.CD_COMM_FW_VER)))
            .Append(mlContext.Transforms.Categorical.OneHotEncoding(outputColumnName: "CD_COMM_HW_VER_Encoded", inputColumnName: nameof(BatteryHistory.CD_COMM_HW_VER)))
            .Append(mlContext.Transforms.NormalizeMeanVariance(outputColumnName: nameof(BatteryHistory.LastReadVoltage)))
            .Append(mlContext.Transforms.NormalizeMeanVariance(outputColumnName: nameof(BatteryHistory.HistoricVoltage1)))
            .Append(mlContext.Transforms.NormalizeMeanVariance(outputColumnName: nameof(BatteryHistory.HistoricVoltage2)))
            .Append(mlContext.Transforms.NormalizeMeanVariance(outputColumnName: nameof(BatteryHistory.LastReadVoltage)))
            .Append(mlContext.Transforms.NormalizeMeanVariance(outputColumnName: nameof(BatteryHistory.CurrentTemperature)))
            .Append(mlContext.Transforms.NormalizeMeanVariance(outputColumnName: nameof(BatteryHistory.LastReadTemperature)))
            .Append(mlContext.Transforms.NormalizeMeanVariance(outputColumnName: nameof(BatteryHistory.HistoricTemperature1)))
            .Append(mlContext.Transforms.NormalizeMeanVariance(outputColumnName: nameof(BatteryHistory.HistoricTemperature2)))
            .Append(mlContext.Transforms.NormalizeMeanVariance(outputColumnName: nameof(BatteryHistory.ModuleAgeInDays)))
            .Append(mlContext.Transforms.Concatenate("Features", featureColumnNames));


        // (OPTIONAL) Peek data (such as 5 records) in training DataView after applying the ProcessPipeline's transformations into "Features" 
        ConsoleHelper.PeekDataViewInConsole(mlContext, trainingDataView, pipeline, 5);
        ConsoleHelper.PeekVectorColumnDataInConsole(mlContext, "Features", trainingDataView, pipeline, 5);

        // STEP 3: Set the training algorithm, then create and config the modelBuilder - Selected Trainer (SDCA Regression algorithm)                            
        var trainer = mlContext.Regression.Trainers.Sdca(labelColumnName: "Label", featureColumnName: "Features");
        var trainingPipeline = pipeline.Append(trainer);


        // STEP 4: Train the model fitting to the DataSet
        //The pipeline is trained on the dataset that has been loaded and transformed.
        Console.WriteLine("=============== Training the model ===============");
        var trainedModel = trainingPipeline.Fit(trainingDataView);

        // STEP 5: Evaluate the model and show accuracy stats
        Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");

        IDataView predictions = trainedModel.Transform(testDataView);
        var metrics = mlContext.Regression.Evaluate(predictions, labelColumnName: "Label", scoreColumnName: "Score");

        ConsoleHelper.PrintRegressionMetrics(trainer.ToString(), metrics);

        // STEP 6: Save/persist the trained model to a .ZIP file
        mlContext.Model.Save(trainedModel, trainingDataView.Schema, Program.ModelPath);

        Console.WriteLine("The model is saved to {0}", Program.ModelPath);


        // Feature Importance -------------------

		// ======================================  PROBLEM HERE ===========================================================
        var permutationMetrics = mlContext.Regression.PermutationFeatureImportance(trainedModel.LastTransformer, predictions, permutationCount: 100); // 16 elements
		
		// Trying a different way, same problem as previous line
        PrintFeatureImportance(mlContext, predictions, featureColumnNames);
        
        // Feature Importance -------------------

        return trainedModel;

    }

    private static void PrintFeatureImportance(MLContext mlContext, IDataView data, string[] featureColumnNames)
    {

        // 2. Define estimator with data  pre-processing steps
        IEstimator<ITransformer> dataPrepEstimator =
            mlContext.Transforms.Concatenate("Features", featureColumnNames)
            .Append(mlContext.Transforms.NormalizeMinMax("Features"));

        // 3. Create transformer using the data pre-processing estimator
        ITransformer dataPrepTransformer = dataPrepEstimator.Fit(data);

        // 4. Pre-process the training data
        IDataView preprocessedTrainData = dataPrepTransformer.Transform(data);

        // 5. Define Stochastic Dual Coordinate Ascent machine learning estimator
        var sdcaEstimator = mlContext.Regression.Trainers.Sdca();

        // 6. Train machine learning model
        var sdcaModel = sdcaEstimator.Fit(preprocessedTrainData);

        ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
            mlContext.Regression.PermutationFeatureImportance(sdcaModel, preprocessedTrainData, permutationCount: featureColumnNames.Count());

        var featureImportanceMetrics = permutationFeatureImportance.Select((metric, index) => new { index, metric.RSquared }).OrderByDescending(x => Math.Abs(x.RSquared.Mean));

        Console.Clear();
        Console.WriteLine("Feature Importance");

        foreach (var feature in featureImportanceMetrics)
        {
            Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}"); // Index out of bounds error here
        }
        Console.WriteLine("Press any key...");
        Console.ReadKey();
    }
}

sam-wheat avatar Jan 11 '22 16:01 sam-wheat