superduper
superduper copied to clipboard
[FCT] Functional approach to graph models and listeners
Idea of @jieguangzhou to create models inside a context manager.
# Create model instances
model1 = Model()
model2_multi = Model()
model3 = Model()
# Using the Graph context manager
with Graph.creating() as G:
o_m1 = model1(G.input)
o_m2 = model2(y=model1)
o_m3 = model3(x=o_m1, y=o_m2)
G.predict_one() # is ok
G(model4) # raise NoBuildContextError
Idea, use same formal approach for creating Listener
instances:
with db.listener(coll.find()) as list:
# this is a `Listener` equivalent to
# `Listener(key={'a': 'txt', 'b': 'brand'}, select=coll.find(), model=model1)
o_m1 = model1(a='txt', b='brand')
# This is a downstream listener
# `Listener(key={'x': 'brand', 'y': '_outputs....{model1.identifier}.{model1.version}'}, select=coll.find(), model=model1)`
o_m2 = model2(x='brand', y=o_m1)
...
Suggestion to make easier to debug:
with db.listener(coll.find(), eager=True):
# Same as `model1.predict(list(db.execute(coll.find())))`
o_m1 = model1(a='txt', b='brand')
# This is a downstream prediction
# Same as `model2.predict(...)`
o_m2 = model2(x='brand', y=o_m1)
Solution for multiple inputs and models - create a routing model node:
import dataclasses as dc
import typing as t
from superduperdb.components.model import _Predictor, objectmodel
db = ...
coll = ...
@objectmodel
def f(a, b, c=2):
...
# Data in database looks like this:
{'txt': 'sdsds', 'brand': '12324hd'}
{'txt': 'sdsd1s', 'brand': '212324hd'}
f.predict_in_db(X={'txt': 'b', 'brand': 'a', 'bla': 'c'}, db=db, select=coll.find())
f.predict_in_db(X=(
['brand', 'txt'],
{'other': 'c'}
),
db=db, select=coll.find()
)
# Question, if we have a GraphModel what are the input parameters?
# If we have 1 input node, it's easy, same as input node
# If multiple input nodes ... ?
@dc.dataclass
class InputNode(_Predictor):
inputs: t.List[str]
signature: ...
@property
def inputs(self):
...
def predict_one(self, *args, **kwargs) -> int:
return ...
with Graph.create() as g:
...
Suggestion: create a GraphModel
and then do G.listen(db=db, select=coll.find())