coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

Multifunction consumes linearly increasing memory in the number of functions, even with all the weights shared

Open 0seba opened this issue 5 months ago • 1 comments

🌱 Describe your Feature Request

Hi, when creating a multifunction with many functions that reuses the same weights for all cases, the memory increases on the number of functions. It would be useful to reduce the memory usage for speed and accessibility to use.

Example code to reproduce the problem

This code creates a simple model with 1 billion parameters, and creates many functions with different input sizes for the same model

import os
import shutil
import numpy as np
import coremltools as ct
from coremltools.converters.mil import Builder as mb
import coremltools.converters.mil as mil


w1 = np.random.normal(loc=0.01, size=(16_384, 16_384, 1)).astype(np.float16)
w2 = np.random.normal(loc=0.01, size=(16_384, 16_384, 1)).astype(np.float16)
w3 = np.random.normal(loc=0.01, size=(16_384, 16_384, 1)).astype(np.float16)
w4 = np.random.normal(loc=0.01, size=(16_384, 16_384, 1)).astype(np.float16)


def make_model(length):
    @mb.program(
        input_specs=[
            mb.TensorSpec(
                (1, 16_384, length),
                dtype=mil.input_types.types.fp16,
            ),
        ],
        opset_version=mil.builder.AvailableTarget.iOS18,
    )
    def program(x):
        x = mb.conv(x=x, weight=w1)
        x = mb.conv(x=x, weight=w2)
        x = mb.conv(x=x, weight=w3)
        return mb.conv(x=x, weight=w4)

    cml_converted = ct.convert(
        program,
        compute_units=ct.ComputeUnit.CPU_AND_NE,
        compute_precision=ct.precision.FLOAT16,
        minimum_deployment_target=ct.target.iOS18,
        skip_model_load=True,
    )

    cml_converted.save(f"./model_{length}")


def merge_mfs(mf_filename, new_model, length):
    if os.path.isdir(mf_filename):
        desc = ct.utils.MultiFunctionDescriptor(mf_filename)
    else:
        desc = ct.utils.MultiFunctionDescriptor(None)

    print(f"Adding length {length}, already created lengths: {desc._functions()}")
    desc.add_function(
        new_model,
        src_function_name="main",
        target_function_name=f"length_{length}",
    )
    desc.default_function_name = "length_1"
    ct.utils.save_multifunction(desc, mf_filename)

    shutil.rmtree(new_model)


if __name__ == "__main__":
    mf_name = "mf.mlpackage"
    for i in [1, 2, 4, 6, 8]:
        make_model(i)
        merge_mfs(mf_name, f"model_{i}.mlpackage", i)

Included a video showing the RAM usage.

  • Minute 1:15 peak RAM usage of 4.74GB for 1 function
  • Min 2:06 peak RAM usage of 7.74GB for 2 functions
  • min 4:07 peak RAM usage 9.74GB for 3 functions
  • min 6:31 peak RAM usage of 11.74GB for 4 functions

0seba avatar Sep 27 '24 21:09 0seba