datafusion icon indicating copy to clipboard operation
datafusion copied to clipboard

WIP: experiment with SMJ last buffered batch

Open comphead opened this issue 1 year ago • 1 comments

Which issue does this PR close?

Related to #11555

Closes #.

Rationale for this change

Experiment with approach how to identify a last buffered batch for the given streaming row join key

What changes are included in this PR?

Are these changes tested?

Are there any user-facing changes?

comphead avatar Aug 20 '24 16:08 comphead

To create a reproduce test its needed to run a test in debug mode

async fn test_anti_join_1k_filtered() {
    // NLJ vs HJ gives wrong result
    // Tracked in https://github.com/apache/datafusion/issues/11537
    for i in 0..1000 {
        JoinFuzzTestCase::new(
            make_staggered_batches(1000),
            make_staggered_batches(1000),
            JoinType::LeftAnti,
            Some(Box::new(col_lt_col_filter)),
        )
            .run_test(&[JoinTestType::HjSmj], true)
            .await
    }
}

the test creates a dump of data locally to the disk. for example fuzz_test_debug/batch_size_7

The test below is a reproduce case(just set the paths) which step 1 outputs

#[tokio::test]
async fn test1() {
    let left: Vec<RecordBatch> = JoinFuzzTestCase::load_partitioned_batches_from_parquet(
        "fuzz_test_debug/batch_size_7/input1",
    )
    .await
    .unwrap();

    let right: Vec<RecordBatch> =
        JoinFuzzTestCase::load_partitioned_batches_from_parquet(
            "fuzz_test_debug/batch_size_7/input2",
        )
        .await
        .unwrap();

    JoinFuzzTestCase::new(
        left,
        right,
        JoinType::LeftAnti,
        Some(Box::new(col_lt_col_filter)),
    )
    .run_test(&[JoinTestType::HjSmj], false)
    .await;
}

comphead avatar Aug 20 '24 22:08 comphead

@korowa @viirya please help to understand scenario with ranges.

if there is a left streamed row with join key (1) from the right side we gonna have joined buffered batches where range shows what indices share the same join key.

For example

Streamed data        Buffered data
[1]               -> [0, 1, 1], [1, 1, 2]

Should have ranges [1..3], [0..2]

What I see now is for some extreme case I can get a joined buffered data when being called freeze_streamed which doesn't match the join key.

Like [1..3] for join key 1 and then [0..1] for join key 2, which looks weird for me and it seems like unexpected? WDYT?

comphead avatar Aug 29 '24 18:08 comphead

if there is a left streamed row with join key (1) from the right side we gonna have joined buffered batches where range shows what indices share the same join key.

For example

Streamed data        Buffered data
[1]               -> [0, 1, 1], [1, 1, 2]

Should have ranges [1..3], [0..2]

I don't get the question clearly.

You have [0, 1, 1] as buffered indices for same streamed row? Why you have same buffered row id 1 twice?

viirya avatar Aug 29 '24 20:08 viirya

if there is a left streamed row with join key (1) from the right side we gonna have joined buffered batches where range shows what indices share the same join key. For example

Streamed data        Buffered data
[1]               -> [0, 1, 1], [1, 1, 2]

Should have ranges [1..3], [0..2]

I don't get the question clearly.

You have [0, 1, 1] as buffered indices for same streamed row? Why you have same buffered row id 1 twice?

Thanks @viirya it's not indices, it is a raw data. Let me rephrase it.

If I have a left table

a b
10 20

and right table

a b
5 20
10 20
10 21
10 21
10 22
15 22

And join key is A and Filter is on column B

In freeze_streamed I can observe the right table comes as 3 batches

1 Batch. join_array [10] Range 1..3 - which is correct as rownumbers 1 and 2 related to join key 10 2 Batch. join_array[10] Range 0..2 - which is correct as rownumbers 0 and 1 related to join key 10 3 Batch. join_array[15] Range 0..1 - which is weird, why this batch associated ?

comphead avatar Aug 29 '24 20:08 comphead

