Z3Py - Get all formulas containing an expression - z3

In Z3Py, is it possible to get a list of all formulas where some variable/expression occurs?
An example:
s.add(x < 0)
s.add(x > y)
print s[y]
>>> x > y

The Z3 API does not contain this functionality. However, it can be implemented on top of the Z3 API. Here is a simple implementation:
from z3 import *
# Wrapper for allowing Z3 ASTs to be stored into Python Hashtables.
class AstRefKey:
def __init__(self, n):
self.n = n
def __hash__(self):
return self.n.hash()
def __eq__(self, other):
return self.n.eq(other.n)
def __repr__(self):
return str(self.n)
def askey(n):
assert isinstance(n, AstRef)
return AstRefKey(n)
# Return True iff f contains t.
def contains(f, a):
visited = set() # Set of already visited formulas
# We need to track the set of visited formulas because Z3 represents them as DAGs.
def loop(f):
if askey(f) in visited:
return False # already visited
visited.add(askey(f))
if f.eq(a):
return True
else:
for c in f.children():
if loop(c):
return True
return False
return loop(f)
# Return the set of assertions in s that contains x
def assertions_containing(s, x):
r = []
for f in s.assertions():
if contains(f, x):
r.append(f)
return r
x = Real('x')
y = Real('y')
s = Solver()
s.add(x < 0)
s.add(x > y)
print assertions_containing(s, y)
print assertions_containing(s, x)

Related

Why is Z3 giving me unsat for the following formula?

I have the following formula and Python code trying to find the largest n satisfying some property P:
x, u, n, n2 = Ints('x u n n2')
def P(u):
return Implies(And(2 <= x, x <= u), And(x >= 1, x <= 10))
nIsLargest = ForAll(n2, Implies(P(n2), n2 <= n))
exp = ForAll(x, And(P(n), nIsLargest))
s = SolverFor("LIA")
s.reset()
s.add(exp)
print(s.check())
if s.check() == sat:
print(s.model())
My expectation was that it would return n=10, yet Z3 returns unsat. What am I missing?
You're using the optimization API incorrectly; and your question is a bit confusing since your predicate P has a free variable x: Obviously, the value that maximizes it will depend on both x and u.
Here's a simpler example that can get you started, showing how to use the API correctly:
from z3 import *
def P(x):
return And(x >= 1, x <= 10)
n = Int('n')
opt = Optimize()
opt.add(P(n))
maxN = opt.maximize(n)
r = opt.check()
print(r)
if r == sat:
print("maxN =", maxN.value())
This prints:
sat
maxN = 10
Hopefully you can take this example and extend it your use case.

Use Z3 to find counterexamples for a 'guess solution' to a particular CHC system?

Suppose I have the following CHC system (I(x) is an unkown predicate):
(x = 0) -> I(x)
(x < 5) /\ I(x) /\ (x' = x + 1) -> I(x')
I(x) /\ (x >= 5) -> x = 5
Now suppose I guess a solution for I(x) as x < 2 (actually an incorrect guess). Can I write code for Z3Py to (i) check if it is a valid solution and (ii) if it's incorrect find a counterexample, i.e. a value which satisfies x < 2 but does not satisfy atleast one of the above 3 equations? (For eg: x = 1, is a counterexample as it does not satisfy the 2nd equation?)
Sure. The way to do this sort of reasoning is to assert the negation of the conjunction of all your constraints, and ask z3 if it can satisfy it. If the negation comes back satisfiable, then you have a counter-example. If it is unsat, then you know that your invariant is good.
Here's one way to code this idea, in a generic way, parameterized by the constraint generator and the guessed invariant:
from z3 import *
def Check(mkConstraints, I):
s = Solver()
# Add the negation of the conjunction of constraints
s.add(Not(mkConstraints(I)))
r = s.check()
if r == sat:
print("Not a valid invariant. Counter-example:")
print(s.model())
elif r == unsat:
print("Invariant is valid")
else:
print("Solver said: %s" % r)
Given this, we can code up your particular case in a function:
def System(I):
x, xp = Ints('x xp')
# (x = 0) -> I(x)
c1 = Implies(x == 0, I(x))
# (x < 5) /\ I(x) /\ (x' = x + 1) -> I(x')
c2 = Implies(And(x < 5, I(x), xp == x+1), I(xp))
# I(x) /\ (x >= 5) -> x = 5
c3 = Implies(And(I(x), x >= 5), x == 5)
return And(c1, c2, c3)
Now we can query it:
Check(System, lambda x: x < 2)
The above prints:
Not a valid invariant. Counter-example:
[xp = 2, x = 1]
showing that x=1 violates the constraints. (You can code so that it tells you exactly which constraint is violated as well, but I digress.)
What happens if you provide a valid solution? Let's see:
Check(System, lambda x: x <= 5)
and this prints:
Invariant is valid
Note that we didn't need any quantifiers, as top-level variables act as existentials in z3 and all we needed to do was to find if there's an assignment that violated the constraints.

