Beating TensorFlow Training in-VRAM
In this post, I’d like to introduce a technique that I’ve found helps accelerate mini-batch SGD training in my use case. I suppose this post could also be read as a public grievance directed towards the TensorFlow Dataset API optimizing for the large vision deep learning use-case, but maybe I’m just not hitting the right incantation to get
tf.Dataset working (in which case, drop me a line). The solution is to TensorFlow harder anyway, so this shouldn’t really be read as a complaint.
Nonetheless, if you are working with a new-ish GPU that has enough memory to hold a decent portion of your data alongside your neural network, you may find the final training approach I present here useful. The experiments I’ve run fall exactly in line with this “in-VRAM” use case (in particular, I’m training deep reinforcement learning value and policy networks on semi-toy environments, whose training profile is many iterations of training on a small replay buffer of examples). For some more context, you may want to check out an article on the TensorForce blog, which suggests that RL people should be building more of their TF graphs like this.
Briefly, if you have a dataset that fits into a GPU’s memory, you’re giving away a lot of speed with the usual TensorFlow pipelining or data-feeding approach, where the CPU delivers mini-batches whose forward/backward passes are computed on GPUs. This gets worse as you move to pricier GPUs, whose relative CPU-GPU bandwidth-to-GPU-speed ratio drops. Pretty easy change for a 2x.
Let’s get to it. With numbers similar to my use case, 5 epochs of training take about 16 seconds with the standard
feed_dict approach, 12-20 seconds with the TensorFlow Dataset API, and 8 seconds with a custom TensorFlow control-flow construct.
This was tested on an Nvidia Tesla P100 with a compiled TensorFlow 1.4.1 (CUDA 9, cuDNN 7), Python 3.5. Here is the test script. I didn’t test it too many times (exec trace). Feel free to change the data sizes to see if the proposed approach would still help in your setting.
Let’s fix the toy benchmark supervised task we’re looking at:
This is the (docs-discouraged) approach that everyone really uses for training. Prepare a mini-batch on the CPU, ship it off to the GPU. Note code here and below is excerpted (see the test script link above for the full code). It won’t work if you just copy it.
This drops whole-dataset loss from around 4500 to around 4, taking around 16 seconds for training. You might worry that random-number generation might be taking a while, but excluding that doesn’t drop the time more than 0.5 seconds.
Dataset API Approach
With the dataset API, we set up a pipeline where TensorFlow orchestrates some dataflow by synergizing more buzzwords on its worker threads. This should constantly feed the GPU by staging the next mini-batch while the current one is sitting on the GPU. This might be the case when there’s a lot of data, but it doesn’t seem to work very well when the data is small and GPU-CPU latency, not throughput, is the bottleneck.
Another unpleasant thing to deal with is that all those orchestrated workers and staging areas and buffers and shuffle queues need magic constants to work well. I tried my best, but it seems like performance is very sensitive with this use case. This could be fixed if Dataset detected (or could be told) it could be placed onto the GPU, and then it did so.
For a small
1000, this trains in around 12 seconds. But then it’s not actually shuffling the data too well (since all data points can only move by a position of 1000). Still, the loss drops from around 4500 to around 4, as in the
feed_dict case. A large
1000000, which you’d think should effectively move the dataset onto the GPU entirely, performs worse than
feed_dict at around 20 seconds.
I don’t think I’m unfair in counting
it.initializer time in my benchmark (which isn’t that toy, either, since it’s similar to my RL use case size). All the training methods need to load the data onto the GPU, and the data isn’t available until run time.
Using a TensorFlow Loop
This post isn’t a tutorial on
tf.while_loop and friends, but this code does what was promised: just feed everything once into the GPU and do all your epochs without asking for permission to continue from the CPU.
This one crushes at around 8 seconds, dropping loss again from around 4500 to around 4.
It’s pretty clear Dataset isn’t feeding as aggressively as it can, and its many widgets and knobs don’t help (well, they do, but only after making me do more work). But, if TF wants to invalidate this blog post, I suppose it could add yet another option that plops the dataset into the GPU.