machinelearning icon indicating copy to clipboard operation
machinelearning copied to clipboard

Direct Accces to Microsoft.ML.GenAI.LLaMA Model

Open aforoughi1 opened this issue 11 months ago • 5 comments

I would like to convert a LLama model to a Multiclass Classification Model and then finetune the model on my classification labels.

Currently, the Microsoft.ML.GenAI.LLaMA/Module/LlamaModel is internal.

Step 1 Load the pre-trained LLaMA model

string device = "cpu";
string weightFolder = @".\Llama3.1-8B";
string originalWeightFolder = Path.Combine(weightFolder, "original");
string configName = "config.json";
string modelFile = "tokenizer.model";
string checkPointName = "model.safetensors.index.json";

// Load the Pretrained Model: First, load the pretrained LLaMA model using TorchSharp.
model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeToInt8: false, targetDevice: device);

Step 2 Create a classification head

public class ClassificationHead : Module<Tensor, Tensor>
{
    private readonly Module<Tensor, Tensor> linear1;
    private readonly Module<Tensor, Tensor> relu;
    private readonly Module<Tensor, Tensor> linear2;

    public ClassificationHead(int d_model,int outputSize, int num_classes) : base(nameof(ClassificationHead))
    {
        linear1 = Linear(d_model, outputSize);     // Intermediate layer
        relu = ReLU();                      // Activation
        linear2 = Linear(outputSize, num_classes); // Output layer

        RegisterComponents();
    }

    public override Tensor forward(Tensor x)
    {
        var output = linear1.forward(x);
        output = relu.forward(output);
        output = linear2.forward(output);
        return output;
    }
}

step 3 Integrate the classification head into the LLaMA model

This step is not possible to override the Forward Pass: The input is passed through the LLaMA model and then through the classification head to get the output logits.

step 4 Set up the training loop to optimize the model using my data

I intend to use Microsoft.ML.GenAI.Core/Trainer/CasualLMSupervisedFineTuningTrainer

aforoughi1 avatar Jan 16 '25 09:01 aforoughi1

The following enhancements are similar, perhaps can be merged to one PR and to address my requirement.

Finetune #7287 Finetune #7325

https://huggingface.co/docs/transformers/v4.48.0/en/model_doc/llama2#transformers.LlamaForSequenceClassification

aforoughi1 avatar Jan 20 '25 09:01 aforoughi1

@LittleLittleCloud can you take a look at this?

michaelgsharp avatar Mar 10 '25 07:03 michaelgsharp

@aforoughi1 We can make the llama modules public class. Maybe you can also use LastHiddenState as the input for the classification layer from LlamaForCausalLM.FromPretrained

For fine-tuning, you probably just need to fine-tune the last two linear classification layers instead of looking at CasualLMSupervisedFineTuningTrainer to fine-tune the entire 8B model

LittleLittleCloud avatar Mar 25 '25 18:03 LittleLittleCloud

I agree with your recommendation.

I also implemented a solution based on LlamaForCausalLM but the forword() method is not getting invoked.

I am using Microsoft.ML.GenAI.LLaMA version 0.23.0-preview.25124.1, but I think the problem is in TrainAsync method.

My package source is https://pkgs.dev.azure.com/dnceng/public/_packaging/dotnet-libraries/nuget/v3/index.json

I am not sure where the symbol file is located to debug it.

see the code below :

using System;
using System.IO;
using Microsoft.ML.GenAI.LLaMA;
using Microsoft.ML.Tokenizers;
using Microsoft.ML.GenAI.Core;
using static TorchSharp.torch.nn;
using static TorchSharp.torch;
using TorchSharp.PyBridge;
using TorchSharp.Modules;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.GenAI.Core.Trainer;
using System.Threading;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Console;

namespace Test
{
    public class LLamaForSequenceClassification : nn.Module<CausalLMModelInput, CausalLMModelOutput>, IDisposable
    {
        private readonly Tokenizer tokenizer;
        private readonly LlamaForCausalLM lamaModel;
        private readonly Linear classifier;
        private readonly LlamaConfig config;
        private readonly string device = "cpu";
        private readonly string configName = "config.json";
        private readonly string modelFile = "tokenizer.model";
        private readonly string weightFolder = @"C:\Llama3.1-8B";
        private string checkPointName = "model.safetensors.index.json";

