Currently, artificial neural networks are trained through back-propagation by gradient descent. Their output layer is compared to the ground truth, and each place where their output layer varies experiences a ‘tug’ (called the ‘cost function’). These ‘tugs’ propagate backwards, down through the neural network, changing the weights of synaptic connections. Yet, because the ‘tug’ begins at the output layer and must propagate through the entire depth of the network, it often diffuses. This is called the vanishing gradient problem.
I offer a method for propagating training, which operates on every neuron directly. Instead of a vanishing gradient, the ‘tug’ on each neuron is strong, and is in direct proportion to its predictive power. This method can be considered a generalization of the ‘cost function’ over the entire network.
The key insight is this: some neurons function as error-detectors. That is, the neuron’s activity is correlated with the network’s success at classification. The value of error-detectors can be expressed as a gambling problem: “If I can only see this one neuron’s activity, and I am trying to predict if the network mis-classifies, can I make a solid bet based on this one neuron?” By listening to the error-detecting neurons, the network learns to avoid those errors!
“How do you find error-detectors? How do you measure the reliability of the error-detectors? What do you do with them, when you find them?”
Suppose that you have an image data set that you wish to teach to a neural network. You could start by taking all the images of cats in your data set, and feeding them into the neural network. (This process must be done for each type of image in your data set.) As each image is fed through the network, you record the activation level of each neuron. Some images are correctly classified as cats — place the record of their neurons’ activation levels in a ‘correctly classified’ bucket. Other images are mis-classified — place the record of their neurons’ activation levels in a ‘mis-classified’ bucket.
When you have fed your network all the cat images, you look at each neuron in turn. Each neuron has two records of its activation: one record of its activation during ‘correctly classified’ images, and another record, its activation during ‘mis-classified’ images, each in their own bucket. For each particular neuron, you compare the distribution of activations for the ‘correctly classified’ and ‘mis-classified’ buckets.
A cheap trick for comparing these distributions: sample randomly (with replacement) from both buckets, and compute their difference (i.e. the neuron’s activation during the ‘correctly classified’ image, minus activation during the ‘mis-classified’ image). This forms a new distribution, the distance distribution. The statistical mean of the distance distribution may be far from zero (indicating that the neuron’s activation during ‘correctly classified’ and ‘mis-classified’ images were consistently different). Additionally, the distribution’s standard deviation may be large or small (resulting from large and small variances in each bucket’s activations, respectively).
To compensate for each bucket having large or small standard deviation, you can divide the distribution’s mean by the square root of the product of the standard deviations of the two buckets’ activation distributions. This tells us how much the neuron activation levels of ‘correctly classified’ and ‘mis-classified’ images differ.
If this measure is far from zero, then the neuron is firing differently for ‘correctly classified’ images than it does for ‘mis-classified’ images. The neuron has activation levels that are distinct, depending upon whether or not the network ‘mis-classifies’ an image. It is hard to over-emphasize the importance of this measure. What it means for our gambling problem: if we can only see that neuron, and we have to guess whether or not the neural network made a mistake, we could reliably use that neuron’s activation level to predict ‘mis-classification’. That neuron acts as an error-detector.
(Side-Note: The Earth Mover’s Distance (or Wasserstein metric) measures the distance between two distributions by treating the distributions like piles of dirt. The EMD is the total distance that chunks of dirt must travel, to turn one pile into the other. It is the more traditional distance metric for two probability distributions. My sampling technique is equivalent to the distance between the means of the two distributions.)
When a Neuron Detects Errors
For a cat image to be mis-classified, there must have been a difference in neuron activations somewhere. Either a neuron that should fire didn’t (e.g. an image of a cat where you cannot see its ears completely, which leaves the ‘cat ear’ neuron silent), or a neuron fired when it shouldn’t (e.g. an image of a cat, which triggers the ‘fox ear’ neuron by mistake). Let’s consider what behavior we would prefer, in each of those circumstances:
The ‘cat ear’ neuron didn’t fire — We would hope that the neural network increases the strength of the signals feeding into the ‘cat ear’ neuron, so that it is more likely to be active, instead of silent.
The ‘fox ear’ neuron fired by accident — We would hope that the neural network decreases the strength of the signals feeding into the ‘fox ear’ neuron, so that it does not become active inappropriately.
So, if the distance measure described earlier is positive, then that neuron is usually active when the network correctly classifies a cat, while being inactive during mis-classifications. In that case, we hope to increase the sensitivity of the connections feeding into that neuron, so that it will activate ‘cat’ when other neurons were unable (e.g. boosting likelihood of the ‘cat ear’ neuron firing, even on scant evidence, so that the image is still classified as a ‘cat’). And, if the distance measure is negative, then the neuron is usually active during mis-classification, while it is silent when correctly classifying. We seek to decrease the sensitivity of the connections feeding into that neuron, in the hopes that silencing it will lower activation levels of the mis-classifying neurons (e.g. silencing the ‘fox ear’ signal which caused a ‘cat’ to be mis-identified as a ‘fox’). These increases and decreases would propagate down the network; they follow back-propagation by gradient descent. Together, the changes in sensitivity exclude neurons which contribute to errors and accumulate neurons which contribute to correct classification.
Additionally, we would prefer a forward propagation — if a neuron is usually active during mis-classification and silent for correct classification, then we can safely increase its connectivity to the ‘cat’ output neuron. That way, should the rest of the network fail to sufficiently activate the ‘cat’ neuron, this error-detecting neuron will tip the scales in favor of ‘cat’ classification. (e.g. the ‘fox ear’ neuron might fire alongside many ‘cat-features’, in which case the ‘fox ear’ actually helps activate the ‘cat’ neuron in the output layer.)
Meanwhile, if a neuron is usually silent (or has a negative-valued activation) during mis-classification and is active during correct classification, we hope that its silence is wired to cause increased activity on the ‘cat’ output neuron. As a result of such a wiring, the ‘cat’ classification is similarly encouraged by the error-detecting neuron. (For silent neurons, this requires a special activation function that maps 0 to some positive constant and small positive values to some fraction of that constant; for negative-valued activations, this requires a negative synaptic weight.)
With these, both forward and backward propagation of reinforcement, the neural network tends to learn features which are encoded on the lowest detectable layer of the network, and the network focuses on features which have minimal overlap between categories. As a result, ‘cat’ is distinguished from ‘fox’ by those features which are most dissimilar between the two.
I suggest that this backward and forward propagation of learning be computed once per epoch, so that there are enough data points in each of the ‘correctly classified’ and ‘mis-classified’ buckets. In addition to comparing ‘cats that were correctly classified as cats’ to ‘cats that were mis-classified as something else’, you can compare ‘cats that were correctly classified as cats’ to ‘other things that were mis-classified as cats’. These two buckets create additional distance measures, and allow similar propagation of learning.
Yes, this may be slower than traditional back-propagation from the output layer, when training a shallow network. Yet, for very deep networks, this method may be a large improvement. (In very deep networks, features are often so broadly diffused that learning slows or stalls out. ResNet is usually credited with overcoming problems of depth. However, ResNet achieved peak performance around a depth of 100 layers, and performance declined significantly when they tried 1,000 layers — 1,000 layers performed worse than 34 layers! Back-propagation from the output layer still does not allow training on very deep networks.)
This method also enables training a network piecemeal — a new cluster of neurons can be inserted at any depth in the network, and these can be trained while the rest of the network’s synaptic weights are frozen. Because this method is a ‘cost function’ defined over the entire network, it is able to train the inserted cluster of neurons directly. That cluster’s rate of learning is not slowed, regardless of total network depth. That is the most valuable property of this method, and I believe it is the key to a neural network that can learn new information without forgetting what it learned in the past. As new data comes along, new clusters can be inserted and trained without compromising the rest of the network.