pybind11
pybind11 copied to clipboard
Issue with method returning std::unique_ptr combined with trampoline class
Hello, I encountered the problem with method returning std::unique_ptr. When exposed without trampoline class, everything works, however when trampoline class is involved, the code does not compile and compiler complains for missing type_caster. The behavior is illustrated on attached minimal example.
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
namespace py = pybind11;
class Status {
public:
Status() {}
};
class Factory {
public:
Factory() {}
virtual std::unique_ptr<Status> GetStatus() const {
return std::unique_ptr<Status> (new Status);
}
};
template <class FactoryBase = Factory> class PyFactory: public FactoryBase {
public:
using FactoryBase::FactoryBase;
std::unique_ptr<Status> GetStatus() const override {
PYBIND11_OVERLOAD(std::unique_ptr<Status> , FactoryBase, GetStatus, );
}
};
PYBIND11_MODULE(demo2, m) {
py::class_<Status>(m, "Status")
.def(py::init<>())
;
py::class_<Factory, PyFactory<>>(m, "Factory")
.def(py::init<>())
.def("GetStatus", &Factory::GetStatus)
;
}
I just encountered the same problem recently, and found the following solution, I hope it can help. The root of the problem is : pybind11 wraps python's c api, which uses reference counting.
So the first solution is to replace std::unique_ptr with std::shared_ptr:
class Factory {
public:
Factory() {}
virtual ~Factory() = default;
//see note 1
virtual std::shared_ptr<Status> GetStatus() const {
return std::unique_ptr<Status>(new Status);
}
};
class PyFactory : public Factory {
public:
using Factory::Factory;
//see note 2
std::shared_ptr<Status> GetStatus() const override {
PYBIND11_OVERLOAD(std::shared_ptr<Status>, Factory, GetStatus);
}
};
void Test(Factory* factory) {
// see note 3
// test for: Python code derived Factory and use
if (factory) {
auto result = factory->GetStatus();
if (result) {
}
}
}
PYBIND11_MODULE(TestModule, m) {
m.def("Test", Test);
// see note 4
// avoid crash by change holder type
py::class_<Status, std::shared_ptr<Status>>(m, "Status")
.def(py::init<>())
;
py::class_<Factory, PyFactory>(m, "Factory")
.def(py::init<>())
.def("GetStatus", &Factory::GetStatus)
;
}
note 1、note 2: changestd::unique_ptrtostd::shared_ptr;note 3:testFactoryderived inpython;note 4:will crash without this,Check out the link for more information.
But this method needs to modify the C++ library code and may not be accepted, so there is a second way:
class Factory {
public:
Factory() {}
virtual ~Factory() = default;
virtual std::unique_ptr<Status> GetStatus() const {
return std::unique_ptr<Status>(new Status);
}
};
class PyFactory : public Factory {
public:
using Factory::Factory;
//see note 1
std::unique_ptr<Status> GetStatus() const override {
return std::unique_ptr<Status>(GetStatusWrapper());
}
//see note 2
Status* GetStatusWrapper() const {
PYBIND11_OVERLOAD_INT(Status*, Factory, "GetStatus");
return Factory::GetStatus().release();
}
};
void Test(Factory* factory) {
//工厂类是由python端派生的,那么它就会调用PyFactory的GetStatus
if (factory) {
auto result = factory->GetStatus();
if (result) {
}
}
}
PYBIND11_MODULE(TestModule, m) {
m.def("Test", Test);
// see note 3
py::class_<Status, std::unique_ptr<Status, py::nodelete>>(m, "Status")
.def(py::init<>())
;
py::class_<Factory, PyFactory>(m, "Factory")
.def(py::init<>())
//see note 4
//python端派生自Factory时,如果要调用基类的GetStatus,就会走该实现
.def("GetStatus", [](const Factory* factory)->Status* {
return dynamic_cast<const PyFactory*>(factory)->GetStatusWrapper();
})
;
}
note 1:all python derived classes inherit fromPyFactory,Testwill callPyFactory::GetStatus();note 2:usePYBIND11_OVERLOAD_INTdispatch to derived python class implement, useFactory::GetStatus().release()for default case;note 3: will crash without this,Check out the link for more information;note 4:python derived classsuper().GetStatus()will call this .
Test.py:
from TestModule import Factory,Test
class TestFactory(Factory):
def __init__(self):
super().__init__()
def GetStatus(self):
r = super().GetStatus()
print("TestFactory::GetStatus")
return r
o1 = TestFactory()
Test(o1)
o1.GetStatus()
print("Finish")
You can understand its implementation logic through debug.
inspired by #673