Python API
ast-grep's Python API is powered by PyO3. You can write Python to programmatically inspect and change syntax trees.
To try out ast-grep's Python API, you can use the online colab notebook.
Installation
ast-grep's Python library is distributed on PyPI. You can install it with pip.
pip install ast-grep-py
Core Concepts
The core concepts in ast-grep's Python API are:
SgRoot
: a class to parse a string into a syntax treeSgNode
: a node in the syntax tree
Make AST like a XML/HTML doc!
Using ast-grep's API is like web scraping using beautiful soup or pyquery. You can use SgNode
to traverse the syntax tree and collect information from the nodes.
A common workflow to use ast-grep's Python API is:
- Parse a string into a syntax tree by using
SgRoot
- Get the root node of the syntax tree by calling
root.root()
find
relevant nodes by using patterns or rules- Collect information from the nodes
Example:
from ast_grep_py import SgRoot
root = SgRoot("print('hello world')", "python") # 1. parse
node = root.root() # 2. get root
print_stmt = node.find(pattern="print($A)") # 3. find
print_stmt.get_match('A').text() # 4. collect information
# 'hello world'
SgRoot
The SgRoot
class has the following signature:
class SgRoot:
def __init__(self, src: str, language: str) -> None: ...
def root(self) -> SgNode: ...
__init__
takes two arguments: the first argument is the source code string, and the second argument is the language name. root
returns the root node of the syntax tree, which is an instance of SgNode
.
Example:
root = SgRoot("print('hello world')", "python") # 1. parse
node = root.root() # 2. get root
The code above parses the string print('hello world')
into a syntax tree, and gets the root node of the syntax tree.
The root node can be used to find other nodes in the syntax tree.
SgNode
SgNode
is the most important class in ast-grep's Python API. It provides methods to inspect and traverse the syntax tree. The following sections will introduce several methods in SgNode
.
Example:
node = root.root()
string = node.find(kind="string")
assert string # assume we can find a string node in the source
print(string.text())
Search
You can use find
and find_all
to search for nodes in the syntax tree.
find
returns the first node that matches the pattern or rule.find_all
returns a list of nodes that match the pattern or rule.
# Search
class SgNode:
@overload
def find(self, **kwargs: Unpack[Rule]) -> Optional[SgNode]: ...
@overload
def find_all(self, **kwargs: Unpack[Rule]) -> List[SgNode]: ...
@overload
def find(self, config: Config) -> Optional[SgNode]: ...
@overload
def find_all(self, config: Config) -> List[SgNode]: ...
find
has two overloads: one takes keyword arguments of Rule
, and the other takes a Config
object.
Search with Rule
Using keyword arguments rule is the most straightforward way to search for nodes.
The argument name is the key of a rule, and the argument value is the rule's value. You can passing multiple keyword arguments to find
to search for nodes that match all the rules.
root = SgRoot("print('hello world')", "python")
node = root.root()
node.find(pattern="print($A)") # will return the print function call
node.find(kind="string") # will return the string 'hello world'
# below will return print function call because it matches both rules
node.find(pattern="print($A)", kind="call")
# below will return None because the pattern cannot be a string literal
node.find(pattern="print($A)", kind="string")
strings = node.find_all(kind="string") # will return [SgNode("hello world")]
assert len(strings) == 1
Search with Config
You can also use a Config
object to search for nodes. This is similar to directly use YAML in the command line.
The main difference between using Config
and using Rule
is that Config
has more options to control the search behavior, like constraints
and utils
.
# will find a string node with text 'hello world'
root.root().find({
"rule": {
"pattern": "print($A)",
},
"constraints": {
"A": { "regex": "hello" }
}
})
# will return None because constraints are not satisfied
root.root().find({
"rule": {
"pattern": "print($A)",
},
"constraints": {
"A": { "regex": "no match" }
}
})
Match
Once we find a node, we can use the following methods to get meta variables from the search.
The get_match
method returns the single node that matches the single meta variable.
And the get_multiple_matches
returns a list of nodes that match the multi meta variable.
class SgNode:
def get_match(self, meta_var: str) -> Optional[SgNode]: ...
def get_multiple_matches(self, meta_var: str) -> List[SgNode]: ...
def __getitem__(self, meta_var: str) -> SgNode: ...
Example:
src = """
print('hello')
logger('hello', 'world', '!')
"""
root = SgRoot(src, "python").root()
node = root.find(pattern="print($A)")
arg = node.get_match("A") # returns SgNode('hello')
assert arg # assert node is found
arg.text() # returns 'hello'
# returns [] because $A and $$$A are different
node.get_multiple_matches("A")
logs = root.find(pattern="logger($$$ARGS)")
# returns [SgNode('hello'), SgNode(','), SgNode('world'), SgNode(','), SgNode('!')]
logs.get_multiple_matches("ARGS")
logs.get_match("A") # returns None
SgNode
also supports __getitem__
to get the match of single meta variable.
It is equivalent to get_match
except that it will either return SgNode
or raise a KeyError
if the match is not found.
Use __getitem__
to avoid unnecessary None
checks when you are using a type checker.
node = root.find(pattern="print($A)")
# node.get_match("A").text() # error: node.get_match("A") can be None
node["A"].text() # Ok
Inspection
The following methods are used to inspect the node.
# Node Inspection
class SgNode:
def range(self) -> Range: ...
def is_leaf(self) -> bool: ...
def is_named(self) -> bool: ...
def is_named_leaf(self) -> bool: ...
def kind(self) -> str: ...
def text(self) -> str: ...
Example:
root = SgRoot("print('hello world')", "python")
node = root.root()
node.text() # will return "print('hello world')"
Another important method is range
, which returns two Pos
object representing the start and end of the node.
One Pos
contains the line, column, and offset of that position. All of them are 0-indexed.
You can use the range information to locate the source and modify the source code.
rng = node.range()
pos = rng.start # or rng.end, both are `Pos` objects
pos.line # 0, line starts with 0
pos.column # 0, column starts with 0
rng.end.index # 17, index starts with 0
Refinement
You can also filter nodes after matching by using the following methods.
This is dubbed as "refinement" in the documentation. Note these refinement methods only support using Rule
.
# Search Refinement
class SgNode:
def matches(self, **rule: Unpack[Rule]) -> bool: ...
def inside(self, **rule: Unpack[Rule]) -> bool: ...
def has(self, **rule: Unpack[Rule]) -> bool: ...
def precedes(self, **rule: Unpack[Rule]) -> bool: ...
def follows(self, **rule: Unpack[Rule]) -> bool: ...
Example:
node = root.find(pattern="print($A)")
if node["A"].matches(kind="string"):
print("A is a string")
Traversal
You can traverse the tree using the following methods, like using pyquery.
# Tree Traversal
class SgNode:
def get_root(self) -> SgRoot: ...
def field(self, name: str) -> Optional[SgNode]: ...
def parent(self) -> Optional[SgNode]: ...
def child(self, nth: int) -> Optional[SgNode]: ...
def children(self) -> List[SgNode]: ...
def ancestors(self) -> List[SgNode]: ...
def next(self) -> Optional[SgNode]: ...
def next_all(self) -> List[SgNode]: ...
def prev(self) -> Optional[SgNode]: ...
def prev_all(self) -> List[SgNode]: ...
Fix code
SgNode
is immutable so it is impossible to change the code directly.
However, SgNode
has a replace
method to generate an Edit
object. You can then use the commitEdits
method to apply the changes and generate new source string.
class Edit:
# The start position of the edit
start_pos: int
# The end position of the edit
end_pos: int
# The text to be inserted
inserted_text: str
class SgNode:
# Edit
def replace(self, new_text: str) -> Edit: ...
def commit_edits(self, edits: List[Edit]) -> str: ...
Example
root = SgRoot("print('hello world')", "python").root()
node = root.find(pattern="print($A)")
edit = node.replace("logger.log('bye world')")
new_src = node.commit_edits([edit])
# "logger.log('bye world')"
Note, logger.log($A)
will not generate logger.log('hello world')
in Python API unlike the CLI. This is because using the host language to generate the replacement string is more flexible.
WARNING
Metavariable will not be replaced in the replace
method. You need to create a string using get_match(var_name)
by using Python.
See also ast-grep#1172