amazon-redshift-python-driver
amazon-redshift-python-driver copied to clipboard
Please support sqlalchemy redshift.
Title.
Yes, would like this addressed: https://github.com/sqlalchemy-redshift/sqlalchemy-redshift/issues/264 Thanks!
This code would work, but it doesn't allow you to use metadata.create_all() to create table properly (because there's no dist key and sort key definition here), and also cannot use redshift specific data type. However, it works in 95% of use case. That's why I created https://github.com/MacHu-GWU/simple_aws_redshift-project to let it work with sqlalchemy and also does the authentication part of redshift-connector.
# -*- coding: utf-8 -*-
"""
Example: use sqlalchemy to work with AWS Redshift
"""
import sqlalchemy as sa
import sqlalchemy.orm as orm
import sqlalchemy.dialects
import sqlalchemy.dialects.postgresql.psycopg2
class RedshiftPostgresDialect(
sqlalchemy.dialects.postgresql.psycopg2.PGDialect_psycopg2,
):
# We need this to turn on statement caching
# See warnning:
#
# SAWarning: Dialect postgresql:psycopg2 will not make use of SQL compilation
# caching as it does not set the 'supports_statement_cache' attribute to ``True``.
# This can have significant performance implications including some
# performance degradations in comparison to prior SQLAlchemy versions.
# Dialect maintainers should seek to set this attribute to True after
# appropriate development and testing for SQLAlchemy 1.4 caching support.
# Alternatively, this attribute may be set to False which will disable this warning.
# (Background on this warning at: https://sqlalche.me/e/20/cprf)
supports_statement_cache = True
# This method in the Base class has a line
# ``std_string = connection.exec_driver_sql("show standard_conforming_strings").scalar()``
# we override it to skip this line
# See error:
#
# sqlalchemy.exc.ProgrammingError: (psycopg2.errors.UndefinedObject)
# unrecognized configuration parameter "standard_conforming_strings"
def _set_backslash_escapes(self, connection):
self._backslash_escapes = "off"
driver = "redshift_custom"
sqlalchemy.dialects.registry.register(
"redshift_custom", __name__, "RedshiftPostgresDialect"
)
def create_engine(
host: str,
port: int,
database: str,
username: str,
password: str,
test_connection: bool = True,
) -> sa.Engine:
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
# print(url) # for debug only
engine = sa.create_engine(url)
if test_connection:
with engine.connect() as conn:
sql = "SELECT 1;"
stmt = sa.text(sql)
result = conn.execute(stmt).fetchall()
# print(result) # for debug only
return engine
# ------------------------------------------------------------------------------
# ORM Model
# ------------------------------------------------------------------------------
class Base(orm.DeclarativeBase):
pass
class Department(Base):
__tablename__ = "department"
id: orm.Mapped[str] = orm.mapped_column(primary_key=True)
name: orm.Mapped[str] = orm.mapped_column(nullable=False, unique=True)
employees: orm.Mapped[list["Employee"]] = orm.relationship(
"Employee",
back_populates="department",
cascade="all, delete-orphan",
)
class Employee(Base):
__tablename__ = "employee"
id: orm.Mapped[str] = orm.mapped_column(primary_key=True)
name: orm.Mapped[str] = orm.mapped_column(nullable=False)
department_id: orm.Mapped[str] = orm.mapped_column(
sa.ForeignKey("department.id"),
nullable=False,
)
department: orm.Mapped[Department] = orm.relationship(
"Department",
back_populates="employees",
)
# create a view, which counts the number of employees in each department
# sqlalchemy cannot map view for insert, only for query
department_summary = sa.Table(
"department_summary",
Base.metadata,
sa.Column("department_id", sa.String, primary_key=True),
sa.Column("employee_count", sa.Integer),
# 这里不需要 autoload=True,因为我们会手动创建 view
autoload_with=None,
)
def drop_all(engine: sa.Engine):
with engine.connect() as conn:
sql = "DROP VIEW IF EXISTS department_summary"
stmt = sa.text(sql)
conn.execute(stmt)
conn.commit()
Base.metadata.drop_all(
engine,
tables=[
Department.__table__,
Employee.__table__,
],
checkfirst=True,
)
def create_all(engine: sa.Engine):
Base.metadata.create_all(
engine,
tables=[
Department.__table__,
Employee.__table__,
],
checkfirst=True,
)
with engine.connect() as conn:
sql = """
CREATE OR REPLACE VIEW department_summary AS
SELECT
d.id AS department_id,
COUNT(e.id) AS employee_count
FROM department d
LEFT JOIN employee e ON e.department_id = d.id
GROUP BY d.id
""".strip()
stmt = sa.text(sql)
conn.execute(stmt)
conn.commit()
def insert_sample_data(engine: sa.Engine):
with orm.Session(engine) as ses:
hr = Department(id="dept-1", name="HR")
eng = Department(id="dept-2", name="Engineering")
ses.add_all([hr, eng])
ses.commit()
new_emps = [
{"id": "e-1", "name": "Alice", "department_id": hr.id},
{"id": "e-2", "name": "Bob", "department_id": hr.id},
{"id": "e-3", "name": "Carol", "department_id": eng.id},
]
for emp_data in new_emps:
emp = Employee(**emp_data)
ses.add(emp)
ses.commit()
def select_sample_data(engine: sa.Engine):
with orm.Session(engine) as ses:
# query ORM models
depts = ses.query(Department).all()
for d in depts:
print(f"{d.name} employees: {[e.name for e in d.employees]}")
# query view
stmt = sa.select(
department_summary.c.department_id,
department_summary.c.employee_count,
).order_by(department_summary.c.department_id)
for row in ses.execute(stmt):
print(f"Dept {row.department_id} has {row.employee_count} employees")
def run_demo(engine: sa.Engine):
metadata = sa.MetaData()
metadata.reflect(bind=engine)
for table in metadata.sorted_tables:
print(f"{table = }")
for col in table.columns.values():
print(f" {col = }")
# drop_all(engine)
# create_all(engine)
# insert_sample_data(engine)
# select_sample_data(engine)
engine.connect()
if __name__ == "__main__":
from simple_aws_redshift.api import RedshiftServerlessConnectionParams
from simple_aws_redshift.tests.settings import get_settings
settings = get_settings()
params = RedshiftServerlessConnectionParams.new(
redshift_serverless_client=settings.bsm.redshiftserverless_client,
namespace_name=settings.namespace_name,
workgroup_name=settings.workgroup_name,
)
host = params.host
port = params.port
database = params.database
username = settings.admin_username
password = settings.admin_password
engine = create_engine(
host=host,
port=port,
database=database,
username=username,
password=password,
test_connection=True,
)
run_demo(engine)