The Hello World intro to machine learning is usually by way of the Iris
flower image classification or the MNIST handwritten digit recognition. In
this post, I describe training a neural network in Pharo to perform
handwritten digit recognition. Instead of the
MNIST dataset, I'll use the smaller
dataset. According to the
this dataset consists of two files:
optdigits.tra, 3823 records
optdigits.tes, 1797 records
Each record consists of 64 inputs and 1 class attributes.
Input attributes are integers in the range 0..16.
The class attribute is the class code 0..9, i.e., it denotes the digit that the 64 input attributes encode.
The files are in CSV format. Let's use the excellent NeoCSV package to
read the data:
Next, install MLNeuralNetwork by Oleksandr Zaytsev:
MLNeuralNetwork operates on MLDataset instances, so modify the CSV reader
Note MLMnistReader>>onehot: which creates a 'one-hot' vector for each
digit. One-hot vectors make machine learning more effective. They are easy
to understand "pictorially":
For the digit 0, [1 0 0 0 0 0 0 0 0 0].
For the digit 1, [0 1 0 0 0 0 0 0 0 0].
For the digit 3, [0 0 0 1 0 0 0 0 0 0].
For the digit 9, [0 0 0 0 0 0 0 0 0 1].
Since there are over 5,000 records, we precompute the one-hot vectors and
reuse them, instead of creating one vector per record.
Now create a 3-layer neural network of 64 input, 96 hidden, and 10 output
neurons, set it to learn from the training data for 500 epochs, and test
From the inspector, we can see that the network got records 3, 6 and 20
In the inspector's code pane, the following snippet shows that the
network's accuracy is about 92%.
Not bad for a first attempt. However, the data set's source states that
the K-Nearest Neighbours algorithm achieved up to 98% accuracy on the
testing set, so there's plenty of room for improvement for the network.
Here's a screenshot showing some of the 8x8 digits with their predicted and
actual values. I don't know about you, but the top right "digit" looks
more like a smudge to me than any number.