Search

# Extending Numba Types for Clean, Fast Code

Updated: Dec 11, 2020

The scientific Python ecosystem relies on compiled code for fast execution. While it is still common to see C++ and C code contributed to the internals of Python libraries, much of the new compiled code written today uses a Python library like Cython or Numba to speed up Python code. Both of these libraries center around working with NumPy arrays, although both, with a little work, can be extended to work with custom data structures. As there are a number of blog posts comparing Cython and Numba, I'd like to draw attention to the extension capabilities of Numba and how it can be used to make clean, readable code. Numba's extensibility has been beautifully demonstrated by the Awkward Array library, which copies the NumPy array interface and can be used inside of jitted functions.

If we teach Numba how to use a custom data structure, instead of passing multiple related arguments, we can pass the whole data structure and use its functionality inside of a Numba-jitted function. This is particularly useful when we want to have access to a data structure in multiple algorithms. A case came up for me when I was writing sparse matrix multiplication kernel between a sparse matrix a and a dense NumPy array b. Here's a simplified version of the algorithm in plain Python. We iterate through each of the nonzeros in a and all of the elements of b.

```def csc_ndarray_dot(a: csc_matrix, b: np.ndarray):
out = np.zeros((a.shape, b.shape))
for j in range(b.shape):
for i in range(b.shape):
for k in range(a.indptr[i], a.indptr[i + 1]):
out[a.indices[k], j] += a.data[k] * b[i, j]
return out```

With all of the nested loops, this is code in desperate need of some acceleration. However, because we can't use sparse matrices inside of Numba or Cython, we have to modify the function. In the case of Numba, we have to separate out the attributes of the sparse matrix.

```@numba.jit(nopython=True)
def numba_csc_ndarray_dot(a_shape, a_data, a_indices, a_indptr, b):
...```

In the case of Cython, we have to separate out the attributes and add typing information. In both cases, the function gets messier.

```%%cython
import cython
import numpy as np
cimport numpy as np
@cython.boundscheck(False)
@cython.wraparound(False)
def cython_csc_ndarray_dot(tuple a_shape,
np.ndarray[double, ndim=1] a_data,
np.ndarray[long, ndim=1] a_indices,
np.ndarray[long, ndim=1] a_indptr,
np.ndarray[double, ndim=2] b):
cdef np.ndarray[double, ndim=2] out = np.zeros((a_shape,
b.shape))
cdef int i,j,k
for j in range(b.shape):
for i in range(b.shape):
for k in range(a_indptr[i], a_indptr[i + 1]):
out[a_indices[k], j] += a_data[k] * b[i, j]
return out```

Fortunately, with a little work, we can teach Numba how to use sparse matrices internally. First, we'll define a pure Python class for compressed sparse column matrices. To do this, we need four attributes:

1. data, a NumPy array that stores the nonzero values of the matrix;

2. indices, a NumPy array that stores the rows of each of the nonzero values;

3. indptr, a NumPy array that stores pointers to the beginning of each column; and

shape, a tuple that stores the dimensions of the matrix.

```class csc_matrix:
def __init__(self, data: np.array,
indices: np.array,
indptr: np.array,
shape: tuple):
self.data = data
self.indices = indices
self.indptr = indptr
self.shape = shape```

Since Numba doesn't deal with native Python types directly, we need to specify what the csc_matrix class looks like in Numba's types. To teach Numba to recognize the csc_matrix class, we'll define a new class that extends Numba's Type class. Here we specify the types of each of the attributes.

```from numba.core import types
class MatrixType(types.Type):
def __init__(self, dtype):
self.dtype = dtype
self.data = types.Array(self.dtype, 1, 'C')
self.indices = types.Array(types.int64, 1, 'C')
self.indptr = types.Array(types.int64, 1, 'C')
self.shape = types.UniTuple(types.int64, 2)
super(MatrixType, self).__init__('csc_matrix')```

For Numba to know that the csc_matrix should be typed as a MatrixType, we need to register that relationship:

```from numba.extending import typeof_impl

@typeof_impl.register(csc_matrix)
def typeof_matrix(val, c):
data = typeof_impl(val.data, c)
return MatrixType(data.dtype)```

The types that are used in nopython mode use data models (Numba-specific representations of the class). In this case, we'll extend the StructModel which is similar to a struct in C.

```from numba.extending import models
@register_model(MatrixType)
class MatrixModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
('data', fe_type.data),
('indices', fe_type.indices),
('indptr', fe_type.indptr),
('shape', fe_type.shape)
]
models.StructModel.__init__(self, dmm, fe_type, members)```

We have to specify relationships using a Numba convenience function to enable access to class attributes with the same identifiers within jitted functions.

```from numba.extending import make_attribute_wrapper
make_attribute_wrapper(MatrixType, 'data', 'data')
make_attribute_wrapper(MatrixType, 'indices', 'indices')
make_attribute_wrapper(MatrixType, 'indptr', 'indptr')
make_attribute_wrapper(MatrixType, 'shape', 'shape')```

Almost there! All that's left is to do is teach Numba how to make a native (Numba) matrix into a Python matrix and vice versa. This is called boxing and unboxing.

```def make_matrix(context, builder, typ, **kwargs):
return cgutils.create_struct_proxy(typ)(context,
builder, **kwargs)

@unbox(MatrixType)
def unbox_matrix(typ, obj, c):
data = c.pyapi.object_getattr_string(obj, "data")
indices = c.pyapi.object_getattr_string(obj, "indices")
indptr = c.pyapi.object_getattr_string(obj, "indptr")
shape = c.pyapi.object_getattr_string(obj, "shape")
matrix = make_matrix(c.context, c.builder, typ)
matrix.data = c.unbox(typ.data, data).value
matrix.indices = c.unbox(typ.indices, indices).value
matrix.indptr = c.unbox(typ.indptr, indptr).value
matrix.shape = c.unbox(typ.shape, shape).value
for att in [data, indices, indptr, shape]:
c.pyapi.decref(att)
is_error = cgutils.is_not_null(
c.builder, c.pyapi.err_occurred()```