mo-sql-parsing
mo-sql-parsing copied to clipboard
Link back information about position in the original string
To analyse SQL in the editor and provide some syntax checking it would be great for each branch or leaf in the tree to know where it starts and ends in the original document.
Is there a quick way to return e.g. TreeDict(dict)
and TreeList(list)
classes that would have start
and end
location?
I can try and do it myself if you can point me in the right direction. Seems like scrub()
function should do it, but it's not very obvious.
That sounds like an interesting feature! Looking at the scrub
method, it does appear to have all the information you need to get this done.
I am not sure how you would like to include the start/end properties. Maybe you can make scrub emit objects instead of dict/lists. Then properties are like usual, but you have attributes to store start/end
>>> class A(dict):
... pass
...
>>> a = A()
>>> setattr(a, "start", 42)
>>> a['start']=43
>>> a['start']
43
>>> a.start
42
No matter what your choice, you can patch the scrub method without making a PR:
>>> def my_new_scrub_method():
... pass
...
>>> from mo_sql_parsing import utils
>>> utils.scrub = my_new_scrub_method
If you do make a PR, be sure to make it an option
- a parameter in the parser call, or
parse(sql, add_positions=True)
- a method you run to do the override above, or
add_positions() parse(sql)
- a context manager
with add_positions(): parse(sql)
I will leave it with you for now.
Thanks! I don't have time for the proper PR yet, but here is the patch I made in case you or anyone else might need it.
There are scenarios when there is no correct start
and get
(some returns from Call
, identifiers, etc.) I try fixing some of the cases with _fix()
.
from __future__ import annotations
import mo_sql_parsing
import mo_parsing.utils
from mo_future import text, number_types, binary_type
from mo_parsing import *
from mo_sql_parsing.utils import scrub_op, SQL_NULL, Call
from collections.abc import Generator
class SqlTree:
start: int = None
end: int = None
@staticmethod
def get_meta_keys() -> tuple[str, str]:
return ("start", "end")
def set_meta(self, **kwargs) -> SqlTree:
for k in kwargs:
assert k in SqlTree.get_meta_keys()
setattr(self, k, kwargs[k])
return self
def get_meta(self) -> dict[str, object]:
return {k: getattr(self, k, None) for k in SqlTree.get_meta_keys()}
def copy_meta(self, other: SqlTree) -> SqlTree:
return self.set_meta(**other.get_meta())
class SqlList(list, SqlTree):
def items(self) -> Generator[tuple[int, object]]:
return enumerate(self)
class SqlDict(dict, SqlTree):
pass
class SqlValue(SqlTree):
value: str | int = None
def __init__(self, value: str | int):
self.value = value
def __str__(self) -> str:
return self.value.__str__()
def __repr__(self) -> str:
return self.value.__repr__()
def items(self) -> Generator[tuple[int, str | int]]:
return enumerate([self.value])
flat_keys = "name", "value", "all_columns"
transpile_map = {list: SqlList, dict: SqlDict, int: SqlValue, str: SqlValue}
# Results of original `scrub` function is appended to a list, so apply it after `_parse`
mo_sql_parsing.utils.scrub = lambda x: x
def parse(sql: str) -> SqlTree:
parsed = mo_sql_parsing.parse(sql)
tree = _transpile(parsed)
tree = _squash(tree)
_fix(tree)
_check(tree, sql)
return tree
def _check(tree: SqlTree, sql: str) -> None:
meta = tree.get_meta()
if all(v is not None for v in meta.values()):
print(str(tree)[0:20], " === ", sql[meta["start"] : meta["end"]])
if isinstance(tree, (SqlDict, SqlList)):
for k, v in tree.items():
_check(v, sql)
def _fix(tree: SqlTree) -> None:
# TODO: fix parser to always have start and end without these hacks
if isinstance(tree, SqlList):
start = (r.start for r in tree if r.start is not None)
end = (r.end for r in tree if r.end is not None)
tree.set_meta(start=min(start, default=None), end=max(end, default=None))
elif isinstance(tree, SqlDict) and len(tree) == 1:
child = list(tree.keys()).pop()
if not all(v is not None for v in tree[child].get_meta().values()):
tree[child].copy_meta(tree)
if isinstance(tree, SqlTree):
for k, v in tree.items():
_fix(v)
def _squash(dirty: SqlTree, parent=None) -> SqlTree:
global flat_keys
# Recursively clean up the tree
if isinstance(dirty, SqlList):
clean = [_squash(r, dirty) for r in dirty]
clean = [r for r in clean if r is not None]
if len(clean) > 1:
clean = SqlList(clean)
elif len(clean) == 1 and clean[0] is not None:
clean = clean[0]
else:
clean = None
elif isinstance(dirty, SqlDict):
clean = {k: _squash(v, dirty) for k, v in dirty.items()}
clean = {
k: v if isinstance(v, list) or k in flat_keys else SqlList([v]).copy_meta(v)
for k, v in clean.items()
if v is not None
}
clean = SqlDict(clean)
elif dirty is None or isinstance(dirty, SqlValue):
clean = dirty
else:
raise NotImplementedError(f"Not implemented for {dirty.__class__}")
# Preserve meta attributes
if clean is None:
return clean
elif all(v is not None for v in clean.get_meta().values()):
return clean
elif all(v is not None for v in dirty.get_meta().values()):
return clean.copy_meta(dirty)
elif all(v is not None for v in parent.get_meta().values()):
return clean.copy_meta(parent)
else:
return clean
def _transpile(dirty: object) -> SqlTree:
global transpile_map
loc_attrs, loc = ("start", "end"), {}
clean = None
# Parse depending on type
if dirty is SQL_NULL or dirty is None:
clean = None
elif isinstance(dirty, (text, number_types)):
# TODO: Simple tokens do not have `start` and `end`
clean = dirty
elif isinstance(dirty, binary_type):
clean = dirty.decode("utf8")
elif isinstance(dirty, list):
clean = [_transpile(r) for r in dirty]
elif isinstance(dirty, dict):
clean = {k: _transpile(v) for k, v in dirty.items()}
elif isinstance(dirty, Call):
kwargs = _transpile(dirty.kwargs)
args = _transpile(dirty.args)
clean = scrub_op(dirty.op, args, kwargs)
# TODO: Call object has no `start` and `end`
elif isinstance(dirty, mo_parsing.results.ForwardResults):
loc = {a: getattr(dirty, a, None) for a in loc_attrs}
clean = _transpile(dirty.tokens)
elif isinstance(dirty, mo_parsing.results.ParseResults):
loc = {a: getattr(dirty, a, None) for a in loc_attrs}
tokens = dict(dirty.items()) or dirty.tokens
clean = _transpile(tokens)
# TODO: "*" is {all_columns: {}}, while "tbl.*" is {all_columns: "tbl"}
# for consistency better {all_columns: ''}
# TODO: ParseResults often has `start=-1` and `end=0`
else:
raise NotImplementedError(f"Not implemented for {dirty.__class__}")
# Transpile to Sql classes
if clean.__class__ in transpile_map:
clean = transpile_map[clean.__class__](clean)
# Update meta attributes if captured
if loc and all(loc[v] is not None and loc[v] >= 0 for v in loc_attrs):
clean.set_meta(**loc)
return clean
Thank you. I made a branch: https://github.com/klahnakoski/mo-sql-parsing/tree/add-start-end
it will need tests