langchain
langchain copied to clipboard
Improve effeciency of TextSplitter.split_documents, iterate once
Improve TextSplitter.split_documents, collect page_content and metadata in one iteration
Who can review?
Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested:
@eyurtsev In the case where documents is a generator that can only be iterated once making this change is a huge help. Otherwise a silent issue happens where metadata is empty for all documents when documents is a generator. So we expand the argument from List[Document]
to Union[Iterable[Document], Sequence[Document]]
Generate a million mock documents, here are the first few
from dataclasses import dataclass
@dataclass
class Document:
page_content: None
metadata: None
docs = []
for i in range(int(1e7)):
doc = Document(page_content=f'text_{i}', metadata=f'metadata_{i}')
docs.append(doc)
docs[:3]
[Document(page_content='text_0', metadata='metadata_0'),
Document(page_content='text_1', metadata='metadata_1'),
Document(page_content='text_2', metadata='metadata_2')]
4 Proposed methods (the last 2 for completion)
def function_1(documents):
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return texts, metadatas
def function_2(documents):
texts, metadatas = [], []
for doc in documents:
texts.append(doc.page_content)
metadatas.append(doc.metadata)
return texts, metadatas
def function_3(documents):
return tuple(zip(*([d.page_content, d.metadata] for d in documents)))
def function_4(documents):
list_of_pairs = []
for d in documents:
list_of_pairs.append([d.page_content, d.metadata])
return tuple(zip(*list_of_pairs))
%%timeit
function_1(docs)
331 ms ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit
function_2(docs)
319 ms ± 770 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit
function_3(docs)
1.09 s ± 9.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit
function_3(docs)
1.08 s ± 10.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Comparison across various docs lengths
import matplotlib.pyplot as plt
import numpy as np
import timeit
list_lengths = [1, 10, 100, 1000, 10000, 100000, 1000000, 10000000]
timings = {'original_function': [], 'proposed_function': [], 'alternate_function': [], 'alternate_function_2': []}
functions = {'original_function': function_1, 'proposed_function': function_2, 'alternate_function': function_3, 'alternate_function_2': function_4}
for length in list_lengths:
documents = docs[:length]
for function_name, function in functions.items():
start_time = timeit.default_timer()
function(documents)
end_time = timeit.default_timer()
timings[function_name].append(end_time - start_time)
for function_name, times in timings.items():
plt.plot(list_lengths, times, label=function_name)
plt.xlabel('List Length')
plt.ylabel('Time (seconds)')
plt.legend()
plt.loglog()
plt.show()
@eyurtsev Here is the summary of changes for this review.
- Add tests for split_documents.
- Add single iteration option
- Changed the type of argument from
List[Document]
toIterable[Document]
@startakovsky Apologies for making you benchmark here. I didn't read the summary carefully and thought the change was to speed up the code. The change was actually to avoid potentially exhausting a generator, in which case benchmarking considerations are irrelevant. But thank you for taking the time to do it!
addressing lint issues here: https://github.com/hwchase17/langchain/pull/5111
@eyurtsev good to know anyway if there is a significance performance dip. Thanks for merging this!