        public LLamaForSequenceClassification(int num_classes) : base(nameof(LLamaForSequenceClassification))
        {
            // Prepare the Data for classification, which includes tokenizing the input text and labeling the data.
            tokenizer =  LlamaTokenizerHelper.FromPretrained(Path.Combine(weightFolder, "original"), modelFile);
            var trainDataSet = CreateDataSet();

            // Load the Pretrained LLaMA model
            config = new LlamaConfig();
            lamaModel = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeToInt8: false, targetDevice: device);

            // If you only want to train the classification head,
            // freeze the LLaMA model's parameters to prevent updates during training:
            foreach (var param in lamaModel.parameters())
            {
                param.requires_grad = false;
            }


            // Create a new classification head
            classifier = Linear(config.HiddenSize, num_classes);
            
            //register the classifier layer using the RegisterComponents() method
            RegisterComponents();

            PrintRegisteredComponents();

            var option = new CasualLMSupervisedFineTuningTrainer.Option
            {
                BatchSize = 1,
                Device = device,
                Epoch = 100,
                LearningRate = 5e-5f,
            };

            // Define the Training Loop: Set up the training loop to fine - tune the model with your classification data.
            RunTrain(trainDataSet._CasualDataset, option);
           
        }

        // This method overrides the forward function to process the input through the model,
        // extract the last token's hidden state, use it for classification,
        // and then return the logits.
        public override CausalLMModelOutput forward(CausalLMModelInput input)
        {
            // Forward pass through the base model
            var modelOutput = lamaModel.forward(input);

            // Extracts the last token hidden state for classification
            var lastTokenHiddenState = modelOutput.LastHiddenState[0, ^1];

            // Forward pass through the classifier
            var classifierOutput = classifier.forward(lastTokenHiddenState);

            // Wraps the classifier output in a CausalLMModelOutput and returns it
            var logits = new CausalLMModelOutput(classifierOutput);
            return logits;
        }

        public void PrintRegisteredComponents()
        {
            var dict = base.state_dict();
            foreach (var component in dict)
            {
                Console.WriteLine($"{component.Key}: {component.Value}");
            }
        }

        public (IEnumerable<IReadOnlyList<int>> inputIds, IEnumerable<IReadOnlyList<int>> labelIds, CausalLMDataset _CasualDataset) CreateDataSet()
        {
            //Sentence is input text
            //Sentiment is a string: 'positive', 'negative' or 'neutral'
            //Label = {0:'neutral', 1:'positive',-1:'negative'}

            string filePath = @"C:\FinancialPhraseBank.txt";
            var lines = File.ReadAllLines(filePath);
            var count = lines.Count();
            IEnumerable<Tuple<string, int>> result = lines
                        .Skip(1) // skip header
                        .Select(line => line.Split('\t'))
                        .Select(columns => new Tuple<string, int>(
                                           columns[0],
                                           int.Parse(columns[2])));

            // Assign the results
            var inputIds = result.Select(tuple => tuple.Item1)
                                 .Select(t => tokenizer.EncodeToIds(t));

            var labelIds = result.Select(tuple => tuple.Item2)  // Convert each item to a IReadOnlyList<int>
                                 .Select((item, index) => (IReadOnlyList<int>)new List<int> { item }.AsReadOnly());

            CausalLMDataset casualDataset = CausalLMDataset.Create(inputIds, labelIds);

            return (inputIds, labelIds, casualDataset);
        }

