PyTA Project: Converting Function Preconditions to Z3 Constraints

·

9 min read

Today's task is to update ExprWrapper, a module that converts a python expression to corresponding z3 expression, to support container classes like list , tuple, and set, and in operation. In this article, I will first provide a brief overview of z3 library and its use in PythonTA, and then explain my implementation.

Introduction to Z3 library

Z3 is a theorem prover library that can solve systems of equations. A complete introduction to Z3 can be found at this documentation. Here I will only cover features relevant to the task, including data types, arithmetic and boolean logic, functions, and solver.

Z3 contains three data types: integer, real number, and boolean value. Unlike float value in python, real numbers in Z3 are precise values, where rational numbers are represented in quotients and irrational numbers are represented as roots of polynomials. For example Q(1, 3) defines a real value 1/3, which differs from the float value 1/3:

We can define variables of different data types and use them in equations. The solve function is used to find a solution with the given constraints. If the system of equations is unsatisfiable, the function will return "no solution".

x = Int('x')
y = Real('y')
solve(x**2 + y**2 > 3, x**3 + y < 5)  # result: [y = 0, x = -2]

Z3 supports boolean operations including And, Or, Not, Implies, If, and equals (==). Note that Implies represents an "if-then" statement, where the second argument is true when the first one is true. If represents am "if-then-else" statement, where the second argument is returned if the condition is true, otherwise the third argument is returned. The code below shows their difference. Note that the return value is [] in the second example because the statement is vacuously true when the condition is false.

p = Bool('p')
q = Bool('q')
solve(Implies(True, q))    # [q = True]
solve(Implies(False, q))   # [] 
solve(If(True, p, q))      # [p = True]
solve(If(False, p, q))     # [q = True]

solve(Implies(And(p, q), q))    # [p = True, q = True]
solve(If(Or(p, q), p, q))      # [p = True, q = False]

A Solver is a general-purpose solver that can store and solve systems of equations. We can add a constraint to the solver with add method. check returns sat if the constraints has a solution, unsat if no solution exists, and unknown if the solver cannot solve the system. A constraint is called satisfiable if there exists a solution, and valid if it's true for all values. A constraint F is valid if Not(F) is unsatisfiable. The solution to a system is called a model, and is retrieved from model method. Note that model only returns one solution to the constraints. If we want to get all solutions to the system, we can iteratively call model on the Solver and add the solution to the constraints, until all solutions are found and the constraints are unsatisfiable (or, if there are infinitely many solutions, an infinite loop would occur).

s = Solver()
x = Int('x')
s.add(Or(x == 1, x == 2, x == 3))
solutions = []

while s.check() == sat:
    model = s.model()
    x_value = model.eval(x)
    solutions.append(x_value)
    s.add(x != x_value)

The Solver maintains a stack of constraints, and we can use push method to create a new scope to add additional constraints, and pop to descard constraints from the previous push. We can consider this as analogous to git branches, where push is like creating a new branch from the main branch, and pop reverts to the main branch and discards the changes.

x = Int('x')
y = Int('y')

s = Solver()
s.add(x > 10, y == x + 2)
print (s)    # [x > 10, y == x + 2]
print (s.check())    # sat

print ("Create a new scope...")
s.push()    
s.add(y < 11)
print (s)    # [x > 10, y == x + 2, y < 11]
print (s.check())    # unsat

s.pop()
print (s)    # [x > 10, y == x + 2]
print (s.check())    # sat

PythonTA use of Z3

One feature PythonTA is planning to incorporate is to apply Z3 library detect logical errors statically. For example, in the code below, since the function precondition indicates x > 0 , the else branch should not be executed if the function is called correctly. Thus, the code has a logical error.

def f(x: int) -> int:
    """Precondition: x > 0"""
    if x > 0:
        return 1
    else:
        return 0  # This is unreachable based on the precondition

The module z3_visitor.py is used to convert function preconditions defined in the function docstring to a list of z3 constraints, and store these constraints as an attribute z3_constraints in the FunctionDef node. It registers the method set_function_def_z3_constraints to Astroid as a transform, which will be invoked whenever a FunctionDef node is being visited.

class Z3Visitor:
    """
    The class responsible for visiting astroid nodes (currently only FunctionDef nodes),
    parsing preconditions, and converting them to z3 expressions to be appended in the
    z3_constraints attribute of the node.
    """

    def __init__(self):
        """Return a TransformVisitor that sets an environment for every node."""
        visitor = TransformVisitor()
        # Register transforms
        visitor.register_transform(nodes.FunctionDef, self.set_function_def_z3_constraints)
        self.visitor = visitor

    def set_function_def_z3_constraints(self, node: nodes.FunctionDef):
        # Parse types
        types = {}
        annotations = node.args.annotations
        arguments = node.args.args
        for ann, arg in zip(annotations, arguments):
            if ann is None:
                continue
            # TODO: what to do about subscripts ex. Set[int], List[Set[int]], ...
            inferred = ann.inferred()
            if len(inferred) > 0 and inferred[0] is not Uninferable:
                if isinstance(inferred[0], nodes.ClassDef):
                    types[arg.name] = inferred[0].name
        # Parse preconditions
        preconditions = parse_assertions(node, parse_token="Precondition")
        # Get z3 constraints
        z3_constraints = []
        for pre in preconditions:
            pre = astroid.parse(pre).body[0]
            ew = ExprWrapper(pre, types)
            try:
                transformed = ew.reduce()
            except (Z3Exception, Z3ParseException):
                transformed = None
            if transformed is not None:
                z3_constraints.append(transformed)
        # Set z3 constraints
        node.z3_constraints = z3_constraints
        return node

