Aggregator and generator inputs should be Iterators, not lists
Description
Here is a sample generator from LLM tutorial:
from typing import List
class Dialog(BaseModel):
id: int
text: str
def text_block(id: List[int], sender: List[str], text: List[str]) -> Dialog:
columns = zip(text, sender)
conversation = ""
for text, sender in columns:
conversation = "\n ".join([conversation,f"{sender}: {text}"])
yield Dialog(id=id[0], text=conversation)
chain = DataChain.from_csv("gs://datachain-demo/chatbot-csv/").agg(text_block, output={"dialog": Dialog}, partition_by="id").save()
This syntax has a number of issues:
-
Input Column names are implictly made into list names. This is awkward because argument "sender" is a list that would be better named "senders".
-
Passing lists from SQL limits out-of-memory operations
-
The aggregation key when passed as a parameter does not have to be a list because it is identical in every record
Here is a proposed updated signature:
def text_block(id: int, sender: Iterator[str], text: Iterator[str]): -> dict[str, str]
columns = zip(text, sender)
conversation = ""
for text, sender in columns:
conversation = "\n ".join([conversation,f"{sender}: {text}"])
yield {"id": id[0], "conversation": conversation}
chain = DataChain.from_csv('gs://datachain-demo/chatbot-csv/').agg(text_block, partition_by='id').save()
That's great idea!
It seems you are also proposed the idea of returning dict and use keys of the dict as return signals. I recommend creating a separate issue for that - these two are not related to each other and dict as an output might be challenging issue since we have a built-in dict already.
Without this, the API should look the one below. @volkfox please correct me if I'm missing anything.
def text_block(id: int, sender: Iterator[str], text: Iterator[str]) -> tuple[int, str]:
columns = zip(text, sender)
conversation = ""
for text, sender in columns:
conversation = "\n ".join([conversation,f"{sender}: {text}"])
yield id, conversation
chain = (
DataChain.from_csv('gs://datachain-demo/chatbot-csv/')
.agg(res=text_block, partition_by='id', output={"id": int, "conversation": str} )
.save()
)