Numba: My first attempt at being serious with it

written by Eric J. Ma on 2017-02-08

This evening, I saw a Tweet about using numba, and I thought, it's about time I give it a proper shot. I had been solving some dynamic programming problems just for fun, and I thought this would be a good test case for numba's capabilities.

The DP problem I was trying to solve was that of collecting apples on a grid. Here's how the problem is posed:

I have a number of apples distributed randomly on a grid. I start at the top-left hand corner, and I'm only allowed to move downwards or to the right. Along the way, I pick up apples. What's the maximum number of apples I can pick up along the way?

This is a classic 2-dimensional DP problem. I simulated some random integers:

n = 200
arr = np.random.randint(low=0, high=100, size=n**2).reshape(n, n)

I then wrote out my solution, and wrapped it in two versions of the function call: one native and one numba-JIT'd.

# Let's collect apples.

from numba import jit

@jit
def collect_apples(arr):
    sum_apples = np.zeros(shape=arr.shape)
    for row in range(arr.shape[0]):
        for col in range(arr.shape[1]):
            if col != 0:
                val_left = arr[row, col - 1]
            else:
                val_left = 0
            if row != 0:
                val_up = arr[row - 1, col]
            else:
                val_up = 0
            sum_apples[row, col] = arr[row, col] + max(val_left, val_up)
    return sum_apples

def collect_apples_nonjit(arr):
    sum_apples = np.zeros(shape=arr.shape)
    for row in range(arr.shape[0]):
        for col in range(arr.shape[1]):
            if col != 0:
                val_left = arr[row, col - 1]
            else:
                val_left = 0
            if row != 0:
                val_up = arr[row - 1, col]
            else:
                val_up = 0
            sum_apples[row, col] = arr[row, col] + max(val_left, val_up)
    return sum_apples

Here's the performance results:

%%timeit
collect_apples(arr)

The slowest run took 4.27 times longer than the fastest. This could mean that an intermediate result is being cached. 10000 loops, best of 3: 99.7 ┬Ás per loop

%%timeit
collect_apples_nonjit(arr)

10 loops, best of 3: 50.3 ms per loop

Wow! Over 500-fold speedup! All obtained for free using the @jit decorator.