TornadoVM
TornadoVM copied to clipboard
[proposal] Support of functional interfaces
Many tasks of heterogenous computing are variations of the same well-known patterns that differ only in functions called per element. Like, reduce, for example, when SoftMax and RMS normalization, in a nutshell, differ only in functions processed.
It would be beneficial to support functional interfaces as parameters of kernels that unwind to real calls (with appropriate restrictions, of course).
That will minimize development time and increase the maintainability of kernels developed in TornadoVM.
Hi @andrii0lomakin , is this a proposal or an issue?
You can build libraries on top of TornadoVM that solve specific functionality for domains of applications such as LLMs, Linear Algebra, Graphics, Physics, etc.
This is a feature request. At least, I have chosen it like this :-) . At the moment, I need to repeat the same boilerplate code repeatedly, and it would benefit me to pass functional interfaces.
This is a feature request.
Ok, I open this for discussion for all community members and TornadoVM maintainers.
At the moment, I need to repeat the same boilerplate code repeatedly, and it would benefit me to pass functional interfaces.
What do you mean by passing functional interfaces? At which level? Tasks within TaskGraph receive functional interfaces already. Do you have any example in mind that you can share?
As I mentioned, one can build libraries on top. For example, LLM and transformer library that contain softmax, normalization, reductions and the matmul. Is this what you are referring to?
Simplest example.
I have an RMS norm layer that, in a nutshell, consists of reduce kernels that are called in layers as described here: https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf.
Initially, I reduce the squired values and then just call the plain sum kernel in subsequent layers. I can create a single kernel that accepts Function as a parameter and inline it during compilation, but instead, I need to repeat the same code.
The same is true for SoftMax, which uses a reduce function as its denominator.
I suppose there are plenty of other such use cases.
P.S. I am ready to implement this feature myself, but I need your feedback, of course.
Reductions are supported in TornadoVM in two ways:
- Using the combination of
@Paralleland@Reducefor loop-parallelism. The TornadoVM JIT Compiler is able to generate efficient reductions for simple test-cases (min, max, sum, etc). Details: https://www.researchgate.net/publication/327871451_Using_Compiler_Snippets_to_Exploit_Parallelism_on_Heterogeneous_Hardware_A_Java_Reduction_Case_Study
Note that the final reduction happens on the device. The TornadoVM JIT compiler generates two kernels from a single reduction kernel (one to run in parallel, and the second to perform the final reductions from the remaining work-groups).
- Using the Kernel API and the
KernelContextto access local memory and barriers. In this case, the developer needs to implement the whole logic. This might be also the path to follow for more complex reductions.
@jjfumero seems like I was not clear enough, I know how to implement reduce and have knowledge of barriers and etc.
I mean that if I had the ability to pass the function that I need to call when I perform reduce, I would not need to repeat the same boilerplate code again and again.
For example in my code I have two kernels, one does:
localSnippet[context.localIdx] = inputTensor.get(currentInputOffset);
and another one
float value = inputTensor.get(currentInputOffset);
localSnippet[context.localIdx] = value * value;
in general I can pass just a function in both kernels, but instead I need to repeat the same code again and again.
P.S. In general when I raise the issue I have impression that I always receive information how to implement basics of functionality instead of discussion of concrete issue in depth, probably I need change something in my conversational style to avoid this :-)
As for TorandoVM annotations, I find them a good entry point for developers who want to learn about heterogenous computing, but because all those "complications" are done for nothing but performance, at least at this concrete stage they are not quite suitable for production usage, at least in commercial applications or in librariries that are created to support commercial tools.
at least at this concrete stage they are not quite suitable for production usage
TornadoVM is an academic project fully developed and maintained by Master and PhD students, researchers and staff at The University of Manchester. It is not a product, at least yet. Feedback and contributions are very welcome.
In general when I raise the issue I have impression that I always receive information how to implement basics of functionality instead of discussion of concrete issue in depth, probably I need change something in my conversational style to avoid this :-)
More concrete questions with test cases will be useful.
If I sketch what you want (pseudocode):
public static void sampleReduction(KernelContext context,
FloatArray a,
FloatArray b,
FunctionalInterface f) {
int globalIdx = context.globalIdx;
int localIdx = context.localIdx;
int localGroupSize = context.localGroupSizeX;
int groupID = context.groupIdx; // Expose Group ID
float[] localA = context.allocateFloatLocalArray(256);
localA[localIdx] = a.get(globalIdx);
for (int stride = (localGroupSize / 2); stride > 0; stride /= 2) {
context.localBarrier();
if (localIdx < stride) {
localA[localIdx] = f.apply(localA[localIdx], localA[localIdx + stride]); // Use of a functional interface
}
}
if (localIdx == 0) {
b.set(groupID, localA[0]);
}
}
The functional interface could potentially be used at any level and in different scenarios, not just for the compute. Is this a better approximation?
@jjfumero
First of all I truly believe that TornadoVM has great potential and is doing a lot of advertising for this project. Hopefully with some positive outcome. If all goes well I hope there will be a lot of contributions from my side.
Do not understand me wrong, my last observation was intended only to improve the quality of observation and nothing more than that.
The functional interface could potentially be used at any level and in different scenarios, not just for the compute. Is this a better approximation?
Absolutely, thank you for your summarization.
As I have mentioned above I am ready to implement this feature myself, I am not sure that it fits project design.
Comments and feedback are very valuable for us so we really appreciate your feedback. Hopefully with the help of community members like you, TornadoVM can improve in many aspects.
This feature looks a great addition. If you want to implement these cases, feel free to open a PR.
@jjfumero Cool, I am on it then. I will provide PR in a weeks.
@jjfumero, here is a sketch of the steps that I will follow to implement the given issue.
First of all, only TaskX functional interfaces will be accepted as arguments.
I also will rewrite uk.ac.manchester.tornado.runtime.analyzer.TaskUtils#resolveMethodHandle to use ASM. It looks more robust and maintainable to me. I will likely add more checks to ensure that only lambda's of the correct form will be passed.
The algorithm will be similar to already implemented in uk.ac.manchester.tornado.runtime.analyzer.TaskUtils#resolveMethodHandle:
- Find a static method in the task code passed to the
resolveMethodHandle. - Find functional interface arguments in the list of arguments to the static method. Only
TaskXinstances will be allowed. - If there are no instances of TaskX passed, just return the original method.
- Otherwise, use the ASM method visitor and class writer to navigate over passed in static method and replace all calls to functional interfaces by the call to a static method.
As the result of the last step, the byte code of a new class will be generated, and the class will be defined using the jdk.internal.misc.Unsafe#defineClass method.
P.S. With such an approach, I am thinking about loosening the requirements of the lambda code passed in the task.
As for me, that is enough of TaskX to be stateless and not use this pointer of the callee. In such cases, we can always generate a class with a static method that will represent a given task and use it in TornadoVM.
P.S.2 In later versions, I am going to validate the passed-in tasks using ASM to throw more meaningful exceptions than it is now when, in many cases, some kind of cryptic errors are thrown, leaving the users puzzled about what is really going on.
IMHO, GrallVM JIT can be successfully replaced by the upcoming HAT project but it seems like it will take a while before the first version will be provided by the OpenJDK team.
Follow up questions:
only TaskX functional interfaces will be accepted as arguments.
- Does this proposal change the TornadoVM Task-Graph API?
- Can you share test-cases about how the proposal looks like?
- Does this proposal plan to change the TornadoVM Runtime component?
- Are you planing to adapt/extend the codgen (backends)?
With such an approach, I am thinking about loosening the requirements of the lambda code passed in the task.
Which requirements are you referring to?
In such cases, we can always generate a class with a static method that will represent a given task and use it in TornadoVM.
If we have the reference to the method already, why do we need to generate host code to be able to compile with TornadoVM? I did not get this part.
HI @jjfumero
Does this proposal change the TornadoVM Task-Graph API?
No. API stays the same.
Can you share test-cases about how the proposal looks like?
I can not provide test-case now. But I will provide conceptual usage:
Kernel:
public static void sampleReduction(KernelContext context,
FloatArray a,
FloatArray b,
BiFunction<Float, Float> f) {
int globalIdx = context.globalIdx;
int localIdx = context.localIdx;
int localGroupSize = context.localGroupSizeX;
int groupID = context.groupIdx; // Expose Group ID
float[] localA = context.allocateFloatLocalArray(256);
localA[localIdx] = a.get(globalIdx);
for (int stride = (localGroupSize / 2); stride > 0; stride /= 2) {
context.localBarrier();
if (localIdx < stride) {
localA[localIdx] = f.apply(localA[localIdx], localA[localIdx + stride]); // Use of a functional interface
}
}
if (localIdx == 0) {
b.set(groupID, localA[0]);
}
}
Usage:
taskGraph.task(C::sampleReduction, context, a, b, (BiFunction) (a1, a2) -> a1 + a2)
Does this proposal plan to change the TornadoVM Runtime component?
As I can see now, only TaskUtils#resolveMethodHandle will be changed.
Are you planing to adapt/extend the codgen (backends)?
No. But I will probably need to add support for handling primitive wrappers, we will see.
If we already have the reference to the method, why do we need to generate host code to be able to compile it with TornadoVM? I did not get this part.
AFAIK most backends do not support polymorphic calls, so passing of lambda essentially means implicit generation of new kernel with passed function. Kotlin works exactly the same way by inlining passed in lambdas and generation of artificial functions to decrease object allocations.
Only TaskX instances will be allowed.
Do not think this requirement is actually needed. Will check during concrete implementaiton.
Which requirements are you referring to?
Let us just skip it for a moment, from my experience the only robust way to pass kernels as of now is to pass static methods only. For example this code fails
taskGraph.task("copyVector",
(source, sourceOffset, destination, destinationOffset, length) -> {
for (@Parallel int i = 0; i < length; i++) {
destination.set(destinationOffset + i, source[sourceOffset + i]);
}
}, arrayToCopy, 0, resultArray, 0, arrayToCopy.length);
Though the same code in static method works like a charm, but I need to perform deeper investigation about reasons. Seems like it all steams down to the handling of Unbox(ing) that I suppose I will need to deal anyway during implementation.
In general I want to do the following with primitive wrappers:
Allow only the following operations on wrappers in passed in code:
- Pass them as an parameters to the functions.
- Assign to another variable.
- Perform arithmetic operations.
- Unboxing (xValue that corresponds to the passed type, according to JLS).
Those checks will be performed during resolving of method handles and will allow essentially to replace wrappers by primitives. But I will be ready to discuss it in more details when I will go further in implementation.
@jjfumero Did I answer your questions?
Yes. I think we have different views on implementation, which is normal. If you want to work on this, I suggest reviewing the proposal when you have a PoC, and we can iterate on this.
I think what you will see is a new parameter in the Graal IR that corresponds to your new lambda function. Then, from my view, you can use Graal to get access to that lambda. To me, the bulk of the work is in the code gen. But again, it could be just another way to implement the same functionality.
Yes. I think we have different views on implementation, which is normal. If you want to work on this, I suggest reviewing the >proposal when you have a PoC, and we can iterate on this.
Sounds like a plan, thank you.