The MNIST database is a large image database of handwritten digits that is commonly used for training and testing in the field of machine learning / image categorisation. Information regarding this dataset and various results achieved is widely published.
A correctly modelled/trained neural network should be able to achieve a 5% error rate on this dataset. Most/all published results are better than that. The largest, most advanced, models have managed 0,35%. That's almost unbelievably good. ojAlgo currently doesn't have all the features required to build that kind of model. The model in the program listed below gets about 2.2% error rate. Here are some sample digits/images from the data set.
The program below (with its dependency on ojAlgo) can do the following:
The main benefit of using ojAlgo is how easy it is to do this and get good results. Download the example code below (you also need ojAlgo v46.1.1 or later) and run it, and start modifying the network structure, learning rate and other things. (You also need to download the data files, and update the various paths in the programs.)
The program below (with its dependency on ojAlgo) can do the following:
- Read/parse the files containing the image data and labels.
- Generate the actual images so that you can inspect them. The example images above are generated with that code.
- Print images to console (to sanity check results)
- Model and train feedforward neural networks:
- Any number of layers
- Any number of input/output nodes per layer
- Choose between 5 different activator and 2 different error/loss functions
The main benefit of using ojAlgo is how easy it is to do this and get good results. Download the example code below (you also need ojAlgo v46.1.1 or later) and run it, and start modifying the network structure, learning rate and other things. (You also need to download the data files, and update the various paths in the programs.)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import static org.ojalgo.ann.ArtificialNeuralNetwork.Activator.SIGMOID; | |
import static org.ojalgo.ann.ArtificialNeuralNetwork.Activator.SOFTMAX; | |
import static org.ojalgo.function.constant.PrimitiveMath.DIVIDE; | |
import java.io.File; | |
import java.io.IOException; | |
import org.ojalgo.OjAlgoUtils; | |
import org.ojalgo.ann.ArtificialNeuralNetwork; | |
import org.ojalgo.ann.NetworkInvoker; | |
import org.ojalgo.ann.NetworkTrainer; | |
import org.ojalgo.array.ArrayAnyD; | |
import org.ojalgo.data.DataBatch; | |
import org.ojalgo.data.image.ImageData; | |
import org.ojalgo.netio.BasicLogger; | |
import org.ojalgo.netio.IDX; | |
import org.ojalgo.netio.ToFileWriter; | |
import org.ojalgo.structure.AccessAnyD.MatrixView; | |
import org.ojalgo.type.format.NumberStyle; | |
/** | |
* A example of how to build, train and use artificial neural networks with ojAlgo using the MNIST database. | |
* This is an updated version of a previous example. | |
* | |
* @see https://www.ojalgo.org/2021/08/artificial-neural-network-example-v2/ | |
* @see https://www.ojalgo.org/2018/09/introducing-artificial-neural-networks-with-ojalgo/ | |
* @see https://github.com/optimatika/ojAlgo/wiki/Artificial-Neural-Networks | |
*/ | |
public class TrainingANN { | |
static final File OUTPUT_TEST_IMAGES = new File("/Users/apete/Developer/data/images/test/"); | |
static final File OUTPUT_TRAINING_IMAGES = new File("/Users/apete/Developer/data/images/training/"); | |
static final File TEST_IMAGES = new File("/Users/apete/Developer/data/t10k-images-idx3-ubyte"); | |
static final File TEST_LABELS = new File("/Users/apete/Developer/data/t10k-labels-idx1-ubyte"); | |
static final File TRAINING_IMAGES = new File("/Users/apete/Developer/data/train-images-idx3-ubyte"); | |
static final File TRAINING_LABELS = new File("/Users/apete/Developer/data/train-labels-idx1-ubyte"); | |
public static void main(final String[] args) throws IOException { | |
BasicLogger.debug(); | |
BasicLogger.debug(TrainingANN.class); | |
BasicLogger.debug(OjAlgoUtils.getTitle()); | |
BasicLogger.debug(OjAlgoUtils.getDate()); | |
BasicLogger.debug(); | |
int trainingEpochs = 50; | |
int batchSize = 100; | |
int numberToPrint = 10; | |
boolean generateImages = false; | |
ArtificialNeuralNetwork network = ArtificialNeuralNetwork.builder(28 * 28).layer(200, SIGMOID).layer(10, SOFTMAX).get(); | |
NetworkTrainer trainer = network.newTrainer(batchSize).rate(0.01).dropouts(); | |
ArrayAnyD<Double> trainingLabels = IDX.parse(TRAINING_LABELS); | |
ArrayAnyD<Double> trainingImages = IDX.parse(TRAINING_IMAGES); | |
trainingImages.modifyAll(DIVIDE.by(255)); // Normalise the image pixel values that are between 0 and 255 | |
DataBatch inputBatch = trainer.newInputBatch(); | |
DataBatch outputBatch = trainer.newOutputBatch(); | |
for (int e = 0; e < trainingEpochs; e++) { | |
for (MatrixView<Double> imageData : trainingImages.matrices()) { | |
inputBatch.addRow(imageData); | |
long imageIndex = imageData.index(); | |
int label = trainingLabels.intValue(imageIndex); | |
// The label is an integer [0,9] representing the digit in the image | |
// That label is used as the index to set a single 1.0 | |
outputBatch.addRowWithSingleUnit(label); | |
if (inputBatch.isFull()) { | |
trainer.train(inputBatch, outputBatch); | |
inputBatch.reset(); | |
outputBatch.reset(); | |
} | |
if (generateImages && e == 0) { | |
TrainingANN.generateImage(imageData, label, OUTPUT_TRAINING_IMAGES); | |
} | |
} | |
} | |
/* | |
* It is of course possible to invoke/evaluate the network using batched input data. Further more it | |
* is possible to have multiple invokers running in different threads. Here we stick to 1 thread and | |
* simple batch size == 1. | |
*/ | |
NetworkInvoker invoker = network.newInvoker(); | |
ArrayAnyD<Double> testLabels = IDX.parse(TEST_LABELS); | |
ArrayAnyD<Double> testImages = IDX.parse(TEST_IMAGES); | |
testImages.modifyAll(DIVIDE.by(255)); | |
int right = 0; | |
int wrong = 0; | |
for (MatrixView<Double> imageData : testImages.matrices()) { | |
long expected = testLabels.longValue(imageData.index()); | |
long actual = invoker.invoke(imageData).indexOfLargest(); | |
if (actual == expected) { | |
right++; | |
} else { | |
wrong++; | |
} | |
if (imageData.index() < numberToPrint) { | |
BasicLogger.debug(""); | |
BasicLogger.debug("Image {}: {} <=> {}", imageData.index(), expected, actual); | |
IDX.print(imageData, BasicLogger.DEBUG); | |
} | |
if (generateImages) { | |
TrainingANN.generateImage(imageData, expected, OUTPUT_TEST_IMAGES); | |
} | |
} | |
BasicLogger.debug(""); | |
BasicLogger.debug("========================================================="); | |
BasicLogger.debug("Error rate: {}", (double) wrong / (double) (right + wrong)); | |
} | |
private static void generateImage(final MatrixView<Double> imageData, final long imageLabel, final File directory) throws IOException { | |
ToFileWriter.mkdirs(directory); | |
int nbRows = imageData.getRowDim(); | |
int nbCols = imageData.getColDim(); | |
// IDX-files and ojAlgo data structures are indexed differently. | |
// That doesn't matter when we're doing the math, | |
// but need to transpose the data when creating an image to look at. | |
ImageData image = ImageData.newGreyScale(nbCols, nbRows); | |
for (int i = 0; i < nbRows; i++) { | |
for (int j = 0; j < nbCols; j++) { | |
// The colours are stored inverted in the IDX-files (255 means "ink" | |
// and 0 means "no ink". In computer graphics 255 usually means "white" | |
// and 0 "black".) In addition the image data was previously scaled | |
// to be in the range [0,1]. That's why... | |
double grey = 255.0 * (1.0 - imageData.doubleValue(i, j)); | |
image.set(j, i, grey); | |
} | |
} | |
String name = NumberStyle.toUniformString(imageData.index(), 60_000) + "_" + imageLabel + ".png"; | |
File outputfile = new File(directory, name); | |
image.writeTo(outputfile); | |
} | |
} |
No comments:
Post a Comment