amazon-redshift-python-driver icon indicating copy to clipboard operation
amazon-redshift-python-driver copied to clipboard

Please support sqlalchemy redshift.

Open nlm4145 opened this issue 1 year ago • 2 comments

Title.

nlm4145 avatar Sep 26 '24 02:09 nlm4145

Yes, would like this addressed: https://github.com/sqlalchemy-redshift/sqlalchemy-redshift/issues/264 Thanks!

stefanks avatar Oct 14 '24 14:10 stefanks

This code would work, but it doesn't allow you to use metadata.create_all() to create table properly (because there's no dist key and sort key definition here), and also cannot use redshift specific data type. However, it works in 95% of use case. That's why I created https://github.com/MacHu-GWU/simple_aws_redshift-project to let it work with sqlalchemy and also does the authentication part of redshift-connector.

# -*- coding: utf-8 -*-

"""
Example: use sqlalchemy to work with AWS Redshift
"""

import sqlalchemy as sa
import sqlalchemy.orm as orm
import sqlalchemy.dialects
import sqlalchemy.dialects.postgresql.psycopg2


class RedshiftPostgresDialect(
    sqlalchemy.dialects.postgresql.psycopg2.PGDialect_psycopg2,
):
    # We need this to turn on statement caching
    # See warnning:
    #
    # SAWarning: Dialect postgresql:psycopg2 will not make use of SQL compilation
    # caching as it does not set the 'supports_statement_cache' attribute to ``True``.
    # This can have significant performance implications including some
    # performance degradations in comparison to prior SQLAlchemy versions.
    # Dialect maintainers should seek to set this attribute to True after
    # appropriate development and testing for SQLAlchemy 1.4 caching support.
    # Alternatively, this attribute may be set to False which will disable this warning.
    # (Background on this warning at: https://sqlalche.me/e/20/cprf)
    supports_statement_cache = True

    # This method in the Base class has a line
    # ``std_string = connection.exec_driver_sql("show standard_conforming_strings").scalar()``
    # we override it to skip this line
    # See error:
    #
    # sqlalchemy.exc.ProgrammingError: (psycopg2.errors.UndefinedObject)
    # unrecognized configuration parameter "standard_conforming_strings"
    def _set_backslash_escapes(self, connection):
        self._backslash_escapes = "off"


driver = "redshift_custom"

sqlalchemy.dialects.registry.register(
    "redshift_custom", __name__, "RedshiftPostgresDialect"
)


def create_engine(
    host: str,
    port: int,
    database: str,
    username: str,
    password: str,
    test_connection: bool = True,
) -> sa.Engine:
    url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
    # print(url)  # for debug only
    engine = sa.create_engine(url)
    if test_connection:
        with engine.connect() as conn:
            sql = "SELECT 1;"
            stmt = sa.text(sql)
            result = conn.execute(stmt).fetchall()
            # print(result)  # for debug only
    return engine


# ------------------------------------------------------------------------------
# ORM Model
# ------------------------------------------------------------------------------
class Base(orm.DeclarativeBase):
    pass


class Department(Base):
    __tablename__ = "department"

    id: orm.Mapped[str] = orm.mapped_column(primary_key=True)
    name: orm.Mapped[str] = orm.mapped_column(nullable=False, unique=True)

    employees: orm.Mapped[list["Employee"]] = orm.relationship(
        "Employee",
        back_populates="department",
        cascade="all, delete-orphan",
    )


class Employee(Base):
    __tablename__ = "employee"

    id: orm.Mapped[str] = orm.mapped_column(primary_key=True)
    name: orm.Mapped[str] = orm.mapped_column(nullable=False)
    department_id: orm.Mapped[str] = orm.mapped_column(
        sa.ForeignKey("department.id"),
        nullable=False,
    )

    department: orm.Mapped[Department] = orm.relationship(
        "Department",
        back_populates="employees",
    )


# create a view, which counts the number of employees in each department
# sqlalchemy cannot map view for insert, only for query
department_summary = sa.Table(
    "department_summary",
    Base.metadata,
    sa.Column("department_id", sa.String, primary_key=True),
    sa.Column("employee_count", sa.Integer),
    # 这里不需要 autoload=True,因为我们会手动创建 view
    autoload_with=None,
)


def drop_all(engine: sa.Engine):
    with engine.connect() as conn:
        sql = "DROP VIEW IF EXISTS department_summary"
        stmt = sa.text(sql)
        conn.execute(stmt)
        conn.commit()
    Base.metadata.drop_all(
        engine,
        tables=[
            Department.__table__,
            Employee.__table__,
        ],
        checkfirst=True,
    )


def create_all(engine: sa.Engine):
    Base.metadata.create_all(
        engine,
        tables=[
            Department.__table__,
            Employee.__table__,
        ],
        checkfirst=True,
    )
    with engine.connect() as conn:
        sql = """
CREATE OR REPLACE VIEW department_summary AS
SELECT
    d.id AS department_id,
    COUNT(e.id) AS employee_count
FROM department d
    LEFT JOIN employee e ON e.department_id = d.id
GROUP BY d.id
""".strip()
        stmt = sa.text(sql)
        conn.execute(stmt)
        conn.commit()


def insert_sample_data(engine: sa.Engine):
    with orm.Session(engine) as ses:
        hr = Department(id="dept-1", name="HR")
        eng = Department(id="dept-2", name="Engineering")
        ses.add_all([hr, eng])
        ses.commit()

        new_emps = [
            {"id": "e-1", "name": "Alice", "department_id": hr.id},
            {"id": "e-2", "name": "Bob", "department_id": hr.id},
            {"id": "e-3", "name": "Carol", "department_id": eng.id},
        ]
        for emp_data in new_emps:
            emp = Employee(**emp_data)
            ses.add(emp)
        ses.commit()


def select_sample_data(engine: sa.Engine):
    with orm.Session(engine) as ses:
        # query ORM models
        depts = ses.query(Department).all()
        for d in depts:
            print(f"{d.name} employees: {[e.name for e in d.employees]}")

        # query view
        stmt = sa.select(
            department_summary.c.department_id,
            department_summary.c.employee_count,
        ).order_by(department_summary.c.department_id)
        for row in ses.execute(stmt):
            print(f"Dept {row.department_id} has {row.employee_count} employees")


def run_demo(engine: sa.Engine):
    metadata = sa.MetaData()
    metadata.reflect(bind=engine)
    for table in metadata.sorted_tables:
        print(f"{table = }")
        for col in table.columns.values():
            print(f"  {col = }")
    # drop_all(engine)
    # create_all(engine)
    # insert_sample_data(engine)
    # select_sample_data(engine)
    engine.connect()

if __name__ == "__main__":
    from simple_aws_redshift.api import RedshiftServerlessConnectionParams
    from simple_aws_redshift.tests.settings import get_settings

    settings = get_settings()
    params = RedshiftServerlessConnectionParams.new(
        redshift_serverless_client=settings.bsm.redshiftserverless_client,
        namespace_name=settings.namespace_name,
        workgroup_name=settings.workgroup_name,
    )
    host = params.host
    port = params.port
    database = params.database
    username = settings.admin_username
    password = settings.admin_password

    engine = create_engine(
        host=host,
        port=port,
        database=database,
        username=username,
        password=password,
        test_connection=True,
    )

    run_demo(engine)

MacHu-GWU avatar Jun 16 '25 01:06 MacHu-GWU