Update codegen for in-place foreach to return `List[Tensor]`
Fixes #104817
Examples of generated in-place foreach functions -- add and addcmul
::std::vector<at::Tensor> _foreach_add__Scalar(c10::DispatchKeySet ks, at::TensorList self, const at::Scalar & scalar) {
auto self_ = unpack(self, "self", 0);
[[maybe_unused]] auto _any_requires_grad = compute_requires_grad( self );
std::vector<bool> _any_has_forward_grad_self(self.size());
for (const auto& i : c10::irange(self.size())) {
_any_has_forward_grad_self[i] = isFwGradDefined(self[i]);
}
check_inplace(self, _any_requires_grad);
std::vector<c10::optional<at::Tensor>> original_selfs(self.size());
std::vector<std::shared_ptr<AddBackward1>> grad_fns;
if (_any_requires_grad) {
for (const auto& i : c10::irange( self.size() )) {
const auto ith_requires_grad = compute_requires_grad(self[i]);
check_inplace(self[i], ith_requires_grad);
grad_fns.push_back([&]() -> std::shared_ptr<AddBackward1> {
if (!ith_requires_grad) {
return nullptr;
} else {
auto grad_fn = std::shared_ptr<AddBackward1>(new AddBackward1(), deleteNode);
grad_fn->set_next_edges(collect_next_edges( self[i] ));
return grad_fn;
}
}());
}
if (!grad_fns.empty()) {
for (const auto& i : c10::irange(grad_fns.size())) {
auto grad_fn = grad_fns[i];
if (grad_fn != nullptr) {
grad_fn->self_scalar_type = self[i].scalar_type();
}
}
}
}
#ifndef NDEBUG
std::vector<c10::optional<Storage>> self__storage_saved(self_.size());
for (const Tensor& tensor : self_)
self__storage_saved.push_back(
tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt);
std::vector<c10::intrusive_ptr<TensorImpl>> self__impl_saved(self_.size());
for (size_t i=0; i<self_.size(); i++)
if (self_[i].defined()) self__impl_saved[i] = self_[i].getIntrusivePtr();
#endif
{
at::AutoDispatchBelowAutograd guard;
at::redispatch::_foreach_add_(ks & c10::after_autograd_keyset, self_, scalar);
}
#ifndef NDEBUG
for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (self__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(self_))
TORCH_INTERNAL_ASSERT(self__storage_saved[i].value().is_alias_of(self_[i].storage()));
}
for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (self__impl_saved[i] && !at::impl::tensorlist_has_dispatch(self_))
TORCH_INTERNAL_ASSERT(self__impl_saved[i] == self_[i].getIntrusivePtr());
}
#endif
if (!grad_fns.empty()) {
auto differentiable_outputs = flatten_tensor_args( self );
TORCH_INTERNAL_ASSERT(differentiable_outputs.size() == grad_fns.size());
for (const auto& i : c10::irange(grad_fns.size())) {
auto grad_fn = grad_fns[i];
if (grad_fn != nullptr) {
rebase_history(differentiable_outputs[i], grad_fns[i]);
}
}
}
std::vector<c10::optional<at::Tensor>> self_new_fw_grad_opts(self.size(), c10::nullopt);
for (const auto& i : c10::irange(self_new_fw_grad_opts.size())) {
if (_any_has_forward_grad_self[i]) {
auto self_t_raw = toNonOptFwGrad(self[i]);
auto self_tensor = toNonOptTensor(self[i]);
auto self_t = (self_t_raw.defined() || !self_tensor.defined())
? self_t_raw : at::zeros(self_tensor.sizes(), self_tensor.options());
self_t = GradMode::is_enabled() ? self_t.clone() : self_t;
self_new_fw_grad_opts[i] = self_t_raw.defined() ? self_t_raw.copy_(self_t.clone()) : self_t.clone();
}
}
for (const auto& i : c10::irange(self_new_fw_grad_opts.size())) {
auto& self_new_fw_grad_opt = self_new_fw_grad_opts[i];
if (self_new_fw_grad_opt.has_value() && self_new_fw_grad_opt.value().defined() && self[i].defined()) {
// The hardcoded 0 here will need to be updated once we support multiple levels.
self[i]._set_fw_grad(self_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ true);
}
}
return self.vec();
}
...
::std::vector<at::Tensor> _foreach_addcmul__Scalar(c10::DispatchKeySet ks, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value) {
auto self_ = unpack(self, "self", 0);
auto tensor1_ = unpack(tensor1, "tensor1", 1);
auto tensor2_ = unpack(tensor2, "tensor2", 2);
[[maybe_unused]] auto _any_requires_grad = compute_requires_grad( self, tensor1, tensor2 );
TORCH_CHECK(
self.size() == tensor1.size(),
"Tensor lists must have the same number of tensors, got ",
self.size(),
" and ",
tensor1.size());
TORCH_CHECK(
self.size() == tensor2.size(),
"Tensor lists must have the same number of tensors, got ",
self.size(),
" and ",
tensor2.size());
std::vector<bool> _any_has_forward_grad_self(self.size());
for (const auto& i : c10::irange(self.size())) {
_any_has_forward_grad_self[i] = isFwGradDefined(self[i]) || isFwGradDefined(tensor1[i]) || isFwGradDefined(tensor2[i]);
}
check_inplace(self, _any_requires_grad);
std::vector<c10::optional<at::Tensor>> original_selfs(self.size());
std::vector<std::shared_ptr<AddcmulBackward0>> grad_fns;
if (_any_requires_grad) {
for (const auto& i : c10::irange( self.size() )) {
const auto ith_requires_grad = compute_requires_grad(self[i], tensor1[i], tensor2[i]);
check_inplace(self[i], ith_requires_grad);
grad_fns.push_back([&]() -> std::shared_ptr<AddcmulBackward0> {
if (!ith_requires_grad) {
return nullptr;
} else {
auto grad_fn = std::shared_ptr<AddcmulBackward0>(new AddcmulBackward0(), deleteNode);
grad_fn->set_next_edges(collect_next_edges( self[i], tensor1[i], tensor2[i] ));
return grad_fn;
}
}());
}
if (!grad_fns.empty()) {
for (const auto& i : c10::irange(grad_fns.size())) {
auto grad_fn = grad_fns[i];
if (grad_fn != nullptr) {
grad_fn->self_scalar_type = self[i].scalar_type();
if (grad_fn->should_compute_output(2)) {
grad_fn->tensor1_ = SavedVariable(tensor1[i], false);
}
grad_fn->tensor1_scalar_type = tensor1[i].scalar_type();
if (grad_fn->should_compute_output(1)) {
grad_fn->tensor2_ = SavedVariable(tensor2[i], false);
}
grad_fn->tensor2_scalar_type = tensor2[i].scalar_type();
grad_fn->value = value;
}
}
}
}
#ifndef NDEBUG
std::vector<c10::optional<Storage>> self__storage_saved(self_.size());
for (const Tensor& tensor : self_)
self__storage_saved.push_back(
tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt);
std::vector<c10::intrusive_ptr<TensorImpl>> self__impl_saved(self_.size());
for (size_t i=0; i<self_.size(); i++)
if (self_[i].defined()) self__impl_saved[i] = self_[i].getIntrusivePtr();
std::vector<c10::optional<Storage>> tensor1__storage_saved(tensor1_.size());
for (const Tensor& tensor : tensor1_)
tensor1__storage_saved.push_back(
tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt);
std::vector<c10::intrusive_ptr<TensorImpl>> tensor1__impl_saved(tensor1_.size());
for (size_t i=0; i<tensor1_.size(); i++)
if (tensor1_[i].defined()) tensor1__impl_saved[i] = tensor1_[i].getIntrusivePtr();
std::vector<c10::optional<Storage>> tensor2__storage_saved(tensor2_.size());
for (const Tensor& tensor : tensor2_)
tensor2__storage_saved.push_back(
tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt);
std::vector<c10::intrusive_ptr<TensorImpl>> tensor2__impl_saved(tensor2_.size());
for (size_t i=0; i<tensor2_.size(); i++)
if (tensor2_[i].defined()) tensor2__impl_saved[i] = tensor2_[i].getIntrusivePtr();
#endif
{
at::AutoDispatchBelowAutograd guard;
at::redispatch::_foreach_addcmul_(ks & c10::after_autograd_keyset, self_, tensor1_, tensor2_, value);
}
#ifndef NDEBUG
for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (self__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(self_))
TORCH_INTERNAL_ASSERT(self__storage_saved[i].value().is_alias_of(self_[i].storage()));
}
for (size_t i=0; i<self_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (self__impl_saved[i] && !at::impl::tensorlist_has_dispatch(self_))
TORCH_INTERNAL_ASSERT(self__impl_saved[i] == self_[i].getIntrusivePtr());
}
for (size_t i=0; i<tensor1_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (tensor1__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(tensor1_))
TORCH_INTERNAL_ASSERT(tensor1__storage_saved[i].value().is_alias_of(tensor1_[i].storage()));
}
for (size_t i=0; i<tensor1_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (tensor1__impl_saved[i] && !at::impl::tensorlist_has_dispatch(tensor1_))
TORCH_INTERNAL_ASSERT(tensor1__impl_saved[i] == tensor1_[i].getIntrusivePtr());
}
for (size_t i=0; i<tensor2_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (tensor2__storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(tensor2_))
TORCH_INTERNAL_ASSERT(tensor2__storage_saved[i].value().is_alias_of(tensor2_[i].storage()));
}
for (size_t i=0; i<tensor2_.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (tensor2__impl_saved[i] && !at::impl::tensorlist_has_dispatch(tensor2_))
TORCH_INTERNAL_ASSERT(tensor2__impl_saved[i] == tensor2_[i].getIntrusivePtr());
}
#endif
if (!grad_fns.empty()) {
auto differentiable_outputs = flatten_tensor_args( self );
TORCH_INTERNAL_ASSERT(differentiable_outputs.size() == grad_fns.size());
for (const auto& i : c10::irange(grad_fns.size())) {
auto grad_fn = grad_fns[i];
if (grad_fn != nullptr) {
rebase_history(differentiable_outputs[i], grad_fns[i]);
}
}
}
std::vector<c10::optional<at::Tensor>> self_new_fw_grad_opts(self.size(), c10::nullopt);
for (const auto& i : c10::irange(self_new_fw_grad_opts.size())) {
if (_any_has_forward_grad_self[i]) {
auto self_t_raw = toNonOptFwGrad(self[i]);
auto self_tensor = toNonOptTensor(self[i]);
auto self_t = (self_t_raw.defined() || !self_tensor.defined())
? self_t_raw : at::zeros(self_tensor.sizes(), self_tensor.options());
auto tensor1_t_raw = toNonOptFwGrad(tensor1[i]);
auto tensor1_tensor = toNonOptTensor(tensor1[i]);
auto tensor1_t = (tensor1_t_raw.defined() || !tensor1_tensor.defined())
? tensor1_t_raw : at::_efficientzerotensor(tensor1_tensor.sizes(), tensor1_tensor.options());
auto tensor1_p = toNonOptPrimal(tensor1[i]);
auto tensor2_t_raw = toNonOptFwGrad(tensor2[i]);
auto tensor2_tensor = toNonOptTensor(tensor2[i]);
auto tensor2_t = (tensor2_t_raw.defined() || !tensor2_tensor.defined())
? tensor2_t_raw : at::_efficientzerotensor(tensor2_tensor.sizes(), tensor2_tensor.options());
auto tensor2_p = toNonOptPrimal(tensor2[i]);
self_t = GradMode::is_enabled() ? self_t.clone() : self_t;
self_new_fw_grad_opts[i] = self_t_raw.defined() ? self_t_raw.copy_(self_t + maybe_multiply(tensor1_t * tensor2_p, value) + maybe_multiply(tensor2_t * tensor1_p, value)) : self_t + maybe_multiply(tensor1_t * tensor2_p, value) + maybe_multiply(tensor2_t * tensor1_p, value);
}
}
for (const auto& i : c10::irange(self_new_fw_grad_opts.size())) {
auto& self_new_fw_grad_opt = self_new_fw_grad_opts[i];
if (self_new_fw_grad_opt.has_value() && self_new_fw_grad_opt.value().defined() && self[i].defined()) {
// The hardcoded 0 here will need to be updated once we support multiple levels.
self[i]._set_fw_grad(self_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ true);
}
}
return self.vec();
}
VariableType_?.cpp are available at https://gist.github.com/crcrpar/0b439725692bcafca21da1d3e0780d3c.
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/118622
- :page_facing_up: Preview Python docs built from this PR
- :page_facing_up: Preview C++ docs built from this PR
- :question: Need help or want to give feedback on the CI? Visit the bot commands wiki or our office hours
Note: Links to docs will display an error until the docs builds have been completed.
:x: 1 New Failure, 41 Pending
As of commit bad86238df1ece47d4e7f96fe7bf964a1da3e16a with merge base 9c597ff137ead9f7f7ec8fdcbf473de2d328e61b ():
NEW FAILURE - The following job has failed:
-
Lint / lintrunner-noclang / linux-job (gh)
>>> Lint for tools/autograd/gen_python_functions.py:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
@pytorchbot rebase
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here
Successfully rebased inplace-foreach-with-return onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout inplace-foreach-with-return && git pull --rebase)
@albanD as an original participant in the convo at https://github.com/pytorch/pytorch/pull/104780#discussion_r1256289461 I do think you should be involved in this review :P
the simplest fix here is at the python or python binding level.
I'm not sure how I'm supposed to avoid the current diff in aten.
IIUC this would mean I'd make a change into tools/autograd/gen_variable_type so that in-place foreach functions in VariableType_?.cpp have return type and return self.vec() (or just self). One diff would be as follows:
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 4fa0b98c51..5f9773b645 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -958,12 +958,20 @@ def gen_variable_type_func(
result["type_derived_method_definitions_Default"] = [type_definition]
result["wrapper_registrations_Default"] = [wrapper_registration]
else:
+ return_type = cpp.returns_type(f.func.returns, symint=True).cpp_type()
+ if (is_inplace_foreach_with_return := (
+ f.func.kind() == SchemaKind.inplace
+ and name.startswith("_foreach")
+ and is_tensor_list_type(f.func.arguments.self_arg.argument.type)
+ )):
+ return_type = tensorListT
if not fn.info:
key = "Default"
type_definition = METHOD_DEFINITION.substitute(
- return_type=cpp.returns_type(
- f.func.returns, symint=True
- ).cpp_type(),
+ return_type=return_type,
type_wrapper_name=type_wrapper_name(f, key),
type_definition_body=emit_body(fn, key),
formals=formals,
@@ -974,9 +982,7 @@ def gen_variable_type_func(
else:
for key in fn.info.keys():
type_definition = METHOD_DEFINITION.substitute(
- return_type=cpp.returns_type(
- f.func.returns, symint=True
- ).cpp_type(),
+ return_type=return_type,
type_wrapper_name=type_wrapper_name(f, key),
type_definition_body=emit_body(fn, key),
formals=formals,
@@ -2159,4 +2165,6 @@ def emit_body(
body.append("reset_grad_accumulator(self);")
if not returns_void:
body.append(f"return {get_return_value(f)};")
+ if is_inplace_foreach and is_tensor_list_type(f.func.arguments.self_arg.argument.type):
+ body.append(f"return self;")
This gives me expected in-place definitions e.g.
at::TensorList _foreach_acos_(c10::DispatchKeySet ks, at::TensorList self) {
auto self_ = unpack(self, "self", 0);
...
return self;
}
but it seems that this would violate a rule:
terminate called after throwing an instance of 'c10::Error'
what():
Mismatch in kernel C++ signatures
operator: aten::_foreach_acos_(Tensor(a!)[] self) -> ()
registered at /home/mkozuki/ghq/github.com/crcrpar/torch-1/build/aten/src/ATen/RegisterSchema.cpp:6
kernel 1: void (c10::ArrayRef<at::Tensor>)
dispatch key: CPU
registered at /home/mkozuki/ghq/github.com/crcrpar/torch-1/build/aten/src/ATen/RegisterCPU.cpp:31375
kernel 2: c10::ArrayRef<at::Tensor> (c10::ArrayRef<at::Tensor>)
dispatch key: Autograd
registered at /home/mkozuki/ghq/github.com/crcrpar/torch-1/torch/csrc/autograd/generated/VariableType_0.cpp:17451
Exception raised from registerKernel at /home/mkozuki/ghq/github.com/crcrpar/torch-1/aten/src/ATen/core/dispatch/OperatorEntry.cpp:130 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x8e (0x7fde1f7a6b6e in /home/mkozuki/ghq/github.com/crcrpar/torch-1/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xca (0x7fde1f7a503a in /home/mkozuki/ghq/github.com/crcrpar/torch-1/torch/lib/libc10.so)
frame #2: c10::impl::OperatorEntry::registerKernel(c10::Dispatcher const&, std::optional<c10::DispatchKey>, c10::KernelFunction, std::optional<c10::impl::CppSignature>, std::unique_ptr<c10::FunctionSchema, std::default_delete<c10::FunctionSchema> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x9ae (0x7fddfa0c9a3e in /home/mkozuki/ghq/github.com/crcrpar/torch-1/torch/lib/libtorch_cpu.so)
frame #3: c10::Dispatcher::registerImpl(c10::OperatorName, std::optional<c10::DispatchKey>, c10::KernelFunction, std::optional<c10::impl::CppSignature>, std::unique_ptr<c10::FunctionSchema, std::default_delete<c10::FunctionSchema> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x130 (0x7fddfa0c01b0 in /home/mkozuki/ghq/github.com/crcrpar/torch-1/torch/lib/libtorch_cpu.so)
frame #4: torch::Library::_impl(char const*, torch::CppFunction&&, torch::_RegisterOrVerify) & + 0x328 (0x7fddfa0f9b58 in /home/mkozuki/ghq/github.com/crcrpar/torch-1/torch/lib/libtorch_cpu.so)
frame #5: <unknown function> + 0x366b8fb (0x7fddfc86b8fb in /home/mkozuki/ghq/github.com/crcrpar/torch-1/torch/lib/libtorch_cpu.so)
frame #6: <unknown function> + 0xc2c2a9 (0x7fddf9e2c2a9 in /home/mkozuki/ghq/github.com/crcrpar/torch-1/torch/lib/libtorch_cpu.so)
frame #7: <unknown function> + 0xbb493f (0x7fddf9db493f in /home/mkozuki/ghq/github.com/crcrpar/torch-1/torch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0x647e (0x7fde211ea47e in /lib64/ld-linux-x86-64.so.2)
frame #9: <unknown function> + 0x6568 (0x7fde211ea568 in /lib64/ld-linux-x86-64.so.2)
frame #10: _dl_catch_exception + 0xe5 (0x7fde20974af5 in /lib/x86_64-linux-gnu/libc.so.6)
frame #11: <unknown function> + 0xdff6 (0x7fde211f1ff6 in /lib64/ld-linux-x86-64.so.2)
frame #12: _dl_catch_exception + 0x88 (0x7fde20974a98 in /lib/x86_64-linux-gnu/libc.so.6)
frame #13: <unknown function> + 0xe34e (0x7fde211f234e in /lib64/ld-linux-x86-64.so.2)
frame #14: <unknown function> + 0x9063c (0x7fde2089063c in /lib/x86_64-linux-gnu/libc.so.6)
frame #15: _dl_catch_exception + 0x88 (0x7fde20974a98 in /lib/x86_64-linux-gnu/libc.so.6)
frame #16: _dl_catch_error + 0x33 (0x7fde20974b63 in /lib/x86_64-linux-gnu/libc.so.6)
frame #17: <unknown function> + 0x9012e (0x7fde2089012e in /lib/x86_64-linux-gnu/libc.so.6)
frame #18: dlopen + 0x48 (0x7fde208906c8 in /lib/x86_64-linux-gnu/libc.so.6)
<omitting python frames>
frame #44: <unknown function> + 0x29d90 (0x7fde20829d90 in /lib/x86_64-linux-gnu/libc.so.6)
frame #45: __libc_start_main + 0x80 (0x7fde20829e40 in /lib/x86_64-linux-gnu/libc.so.6)
frame #46: _start + 0x25 (0x56496402e095 in /home/mkozuki/.pyenv/versions/torchdev1-3.11/bin/python)
Modifying autograd kernel generation signature would mean that the dispatcher would need to understand mutable tensor list returns, which we want to avoid.
Python binding level should refer to something like (See python_torch_function_*.cpp):
// _foreach_abs_
static PyObject * THPVariable__foreach_abs_(PyObject* self_, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"_foreach_abs_(TensorList self)",
}, /*traceable=*/false);
ParsedArgs<1> parsed_args;
auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
if(_r.has_torch_function()) {
return handle_torch_function(_r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
}
// aten::_foreach_abs_(Tensor(a!)[] self) -> ()
auto dispatch__foreach_abs_ = [](at::TensorList self) -> void {
pybind11::gil_scoped_release no_gil;
at::_foreach_abs_(self);
};
dispatch__foreach_abs_(_r.tensorlist(0));
Py_RETURN_NONE;
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
I'd start by looking at tools/autograd/gen_python_functions.py
@crcrpar you should have write access to the repo, due to the recent security incident CI won't run automatically unless your PR is on a branch in pytorch/pytorch repo, so consider rehoming it there so that you don't have to get us to approve and run ci
@crcrpar you should have write access to the repo, due to the recent security incident CI won't run automatically unless your PR is on a branch in pytorch/pytorch repo, so consider rehoming it there so that you don't have to get us to approve and run ci
@ezyang I'm afraid I don't have that access:
~/ghq/github.com/crcrpar/torch-1
% git remote get-url upstream
[email protected]:pytorch/pytorch.git
~/ghq/github.com/crcrpar/torch-1
% git push upstream inplace-foreach-with-return
ERROR: Permission to pytorch/pytorch.git denied to crcrpar.
fatal: Could not read from remote repository.
Please make sure you have the correct access rights
and the repository exists.
@crcrpar you should be added now, sorry about the delay
@ezyang thank you! I'll reopen this pull request after pushing this branch to pytorch/pytorch