Skip to content

Instantly share code, notes, and snippets.

@jan-glx
Created August 26, 2015 21:02
Show Gist options
  • Select an option

  • Save jan-glx/e39f2b0ef23b10ee7e13 to your computer and use it in GitHub Desktop.

Select an option

Save jan-glx/e39f2b0ef23b10ee7e13 to your computer and use it in GitHub Desktop.

Revisions

  1. jan-glx created this gist Aug 26, 2015.
    71 changes: 71 additions & 0 deletions speed_dist2.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,71 @@
    __author__ = 'jan'


    import numpy as np
    import scipy.weave

    def grad_dist2(ls, x1, x2=None):
    if x2 is None:
    x2 = x1

    # Rescale.
    x1 = x1 / ls
    x2 = x2 / ls

    N = x1.shape[0]
    M = x2.shape[0]
    D = x1.shape[1]
    gX = np.zeros((x1.shape[0],x2.shape[0],x1.shape[1]))

    code = \
    """
    for (int i=0; i<N; i++)
    for (int j=0; j<M; j++)
    for (int d=0; d<D; d++)
    gX(i,j,d) = (2/ls(d))*(x1(i,d) - x2(j,d));
    """
    try:
    scipy.weave.inline(code, ['x1','x2','gX','ls','M','N','D'], \
    type_converters=scipy.weave.converters.blitz, \
    compiler='gcc')
    except:
    # The C code weave above is 10x faster than this:
    for i in xrange(0,x1.shape[0]):
    gX[i,:,:] = 2*(x1[i,:] - x2[:,:])*(1/ls)

    return gX

    def grad_dist3(ls, x1, x2=None):
    if x2 is None:
    x2 = x1

    # Rescale.
    x1 = x1 / ls
    x2 = x2 / ls

    N = x1.shape[0]
    M = x2.shape[0]
    D = x1.shape[1]
    gX = np.zeros((x1.shape[0],x2.shape[0],x1.shape[1]))


    # The C code weave above is 10x faster than this:
    for i in xrange(0,x1.shape[0]):
    gX[i,:,:] = 2*(x1[i,:] - x2[:,:])*(1/ls)

    return gX

    x1=np.random.randn(400,300)
    x2=np.random.randn(500,300)
    ls=3.0

    gX=grad_dist2(ls, x1, x2)

    gX2=((x1*2/ls**2)[:,None,:]-(x2*2/ls**2)[None,:,:])


    import timeit
    print(timeit.timeit("gX=grad_dist2(ls, x1, x2)","from __main__ import *",number=10))
    print(timeit.timeit("gX=grad_dist3(ls, x1, x2)","from __main__ import *",number=10))
    print(timeit.timeit("gX2=((x1*2/ls**2)[:,None,:]-(x2*2/ls**2)[None,:,:])","from __main__ import *",number=10))
    print(np.allclose(gX,gX2))