datafusion
datafusion copied to clipboard
Custom operator support
I think this is ready to review.
~~WIP to decide if we progress cc @alamb.~~
~~This is based on #11132 since I need Operator to be non-copyable. If we decide to progress I can rebase this, or cherry pick and create a new branch.~~
~~You can see this without the churn of #11132 at https://github.com/samuelcolvin/datafusion/pull/1.~~
high level question
Is all the effort here worth it, should we just add the 20 or so extra operators from the sql library to Operator?
Advantages of this route:
- more flexibility in how operators behave, what their signatures is, precedence, negation etc.
- easier to use custom operators or new operators adding to the SQL library without waiting for datafusion to support them
Disadvantages:
- ~~a lot of faff here~~ it's not that bad
To to
- [x] How should
ParseCustomOperatorbe passed into the SQL parser, it definitely shouldn't be as it is now, perhaps we should have an equivalent ofregister_function_rewritelike, e.g.register_custom_operators? DONE - [x] ~~is the hack with
WrapCustomOperatornecessary and acceptable, maybe someone else's Rust foo could avoid this?~~ I think what we have is good - [x] ~~should
CustomOperatorprovide aget_udfmethod, then we rewrite to that function automatically, rather than relying onregister_function_rewrite?~~ I don't think so, the current write logic is more powerful - [x] what should we do about
from_proto_binary_op, we can't keep the same signature and support custom operators, this might be easy to fix - done by adding the register methods toFunctionRegistry - [x] ~~is it okay to rely on
name()of the operator to implement equality, ordering and hashing?~~ I think it's good - [x] Needs tests - basic tests are done, LMK what more is needed
Example Usage:
Here's a trivial example of usage that just replaces the "->" operator with string concat:
use std::sync::Arc;
use datafusion::arrow::datatypes::DataType;
use datafusion::common::config::ConfigOptions;
use datafusion::common::tree_node::Transformed;
use datafusion::common::DFSchema;
use datafusion::error::Result;
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::expr_rewriter::FunctionRewrite;
use datafusion::logical_expr::{
CustomOperator, Operator, ParseCustomOperator, WrapCustomOperator,
};
use datafusion::prelude::*;
use datafusion::sql::sqlparser::ast::BinaryOperator;
#[derive(Debug)]
enum MyCustomOperator {
/// Question, like `?`
Question,
/// Arrow, like `->`
Arrow,
/// Long arrow, like `->>`
LongArrow,
}
impl std::fmt::Display for MyCustomOperator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MyCustomOperator::Question => write!(f, "?"),
MyCustomOperator::Arrow => write!(f, "->"),
MyCustomOperator::LongArrow => write!(f, "->>"),
}
}
}
impl CustomOperator for MyCustomOperator {
fn binary_signature(
&self,
lhs: &DataType,
rhs: &DataType,
) -> Result<(DataType, DataType, DataType)> {
Ok((lhs.clone(), rhs.clone(), lhs.clone()))
}
fn op_to_sql(&self) -> Result<BinaryOperator> {
match self {
MyCustomOperator::Question => Ok(BinaryOperator::Question),
MyCustomOperator::Arrow => Ok(BinaryOperator::Arrow),
MyCustomOperator::LongArrow => Ok(BinaryOperator::LongArrow),
}
}
fn name(&self) -> &'static str {
match self {
MyCustomOperator::Question => "Question",
MyCustomOperator::Arrow => "Arrow",
MyCustomOperator::LongArrow => "LongArrow",
}
}
}
impl TryFrom<&str> for MyCustomOperator {
type Error = ();
fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
match value {
"Question" => Ok(MyCustomOperator::Question),
"Arrow" => Ok(MyCustomOperator::Arrow),
"LongArrow" => Ok(MyCustomOperator::LongArrow),
_ => Err(()),
}
}
}
#[derive(Debug)]
struct CustomOperatorParser;
impl ParseCustomOperator for CustomOperatorParser {
fn name(&self) -> &str {
"PostgresParseCustomOperator"
}
fn op_from_ast(&self, op: &BinaryOperator) -> Result<Option<Operator>> {
match op {
BinaryOperator::Question => Ok(Some(MyCustomOperator::Question.into())),
BinaryOperator::Arrow => Ok(Some(MyCustomOperator::Arrow.into())),
BinaryOperator::LongArrow => Ok(Some(MyCustomOperator::LongArrow.into())),
_ => Ok(None),
}
}
fn op_from_name(&self, raw_op: &str) -> Result<Option<Operator>> {
if let Ok(op) = MyCustomOperator::try_from(raw_op) {
Ok(Some(op.into()))
} else {
Ok(None)
}
}
}
impl FunctionRewrite for CustomOperatorParser {
fn name(&self) -> &str {
"PostgresParseCustomOperator"
}
fn rewrite(
&self,
expr: Expr,
_schema: &DFSchema,
_config: &ConfigOptions,
) -> Result<Transformed<Expr>> {
if let Expr::BinaryExpr(bin_expr) = &expr {
if let Operator::Custom(WrapCustomOperator(op)) = &bin_expr.op {
if let Ok(pg_op) = MyCustomOperator::try_from(op.name()) {
// return BinaryExpr with a different operator
let mut bin_expr = bin_expr.clone();
bin_expr.op = match pg_op {
MyCustomOperator::Arrow => Operator::StringConcat,
MyCustomOperator::LongArrow => Operator::Plus,
MyCustomOperator::Question => Operator::Minus,
};
return Ok(Transformed::yes(Expr::BinaryExpr(bin_expr)));
}
}
}
Ok(Transformed::no(expr))
}
}
async fn run(sql: &str) -> Result<()> {
let config = SessionConfig::new().set_str("datafusion.sql_parser.dialect", "postgres");
let mut ctx = SessionContext::new_with_config(config);
ctx.register_function_rewrite(Arc::new(CustomOperatorParser))?;
ctx.register_parse_custom_operator(Arc::new(CustomOperatorParser))?;
let df = ctx.sql(sql).await?;
df.show().await
}
#[tokio::main]
async fn main() {
run("select 'foo'->'bar';").await.unwrap();
run("select 1->>2;").await.unwrap();
run("select 1 ? 2;").await.unwrap();
}