# Extending Numba Types for Clean, Fast Code

Updated: Dec 11, 2020

**The scientific Python ecosystem relies on compiled code for fast execution. While it is still common to see C++ and C code contributed to the internals of Python libraries, much of the new compiled code written today uses a Python library like **__Cython__** or **__Numba__** to speed up Python code. Both of these libraries center around working with NumPy arrays, although both, with a little work, can be extended to work with custom data structures. As there are a number of blog posts comparing Cython and Numba, I'd like to draw attention to the **__extension capabilities of Numba__** and how it can be used to make clean, readable code. Numba's extensibility has been beautifully demonstrated by the **__Awkward Array__** library, ****which**** copies the NumPy array interface and can be used inside of jitted functions.**

**If we teach Numba how to use a custom data structure, instead of passing multiple related arguments, we can pass the whole data structure and use its functionality inside of a Numba-jitted function. This is particularly useful when we want to have access to a data structure in multiple algorithms. A case came up for me when I was writing sparse matrix multiplication kernel between a sparse matrix a and a dense NumPy array b. Here's a simplified version of the algorithm in plain Python. We iterate through each of the nonzeros in a and all of the elements of b. **

**def ****csc_ndarray_dot****(****a****:**** csc_matrix****,**** b****:**** np****.****ndarray****)****:**
** out ****=**** np****.****zeros****(****(****a****.****shape****[****0****]****,**** b****.****shape****[****1****]****)****)**
** ****for**** j ****in**** ****range****(****b****.****shape****[****1****]****)****:**
** ****for**** i ****in**** ****range****(****b****.****shape****[****0****]****)****:**
** ****for**** k ****in**** ****range****(****a****.****indptr****[****i****]****,**** a****.****indptr****[****i ****+**** ****1****]****)****:**
** out****[****a****.****indices****[****k****]****,**** j****]**** ****+=**** a****.****data****[****k****]**** ********* b****[****i****,**** j****]**
**return**** out**

**With all of the nested loops, this is code in desperate need of some acceleration. However, because we can't use sparse matrices inside of Numba or Cython, we have to modify the function. In the case of Numba, we have to separate out the attributes of the sparse matrix.**

**@numba****.****jit****(****nopython****=****True****)**
**def ****numba_csc_ndarray_dot****(****a_shape****,**** a_data****,**** a_indices****,**** a_indptr****,**** b****)****:**
** ****...**

**In the case of Cython, we have to separate out the attributes and add typing information. In both cases**,** the function gets messier.**

**%****%****cython**
**import**** cython**
**import**** numpy ****as**** np**
**cimport numpy ****as**** np**
**@cython****.****boundscheck****(****False****)**
**@cython****.****wraparound****(****False****)**** **
**def ****cython_csc_ndarray_dot****(****tuple a_shape****,**
**np****.****ndarray****[****double****,**** ndim****=****1****]**** a_data****,**
**np****.****ndarray****[****long****,**** ndim****=****1****]**** a_indices****,**
**np****.****ndarray****[****long****,**** ndim****=****1****]**** a_indptr****,**
**np****.****ndarray****[****double****,**** ndim****=****2****]**** b****)****:**
** cdef np****.****ndarray****[****double****,**** ndim****=****2****]**** out ****=**** np****.****zeros****(****(****a_shape****[****0****]****,**
** b****.****shape****[****1****]****)****)**
** cdef int i****,****j****,****k**
** ****for**** j ****in**** ****range****(****b****.****shape****[****1****]****)****:**
** ****for**** i ****in**** ****range****(****b****.****shape****[****0****]****)****:**
** ****for**** k ****in**** ****range****(****a_indptr****[****i****]****,**** a_indptr****[****i ****+**** ****1****]****)****:**
** out****[****a_indices****[****k****]****,**** j****]**** ****+=**** a_data****[****k****]**** ********* b****[****i****,**** j****]**
**return**** out**

**Fortunately, with a little work, we can teach Numba how to use sparse matrices internally. First, we'll define a pure Python class for compressed sparse column matrices. To do this, we need four attributes:**

**data, a NumPy array that stores the nonzero values of the matrix;****indices, a NumPy array that stores the rows of each of the nonzero values;****indptr, a NumPy array that stores pointers to the beginning of each column; and**

**shape, a tuple that stores the dimensions of the matrix. **

**class**** ****csc_matrix****:**
** def ****__init__****(****self****,**** data****:**** np****.****array****,**
** indices****:**** np****.****array****,**
** indptr****:**** np****.****array****,**
** shape****:**** tuple****)****:**
** self****.****data ****=**** data**
** self****.****indices ****=**** indices**
** self****.****indptr ****=**** indptr**
** self****.****shape ****=**** shape**

