cugraph icon indicating copy to clipboard operation
cugraph copied to clipboard

[FEA] Use `edge_ids` directly in uniform sampling call to prevent cost of edge_id lookup

Open VibhuJawa opened this issue 3 years ago • 0 comments

Describe the solution you'd like and any additional context

Currently 74% of time of sample_neighbors in gnn/graph_store.py is spent in looking up edge_ids. If we provide edge_ids inside the uniform sampling call we should be able to optimize that by alot by preventing the time spent in edge_ids lookup.

Additional Context: Fixing this will be critical for us to get acceptable performance for upstreaming DGL work.

Benchmarks

See below line for the profiler.

   259         1      73755.0  73755.0     74.0          sampled_df = edge_df.merge(sampled_df)
 181                                               def sample_neighbors(
   182                                                   self, nodes, fanout=-1, edge_dir="in", prob=None, replace=False
   183                                               ):
   184                                                   """
   185                                                   Sample neighboring edges of the given nodes and return the subgraph.
   186                                           
   187                                                   Parameters
   188                                                   ----------
   189                                                   nodes_cap : Dlpack of Node IDs (single dimension)
   190                                                       Node IDs to sample neighbors from.
   191                                                   fanout : int
   192                                                       The number of edges to be sampled for each node on each edge type.
   193                                                       If -1 is given all the neighboring edges for each node on
   194                                                       each edge type will be selected.
   195                                                   edge_dir : str {"in" or "out"}
   196                                                       Determines whether to sample inbound or outbound edges.
   197                                                       Can take either in for inbound edges or out for outbound edges.
   198                                                   prob : str
   199                                                       Feature name used as the (unnormalized) probabilities associated
   200                                                       with each neighboring edge of a node. Each feature must be a
   201                                                       scalar. The features must be non-negative floats, and the sum of
   202                                                       the features of inbound/outbound edges for every node must be
   203                                                       positive (though they don't have to sum up to one). Otherwise,
   204                                                       the result will be undefined. If not specified, sample uniformly.
   205                                                   replace : bool
   206                                                       If True, sample with replacement.
   207                                           
   208                                                   Returns
   209                                                   -------
   210                                                   DLPack capsule
   211                                                       The src nodes for the sampled bipartite graph.
   212                                                   DLPack capsule
   213                                                       The sampled dst nodes for the sampledbipartite graph.
   214                                                   DLPack capsule
   215                                                       The corresponding eids for the sampled bipartite graph
   216                                                   """
   217                                           
   218         1          2.0      2.0      0.0          if edge_dir not in ["in", "out"]:
   219                                                       raise ValueError(
   220                                                           f"edge_dir must be either 'in' or 'out' got {edge_dir} instead"
   221                                                       )
   222                                           
   223         1          1.0      1.0      0.0          if edge_dir == "in":
   224         1          2.0      2.0      0.0              sg = self.extracted_reverse_subgraph_without_renumbering
   225                                                   else:
   226                                                       sg = self.extracted_subgraph_without_renumbering
   227                                           
   228         1          1.0      1.0      0.0          if not hasattr(self, '_sg_node_dtype'):
   229                                                       self._sg_node_dtype = sg.edgelist.edgelist_df['src'].dtype
   230                                           
   231                                                   # Uniform sampling assumes fails when the dtype
   232                                                   # if the seed dtype is not same as the node dtype
   233         1        413.0    413.0      0.4          nodes = cudf.from_dlpack(nodes).astype(self._sg_node_dtype)
   234                                           
   235         2      21310.0  10655.0     21.4          sampled_df = uniform_neighbor_sample(
   236         1          1.0      1.0      0.0              sg, start_list=nodes, fanout_vals=[fanout],
   237         1          0.0      0.0      0.0              with_replacement=replace
   238                                                   )
   239                                           
   240         1        379.0    379.0      0.4          sampled_df.drop(columns=["indices"], inplace=True)
   241                                           
   242                                                   # handle empty graph case
   243         1         13.0     13.0      0.0          if len(sampled_df) == 0:
   244                                                       return None, None, None
   245                                           
   246                                                   # we reverse directions when directions=='in'
   247         1          1.0      1.0      0.0          if edge_dir == "in":
   248         2        177.0     88.5      0.2              sampled_df.rename(
   249         1          1.0      1.0      0.0                  columns={"destinations": src_n, "sources": dst_n}, inplace=True
   250                                                       )
   251                                                   else:
   252                                                       sampled_df.rename(
   253                                                           columns={"sources": src_n, "destinations": dst_n}, inplace=True
   254                                                       )
   255                                           
   256                                                   # FIXME: Remove once below lands
   257                                                   # https://github.com/rapidsai/cugraph/issues/2444
   258         1       1226.0   1226.0      1.2          edge_df = self.gdata._edge_prop_dataframe[[src_n, dst_n, eid_n]]
   259         1      73755.0  73755.0     74.0          sampled_df = edge_df.merge(sampled_df)
   260                                           
   261         1          2.0      2.0      0.0          return (
   262         1        929.0    929.0      0.9              sampled_df[src_n].to_dlpack(),
   263         1        714.0    714.0      0.7              sampled_df[dst_n].to_dlpack(),
   264         1        688.0    688.0      0.7              sampled_df[eid_n].to_dlpack(),
   265                                                   )

VibhuJawa avatar Aug 09 '22 03:08 VibhuJawa