Deep Java Library — benchmarking

I’ve been exploring usage of DJL (Deep Java Library) and wanted to share a few basic benchmarks that I’ve captured while testing out the framework and its engine-agnostic approach to deep learning.

I really liked the idea of being able to load already trained models and run them inside of a JVM ecosystem. It was even better when I’ve learned that I can load models trained in MXNet, Tensorflow, PyTorch and ONNX.

Below benchmark does a few very basic measurements that I was interested to capture — multi-threaded inference on CPU and on GPU, how fast it is and what would be extra complexities of running a production inference service with a high throughput on a JVM.


  • AWS EC2 r5.4xlarge was used for CPU inference.
  • AWS EC2 g4dn.4xlarge was used for GPU inference.
  • Toch v1.7.1
  • traced_resnet18 model


│ device │ p50 │ p99 │ throughput │ threads │
│ cpu │ 115.300 ms │ 267.471 ms │ 2.09 │ 1 │
│ cpu │ 504.148 ms │ 983.214 ms │ 59.71 │ 32 │
│ cpu │ 242.713 ms │ 553.606 ms │ 57.37 │ 16 │
│ gpu │ 55.965 ms │ 360.626 ms │ 120.19 │ 10 │


  • More threads != more throughput — choose amount of threads for inference wisely. I found that running DJL predictors via number of threads less than number of available cores on the system gives better throughput.
  • GPU is a great way to speed up single-threaded inference, but not all engines support it. (MXNet github issue, PyTorch issue describing occasional slow down)
  • When it comes to DJL — I’ve tested PyTorch and MXNet engines’ support of multi-threaded GPU inference. While PyTorch works, MXNet engine crashes with a CUDA: an illegal memory access was encountered error.
  • Tensor allocation and garbage collection behaves very differently between deep learning engines. When using TensorFlow, DJL uses eager session and it is an expensive operation, also Keras models might be slower — I’ve seen it first hand and had to convert a model into ONNX to gain 10x performance improvement.

Command to execute a benchmark in the root of djl git repo, add -t N where N > 1 to indicate multi-threaded inference.

./gradlew benchmark -Dai.djl.default_engine=PyTorch -Dai.djl.repository.zoo.location= — args=’-c 10 -s 1,3,224,224'

Nvidia and CUDA details, the version of

NVIDIA-SMI 460.32.03    
Driver Version: 460.32.03
CUDA Version: 11.2
CuDNN: 7.6.5
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_19:24:38_PDT_2019
Cuda compilation tools, release 10.2, V10.2.89

Disabled. Software Engineer at Netflix focusing on AI. Co-founder of Vortle