set_function_def_z3_constraints first parses the function arguments by inferring from the arguments' type annotations. Then, it retrieves the part of function docstring under "Preconsition" as the function preconditions. Finally, it passes the argument types and preconditions to ExprWrapper and invokes its reduce() method to get the corresponding z3 constraints of the preconditions.

ExprWrapper is the class that converts Astroid nodes to Z3 constraints, and is the class we will be working on. Currently, the reduce method in ExprWrapper parses data types int, float, and bool , operators including boolean operations (and, or, not), unary operation (not), binary operations (+, -, , /, **, <=, >=, <, >, ==), and function names. To handle nested expressions, each parsing method recursively calls reduce on the expression components, where a single name or constant value are the base case.

    def reduce(self, node: astroid.NodeNG = None) -> z3.ExprRef:
        """
        Convert astroid node to z3 expression and return it.
        If an error is encountered or a case is not considered, return None.
        """
        if node is None:
            node = self.node

        if isinstance(node, nodes.BoolOp):
            node = self.parse_bool_op(node)
        elif isinstance(node, nodes.UnaryOp):
            node = self.parse_unary_op(node)
        elif isinstance(node, nodes.Compare):
            node = self.parse_compare(node)
        elif isinstance(node, nodes.BinOp):
            node = self.parse_bin_op(node)
        elif isinstance(node, nodes.Const):
            node = node.value
        elif isinstance(node, nodes.Name):
            node = self.apply_name(node.name)
        else:
            raise Z3ParseException(f"Unhandled node type {type(node)}.")

        return node

Extending ExprWrapper to support container classes

To complete our task, we need to add a parse method for list/set/tuple, and extend the method apply_bin_op to support in and not in operations. Note that not in should be interpreted as a single expression, not as a combination of not and in , as it is represented as a single binop node. We can interpret in and not in as follows:

a in lst -> a == lst[0] or a == lst[1] or ... or a == lst[n]
            any(a == e for e in lst) # written in Python expression

a not in lst -> a != lst[0] and a != lst[1] and ... and a != lst[n]
                all(a != e for e in lst) # written in Python expression

The code below shows the expanded apply_bin_op method. z3.Or and z3.And can take any number of parameters, so we construct the expression with a comprehension and use * to break the list to a sequence of arguments.

    def apply_bin_op(
        self, left: z3.ExprRef, op: str, right: Union[z3.ExprRef, List[z3.ExprRef]]
    ) -> z3.ExprRef:
        """Given left, right, op, apply the binary operation."""
        try:
            if op == "+":
                return left + right
            elif op == "-":
                return left - right
            elif op == "*":
                return left * right
            elif op == "/":
                return left / right
            elif op == "**":
                return left**right
            elif op == "==":
                return left == right
            elif op == "<=":
                return left <= right
            elif op == ">=":
                return left >= right
            elif op == "<":
                return left < right
            elif op == ">":
                return left > right
            elif op == "in":
                return z3.Or(*[left == element for element in right])
            elif op == "not in":
                return z3.And(*[left != element for element in right])
            else:
                raise Z3ParseException(f"Unhandled binary operation {op}.")
        except TypeError:
            raise Z3ParseException(f"Operation {op} incompatible with types {left} and {right}.")

We also need a method to convert a List/Set/Tuple AST node to a list of Z3 expressions of the container elements. Here we directly store the Z3 expressions in a regular Python list. Though Z3 provides expressions like Array to store a list of elements, according to its documentation Array should only be used to handle unbounded or very large lists due to its inefficiency, and would be an overkill for small lists with fixed elements.

    def parse_container_op(
        self, node: Union[nodes.List, nodes.Set, nodes.Tuple]
    ) -> List[z3.ExprRef]:
        """Convert an astroid List, Set, Tuple node to a list of z3 expressions."""
        return [self.reduce(element) for element in node.elts]

Finally, we need to include parse_container_op to the reduce method.

    def reduce(self, node: astroid.NodeNG = None) -> z3.ExprRef:
        """
        Convert astroid node to z3 expression and return it.
        If an error is encountered or a case is not considered, return None.
        """
        if node is None:
            node = self.node

        if isinstance(node, nodes.BoolOp):
            node = self.parse_bool_op(node)
        elif isinstance(node, nodes.UnaryOp):
            node = self.parse_unary_op(node)
        elif isinstance(node, nodes.Compare):
            node = self.parse_compare(node)
        elif isinstance(node, nodes.BinOp):
            node = self.parse_bin_op(node)
        elif isinstance(node, nodes.Const):
            node = node.value
        elif isinstance(node, nodes.Name):
            node = self.apply_name(node.name)
        elif isinstance(node, (nodes.List, nodes.Tuple, nodes.Set)):
            node = self.parse_container_op(node)
        else:
            raise Z3ParseException(f"Unhandled node type {type(node)}.")

        return node

After completing the update, we need to write a unit test in test_z3_visitor to test the visitor's performance on container types.

def test_container_constraints():
    z3v = Z3Visitor()
    code = """
    def f(x: int):
        '''
        Preconditions:
            - x in [1, 2, 3]
            - x not in {4, 5, 6, 7, 8}
            - x in (1, 2)
        '''
        pass
    """
    mod = z3v.visitor.visit(astroid.parse(code))
    function_def = mod.body[0]
    # Declare variables
    x = z3.Int("x")
    # Construct expected
    expected = [
        z3.Or(x == 1, x == 2, x == 3),
        z3.And(x != 4, x != 5, x != 6, x != 7, x != 8),
        z3.Or(x == 1, x == 2),
    ]
    actual = function_def.z3_constraints
    assert actual != []
    for e, a in zip(expected, actual):
        solver = z3.Solver()
        solver.add(e == a)
        assert solver.check() == z3.sat

Did you find this article valuable?

Support Raine's blog by becoming a sponsor. Any amount is appreciated!