data
data copied to clipboard
`insert_dp` for adding additional pipes (similar to `replace_dp` and `remove_dp`)
🚀 The feature
Building on dataloader2.graph.replace_dp add a insert_dp possibly insert_dpg (insert a sub graph) functions to insert datapipes into an existing template.
Motivation, pitch
I want to be able to add a single or a graph of datapipes without replacing the existing datapipes.
My practical example is making a dqn multi-processing friendly async dqn
A default dqn agent has the following pipeline:
agent = AgentBase(model)
agent = StepFieldSelector(agent,field='state')
agent = SimpleModelRunner(agent,device=device)
agent = ArgMaxer(agent)
selector = EpsilonSelector(agent,min_epsilon=min_epsilon,max_epsilon=max_epsilon,max_steps=max_steps,device=device)
if logger_bases is not None: agent = EpsilonCollector(selector,logger_bases)
agent = ArgMaxer(agent,only_idx=True)
agent = NumpyConverter(agent)
agent = PyPrimativeConverter(agent)
agent = AgentHead(agent)
I want to make the tempalte / base dqn agent capable of syncing a model across spawn processes. So we insert a data pipe to sync the model.
agent = AgentBase(model)
agent = StepFieldSelector(agent,field='state')
#### agent = ModelSubscriber(agent,device=device) <- insert a pipe before the `SimpleModelRunner` pipe ####
agent = SimpleModelRunner(agent,device=device)
agent = ArgMaxer(agent)
selector = EpsilonSelector(agent,min_epsilon=min_epsilon,max_epsilon=max_epsilon,max_steps=max_steps,device=device)
if logger_bases is not None: agent = EpsilonCollector(selector,logger_bases)
agent = ArgMaxer(agent,only_idx=True)
agent = NumpyConverter(agent)
agent = PyPrimativeConverter(agent)
agent = AgentHead(agent)
Alternatives
Option 1
Add if statements / modify the template to contain most extensions, kind of like the EpsilonCollector I have above.
Option 2
Modify replace_db to support replacing a dp with a DataPipeGraph. So we would do something like:
agent_sub = ModelSubscriber(find_dps(agent,StepFieldSelector)[0],device=device)
agent_sub = SimpleModelRunner(agent_sub,device=device)
replace_db(agent,SimpleModelRunner,agent_sub)
Additional context
Not super tested, but the implementation below I think can allow for inserting a DataPipe or an entire isolated DataPipeGraph
# I have a PassThroughIterPipe that acts as a location that the insert code can definitively know that
# it can reassign
class PassThroughIterPipe(dp.iter.IterDataPipe):
def __init__(self,source_datapipe): self.source_datapipe = source_datapipe
def __iter__(self): return (o for o in self.source_datapipe)
def find_dp(graph: DataPipeGraph, dp_type: Type[DataPipe]) -> DataPipe:
pipes = find_dps(graph,dp_type)
if len(pipes)==1: return pipes[0]
elif len(pipes)>1:
found_ids = set([id(pipe) for pipe in pipes])
if len(found_ids)>1:
warn(f"""There are {len(pipes)} pipes of type {dp_type}. If this is intended,
please use `find_dps` directly. Returning first instance.""")
return pipes[0]
else:
raise LookupError(f'Unable to find {dp_type} starting at {graph}')
find_dp.__doc__ = "Returns a single `DataPipe` as opposed to `find_dps`.\n"+find_dps.__doc__
def _insert_dp(recv_dp, send_graph: DataPipeGraph, old_dp: DataPipe, new_dp: DataPipe) -> None:
old_dp_id = id(old_dp)
for send_id in send_graph:
if send_id == old_dp_id:
# We do the same as replace_dp here by switching recv_dp to new_dp
_assign_attr(recv_dp, old_dp, new_dp, inner_dp=True)
# Replace the last datapipe in new_dp with the old_dp
final_datapipe = find_dp(traverse(new_dp),PassThroughIterPipe)
# But now we switch new_dp from the place holder pipe PassThroughIterPipe, to old_dp thus
# not breaking the chain. Havent tested if this works for whole graphs as new_dp
_assign_attr(new_dp, final_datapipe, old_dp, inner_dp=True)
# new_dp.source_datapipe
else:
send_dp, sub_send_graph = send_graph[send_id]
_insert_dp(send_dp, sub_send_graph, old_dp, new_dp)
def insert_dp(graph: DataPipeGraph, on_datapipe: DataPipe, insert_datapipe: DataPipe) -> DataPipeGraph:
r"""
Given the graph of DataPipe generated by ``traverse`` function and the ``on_datapipe`` DataPipe to be reconnected and
the new ``insert_datapipe`` DataPipe to be inserted after ``on_datapipe``,
return the new graph of DataPipe.
"""
assert len(graph) == 1
# Check if `on_datapipe` is that the head of the graph
# If so, we `insert_datapipe`
if id(on_datapipe) in graph:
graph = traverse(insert_datapipe, only_datapipe=True)
final_datapipe = list(graph.values())[0][0]
for recv_dp, send_graph in graph.values():
_insert_dp(recv_dp, send_graph, on_datapipe, insert_datapipe)
return traverse(final_datapipe, only_datapipe=True)
With the test being:
it_pipe = dp.iter.IterableWrapper([1,2,3,4,5,6])
pipe = it_pipe.cycle(count=2)
pipe = pipe.batch(batch_size=2)
new_dp = insert_dp(
traverse(pipe,only_datapipe=True),
find_dp(traverse(pipe,only_datapipe=True),dp.iter.Cycler),
dp.iter.Header(PassThroughIterPipe([]),limit=4)
)