**Since Numba doesn't deal with native Python types directly, we need to specify what the csc_matrix class looks like in Numba's types. To teach Numba to recognize the csc_matrix class, we'll define a new class that extends Numba's Type class. Here we specify the types of each of the attributes.**

**from**** numba****.****core ****import**** types**
**class**** ****MatrixType****(****types****.****Type****)****:**
** def ****__init__****(****self****,**** dtype****)****:**
** self****.****dtype ****=**** dtype**
** self****.****data ****=**** types****.****Array****(****self****.****dtype****,**** ****1****,**** ****'C'****)**
** self****.****indices ****=**** types****.****Array****(****types****.****int64****,**** ****1****,**** ****'C'****)**
** self****.****indptr ****=**** types****.****Array****(****types****.****int64****,**** ****1****,**** ****'C'****)**
** self****.****shape ****=**** types****.****UniTuple****(****types****.****int64****,**** ****2****)**
** ****super****(****MatrixType****,**** self****)****.****__init__****(****'csc_matrix'****)**

**For Numba to know that the csc_matrix should be typed as a MatrixType, we need to register that relationship:**

**from**** numba****.****extending ****import**** typeof_impl**
**@typeof_impl****.****register****(****csc_matrix****)**
**def ****typeof_matrix****(****val****,**** c****)****:**
** data ****=**** ****typeof_impl****(****val****.****data****,**** c****)**
** ****return**** ****MatrixType****(****data****.****dtype****)**

**The types that are used in nopython mode use data models (Numba-specific representations of the class). In this case, we'll extend the StructModel which is similar to a struct in C.**

**from**** numba****.****extending ****import**** models**
**@****register_model****(****MatrixType****)**
**class**** ****MatrixModel****(****models****.****StructModel****)****:**
** def ****__init__****(****self****,**** dmm****,**** fe_type****)****:**
** members ****=**** ****[**
**(****'data'****,**** fe_type****.****data****)****,**
**(****'indices'****,**** fe_type****.****indices****)****,**
**(****'indptr'****,**** fe_type****.****indptr****)****,**
**(****'shape'****,**** fe_type****.****shape****)**
** ****]**
** models****.****StructModel****.****__init__****(****self****,**** dmm****,**** fe_type****,**** members****)**

**We have to specify relationships using a Numba convenience function to enable access to class attributes with the same identifiers within jitted functions.**

**from**** numba****.****extending ****import**** make_attribute_wrapper**
**make_attribute_wrapper****(****MatrixType****,**** ****'data'****,**** ****'data'****)**
**make_attribute_wrapper****(****MatrixType****,**** ****'indices'****,**** ****'indices'****)**
**make_attribute_wrapper****(****MatrixType****,**** ****'indptr'****,**** ****'indptr'****)**
**make_attribute_wrapper****(****MatrixType****,**** ****'shape'****,**** ****'shape'****)**

**Almost there! All that's left is to do is teach Numba how to make a native (Numba) matrix into a Python matrix and vice versa. This is called boxing and unboxing. **

**def ****make_matrix****(****context****,**** builder****,**** typ****,**** **********kwargs****)****:**
**return**** cgutils****.****create_struct_proxy****(****typ****)****(****context****,**
** builder****,**** **********kwargs****)**
**@****unbox****(****MatrixType****)**
**def ****unbox_matrix****(****typ****,**** obj****,**** c****)****:**
**data ****=**** c****.****pyapi****.****object_getattr_string****(****obj****,**** ****"data"****)**
**indices ****=**** c****.****pyapi****.****object_getattr_string****(****obj****,**** ****"indices"****)**
**indptr ****=**** c****.****pyapi****.****object_getattr_string****(****obj****,**** ****"indptr"****)**
**shape ****=**** c****.****pyapi****.****object_getattr_string****(****obj****,**** ****"shape"****)**
**matrix ****=**** ****make_matrix****(****c****.****context****,**** c****.****builder****,**** typ****)**
**matrix****.****data ****=**** c****.****unbox****(****typ****.****data****,**** data****)****.****value**
**matrix****.****indices ****=**** c****.****unbox****(****typ****.****indices****,**** indices****)****.****value**
**matrix****.****indptr ****=**** c****.****unbox****(****typ****.****indptr****,**** indptr****)****.****value**
**matrix****.****shape ****=**** c****.****unbox****(****typ****.****shape****,**** shape****)****.****value**
**for**** att ****in**** ****[****data****,**** indices****,**** indptr****,**** shape****]****:**
**c****.****pyapi****.****decref****(****att****)**
**is_error ****=**** cgutils****.****is_not_null****(**
** c****.****builder****,**** c****.****pyapi****.****err_occurred****(****)**