Skip to content

Instantly share code, notes, and snippets.

@Parr0sky
Forked from bellbind/ecc.py
Last active October 16, 2022 07:08
Show Gist options
  • Select an option

  • Save Parr0sky/4ed4df8266ce1ce970e5ecab6bb6df06 to your computer and use it in GitHub Desktop.

Select an option

Save Parr0sky/4ed4df8266ce1ce970e5ecab6bb6df06 to your computer and use it in GitHub Desktop.
[python]basics of elliptic curve cryptography
# Basics of Elliptic Curve Cryptography implementation on Python
import collections
import decimal
decimal.getcontext().prec = 5
def inv(n, q):
"""div on PN modulo a/b mod q as a * inv(b, q) mod q
>>> assert n * inv(n, q) % q == 1
"""
for i in range(q):
if (n * i) % q == 1:
return i
pass
assert False, "unreached"
pass
def sqrt(n, q):
"""sqrt on PN modulo: returns two numbers or exception if not exist
>>> assert (sqrt(n, q)[0] ** 2) % q == n
>>> assert (sqrt(n, q)[1] ** 2) % q == n
"""
assert n < q
print("pasa assert")
for i in range(1, q//2):
if (((i * i) % q) == n) or (((q-i-1) * (q-i-1)) % q) == n:
return (i, q - i)
pass
raise Exception("not found")
Coord = collections.namedtuple("Coord", ["x", "y"])
#from https://eli.thegreenplace.net/2009/03/07/computing-modular-square-roots-in-python#:~:text=Consider%20the%20congruence%20of%20the%20form%3A&text=In%20normal%20arithmetic%2C%20this%20is,root%20of%20n%20modulo%20p.
def sqrt_mod(a, p):
""" Find a quadratic residue (mod p) of 'a'. p
must be an odd prime.
Solve the congruence of the form:
x^2 = a (mod p)
And returns x. Note that p - x is also a root.
0 is returned is no square root exists for
these a and p.
The Tonelli-Shanks algorithm is used (except
for some simple cases in which the solution
is known from an identity). This algorithm
runs in polynomial time (unless the
generalized Riemann hypothesis is false).
"""
# Simple cases
#
if legendre_symbol(a, p) != 1:
print("retorna 0s 1")
return 0,0
elif a == 0:
print("retorna 0s 2")
return 0,0
elif p == 2:
print("retorna 0s 3")
return 0,0
elif p % 4 == 3:
print("retorna pows")
return pow(a, (p + 1) / 4, p), -pow(a, (p + 1) / 4, p)
# Partition p-1 to s * 2^e for an odd s (i.e.
# reduce all the powers of 2 from p-1)
#
s = p - 1
e = 0
while s % 2 == 0:
s /= 2
e += 1
# Find some 'n' with a legendre symbol n|p = -1.
# Shouldn't take long.
#
n = 2
while legendre_symbol(n, p) != -1:
n += 1
# Here be dragons!
# Read the paper "Square roots from 1; 24, 51,
# 10 to Dan Shanks" by Ezra Brown for more
# information
#
# x is a guess of the square root that gets better
# with each iteration.
# b is the "fudge factor" - by how much we're off
# with the guess. The invariant x^2 = ab (mod p)
# is maintained throughout the loop.
# g is used for successive powers of n to update
# both a and b
# r is the exponent - decreases with each update
#
x = pow(a, (s + 1) / 2, p)
b = pow(a, s, p)
g = pow(n, s, p)
r = e
while True:
t = b
m = 0
for m in range(r):
if t == 1:
break
t = pow(t, 2, p)
if m == 0:
print("retorna x")
return x,-x
gs = pow(g, 2 ** (r - m - 1), p)
g = (gs * gs) % p
x = (x * gs) % p
b = (b * g) % p
r = m
def legendre_symbol(a, p):
""" Compute the Legendre symbol a|p using
Euler's criterion. p is a prime, a is
relatively prime to p (if p divides
a, then a|p = 0)
Returns 1 if a has a square root modulo
p, -1 otherwise.
"""
print(": "+str(a)+","+str(((p - 1) / 2))+","+str(p))
ls = pow(a, ((p - 1) / 2),p)
return -1 if ls == p - 1 else ls
class EC(object):
"""System of Elliptic Curve"""
def __init__(self, a, b, q):
"""elliptic curve as: (y**2 = x**3 + a * x + b) mod q
- a, b: params of curve formula
- q: prime number
"""
assert 0 <= a and a < q and 0 < b and b < q and q > 2
assert (4 * (a ** 3) + 27 * (b ** 2)) % q != 0
self.a = a
self.b = b
self.q = q
# just as unique ZERO value representation for "add": (not on curve)
self.zero = Coord(0, 0)
pass
def is_valid(self, p):
if p == self.zero: return True
l = (p.y ** 2) % self.q
r = ((p.x ** 3) + self.a * p.x + self.b) % self.q
return l == r
def at(self, x):
"""find points on curve at x
- x: int < q
- returns: ((x, y), (x,-y)) or not found exception
>>> a, ma = ec.at(x)
>>> assert a.x == ma.x and a.x == x
>>> assert a.x == ma.x and a.x == x
>>> assert ec.neg(a) == ma
>>> assert ec.is_valid(a) and ec.is_valid(ma)
"""
assert x < self.q
print("primero")
ysq = (x ** 3 + self.a * x + self.b) % self.q
print("segundo")
print("vals: "+str(ysq)+","+str(self.q))
y, my = sqrt(ysq, self.q)
return Coord(x, y), Coord(x, my)
def neg(self, p):
"""negate p
>>> assert ec.is_valid(ec.neg(p))
"""
return Coord(p.x, -p.y % self.q)
def add(self, p1, p2):
"""<add> of elliptic curve: negate of 3rd cross point of (p1,p2) line
>>> d = ec.add(a, b)
>>> assert ec.is_valid(d)
>>> assert ec.add(d, ec.neg(b)) == a
>>> assert ec.add(a, ec.neg(a)) == ec.zero
>>> assert ec.add(a, b) == ec.add(b, a)
>>> assert ec.add(a, ec.add(b, c)) == ec.add(ec.add(a, b), c)
"""
if p1 == self.zero: return p2
if p2 == self.zero: return p1
if p1.x == p2.x and (p1.y != p2.y or p1.y == 0):
# p1 + -p1 == 0
return self.zero
if p1.x == p2.x:
# p1 + p1: use tangent line of p1 as (p1,p1) line
l = (3 * p1.x * p1.x + self.a) * inv(2 * p1.y, self.q) % self.q
pass
else:
l = (p2.y - p1.y) * inv(p2.x - p1.x, self.q) % self.q
pass
x = (l * l - p1.x - p2.x) % self.q
y = (l * (p1.x - x) - p1.y) % self.q
return Coord(x, y)
def mul(self, p, n):
"""n times <mul> of elliptic curve
>>> m = ec.mul(p, n)
>>> assert ec.is_valid(m)
>>> assert ec.mul(p, 0) == ec.zero
"""
r = self.zero
m2 = p
# O(log2(n)) add
while 0 < n:
if n & 1 == 1:
r = self.add(r, m2)
pass
n, m2 = n >> 1, self.add(m2, m2)
pass
# [ref] O(n) add
#for i in range(n):
# r = self.add(r, p)
# pass
return r
def order(self, g):
"""order of point g
>>> o = ec.order(g)
>>> assert ec.is_valid(a) and ec.mul(a, o) == ec.zero
>>> assert o <= ec.q
"""
print(self.is_valid(g))
assert self.is_valid(g) and g != self.zero
for i in range(1, self.q + 1):
if self.mul(g, i) == self.zero:
return i
pass
raise Exception("Invalid order")
pass
class ElGamal(object):
"""ElGamal Encryption
pub key encryption as replacing (mulmod, powmod) to (ec.add, ec.mul)
- ec: elliptic curve
- g: (random) a point on ec
"""
def __init__(self, ec, g):
assert ec.is_valid(g)
self.ec = ec
self.g = g
self.n = ec.order(g)
pass
def gen(self, priv):
"""generate pub key
- priv: priv key as (random) int < ec.q
- returns: pub key as points on ec
"""
return self.ec.mul(g, priv)
def enc(self, plain, pub, r):
"""encrypt
- plain: data as a point on ec
- pub: pub key as points on ec
- r: randam int < ec.q
- returns: (cipher1, ciper2) as points on ec
"""
assert self.ec.is_valid(plain)
assert self.ec.is_valid(pub)
return (self.ec.mul(g, r), self.ec.add(plain, self.ec.mul(pub, r)))
def dec(self, cipher, priv):
"""decrypt
- chiper: (chiper1, chiper2) as points on ec
- priv: private key as int < ec.q
- returns: plain as a point on ec
"""
c1, c2 = cipher
assert self.ec.is_valid(c1) and ec.is_valid(c2)
return self.ec.add(c2, self.ec.neg(self.ec.mul(c1, priv)))
pass
class DiffieHellman(object):
"""Elliptic Curve Diffie Hellman (Key Agreement)
- ec: elliptic curve
- g: a point on ec
"""
def __init__(self, ec, g):
self.ec = ec
self.g = g
self.n = ec.order(g)
pass
def gen(self, priv):
"""generate pub key"""
assert 0 < priv and priv < self.n
return self.ec.mul(self.g, priv)
def secret(self, priv, pub):
"""calc shared secret key for the pair
- priv: my private key as int
- pub: partner pub key as a point on ec
- returns: shared secret as a point on ec
"""
assert self.ec.is_valid(pub)
assert self.ec.mul(pub, self.n) == self.ec.zero
return self.ec.mul(pub, priv)
pass
class DSA(object):
"""ECDSA
- ec: elliptic curve
- g: a point on ec
"""
def __init__(self, ec, g):
self.ec = ec
self.g = g
self.n = ec.order(g)
pass
def gen(self, priv):
"""generate pub key"""
assert 0 < priv and priv < self.n
return self.ec.mul(self.g, priv)
def sign(self, hashval, priv, r):
"""generate signature
- hashval: hash value of message as int
- priv: priv key as int
- r: random int
- returns: signature as (int, int)
"""
assert 0 < r and r < self.n
m = self.ec.mul(self.g, r)
return (m.x, inv(r, self.n) * (hashval + m.x * priv) % self.n)
def validate(self, hashval, sig, pub):
"""validate signature
- hashval: hash value of message as int
- sig: signature as (int, int)
- pub: pub key as a point on ec
"""
assert self.ec.is_valid(pub)
assert self.ec.mul(pub, self.n) == self.ec.zero
w = inv(sig[1], self.n)
u1, u2 = hashval * w % self.n, sig[0] * w % self.n
p = self.ec.add(self.ec.mul(self.g, u1), self.ec.mul(pub, u2))
return p.x % self.n == sig[0]
pass
if __name__ == "__main__":
# shared elliptic curve system of examples
ec = EC(0, 7, 115792089237316195423570985008687907853269984665640564039457584007908834671663)
g, _ = ec.at(7)
assert ec.order(g) <= ec.q
mapping = [ec.mul(g, i) for i in range(eg.n)]
plain = mapping[7]
print(len(plain.encode('utf-8')))
# ElGamal enc/dec usage
eg = ElGamal(ec, g)
# mapping value to ec point
# "masking": value k to point ec.mul(g, k)
# ("imbedding" on proper n:use a point of x as 0 <= n*v <= x < n*(v+1) < q)
priv = 5
pub = eg.gen(priv)
cipher = eg.enc(plain, pub, 15)
decoded = eg.dec(cipher, priv)
assert decoded == plain
assert cipher != pub
# ECDH usage
dh = DiffieHellman(ec, g)
apriv = 11
apub = dh.gen(apriv)
bpriv = 3
bpub = dh.gen(bpriv)
cpriv = 7
cpub = dh.gen(cpriv)
# same secret on each pair
assert dh.secret(apriv, bpub) == dh.secret(bpriv, apub)
assert dh.secret(apriv, cpub) == dh.secret(cpriv, apub)
assert dh.secret(bpriv, cpub) == dh.secret(cpriv, bpub)
# not same secret on other pair
assert dh.secret(apriv, cpub) != dh.secret(apriv, bpub)
assert dh.secret(bpriv, apub) != dh.secret(bpriv, cpub)
assert dh.secret(cpriv, bpub) != dh.secret(cpriv, apub)
# ECDSA usage
dsa = DSA(ec, g)
priv = 11
pub = eg.gen(priv)
hashval = 128
r = 7
sig = dsa.sign(hashval, priv, r)
assert dsa.validate(hashval, sig, pub)
pass
# Appendix: improved modulo calc
def inv(n, q):
"""div on PN modulo a/b mod q as a * inv(b, q) mod q
>>> assert n * inv(n, q) % q == 1
"""
# n*inv % q = 1 => n*inv = q*m + 1 => n*inv + q*-m = 1
# => egcd(n, q) = (inv, -m, 1) => inv = egcd(n, q)[0] (mod q)
return egcd(n, q)[0] % q
#[ref] naive implementation
#for i in range(q):
# if (n * i) % q == 1:
# return i
# pass
#assert False, "unreached"
#pass
def egcd(a, b):
"""extended GCD
returns: (s, t, gcd) as a*s + b*t == gcd
>>> s, t, gcd = egcd(a, b)
>>> assert a % gcd == 0 and b % gcd == 0
>>> assert a * s + b * t == gcd
"""
s0, s1, t0, t1 = 1, 0, 0, 1
while b > 0:
q, r = divmod(a, b)
a, b = b, r
s0, s1, t0, t1 = s1, s0 - q * s1, t1, t0 - q * t1
pass
return s0, t0, a
def inv2(n, q):
"""another PN invmod: from euler totient function
- n ** (q - 1) % q = 1 => n ** (q - 2) % q = n ** -1 % q
"""
assert q > 2
s, p2, p = 1, n, q - 2
while p > 0:
if p & 1 == 1: s = s * p2 % q
p, p2 = p >> 1, pow(p2, 2, q)
pass
return s
def sqrt(n, q):
"""sqrt on PN modulo: returns two numbers or exception if not exist
>>> assert (sqrt(n, q)[0] ** 2) % q == n
>>> assert (sqrt(n, q)[1] ** 2) % q == n
"""
assert n < q
for i in range(1, q):
if pow(i, 2, q) == n:
return (i, q - i)
pass
raise Exception("not found")
def sqrt2(n, q):
"""sqrtmod for bigint
- Algorithm 3.34 of http://www.cacr.math.uwaterloo.ca/hac/about/chap3.pdf
"""
import random
# b: some non-quadratic-residue
b = 0
while b == 0 or jacobi(b, q) != -1:
b = random.randint(1, q - 1)
pass
# q = t * 2^s + 1, t is odd
t, s = q - 1, 0
while t & 1 == 0:
t, s = t >> 1, s + 1
pass
assert q == t * pow(2, s) + 1 and t % 2 == 1
ni = inv(n, q)
c = pow(b, t, q)
r = pow(n, (t + 1) // 2, q)
for i in range(1, s):
d = pow(pow(r, 2, q) * ni % q, pow(2, s - i - 1, q), q)
if d == q - 1: r = r * c % q
c = pow(c, 2, q)
pass
return (r, q - r)
def jacobi(a, q):
"""jacobi symbol: judge existing sqrtmod (1: exist, -1: not exist)
- j(a*b,q) = j(a,q)*j(b,q)
- j(a*q+b, q) = j(b, q)
- j(a, 1) = 1
- j(0, q) = 0
- j(2, q) = -1 ** (q^2 - 1)/8
- j(p, q) = -1 ^ {(p - 1)/2 * (q - 1)/2} * j(q, p)
"""
if q == 1: return 1
if a == 0: return 0
if a % 2 == 0: return (-1) ** ((q * q - 1) // 8) * jacobi(a // 2, q)
return (-1) ** ((a - 1) // 2 * (q - 1) // 2) * jacobi(q % a, a)
def jacobi2(a, q):
"""quick jacobi symbol
- algorithm 2.149 of http://www.cacr.math.uwaterloo.ca/hac/about/chap2.pdf
"""
if a == 0: return 0
if a == 1: return 1
a1, e = a, 0
while a1 & 1 == 0:
a1, e = a1 >> 1, e + 1
pass
m8 = q % 8
s = -1 if m8 == 3 or m8 == 5 else 1 # m8 = 0,2,4,6 and 1,7
if q % 4 == 3 and a1 % 4 == 3: s = -s
return s if a1 == 1 else s * jacobi2(q % a1, a1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment