datafusion icon indicating copy to clipboard operation
datafusion copied to clipboard

Custom operator support

Open samuelcolvin opened this issue 1 year ago • 9 comments

I think this is ready to review.

~~WIP to decide if we progress cc @alamb.~~

~~This is based on #11132 since I need Operator to be non-copyable. If we decide to progress I can rebase this, or cherry pick and create a new branch.~~

~~You can see this without the churn of #11132 at https://github.com/samuelcolvin/datafusion/pull/1.~~

high level question

Is all the effort here worth it, should we just add the 20 or so extra operators from the sql library to Operator?

Advantages of this route:

  • more flexibility in how operators behave, what their signatures is, precedence, negation etc.
  • easier to use custom operators or new operators adding to the SQL library without waiting for datafusion to support them

Disadvantages:

  • ~~a lot of faff here~~ it's not that bad

To to

  • [x] How should ParseCustomOperator be passed into the SQL parser, it definitely shouldn't be as it is now, perhaps we should have an equivalent of register_function_rewrite like, e.g. register_custom_operators? DONE
  • [x] ~~is the hack with WrapCustomOperator necessary and acceptable, maybe someone else's Rust foo could avoid this?~~ I think what we have is good
  • [x] ~~should CustomOperator provide a get_udf method, then we rewrite to that function automatically, rather than relying on register_function_rewrite?~~ I don't think so, the current write logic is more powerful
  • [x] what should we do about from_proto_binary_op, we can't keep the same signature and support custom operators, this might be easy to fix - done by adding the register methods to FunctionRegistry
  • [x] ~~is it okay to rely on name() of the operator to implement equality, ordering and hashing?~~ I think it's good
  • [x] Needs tests - basic tests are done, LMK what more is needed

Example Usage:

Here's a trivial example of usage that just replaces the "->" operator with string concat:
use std::sync::Arc;

use datafusion::arrow::datatypes::DataType;
use datafusion::common::config::ConfigOptions;
use datafusion::common::tree_node::Transformed;
use datafusion::common::DFSchema;
use datafusion::error::Result;
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::expr_rewriter::FunctionRewrite;
use datafusion::logical_expr::{
    CustomOperator, Operator, ParseCustomOperator, WrapCustomOperator,
};
use datafusion::prelude::*;
use datafusion::sql::sqlparser::ast::BinaryOperator;

#[derive(Debug)]
enum MyCustomOperator {
    /// Question, like `?`
    Question,
    /// Arrow, like `->`
    Arrow,
    /// Long arrow, like `->>`
    LongArrow,
}

impl std::fmt::Display for MyCustomOperator {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            MyCustomOperator::Question => write!(f, "?"),
            MyCustomOperator::Arrow => write!(f, "->"),
            MyCustomOperator::LongArrow => write!(f, "->>"),
        }
    }
}

impl CustomOperator for MyCustomOperator {
    fn binary_signature(
        &self,
        lhs: &DataType,
        rhs: &DataType,
    ) -> Result<(DataType, DataType, DataType)> {
        Ok((lhs.clone(), rhs.clone(), lhs.clone()))
    }

    fn op_to_sql(&self) -> Result<BinaryOperator> {
        match self {
            MyCustomOperator::Question => Ok(BinaryOperator::Question),
            MyCustomOperator::Arrow => Ok(BinaryOperator::Arrow),
            MyCustomOperator::LongArrow => Ok(BinaryOperator::LongArrow),
        }
    }

    fn name(&self) -> &'static str {
        match self {
            MyCustomOperator::Question => "Question",
            MyCustomOperator::Arrow => "Arrow",
            MyCustomOperator::LongArrow => "LongArrow",
        }
    }
}

impl TryFrom<&str> for MyCustomOperator {
    type Error = ();

    fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
        match value {
            "Question" => Ok(MyCustomOperator::Question),
            "Arrow" => Ok(MyCustomOperator::Arrow),
            "LongArrow" => Ok(MyCustomOperator::LongArrow),
            _ => Err(()),
        }
    }
}

#[derive(Debug)]
struct CustomOperatorParser;

impl ParseCustomOperator for CustomOperatorParser {
    fn name(&self) -> &str {
        "PostgresParseCustomOperator"
    }

    fn op_from_ast(&self, op: &BinaryOperator) -> Result<Option<Operator>> {
        match op {
            BinaryOperator::Question => Ok(Some(MyCustomOperator::Question.into())),
            BinaryOperator::Arrow => Ok(Some(MyCustomOperator::Arrow.into())),
            BinaryOperator::LongArrow => Ok(Some(MyCustomOperator::LongArrow.into())),
            _ => Ok(None),
        }
    }

    fn op_from_name(&self, raw_op: &str) -> Result<Option<Operator>> {
        if let Ok(op) = MyCustomOperator::try_from(raw_op) {
            Ok(Some(op.into()))
        } else {
            Ok(None)
        }
    }
}

impl FunctionRewrite for CustomOperatorParser {
    fn name(&self) -> &str {
        "PostgresParseCustomOperator"
    }

    fn rewrite(
        &self,
        expr: Expr,
        _schema: &DFSchema,
        _config: &ConfigOptions,
    ) -> Result<Transformed<Expr>> {
        if let Expr::BinaryExpr(bin_expr) = &expr {
            if let Operator::Custom(WrapCustomOperator(op)) = &bin_expr.op {
                if let Ok(pg_op) = MyCustomOperator::try_from(op.name()) {
                    // return BinaryExpr with a different operator
                    let mut bin_expr = bin_expr.clone();
                    bin_expr.op = match pg_op {
                        MyCustomOperator::Arrow => Operator::StringConcat,
                        MyCustomOperator::LongArrow => Operator::Plus,
                        MyCustomOperator::Question => Operator::Minus,
                    };
                    return Ok(Transformed::yes(Expr::BinaryExpr(bin_expr)));
                }
            }
        }
        Ok(Transformed::no(expr))
    }
}

async fn run(sql: &str) -> Result<()> {
    let config = SessionConfig::new().set_str("datafusion.sql_parser.dialect", "postgres");
    let mut ctx = SessionContext::new_with_config(config);
    ctx.register_function_rewrite(Arc::new(CustomOperatorParser))?;
    ctx.register_parse_custom_operator(Arc::new(CustomOperatorParser))?;
    let df = ctx.sql(sql).await?;
    df.show().await
}

#[tokio::main]
async fn main() {
    run("select 'foo'->'bar';").await.unwrap();
    run("select 1->>2;").await.unwrap();
    run("select 1 ? 2;").await.unwrap();
}

samuelcolvin avatar Jun 26 '24 23:06 samuelcolvin