Direct Accces to Microsoft.ML.GenAI.LLaMA Model
I would like to convert a LLama model to a Multiclass Classification Model and then finetune the model on my classification labels.
Currently, the Microsoft.ML.GenAI.LLaMA/Module/LlamaModel is internal.
Step 1 Load the pre-trained LLaMA model
string device = "cpu";
string weightFolder = @".\Llama3.1-8B";
string originalWeightFolder = Path.Combine(weightFolder, "original");
string configName = "config.json";
string modelFile = "tokenizer.model";
string checkPointName = "model.safetensors.index.json";
// Load the Pretrained Model: First, load the pretrained LLaMA model using TorchSharp.
model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeToInt8: false, targetDevice: device);
Step 2 Create a classification head
public class ClassificationHead : Module<Tensor, Tensor>
{
private readonly Module<Tensor, Tensor> linear1;
private readonly Module<Tensor, Tensor> relu;
private readonly Module<Tensor, Tensor> linear2;
public ClassificationHead(int d_model,int outputSize, int num_classes) : base(nameof(ClassificationHead))
{
linear1 = Linear(d_model, outputSize); // Intermediate layer
relu = ReLU(); // Activation
linear2 = Linear(outputSize, num_classes); // Output layer
RegisterComponents();
}
public override Tensor forward(Tensor x)
{
var output = linear1.forward(x);
output = relu.forward(output);
output = linear2.forward(output);
return output;
}
}
step 3 Integrate the classification head into the LLaMA model
This step is not possible to override the Forward Pass: The input is passed through the LLaMA model and then through the classification head to get the output logits.
step 4 Set up the training loop to optimize the model using my data
I intend to use Microsoft.ML.GenAI.Core/Trainer/CasualLMSupervisedFineTuningTrainer
The following enhancements are similar, perhaps can be merged to one PR and to address my requirement.
Finetune #7287 Finetune #7325
https://huggingface.co/docs/transformers/v4.48.0/en/model_doc/llama2#transformers.LlamaForSequenceClassification
@LittleLittleCloud can you take a look at this?
@aforoughi1 We can make the llama modules public class. Maybe you can also use LastHiddenState as the input for the classification layer from LlamaForCausalLM.FromPretrained
For fine-tuning, you probably just need to fine-tune the last two linear classification layers instead of looking at CasualLMSupervisedFineTuningTrainer to fine-tune the entire 8B model
I agree with your recommendation.
I also implemented a solution based on LlamaForCausalLM but the forword() method is not getting invoked.
I am using Microsoft.ML.GenAI.LLaMA version 0.23.0-preview.25124.1, but I think the problem is in TrainAsync method.
My package source is https://pkgs.dev.azure.com/dnceng/public/_packaging/dotnet-libraries/nuget/v3/index.json
I am not sure where the symbol file is located to debug it.
see the code below :
using System;
using System.IO;
using Microsoft.ML.GenAI.LLaMA;
using Microsoft.ML.Tokenizers;
using Microsoft.ML.GenAI.Core;
using static TorchSharp.torch.nn;
using static TorchSharp.torch;
using TorchSharp.PyBridge;
using TorchSharp.Modules;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.GenAI.Core.Trainer;
using System.Threading;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Console;
namespace Test
{
public class LLamaForSequenceClassification : nn.Module<CausalLMModelInput, CausalLMModelOutput>, IDisposable
{
private readonly Tokenizer tokenizer;
private readonly LlamaForCausalLM lamaModel;
private readonly Linear classifier;
private readonly LlamaConfig config;
private readonly string device = "cpu";
private readonly string configName = "config.json";
private readonly string modelFile = "tokenizer.model";
private readonly string weightFolder = @"C:\Llama3.1-8B";
private string checkPointName = "model.safetensors.index.json";
public LLamaForSequenceClassification(int num_classes) : base(nameof(LLamaForSequenceClassification))
{
// Prepare the Data for classification, which includes tokenizing the input text and labeling the data.
tokenizer = LlamaTokenizerHelper.FromPretrained(Path.Combine(weightFolder, "original"), modelFile);
var trainDataSet = CreateDataSet();
// Load the Pretrained LLaMA model
config = new LlamaConfig();
lamaModel = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeToInt8: false, targetDevice: device);
// If you only want to train the classification head,
// freeze the LLaMA model's parameters to prevent updates during training:
foreach (var param in lamaModel.parameters())
{
param.requires_grad = false;
}
// Create a new classification head
classifier = Linear(config.HiddenSize, num_classes);
//register the classifier layer using the RegisterComponents() method
RegisterComponents();
PrintRegisteredComponents();
var option = new CasualLMSupervisedFineTuningTrainer.Option
{
BatchSize = 1,
Device = device,
Epoch = 100,
LearningRate = 5e-5f,
};
// Define the Training Loop: Set up the training loop to fine - tune the model with your classification data.
RunTrain(trainDataSet._CasualDataset, option);
}
// This method overrides the forward function to process the input through the model,
// extract the last token's hidden state, use it for classification,
// and then return the logits.
public override CausalLMModelOutput forward(CausalLMModelInput input)
{
// Forward pass through the base model
var modelOutput = lamaModel.forward(input);
// Extracts the last token hidden state for classification
var lastTokenHiddenState = modelOutput.LastHiddenState[0, ^1];
// Forward pass through the classifier
var classifierOutput = classifier.forward(lastTokenHiddenState);
// Wraps the classifier output in a CausalLMModelOutput and returns it
var logits = new CausalLMModelOutput(classifierOutput);
return logits;
}
public void PrintRegisteredComponents()
{
var dict = base.state_dict();
foreach (var component in dict)
{
Console.WriteLine($"{component.Key}: {component.Value}");
}
}
public (IEnumerable<IReadOnlyList<int>> inputIds, IEnumerable<IReadOnlyList<int>> labelIds, CausalLMDataset _CasualDataset) CreateDataSet()
{
//Sentence is input text
//Sentiment is a string: 'positive', 'negative' or 'neutral'
//Label = {0:'neutral', 1:'positive',-1:'negative'}
string filePath = @"C:\FinancialPhraseBank.txt";
var lines = File.ReadAllLines(filePath);
var count = lines.Count();
IEnumerable<Tuple<string, int>> result = lines
.Skip(1) // skip header
.Select(line => line.Split('\t'))
.Select(columns => new Tuple<string, int>(
columns[0],
int.Parse(columns[2])));
// Assign the results
var inputIds = result.Select(tuple => tuple.Item1)
.Select(t => tokenizer.EncodeToIds(t));
var labelIds = result.Select(tuple => tuple.Item2) // Convert each item to a IReadOnlyList<int>
.Select((item, index) => (IReadOnlyList<int>)new List<int> { item }.AsReadOnly());
CausalLMDataset casualDataset = CausalLMDataset.Create(inputIds, labelIds);
return (inputIds, labelIds, casualDataset);
}
public void RunTrain(CausalLMDataset trainDataset, CasualLMSupervisedFineTuningTrainer.Option trainOption)
{
var loggerFactory = LoggerFactory.Create(builder => builder
.AddConsole()
.AddSimpleConsole(options =>
{
options.IncludeScopes = true;
options.SingleLine = true;
options.TimestampFormat = "yyyy-MM-dd HH:mm:ss ";
options.UseUtcTimestamp = true;
})
.SetMinimumLevel(LogLevel.Information));
var logger = loggerFactory.CreateLogger<CasualLMSupervisedFineTuningTrainer>();
var pipeline = new CausalLMPipeline<Tokenizer, LlamaForCausalLM>(tokenizer, lamaModel, device);
var sftTrainer = new CasualLMSupervisedFineTuningTrainer(pipeline,logger);
var ct = new CancellationTokenSource().Token;
logger.LogInformation("RunTrain...");
var stopWatch = System.Diagnostics.Stopwatch.StartNew();
stopWatch.Start();
IAsyncEnumerable<ICausalLMPipeline> results = sftTrainer.TrainAsync(trainDataset,trainOption,ct);
stopWatch.Stop();
var runtimeMS = stopWatch.ElapsedMilliseconds;
logger.LogInformation($"trained in {runtimeMS / 1000} s");
// Evaluate the Model: After training, you should evaluate the model to measure its performance.
// Save the fine-tuned model
logger.LogInformation("Saving the fine-tuned model ...");
string fileName = Path.Combine(@"C:\", @"LLamaForSequenceClassification.safetensors");
////var stateDict = pipeline.TypedModel.state_dict();
var stateDict = pipeline.Model.state_dict();
//Safetensors.SaveStateDict(fileName, stateDict);
loggerFactory.Dispose();
}
}
}
Here's the code which just fine-tune the classifier and freeze the llama model layer
The label must be greater than zero for cross-entrophy, that's one thing to remind
using System;
using System.IO;
using Microsoft.ML.GenAI.LLaMA;
using Microsoft.ML.Tokenizers;
using Microsoft.ML.GenAI.Core;
using static TorchSharp.torch.nn;
using static TorchSharp.torch;
using TorchSharp.PyBridge;
using TorchSharp.Modules;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.GenAI.Core.Trainer;
using System.Threading;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Console;
using System.Threading.Tasks;
using TorchSharp;
namespace Test
{
public class LLamaForSequenceClassification : nn.Module<CausalLMModelInput, CausalLMModelOutput>, IDisposable
{
private readonly Tokenizer tokenizer;
private readonly LlamaForCausalLM lamaModel;
private readonly Linear classifier;
private readonly LlamaConfig config;
private readonly string device = "cuda";
private readonly string configName = "config.json";
private readonly string modelFile = "tokenizer.model";
private readonly string weightFolder = @"C:\Users\\source\repos\Llama-3.2-1B-Instruct";
private string checkPointName = "model.safetensors";
public LLamaForSequenceClassification(int num_classes) : base(nameof(LLamaForSequenceClassification))
{
// Prepare the Data for classification, which includes tokenizing the input text and labeling the data.
tokenizer = LlamaTokenizerHelper.FromPretrained(Path.Combine(weightFolder, "original"), modelFile);
var trainDataSet = CreateDataSet();
// Load the Pretrained LLaMA model
config = new LlamaConfig();
lamaModel = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeToInt8: false, targetDevice: device);
// Create a new classification head
classifier = Linear(2048, num_classes, device: device);
//register the classifier layer using the RegisterComponents() method
RegisterComponents();
PrintRegisteredComponents();
// Run the training loop
RunTrain(trainDataSet.input, trainDataSet.labelIds);
}
// This method overrides the forward function to process the input through the model,
// extract the last token's hidden state, use it for classification,
// and then return the logits.
public override CausalLMModelOutput forward(CausalLMModelInput input)
{
// Forward pass through the base model
var modelOutput = lamaModel.forward(input);
// Extracts the last token hidden state for classification
var lastTokenHiddenState = modelOutput.LastHiddenState[0, ^1];
// Forward pass through the classifier
var classifierOutput = classifier.forward(lastTokenHiddenState);
// Wraps the classifier output in a CausalLMModelOutput and returns it
var logits = new CausalLMModelOutput(classifierOutput);
return logits;
}
public void PrintRegisteredComponents()
{
var dict = base.state_dict();
foreach (var component in dict)
{
Console.WriteLine($"{component.Key}: {component.Value}");
}
}
public (IEnumerable<string> input, IEnumerable<int> labelIds) CreateDataSet()
{
//Sentence is input text
//Sentiment is a string: 'positive', 'negative' or 'neutral'
//Label = {0:'neutral', 1:'positive',2:'negative'}
//string filePath = @"C:\FinancialPhraseBank.txt";
string dummyFileContent = @"Sentence Sentiment
The company's revenue increased significantly this quarter. 1
The stock price dropped due to market conditions. 2
The new product launch was well-received by customers. 1
There are concerns about the company's future growth. 2
The financial report showed stable performance. 0
Investors are optimistic about the upcoming merger. 1
The CEO's resignation caused uncertainty in the market. 2
The quarterly earnings exceeded analysts' expectations. 1
The company is facing legal challenges that may impact its operations. 2
The dividend payout remained unchanged from the previous year. 0";
var lines = dummyFileContent.Split(['\n']);
var count = lines.Count();
IEnumerable<Tuple<string, int>> result = lines
.Skip(1) // skip header
.Select(line => line.Split('\t'))
.Select(columns => new Tuple<string, int>(
columns[0],
int.Parse(columns[1])));
return (result.Select(x => x.Item1), result.Select(x => x.Item2));
}
public void RunTrain(
IEnumerable<string> dataset,
IEnumerable<int> label,
float learningRate = 5e-5f,
int epoch = 5,
int batch = 10)
{
var loggerFactory = LoggerFactory.Create(builder => builder
.AddConsole()
.AddSimpleConsole(options =>
{
options.IncludeScopes = true;
options.SingleLine = true;
options.TimestampFormat = "yyyy-MM-dd HH:mm:ss ";
options.UseUtcTimestamp = true;
})
.SetMinimumLevel(LogLevel.Information));
var logger = loggerFactory.CreateLogger<CasualLMSupervisedFineTuningTrainer>();
ICausalLMPipeline pipeline = new CausalLMPipeline<Tokenizer, LlamaForCausalLM>(tokenizer, lamaModel, device);
var ct = new CancellationTokenSource().Token;
// generate embeddings
// because we won't tune the weight of llama model
// we can generate the embeddings in one round and
// use it for the following training
var embeddings = dataset.Select(x =>
{
var embeddings = pipeline.GenerateEmbeddingFromLastTokenPool(x);
return torch.tensor(embeddings, device: device);
});
logger.LogInformation("RunTrain...");
// train classifier
var optimizer = new Adam(classifier.parameters(), lr: learningRate);
// loss: cross entropy
for (int i = 0; i != epoch; i++)
{
logger.LogInformation($"Epoch {i + 1}/{epoch}");
// forward
var lossesForEachBatch = new List<float>();
for (int j = 0; j != dataset.Count(); j += batch)
{
var batchEmbeddings = embeddings.Skip(j).Take(batch);
var batchLabel = label.Skip(j).Take(batch);
var logits = torch.vstack(batchEmbeddings.ToArray());
var target = torch.tensor(batchLabel.ToArray(), device: device, dtype: ScalarType.Int64);
var output = classifier.forward(logits);
var batchLoss = torch.nn.functional.cross_entropy(output, target);
// backward
optimizer.zero_grad();
batchLoss.backward();
optimizer.step();
lossesForEachBatch.Add(batchLoss.item<float>());
}
logger.LogInformation($"Loss: {lossesForEachBatch.Average()}");
}
}
}
}