#[tokio::test]
async fn test_ranges() {
    let left: Vec<RecordBatch> = make_staggered_batches(1);

    let left = vec![
        RecordBatch::try_new(
            left[0].schema().clone(),
            vec![
                Arc::new(Int32Array::from(vec![1])),
                Arc::new(Int32Array::from(vec![10])),
                Arc::new(Int32Array::from(vec![10])),
                Arc::new(Int32Array::from(vec![1000])),
            ],
        ).unwrap()
    ];

    let right = vec![
        RecordBatch::try_new(
            left[0].schema().clone(),
            vec![
                Arc::new(Int32Array::from(vec![0, 1, 1, 2])),
                Arc::new(Int32Array::from(vec![0, 10, 11, 20])),
                Arc::new(Int32Array::from(vec![0, 1100, 0, 2100])),
                Arc::new(Int32Array::from(vec![0, 11000, 0, 21000])),
            ],
        ).unwrap(),
        RecordBatch::try_new(
            left[0].schema().clone(),
            vec![
                Arc::new(Int32Array::from(vec![2, 2])),
                Arc::new(Int32Array::from(vec![20, 21])),
                Arc::new(Int32Array::from(vec![2101, 0])),
                Arc::new(Int32Array::from(vec![21001, 0])),
            ],
        ).unwrap(),

    ];

    JoinFuzzTestCase::new(
        left,
        right,
        JoinType::LeftAnti,
        Some(Box::new(col_lt_col_filter)),
    )
        .run_test(&[JoinTestType::HjSmj], false)
        .await;
}

if you debug freeze_streamed you can see of the buffered data batches has range 0 .. 1 but for another join key. Do you think it is correct? Probably we need to check join array from first batch with subsequent batches

comphead avatar Aug 29 '24 23:08 comphead

If I have a left table

a b 10 20 and right table

a b 5 20 10 20 10 21 10 21 10 22 15 22 And join key is A and Filter is on column B

In freeze_streamed I can observe the right table comes as 3 batches

1 Batch. join_array [10] Range 1..3 - which is correct as rownumbers 1 and 2 related to join key 10 2 Batch. join_array[10] Range 0..2 - which is correct as rownumbers 0 and 1 related to join key 10 3 Batch. join_array[15] Range 0..1 - which is weird, why this batch associated ?

Would you let me know how do you cut the 3 batches among the 6 buffered rows?

viirya avatar Aug 29 '24 23:08 viirya

If I have a left table a b 10 20 and right table a b 5 20 10 20 10 21 10 21 10 22 15 22 And join key is A and Filter is on column B In freeze_streamed I can observe the right table comes as 3 batches 1 Batch. join_array [10] Range 1..3 - which is correct as rownumbers 1 and 2 related to join key 10 2 Batch. join_array[10] Range 0..2 - which is correct as rownumbers 0 and 1 related to join key 10 3 Batch. join_array[15] Range 0..1 - which is weird, why this batch associated ?

Would you let me know how do you cut the 3 batches among the 6 buffered rows?

I believe it depends on batch_size, output_size. What I have observed the buffered batch of 6 rows can be processed differently. 3 + 1 + 1 + 1, or 1 + 1 + 1 + 1 + 1 + 1, or 1 batch of 6 rows.

I think @korowa mentioned it here

filtered anti join should return only the records for which buffered-side scanning is completed (as freeze_streamed may be called in the middle of buffered-data scanning, due to output batch size), and there were no true filters for them (from p.1) -- so, maybe we should split filter evaluation and output emission in freeze_streamed (since the filters should be checked for all matched indices, but in the same time, the current streamed index can be filitered out of output because it has further buffered batches to be joined with)?

For the simplicity lets consider the test in https://github.com/apache/datafusion/pull/12082#issuecomment-2319361185

When I debug the freeze_streamed I can see the buffered data is coming as

[datafusion/physical-plan/src/joins/sort_merge_join.rs:1500:25] &self.buffered_data.batches = [
    BufferedBatch {
        batch: Some(
            RecordBatch {
                columns: [
                    PrimitiveArray<Int32>
                    [
                      0,
                      1,
                      1,
                      2,
                    ],
                    PrimitiveArray<Int32>
                    [
                      0,
                      10,
                      11,
                      20,
                    ],
                    PrimitiveArray<Int32>
                    [
                      0,
                      1100,
                      0,
                      2100,
                    ],
                    PrimitiveArray<Int32>
                    [
                      0,
                      11000,
                      0,
                      21000,
                    ],
                ],
                row_count: 4,
            },
        ),
        range: 3..4,
        join_arrays: [
            PrimitiveArray<Int32>
            [
              0,
              1,
              1,
              2,
            ],
            PrimitiveArray<Int32>
            [
              0,
              10,
              11,
              20,
            ],
        ],
    },
    BufferedBatch {
        batch: Some(
            RecordBatch { 
                columns: [
                    PrimitiveArray<Int32>
                    [
                      2,
                      2,
                    ],
                    PrimitiveArray<Int32>
                    [
                      20,
                      21,
                    ],
                    PrimitiveArray<Int32>
                    [
                      2101,
                      0,
                    ],
                    PrimitiveArray<Int32>
                    [
                      21001,
                      0,
                    ],
                ],
                row_count: 2,
            },
        ),
        range: 0..1,
        join_arrays: [
            PrimitiveArray<Int32>
            [
              2,
              2,
            ],
            PrimitiveArray<Int32>
            [
              20,
              21,
            ],
        ],
    },
]

