numba
numba copied to clipboard
Caching a function that takes in a jitclass instance
@jitclass([])
class NumbaClass():
def __init__(self):
pass
@njit(cache=True)
def calc(numbaObj):
pass
numbaObject = NumbaClass()
calc(numbaObject)
Each time the script above is ran, it generates a new cache entry for "calc", likely because NumbaClass.class_type.instance_type changes on every run, making the cache flag effectively not work and growing the __pycache__ folder on every run (mine eneded up being a few gigs). Is there a work-around for this? Numba 0.51.2
Thanks for the report. I think your assessment of the situation is correct. Should probably warn users this doesn't work and disable caching.
Cross-linking from here, thanks @stuartarchibald for the tip. Is this on the schedule for being fixed? In the meantime are there any workarounds?
@DannyWeitekamp has put a bunch of structref examples online and I feel pretty sure there's a comment or two out there from core devs about jitclass being subsumed/replaced by structref even though I can't seem to find the reference at the moment. Never having used structref, I don't know if it suffers the same caching issue or is sufficiently capable to handle jitclass-like functionality. Perhaps now is my chance to investigate :)
Any guidance as to how to think about caching functions with a jitclass-ish argument would be most welcome!
The hard way
My workaround for this issue has been to use structrefs instead of jitclasses. @nelson2005 Indeed they do cache just fine. The downside is that they are a bit more cumbersome than jitclasses if you need to do things like define methods on them or access members outside of a jitted function. In both cases the trick is to define a method on their StructRefProxy class (a @property in the case of member access). Also if you need to use a method inside a jitted function then you need to use @overload_method on their type.
Note that I put type in italics because when you subclass types.StructRef to define the type for the structref it isn't actually a fully specified type yet, it's more like a constructor for the actual type which you make by passing that type a set of fields (a list of name+type pairs) . Because of this I usually call the type the TypeTemplate in my code. I guess they aren't the same thing since one might want to have different member type specializations for the same structref.
In the structref example the devs use this helper function structref.define_proxy, which does two things 1) defines a constructor for the structref (i.e. a function that takes in as arguments the values of all of the fields, and outputs the structref instance), 2) defines boxing/unboxing behavior (basically how the structref gets passed btw. python and numba). But I tend to not use this helper function because a lot of times it's nice to write your own constructor so that you can do things like have optional arguments. You do have to be careful with this though because you can get segfaulty weirdness if you don't assign all of the structref's members in the constructor and then try to retrieve them.
Here's a cleaned up dump of some of the code from one of my projects as a reference:
class KnowledgeBase(structref.StructRefProxy):
def __new__(cls, context=None):
context_data = KnowledgeBaseContext.get_context(context).context_data
kb_data = init_kb_data(context_data)
self = kb_ctor(context_data,kb_data) #This is defined below
self.kb_data = kb_data
self.context_data = context_data
return self
def declare(self,fact,name=None):
return declare(self,fact,name) #this is a jitted function defined elsewhere
def retract(self,identifier):
return retract(self,identifier) #this is a jitted function defined elsewhere
def modify(self, fact, attr, val):
return modify(self,fact, literally(attr), val) #this is a jitted function defined elsewhere
@property
def halt_flag(self):
return get_halt_flag(self)
@property
def backtrack_flag(self):
return get_backtrack_flag(self)
def __del__(self):
pass
# kb_data_dtor(self.kb_data)
@njit(cache=True)
def get_halt_flag(self):
return self.halt_flag
@njit(cache=True)
def get_backtrack_flag(self):
return self.backtrack_flag
@structref.register
class KnowledgeBaseTypeTemplate(types.StructRef):
def preprocess_fields(self, fields):
return tuple((name, types.unliteral(typ)) for name, typ in fields)
kb_fields = [
("kb_data",KnowledgeBaseDataType),
("context_data" , KnowledgeBaseContextDataType),
("halt_flag", u1),
("backtrack_flag", u1)
]
define_boxing(KnowledgeBaseTypeTemplate,KnowledgeBase)
# This is the actual resolved type
KnowledgeBaseType = KnowledgeBaseTypeTemplate(kb_fields)
@njit(cache=True)
def kb_ctor(context_data, kb_data=None):
st = new(KnowledgeBaseType) # <- Note that the resolved type is used here
st.context_data = context_data
st.kb_data = kb_data if(kb_data is not None) else init_kb_data(context_data)
st.halt_flag = u1(0)
st.backtrack_flag = u1(0)
return st
# Allows you to use KnowledgeBase as a constructor in a jitted function
@overload(KnowledgeBase)
def overload_KnowledgeBase(context_data=None, kb_data=None):
return kb_ctor
# Allows you to call kb.halt() inside a jitted function
@overload_method(KnowledgeBaseTypeTemplate, "halt")
def kb_halt(self):
def impl(self):
self.halt_flag = u1(1)
return impl
The easy way
So obviously all of this is kind of excessive if you only want a struct-like thing to pass around. For those cases I made a helper function here: https://github.com/DannyWeitekamp/Cognitive-Rule-Engine/blob/main/cre/structref.py#L75
There's a lot going on with this ^ and there is some custom caching stuff that goes along with it. To use it though is simple, just do something like this:
BEEP1, BEEP1Type = define_structref("BEEP1", [("A", i8), ("B", i8),("C", i8)])
And boom you've got your constructor and type ready to go
@njit(cache=True)
def print_A(beep):
print(beep.A)
print_A(BEEP1(5,6,7)) # prints 5
Thanks, that's very helpful! With the limitations of jitclass spelled out in in this issue, it looks like jitclass is a bit of a dead end and I'll need to switch my implementations to structref :)
I'll move further questions back to discourse where you've previously discussed some refcount magic and AOT examples
This is due to id(self) being part of the jitclass' name:https://github.com/numba/numba/blob/f04f4b164887e5aa5ffeeed1bef317e13ec995f8/numba/core/types/misc.py#L425-L426
I looked through the history but couldn't figure out why it's needed. I saw at some point early on (when jitclasses were being added), the sig didn't include the fielddesc. But since it does include it now, it seems like that should be enough to fully define the type's layout? Is there some problem I'm not seeing?
I tried removing id(self), which makes caching work, and adding/removing fields works as expected too (just makes new cache entries).
Hello Danny.
The hard way
My workaround for this issue has been to use structrefs instead of jitclasses.
I encountered a difficulty along the way.
from numba.core import types
from numba.experimental import structref
from numba.core.extending import overload_method
cusp_spec = [
('neu', nb.int64),
('ned', nb.int64),
('orbitals_up', nb.int64),
('orbitals_down', nb.int64),
]
@nb.experimental.jitclass(cusp_spec)
class Cusp:
def __init__(self, neu, ned):
"""Scheme for adding electron–nucleus cusps to Gaussian orbitals
:param neu: number of up electrons
:param ned: number of down electrons
"""
self.neu = neu
self.ned = ned
def value(self, r_e):
"""Cusp correction for s-part of orbitals.
:param r_e: electron coordinates
"""
return self.neu + self.ned
@structref.register
class Slater_class_t(types.StructRef):
def preprocess_fields(self, fields):
return tuple((name, types.unliteral(typ)) for name, typ in fields)
Slater_instance_t = Slater_class_t([
('neu', nb.int64),
('ned', nb.int64),
('cusp', nb.optional(Cusp.class_type.instance_type)),
])
class Slater(structref.StructRefProxy):
"""Slater wavefunction."""
def __new__(cls, neu, ned, cusp):
"""Slater wavefunction.
:param neu: number of up electrons
:param ned: number of down electrons
"""
return structref.StructRefProxy.__new__(cls, neu, ned, cusp)
@nb.njit(nogil=True, parallel=False, cache=True)
@overload_method(Slater_class_t, 'value')
def slater_value(self, r_e):
"""Wave function value.
:param r_e: electron coordinates
"""
def impl(self, r_e) -> float:
return 0.0
return impl
structref.define_proxy(Slater, Slater_class_t, ['neu', 'ned', 'cusp'])
spec = [
('neu', nb.int64),
('ned', nb.int64),
('slater', Slater_instance_t),
]
@nb.experimental.jitclass(spec)
class Wfn:
def __init__(self, neu, ned, slater):
"""Wave function in general form.
:param neu: number of up electrons
:param ned: number of down electrons
"""
self.neu = neu
self.ned = ned
self.slater = slater
def value(self, n_vectors) -> float:
return self.slater.value(n_vectors)
input = dict(neu=2, ned=2, cusp_correction=True)
if input['cusp_correction']:
cusp = Cusp(input['neu'], input['ned'])
else:
cusp = None
slater = Slater(input['neu'], input['ned'], cusp)
wfn = Wfn(input['neu'], input['ned'], slater)
Wfn.value((0,0,0))
I have triple nesting of classes: Wfn, Slater and Cusp. Wfn class is called from pure python code, others only from numba. Cusp class is optional so nb.optional is used, but this results in an issue in both cases: When cusp is initialised
Failed in nopython mode pipeline (step: native lowering) Cannot cast numba.Slater_class_t(('neu', int64), ('ned', int64), ('cusp', instance.jitclass.Cusp#7f92907cfc70neu:int64,ned:int64,orbitals_up:int64,orbitals_down:int64)) to numba.Slater_class_t(('neu', int64), ('ned', int64), ('cusp', OptionalType(instance.jitclass.Cusp#7f92907cfc70neu:int64,ned:int64,orbitals_up:int64,orbitals_down:int64))): %"inserted.meminfo.1" = insertvalue {i8*} undef, i8* %"arg.slater.0", 0
and when is not:
Failed in nopython mode pipeline (step: native lowering) Cannot cast numba.Slater_class_t(('neu', int64), ('ned', int64), ('cusp', none)) to numba.Slater_class_t(('neu', int64), ('ned', int64), ('cusp', OptionalType(instance.jitclass.Cusp#7fb3cade3c70neu:int64,ned:int64,orbitals_up:int64,orbitals_down:int64))): %"inserted.meminfo.1" = insertvalue {i8*} undef, i8* %"arg.slater.0", 0
If to use Slater as jitclass (not structref) everything works fine, but without caching. How to fix the situation in the case of structrefs?
Best, Vladimir.
Hey Vladimir (@Konjkov),
Just a heads up that most folks in the numba community are more responsive at https://numba.discourse.group/ if you have more questions down the road. There are also lot of good discussions on structrefs there.
The issue that you're having is actually pretty common. And I think whenever the devs/community get around to revising the structref/jitclass stuff it should be redesigned in a way that avoids the issue your're having.
What is happening is that when you call the constructor for Slater numba is building an implementation for it on the fly on the basis of your provided arguments. So what is returned is a specialization of Slater_class_t where the 'cusp' slot is assigned to Cusp.class_type.instance_type, lets just call that Cusp_t. But you defined Slater_instance_t so that 'cusp' is an optional(Cusp_t). Basically the constructor machinery is giving you a type that you never defined, that is subtly different than the one you did define.
What I recommend people do by default to avoid this is to just build their own constructor instead of relying on StructRefProxy.__new__ to do it for you. For instance, somewhere after the call to define_proxy:
structref.define_proxy(Slater, Slater_class_t, ['neu', 'ned', 'cusp'])
# Define after this ^
@nb.njit(Slater_instance_t(nb.int64, nb.int64, nb.optional(Cusp.class_type.instance_type)), cache=True)
def new_slater(neu, ned, cusp):
st = structref.new(Slater_instance_t)
st.neu = neu
st.ned = ned
st.cusp = cusp
return st
Then in the Proxy class swap new_slater out for StructRefProxy.__new__
class Slater(structref.StructRefProxy):
"""Slater wavefunction."""
def __new__(cls, neu, ned, cusp):
"""Slater wavefunction.
:param neu: number of up electrons
:param ned: number of down electrons
"""
return new_slater(neu, ned, cusp)
# return structref.StructRefProxy.__new__(cls, neu, ned, cusp)
A couple more thoughts on your approach:
-
If you want all of this to be cache-friendly you're going to need to phase out jitclasses entirely. I'm almost 100% sure that Structrefs that reference a jitclasses are also not cacheable. The fundamental issue is that jitclass types don't serialize properly because there is a memory address pointer in the definition, or something like that, so they don't look the same between executions of the same code.
-
As you've probably noticed by now, using numba to write things in an object oriented way can be a bit of a pain because the structref/jitclass stuff is a bit rough around the edges and requires a lot of boiler plate code. It looks like the essential elements of what you're trying to do are fundamentally numerical. And you might consider if what you want to accomplish can be implemented just with numpy arrays (which numba is a lot better at handling). If you don't need collections like Lists/Dicts to be part of your objects then you can often get away with using structured arrays (record arrays) which also come with the benefit of being block allocated. The fastest possible implementation of anything with numba usually involves packing things into numpy arrays whenever possible because it makes for a better memory access pattern than using pointer chasing code over heap allocated objects.
Hello DannyWeitekamp
Your solution works great despite the fact that I have a big project, only code using numba-mpi is not cached.
Best, Vladimir.