machinelearning icon indicating copy to clipboard operation
machinelearning copied to clipboard

Proposal: Pluggable RNG in MLContext (enable deterministic, portable RNGs like MT19937)

Open asp2286 opened this issue 3 months ago • 0 comments

Summary

I propose adding a minimal extension point to let users inject a custom RNG into MLContext, without changing defaults or breaking back-compat. This enables deterministic, portable randomness across platforms and languages (e.g., align with C++ std::mt19937), and allows advanced users to choose an RNG that matches their reproducibility requirements.

Motivation

  • Reproducibility across ecosystems: Many data science stacks standardize on MT19937 (e.g., C++ std::mt19937, NumPy’s legacy PCG/MT usage), making it easier to compare experiments when the same PRNG is available in ML.NET.
  • Zero impact by default: Default behavior remains unchanged and backwards compatible.
  • Testability: Easier to write bitwise-stable tests that don’t depend on underlying System.Random variations.

Design

Add a small interface and an optional parameter to MLContext:

public interface IRandomSource
{
    int Next();
    int Next(int maxValue);
    int Next(int minValue, int maxValue);
    long NextInt64();
    long NextInt64(long maxValue);
    long NextInt64(long minValue, long maxValue);
    double NextDouble();
    float NextSingle();
    void NextBytes(Span<byte> buffer);
}

// Existing constructor remains
public sealed class MLContext
{
    public MLContext(int? seed = null) : this(seed, rng: null) { }

    public MLContext(int? seed, IRandomSource? rng)
    {
        _rng = rng ?? new RandomSourceAdapter(seed is null ? Random.Shared : new Random(seed.Value));
        // ... existing initialization
    }

    internal IRandomSource RandomSource => _rng;
    private readonly IRandomSource _rng;
}

internal sealed class RandomSourceAdapter : IRandomSource
{
    private readonly Random _rand;
    public RandomSourceAdapter(Random rand) => _rand = rand;
    public int Next() => _rand.Next();
    public int Next(int maxValue) => _rand.Next(maxValue);
    public int Next(int minValue, int maxValue) => _rand.Next(minValue, maxValue);
    public long NextInt64() => _rand.NextInt64();
    public long NextInt64(long maxValue) => _rand.NextInt64(maxValue);
    public long NextInt64(long minValue, long maxValue) => _rand.NextInt64(minValue, maxValue);
    public double NextDouble() => _rand.NextDouble();
    public float NextSingle() => _rand.NextSingle();
    public void NextBytes(Span<byte> buffer) => _rand.NextBytes(buffer);
}

asp2286 avatar Sep 13 '25 11:09 asp2286