By Tom Augspurger
This work is supported byAnaconda, Inc.
This post describes a recent improvement made to TPOT. TPOT is an automated machine learning library for Python. It does some feature engineering and hyper-parameter optimization for you. TPOT uses genetic algorithms to evaluate which models are performing well and how to choose new models to try out in the next generation.
In TPOT-730, we made some modifications to TPOT to support distributed training. As a TPOT user, the only changes you need to make to your code are
- Connect a client to your Dask Cluster
- Specify the
use_dask=Trueargument to your TPOT estimator
From there, all the training will use your cluster of machines. This screencast shows an example on an 80-core Dask cluster.
Fitting a TPOT estimator consists of several stages. The bulk of the time is spent evaluating individual scikit-learn pipelines. Dask-ML already had code for splitting apart a scikit-learn
Pipeline.fit call into individual tasks. This is used in Dask-ML’s hyper-parameter optimization to avoid repeating work. We were able to drop-in Dask-ML’s fit and scoring method for the one already used in TPOT. That small change allows fitting the many individual models in a generation to be done on a cluster.
There’s still some room for improvement. Internal to TPOT, some time is spent determining the next set of models to try out (this is the “mutation and crossover phase”). That’s not (yet) been parallelized with Dask, so you’ll notice some periods of inactivity on the cluster.
This will be available in the next release of TPOT. You can try out a small example now on the dask-examples binder.
Stepping back a bit, I think this is a good example of how libraries can use Dask internally to parallelize workloads for their users. Deep down in TPOT there was a single method for fitting many scikit-learn models on some data and collecting the results. Dask-ML has code for building a task graph that does the same thing. We were able to swap out the eager TPOT code for the lazy Dask version, and get things distributed on a cluster. Projects like xarray have been able to do a similar thing with dask Arrays in place of NumPy arrays. If Dask-ML hadn’t already had that code,
dask.delayed could have been used instead.
If you have a library that you think could take advantage of Dask, please reach out!