RustaCUDA
RustaCUDA copied to clipboard
Figure out a way to make the context API truly safe
Maybe someone can find a way to make the Context API more safe than it is. I haven't been able to think of anything so far.
My advice is probably checking out some of the other CUDA wrapper implementations for other languages. PyCUDA seems to be the most popular non C/C++ binding set, and I'm sure whatever the Haskell binding is doing will be too safe.
Plus, the PyCUDA has some really nifty wrappers to simplify the API that could be used for inspiration.
Hello,
I have an idea for how to make a better Context API.
This is inspired by the Python keyword with
.
with_context!( ctx, {
Code here uses context ...
})
This macro would expand to:
{
Push the context onto thread local context stack.
let ret_val = evaluate_given_code_block();
Pop the context off of thread local context stack.
return ret_val;
}
Inside of the block, the user can use the global context. The context will always be cleaned up at the end.
This same technique is also applicable for Devices
.
Another thing to help manage contexts is: CUDA keeps a ref-count for every context. These are exposed as cuCtxAttach
and cuCtxDetach
.
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#context
Using this ref-count could make your UnownedContext
obsolete?
I like what ctrl-z is thinking with with_context!, but a macro can only convert input to public methods. What about using a special guard object to do the pre / post actions?
/// Wrap to impl Send
#[doc(hidden)]
pub struct SendContext(Context);
impl SendContext {
fn as_context(&self) -> &Context {
&self.0
}
}
pub struct SyncContext {
context: Mutex<SendContext>
}
impl SyncContext {
pub fn lock() -> LockResult<ContextGuard> { // Maybe have own lock result that includes cuda result, to avoid panicking?
match context.lock() {
Ok(guard) => {
CurrentContext::set_current(guard)
.unwrap();
Ok(guard)
},
Err(poison) => {
let guard = poison.into_inner();
CurrentContext::set_current(guard)
.unwrap();
PoisonError::new(guard)
}
}
}
}
pub struct ContextGuard<'a> {
context: MutexGuard<'a, SendContext>
}
impl<'a> Drop for ContextGuard<'a> {
fn drop(&mut self) {
let current = CurrentContext::get_current()
.unwrap();
if current == self.context {
ContextStack::pop()
.unwrap();
}
else {
/// either panic or do nothing
}
}
}
impl Deref for ContextGuard {
type Target = Context;
fn deref(&self) -> &Context {
self.context.as_context()
}
}
struct SyncDeviceBuffer<T> {
buffer: DeviceBuffer<T>,
context: SyncContext<T>
}
impl<T> SyncDeviceBuffer<T> {
pub unsafe fn uninitialized(context: Arc<SyncContext>, size: usize) -> CudaResult<Self> {
let _c = context.lock()?;
let buffer = DeviceBuffer::uninitialized(size)?;
Ok(Self {
buffer,
context
})
}
}
impl<T> Deref for SyncDeviceBuffer<T> {
type Target = &DeviceSlice<T>;
fn deref(&self) -> &Self::Target {
&*self.buffer
}
}