astro-sdk
astro-sdk copied to clipboard
Table naming when expanding a LoadFileOperator
Please describe the feature you'd like to see
The LoadFileOperator accepts output_table and input_file arguments. When expanding over a set of files, you may want to load them with the exact same table definition, but to different paths (e.g. 5 files for the same data feed). This works by expanding over input_file (list), but _tmp table names are generated for the loaded tables. To name the tables, we need to generate a config combining the desired table name, table object, and file object to expand over. This requires a separate task, or using task methods such as map & zip to prepare the load configurations.
Describe the solution you'd like
I would like the ability to directly pass a table_name_parsing_function. The table_name_parsing_function would take the file name from input_file as an argument, and allow the user to parse this into a table name for that file. By default, We could transform the file name into an ANSI SQL compliant table name. As an example, we could combine the file name, run id, and task index.
Are there any alternatives to this feature? Noted above in the description.
Acceptance Criteria
- [ ] All checks and tests in the CI should pass
- [ ] Unit tests (90% code coverage or more, once available)
- [ ] Integration tests (if the feature relates to a new database or external service)
- [ ] Example DAG
- [ ] Docstrings in reStructuredText for each of methods, classes, functions and module-level attributes (including Example DAG on how it should be used)
- [ ] Exception handling in case of errors
- [ ] Logging (are we exposing useful information to the user? e.g. source and destination)
- [ ] Improve the documentation (README, Sphinx, and any other relevant)
- [ ] How to use Guide for the feature (example)
@ReadytoRocc could you share some example DAGs as per the description above?
@ReadytoRocc could you share some example DAGs as per the description above?
@sunank200 - please see the example below:
from airflow.decorators import dag, task
from airflow.exceptions import AirflowSkipException
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.operators.python import get_current_context
from astro import sql as aql
from astro.sql.table import Table, Metadata
from astro.files import File
from astro.sql.operators.load_file import LoadFileOperator as LoadFile
from datetime import datetime
import pandas as pd
@task(task_id="parse_load_configs")
def parse_load_configs_func(output_table_dataset, output_table_conn_id, file_list):
import os
load_configs = []
for file in file_list:
table = Table(
metadata=Metadata(
schema=output_table_dataset,
),
conn_id=output_table_conn_id,
temp=False,
)
table.name = os.path.basename(file.path).split(".")[0]
load_configs.append({"output_table": table, "input_file": file})
return load_configs
@task(task_id="scan_gcs")
def gcs_scan_func(
gcp_conn_id, bucket_name, prefix=None, delimiter=None, regex=None, **kwargs
):
import re
gcs_hook = GCSHook(gcp_conn_id=gcp_conn_id)
timespan_start = kwargs["data_interval_start"]
timespan_end = kwargs["data_interval_end"]
print(f"Scaning between {timespan_start} and {timespan_end}.")
files = gcs_hook.list_by_timespan(
bucket_name=bucket_name,
prefix=prefix,
delimiter=delimiter,
timespan_start=timespan_start,
timespan_end=timespan_end,
)
if regex:
_files = []
re_com = re.compile(regex)
for file in files:
if re_com.fullmatch(file):
_file = f"gs://{bucket_name}/{file}"
_files.append(File(path=_file, conn_id=gcp_conn_id))
files = _files
if len(files) == 0:
raise AirflowSkipException("No Files found, skipping.")
else:
return files
# Variables
BIGQUERY_DATASET = ""
GCP_CONN_ID = ""
GCS_BUCKET = ""
@dag(
schedule_interval="* * * * *",
start_date=datetime(2022, 12, 5),
catchup=False,
)
def bq_sdk():
gcs_scan_task = gcs_scan_func(
gcp_conn_id=GCP_CONN_ID, bucket_name=GCS_BUCKET, regex=r".*\.csv"
)
parse_load_configs_func_task = parse_load_configs_func(
output_table_dataset=BIGQUERY_DATASET,
output_table_conn_id=GCP_CONN_ID,
file_list=gcs_scan_task,
)
load_gcs_to_bq = LoadFile.partial(
task_id="load_gcs_to_bq",
use_native_support=True,
).expand_kwargs(parse_load_configs_func_task)
dag_obj = bq_sdk()