Handling of streams for new API layer
Moving over from gh-175, I think stream handling needs to be discussed separately. I had thought quite a bit about it a while ago. My take-aways (very summarized) was that the CAI approach is actually the sane one. The caveat of it is the stream lifetime guarantee question: You really need to keep the stream alive, which is attach it to the lifetime of the exported object in some cases.
A very rough sketch of reasoning. There are two main cases for exchange:
- The
kerneluse case (e.g. numba, FFI). In this case lifetime of everything is effectively "borrowed". By the time thekernelreturns, all tensors (and the stream) are released. - Long term export by the user. In this case, stream order guarantees should be up to the user IMO (besides at the start and end of lifetime likely).
- End of lifetime is the interesting point here, because if we need the stream, then the stream lifetime has to be tied to the lifetime of the export. (e.g. due to asynchronous deallocation?)
- There are probably some mixed cases, but I think it is helpful to focus on these maybe.
There are three solutions right now that we see:
- CAI (cuda array interface): producer provides the stream.
- DLPack: consumer provides the stream.
- Arrow: Use of an event.
Of these, CAI is the only one that cleanly supports the first kernel use-case. Both others seem designed mainly for the second use-case and have a severe limitation here IMO. Now for arrow this may actually not be a problem, because arrow does not support mutability. As long there is no async deallocation the kernel cannot have the data passed in mutated before it finishes (immutable data cannot be mutated except by deallocation which effectively allows mutating it)!
(For DLPack, mutation is a thing and I think that rules out a single event design. There may be ways to use multiple events, but it is tricky. The single event is maybe cleaner (but possibly slower) than the DLPack design (because the DLPack design does not provide additional functionality).)
To explain why CAI is the only one, consider the following:
def library.kernel(arr): # via protocol
return arr + 1
arr = producer.create_array()
res1 = library.kernel(arr)
arr += 2 # mutates arr while kernel is still running
CAI allows the kernel to sync back the stream on which arr += 2 should be happening, a single event or the consumer stream does not. (Leo correctly stated that if the kernel uses xp.from_dlpack(arr + 1) to return the result or even just converts arr via dlpack a second time, a kernel could hack to enforce an order.)
If async deallocation is something to expect then things might get even more complicated (it's possible often, but uncommon in practice maybe). Memory pools are also asynchronous and memory pools and maybe even worse so?! It is honestly a problem I had missed here!
Another approach to the stream lifetime issue, could be to provide a get_stream() which gets the current active one used. I am not quite sure how that works and if lifetime management isn't a problem it seemed easier to me.
(I had written down some more problematic cases, but I think the mutating one is the interesting one. The escalation of the mutating one might be mainly the case where the "mutation" is really just the deallocation itself -- I would have to read my own stuff again to be sure I am not missing things.)
One other thing might be to actually distinguish the two use-cases more explicitly.
In the kernel use-case the consumer could just use the stream provided (maybe unless the user passes stream= explicitly making it their problem). Of course if there are multiple streams that would be a problem.
If the producer knows whether this is a long-term or just a "borrowed" export things might be a bit easier. (Up to refusing a long term export without doing a copy.)
For those of us with terrible memories, could you spell out or link to CAI?
Sorry yeah, always confused me too until I looked at it more often: the cuda array interface (version 3: https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html).
Of these, CAI is the only one that cleanly supports the first kernel use-case. Both others seem designed mainly for the second use-case and have a severe limitation here IMO. Now for arrow this may actually not be a problem, because arrow does not support mutability. As long there is no async deallocation the
kernelcannot have the data passed in mutated before it finishes (immutable data cannot be mutated except by deallocation which effectively allows mutating it)! (For DLPack, mutation is a thing and I think that rules out a single event design. There may be ways to use multiple events, but it is tricky. The single event is maybe cleaner (but possibly slower) than the DLPack design (because the DLPack design does not provide additional functionality).)
Is the issue that there's no way to provide a Stream / Event to the dlpack deleter function and so there's no way to enforce a stream ordered deallocation if you mutate the data on a different stream than which the data was shared? I agree that is a large performance problem and we should try to address that problem if possible. You can record to an event multiple times, so if we wanted to use events, we could make the semantics that anyone mutating the data is required to record their work to the producer provided event and then the deleter is required to properly wait / synchronize on the event as part of the deallocation path?
To explain why CAI is the only one, consider the following:
def library.kernel(arr): # via protocol return arr + 1 arr = producer.create_array() res1 = library.kernel(arr) arr += 2 # mutates arr while kernel is still runningCAI allows the kernel to sync back the stream on which
arr += 2should be happening, a single event or the consumer stream does not. (Leo correctly stated that if the kernel usesxp.from_dlpack(arr + 1)to return the result or even just convertsarrvia dlpack a second time, a kernel could hack to enforce an order.)
In your example here, I think it makes the assumption that both library.kernel(...) and arr += 2 would be run on the stream of arr and I think that's a bad assumption. What if we have two arrays and something looks like:
def library.kernel(arr1, arr2): # via protocol
return arr1 + arr2
arr1 = producer.create_array(stream=X)
arr2 = producer.create_array(stream=Y)
res1 = library.kernel(arr1, arr2) # Should this run on Stream X, Y, or an entirely different stream Z?
arr1 += 2
arr2 += 2
In this case library.kernel would need to either synchronize Streams X and Y in some way with each other, which is probably unexpected behavior, or use its own Stream Z and synchronize it with regards to Stream X and Y, which I think is the more common case. In this situation, implementation wise, you'd probably want to create two events, record on stream X and Y respectively, and then wait on the two events in Stream Z.
This is a rudimentary example and in this rudimentary example, you could argue that a user should manage the synchronization across streams in some kind of way, but this pattern persists and has become quite common today with CUDA Graphs and especially graph compilers. It's not feasible for a user to manage this directly. Additionally, with CUDA Stream Capture for CUDA Graphs, you are actually required to use events as the synchronization primitive and calling cudaStreamSynchronize is not allowed.
Right, I should have been clear that I thought library.kernel() may be using its own stream and that complicates correct stream order especially after the call. Adding the streams (and maybe two arrays) makes that clearer!
you could argue that a user should manage the synchronization across streams in some kind of way
Which honestly may yet be a reasonable approach. If our need is mainly kernels that can use the stream from the inputs, the kernel could ask the user to provide a stream explicitly if they mismatch. So long as this is explicitly required, I am pretty comfortable with saying that it is the users problem.
You can record to an event multiple times, so if we wanted to use events, we could make the semantics that anyone mutating the data is required to record their work
This may actually be a nice solution (compared to a stream, the downside is mainly that streams are fast when the library is happy to just use it).
You would have to update the event for the kernel basically when you are done with the data (for long term export that may be for freeing only). I had some weird reason for avoiding it, but I suspect it wasn't meaningful as such. I think the reason was if you have:
arr = producer.create_array(stream=X)
view = consumer.from_dlpack(arr)
library1.kernel(view) # also goes via dlpack
library2.kernel(view)
del arr, view # or do things
then consumer (i.e. view) has to create a new event on some stream and propagate that back to arr (and producer) when view is deleted.
But that isn't a problem as such, every time we create a view we must create a new event, that is all.
(I had thought of it in a slightly different context where I thought it is neat if creating a new view can just pass on a view of itself with refcounting basically.)
I think there are two kinds of design patterns in libraries we see so far
- D0: Stream life-cycle management is tied to array life-cycle management, something discussed in this post
- D1: Stream life-cycle is managed independent from the array themselves, usually through a context manager, and user manages the stream synchronization manually, and ensuring del lifecycle (through record_stream call). PyTorch is an example:
def my_lib_foo(x, out):
x_view = library.from_dlpack(x)
out_view = library.from_dlpack(out)
# can be done through optional_out_env_stream
stream = query_torch_stream(x.device)
library.kernel(x_temp, out, stream)
x = torch.empty(shape, dtype, device="cuda:0")
stream = torch.Stream()
# explicit insert dependency
stream.wait_stream(torch.cuda.default_stream(cuda))
with torch.stream(stream):
# all kernels should now run on stream
out = torch.empty(shape, dtype, device="cuda:0")
my_lib_foo(x, out)
# ensures that x do not get de-allocated before stream compute
# or use wait_stream to sync back
x.record_stream(stream)
Compared to the "kernel" usecase mentioned in the initial post, the difference in D1 is that the stream is pasesed as a view (and not borrowed) whose liveness only guaranteed before switch of the surrounding torch stream context
D1 is lower level approach since such management and correctness falls onto the shoulder of users to use the stream API and context correctly. While it is debatable to argue whetehr D0 or D1 is better. From the DLPack's pov, there is a need to take the common demoninator, which means considering lowe-level usecase for now (D1), and separate stream ownership from the tensor data structure. In such case, we acknowledge that stream related life cycle management is hard, and we do not attempt to extend the stream lifecycle beyond the overall local scope my_lib_foo related compute for now.
In the meantime, I agree it is good to also separately think about the design space of stream ownership that goes beyond the local scope, which do not have to immediately tied to the local proposal(as long as they are compatible)
I agree that frameworks / libraries do not wish to tie stream lifetimes to dlpack and that that probably doesn't make sense. Using stream context managers is a good example of why that doesn't make sense. This is why we originally used consumer provided streams for the Python protocol.
What I am now proposing is that we should use producer provided events as a synchronization mechanism. In your D1 pattern above, while that works for PyTorch, CuPy, and other libraries that expose the ability to interact with CUDA streams, other frameworks, like JAX / Tensorflow (I believe...) do not expose CUDA streams directly, but use CUDA streams internally. I believe historical conversations indicated it wasn't possible for them to tie the lifetime of these internal CUDA streams into some kind of dlpack deleter, so in this case what would be the correct behavior? Should JAX / Tensorflow pass an internally used stream through the optional_out_env_stream even though there's no ability for the dlpack deleter or the user to control its lifetime?
It's still a bit unclear to me what we can and cannot do (not that we will):
- The proposed
dlpack_from_py_object(self, void **stream)API to pass out a stream forselfthat is temporarily valid is OK for everyone, including JAX? - Guaranteeing the lifetime of the stream is really so bad? (Maybe I shouldn't ask, it seems likely one can avoid it either way. I agree it is weird, I think to me it seemed like a "you are on your own" thing if you export a
viewbeyondwithstatement boundaries.) - I am not clear whether events have a general advantage over streams beyond avoiding lifetime issues?
- Events are more precise, but is there an important scenario where waiting for the event rather than a stream is very helpful? (I realize that without a
StreamWaitStreamthere will be an event involved either way, but if any producer always creates the event viaRecordEvent(stream)at export time then I think they might as well pass the stream -- without lifetime guarantee.)
- Events are more precise, but is there an important scenario where waiting for the event rather than a stream is very helpful? (I realize that without a
In kernels the stream lifetime being temporary is no problem. The stream is guaranteed to be valid until the lifetime of the view ends. If you were to expose that to Python you would have to use a context manager, to do that:
def my_kernel_written_in_cupy_but_accepts_all_dlpack_arrays(arr):
with cupy.from_dlpack_ctx(arr) as view:
# work with view
Unless the user has a view.get_current_stream() and can sync manually I think a with conceptually the right thing in Python either way.
(That is without it the user would need @tqchen query_torch_stream but it has to be available via the view or the exchange API, while with a context you only need what is currently proposed.)
For long-term export (without a context), I think it is totally fine to tell users that it's all up to them. I would say that x = cupy.from_dlpack(y) must make sure that x is synchronized to start with (i.e. in the case where y is immutable we are golden). But that is the only guarantee we should care about.
The problem then is just deletion, I believe. Maybe that could be solved by enabling explicit synchronization in the deleter when needed, e.g. via dltensor->delete_with_stream(stream) (or an event). That means the producer can do what is necessary to ensure the right lifetime of the memory (in the simplest case, they synchronize the stream or don't use any async deallocation) -- there should be a way to pass something special to signal that deletion is OK at any time.
Thanks for the further discussions
Should JAX / Tensorflow pass an internally used stream through the optional_out_env_stream even though there's no ability for the dlpack deleter or the user to control its lifetime
Here is my understanding of the overall state:
For systems like JAX/XLA, a more canonical way perhaps is to have a "graph callback op" that it calls during graph execution, at which point it should expose its internal stream, the same API could still be used for handling such graph callback at C layer related to XLA FFI (and not interact with python). The main reason is that JAX is not eagerly executing(based on my understanding), so we need to bring the op to the graph executor, at which point stream is used.
Then in the specific case of callback context, assuming it is implemented via python C extension, passing internal stream via optional_out_env_stream still makes sense, and the liveness of the stream is only guaranteed during the period of the callback.
In other cases, when the graph is not executing (and only constructing), likely passing the default stream and sync to that is still the safest option.
Likely the C extension can be used in former, or we can detect the context and go for the safe option.
My understanding is that the current state of JAX / XLA is that when trying to export via DLPack it effectively triggers execution and blocks until the data being exported is ready.
What you described as far as pulling it into the XLA graph sounds more optimal, but I think it's an unrealistic expectation for users and even other library developers to work with that constraint. I.E. if I'm writing a library primarily with CuPy using streams and want to be able to use an arbitrary user provided object that supports DLPack, how could this work with JAX? It sounds like the options would be:
- JAX is required to make sure to make sure the data is synchronized before returning the dlpack tensor
- JAX passes the pointer to its internally managed stream via the output stream pointer, but there's 0 guarantees about lifetime and we could end up with a dangling pointer
- JAX creates a separate "external" stream and guarantees the data is ordered on that stream and exposes that "external" stream via some kind of API to allow user control of its lifetime outside of dlpack
- I write my CuPy-based library to fit into an XLA graph paradigm instead of using the standard CUDA streams / events / graphs paradigm
None of these really feel like viable options to me outside of the status quo where JAX synchronizes before returning the dlpack tensor.
I feel likely we need to go for the conservative recommendation (maintaining the status quo):
- JAX make sure data is synchronized to a known stream(default stream)
- JAX pass the default stream via
out_env_stream
Until JAX provides some form of "callback option", which may be able to do more in that context, as such, JAX may be able to pass its execution stream as out_env_stream , and that stream's liveness is only guaranteed locally(aka during the lifetime of that callback)
Why not use an event instead of a stream then? The event can be properly recorded whether in an XLA or CUDA Graph, can be gracefully used with streams, or can be synchronized to with respect to host code. As far as I can tell it handles all of the situations we care about.
The downside is that we add some complexity to the dlpack protocol implementation in that pushing Event lifetime management to the user isn't really feasible / desirable, so we'd probably need to define the semantics and incorporate it into the deleter function?
On event, you are right one issue is the implementation complexity, indeed deleter lifecycle would be one issue (where-as in case of torch it can be added complexity).
From the user facing API pov, seems the most common API are still streamlines through stream. e.g. the canonical torch example of cuda graph capture, so stream preserving should be the common case we would like to optimize for.
s = torch.cuda.Stream()
x = torch.randn(8, device="cuda")
g = torch.cuda.CUDAGraph()
with torch.cuda.stream(s):
with torch.cuda.graph(g):
_ = x + 1
mylib_tensor = mylib.from_dlpack(x)
mylib_kernel(mylib_tensor)
So going through event (and departing from stream) seems to makes it less straightforward to work with such canonical usecase, in effort trying to retro-fiting to other cases.
I am also not 100% sure how exactly XLA might implement the event based recording in such case, or possible speedup we could get through that (versus default stream sync)
Since the speed of light(SoL, the best possible performance) pattern is still try to capture everything as a CUDAGraph when possible and execute, ideally we should try instead work with XLA to offer some form of python callback regions where its execution stream can be passed through, and maybe for JAX to optimally rewrite part of the code to such callback (of cpp extension or python), so JAX path can also have such SoL execution (through stream)
@tqchen Unfortunately, CUDA leaves us a few different options as far as how operations can be enqueued:
A. Directly via a Stream B. Directly via a CUDA Graph, which is then launched on a Stream C. Indirectly via a CUDA Graph via Stream Capture, which is then launched on a Stream
If we want to handle passing data around that allows for leveraging each of these mechanisms it is difficult and we have different tradeoffs available to us.
For Option A:
- If we pass around Streams via dlpack tensors, it gives us the choice of either scheduling
mylib's operations on the passed stream ormylibrecording an Event on the passed stream, and then waiting on that event in its own stream.- Pros: no event creation / destruction overhead, more flexibility in where work is scheduled
- Cons: risk of dangling Stream pointer, risk of libraries stepping on each others toes in work Stream scheduling
- If we pass around Events via dlpack tensors, it requires
mylibto wait on that event in its own stream.- Pros: no risk of dangling pointers, no risk of libraries stepping on each others toes in work Stream scheduling
- Cons: event creation / destruction overhead, less flexibility in where work is scheduled
For Option B:
- If we pass around Streams via dlpack tensors, we can't schedule
mylib's operations into a Graph without a user explicitly handling the graph across library boundaries. There's no streams internal to a graph to allow importing / exporting data to a graph asynchronously without using events.- Pros: no risk of libraries stepping on each others toes in work Graph scheduling
- Cons: risk of dangling Stream pointer, can't schedule work onto a graph across libraries, can't import / export data to a graph
- If we pass around Events via dlpack tensors, we can't schedule
mylib's operations into a Graph without a user explicitly handling the graph across library boundaries. There are CUDA event nodes as part of the CUDA Graph apis, where you can both import and export data to a graph asynchronously and utilize events to control ordering properly.- Pros: Can import / export to a graph, no risk of libraries stepping on each others toes in work Graph scheduling
- Cons: Can't schedule work onto a graph across libraries
For Option C (matches your above example, assuming stream s is used for stream capture onto graph g):
- If we pass around Streams via dlpack tensors, it gives us the choice of scheduling
mylib's operations on the passed capturing stream.- Pros: can schedule work on the same capturing stream across libraries
- Cons: risk of dangling Stream pointer, risk of libraries stepping on each others toes in work Graph scheduling
- If we pass around Events via dlpack tensors,
my_libcan't schedule work on the capturing stream or if a graph is being captured to, on the capturing graph.- Pros: no risk of libraries stepping on each others toes in work Graph scheduling
- Cons: Can't schedule work onto a graph across libraries
Based on writing this up, Events definitely are definitely less flexible than Streams, but the cons of using Streams are very non-trivial, particularly the risk of dangling Stream pointers. Additionally, both Streams and Events don't allow us to fully utilize CUDA Graphs. If we ignore CUDA Graphs, then I would argue that the Pros of using Events outweigh the Cons.
In your example above, I would argue that mylib scheduling it's work on the torch stream / graph is an implementation detail of mylib and will not universally hold true. I.E. say I had something that looked like:
x = torch.randn(8, device="cuda")
with torch.cuda.stream():
x += 1
mylib.some_kernel(x) # Do we expect mylib to run on the pytorch stream via it being passed through the dlpack tensor of x?
x2 = mylib.rand_tensor()
with my_lib.cuda_stream(): # Assume this is doing stream capture on s2
x2 += 1
torch.some_function(x2) # Do we expect torch to run on the mylib stream via it being passed through the dlpack tensor of x2?
with torch.cuda.stream():
torch.some_function(x2) # What about now?
with new_lib.cuda_stream():
new_lib.some_kernel(x, x2) # What do we expect to happen here? What would be in the dlpack tensors of x and x2?
In the above example, things get confusing as far as what work is scheduled where if we're passing streams around. If we're passing events around, the work scheduling semantics are clear but it makes context managers not as likely to work across library boundaries. Should this be a goal of dlpack?
Thanks @kkraus14 , I think the main spirit is that Option C is kinda of the common usecase that we would like to support.
In the above example, we kinda of depend on user to make use of the torch stream, here is one example APIs that i can kind of come up with (which I think is also still common in today's DLPack usages)
A0: Manual stream context syncing
x = torch.randn(8, device="cuda")
with torch.cuda.stream():
x += 1
# sync my lib stream with torch stream
with mylib.use_torch_stream():
mylib.some_kernel(mylib.from_dlpack(x))
# alternatively come with explicit stream passing
mylib.some_kernel(mylib.from_dlpack(x), stream=torch.cuda.current_stream())
This is actually working mechanism of many libraries today. In the most common execution patterns, we usually have one framework(torch) that handles primary stream context, and library(mylib) that invokes kernels that interop with framework operations. Note that in most cases, there is no stream ownership being passed around, instead, we use the current stream as a temp thing.
A1: Auto Stream context passing
With the fast C API proposal, actually we can do something more organic here (written in python but works on C) for such case
def mylib_generic_kernel(x):
data = mylib.from_dlpack(x)
# queried via DLPackPyObjectToManagedTensor function (opt_stream_out)
stream = x.__env_stream__()
with mylib.use_stream(stream):
call_kernel(data)
x = torch.randn(8, device="cuda")
with torch.cuda.stream():
x += 1
# mylib_kernel now automatically use torch stream
mylib_generic_kernel(x)
y = new_lib.tensor()
with new_lib.cuda_stream():
# also works here, uses the newlib stream
mylib_generic_kernel(y)
Note that this does require mylib to structure kernel in certain way(mylib_generic_kernel) and DLPack only plays part of the picture here (One caveat is that torch is not structuring kernels in this way, so the other way still requires manual switching). But ideally DLPack should facilitate such pattern which is common and useful pattern.
Thanks @kkraus14 , I think the main spirit is that Option C is kinda of the common usecase that we would like to support.
I originally come from the DataFrame world and my work generally brings me to places with data dependent shapes / sizes where I generally see Option A much more commonly than Option C which I believe is more common than option B.
For your above examples, A0 only works with tightly coupled libraries. I.E. mylib needs to speak torch for it to work. If you instead wanted to use the stream from cupy, or cuda.core, or some_lib_xyz, it presumably wouldn't work and it isn't feasible for every library to speak to every other library. This is additionally a Python only solution and doesn't suffice for C++ codebases and presumably performance critical code. We could possibly use and extend the __cuda_stream__ protocol (https://nvidia.github.io/cuda-python/cuda-core/latest/interoperability.html#cuda-stream-protocol) in this direction, but we still ultimately want to solve this problem in C/C++.
For A1, would __env_stream__ be a new protocol that we standardize? If so, this sounds like a promising direction from my perspective, especially if we could generalize it to allow passing around Streams, Graphs (and needed node dependency information), or Events.
Thanks @kkraus14 , I agree A1 would be the more preferred ideal flow.
__env_stream__ is something that is already resembled in the DLPackManagedTensorFromPyObject proposal, it is effectively optional_out_env_stream. I agree that perhaps we could also try to add this protocol to the python end
Updated after some discussions, I think it is clear there are generally two kinds of consumer needs for DLPack exchange:
- N0: library support, where consumer.kernel(x, y, z) would like to run a kernel with the data from x, y, z. The consumer is also expected to run the kernel with the same stream context as the producer. For example, when x, y, z is torch.Tensor, consumer.kernel should run on torch.cuda.current_stream(device) to enable no synchronization in kernel launch and maximum compatibility with CUDA graph capture in the producer. This is the desirable behavior for library extension support for frameworks like PyTorch.
- N1: data ingestion and retention, in such a case, the consumer is interested in obtaining the data from the producer and runs further computation on its own stream. In such a case, the consumer like to record a dependency from the producer and likely want to use its own stream.
Different libraries/frameworks may have different needs. Note that N0 is the current most important canonical usecase of DLPack for torch extensions. While N1 can be useful for some data processing needs.
We have updated the proposal with two distinct APIs
- DLPackManagedTensorFromPyObject now optionally allow query for out_last_active_stream, where consumers who like N1 behavior can record an event dependency or sync to it
- New API DLPackCurrentWorkStream is now available optionally, for needs of N0
I think it should resolve the issues being bought up so far. Please checkout the updated proposal here https://github.com/dmlc/dlpack/issues/175#issue-3411508155, the PR is also now updated to reflect the new proposal
That sounds reasonable to me (also the change with prev_version seems good over there, so long there is a solution all is good and this also versions slots).
I am a bit unsure if the long term querying is particularly important (i.e. DLPackCurrentWorkStream), since it is also available via re-export, but it is simple and good.
Does this mean that we consider "long term" export safety w.r.t. to asynchronous deallocation a user problem, or would that be a requirement to synchronize back using N1:
arr = producer.create_array(stream=stream1)
view = cp.from_dlpack(arr) # uses stream2
del arr
res = view + 2
del view
producer.do_more_work(stream1) # in theory could re-use `view` memory
to be clear, I am happy with either (i.e. if producer may deallocate asynchronously, then it is the users problem to deal with this here -- presumably most producers only do that when users do manual stream stuff, and then it is OK if things are up to them.).
I think it is harder to guarantee "long term" safety on stream in general, so for the case of N0, very likely we only launch on the DLPackCurrentWorkStream and rerun the query everytime.
For he case of N1, likely we need strictly sync back, so recording an event and run on different consumer stream would be better for followup long term jobs outside the specific exchange consumer contetx
Send a followup note, we have updated the proposal to focus on usecases of N0, please see latest updates here https://github.com/dmlc/dlpack/issues/175