djl icon indicating copy to clipboard operation
djl copied to clipboard

llama.cpp on spark

Open lslslslslslslslslslsls opened this issue 10 months ago • 5 comments

Hi djl developers, I'd like to combine the image classification on spark demo and chatbot demo to build llama.cpp on spark, following the design and conventions in djl spark extension.

I find in both image cls on spark demo and djl spark extension code, the core part is like df.mapPartitions(transformRowsFunc), with a model loading process like modelLoader.newPredictor() in transformRowsFunc. It seems the model would be loaded for one partition and then released, and loaded and released for another.

I test loading resnet 50, that takes about 0.3s, which is totally acceptable for predicting an amount of images in a partition. However, for llama model, (e.g., llama-2-7b-chat.Q5_K_S.gguf 4.65G), it takes 13s for each loading process and seems too long and somehow redundant.

Is there any solution to reduce or minimize this model loading time on spark?

lslslslslslslslslslsls avatar Apr 10 '24 07:04 lslslslslslslslslslsls

The model can be re-used within the machine. We typically recommend creating a Predictor per thread. So if you are re-loading the model multiple times, you can avoid it by storing the model globally

zachgk avatar Apr 10 '24 22:04 zachgk

@oreo-yum

In spark, each partition may run on different machine or JVM. DJL model is not serializable, because it need to load native libraries. So you have to load the model in each partition.

frankfliu avatar Apr 10 '24 22:04 frankfliu

The model can be re-used within the machine. We typically recommend creating a Predictor per thread. So if you are re-loading the model multiple times, you can avoid it by storing the model globally

Hi zachgk, thanks for you quick reply. I come to djl solution to get rid of pyspark working model, which load models in each Python process corresponding to each thread in JVM process. I always require executors with high MEM/CPU ratio (e.g., 2 core, 30G memory) that makes the system admin think I have a bad design 😭. Creating a predictor per thread take me back to a similar solution as pyspark.

lslslslslslslslslslsls avatar Apr 11 '24 04:04 lslslslslslslslslslsls

@oreo-yum

The predictor for PyTorch is pretty light-weighted. Only the model takes significant memory. You can use a global Predictor (assume you use PyTorch) per JVM if you want to.

frankfliu avatar Apr 11 '24 04:04 frankfliu

@oreo-yum

The predictor for PyTorch is pretty light-weighted. Only the model takes significant memory. You can use a global Predictor (assume you use PyTorch) per JVM if you want to.

Thanks @frankfliu Frank, It seems I need separate the criteria.loadModel and model.newPredictor, then try to keep criteria.loadModel called once. BTW, I test the below code, and find the condition if (model == null) is not working, it's null every time and reload model for each partition. ModelLoader

class ModelLoader[A, B](val engine: String, val url: String, val inputClass: Class[A], val outputClass: Class[B],
                        var translatorFactory: TranslatorFactory, val arguments: java.util.Map[String, AnyRef])
  extends Serializable {

  @transient private var model: ZooModel[A, B] = _

  /**
   * Creates a new Predictor.
   *
   * @return an instance of `Predictor`
   */
  def newPredictor(): Predictor[A, B] = {
    if (model == null) {
      val criteria = Criteria.builder
        .setTypes(inputClass, outputClass)
        .optEngine(engine)
        .optModelUrls(url)
        .optTranslatorFactory(translatorFactory)
        .optProgress(new ProgressBar)
        .optArguments(arguments)
        .build
      model = criteria.loadModel
    }
    model.newPredictor
  }
}

Is it possible to share the model when multiple partition are running in parallel? My understanding the model weights are immutable data, that is safe to share among threads in the same executor. However I'm not sure how to manage model life cycle. Is there anyway to manage load and release models in djl?

lslslslslslslslslslsls avatar Apr 11 '24 05:04 lslslslslslslslslslsls