        public void RunTrain(CausalLMDataset trainDataset, CasualLMSupervisedFineTuningTrainer.Option trainOption)
        {
            var loggerFactory = LoggerFactory.Create(builder => builder
                                                                    .AddConsole()
                                                                    .AddSimpleConsole(options =>
                                                                    {
                                                                        options.IncludeScopes = true;
                                                                        options.SingleLine = true;
                                                                        options.TimestampFormat = "yyyy-MM-dd HH:mm:ss ";
                                                                        options.UseUtcTimestamp = true;
                                                                    })
                                                                    .SetMinimumLevel(LogLevel.Information));
                                        
            var logger = loggerFactory.CreateLogger<CasualLMSupervisedFineTuningTrainer>();
            var pipeline = new CausalLMPipeline<Tokenizer, LlamaForCausalLM>(tokenizer, lamaModel, device);
            var sftTrainer = new CasualLMSupervisedFineTuningTrainer(pipeline,logger);

            var ct = new CancellationTokenSource().Token;

            logger.LogInformation("RunTrain...");

            var stopWatch = System.Diagnostics.Stopwatch.StartNew();
            stopWatch.Start();
            IAsyncEnumerable<ICausalLMPipeline> results = sftTrainer.TrainAsync(trainDataset,trainOption,ct);
            stopWatch.Stop();
            var runtimeMS = stopWatch.ElapsedMilliseconds;
            
            logger.LogInformation($"trained in {runtimeMS / 1000} s");

            // Evaluate the Model: After training, you should evaluate the model to measure its performance.

            // Save the fine-tuned model
            logger.LogInformation("Saving the fine-tuned model ...");
            string fileName = Path.Combine(@"C:\", @"LLamaForSequenceClassification.safetensors");
            ////var stateDict = pipeline.TypedModel.state_dict();
            var stateDict = pipeline.Model.state_dict();
            //Safetensors.SaveStateDict(fileName, stateDict);
            loggerFactory.Dispose(); 
            
        }

    }
}

aforoughi1 avatar Mar 25 '25 20:03 aforoughi1

Here's the code which just fine-tune the classifier and freeze the llama model layer

The label must be greater than zero for cross-entrophy, that's one thing to remind

using System;
using System.IO;
using Microsoft.ML.GenAI.LLaMA;
using Microsoft.ML.Tokenizers;
using Microsoft.ML.GenAI.Core;
using static TorchSharp.torch.nn;
using static TorchSharp.torch;
using TorchSharp.PyBridge;
using TorchSharp.Modules;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.GenAI.Core.Trainer;
using System.Threading;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Console;
using System.Threading.Tasks;
using TorchSharp;

namespace Test
{
    public class LLamaForSequenceClassification : nn.Module<CausalLMModelInput, CausalLMModelOutput>, IDisposable
    {
        private readonly Tokenizer tokenizer;
        private readonly LlamaForCausalLM lamaModel;
        private readonly Linear classifier;
        private readonly LlamaConfig config;
        private readonly string device = "cuda";
        private readonly string configName = "config.json";
        private readonly string modelFile = "tokenizer.model";
        private readonly string weightFolder = @"C:\Users\\source\repos\Llama-3.2-1B-Instruct";
        private string checkPointName = "model.safetensors";

        public LLamaForSequenceClassification(int num_classes) : base(nameof(LLamaForSequenceClassification))
        {
            // Prepare the Data for classification, which includes tokenizing the input text and labeling the data.
            tokenizer = LlamaTokenizerHelper.FromPretrained(Path.Combine(weightFolder, "original"), modelFile);
            var trainDataSet = CreateDataSet();

            // Load the Pretrained LLaMA model
            config = new LlamaConfig();
            lamaModel = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeToInt8: false, targetDevice: device);

            // Create a new classification head
            classifier = Linear(2048, num_classes, device: device);

            //register the classifier layer using the RegisterComponents() method
            RegisterComponents();

            PrintRegisteredComponents();

            // Run the training loop
            RunTrain(trainDataSet.input, trainDataSet.labelIds);
        }

        // This method overrides the forward function to process the input through the model,
        // extract the last token's hidden state, use it for classification,
        // and then return the logits.
        public override CausalLMModelOutput forward(CausalLMModelInput input)
        {
            // Forward pass through the base model
            var modelOutput = lamaModel.forward(input);

            // Extracts the last token hidden state for classification
            var lastTokenHiddenState = modelOutput.LastHiddenState[0, ^1];

            // Forward pass through the classifier
            var classifierOutput = classifier.forward(lastTokenHiddenState);

            // Wraps the classifier output in a CausalLMModelOutput and returns it
            var logits = new CausalLMModelOutput(classifierOutput);
            return logits;
        }