Thanks for taking explore of torchdata.
IIUC, you want to insert a DataPipe between two DataPipes. Let's say you want to insert DataPipe C between A and B.
Original Graph:
A -> B
Expected New Graph:
A -> C -> B
Does this DataPipe accept an argument of DataPipe? And, the PassThroughDataPipe is used as a placeholder to construct DataPipe C. I guess the intention is not to get DataPipe A. However, we need to find DataPipe A anyways for the insert_dp function, I don't see the benefit of having a PassThroughDataPipe. Please correct me if I am wrong.
Your proposed approach:
- First, get instance of
DataPipeA - Second, construct
DataPipeC with PassThrough (graph: PassThrough -> C) - Third, insert C after A. (replace PassThrough by A and then replace A by C)
For the option 2:
- First, get instance of
DataPipeA - Second, construct
DataPipeC with A (graph: A -> C) - Third, replace A by C in the original graph (A -> B ===> A -> C -> B)
@ejguan I think you understand what im trying to do.
I did go ahead and implemented a practical example that ill attach below.
Does this DataPipe accept an argument of DataPipe? And, the PassThroughDataPipe
is used as a placeholder to construct DataPipe C. I guess the intention is not to get
DataPipe A. However, we need to find DataPipe A anyways for the insert_dp function,
I don't see the benefit of having a PassThroughDataPipe. Please correct me if I am wrong.
I think you're right actually, but the new pipe needs to be a function to avoid the need for a PassThrough (sub_graph)
I'm realizing that I can get the same behavior using pure replace_dp:
pipe = A(range(10))
pipe = B(pipe)
test_eq(list(pipe),range(10))
traverse(pipe)
# {140090589776784: (B, {140090589776848: (A, {})})}
new_dp = replace_dp(
traverse(pipe,only_datapipe=True),
find_dp(traverse(pipe,only_datapipe=True),A),
C(find_dp(traverse(pipe,only_datapipe=True),A))
)
new_dp
# {140090589776784: (B, {140090589778448: (C, {140090589776848: (A, {})})})}
pipe = A(range(10))
pipe = B(pipe)
test_eq(list(pipe),range(10))
traverse(pipe)
# {140090589780944: (B, {140090589779664: (A, {})})}
def sub_graph(pipe):
pipe = C(pipe)
pipe = D(pipe)
pipe = E(pipe)
return pipe
new_dp = replace_dp(
traverse(pipe,only_datapipe=True),
find_dp(traverse(pipe,only_datapipe=True),A),
sub_graph(find_dp(traverse(pipe,only_datapipe=True),A)) # I can even insert sections of pipes
)
new_dp
# {140090589780944: (B,
# {140090589796432: (E,
# {140090589796368: (D,
# {140090589796112: (C, {140090589779664: (A, {})})})})})}
So maybe instead of a function in torchdata, this would instead of a documentation thing. I think a bunch of users will be interested in "insert" behaviors.
So maybe instead of a function in torchdata, this would instead of a documentation thing. I think a bunch of users will be interested in "insert" behaviors.
Sounds like a reasonable request. Do you want to open a PR to append it to the in-line doc here?
I will open another issue regarding our online doc.
@ejguan I'm happy to! I'll try to get a pr made later this week hopefully