xla icon indicating copy to clipboard operation
xla copied to clipboard

Implement GetCompiledMemoryStats for GPU AOT executables

Open jaro-sevcik opened this issue 9 months ago • 3 comments
trafficstars

This implements GetCompiledMemoryStats for ahead-of-time compiled executables. With this patch, one can estimate memory consumption of a JAX function even without access to a GPU.

Unfortunately, the patch duplicates code between unloaded and loaded GPU executables and between GpuThunkAotCompilationResult::GetBufferAssignment() and Compiler::BufferSizeBytesFunction()+GpuCompiler::ShapeSizeBytesFunction(). This could be perhaps improved by exposing the relevant compiler code as static methods, but that does not seem worth the extra complexity.

The patch also threads pointer_size from GpuCompiler to GpuThunkAotCompilationResult so that we can get buffer allocation sizes without direct access to the compiler. Another option would be embedding pointer_size within CompilationResultProto.

Also note that this still does not set generated_code_size_in_bytes correctly - that would require duplicating some code from GpuExecutable::SizeOfGeneratedCodeInBytes().

jaro-sevcik avatar Feb 17 '25 19:02 jaro-sevcik