Understanding quantifier traversing in Z3

I'm trying to understand traversing quantified formula in z3 (i'm using z3py). Have no idea how to pickup the quantified variables. For example in code shown below i'm trying to print the same formula and getting error.
from z3 import *
def traverse(e):
if is_quantifier(e):
var_list = []
if e.is_forall():
for i in range(e.num_vars()):
var_list.append(e.var_name(i))
return ForAll (var_list, traverse(e.body()))
x, y = Bools('x y')
fml = ForAll(x, ForAll (y, And(x,y)))
same_formula = traverse( fml )
print same_formula
With little search i got to know that z3 uses De Bruijn index and i have to get something like Var(1, BoolSort()). I can think of using var_sort() but how to get the formula to return the variable correctly. Stuck here for some time.
var_list is a list of strings, but ForAll expects a list of constants. Also, traverse should return e when it's not a quantifier. Here's a modified example:
from z3 import *
def traverse(e):
if is_quantifier(e):
var_list = []
if e.is_forall():
for i in range(e.num_vars()):
c = Const(e.var_name(i) + "-traversed", e.var_sort(i))
var_list.append(c)
return ForAll (var_list, traverse(e.body()))
else:
return e
x, y = Bools('x y')
fml = ForAll(x, ForAll (y, And(x,y)))
same_formula = traverse( fml )
print(same_formula)

Divide a Point in Elliptic Curve Cryptography

