Deep Java Library — benchmarking

I’ve been exploring usage of DJL (Deep Java Library) and wanted to share a few basic benchmarks that I’ve captured while testing out the framework and its engine-agnostic approach to deep learning.

I really liked the idea of being able to load already trained models and run them inside of a JVM ecosystem. It was even better when I’ve learned that I can load models trained in MXNet, Tensorflow, PyTorch and ONNX.

Below benchmark does a few very basic measurements that I was interested to capture — multi-threaded inference on CPU and on GPU, how fast it is and what would be extra complexities of running a production inference service with a high throughput on a JVM.

Setup:

  • AWS EC2 r5.4xlarge was used for CPU inference.
  • AWS EC2 g4dn.4xlarge was used for GPU inference.
  • Toch v1.7.1
  • traced_resnet18 model

TLDR;

Observations:

  • More threads != more throughput — choose amount of threads for inference wisely. I found that running DJL predictors via number of threads less than number of available cores on the system gives better throughput.
  • GPU is a great way to speed up single-threaded inference, but not all engines support it. (MXNet github issue, PyTorch issue describing occasional slow down)
  • When it comes to DJL — I’ve tested PyTorch and MXNet engines’ support of multi-threaded GPU inference. While PyTorch works, MXNet engine crashes with a CUDA: an illegal memory access was encountered error.
  • Tensor allocation and garbage collection behaves very differently between deep learning engines. When using TensorFlow, DJL uses eager session and it is an expensive operation, also Keras models might be slower — I’ve seen it first hand and had to convert a model into ONNX to gain 10x performance improvement.

Command to execute a benchmark in the root of djl git repo, add -t N where N > 1 to indicate multi-threaded inference.

Nvidia and CUDA details, the version of