py-tree-sitter
py-tree-sitter copied to clipboard
Incorret Parsing of `body` field of `class_defintion`
I want to build a call graph using tree-sitter. While doing so, I encountered an issue that building a tree out of the body part of a class_definition doesn't work. I added comments to the code, so I will be short in this prosa.
Any preliminary indication whether is wrong usage of py-tree-sitter or a bug in the implementation is highly appreciated since the task is work related and no spare time project.
The source code:
class A:
def sub(self): ...
def add(self): ...
class B:
def div(self): ...
def mul(self): ...
def f(): ...
def g(): ...
print()
list()
A()
My tree-sitter code:
nodetype.py:
from enum import Enum
class NodeType(Enum):
CALL = "call"
CLASS_DEF = "class_definition"
FUNCTION_DEF = "function_definition"
IF = "if_statement"
MODULE = "module"
and main.py:
import itertools
import matplotlib.pyplot as plt
import networkx as nx
from pathlib import Path
from tree_sitter import Language, Parser, Tree, Node
import tree_sitter_python
from nodetype import NodeType
PYTHON_LANGUAGE = Language(tree_sitter_python.language())
PARSER = Parser(PYTHON_LANGUAGE)
source_file = Path("./tscg/cursor_source_code.py")
CODE = source_file.read_text()
TREE = PARSER.parse(CODE.encode())
G = nx.DiGraph()
def name(self) -> str:
"""Get the name of the current node.
Only possible if there is an `identifier` Node within it's children.
"""
cap_name = "idf"
snippet = CODE
tree_snippet = PARSER.parse(snippet.encode(encoding="utf8"))
# Class Module overwrites this method, hence no case here
match self.type:
case NodeType.FUNCTION_DEF.value | NodeType.CLASS_DEF.value:
q = f"({self.type} name: (identifier) @{cap_name})"
query = PYTHON_LANGUAGE.query(q).set_byte_range((self.start_byte, self.end_byte))
captures = query.captures(tree_snippet.root_node)
assert len(captures[cap_name]) == 1, f"More than one identifier node for Node {self.tsn} was captured."
try:
capture = captures[cap_name][0]
return snippet[capture.start_byte : capture.end_byte]
except KeyError:
msg = f"No `identifier` Node available in {self.tsn}, hence no Name."
raise Exception(msg)
case NodeType.CALL.value:
q = (
# Ordinary calls: f()
f"""({self.type}
function: (identifier) @{cap_name}
)"""
f"""({self.type}
function: (attribute
object: (identifier) @obj.{cap_name}
attribute: (identifier) @{cap_name}
)
)
"""
)
query = PYTHON_LANGUAGE.query(q).set_byte_range((self.byte_range))
matches = query.matches(tree_snippet.root_node)
assert len(matches) == 1, f"{len(matches) = } != 1"
# Ordninary function call
if len(matches[0][1]) == 1:
match = matches[0][1]
return snippet[match[f"{cap_name}"][0].start_byte : match[f"{cap_name}"][0].end_byte] + "()"
# Instance method call (len(matches[0][1]) == 1)
else:
match = matches[0][1]
return snippet[match[f"obj.{cap_name}"][0].start_byte : match[f"{cap_name}"][0].end_byte] + "()"
case NodeType.MODULE.value:
return "MOD"
case "block":
return f"Blk({self.parent})"
case _:
msg = f"Processing Node of type {self.type}. Only {NodeType.FUNCTION_DEF.value}, {NodeType.CLASS_DEF.value} and {NodeType.CALL.value} possible."
raise Exception(msg)
# Makes debugging easier and networkX relies on __str__() for node labels
# I know, this is ...discouraged.
Node.__str__ = name
Node.__repr__ = name
Node.name = property(name) # type: ignore[reportAttributeAccessIssue]
def traverse_tree(tree: Tree, graph_parent: Node):
"""Traverse `tree` to collect certain nodes (currently function_definition, class_defintion, call).
Parameters
----------
tree: Tree
The `Tree` to traverse. Can be whole module/file or `body`/`block` of a `function_defintion`, `class_definiton`.
graph_parent: Node
The parent `Node` relevant for drawing an edge to.
"""
for c in tree.root_node.children:
if c.type in ("function_definition", "class_definition"):
G.add_edge(
graph_parent,
c,
ts_type=NodeType.FUNCTION_DEF if c.type == "function_definition" else NodeType.CLASS_DEF,
)
# Retrieve body field/block node of function/class_definition
cap = f"{c.type}.body"
class_body_q = f"""({c.type}
body: (block) @{cap})"""
# - set_max_start_depth(1): Ignore class_definition within current class_defintion
# - set_byte_range(c.byte_range): Capture the body of the current class only (because code file has two)
q = PYTHON_LANGUAGE.query(class_body_q).set_max_start_depth(1).set_byte_range(c.byte_range)
# Captures sequentially block of class A and B correctly
body_captures = q.captures(tree.root_node)
for body in body_captures[cap]:
# Build subtree of body (to extract function_defintion in class)
# Slice of code is correct (verified in)
code_body = CODE.encode()[slice(*body.byte_range)]
body_tree = PARSER.parse(code_body)
# prints
# b'def sub(self): ...\n def add(self): ...'
# [sub, add]
# b'def div(self): ...\n def mul(self): ...'
# [sub, add] <-- wrong
# instead of
# b'def sub(self): ...\n def add(self): ...'
# [sub, add]
# b'def div(self): ...\n def mul(self): ...'
# [div, mul] <-- right/expected
if body.name in ("Blk(A)", "Blk(B)"): # type: ignore[reportAttributeAccessIssue]
print(code_body)
print(body_tree.root_node.children)
traverse_tree(body_tree, c)
# Check for call nodes in unknown parent type (expression_statements, assignment)
else:
cap = "calls"
call_q = f"(call) @{cap}"
q = PYTHON_LANGUAGE.query(call_q)
call_captures = q.captures(tree.root_node)
if cap in call_captures and len(calls := call_captures[cap]) > 0:
G.add_edges_from(itertools.pairwise([graph_parent] + sorted(calls, key=lambda n: n.start_byte)), ts_type=NodeType.CALL)
traverse_tree(TREE, TREE.root_node)
# Draw Graph
plt.figure(figsize=(12, 10))
pos = nx.shell_layout(G)
nx.draw_networkx(
G,
pos=pos,
with_labels=True,
)
plt.savefig(".res/graph.png")
plt.show()
The resulted graph:
I expect the two children of B to be div and mul (not sub and add).
Edit
To easen reproducibility, the output of pip freeze:
asttokens==3.0.0
bigtree==0.28.0
contourpy==1.3.2
cycler==0.12.1
decorator==5.2.1
exceptiongroup==1.2.2
executing==2.2.0
fonttools==4.57.0
iniconfig==2.1.0
ipython==8.35.0
jedi==0.19.2
Jinja2==3.1.6
jsonpickle==4.0.5
kiwisolver==1.4.8
MarkupSafe==3.0.2
matplotlib==3.10.1
matplotlib-inline==0.1.7
networkx==3.4.2
numpy==2.2.5
packaging==25.0
parso==0.8.4
pexpect==4.9.0
pillow==11.2.1
pluggy==1.5.0
prompt_toolkit==3.0.51
ptyprocess==0.7.0
pure_eval==0.2.3
pydot==3.0.4
Pygments==2.19.1
pyparsing==3.2.3
pytest==8.3.5
python-dateutil==2.9.0.post0
pyvis==0.3.2
six==1.17.0
stack-data==0.6.3
tomli==2.2.1
traitlets==5.14.3
tree-sitter==0.24.0
tree-sitter-python==0.23.6
typing_extensions==4.13.2
wcwidth==0.2.13
Have you tried tree-sitter-graph (in Rust)?