pybind11 icon indicating copy to clipboard operation
pybind11 copied to clipboard

Issue with method returning std::unique_ptr combined with trampoline class

Open bpatzak opened this issue 6 years ago • 1 comments

Hello, I encountered the problem with method returning std::unique_ptr. When exposed without trampoline class, everything works, however when trampoline class is involved, the code does not compile and compiler complains for missing type_caster. The behavior is illustrated on attached minimal example.

#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
namespace py = pybind11;

class Status {
public:
  Status() {}
};

class Factory {
public:
  Factory() {}
  virtual std::unique_ptr<Status> GetStatus() const {
    return std::unique_ptr<Status> (new Status);
  }
};

template <class FactoryBase = Factory> class PyFactory: public FactoryBase {
public:
  using FactoryBase::FactoryBase;
  std::unique_ptr<Status> GetStatus() const override {
    PYBIND11_OVERLOAD(std::unique_ptr<Status> , FactoryBase, GetStatus, );
  }
};


PYBIND11_MODULE(demo2, m) {

  py::class_<Status>(m, "Status")
    .def(py::init<>())
    ;

  py::class_<Factory, PyFactory<>>(m, "Factory")
    .def(py::init<>())
    .def("GetStatus", &Factory::GetStatus)
    ;
}

bpatzak avatar Oct 21 '19 06:10 bpatzak

I just encountered the same problem recently, and found the following solution, I hope it can help. The root of the problem is : pybind11 wraps python's c api, which uses reference counting.

So the first solution is to replace std::unique_ptr with std::shared_ptr:

class Factory {
public:
    Factory() {}
    virtual ~Factory() = default;
    //see note 1
    virtual std::shared_ptr<Status> GetStatus() const {
        return std::unique_ptr<Status>(new Status);
    }
};

class PyFactory : public Factory {
public:
    using Factory::Factory;
	
    //see note 2
    std::shared_ptr<Status> GetStatus() const override {
        PYBIND11_OVERLOAD(std::shared_ptr<Status>, Factory, GetStatus);
    }
};

void Test(Factory* factory) {
    // see note 3
    // test for: Python code derived Factory and use
    if (factory) {
        auto result = factory->GetStatus();
        if (result) {

        }
    }
}

PYBIND11_MODULE(TestModule, m) {

    m.def("Test", Test);
    // see note 4
    // avoid crash by change holder type
    py::class_<Status, std::shared_ptr<Status>>(m, "Status")
        .def(py::init<>())
        ;

    py::class_<Factory, PyFactory>(m, "Factory")
        .def(py::init<>())
        .def("GetStatus", &Factory::GetStatus)
        ;
}

But this method needs to modify the C++ library code and may not be accepted, so there is a second way:

class Factory {
public:
    Factory() {}
    virtual ~Factory() = default;
    virtual std::unique_ptr<Status> GetStatus() const {
        return std::unique_ptr<Status>(new Status);
    }
};

class PyFactory : public Factory {
public:
    using Factory::Factory;
	
    //see note 1
    std::unique_ptr<Status> GetStatus() const override {
        return std::unique_ptr<Status>(GetStatusWrapper());
    }

    //see note 2
    Status* GetStatusWrapper() const {
        PYBIND11_OVERLOAD_INT(Status*, Factory, "GetStatus");
        return Factory::GetStatus().release();
    }
};

void Test(Factory* factory) {
    //工厂类是由python端派生的,那么它就会调用PyFactory的GetStatus
    if (factory) {
        auto result = factory->GetStatus();
        if (result) {

        }
    }
}

PYBIND11_MODULE(TestModule, m) {

    m.def("Test", Test);
    // see note 3
    py::class_<Status, std::unique_ptr<Status, py::nodelete>>(m, "Status")
        .def(py::init<>())
        ;

    py::class_<Factory, PyFactory>(m, "Factory")
        .def(py::init<>())
        //see note 4
        //python端派生自Factory时,如果要调用基类的GetStatus,就会走该实现
        .def("GetStatus", [](const Factory* factory)->Status* {
        return dynamic_cast<const PyFactory*>(factory)->GetStatusWrapper();
            })
    ;
}
  • note 1:all python derived classes inherit from PyFactory,Test will call PyFactory::GetStatus();
  • note 2:use PYBIND11_OVERLOAD_INT dispatch to derived python class implement, use Factory::GetStatus().release() for default case;
  • note 3: will crash without this,Check out the link for more information;
  • note 4:python derived class super().GetStatus() will call this .

Test.py:

from  TestModule import Factory,Test

class TestFactory(Factory):
    def __init__(self):
        super().__init__()

    def GetStatus(self):
        r = super().GetStatus()
        print("TestFactory::GetStatus")
        return r

o1 = TestFactory()
Test(o1)
o1.GetStatus()
print("Finish")

You can understand its implementation logic through debug.

inspired by #673

liff-engineer avatar Apr 29 '22 12:04 liff-engineer