pytorch_open_registration_example
pytorch_open_registration_example copied to clipboard
Failure to run and build in nightly 2.4
Several issues
- issues with changed const and missing
copy_data
from Allocator class -
torch.register_privateuse1_backend
->+torch.utils.rename_privateuse1_backend
However even with all that I still get failure during runtime:
Traceback (most recent call last):
File "/tmp/pytorch_open_registration_example/open_registration_example.py", line 102, in <module>
x = torch.ones(4, 4, device='foo:2')
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'torch.foo'
My modifications:
diff --git a/cpp_extensions/open_registration_extension.cpp b/cpp_extensions/open_registration_extension.cpp
index b666253..89d4486 100644
--- a/cpp_extensions/open_registration_extension.cpp
+++ b/cpp_extensions/open_registration_extension.cpp
@@ -43,7 +43,7 @@ at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other,
// A dummy allocator for our custom device, that secretly uses the CPU
struct DummyCustomAllocator final : at::Allocator {
DummyCustomAllocator() = default;
- at::DataPtr allocate(size_t nbytes) const override {
+ at::DataPtr allocate(size_t nbytes) override {
std::cout << "Custom allocator's allocate() called!" << std::endl;
void* data = c10::alloc_cpu(nbytes);
return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, 0)};
@@ -57,6 +57,11 @@ struct DummyCustomAllocator final : at::Allocator {
c10::free_cpu(ptr);
}
+ void copy_data(void* dest, const void* src, std::size_t count) const override
+ {
+ default_copy_data(dest,src,count);
+ }
+
at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete;
}
diff --git a/open_registration_example.py b/open_registration_example.py
index 9b31d69..1c1e048 100644
--- a/open_registration_example.py
+++ b/open_registration_example.py
@@ -88,7 +88,7 @@ def test(x, y):
# Option 1: Use torch.register_privateuse1_backend("foo"), which will allow
# "foo" as a device string to work seamlessly with pytorch's API's.
# You may need a more recent nightly of PyTorch for this.
-torch.register_privateuse1_backend('foo')
+torch.utils.rename_privateuse1_backend('foo')
# Show that in general, passing in a custom device string will fail.
try:
``