This evening, I saw a Tweet about using `numba`

, and I thought, it's about time I give it a proper shot. I had been solving some dynamic programming problems just for fun, and I thought this would be a good test case for `numba`

's capabilities.

The DP problem I was trying to solve was that of collecting apples on a grid. Here's how the problem is posed:

I have a number of apples distributed randomly on a grid. I start at the top-left hand corner, and I'm only allowed to move downwards or to the right. Along the way, I pick up apples. What's the maximum number of apples I can pick up along the way?

This is a classic 2-dimensional DP problem. I simulated some random integers:

n = 200 arr = np.random.randint(low=0, high=100, size=n**2).reshape(n, n)

I then wrote out my solution, and wrapped it in two versions of the function call: one native and one numba-JIT'd.

# Let's collect apples. from numba import jit @jit def collect_apples(arr): sum_apples = np.zeros(shape=arr.shape) for row in range(arr.shape[0]): for col in range(arr.shape[1]): if col != 0: val_left = arr[row, col - 1] else: val_left = 0 if row != 0: val_up = arr[row - 1, col] else: val_up = 0 sum_apples[row, col] = arr[row, col] + max(val_left, val_up) return sum_apples def collect_apples_nonjit(arr): sum_apples = np.zeros(shape=arr.shape) for row in range(arr.shape[0]): for col in range(arr.shape[1]): if col != 0: val_left = arr[row, col - 1] else: val_left = 0 if row != 0: val_up = arr[row - 1, col] else: val_up = 0 sum_apples[row, col] = arr[row, col] + max(val_left, val_up) return sum_apples

Here's the performance results:

%%timeit collect_apples(arr)

*The slowest run took 4.27 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 99.7 µs per loop*

%%timeit collect_apples_nonjit(arr)

*10 loops, best of 3: 50.3 ms per loop*

Wow! Over 500-fold speedup! All obtained for free using the `@jit`

decorator.