What is the best way to substitute a function symbol (with another function) in a formula?
Z3py's substitute seems to only work with expressions, and what I do now is I try to guess all possible combinations of consts/vars to which the function could be applied and then substitute those with an application of another function. Is there a better way to do that?
We can implement a simple bottom-up rewriter that given a term s, a function f and term t will replace every f-application f(r_1, ..., r_n) in s with t[r_1, ..., r_n]. I'm using the notation t[r_1, ..., r_n] to denote the term obtained by replacing the free-variables in t with the terms r_1, ..., r_n.
The rewriter can be implemented the Z3 API. I use an AstMap to cache results, and a todo list to store expressions that still have to be processed.
Here is a simple example that replaces f-applications of the form f(t) with g(t+1) in s.
x = Var(0, IntSort())
print rewrite(s, f, g(x + 1))
Here is the code and more examples. Beware, I only tested the code in a small set of examples.
from z3 import *
def update_term(t, args):
# Update the children of term t with args.
# len(args) must be equal to the number of children in t.
# If t is an application, then len(args) == t.num_args()
# If t is a quantifier, then len(args) == 1
n = len(args)
_args = (Ast * n)()
for i in range(n):
_args[i] = args[i].as_ast()
return z3._to_expr_ref(Z3_update_term(t.ctx_ref(), t.as_ast(), n, _args), t.ctx)
def rewrite(s, f, t):
"""
Replace f-applications f(r_1, ..., r_n) with t[r_1, ..., r_n] in s.
"""
todo = [] # to do list
todo.append(s)
cache = AstMap(ctx=s.ctx)
while todo:
n = todo[len(todo) - 1]
if is_var(n):
todo.pop()
cache[n] = n
elif is_app(n):
visited = True
new_args = []
for i in range(n.num_args()):
arg = n.arg(i)
if not arg in cache:
todo.append(arg)
visited = False
else:
new_args.append(cache[arg])
if visited:
todo.pop()
g = n.decl()
if eq(g, f):
new_n = substitute_vars(t, *new_args)
else:
new_n = update_term(n, new_args)
cache[n] = new_n
else:
assert(is_quantifier(n))
b = n.body()
if b in cache:
todo.pop()
new_n = update_term(n, [ cache[b] ])
cache[n] = new_n
else:
todo.append(b)
return cache[s]
f = Function('f', IntSort(), IntSort())
a, b = Ints('a b')
s = Or(f(a) == 0, f(a) == 1, f(a+a) == 2)
# Example 1: replace all f-applications with b
print rewrite(s, f, b)
# Example 2: replace all f-applications f(t) with g(t+1)
g = Function('g', IntSort(), IntSort())
x = Var(0, IntSort())
print rewrite(s, f, g(x + 1))
# Now, f and g are binary functions.
f = Function('f', IntSort(), IntSort(), IntSort())
g = Function('g', IntSort(), IntSort(), IntSort())
# Example 3: replace all f-applications f(t1, t2) with g(t2, t1)
s = Or(f(a, f(a, b)) == 0, f(b, a) == 1, f(f(1,0), 0) == 2)
# The first argument is variable 0, and the second is variable 1.
y = Var(1, IntSort())
print rewrite(s, f, g(y, x))
# Example 4: quantifiers
s = ForAll([a], f(a, b) >= 0)
print rewrite(s, f, g(y, x + 1))
Related
I have the following formula and Python code trying to find the largest n satisfying some property P:
x, u, n, n2 = Ints('x u n n2')
def P(u):
return Implies(And(2 <= x, x <= u), And(x >= 1, x <= 10))
nIsLargest = ForAll(n2, Implies(P(n2), n2 <= n))
exp = ForAll(x, And(P(n), nIsLargest))
s = SolverFor("LIA")
s.reset()
s.add(exp)
print(s.check())
if s.check() == sat:
print(s.model())
My expectation was that it would return n=10, yet Z3 returns unsat. What am I missing?
You're using the optimization API incorrectly; and your question is a bit confusing since your predicate P has a free variable x: Obviously, the value that maximizes it will depend on both x and u.
Here's a simpler example that can get you started, showing how to use the API correctly:
from z3 import *
def P(x):
return And(x >= 1, x <= 10)
n = Int('n')
opt = Optimize()
opt.add(P(n))
maxN = opt.maximize(n)
r = opt.check()
print(r)
if r == sat:
print("maxN =", maxN.value())
This prints:
sat
maxN = 10
Hopefully you can take this example and extend it your use case.
I want to check the value of a, b, c, and if value 'a' equals to 1, 'x' is added one. We continue the process for values 'b' and 'c'.
So if a=1, b=1, c=1, the result of x should be 3.
if a=1, b=1, c=0, so the result of x should be 2.
Any methods to be implemented in z3?
The source code looks like this:
from z3 import *
a, b, c = Ints('a b c')
x, y = Ints('x y')
s = Solver()
s.add(If(a==1, x=x + 1, y = y-1))
s.add(If(b==1, x=x + 1, y = y-1))
s.add(If(c==1, x=x + 1, y = y-1))
s.check()
print s.model()
Any suggestions about what I can do?
This sort of "iterative" processing is usually modeled by unrolling the assignments and creating what's known as SSA form. (Static single assignment.) In this format, every variable is assigned precisely once, but can be used many times. This is usually done by some underlying tool as it is rather tedious, but you can do it by hand as well. Applied to your problem, it'd look something like:
from z3 import *
s = Solver()
a, b, c = Ints('a b c')
x0, x1, x2, x3 = Ints('x0 x1 x2 x3')
s.add(x0 == 0)
s.add(x1 == If(a == 1, x0+1, x0))
s.add(x2 == If(b == 1, x1+1, x1))
s.add(x3 == If(c == 1, x2+1, x2))
# Following asserts are not part of your problem, but
# they make the output interesting
s.add(b == 1)
s.add(c == 0)
# Find the model
if s.check() == sat:
m = s.model()
print("a=%d, b=%d, c=%d, x=%d" % (m[a].as_long(), m[b].as_long(), m[c].as_long(), m[x3].as_long()))
else:
print "no solution"
SSA transformation is applied to the variable x, creating as many instances as necessary to model the assignments. When run, this program produces:
a=0, b=1, c=0, x=1
Hope that helps!
Note that z3 has many functions. One you could use here is Sum() for the sum of a list. Inside the list you can put simple variables, but also expression. Here an example for both a simple and a more complex sum:
from z3 import *
a, b, c = Ints('a b c')
x, y = Ints('x y')
s = Solver()
s.add(a==1, b==0, c==1)
s.add(x==Sum([a,b,c]))
s.add(y==Sum([If(a==1,-1,0),If(b==1,-1,0),If(c==1,-1,0)]))
if s.check() == sat:
print ("solution:", s.model())
else:
print ("no solution possible")
Result:
solution: [y = 2, x = 2, c = 1, b = 0, a = 1]
If your problem is more complex, using BitVecs instead of Ints can make it run a little faster.
edit: Instead of Sum() you could also simply use addition as in
s.add(x==a+b+c)
s.add(y==If(a==1,-1,0)+If(b==1,-1,0)+If(c==1,-1,0))
Sum() makes sense towards readability when you have a longer list of variables, or when the variables already are in a list.
I'm using Elliptic Curve to design a security system. P is a point on elliptic curve. The receiver must obtain P using formula k^-1(kP). The receiver does not know P but knows k. I need to compute k^-1(R) where R=kP. How can I do this using Point Multiplication or Point Addition.
I suggest first learning a bit more about ECC (for example, read some of Paar's book and listen to his course at http://www.crypto-textbook.com/) before tackling something this complex. For this particular question, ask yourself: "What does the inverse of k mean?"
Very interesting question you have! I was happy to implement from scratch Python solution for your task, see code at the bottom of my answer.
Each elliptic curve has an integer order q. If we have any point P on curve then it is well known that q * P = Zero, in other words multiplying any point by order q gives zero-point (infinity point).
Multiplying zero (infinity) point by any number gives zero again, i.e. j * Zero = Zero for any integer j. Adding any point P to zero-point gives P, i.e. Zero + P = P.
In our task we have some k such that R = k * P. We can very easily (very fast) compute Modular Inverse of k modulo order q, using for example Extended Euclidean Algorithm.
Inverse of k modulo q by definition is such that k * k^-1 = 1 (mod q), which by definition of modulus is equal k * k^-1 = j * q + 1 for some integer j.
Then k^-1 * R = k^-1 * k * P = (j * q + 1) * P = j * (q * P) + P = j * Zero + P = Zero + P = P. Thus multiplying R by k^-1 gives P, if k^-1 is inverse of k modulo q.
You can read about point addition and multiplication formulas on this Wiki.
Lets now check our formulas in Python programming language. I decided to implement from scratch simple class ECPoint, which implements all curve operations (addition and multiplication), see code below.
We take any ready-made curve, for example most popular 256-bit curve secp256k1, which is used in Bitcoin. Its parameters can be found here (this doc contains many other popular standard curves), also you can read about this specific curve on Bitcoin Wiki Page.
Following code is fully self-contained Python script, doesn't need any external dependencies and modules. You can run it straight away on any computer. ECPoint class implements all curve arithmetics. Function test() does following operations: we take standard secp256k1 params with some base point G, we compute any random point P = random * G, then we generate random k, compute R = k * P, compute modular inverse k^-1 (mod q) by using function modular_inverse() (which uses extended Euclidean algorithm egcd()), compute found_P = k^-1 * R and check that it is equal to P, i.e. check that k^-1 * R == P, print resulting k^-1 * R. All random values are 256-bit.
Try it online!
def egcd(a, b):
# https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm
ro, r, so, s, to, t = a, b, 1, 0, 0, 1
while r != 0:
ro, (q, r) = r, divmod(ro, r)
so, s = s, so - q * s
to, t = t, to - q * t
return ro, so, to
def modular_inverse(a, mod):
# https://en.wikipedia.org/wiki/Modular_multiplicative_inverse
g, s, t = egcd(a, mod)
assert g == 1, 'Value not invertible by modulus!'
return s % mod
class ECPoint:
#classmethod
def Int(cls, x):
return int(x)
#classmethod
def std_point(cls, name):
if name == 'secp256k1':
# https://en.bitcoin.it/wiki/Secp256k1
# https://www.secg.org/sec2-v2.pdf
p = 0xFFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFE_FFFFFC2F
a = 0
b = 7
x = 0x79BE667E_F9DCBBAC_55A06295_CE870B07_029BFCDB_2DCE28D9_59F2815B_16F81798
y = 0x483ADA77_26A3C465_5DA4FBFC_0E1108A8_FD17B448_A6855419_9C47D08F_FB10D4B8
q = 0xFFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFE_BAAEDCE6_AF48A03B_BFD25E8C_D0364141
else:
assert False
return ECPoint(x, y, a, b, p, q)
def __init__(self, x, y, A, B, N, q, *, prepare = True):
if prepare:
N = self.Int(N)
A, B, x, y, q = [self.Int(e) % N for e in [A, B, x, y, q]]
assert (4 * A ** 3 + 27 * B ** 2) % N != 0
assert (y ** 2 - x ** 3 - A * x - B) % N == 0, (
x, y, A, B, N, (y ** 2 - x ** 3 - A * x) % N)
assert N % 4 == 3
assert y == pow(x ** 3 + A * x + B, (N + 1) // 4, N)
self.A, self.B, self.N, self.x, self.y, self.q = A, B, N, x, y, q
def __add__(self, other):
A, N = self.A, self.N
Px, Py, Qx, Qy = self.x, self.y, other.x, other.y
if Px == Qx and Py == Qy:
s = ((Px * Px * 3 + A) * self.inv(Py * 2, N)) % N
else:
s = ((Py - Qy) * self.inv(Px - Qx, N)) % N
x = (s * s - Px - Qx) % N
y = (s * (Px - x) - Py) % N
return ECPoint(x, y, A, self.B, N, self.q, prepare = False)
def __rmul__(self, other):
other = self.Int(other - 1)
r = self
while True:
if other & 1:
r = r + self
if other == 1:
return r
other >>= 1
self = self + self
#classmethod
def inv(cls, a, n):
return modular_inverse(a, n)
def __repr__(self):
return str(dict(x = self.x, y = self.y, A = self.A,
B = self.B, N = self.N, q = self.q))
def __eq__(self, other):
for i, (a, b) in enumerate([
(self.x, other.x), (self.y, other.y), (self.A, other.A),
(self.B, other.B), (self.N, other.N), (self.q, other.q)]):
if a != b:
return False
return True
def test():
import random
bits = 256
P = random.randrange(1 << bits) * ECPoint.std_point('secp256k1')
k = random.randrange(1 << bits)
R = k * P
found_P = modular_inverse(k, R.q) * R
assert found_P == P
print(found_P)
if __name__ == '__main__':
test()
Output:
{
'x': 108051465657467150531748691374311160382608428790397210924352716318223953013557,
'y': 4462548165448905789984443302412298811224817997977472205419179335194291964455,
'A': 0,
'B': 7,
'N': 115792089237316195423570985008687907853269984665640564039457584007908834671663,
'q': 115792089237316195423570985008687907852837564279074904382605163141518161494337
}
I posted a related question, but then I think it was not very clear. I would like to rephrase the problem like this:
Two formulas a1 == a + b (1) and a1 == b (2) are equivalent if a == 0. Given these formulas (1) and (2), how can I use Z3 python to find out this required condition (a == 0) so the above formulas become equivalent?
I suppose that a1, a and b are all in the format of BitVecs(32).
Edit: I came up with the code like this:
from z3 import *
a, b = BitVecs('a b', 32)
a1 = BitVec('a1', 32)
s = Solver()
s.add(ForAll(b, a + b == b))
if s.check() == sat:
print 'a =', s.model()[a]
else:
print 'Not Equ'
The output is: a = 0, as expected.
However, when I modified the code a bit to use two formulas, it doesnt work anymore:
from z3 import *
a, b = BitVecs('a b', 32)
a1 = BitVec('a1', 32)
f = True
f = And(f, a1 == a * b)
g = True
g = And(g, a1 == b)
s = Solver()
s.add(ForAll(b, f == g))
if s.check() == sat:
print 'a =', s.model()[a]
else:
print 'Not Equ'
The output now is different: a = 1314914305
So the questions are:
(1) Why the second code produces different (wrong) result?
(2) Is there any way to do this without using ForAll (or quantifier) at all?
Thanks
The two codes produce the same correct answer a = 0. You have a typo: you are writing
a1 = a*b and it must be a1 = a + b . Do you agree?
Possible code without using ForAll:
a, b = BitVecs('a b', 32)
a1 = BitVec('a1', 32)
s = Solver()
s.add(a + b == b)
if s.check() == sat:
print 'a =', s.model()[a]
else:
print 'Not Equ'
s1 = Solver()
s1.add(a==0, Not(a + b == b))
print s1.check()
Output:
a = 0
unsat
In the following working example , How to retrieve the matched model?
S, (cl_3,cl_39,cl_11, me_32,m_59,m_81) =
EnumSort('S', ['cl_3','cl_39','cl_11','me_32','me_59','me_81'])
h1, h2 = Consts('h1 h2', S)
def fun(h1 , h2):
conds = [
(cl_3, me_32),
(cl_39, me_59),
(cl_11, me_81),
# ...
]
and_conds = (And(h1==a, h2==b) for a,b in conds)
return Or(*and_conds)
For Example:
as the following solver
s = Solver()
x1 = Const('x1', S)
x2 = Const('x2', S)
s.add(fun(x1,x2))
print s.check()
print s.model()
I'm assuming that you want the value of x1 and x2 in the model produced by Z3. If that is the case, you can retrieve them using:
m = s.model()
print m[x1]
print m[x2]
Here is the complete example (also available online here). BTW, note that we don't need h1, h2 = Consts('h1 h2', S).
S, (cl_3, cl_39, cl_11, me_32, me_59, me_81) =
EnumSort('S', ['cl_3','cl_39','cl_11','me_32','me_59','me_81'])
def fun(h1 , h2):
conds = [
(cl_3, me_32),
(cl_39, me_59),
(cl_11, me_81),
]
and_conds = (And(h1==a, h2==b) for a,b in conds)
return Or(*and_conds)
s = Solver()
x1 = Const('x1', S)
x2 = Const('x2', S)
s.add(fun(x1,x2))
print s.check()
m = s.model()
print m
print m[x1]
print m[x2]