Skip to content

Instantly share code, notes, and snippets.

@k06a
Created January 21, 2022 12:23
Show Gist options
  • Select an option

  • Save k06a/fd96f05d38bd25b232d13a027822de0c to your computer and use it in GitHub Desktop.

Select an option

Save k06a/fd96f05d38bd25b232d13a027822de0c to your computer and use it in GitHub Desktop.

Revisions

  1. k06a created this gist Jan 21, 2022.
    213 changes: 213 additions & 0 deletions lohi.h
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,213 @@
    #ifndef Lohi_h
    #define Lohi_h

    namespace lohi {
    template<typename T>
    struct Lohi;

    template<typename T>
    T zero(thread const T &) {
    return 0;
    }

    template<typename T>
    Lohi<T> zero(thread const Lohi<T> &a) {
    return {zero(a.lo), zero(a.hi)};
    }

    template<typename T>
    bool is_zero(thread const T &t) {
    return t == 0;
    }

    template<typename T>
    bool is_zero(thread const Lohi<T> &a) {
    return is_zero(a.lo) && is_zero(a.hi);
    }

    template<typename T>
    bool bit(thread const T &t, unsigned b) {
    return ((t >> b) & 1) != 0;
    }

    template<typename T>
    bool bit(thread const Lohi<T> &lohi, unsigned b) {
    return (b < sizeof(T)*8 ? bit(lohi.lo, b) : bit(lohi.hi, b - sizeof(T)*8));
    }

    template<typename T>
    int msb(T a) {
    if (is_zero(a)) {
    return sizeof(T)*8;
    }

    int res = 0;
    for (int i = sizeof(T)*4; i > 0; i /= 2) {
    T b = a >> i;
    if (!is_zero(b)) {
    res += i;
    a = b;
    }
    }
    return res;
    }

    template<typename T>
    int msb(thread const Lohi<T> &a) {
    int res = msb(a.hi);
    if (res == sizeof(T)*8) {
    res += msb(a.lo);
    }
    return res;
    }

    //

    template<typename T>
    struct Lohi {
    T lo;
    T hi;

    thread Lohi<T> & operator++() {
    if (is_zero(++lo)) {
    ++hi;
    }
    return *this;
    }

    thread Lohi<T> & operator--() {
    if (is_zero(lo)) {
    --hi;
    }
    --lo;
    return *this;
    }

    bool operator == (thread const Lohi<T> & rhs) const {
    return this->hi == rhs.hi && this->lo == rhs.lo;
    }

    bool operator != (thread const Lohi<T> & rhs) const {
    return this->hi != rhs.hi || this->lo != rhs.lo;
    }

    bool operator < (thread const Lohi<T> & rhs) const {
    return this->hi < rhs.hi || (this->hi == rhs.hi && this->lo < rhs.lo);
    }

    bool operator <= (thread const Lohi<T> & rhs) const {
    return *this < rhs || (*this == rhs);
    }

    bool operator > (thread const Lohi<T> & rhs) const {
    return this->hi > rhs.hi || (this->hi == rhs.hi && this->lo > rhs.lo);
    }

    bool operator >= (thread const Lohi<T> & rhs) const {
    return *this > rhs || (*this == rhs);
    }

    Lohi<T> operator | (thread const Lohi<T> & rhs) const {
    Lohi<T> result;
    result.lo = this->lo | rhs.lo;
    result.hi = this->hi | rhs.hi;
    return result;
    }

    Lohi<T> operator >> (int offset) const {
    Lohi<T> result;
    result.lo = (this->lo >> offset) | (this->hi << (sizeof(T) - offset));
    result.hi = this->hi >> offset;
    return result;
    }

    Lohi<T> operator << (int offset) const {
    Lohi<T> result;
    result.hi = this->hi << offset | (this->lo >> (sizeof(T) - offset));;
    result.lo = (this->lo << offset);
    return result;
    }

    Lohi<T> operator + (thread const Lohi<T> & rhs) const {
    Lohi<T> result;
    result.lo = this->lo + rhs.lo;
    result.hi = this->hi + rhs.hi;
    if (result.lo < this->lo) {
    ++result.hi;
    }
    return result;
    }

    Lohi<T> operator - (thread const Lohi<T> & rhs) const {
    Lohi<T> result;
    result.hi = this->hi - rhs.hi;
    result.lo = this->lo - rhs.lo;
    if (result.lo > this->lo) {
    --result.hi;
    }
    return result;
    }

    Lohi<Lohi<T>> operator * (thread const Lohi<T> & rhs) const {
    return _lohi_mul(*this, rhs);
    }
    };

    typedef Lohi<uint64_t> uint128_t;
    typedef Lohi<uint128_t> uint256_t;
    typedef Lohi<uint256_t> uint512_t;

    uint256_t u256(int t) {
    return { { uint64_t(t), 0 }, { 0, 0 } };
    }

    template<typename T>
    Lohi<T> _lohi_mul(thread const T & lhs, thread const T & rhs);

    template<>
    Lohi<uint64_t> _lohi_mul<uint64_t>(thread const uint64_t & lhs, thread const uint64_t & rhs) {
    uint64_t op1 = lhs;
    uint64_t op2 = rhs;

    uint64_t u1 = (op1 & 0xffffffff);
    uint64_t v1 = (op2 & 0xffffffff);
    uint64_t t = (u1 * v1);
    uint64_t w3 = (t & 0xffffffff);
    uint64_t k = (t >> 32);

    op1 >>= 32;
    t = (op1 * v1) + k;
    k = (t & 0xffffffff);
    uint64_t w1 = (t >> 32);

    op2 >>= 32;
    t = (u1 * op2) + k;
    k = (t >> 32);

    return Lohi<uint64_t> { (op1 * op2) + w1 + k, (t << 32) + w3 };
    }

    template<typename T>
    Lohi<Lohi<T>> _lohi_mul(thread const Lohi<T> & lhs, thread const Lohi<T> & rhs) {
    Lohi<Lohi<T>> result;

    // Define vars (depends of endianess)
    thread Lohi<T> & lo = (thread Lohi<T>&)(((thread T*)&result)[0]);
    thread Lohi<T> & mi = (thread Lohi<T>&)(((thread T*)&result)[1]);
    thread Lohi<T> & hi = (thread Lohi<T>&)(((thread T*)&result)[2]);
    thread T & highest = (thread T&)(((thread T*)&result)[3]);

    // Multiply action
    lo = _lohi_mul(lhs.lo, rhs.lo);
    hi = _lohi_mul(lhs.hi, rhs.hi);
    Lohi<T> m = _lohi_mul(lhs.lo + lhs.hi, rhs.lo + rhs.hi) - (lo + hi);
    mi = mi + m;
    if (mi < m) {
    ++highest;
    }

    return result;
    }
    }

    #endif /* Lohi_h */