xla
xla copied to clipboard
Implement GetCompiledMemoryStats for GPU AOT executables
This implements GetCompiledMemoryStats for ahead-of-time compiled executables. With this patch,
one can estimate memory consumption of a JAX function even without access to a GPU.
Unfortunately, the patch duplicates code between unloaded and loaded GPU executables
and between GpuThunkAotCompilationResult::GetBufferAssignment() and Compiler::BufferSizeBytesFunction()+GpuCompiler::ShapeSizeBytesFunction().
This could be perhaps improved by exposing the relevant compiler code as static methods,
but that does not seem worth the extra complexity.
The patch also threads pointer_size from GpuCompiler to GpuThunkAotCompilationResult so that we can
get buffer allocation sizes without direct access to the compiler. Another option would be embedding
pointer_size within CompilationResultProto.
Also note that this still does not set generated_code_size_in_bytes correctly - that would require
duplicating some code from GpuExecutable::SizeOfGeneratedCodeInBytes().