unzip-bot icon indicating copy to clipboard operation
unzip-bot copied to clipboard

[FEATURE REQUEST] Interfaces

Open EDM115 opened this issue 8 months ago • 1 comments

This could allow to write once and switch between providers by just adding files, ex one interface for DB connexions and then we implement behind this either MongoDB or SQLite or whatever

EDM115 avatar Mar 16 '25 19:03 EDM115

Potential duplicates:

  • [#289] [FEATURE REQUEST] Autoban (73.33%)

  • [#288] [FEATURE REQUEST] Whitelist (74.07%)

  • [#144] [FEATURE REQUEST] Auto-save password (61.67%)

github-actions[bot] avatar Mar 16 '25 19:03 github-actions[bot]

There's basically 2 ways to achieve this :

1. A low level implementation

Here, we create an interface that holds several classes to handle how a database work :

from abc import ABC, abstractmethod

class DatabaseInterface(ABC):
    """
    Abstract base class for database operations.
    Tables/collections are treated generically so that both SQL and NoSQL implementations can share the same interface.
    """

    @abstractmethod
    async def open(self):
        """
        Open the database connection.
        """
        pass

    @abstractmethod
    async def close(self):
        """
        Close the database connection.
        """
        pass

    @abstractmethod
    async def create_table(self, table_name: str, schema: dict = None):
        """
        Create a table (or collection) in the database.
        
        Parameters:
            table_name (str): The name of the table (or collection).
            schema (dict, optional): For SQL, a mapping of column names to SQL types.
                                     For NoSQL, this can be ignored or used if needed.
        """
        pass

    @abstractmethod
    async def count(self, table_name: str, filter: dict = None) -> int:
        """
        Count documents/records in a table (or collection), optionally applying a filter.

        Parameters:
            table_name (str): Name of the table.
            filter (dict, optional): Conditions for counting (defaults to empty filter).

        Returns:
            int: The count of records.
        """
        pass

    @abstractmethod
    async def get_all(self, table_name: str) -> list:
        """
        Retrieve all documents/records from a table (or collection).

        Parameters:
            table_name (str): Name of the table.

        Returns:
            list: A list of records.
        """
        pass

    @abstractmethod
    async def get(self, table_name: str, query: dict) -> dict:
        """
        Retrieve a single document/record that matches the query.

        Parameters:
            table_name (str): Name of the table.
            query (dict): Query conditions.

        Returns:
            dict: The found record or None.
        """
        pass

    @abstractmethod
    async def insert(self, table_name: str, document: dict):
        """
        Insert a new document/record into a table (or collection).

        Parameters:
            table_name (str): Name of the table.
            document (dict): The record to insert.
        """
        pass

    @abstractmethod
    async def update(self, table_name: str, query: dict, update: dict):
        """
        Update document(s)/record(s) that match the query.

        Parameters:
            table_name (str): Name of the table.
            query (dict): The condition that records must match.
            update (dict): The fields to update.
        """
        pass

    @abstractmethod
    async def delete(self, table_name: str, query: dict):
        """
        Delete a specific document/record that matches the query.

        Parameters:
            table_name (str): Name of the table.
            query (dict): The condition that records must match.
        """
        pass

    @abstractmethod
    async def delete_all(self, table_name: str):
        """
        Delete all records from a table (or collection).

        Parameters:
            table_name (str): Name of the table.
        """
        pass

    @abstractmethod
    async def get_all_database(self) -> dict:
        """
        Retrieve all data from the entire database.

        Returns:
            dict: A dictionary with table/collection names as keys and lists of records as values.
        """
        pass

We then do the implementation for the backends we want :

from motor.motor_asyncio import AsyncIOMotorClient
from database_interface import DatabaseInterface

class MongoDBDatabase(DatabaseInterface):
    def __init__(self, connection_str: str, db_name: str):
        """
        Initialize with the connection string and database name.
        """
        self.connection_str = connection_str
        self.db_name = db_name
        self.client = None
        self.db = None

    async def open(self):
        """
        Open a connection to MongoDB.
        """
        self.client = AsyncIOMotorClient(self.connection_str)
        self.db = self.client[self.db_name]

    async def close(self):
        """
        Close the MongoDB connection.
        """
        self.client.close()

    async def create_table(self, table_name: str, schema: dict = None):
        """
        Create a collection in MongoDB.
        Although collections are created automatically on first insert,
        you can explicitly create one if desired.
        """
        collections = await self.db.list_collection_names()
        if table_name not in collections:
            await self.db.create_collection(table_name)

    async def count(self, table_name: str, filter: dict = None) -> int:
        """
        Count documents in a collection.
        """
        if filter is None:
            filter = {}
        return await self.db[table_name].count_documents(filter)

    async def get_all(self, table_name: str) -> list:
        """
        Retrieve all documents from a collection.
        """
        cursor = self.db[table_name].find({})
        return [doc async for doc in cursor]

    async def get(self, table_name: str, query: dict) -> dict:
        """
        Retrieve a single document matching the query.
        """
        return await self.db[table_name].find_one(query)

    async def insert(self, table_name: str, document: dict):
        """
        Insert a document into a collection.
        """
        result = await self.db[table_name].insert_one(document)
        return result.inserted_id

    async def update(self, table_name: str, query: dict, update: dict):
        """
        Update one document in the collection using $set for fields.
        """
        result = await self.db[table_name].update_one(query, {"$set": update})
        return result.modified_count

    async def delete(self, table_name: str, query: dict):
        """
        Delete one document from the collection.
        """
        result = await self.db[table_name].delete_one(query)
        return result.deleted_count

    async def delete_all(self, table_name: str):
        """
        Delete all documents from a collection.
        """
        result = await self.db[table_name].delete_many({})
        return result.deleted_count

    async def get_all_database(self) -> dict:
        """
        Retrieve data from all collections.
        """
        data = {}
        collections = await self.db.list_collection_names()
        for coll in collections:
            cursor = self.db[coll].find({})
            data[coll] = [doc async for doc in cursor]
        return data
import aiosqlite
from database_interface import DatabaseInterface

class SQLiteDatabase(DatabaseInterface):
    def __init__(self, db_path: str):
        """
        Initialize with the path to the SQLite database file.
        """
        self.db_path = db_path
        self.conn = None

    async def open(self):
        """
        Open a connection to the SQLite database.
        """
        self.conn = await aiosqlite.connect(self.db_path)
        self.conn.row_factory = aiosqlite.Row

    async def close(self):
        """
        Close the SQLite database connection.
        """
        await self.conn.close()

    async def create_table(self, table_name: str, schema: dict = None):
        """
        Create a table using the provided schema.
        The schema should be a dict mapping column names to SQL types (e.g., {"id": "INTEGER PRIMARY KEY", "name": "TEXT"}).
        """
        if schema is None:
            raise ValueError("Schema must be provided for SQLite table creation.")
        columns_def = ", ".join([f"{col} {dtype}" for col, dtype in schema.items()])
        sql = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns_def})"
        await self.conn.execute(sql)
        await self.conn.commit()

    async def count(self, table_name: str, filter: dict = None) -> int:
        """
        Count rows in a table with an optional filter.
        """
        where_clause, params = self._build_where_clause(filter)
        sql = f"SELECT COUNT(*) FROM {table_name} {where_clause}"
        async with self.conn.execute(sql, params) as cursor:
            row = await cursor.fetchone()
        return row[0]

    async def get_all(self, table_name: str) -> list:
        """
        Retrieve all rows from a table.
        """
        sql = f"SELECT * FROM {table_name}"
        async with self.conn.execute(sql) as cursor:
            rows = await cursor.fetchall()
        return [dict(row) for row in rows]

    async def get(self, table_name: str, query: dict) -> dict:
        """
        Retrieve a single row that matches the query.
        """
        where_clause, params = self._build_where_clause(query)
        sql = f"SELECT * FROM {table_name} {where_clause} LIMIT 1"
        async with self.conn.execute(sql, params) as cursor:
            row = await cursor.fetchone()
        return dict(row) if row else None

    async def insert(self, table_name: str, document: dict):
        """
        Insert a new row into the table.
        """
        columns = ", ".join(document.keys())
        placeholders = ", ".join(["?"] * len(document))
        sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
        await self.conn.execute(sql, tuple(document.values()))
        await self.conn.commit()

    async def update(self, table_name: str, query: dict, update: dict):
        """
        Update rows in the table that match the query.
        """
        set_clause, set_params = self._build_set_clause(update)
        where_clause, where_params = self._build_where_clause(query)
        sql = f"UPDATE {table_name} SET {set_clause} {where_clause}"
        await self.conn.execute(sql, tuple(set_params + where_params))
        await self.conn.commit()

    async def delete(self, table_name: str, query: dict):
        """
        Delete row(s) that match the query.
        """
        where_clause, params = self._build_where_clause(query)
        sql = f"DELETE FROM {table_name} {where_clause}"
        await self.conn.execute(sql, params)
        await self.conn.commit()

    async def delete_all(self, table_name: str):
        """
        Delete all rows from a table.
        """
        sql = f"DELETE FROM {table_name}"
        await self.conn.execute(sql)
        await self.conn.commit()

    async def get_all_database(self) -> dict:
        """
        Retrieve the entire database content as a dict mapping table names to lists of rows.
        """
        data = {}
        # Retrieve list of tables from the sqlite_master table.
        sql = "SELECT name FROM sqlite_master WHERE type='table'"
        async with self.conn.execute(sql) as cursor:
            tables = await cursor.fetchall()
        for table in tables:
            table_name = table[0]
            data[table_name] = await self.get_all(table_name)
        return data

    def _build_where_clause(self, filter: dict) -> (str, list):
        """
        Helper to convert a dictionary filter into a SQL WHERE clause.
        """
        if not filter:
            return "", []
        clause = " AND ".join([f"{key} = ?" for key in filter.keys()])
        return "WHERE " + clause, list(filter.values())

    def _build_set_clause(self, update: dict) -> (str, list):
        """
        Helper to build a SQL SET clause for update queries.
        """
        clause = ", ".join([f"{key} = ?" for key in update.keys()])
        return clause, list(update.values())

And finally we use them easily, although we need to write different queries for both cases :

import asyncio
from mongodb_db import MongoDBDatabase

async def main():
    # replace these with your actual configuration values
    mongo = MongoDBDatabase(connection_str="your_mongodb_uri", db_name="your_db_name")
    await mongo.open()
    
    # Create a "users" collection (if needed)
    await mongo.create_table("users")
    
    # Insert a record
    user = {"user_id": 123, "name": "Alice"}
    inserted_id = await mongo.insert("users", user)
    print(f"Inserted user with id: {inserted_id}")
    
    # Count users
    count = await mongo.count("users")
    print("Number of users:", count)
    
    # Retrieve all users
    users = await mongo.get_all("users")
    print("Users:", users)
    
    await mongo.close()

asyncio.run(main())
import asyncio
from sqlite_db import SQLiteDatabase

async def main():
    sqlite = SQLiteDatabase(db_path="mydatabase.db")
    await sqlite.open()
    
    # Define a table schema for SQLite (for example, a users table)
    users_schema = {
        "user_id": "INTEGER PRIMARY KEY",
        "name": "TEXT"
    }
    await sqlite.create_table("users", schema=users_schema)
    
    # Insert a record into the users table
    await sqlite.insert("users", {"user_id": 123, "name": "Alice"})
    
    # Count rows in the users table
    count = await sqlite.count("users")
    print("Number of users:", count)
    
    # Retrieve all rows from the users table
    users = await sqlite.get_all("users")
    print("Users:", users)
    
    await sqlite.close()

asyncio.run(main())

2. A high level implementation

The goal here is to provide an ORM-like experience (think Prisma), where none of our queries change and only the implementation differs :

from abc import ABC, abstractmethod

class TableInterface(ABC):
    @abstractmethod
    async def count(self, filter: dict = None) -> int:
        """
        Return the count of records matching the filter.
        """
        pass

    @abstractmethod
    async def find_one(self, query: dict) -> dict:
        """
        Find and return a single record matching the query.
        """
        pass

    @abstractmethod
    async def get_all(self) -> list:
        """
        Retrieve all records from the table.
        """
        pass

    @abstractmethod
    async def insert(self, document: dict):
        """
        Insert a new record into the table.
        """
        pass

    @abstractmethod
    async def update(self, query: dict, update: dict):
        """
        Update record(s) matching the query.
        """
        pass

    @abstractmethod
    async def delete(self, query: dict):
        """
        Delete record(s) that match the query.
        """
        pass

    @abstractmethod
    async def delete_all(self):
        """
        Delete all records in the table.
        """
        pass


class DatabaseInterface(ABC):
    @abstractmethod
    async def open(self):
        """
        Open the database connection.
        """
        pass

    @abstractmethod
    async def close(self):
        """
        Close the database connection.
        """
        pass

    @abstractmethod
    def table(self, table_name: str) -> TableInterface:
        """
        Return a table handle for a given table or collection name.
        """
        pass

    @abstractmethod
    async def get_all_database(self) -> dict:
        """
        Retrieve all data from the entire database.
        """
        pass
from motor.motor_asyncio import AsyncIOMotorClient
from database_interface import DatabaseInterface, TableInterface

class MongoTable(TableInterface):
    def __init__(self, collection):
        self.collection = collection

    async def count(self, filter: dict = None) -> int:
        filter = filter or {}
        return await self.collection.count_documents(filter)

    async def find_one(self, query: dict) -> dict:
        return await self.collection.find_one(query)

    async def get_all(self) -> list:
        cursor = self.collection.find({})
        return [doc async for doc in cursor]

    async def insert(self, document: dict):
        result = await self.collection.insert_one(document)
        return result.inserted_id

    async def update(self, query: dict, update: dict):
        result = await self.collection.update_one(query, {"$set": update})
        return result.modified_count

    async def delete(self, query: dict):
        result = await self.collection.delete_one(query)
        return result.deleted_count

    async def delete_all(self):
        result = await self.collection.delete_many({})
        return result.deleted_count


class MongoDBDatabase(DatabaseInterface):
    def __init__(self, connection_str: str, db_name: str):
        self.connection_str = connection_str
        self.db_name = db_name
        self.client = None
        self.db = None

    async def open(self):
        self.client = AsyncIOMotorClient(self.connection_str)
        self.db = self.client[self.db_name]

    async def close(self):
        self.client.close()

    def table(self, table_name: str) -> TableInterface:
        # Treat the collection as a table
        return MongoTable(self.db[table_name])

    async def get_all_database(self) -> dict:
        data = {}
        collections = await self.db.list_collection_names()
        for coll in collections:
            cursor = self.db[coll].find({})
            data[coll] = [doc async for doc in cursor]
        return data
import aiosqlite
from database_interface import DatabaseInterface, TableInterface

class SQLiteTable(TableInterface):
    def __init__(self, connection, table_name: str):
        self.conn = connection
        self.table_name = table_name

    async def count(self, filter: dict = None) -> int:
        where_clause, params = self._build_where_clause(filter)
        sql = f"SELECT COUNT(*) FROM {self.table_name} {where_clause}"
        async with self.conn.execute(sql, params) as cursor:
            row = await cursor.fetchone()
        return row[0]

    async def find_one(self, query: dict) -> dict:
        where_clause, params = self._build_where_clause(query)
        sql = f"SELECT * FROM {self.table_name} {where_clause} LIMIT 1"
        async with self.conn.execute(sql, params) as cursor:
            row = await cursor.fetchone()
        return dict(row) if row else None

    async def get_all(self) -> list:
        sql = f"SELECT * FROM {self.table_name}"
        async with self.conn.execute(sql) as cursor:
            rows = await cursor.fetchall()
        return [dict(row) for row in rows]

    async def insert(self, document: dict):
        columns = ", ".join(document.keys())
        placeholders = ", ".join(["?"] * len(document))
        sql = f"INSERT INTO {self.table_name} ({columns}) VALUES ({placeholders})"
        await self.conn.execute(sql, tuple(document.values()))
        await self.conn.commit()

    async def update(self, query: dict, update: dict):
        set_clause, set_params = self._build_set_clause(update)
        where_clause, where_params = self._build_where_clause(query)
        sql = f"UPDATE {self.table_name} SET {set_clause} {where_clause}"
        await self.conn.execute(sql, tuple(set_params + where_params))
        await self.conn.commit()

    async def delete(self, query: dict):
        where_clause, params = self._build_where_clause(query)
        sql = f"DELETE FROM {self.table_name} {where_clause}"
        await self.conn.execute(sql, params)
        await self.conn.commit()

    async def delete_all(self):
        sql = f"DELETE FROM {self.table_name}"
        await self.conn.execute(sql)
        await self.conn.commit()

    def _build_where_clause(self, filter: dict) -> (str, list):
        if not filter:
            return "", []
        clause = " AND ".join([f"{key} = ?" for key in filter.keys()])
        return "WHERE " + clause, list(filter.values())

    def _build_set_clause(self, update: dict) -> (str, list):
        clause = ", ".join([f"{key} = ?" for key in update.keys()])
        return clause, list(update.values())


class SQLiteDatabase(DatabaseInterface):
    def __init__(self, db_path: str):
        self.db_path = db_path
        self.conn = None

    async def open(self):
        self.conn = await aiosqlite.connect(self.db_path)
        self.conn.row_factory = aiosqlite.Row

    async def close(self):
        await self.conn.close()

    def table(self, table_name: str) -> TableInterface:
        return SQLiteTable(self.conn, table_name)

    async def get_all_database(self) -> dict:
        data = {}
        # Query sqlite_master to get table names
        sql = "SELECT name FROM sqlite_master WHERE type='table'"
        async with self.conn.execute(sql) as cursor:
            tables = await cursor.fetchall()
        for table in tables:
            table_name = table[0]
            table_obj = self.table(table_name)
            data[table_name] = await table_obj.get_all()
        return data

    async def create_table(self, table_name: str, schema: dict):
        """
        Helper method for creating a table given a schema dict (e.g., {"id": "INTEGER PRIMARY KEY", "name": "TEXT"}).
        """
        columns_def = ", ".join([f"{col} {dtype}" for col, dtype in schema.items()])
        sql = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns_def})"
        await self.conn.execute(sql)
        await self.conn.commit()

The usage is far more simple :

from config import Config

if Config.DB_TYPE == "mongodb":
    from mongodb_db import MongoDBDatabase
    db = MongoDBDatabase(connection_str=Config.MONGODB_URL, db_name=Config.MONGODB_DBNAME)
elif Config.DB_TYPE == "sqlite":
    from sqlite_db import SQLiteDatabase
    db = SQLiteDatabase(db_path="mydatabase.db")
else:
    raise ValueError("Unsupported DB_TYPE in configuration")

# Example usage of the interchangeable API
import asyncio

async def main():
    await db.open()

    # Get a table handle; note that in MongoDB this is a collection.
    users_table = db.table("users_db")
    
    # Count users (remember these are async operations)
    count = await users_table.count()
    print("Users count:", count)
    
    # Insert a user
    await users_table.insert({"user_id": 123456, "user_name": "Test user"})
    print("Inserted user 123456")
    
    # Find a specific user
    user = await users_table.find_one({"user_id": 666})
    print("User 666 found:" if user else "User 666 not found.", user)
    
    await db.close()

asyncio.run(main())

EDM115 avatar Apr 11 '25 18:04 EDM115