djl icon indicating copy to clipboard operation
djl copied to clipboard

Memory leak at prediction time?

Open fracpete opened this issue 6 months ago • 2 comments

I'm using DJL 0.32.0 for tabular regression with PyTorch as engine in a multi-threaded application on Linux and I'm encountering a memory leak. Over time, more and more memory outside the JVM is being gobbled up, till the machine becomes unresponsive.

Below are two classes, with one exhibiting the memory leak and another one that doesn't.

In this example, I'm using a slightly modified TabularRegression class (original here).

This code below exhibits the memory leak (models are trained in a thread pool and then predictions are made via Callable objects submitted to a thread pool). The code simulates the behavior of my application, by separating training and inference. Training is an offline process and saves the models to disk. At inference time, these models get loaded into memory on-demand and then applied till the process finishes (in theory only when new models have been built and the application requires restarting):

import ai.djl.Model;
import ai.djl.basicdataset.tabular.AmesRandomAccess;
import ai.djl.basicdataset.tabular.ListFeatures;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.TabNetRegressionLoss;
import ai.djl.translate.Translator;
import ai.djl.zero.Performance;

import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

/**
 * Process
 * 1. Models get trained via thread pool and then stored in static variables.
 * 2. Predictions on new data are made by submitting jobs to thread pool (using models/translators/predictors from static context).
 *
 * Massive memory leak.
 */
public class AmesThreadpoolLeak {

  public static Model[] models;

  public static Translator<ListFeatures, Float>[] translators;

  public static Predictor<ListFeatures, Float>[] predictors;

  public static void waitForFinish(ExecutorService service) {
    service.shutdown();
    while (!service.isTerminated()) {
      try {
	service.awaitTermination(100, TimeUnit.MILLISECONDS);
      }
      catch (InterruptedException e) {
	// ignored
      }
      catch (Exception e) {
	e.printStackTrace();
      }
    }
  }

  public static void main(String[] args) throws Exception {
    ExecutorService service;
    int numModels = 4;

    // load data
    AmesRandomAccess dataset = AmesRandomAccess.builder()
				 .setSampling(32, true)
				 .addNumericFeature("lotarea")
				 .addNumericFeature("miscval")
				 .addNumericFeature("overallqual")
				 .addNumericLabel("saleprice")
				 .build();
    Dataset[] splitDataset = dataset.randomSplit(8, 2);
    Dataset trainDataset = splitDataset[0];
    Dataset validateDataset = splitDataset[1];

    // get translators
    translators = new Translator[numModels];
    for (int n = 0; n < numModels; n++) {
      System.out.println("translator: " + n);
      translators[n] = dataset.matchingTranslatorOptions().option(ListFeatures.class, Float.class);
    }

    // build models
    models = new Model[numModels];
    service = Executors.newFixedThreadPool(numModels);
    for (int n = 0; n < numModels; n++) {
      final int index = n;
      Callable<String> job = new Callable<>() {
	@Override
	public String call() throws Exception {
	  System.out.println("train: " + index);

	  Block block = TabularRegression.createBlock(Performance.FAST, dataset.getFeatureSize(), dataset.getLabelSize());
	  Model model = Model.newInstance("tabular");
	  model.setBlock(block);

	  TrainingConfig trainingConfig =
	    new DefaultTrainingConfig(new TabNetRegressionLoss())
	      .addTrainingListeners(TrainingListener.Defaults.basic());

	  try (Trainer trainer = model.newTrainer(trainingConfig)) {
	    trainer.initialize(new Shape(1, dataset.getFeatureSize()));
	    EasyTrain.fit(trainer, 5, trainDataset, validateDataset);
	  }
	  models[index] = model;
	  return null;
	}
      };
      service.submit(job);
    }
    waitForFinish(service);

    // instantiate predictors
    predictors = new Predictor[numModels];
    for (int n = 0; n < numModels; n++)
      predictors[n] = models[n].newPredictor(translators[n]);

    // predictions via threadpool
    Random rnd = new Random(1);
    for (int i = 0; i < 1000; i++) {
      service = Executors.newFixedThreadPool(numModels);
      for (int n = 0; n < numModels; n++) {
	final int index = n;
	Callable<String> job = new Callable<>() {
	  @Override
	  public String call() throws Exception {
	    ListFeatures features = new ListFeatures();
	    features.add("" + rnd.nextDouble()*1000);
	    features.add("" + rnd.nextDouble()*100);
	    features.add("" + rnd.nextInt(10));
	    Float pred = predictors[index].predict(features);
	    System.out.println(index + ": " + pred);
	    return null;
	  }
	};
	// memory leak happens even with delayed job submission!
	TimeUnit.MILLISECONDS.sleep(250);
	service.submit(job);
      }
      waitForFinish(service);
    }
  }
}

