Thursday, May 7, 2015

On Reading and Understanding Academic Papers: Batch Normalisation

So, keras.io has a link in its documentation. In the section on the "Batch Normalisation" layer, there is a hyperlink to a PDF of an academic paper on the use and effectiveness of this approach. Tim Berners-Lee would be proud.

I followed that link. I am soooo ignorant when it comes to understanding maths properly. Don't get me wrong, I'm not completely foreign to it -- I work with numerical data all the time, and have a printout of the unit circle and associated trig functions on my desk. I use it multiple times per year. I remember nothing of it between those times.

Reading a paper requires me to have a degree of easy familiarity with mathematical concepts which I just don't have. Let me quote from introduction of the paper. I needed to take an image snapshot to deal with the mathematical notation (sorry Tim!)


Now, as far as I can tell, this paper is a top piece of research and a wonderful presentation of a highly useful and relevant concept. Parts of this I can understand, parts I don't. SGD was already defined as stochastic gradient descent, by the way.

The first symbol is a theta. I know from prior experience plus the description in the text it refers not to a number, but to some kind of collection of numbers being the parameters of the network. I'm not sure if it's a matrix, or several numbers, or what parameters it means exactly. Arg min, I think, means a magical function which returns "The value of theta that minimises the result of the following expression". I'm reading this like a programmer, see.

Okay, then the 1/n multiplied by the sum of a function between 1 and N. This is otherwise known as "the average". I think N refers to the number of training examples, and i is an index parameter into the current example.

I have no clue what the 'l' function is. I'm going to guess from the text it means 'loss function'.

So, unpacked, this means "The network parameters which minimise the mean average loss across the training data".

What's unclear to me is how the mathematical notation actually helps here. Surely the statement "stochastic gradient descent minimises the mean average loss over the training data" is actually more instructive to both the mathematical and casual reader than this function notation?

Now, I can eventually unpack most parts of the paper, slowly, and one-by-one. Writing this post genuinely helped me grok what was going on better. I haven't actually gotten to the section on batch normalisation yet, of course. I'll read on ... casual readers can tune out now as the rest of this post is going to be an extended exposition of my confusion, rather than anything of general interest.

The next paragraph refers to the principle behind mini-batching. There is something slightly non-obvious being said here. They state that "The mini-batch is used to approximate the gradient of the loss function with respect to the parameters...". What they are saying here is that the mini-batch, if it's a representative sample, approximates the overall network. Calculating the loss of the mini-batch is an estimator of the loss of the whole training set. It's smaller and easier to work with than the entire training set, but more representative that looking just at each example. The larger the mini-batch, the more representative it is of the whole, but at the same time it is more unwieldy. Picking the optimal batch size isn't covered, although they do mention that its efficiency also relates to being able to parellelise the calculation of the loss for each example.

The reason I think it's mentioned is that the purpose of mini-batching is similar to the purpose of batch normalisation. They are saying that mini-batching improves generalisation, because the learning is related to average loss across the mini-batch, rather than learning each specific example. That is to say -- it makes the network concentrate less sensitive to spurious details.

As I understand it, batch normalisation also achieves that end, as well as reducing reducing the sensitivity of the network to the tuning of its meta-parameters (the latter being the prime purpose of batch normalisation).

The make the point that in a deep network, the effect of tuning parameters is magnified. For example, if picking a learning rate of 0.1 has some effect on a single-layer network, it could have double that effect on a two-layer network. Potentially. I think this point is a little shaky myself, because having multiple layers could allow each layer to cancel out overshoots at previous layers. However, this might be a strong case for a more intricate layer design based on capturing effects at different scales. For example, having an architecture with a 'fine detail' and a 'course detail' layer might be better than two fine-scale layers. Another approach (for example on images) might be to actually train off smoothed data plus the detail data. Food for thought.

They then move on to what I think is the main game: reducing competition between the layers. As I interpret what they are saying, the learning step affects all layers. As a result, in a sense, layers towards the top of the network are experiencing a 'moving goalposts' situation while the layers underneath them learn and respond differently to the input data. This is basically where I no longer understand things completely. They are referring to the shifting nature of layer outputs as "Internal covariate shift". I interpret this as meaning that higher layers need to re-adjust to the new outputs of lower-layers. I think of it as being like input scaling, except at the layer level, updated through mini-batch training.

They then point to their big result: reduction in training time and improvements to accuracy. They took a well-performing network, and matched its results in 7% of the iterations. That's a big reduction. The also state they can improve final accuracy, but do not state how much.

Now for the details. The paper now talks about their detailed methodology and approach. I'm literally just 1.5 pages into an 8-page document, and my mind is experiencing that pressure you get when you just don't think you can grok any more. I don't think I can reasonably burden my readers any more with my thoughts, nor can I process much more today.

I'm going to have to break up understanding this paper into more sessions, I think, coming back after internalising every page or two. It's probably worth my getting there, because the end of the paper does mention alternative means to the same ends and talks about the limits of the technique. Perhaps I will make further posts on the topic, perhaps I won't. We shall see.

If any readers here are more advanced in their understanding than I am, I would very much appreciate if you could point out anything I've gotten wrong!!!