I'm using Elliptic Curve to design a security system. P is a point on elliptic curve. The receiver must obtain P using formula k^-1(kP). The receiver does not know P but knows k. I need to compute k^-1(R) where R=kP. How can I do this using Point Multiplication or Point Addition.
I suggest first learning a bit more about ECC (for example, read some of Paar's book and listen to his course at http://www.crypto-textbook.com/) before tackling something this complex. For this particular question, ask yourself: "What does the inverse of k mean?"
Very interesting question you have! I was happy to implement from scratch Python solution for your task, see code at the bottom of my answer.
Each elliptic curve has an integer order q. If we have any point P on curve then it is well known that q * P = Zero, in other words multiplying any point by order q gives zero-point (infinity point).
Multiplying zero (infinity) point by any number gives zero again, i.e. j * Zero = Zero for any integer j. Adding any point P to zero-point gives P, i.e. Zero + P = P.
In our task we have some k such that R = k * P. We can very easily (very fast) compute Modular Inverse of k modulo order q, using for example Extended Euclidean Algorithm.
Inverse of k modulo q by definition is such that k * k^-1 = 1 (mod q), which by definition of modulus is equal k * k^-1 = j * q + 1 for some integer j.
Then k^-1 * R = k^-1 * k * P = (j * q + 1) * P = j * (q * P) + P = j * Zero + P = Zero + P = P. Thus multiplying R by k^-1 gives P, if k^-1 is inverse of k modulo q.
You can read about point addition and multiplication formulas on this Wiki.
Lets now check our formulas in Python programming language. I decided to implement from scratch simple class ECPoint, which implements all curve operations (addition and multiplication), see code below.
We take any ready-made curve, for example most popular 256-bit curve secp256k1, which is used in Bitcoin. Its parameters can be found here (this doc contains many other popular standard curves), also you can read about this specific curve on Bitcoin Wiki Page.
Following code is fully self-contained Python script, doesn't need any external dependencies and modules. You can run it straight away on any computer. ECPoint class implements all curve arithmetics. Function test() does following operations: we take standard secp256k1 params with some base point G, we compute any random point P = random * G, then we generate random k, compute R = k * P, compute modular inverse k^-1 (mod q) by using function modular_inverse() (which uses extended Euclidean algorithm egcd()), compute found_P = k^-1 * R and check that it is equal to P, i.e. check that k^-1 * R == P, print resulting k^-1 * R. All random values are 256-bit.
Try it online!
def egcd(a, b):
# https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm
ro, r, so, s, to, t = a, b, 1, 0, 0, 1
while r != 0:
ro, (q, r) = r, divmod(ro, r)
so, s = s, so - q * s
to, t = t, to - q * t
return ro, so, to
def modular_inverse(a, mod):
# https://en.wikipedia.org/wiki/Modular_multiplicative_inverse
g, s, t = egcd(a, mod)
assert g == 1, 'Value not invertible by modulus!'
return s % mod
class ECPoint:
#classmethod
def Int(cls, x):
return int(x)
#classmethod
def std_point(cls, name):
if name == 'secp256k1':
# https://en.bitcoin.it/wiki/Secp256k1
# https://www.secg.org/sec2-v2.pdf
p = 0xFFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFE_FFFFFC2F
a = 0
b = 7
x = 0x79BE667E_F9DCBBAC_55A06295_CE870B07_029BFCDB_2DCE28D9_59F2815B_16F81798
y = 0x483ADA77_26A3C465_5DA4FBFC_0E1108A8_FD17B448_A6855419_9C47D08F_FB10D4B8
q = 0xFFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFE_BAAEDCE6_AF48A03B_BFD25E8C_D0364141
else:
assert False
return ECPoint(x, y, a, b, p, q)
def __init__(self, x, y, A, B, N, q, *, prepare = True):
if prepare:
N = self.Int(N)
A, B, x, y, q = [self.Int(e) % N for e in [A, B, x, y, q]]
assert (4 * A ** 3 + 27 * B ** 2) % N != 0
assert (y ** 2 - x ** 3 - A * x - B) % N == 0, (
x, y, A, B, N, (y ** 2 - x ** 3 - A * x) % N)
assert N % 4 == 3
assert y == pow(x ** 3 + A * x + B, (N + 1) // 4, N)
self.A, self.B, self.N, self.x, self.y, self.q = A, B, N, x, y, q
def __add__(self, other):
A, N = self.A, self.N
Px, Py, Qx, Qy = self.x, self.y, other.x, other.y
if Px == Qx and Py == Qy:
s = ((Px * Px * 3 + A) * self.inv(Py * 2, N)) % N
else:
s = ((Py - Qy) * self.inv(Px - Qx, N)) % N
x = (s * s - Px - Qx) % N
y = (s * (Px - x) - Py) % N
return ECPoint(x, y, A, self.B, N, self.q, prepare = False)
def __rmul__(self, other):
other = self.Int(other - 1)
r = self
while True:
if other & 1:
r = r + self
if other == 1:
return r
other >>= 1
self = self + self
#classmethod
def inv(cls, a, n):
return modular_inverse(a, n)
def __repr__(self):
return str(dict(x = self.x, y = self.y, A = self.A,
B = self.B, N = self.N, q = self.q))
def __eq__(self, other):
for i, (a, b) in enumerate([
(self.x, other.x), (self.y, other.y), (self.A, other.A),
(self.B, other.B), (self.N, other.N), (self.q, other.q)]):
if a != b:
return False
return True
def test():
import random
bits = 256
P = random.randrange(1 << bits) * ECPoint.std_point('secp256k1')
k = random.randrange(1 << bits)
R = k * P
found_P = modular_inverse(k, R.q) * R
assert found_P == P
print(found_P)
if __name__ == '__main__':
test()
Output:
{
'x': 108051465657467150531748691374311160382608428790397210924352716318223953013557,
'y': 4462548165448905789984443302412298811224817997977472205419179335194291964455,
'A': 0,
'B': 7,
'N': 115792089237316195423570985008687907853269984665640564039457584007908834671663,
'q': 115792089237316195423570985008687907852837564279074904382605163141518161494337
}

Substituting function symbols in z3 formulas

What is the best way to substitute a function symbol (with another function) in a formula?
Z3py's substitute seems to only work with expressions, and what I do now is I try to guess all possible combinations of consts/vars to which the function could be applied and then substitute those with an application of another function. Is there a better way to do that?
We can implement a simple bottom-up rewriter that given a term s, a function f and term t will replace every f-application f(r_1, ..., r_n) in s with t[r_1, ..., r_n]. I'm using the notation t[r_1, ..., r_n] to denote the term obtained by replacing the free-variables in t with the terms r_1, ..., r_n.
The rewriter can be implemented the Z3 API. I use an AstMap to cache results, and a todo list to store expressions that still have to be processed.
Here is a simple example that replaces f-applications of the form f(t) with g(t+1) in s.
x = Var(0, IntSort())
print rewrite(s, f, g(x + 1))
Here is the code and more examples. Beware, I only tested the code in a small set of examples.
from z3 import *
def update_term(t, args):
# Update the children of term t with args.
# len(args) must be equal to the number of children in t.
# If t is an application, then len(args) == t.num_args()
# If t is a quantifier, then len(args) == 1
n = len(args)
_args = (Ast * n)()
for i in range(n):
_args[i] = args[i].as_ast()
return z3._to_expr_ref(Z3_update_term(t.ctx_ref(), t.as_ast(), n, _args), t.ctx)
def rewrite(s, f, t):
"""
Replace f-applications f(r_1, ..., r_n) with t[r_1, ..., r_n] in s.
"""
todo = [] # to do list
todo.append(s)
cache = AstMap(ctx=s.ctx)
while todo:
n = todo[len(todo) - 1]
if is_var(n):
todo.pop()
cache[n] = n
elif is_app(n):
visited = True
new_args = []
for i in range(n.num_args()):
arg = n.arg(i)
if not arg in cache:
todo.append(arg)
visited = False
else:
new_args.append(cache[arg])
if visited:
todo.pop()
g = n.decl()
if eq(g, f):
new_n = substitute_vars(t, *new_args)
else:
new_n = update_term(n, new_args)
cache[n] = new_n
else:
assert(is_quantifier(n))
b = n.body()
if b in cache:
todo.pop()
new_n = update_term(n, [ cache[b] ])
cache[n] = new_n
else:
todo.append(b)
return cache[s]
f = Function('f', IntSort(), IntSort())
a, b = Ints('a b')
s = Or(f(a) == 0, f(a) == 1, f(a+a) == 2)
# Example 1: replace all f-applications with b
print rewrite(s, f, b)
# Example 2: replace all f-applications f(t) with g(t+1)
g = Function('g', IntSort(), IntSort())
x = Var(0, IntSort())
print rewrite(s, f, g(x + 1))
# Now, f and g are binary functions.
f = Function('f', IntSort(), IntSort(), IntSort())
g = Function('g', IntSort(), IntSort(), IntSort())
# Example 3: replace all f-applications f(t1, t2) with g(t2, t1)
s = Or(f(a, f(a, b)) == 0, f(b, a) == 1, f(f(1,0), 0) == 2)
# The first argument is variable 0, and the second is variable 1.
y = Var(1, IntSort())
print rewrite(s, f, g(y, x))
# Example 4: quantifiers
s = ForAll([a], f(a, b) >= 0)
print rewrite(s, f, g(y, x + 1))

Resources