code-transformer
code-transformer copied to clipboard
TypeError: 'float' object cannot be interpreted as an integer on modern pytroch
When run on pytorch '1.10.2+cu102' I get following error:
Traceback (most recent call last):
File "/code_transformer/preprocessing/graph/transform.py", line 30, in __call__
distance_matrix = distance_metric(adj)
File "/code_transformer/preprocessing/graph/distances.py", line 76, in __call__
sp_length = all_pairs_shortest_paths(G=G)
File "/code_transformer/preprocessing/graph/alg.py", line 45, in all_pairs_shortest_paths
values = torch.tensor([(dct[0], key, value) for dct in sps for key, value in dct[1].items()],
TypeError: 'float' object cannot be interpreted as an integer
Following diff seems to solve the issue. Somehow pytorch now is not happy with converting floats to integers :)
Index: code_transformer/preprocessing/graph/alg.py
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/code_transformer/preprocessing/graph/alg.py b/code_transformer/preprocessing/graph/alg.py
--- a/code_transformer/preprocessing/graph/alg.py (revision 362ec5300e94c6901566b38e10cb3b93440e7c52)
+++ b/code_transformer/preprocessing/graph/alg.py (date 1663341487624)
@@ -42,7 +42,7 @@
create_using = nx.Graph
G = nx.from_edgelist(edges, create_using=create_using)
sps = nx.all_pairs_dijkstra_path_length(G, cutoff=cutoff)
- values = torch.tensor([(dct[0], key, value) for dct in sps for key, value in dct[1].items()],
+ values = torch.tensor([(dct[0], key, int(value)) for dct in sps for key, value in dct[1].items()],
dtype=torch.long)
return values
@@ -95,7 +95,7 @@
sibling_edges = next_sibling_edges(tree_edges).numpy()
G_siblings = nx.from_edgelist(sibling_edges, create_using=nx.DiGraph)
sps = list(nx.all_pairs_dijkstra_path_length(G_siblings))
- sibling_sp_edgelist = torch.tensor([(from_node, to_node, dist)
+ sibling_sp_edgelist = torch.tensor([(int(from_node), int(to_node), dist)
for from_node, dct in sps
for to_node, dist in dct.items()],
dtype=torch.long)