peewee-async icon indicating copy to clipboard operation
peewee-async copied to clipboard

Manager.create_or_get does not respect composite keys

Open bobsteinke opened this issue 4 years ago • 0 comments

When create_or_get tries to get the row that prevented it from creating. It does this by selecting on unique columns. However, in a composite key the individual columns are not unique, only the combination is. It appears that it does not include the composite key columns in the select.

Test that fails:

import asynctest
from peewee import CompositeKey, IntegerField, Model, TextField
from peewee_async import Manager


class CompKeyModel(Model):
    class Meta:
        table_name = 'generated_image'
        primary_key = CompositeKey('task_id', 'product_type')

    task_id = IntegerField(null=False)
    product_type = TextField(null=False)
    misc = TextField(null=True)


class AsyncPeeweeManagerTests(asynctest.TestCase):
    def setUp(self):
        self.db = # code to set up database...
        self.db_mgr = Manager(self.db)

    async def test_create_or_get_composite_key(self):
        await self.db_mgr.create(CompKeyModel, task_id=1, product_type='a', misc='13')
        await self.db_mgr.create(CompKeyModel, task_id=2, product_type='b', misc='15')
        entry, was_created = await self.db_mgr.create_or_get(CompKeyModel, task_id=2, product_type='b')
        self.assertFalse(was_created)

        # fails, entry can return (1, 'a', '13)
        self.assertEqual(2, entry.task_id)
        self.assertEqual('b', entry.product_type)
        self.assertEqual('15', entry.misc)

Here is suggested code for fixing the problem:

    async def create_or_get(self, model_, **kwargs):
        try:
            return (await self.create(model_, **kwargs)), True
        except IntegrityErrors:
            query = []
            composite_key_field_names = model_._meta.primary_key.field_names if model_._meta.composite_key else ()
            for field_name, value in kwargs.items():
                field = getattr(model_, field_name)
                if field.unique or field.primary_key or field_name in composite_key_field_names:
                    query.append(field == value)
            return (await self.get(model_, *query)), False

bobsteinke avatar Oct 27 '21 21:10 bobsteinke