What are ranges here? the doc says

    /// The range in which the rows share the same join key
    pub range: Range<usize>,

but how range: 3..4 in first batch and range: 0..1, in second matches the join key at all? it points to non matched rows

comphead avatar Aug 29 '24 23:08 comphead

What are ranges here? the doc says

    /// The range in which the rows share the same join key
    pub range: Range<usize>,

but how range: 3..4 in first batch and range: 0..1, in second matches the join key at all? it points to non matched rows

range are the row indices of the batch in the BufferedBatch which have the same join key. Not related to match or not.

viirya avatar Aug 30 '24 00:08 viirya

range are the row indices of the batch in the BufferedBatch which have the same join key. Not related to match or not.

That matches my understanding of these ranges in buffered batches.

Like [1..3] for join key 1 and then [0..1] for join key 2, which looks weird for me and it seems like unexpected? WDYT?

@comphead, I've tried your example and what I see while debugging, there are 3 "versions" of buffered data with the following ranges

0..1 // join key 0

1..3 // join key 1, first right batch
0..2 // join key 1, second right batch

2..3 // join key 2

I'm able to see them before and after join_partial call.

At what point in the code you are able to observe 0..1 for the key 2?

korowa avatar Sep 01 '24 11:09 korowa

At what point in the code you are able to observe 0..1 for the key 2?

I'm running the test from https://github.com/apache/datafusion/pull/12082#issuecomment-2319361185 and debugging the freeze_streamed function. For batch size 2 I'm seeing batches distribution like https://github.com/apache/datafusion/pull/12082#issuecomment-2319492383

You can see there that buffered batch with join array

        join_arrays: [
            PrimitiveArray<Int32>
            [
              2,
              2,
            ],
            PrimitiveArray<Int32>
            [
              20,
              21,
            ],
        ],

which confuses me, I was thinking only buffered batches that contains a streaming key should be there. But looks like its not.

I believe we can get do following:

  • get current join key from streamed_batch.join_arrays by self.streamed_batch.idx
  • find all batches in buffered_data that contain the join key from step 1
  • if the buffered_data.scanning_batch_idx equals to batches length from step2 and this batch range.end == num_rows that probably means SMJ already emitted all the indices from this batch and we are done for the some particular key

@viirya @korowa do you think it would be enough to identify that all rows has been processed for the given join key?

comphead avatar Sep 02 '24 19:09 comphead

@comphead I've finally got it -- it's like in this case SMJ is trying to produce output for each join key pair (streamed-buffered) -- I guess it's how smj state managements works now -- streamed-side index won't move, until all buffered-side data will be processed, since it's required to identify current ordering.

- get current join key from streamed_batch.join_arrays by self.streamed_batch.idx
- find all batches in buffered_data that contain the join key from step 1
- if the buffered_data.scanning_batch_idx equals to batches length from step2 and this batch range.end == num_rows that probably means SMJ already emitted all the indices from this batch and we are done for the some particular key

I'd say that normally you don't need to compare join keys, and you should rely on buffered_data.scanning_finished() (or self.current_ordering == Less), but in your example both of these conditions are either not working, or not intended to work (not sure which of these two is a correct statement).

I also hope to start spending some time on SMJ due to https://github.com/apache/datafusion/issues/12359

korowa avatar Sep 08 '24 13:09 korowa

@comphead I've finally got it -- it's like in this case SMJ is trying to produce output for each join key pair (streamed-buffered) -- I guess it's how smj state managements works now -- streamed-side index won't move, until all buffered-side data will be processed, since it's required to identify current ordering.

- get current join key from streamed_batch.join_arrays by self.streamed_batch.idx
- find all batches in buffered_data that contain the join key from step 1
- if the buffered_data.scanning_batch_idx equals to batches length from step2 and this batch range.end == num_rows that probably means SMJ already emitted all the indices from this batch and we are done for the some particular key

I'd say that normally you don't need to compare join keys, and you should rely on buffered_data.scanning_finished() (or self.current_ordering == Less), but in your example both of these conditions are either not working, or not intended to work (not sure which of these two is a correct statement).

I also hope to start spending some time on SMJ due to #12359

Thanks @korowa I have been experimenting so much with different parts of SMJ and it showed that buffered_data.scanning_finished() is not working, self.current_ordering == Less we cannot rely on this in freeze_streamed as it is called only if self.current_ordering == Equal. Now I'm trying to calculate if its possible to predict that ordering gonna change from Equal to Less.

And yes I was also trying to compare join arrays which potentially can give us a clue that everything is processed, but it might be very expensive

comphead avatar Sep 08 '24 19:09 comphead