Skip to content

Instantly share code, notes, and snippets.

@dbalabka
Last active April 22, 2021 11:14
Show Gist options
  • Select an option

  • Save dbalabka/9c0a3d88312f6d51af8949e85758aa7a to your computer and use it in GitHub Desktop.

Select an option

Save dbalabka/9c0a3d88312f6d51af8949e85758aa7a to your computer and use it in GitHub Desktop.

Revisions

  1. dbalabka revised this gist Apr 22, 2021. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion bootstrap_numba.py
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,4 @@
    import ecentria.abtest.bootstrap as bootstrap
    import scikits.bootstrap as bootstrap
    import numpy as np
    import time
    import numba
  2. dbalabka revised this gist Apr 22, 2021. No changes.
  3. dbalabka created this gist Apr 22, 2021.
    48 changes: 48 additions & 0 deletions bootstrap_numba.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,48 @@
    import ecentria.abtest.bootstrap as bootstrap
    import numpy as np
    import time
    import numba

    @numba.njit(parallel=True, fastmath=True)
    def _calculate_boostrap_mean_stat(data: np.ndarray, n_samples: int) -> np.ndarray:
    n = data.shape[0]
    stat = np.zeros(n_samples)
    for i in numba.prange(n_samples):
    stat[i] = np.random.choice(data, n).mean()
    return stat

    tdata = (np.random.randint(0, 5, 100_000), )
    n_samples = 10_000

    start = time.time()
    bootindices = bootstrap.bootstrap_indices(tdata[0], n_samples)
    stat_old = np.array([np.mean(*(x[indices] for x in tdata))
    for indices in bootindices])
    end = time.time()
    print(end - start)


    start = time.time()
    rng = np.random.default_rng()
    stat_new1 = np.array([np.mean(*(rng.choice(x, tdata[0].shape[0]) for x in tdata)) for _ in range(n_samples)])
    end = time.time()
    print(end - start)

    start = time.time()
    stat_new2 = _calculate_boostrap_mean_stat(tdata[0], n_samples)
    end = time.time()
    print(end - start)

    print(f'{stat_old.shape} == {stat_new1.shape}')
    print(f'{stat_old.shape} == {stat_new2.shape}')
    print(f'{stat_old.mean()} == {stat_new1.mean()}')
    print(f'{stat_old.mean()} == {stat_new2.mean()}')

    assert stat_old.shape == stat_new1.shape
    assert stat_old.shape == stat_new2.shape
    assert round(stat_old.mean(), 3) == round(stat_new1.mean(), 3)
    assert round(stat_old.mean(), 3) == round(stat_new2.mean(), 3)

    # Numba debug
    # bootstrap._calculate_boostrap_mean_stat.parallel_diagnostics(level=4)
    # bootstrap._calculate_boostrap_mean_stat.inspect_types()