astro-sdk
astro-sdk copied to clipboard
run_raw_sql returns a list of lists
For some reason, even though run_raw_sql returns a List[LegacyRow] objects, when it is picked up by a downstream tasks, we get a Tuple[List[LegacyRow]] objects. We should investigate why this is happening and fix it.
Below is an example test that demonstrates the bug. I have tested this with simpler lists and taskflow API does not seem to have the same issue.
@pytest.mark.parametrize(
"database_table_fixture",
[
{"database": Database.SNOWFLAKE},
{"database": Database.BIGQUERY},
{"database": Database.POSTGRES},
{"database": Database.SQLITE},
{"database": Database.REDSHIFT},
],
indirect=True,
ids=["snowflake", "bigquery", "postgresql", "sqlite", "redshift"],
)
def test_raw_sql(database_table_fixture, sample_dag):
_, test_table = database_table_fixture
@aql.run_raw_sql
def raw_sql_query(my_input_table: Table, created_table: Table, num_rows: int):
return "SELECT * FROM {{my_input_table}} LIMIT {{num_rows}}"
@task
def validate_raw_sql(cur: pd.DataFrame):
from sqlalchemy.engine.row import LegacyRow
# Note: It's a broken feature on the main branch that this is return in a list of lists. Problem reported here:
for c in cur[0]:
assert isinstance(c, LegacyRow)
print(cur)
with sample_dag:
homes_file = aql.load_file(
input_file=File(path=str(cwd) + "/../../../data/homes.csv"),
output_table=test_table,
)
raw_sql_result = (
raw_sql_query(
my_input_table=homes_file,
created_table=test_table,
num_rows=5,
handler=lambda cur: cur.fetchall(),
),
)
validate_raw_sql(raw_sql_result)
test_utils.run_dag(sample_dag)