Created
October 17, 2024 10:04
-
-
Save soraros/026cf24eb7a2acb9a823e8d57fac0e32 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # This is a port of https://github.com/orlp/foldhash | |
| from hashlib._ahash import _folded_multiply | |
| from sys import sizeof | |
| from sys.intrinsics import assume | |
| from testing import assert_equal | |
| from utils import Span | |
| alias ARBITRARY0: UInt64 = 0x243f6a8885a308d3 | |
| alias ARBITRARY1: UInt64 = 0x13198a2e03707344 | |
| alias ARBITRARY2: UInt64 = 0xa4093822299f31d0 | |
| alias ARBITRARY3: UInt64 = 0x082efa98ec4e6c89 | |
| alias ARBITRARY4: UInt64 = 0x452821e638d01377 | |
| alias ARBITRARY5: UInt64 = 0xbe5466cf34e90c6c | |
| alias ARBITRARY6: UInt64 = 0xc0ac29b7c97c50dd | |
| alias ARBITRARY7: UInt64 = 0x3f84d5b5b5470917 | |
| alias ARBITRARY8: UInt64 = 0x9216d5d98979fb1b | |
| alias ARBITRARY9: UInt64 = 0xd1310ba698dfb5ac | |
| alias U256 = SIMD[DType.uint64, 4] | |
| alias U128 = SIMD[DType.uint64, 2] | |
| fn into[dt: DType, //](x: Scalar[dt]) -> U128: | |
| return U128(x._float_to_bits[DType.uint64](), 0) | |
| fn shift_left(x: U128, n: UInt8) -> U128: | |
| u128 = __mlir_op.`pop.bitcast`[_type=__mlir_type.`!pop.scalar<ui128>`](x.value) | |
| n_ = __mlir_op.`pop.cast`[_type=__mlir_type.`!pop.scalar<ui128>`](n.value) | |
| shl = __mlir_op.`pop.shl`(u128, n_) | |
| return U128(__mlir_op.`pop.bitcast`[_type=__mlir_type.`!pop.simd<2, ui64>`](shl)) | |
| struct FoldHasher[key: U256]: | |
| var accumulator: UInt64 | |
| var sponge: U128 | |
| var sponge_len: UInt8 | |
| var fold_seed: UInt64 | |
| var expand_seed: UInt64 | |
| var expand_seed2: UInt64 | |
| var expand_seed3: UInt64 | |
| fn __init__(inout self, per_hash_key: UInt64): | |
| self.accumulator = per_hash_key | |
| self.sponge = 0 | |
| self.sponge_len = 0 | |
| self.fold_seed = key[0] | |
| self.expand_seed = key[1] | |
| self.expand_seed2 = key[2] | |
| self.expand_seed3 = key[3] | |
| ... | |
| fn write_num[dtype: DType](inout self, x: Scalar[dtype]): | |
| alias bits: UInt = 8 * sizeof[dtype]() | |
| if self.sponge_len + bits > 128: | |
| self.write_u128(self.sponge) | |
| self.sponge = into(x) | |
| self.sponge_len = 0 | |
| else: | |
| # self.sponge |= into(x) << self.sponge_len | |
| self.sponge |= shift_left(into(x), self.sponge_len) | |
| self.sponge_len += bits | |
| fn write_u128(inout self, n: U128): | |
| self.accumulator = _folded_multiply(n[0] ^ self.accumulator, n[1] ^ self.fold_seed); | |
| fn write[o: ImmutableOrigin](inout self, bytes: Span[Byte, o]): | |
| s0 = self.accumulator | |
| s1 = self.expand_seed | |
| p = bytes.unsafe_ptr() | |
| l = len(bytes) | |
| if l <= 16: | |
| if l >= 8: | |
| s0 ^= p.bitcast[UInt64]().load() | |
| s1 ^= p.offset(l - 8).bitcast[UInt64]().load() | |
| elif l >= 4: | |
| s0 ^= p.bitcast[UInt32]().load().cast[DType.uint64]() | |
| s1 ^= p.offset(l - 4).bitcast[UInt32]().load().cast[DType.uint64]() | |
| elif l > 0: | |
| lo = p.load() | |
| mid = p.offset(l // 2).load() | |
| hi = p.offset(l - 1).load() | |
| s0 ^= lo.cast[DType.uint64]() | |
| s1 ^= (hi.cast[DType.uint64]() << 8) | mid.cast[DType.uint64]() | |
| self.accumulator = _folded_multiply(s0, s1) | |
| elif l < 256: | |
| self.accumulator = hash_bytes_medium(bytes, s0, s1, self.fold_seed) | |
| else: | |
| self.accumulator = hash_bytes_long( | |
| bytes, s0, s1, self.expand_seed2, self.expand_seed3, self.fold_seed) | |
| fn finish(self) -> UInt64: | |
| if self.sponge_len > 0: | |
| lo = self.sponge[0] | |
| hi = self.sponge[1] | |
| return _folded_multiply(lo ^ self.accumulator, hi ^ self.fold_seed) | |
| else: | |
| return self.accumulator | |
| fn hash_bytes_medium( | |
| bytes: Span[Byte], | |
| inout s0: UInt64, inout s1: UInt64, | |
| fold_seed: UInt64 | |
| ) -> UInt64: | |
| assume(len(bytes) >= 16) | |
| p = bytes.unsafe_ptr() | |
| q = p.offset(len(bytes) - 16) | |
| while p <= q: | |
| v1 = p.bitcast[UInt64]().load[width=2]() | |
| v2 = q.bitcast[UInt64]().load[width=2]() | |
| s0 = _folded_multiply(v1[0] ^ s0, v2[0] ^ fold_seed) | |
| s1 = _folded_multiply(v1[1] ^ s1, v2[1] ^ fold_seed) | |
| p += 16 | |
| q -= 16 | |
| return s0 ^ s1 | |
| import math | |
| fn hash_bytes_long( | |
| bytes: Span[Byte], | |
| inout s0: UInt64, inout s1: UInt64, | |
| inout s2: UInt64, inout s3: UInt64, | |
| fold_seed: UInt64 | |
| ) -> UInt64: | |
| assume(len(bytes) >= 16) | |
| end = math.align_down(len(bytes), 64) | |
| p = bytes.unsafe_ptr() | |
| for _ in range(0, end, 64): | |
| v = p.bitcast[UInt64]().load[width=8]() | |
| s0 = _folded_multiply(v[0] ^ s0, v[4] ^ fold_seed) | |
| s1 = _folded_multiply(v[1] ^ s1, v[5] ^ fold_seed) | |
| s2 = _folded_multiply(v[2] ^ s2, v[6] ^ fold_seed) | |
| s3 = _folded_multiply(v[3] ^ s3, v[7] ^ fold_seed) | |
| p += 64 | |
| s0 ^= s2 | |
| s1 ^= s3 | |
| if end < len(bytes): | |
| return hash_bytes_medium(bytes, s0, s1, fold_seed) | |
| else: | |
| return s0 ^ s1 | |
| fn main() raises: | |
| test_write_num() | |
| test_write_medium() | |
| alias global_seed = SIMD[DType.uint64, 4]( | |
| 0x243f6a8885a308d3, | |
| 0x13198a2e03707344, | |
| 0xa4093822299f31d0, | |
| 0x082efa98ec4e6c89, | |
| ) | |
| alias per_hasher_seed: UInt64 = 0x7f1b8b2f7e0e3b3d | |
| fn test_write_num() raises: | |
| hasher = FoldHasher[global_seed](per_hasher_seed) | |
| hasher.write_num[DType.uint8](1) | |
| assert_equal(hasher.finish(), 0xda5d77f425cf5108) | |
| hasher.write_num[DType.uint8](1) | |
| assert_equal(hasher.finish(), 0x9dc8ff7a88b83c2c) | |
| hasher = FoldHasher[global_seed](per_hasher_seed) | |
| hasher.write_num[DType.uint16](-1) | |
| assert_equal(hasher.finish(), 0xa529a2fa72dfcd13) | |
| hasher.write_num[DType.uint16](1) | |
| assert_equal(hasher.finish(), 0x0ea0d85769a320d3) | |
| hasher = FoldHasher[global_seed](per_hasher_seed) | |
| hasher.write_num[DType.uint32](-1) | |
| assert_equal(hasher.finish(), 0xd4893b4fa665e296) | |
| hasher.write_num[DType.uint32](1) | |
| assert_equal(hasher.finish(), 0x2e2c02a342a4550e) | |
| hasher = FoldHasher[global_seed](per_hasher_seed) | |
| hasher.write_num[DType.uint64](-1) | |
| assert_equal(hasher.finish(), 0xf99ef16adf6a1eb3) | |
| hasher.write_num[DType.uint64](1) | |
| assert_equal(hasher.finish(), 0x78ba841a51645271) | |
| hasher.write_num[DType.uint64](1) | |
| hasher.write_num[DType.uint64](1) | |
| assert_equal(hasher.finish(), 0xfefcf3c7372872d3) | |
| fn test_write_medium() raises: | |
| alias s = "hello, world!" | |
| hasher = FoldHasher[global_seed](per_hasher_seed) | |
| hasher.write(s.as_bytes()) | |
| assert_equal(hasher.finish(), 0xee0d4e3673d3a356) | |
| alias t = "Darth Plagueis... was a Dark Lord of the Sith so powerful and so wise, he could use the Force to influence the midi-chlorians... to create... life." | |
| hasher = FoldHasher[global_seed](per_hasher_seed) | |
| hasher.write(t.as_bytes()) | |
| assert_equal(hasher.finish(), 0x2999a4553f8d31cb) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment