py-tree-sitter icon indicating copy to clipboard operation
py-tree-sitter copied to clipboard

Incorret Parsing of `body` field of `class_defintion`

Open PhilippFeO opened this issue 7 months ago • 1 comments

I want to build a call graph using tree-sitter. While doing so, I encountered an issue that building a tree out of the body part of a class_definition doesn't work. I added comments to the code, so I will be short in this prosa.

Any preliminary indication whether is wrong usage of py-tree-sitter or a bug in the implementation is highly appreciated since the task is work related and no spare time project.


The source code:

class A:
    def sub(self): ...
    def add(self): ...


class B:
    def div(self): ...
    def mul(self): ...


def f(): ...


def g(): ...


print()
list()
A()

My tree-sitter code: nodetype.py:

from enum import Enum


class NodeType(Enum):
    CALL = "call"
    CLASS_DEF = "class_definition"
    FUNCTION_DEF = "function_definition"
    IF = "if_statement"
    MODULE = "module"

and main.py:

import itertools
import matplotlib.pyplot as plt
import networkx as nx
from pathlib import Path
from tree_sitter import Language, Parser, Tree, Node
import tree_sitter_python
from nodetype import NodeType

PYTHON_LANGUAGE = Language(tree_sitter_python.language())

PARSER = Parser(PYTHON_LANGUAGE)

source_file = Path("./tscg/cursor_source_code.py")
CODE = source_file.read_text()
TREE = PARSER.parse(CODE.encode())

G = nx.DiGraph()


def name(self) -> str:
    """Get the name of the current node.

    Only possible if there is an `identifier` Node within it's children.
    """

    cap_name = "idf"
    snippet = CODE
    tree_snippet = PARSER.parse(snippet.encode(encoding="utf8"))
    # Class Module overwrites this method, hence no case here
    match self.type:
        case NodeType.FUNCTION_DEF.value | NodeType.CLASS_DEF.value:
            q = f"({self.type} name: (identifier) @{cap_name})"
            query = PYTHON_LANGUAGE.query(q).set_byte_range((self.start_byte, self.end_byte))
            captures = query.captures(tree_snippet.root_node)
            assert len(captures[cap_name]) == 1, f"More than one identifier node for Node {self.tsn} was captured."
            try:
                capture = captures[cap_name][0]
                return snippet[capture.start_byte : capture.end_byte]
            except KeyError:
                msg = f"No `identifier` Node available in {self.tsn}, hence no Name."
                raise Exception(msg)

        case NodeType.CALL.value:
            q = (
                # Ordinary calls: f()
                f"""({self.type}
                     function: (identifier) @{cap_name}
                    )"""
                f"""({self.type}
                        function: (attribute
                            object: (identifier) @obj.{cap_name}
                            attribute: (identifier) @{cap_name}
                        )
                    )
                """
            )
            query = PYTHON_LANGUAGE.query(q).set_byte_range((self.byte_range))
            matches = query.matches(tree_snippet.root_node)
            assert len(matches) == 1, f"{len(matches) = } != 1"
            # Ordninary function call
            if len(matches[0][1]) == 1:
                match = matches[0][1]
                return snippet[match[f"{cap_name}"][0].start_byte : match[f"{cap_name}"][0].end_byte] + "()"
            # Instance method call (len(matches[0][1]) == 1)
            else:
                match = matches[0][1]
                return snippet[match[f"obj.{cap_name}"][0].start_byte : match[f"{cap_name}"][0].end_byte] + "()"
        case NodeType.MODULE.value:
            return "MOD"
        case "block":
            return f"Blk({self.parent})"
        case _:
            msg = f"Processing Node of type {self.type}. Only {NodeType.FUNCTION_DEF.value}, {NodeType.CLASS_DEF.value} and {NodeType.CALL.value} possible."
            raise Exception(msg)


# Makes debugging easier and networkX relies on __str__() for node labels
# I know, this is ...discouraged.
Node.__str__ = name
Node.__repr__ = name
Node.name = property(name)  # type: ignore[reportAttributeAccessIssue]


def traverse_tree(tree: Tree, graph_parent: Node):
    """Traverse `tree` to collect certain nodes (currently function_definition, class_defintion, call).

    Parameters
    ----------
    tree: Tree
        The `Tree` to traverse. Can be whole module/file or `body`/`block` of a `function_defintion`, `class_definiton`.
    graph_parent: Node
        The parent `Node` relevant for drawing an edge to.
    """
    for c in tree.root_node.children:
        if c.type in ("function_definition", "class_definition"):
            G.add_edge(
                graph_parent,
                c,
                ts_type=NodeType.FUNCTION_DEF if c.type == "function_definition" else NodeType.CLASS_DEF,
            )
            # Retrieve body field/block node of function/class_definition
            cap = f"{c.type}.body"
            class_body_q = f"""({c.type}
                                body: (block) @{cap})"""
            # - set_max_start_depth(1): Ignore class_definition within current class_defintion
            # - set_byte_range(c.byte_range): Capture the body of the current class only (because code file has two)
            q = PYTHON_LANGUAGE.query(class_body_q).set_max_start_depth(1).set_byte_range(c.byte_range)
            # Captures sequentially block of class A and B correctly
            body_captures = q.captures(tree.root_node)
            for body in body_captures[cap]:
                # Build subtree of body (to extract function_defintion in class)
                # Slice of code is correct (verified in)
                code_body = CODE.encode()[slice(*body.byte_range)]
                body_tree = PARSER.parse(code_body)
                # prints
                #   b'def sub(self): ...\n    def add(self): ...'
                #   [sub, add]
                #   b'def div(self): ...\n    def mul(self): ...'
                #   [sub, add]  <-- wrong
                # instead of
                #   b'def sub(self): ...\n    def add(self): ...'
                #   [sub, add]
                #   b'def div(self): ...\n    def mul(self): ...'
                #   [div, mul]  <-- right/expected
                if body.name in ("Blk(A)", "Blk(B)"):  # type: ignore[reportAttributeAccessIssue]
                    print(code_body)
                    print(body_tree.root_node.children)
                traverse_tree(body_tree, c)
        # Check for call nodes in unknown parent type (expression_statements, assignment)
        else:
            cap = "calls"
            call_q = f"(call) @{cap}"
            q = PYTHON_LANGUAGE.query(call_q)
            call_captures = q.captures(tree.root_node)
            if cap in call_captures and len(calls := call_captures[cap]) > 0:
                G.add_edges_from(itertools.pairwise([graph_parent] + sorted(calls, key=lambda n: n.start_byte)), ts_type=NodeType.CALL)


traverse_tree(TREE, TREE.root_node)

# Draw Graph
plt.figure(figsize=(12, 10))
pos = nx.shell_layout(G)
nx.draw_networkx(
    G,
    pos=pos,
    with_labels=True,
)
plt.savefig(".res/graph.png")
plt.show()

The resulted graph: Image

I expect the two children of B to be div and mul (not sub and add).


Edit

To easen reproducibility, the output of pip freeze:

asttokens==3.0.0
bigtree==0.28.0
contourpy==1.3.2
cycler==0.12.1
decorator==5.2.1
exceptiongroup==1.2.2
executing==2.2.0
fonttools==4.57.0
iniconfig==2.1.0
ipython==8.35.0
jedi==0.19.2
Jinja2==3.1.6
jsonpickle==4.0.5
kiwisolver==1.4.8
MarkupSafe==3.0.2
matplotlib==3.10.1
matplotlib-inline==0.1.7
networkx==3.4.2
numpy==2.2.5
packaging==25.0
parso==0.8.4
pexpect==4.9.0
pillow==11.2.1
pluggy==1.5.0
prompt_toolkit==3.0.51
ptyprocess==0.7.0
pure_eval==0.2.3
pydot==3.0.4
Pygments==2.19.1
pyparsing==3.2.3
pytest==8.3.5
python-dateutil==2.9.0.post0
pyvis==0.3.2
six==1.17.0
stack-data==0.6.3
tomli==2.2.1
traitlets==5.14.3
tree-sitter==0.24.0
tree-sitter-python==0.23.6
typing_extensions==4.13.2
wcwidth==0.2.13

PhilippFeO avatar May 09 '25 10:05 PhilippFeO

Have you tried tree-sitter-graph (in Rust)?

ObserverOfTime avatar Jun 03 '25 07:06 ObserverOfTime