Support LSP's textDocument/definition
Hi, I'm glad I found a language server for SMT-LIB.
Currently, the most important feature for me is go-to-definition
("textDocument/definition", so I can quickly navigate my formulas.
This includes mainly names declared by declare-* or define-* statements,
and maybe also let bindings.
I've hacked together my own language server in Python that works pretty well for me. But I'm always happy to get rid of that and move to a real solution. I'm not sure when I'll get to delve into OCaml, but for now I'll leave my Python implementation here, in case it helps anyone. There's some boilerplate - the actual routine to find definitions is really simple (and inefficient). Let me know if I can help with examples/explanation.
(The next logical step would be to implement "textDocument/references", to list all uses of a name)
#!/usr/bin/env python3
# Dependencies:
# pip install python-language-server
import sys
import logging
import threading
from pyls import _utils, uris
from pyls_jsonrpc.dispatchers import MethodDispatcher
from pyls_jsonrpc.endpoint import Endpoint
from pyls_jsonrpc.streams import JsonRpcStreamReader, JsonRpcStreamWriter
from pyls.workspace import Workspace
"""
Toy language server that implements textDocument/definition
For example, given this file
```smt2
(declare-const x Int)
(assert (= x 123))
```
if the cursor is on the "x" in line 3, textDocument/definition will return
the position of the x in line 1.
"""
# SMTLIBLanguageServer is adadpted from pyls
log = logging.getLogger(__name__)
PARENT_PROCESS_WATCH_INTERVAL = 10 # 10 s
MAX_WORKERS = 64
class SMTLIBLanguageServer(MethodDispatcher):
""" Implementation of the Microsoft VSCode Language Server Protocol
https://github.com/Microsoft/language-server-protocol/blob/master/versions/protocol-1-x.md
"""
def __init__(self, rx, tx, check_parent_process=False):
self.workspace = None
self.root_uri = None
self.watching_thread = None
self.workspaces = {}
self.uri_workspace_mapper = {}
self._jsonrpc_stream_reader = JsonRpcStreamReader(rx)
self._jsonrpc_stream_writer = JsonRpcStreamWriter(tx)
self._check_parent_process = check_parent_process
self._endpoint = Endpoint(
self, self._jsonrpc_stream_writer.write, max_workers=MAX_WORKERS)
self._dispatchers = []
self._shutdown = False
def start(self):
"""Entry point for the server."""
self._jsonrpc_stream_reader.listen(self._endpoint.consume)
def __getitem__(self, item):
"""Override getitem to fallback through multiple dispatchers."""
if self._shutdown and item != 'exit':
# exit is the only allowed method during shutdown
log.debug("Ignoring non-exit method during shutdown: %s", item)
raise KeyError
try:
return super(SMTLIBLanguageServer, self).__getitem__(item)
except KeyError:
# Fallback through extra dispatchers
for dispatcher in self._dispatchers:
try:
return dispatcher[item]
except KeyError:
continue
raise KeyError()
def m_shutdown(self, **_kwargs):
self._shutdown=True
return None
def m_exit(self, **_kwargs):
self._endpoint.shutdown()
self._jsonrpc_stream_reader.close()
self._jsonrpc_stream_writer.close()
def _match_uri_to_workspace(self, uri):
workspace_uri=_utils.match_uri_to_workspace(uri, self.workspaces)
return self.workspaces.get(workspace_uri, self.workspace)
def capabilities(self):
server_capabilities={
"definitionProvider": True,
}
log.info('Server capabilities: %s', server_capabilities)
return server_capabilities
def m_initialize(self, processId=None, rootUri=None, rootPath=None, initializationOptions=None, **_kwargs):
log.debug('Language server initialized with %s %s %s %s',
processId, rootUri, rootPath, initializationOptions)
if rootUri is None:
rootUri=uris.from_fs_path(
rootPath) if rootPath is not None else ''
self.workspaces.pop(self.root_uri, None)
self.root_uri = rootUri
self.workspace = Workspace(rootUri, self._endpoint, None)
self.workspaces[rootUri] = self.workspace
if self._check_parent_process and processId is not None and self.watching_thread is None:
def watch_parent_process(pid):
# exit when the given pid is not alive
if not _utils.is_process_alive(pid):
log.info("parent process %s is not alive, exiting!", pid)
self.m_exit()
else:
threading.Timer(PARENT_PROCESS_WATCH_INTERVAL,
watch_parent_process, args=[pid]).start()
self.watching_thread = threading.Thread(
target=watch_parent_process, args=(processId,))
self.watching_thread.daemon = True
self.watching_thread.start()
return {'capabilities': self.capabilities()}
def m_initialized(self, **_kwargs):
pass
def m_text_document__definition(self, textDocument=None, position=None, **_kwargs):
doc_uri = textDocument["uri"]
workspace = self._match_uri_to_workspace(doc_uri)
doc = workspace.get_document(doc_uri) if doc_uri else None
return smt_definition(doc, position)
def m_text_document__did_close(self, textDocument=None, **_kwargs):
pass
def m_text_document__did_open(self, textDocument=None, **_kwargs):
pass
def m_text_document__did_change(self, contentChanges=None, textDocument=None, **_kwargs):
pass
def m_text_document__did_save(self, textDocument=None, **_kwargs):
pass
def m_text_document__completion(self, textDocument=None, **_kwargs):
pass
def flatten(list_of_lists):
return [item for lst in list_of_lists for item in lst]
def merge(list_of_dicts):
return {k: v for dictionary in list_of_dicts for k, v in dictionary.items()}
def smt_definition(document, position):
pos = definition(document.source, position["line"], position["character"])
if pos is None:
return None
line, col, token = pos
offset = 1 if len(token) == 1 else (len(token) + 1)
if col == 0:
line -= 1
col = len(document.lines[line]) - offset
else:
col = col - offset
return {
'uri': document.uri,
'range': {
'start': {'line': line, 'character': col},
'end': {'line': line, 'character': col},
}
}
def definition(source, cursor_line, cursor_character):
nodes = list(parser().parse_smtlib(source))
node_at_cursor = find_leaf_node_at(cursor_line, cursor_character, nodes)
line, col, node = find_definition_for(node_at_cursor, nodes)
if node is None:
return None
return line, col, node_at_cursor
def find_leaf_node_at(line, col, nodes):
prev_line_end = -1
prev_col_end = -1
needle = (line, col)
for line_end, col_end, node in nodes:
prev_range = (prev_line_end-1, prev_col_end)
cur_range = (line_end, col_end)
if prev_range < needle < cur_range:
if isinstance(node, str):
return node
else:
node_at = find_leaf_node_at(line, col, node)
assert node_at is not None
return node_at
prev_line_end = line_end
prev_col_end = col_end
return None
def stripprefix(x, prefix):
if x.startswith(prefix):
return x[len(prefix):]
return x
def find_definition_for(needle, nodes):
for node in nodes:
line_end, col_end, n = node
_, _, head = n[0]
if not head.startswith("declare-") and not head.startswith("define-"):
continue
_, _, symbol = n[1]
if head in ("declare-const", "define-const", "declare-fun", "define-fun", "define-fun-rec", "declare-datatype"):
if symbol == needle:
return n[1]
continue
if head in ("declare-datatypes", "define-funs-rec"):
for i, tmp in enumerate(symbol):
_, _, type_parameter_declaration = tmp
_, _, type_name = type_parameter_declaration[0]
if type_name == needle:
return type_parameter_declaration[0]
if head == "declare-datatypes":
constructor = dfs(needle, node)
if constructor is not None:
return constructor
constructor = dfs(stripprefix(needle, "is-"), node)
if constructor is not None:
return constructor
continue
assert f"unsupported form: {head}"
return -1, -1, None
def dfs(needle, node):
assert isinstance(node, tuple)
_, _, n = node
if isinstance(n, str):
if n == needle:
return node
else:
return None
for child in n:
found = dfs(needle, child)
if found is not None:
return found
return None
class parser:
def __init__(self):
self.pos = 0
self.line = 0
self.col = -1
self.text = None
def nextch(self):
char = self.text[self.pos]
self.pos += 1
self.col += 1
if char == "\n":
self.line += 1
self.col = 0
return char
def parse_smtlib(self, text):
assert self.text is None
self.text = text
return self.parse_smtlib_aux()
def parse_smtlib_aux(self):
exprs = []
cur_expr = None
size = len(self.text)
while self.pos < size:
char = self.nextch()
# Stolen from ddSMT's parser. Not fully SMT-LIB compliant but good enough.
# String literals/quoted symbols
if char in ('"', '|'):
first_char = char
literal = [char]
# Read until terminating " or |
while True:
if self.pos >= size:
return
char = self.nextch()
literal.append(char)
if char == first_char:
# Check is quote is escaped "a "" b" is one string literal
if char == '"' and self.pos < size and self.text[self.pos] == '"':
literal.append(self.text[self.pos])
self.nextch()
continue
break
cur_expr.append((self.line, self.col, literal))
continue
# Comments
if char == ';':
# Read until newline
while self.pos < size:
char = self.nextch()
if char == '\n':
break
continue
# Open s-expression
if char == '(':
cur_expr = []
exprs.append(cur_expr)
continue
# Close s-expression
if char == ')':
cur_expr = exprs.pop()
# Do we have nested s-expressions?
if exprs:
exprs[-1].append((self.line, self.col, cur_expr))
cur_expr = exprs[-1]
else:
yield self.line, self.col, cur_expr
cur_expr = None
continue
# Identifier
if char not in (' ', '\t', '\n'):
token = [char]
while True:
if self.pos >= size:
return
char = self.text[self.pos]
if char in ('(', ')', ';'):
break
self.nextch()
if char in (' ', '\t', '\n'):
break
token.append(char)
token = ''.join(token)
# Append to current s-expression
if cur_expr is not None:
cur_expr.append((self.line, self.col, token))
else:
yield self.line, self.col, token
def serve():
stdin = sys.stdin.buffer
stdout = sys.stdout.buffer
server = SMTLIBLanguageServer(stdin, stdout)
server.start()
if __name__ == "__main__":
if len(sys.argv) >= 2 and sys.argv[1] == "definition":
line = int(sys.argv[2])
col = int(sys.argv[3])
print(definition(sys.stdin.read(), line, col))
else:
serve()
Hi, That's very interesting to know ! It's very useful to me to know what features to prioritize, so don't hesitate to report what you need/want Dolmen (and the LSP server) to do, ^^
After thinking about it this week, I now have a fairly good idea of how I want to implement the goto definition in Dolmen, and I'll get on that as soon as I have the time.