Skip to content

Instantly share code, notes, and snippets.

@soraros
Created October 17, 2024 10:04
Show Gist options
  • Select an option

  • Save soraros/026cf24eb7a2acb9a823e8d57fac0e32 to your computer and use it in GitHub Desktop.

Select an option

Save soraros/026cf24eb7a2acb9a823e8d57fac0e32 to your computer and use it in GitHub Desktop.
# This is a port of https://github.com/orlp/foldhash
from hashlib._ahash import _folded_multiply
from sys import sizeof
from sys.intrinsics import assume
from testing import assert_equal
from utils import Span
alias ARBITRARY0: UInt64 = 0x243f6a8885a308d3
alias ARBITRARY1: UInt64 = 0x13198a2e03707344
alias ARBITRARY2: UInt64 = 0xa4093822299f31d0
alias ARBITRARY3: UInt64 = 0x082efa98ec4e6c89
alias ARBITRARY4: UInt64 = 0x452821e638d01377
alias ARBITRARY5: UInt64 = 0xbe5466cf34e90c6c
alias ARBITRARY6: UInt64 = 0xc0ac29b7c97c50dd
alias ARBITRARY7: UInt64 = 0x3f84d5b5b5470917
alias ARBITRARY8: UInt64 = 0x9216d5d98979fb1b
alias ARBITRARY9: UInt64 = 0xd1310ba698dfb5ac
alias U256 = SIMD[DType.uint64, 4]
alias U128 = SIMD[DType.uint64, 2]
fn into[dt: DType, //](x: Scalar[dt]) -> U128:
return U128(x._float_to_bits[DType.uint64](), 0)
fn shift_left(x: U128, n: UInt8) -> U128:
u128 = __mlir_op.`pop.bitcast`[_type=__mlir_type.`!pop.scalar<ui128>`](x.value)
n_ = __mlir_op.`pop.cast`[_type=__mlir_type.`!pop.scalar<ui128>`](n.value)
shl = __mlir_op.`pop.shl`(u128, n_)
return U128(__mlir_op.`pop.bitcast`[_type=__mlir_type.`!pop.simd<2, ui64>`](shl))
struct FoldHasher[key: U256]:
var accumulator: UInt64
var sponge: U128
var sponge_len: UInt8
var fold_seed: UInt64
var expand_seed: UInt64
var expand_seed2: UInt64
var expand_seed3: UInt64
fn __init__(inout self, per_hash_key: UInt64):
self.accumulator = per_hash_key
self.sponge = 0
self.sponge_len = 0
self.fold_seed = key[0]
self.expand_seed = key[1]
self.expand_seed2 = key[2]
self.expand_seed3 = key[3]
...
fn write_num[dtype: DType](inout self, x: Scalar[dtype]):
alias bits: UInt = 8 * sizeof[dtype]()
if self.sponge_len + bits > 128:
self.write_u128(self.sponge)
self.sponge = into(x)
self.sponge_len = 0
else:
# self.sponge |= into(x) << self.sponge_len
self.sponge |= shift_left(into(x), self.sponge_len)
self.sponge_len += bits
fn write_u128(inout self, n: U128):
self.accumulator = _folded_multiply(n[0] ^ self.accumulator, n[1] ^ self.fold_seed);
fn write[o: ImmutableOrigin](inout self, bytes: Span[Byte, o]):
s0 = self.accumulator
s1 = self.expand_seed
p = bytes.unsafe_ptr()
l = len(bytes)
if l <= 16:
if l >= 8:
s0 ^= p.bitcast[UInt64]().load()
s1 ^= p.offset(l - 8).bitcast[UInt64]().load()
elif l >= 4:
s0 ^= p.bitcast[UInt32]().load().cast[DType.uint64]()
s1 ^= p.offset(l - 4).bitcast[UInt32]().load().cast[DType.uint64]()
elif l > 0:
lo = p.load()
mid = p.offset(l // 2).load()
hi = p.offset(l - 1).load()
s0 ^= lo.cast[DType.uint64]()
s1 ^= (hi.cast[DType.uint64]() << 8) | mid.cast[DType.uint64]()
self.accumulator = _folded_multiply(s0, s1)
elif l < 256:
self.accumulator = hash_bytes_medium(bytes, s0, s1, self.fold_seed)
else:
self.accumulator = hash_bytes_long(
bytes, s0, s1, self.expand_seed2, self.expand_seed3, self.fold_seed)
fn finish(self) -> UInt64:
if self.sponge_len > 0:
lo = self.sponge[0]
hi = self.sponge[1]
return _folded_multiply(lo ^ self.accumulator, hi ^ self.fold_seed)
else:
return self.accumulator
fn hash_bytes_medium(
bytes: Span[Byte],
inout s0: UInt64, inout s1: UInt64,
fold_seed: UInt64
) -> UInt64:
assume(len(bytes) >= 16)
p = bytes.unsafe_ptr()
q = p.offset(len(bytes) - 16)
while p <= q:
v1 = p.bitcast[UInt64]().load[width=2]()
v2 = q.bitcast[UInt64]().load[width=2]()
s0 = _folded_multiply(v1[0] ^ s0, v2[0] ^ fold_seed)
s1 = _folded_multiply(v1[1] ^ s1, v2[1] ^ fold_seed)
p += 16
q -= 16
return s0 ^ s1
import math
fn hash_bytes_long(
bytes: Span[Byte],
inout s0: UInt64, inout s1: UInt64,
inout s2: UInt64, inout s3: UInt64,
fold_seed: UInt64
) -> UInt64:
assume(len(bytes) >= 16)
end = math.align_down(len(bytes), 64)
p = bytes.unsafe_ptr()
for _ in range(0, end, 64):
v = p.bitcast[UInt64]().load[width=8]()
s0 = _folded_multiply(v[0] ^ s0, v[4] ^ fold_seed)
s1 = _folded_multiply(v[1] ^ s1, v[5] ^ fold_seed)
s2 = _folded_multiply(v[2] ^ s2, v[6] ^ fold_seed)
s3 = _folded_multiply(v[3] ^ s3, v[7] ^ fold_seed)
p += 64
s0 ^= s2
s1 ^= s3
if end < len(bytes):
return hash_bytes_medium(bytes, s0, s1, fold_seed)
else:
return s0 ^ s1
fn main() raises:
test_write_num()
test_write_medium()
alias global_seed = SIMD[DType.uint64, 4](
0x243f6a8885a308d3,
0x13198a2e03707344,
0xa4093822299f31d0,
0x082efa98ec4e6c89,
)
alias per_hasher_seed: UInt64 = 0x7f1b8b2f7e0e3b3d
fn test_write_num() raises:
hasher = FoldHasher[global_seed](per_hasher_seed)
hasher.write_num[DType.uint8](1)
assert_equal(hasher.finish(), 0xda5d77f425cf5108)
hasher.write_num[DType.uint8](1)
assert_equal(hasher.finish(), 0x9dc8ff7a88b83c2c)
hasher = FoldHasher[global_seed](per_hasher_seed)
hasher.write_num[DType.uint16](-1)
assert_equal(hasher.finish(), 0xa529a2fa72dfcd13)
hasher.write_num[DType.uint16](1)
assert_equal(hasher.finish(), 0x0ea0d85769a320d3)
hasher = FoldHasher[global_seed](per_hasher_seed)
hasher.write_num[DType.uint32](-1)
assert_equal(hasher.finish(), 0xd4893b4fa665e296)
hasher.write_num[DType.uint32](1)
assert_equal(hasher.finish(), 0x2e2c02a342a4550e)
hasher = FoldHasher[global_seed](per_hasher_seed)
hasher.write_num[DType.uint64](-1)
assert_equal(hasher.finish(), 0xf99ef16adf6a1eb3)
hasher.write_num[DType.uint64](1)
assert_equal(hasher.finish(), 0x78ba841a51645271)
hasher.write_num[DType.uint64](1)
hasher.write_num[DType.uint64](1)
assert_equal(hasher.finish(), 0xfefcf3c7372872d3)
fn test_write_medium() raises:
alias s = "hello, world!"
hasher = FoldHasher[global_seed](per_hasher_seed)
hasher.write(s.as_bytes())
assert_equal(hasher.finish(), 0xee0d4e3673d3a356)
alias t = "Darth Plagueis... was a Dark Lord of the Sith so powerful and so wise, he could use the Force to influence the midi-chlorians... to create... life."
hasher = FoldHasher[global_seed](per_hasher_seed)
hasher.write(t.as_bytes())
assert_equal(hasher.finish(), 0x2999a4553f8d31cb)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment