pomegranate
pomegranate copied to clipboard
Out-of-sample prediction using pseudocount
This pull request should resolve [#835]. In detail:
-
pseudocount
support for theBernoulliDistribution
was implemented. - The parameter
alpha
was introduced to smooth only the individual distributions but not the label (like in scikit-learn), leaving the current behaviour ofpseudocount
in tact. - Supporting tests were written.
- No failing tests when running
python3 setup.py test
. - scikit-learn and pomegranate now give identical predictions:
from numpy import array, testing
from pomegranate import *
from sklearn.naive_bayes import BernoulliNB
X_train = array([
[1., 0., 1., 0., 1.],
[1., 0., 1., 0., 1.],
[1., 0., 1., 0., 1.],
])
y_train = array([0, 0, 1])
X_test = array([[1., 0., 1., 0., 0.]])
pmodel = NaiveBayes.from_samples(
BernoulliDistribution,
X_train,
y_train,
alpha=1.0,
)
skmodel = BernoulliNB(alpha=1.0).fit(X_train, y_train)
testing.assert_array_almost_equal(
skmodel.predict_proba(X_test),
pmodel.predict_proba(X_test),
)
Thanks for your detailed review. I'll get back to you soon. Hylke
On Sun, 13 Dec 2020, 00:23 Jacob Schreiber, [email protected] wrote:
@jmschrei commented on this pull request.
Thanks for the contribution! It looks good, in general. As I mention in the comments, would you mind (1) removing typing, (2) making sure the indentation is correct, and slightly reworking the psuedocount parameter? Let me know if you have any questions!
In pomegranate/NaiveBayes.pyx https://github.com/jmschrei/pomegranate/pull/839#discussion_r541798244:
@@ -93,8 +95,9 @@ cdef class NaiveBayes(BayesModel):
@classmethod def from_samples(cls, distributions, X, y=None, weights=None,
pseudocount=0.0, stop_threshold=0.1, max_iterations=1e8,
callbacks=[], return_history=False, verbose=False, n_jobs=1):
pseudocount=0.0, alpha: Optional[float] = None, stop_threshold=0.1,
Would you mind removing the typing? Regardless of whether typing is good or bad (it's own debate), partially typing pomegranate seems like a bad idea and I can't commit to typing every parameter.
In pomegranate/NaiveBayes.pyx https://github.com/jmschrei/pomegranate/pull/839#discussion_r541798318:
@@ -140,6 +143,10 @@ cdef class NaiveBayes(BayesModel): if they don't happen to occur in the data. Only effects mixture models defined over discrete distributions. Default is 0.
alpha : double, optional
Is this indentation inconsistent with the previous lines?
In pomegranate/bayes.pyx https://github.com/jmschrei/pomegranate/pull/839#discussion_r541798466:
@@ -550,8 +552,8 @@ cdef class BayesModel(Model): free(r)
def fit(self, X, y=None, weights=None, inertia=0.0, pseudocount=0.0,
stop_threshold=0.1, max_iterations=1e8, callbacks=[],
return_history=False, verbose=False, n_jobs=1):
alpha: Optional[float] = None, stop_threshold=0.1, max_iterations=1e8,
Same here
In pomegranate/bayes.pyx https://github.com/jmschrei/pomegranate/pull/839#discussion_r541798556:
@@ -769,7 +775,13 @@ cdef class BayesModel(Model): int column_idx, int d) nogil: return -1
- def from_summaries(self, inertia=0.0, pseudocount=0.0, **kwargs):
- def from_summaries(
self,
I don't think every parameter needs its own line here.
In pomegranate/bayes.pyx https://github.com/jmschrei/pomegranate/pull/839#discussion_r541798563:
@@ -769,7 +775,13 @@ cdef class BayesModel(Model): int column_idx, int d) nogil: return -1
- def from_summaries(self, inertia=0.0, pseudocount=0.0, **kwargs):
- def from_summaries(
self,
inertia=0.0,
pseudocount=0.0,
alpha: Optional[float] = None,
Please remove typing
In pomegranate/bayes.pyx https://github.com/jmschrei/pomegranate/pull/839#discussion_r541798846:
@@ -804,8 +820,14 @@ cdef class BayesModel(Model): summaries /= summaries.sum()
for i, distribution in enumerate(self.distributions):
if isinstance(distribution, DiscreteDistribution):
distribution.from_summaries(inertia, pseudocount)
if isinstance(
distribution,
(DiscreteDistribution, BernoulliDistribution, IndependentComponentsDistribution),
):
distribution_smoothing = pseudocount
if alpha is not None:
I'm not sure I like this. You're saying the parameter "pseudocount" and "alpha" will do the same thing here? I'd suggest having pseudocount consistent refer to the labels and alpha refer to the distributions alone, as you mentioned in the documentation.
In pomegranate/distributions/BernoulliDistribution.pyx https://github.com/jmschrei/pomegranate/pull/839#discussion_r541799015:
"""Update the parameters of the distribution from the summaries."""
if self.summaries[0] < 1e-8 or self.frozen: return
p = self.summaries[1] / self.summaries[0]
n: int = 2
please remove typing
In pomegranate/distributions/BernoulliDistribution.pyx https://github.com/jmschrei/pomegranate/pull/839#discussion_r541799360:
self.p = self.p * inertia + p * (1-inertia)
self.logp[0] = _log(1-p) self.logp[1] = _log(p) self.summaries = [0.0, 0.0]
- def fit(
Not sure a new fit function is needed here. The other classes call the from_summaries method not the fit method regardless.
In pomegranate/distributions/IndependentComponentsDistribution.pyx https://github.com/jmschrei/pomegranate/pull/839#discussion_r541799432:
@@ -243,8 +245,14 @@ cdef class IndependentComponentsDistribution(MultivariateDistribution): }
@classmethod
- def from_samples(cls, X, weights=None, distribution_weights=None,
pseudocount=0.0, distributions=None):
- def from_samples(
see previous comments
In pomegranate/distributions/IndependentComponentsDistribution.pyx https://github.com/jmschrei/pomegranate/pull/839#discussion_r541799774:
@@ -253,10 +261,24 @@ cdef class IndependentComponentsDistribution(MultivariateDistribution): X, weights = weight_set(X, weights) n, d = X.shape
if callable(distributions):
distributions = [distributions.from_samples(X[:,i], weights) for i in range(d)]
else:
distributions = [distributions[i].from_samples(X[:,i], weights) for i in range(d)]
initialised_distributions = []
I'm not sure I like all the additions. Maybe just change it to something like distributions = [distributions.from_samples(X[:,i], weights, **params) for i in range(d)] where params is {} normally or {'pseudocount': pseudocount} if that's passed in?
In tests/test_distributions.py https://github.com/jmschrei/pomegranate/pull/839#discussion_r541799910:
@@ -879,6 +879,19 @@ def test_bernoulli(): assert_equal(f.name, "BernoulliDistribution") assert_equal(round(f.parameters[0], 4), 0.1667)
+@with_setup(setup, teardown) +def test_bernoulli_pseudocount():
- """
- Test fitting Bernoulli distribution with non-zero pseudocount.
- """
- a = [0.0, 0.0, 0.0]
- d = BernoulliDistribution.from_samples(a, pseudocount=1)
- assert_equal(d.probability(0), 0.8)
Should probably be assert_almost_equal to handle floating point issues
In tests/test_naive_bayes.py https://github.com/jmschrei/pomegranate/pull/839#discussion_r541800622:
[0., 1.],
[0., 1.],
[0., 1.],
- ])
- y_train = np.array([0, 0, 1])
Test that both
alpha
andpseudocount
are propagated to theBernoulliDistribution.
- for kwargs in ({'alpha': 1.0}, {'pseudocount': 1.0}):
pmodel = NaiveBayes.from_samples(
BernoulliDistribution,
X_train,
y_train,
**kwargs
)
# Check fit for y=0 label (2 records + pseudocount=1 per category).
This test will probably need to be changed. We shouldn't have two parameters that do the same thing. pseudocount should be on the labels (and so only affect the NaiveBayes priors) and alpha should be on the counts for the underlying distributions. That way pseudocount works the exact same way for normal distributions as bernoulli.
— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/jmschrei/pomegranate/pull/839#pullrequestreview-550869048, or unsubscribe https://github.com/notifications/unsubscribe-auth/ALRD5VFCLNF6ARB4ETXEBNDSUP3NXANCNFSM4TXHAVAQ .
Hi Jacob,
I have made all the requested changes, plus I updated the BayesClassifier
model, which was missing from the previous commit.
I hope this looks okay for you.
Kind regards,
Hylke
Can the failing CI be related to the new release of joblib (it looks like https://github.com/scikit-learn-contrib/hdbscan/issues/436)?
Thank you for your contribution. However, pomegranate has recently been rewritten from the ground up to use PyTorch instead of Cython, and so this PR is unfortunately out of date and will be closed.