But this code does not have a memory leak (training and prediction happens in the same thread):

import ai.djl.Model;
import ai.djl.basicdataset.tabular.AmesRandomAccess;
import ai.djl.basicdataset.tabular.ListFeatures;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.TabNetRegressionLoss;
import ai.djl.translate.Translator;
import ai.djl.zero.Performance;

import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

/**
 * Runs 4 separate pipelines via a thread pool:
 * train model, loop for making predictions on new data
 *
 * No memory leak.
 */
public class AmesThreadpoolNoLeak {

  public static void waitForFinish(ExecutorService service) {
    service.shutdown();
    while (!service.isTerminated()) {
      try {
	service.awaitTermination(100, TimeUnit.MILLISECONDS);
      }
      catch (InterruptedException e) {
	// ignored
      }
      catch (Exception e) {
	e.printStackTrace();
      }
    }
  }

  public static void main(String[] args) throws Exception {
    ExecutorService service;
    int numModels = 4;

    // load data
    AmesRandomAccess dataset = AmesRandomAccess.builder()
				 .setSampling(32, true)
				 .addNumericFeature("lotarea")
				 .addNumericFeature("miscval")
				 .addNumericFeature("overallqual")
				 .addNumericLabel("saleprice")
				 .build();
    Dataset[] splitDataset = dataset.randomSplit(8, 2);
    Dataset trainDataset = splitDataset[0];
    Dataset validateDataset = splitDataset[1];

    // separate thread for each model build/prediction pipeline
    service = Executors.newFixedThreadPool(numModels);
    for (int n = 0; n < numModels; n++) {
      final int index = n;
      Callable<String> job = new Callable<>() {
	@Override
	public String call() throws Exception {
	  // translator
	  System.out.println("translator: " + index);
	  Translator<ListFeatures, Float> translator = dataset.matchingTranslatorOptions().option(ListFeatures.class, Float.class);

	  // train
	  System.out.println("train: " + index);
	  Block block = TabularRegression.createBlock(Performance.FAST, dataset.getFeatureSize(), dataset.getLabelSize());
	  Model model = Model.newInstance("tabular");
	  model.setBlock(block);

	  TrainingConfig trainingConfig =
	    new DefaultTrainingConfig(new TabNetRegressionLoss())
	      .addTrainingListeners(TrainingListener.Defaults.basic());

	  try (Trainer trainer = model.newTrainer(trainingConfig)) {
	    trainer.initialize(new Shape(1, dataset.getFeatureSize()));
	    EasyTrain.fit(trainer, 5, trainDataset, validateDataset);
	  }

	  // predict
	  System.out.println("predict: " + index);
	  Random rnd = new Random(1);
	  Predictor<ListFeatures, Float> predictor = model.newPredictor(translator);
	  for (int i = 0; i < 10000; i++) {
	    ListFeatures features = new ListFeatures();
	    features.add("" + rnd.nextDouble()*1000);
	    features.add("" + rnd.nextDouble()*100);
	    features.add("" + rnd.nextInt(10));
	    Float pred = predictor.predict(features);
	    System.out.println(index + ": " + pred);
	  }

	  return null;
	}
      };
      service.submit(job);
    }
    waitForFinish(service);
  }
}

I'm using the following dependencies for DJL 0.32.0 in my pom.xml:

  <dependencies>
    <dependency>
      <groupId>commons-io</groupId>
      <artifactId>commons-io</artifactId>
      <version>2.19.0</version>
    </dependency>

    <dependency>
      <groupId>ai.djl</groupId>
      <artifactId>api</artifactId>
      <version>${djl.version}</version>
    </dependency>

    <dependency>
      <groupId>ai.djl</groupId>
      <artifactId>basicdataset</artifactId>
      <version>${djl.version}</version>
    </dependency>

    <dependency>
      <groupId>ai.djl</groupId>
      <artifactId>djl-zero</artifactId>
      <version>${djl.version}</version>
    </dependency>

    <dependency>
      <groupId>ai.djl.pytorch</groupId>
      <artifactId>pytorch-engine</artifactId>
      <version>${djl.version}</version>
    </dependency>
  </dependencies>

The full example Maven project is available here:

https://github.com/fracpete/djl-test

Any idea what could be the reason for the memory leak?

fracpete avatar May 13 '25 03:05 fracpete

The models and predictors both need to be manually closed or run with a try-with-resources block

zachgk avatar May 13 '25 16:05 zachgk

That is true, but not the issue here.

In both cases only 4 (four) predictors and models get created in total and then reused.

However, the first class is gobbling up memory like crazy as it is making predictions, reusing the same Predictor instances.

fracpete avatar May 13 '25 19:05 fracpete