        public void PrintRegisteredComponents()
        {
            var dict = base.state_dict();
            foreach (var component in dict)
            {
                Console.WriteLine($"{component.Key}: {component.Value}");
            }
        }

        public (IEnumerable<string> input, IEnumerable<int> labelIds) CreateDataSet()
        {
            //Sentence is input text
            //Sentiment is a string: 'positive', 'negative' or 'neutral'
            //Label = {0:'neutral', 1:'positive',2:'negative'}

            //string filePath = @"C:\FinancialPhraseBank.txt";
            string dummyFileContent = @"Sentence	Sentiment
The company's revenue increased significantly this quarter.	1
The stock price dropped due to market conditions.	2
The new product launch was well-received by customers.	1
There are concerns about the company's future growth.	2
The financial report showed stable performance.	0
Investors are optimistic about the upcoming merger.	1
The CEO's resignation caused uncertainty in the market.	2
The quarterly earnings exceeded analysts' expectations.	1
The company is facing legal challenges that may impact its operations.	2
The dividend payout remained unchanged from the previous year.	0";
            var lines = dummyFileContent.Split(['\n']);
            var count = lines.Count();
            IEnumerable<Tuple<string, int>> result = lines
                        .Skip(1) // skip header
                        .Select(line => line.Split('\t'))
                        .Select(columns => new Tuple<string, int>(
                                           columns[0],
                                           int.Parse(columns[1])));


            return (result.Select(x => x.Item1), result.Select(x => x.Item2));
        }

        public void RunTrain(
            IEnumerable<string> dataset,
            IEnumerable<int> label,
            float learningRate = 5e-5f,
            int epoch = 5,
            int batch = 10)
        {
            var loggerFactory = LoggerFactory.Create(builder => builder
                    .AddConsole()
                    .AddSimpleConsole(options =>
                    {
                        options.IncludeScopes = true;
                        options.SingleLine = true;
                        options.TimestampFormat = "yyyy-MM-dd HH:mm:ss ";
                        options.UseUtcTimestamp = true;
                    })
                    .SetMinimumLevel(LogLevel.Information));

            var logger = loggerFactory.CreateLogger<CasualLMSupervisedFineTuningTrainer>();
            ICausalLMPipeline pipeline = new CausalLMPipeline<Tokenizer, LlamaForCausalLM>(tokenizer, lamaModel, device);
            var ct = new CancellationTokenSource().Token;

            // generate embeddings
            // because we won't tune the weight of llama model
            // we can generate the embeddings in one round and 
            // use it for the following training

            var embeddings = dataset.Select(x =>
            {
                var embeddings = pipeline.GenerateEmbeddingFromLastTokenPool(x);
                return torch.tensor(embeddings, device: device);
            });

            logger.LogInformation("RunTrain...");

            // train classifier
            var optimizer = new Adam(classifier.parameters(), lr: learningRate);

            // loss: cross entropy
            for (int i = 0; i != epoch; i++)
            {
                logger.LogInformation($"Epoch {i + 1}/{epoch}");

                // forward
                var lossesForEachBatch = new List<float>();

                for (int j = 0; j != dataset.Count(); j += batch)
                {
                    var batchEmbeddings = embeddings.Skip(j).Take(batch);
                    var batchLabel = label.Skip(j).Take(batch);
                    var logits = torch.vstack(batchEmbeddings.ToArray());
                    var target = torch.tensor(batchLabel.ToArray(), device: device, dtype: ScalarType.Int64);
                    var output = classifier.forward(logits);
                    var batchLoss = torch.nn.functional.cross_entropy(output, target);
                    // backward
                    optimizer.zero_grad();
                    batchLoss.backward();
                    optimizer.step();

                    lossesForEachBatch.Add(batchLoss.item<float>());
                }

                logger.LogInformation($"Loss: {lossesForEachBatch.Average()}");
            }
        }
    }
}

LittleLittleCloud avatar Mar 28 '25 06:03 LittleLittleCloud