Handwritten digit recognition with Tensorflow.js
Handwritten recognition enable us to convert the handwriting documents into digital form. This technology is now being use in numerous ways : reading postal addresses, bank check amounts, digitizing historical literature.
Thanks to tensorflow.js, it brings this powerful technology into the browser. In this article, we are going to build a web application that can predict the digit you draw on the canvas.
Handwritten digit recognition demo
Draw on the black canvas below with your mouse on desktop or your finger on your mobile, click “Predict” to get result of the hand written digit prediction, click “Clean” to start drawing again
GitHub repository
You can download the complete code of the above demo in the link below:
Implementation
This application first use Python script to train and save the model, then use javascript library tensorflow.js to load the model into the browser, and predict what number is the hand drawing digit. Please follow me below to explore further how it is build.
Folder Structure
Let’s start with setting up the project with proper folder structure
- js – contain javascript files
- chart.min.js – chart display javascipt library
- digit-recognition.js – main application javascript
- models – contain saved models and weights
- style – contain css style file
- index.html – main html file
- MNIST.py – Python script to train and save the model
# Step 1 : Train and save model
To begin our journey, we will be writing Python script to train a CNN(Convolutional Neural Network) model on the famous MNIST dataset.
MNIST is a computer vision database consisting of handwritten digits, with labels identifying the digits. Every MNIST data point has two parts: an image of a handwritten digit and a corresponding label.
Import libraries
import tensorflow as tf
import tensorflowjs as tfjs
from tensorflow import keras
Load MNIST dataset
The MNIST dataset consist of 60,000 examples, we are splitting them into training and testing datasets. After that, it required some pre-processing before it can feed into the CNN.
# split the mnist data into train and test
(train_img,train_label),(test_img,test_label) = keras.datasets.mnist.load_data()
# reshape and scale the data
train_img = train_img.reshape([-1, 28, 28, 1])
test_img = test_img.reshape([-1, 28, 28, 1])
train_img = train_img/255.0
test_img = test_img/255.0
# convert class vectors to binary class matrices --> one-hot encoding
train_label = keras.utils.to_categorical(train_label)
test_label = keras.utils.to_categorical(test_label)
Define the model architecture
model = keras.Sequential([
keras.layers.Conv2D(32, (5, 5), padding="same", input_shape=[28, 28, 1]),
keras.layers.MaxPool2D((2,2)),
keras.layers.Conv2D(64, (5, 5), padding="same"),
keras.layers.MaxPool2D((2,2)),
keras.layers.Flatten(),
keras.layers.Dense(1024, activation='relu'),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
Train the model
model.fit(train_img,train_label, validation_data=(test_img,test_label), epochs=10)
test_loss,test_acc = model.evaluate(test_img, test_label)
print('Test accuracy:', test_acc)
It training could take a couple of minutes, and we can get a pretty good result of 98.5% accuracy on the test set.
Save model as tfjs format
Now we have a model, we need to save it into some format that tensorflowjs can load into the browser.
tfjs.converters.save_keras_model(model, 'models')
The model would be saved into the ‘models’ folder, which contains a model.json file and some other weight files.
# Step 2 : Include tensorflow.js
Simply include the scripts for tfjs
in the <head> section of the html file. I also include the jquery library and the chart library as well.
<html>
<head>
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" >
<link rel="stylesheet" type="text/css" href="style/digit.css">
<script src="https://code.jquery.com/jquery-3.3.1.min.js"></script>
<script src="js/chart.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
</head>
# Step 3 : Set up canvas
For the user to draw a digit using mouse on desktop or finger on mobile devices, we need to create a HTML5 element called canvas. Inside the canvas, the user will draw the digit. We will feed the user drawn digit into the deep neural network that we have created to make predictions.
HTML – index.html
Add a placeholder <div> to contain the canvas that you can draw digit on
<div id="canvas_box" class="canvas-box"></div>
Add “Predict” button to get result of the hand written digit prediction, “Clean” button to wipe the canvas and start drawing again
<button id="clear-button" class="btn btn-dark">Clear</button>
<button id="predict-button" class="btn btn-dark">Predict</button>
At the end of the <body>, include the main javascript file digit-recognition.js
<script src="js/digit-recognition.js"></script>
</body>
</html>
Javascript – digit-recognition.js
Initialize the variables
let model;
var canvasWidth = 150;
var canvasHeight = 150;
var canvasStrokeStyle = "white";
var canvasLineJoin = "round";
var canvasLineWidth = 10;
var canvasBackgroundColor = "black";
var canvasId = "canvas";
var clickX = new Array();
var clickY = new Array();
var clickD = new Array();
var drawing;
Create the canvas and append it to the placeholder to display
var canvasBox = document.getElementById('canvas_box');
var canvas = document.createElement("canvas");
canvas.setAttribute("width", canvasWidth);
canvas.setAttribute("height", canvasHeight);
canvas.setAttribute("id", canvasId);
canvas.style.backgroundColor = canvasBackgroundColor;
canvasBox.appendChild(canvas);
if(typeof G_vmlCanvasManager != 'undefined') {
canvas = G_vmlCanvasManager.initElement(canvas);
}
ctx = canvas.getContext("2d");
Drawing inside a canvas is a little bit tricky in mobile and desktop. We need to be aware of all the jQuery handlers that are available for mouse and touch. Below are the jQuery event handlers that we will be using.
- mousedown
- mousemove
- mouseup
- mouseleave
- touchstart
- touchmove
- touchend
- touchleave
//---------------------
// MOUSE DOWN function
//---------------------
$("#canvas").mousedown(function(e) {
var rect = canvas.getBoundingClientRect();
var mouseX = e.clientX- rect.left;;
var mouseY = e.clientY- rect.top;
drawing = true;
addUserGesture(mouseX, mouseY);
drawOnCanvas();
});
//-----------------------
// TOUCH START function
//-----------------------
canvas.addEventListener("touchstart", function (e) {
if (e.target == canvas) {
e.preventDefault();
}
var rect = canvas.getBoundingClientRect();
var touch = e.touches[0];
var mouseX = touch.clientX - rect.left;
var mouseY = touch.clientY - rect.top;
drawing = true;
addUserGesture(mouseX, mouseY);
drawOnCanvas();
}, false);
//---------------------
// MOUSE MOVE function
//---------------------
$("#canvas").mousemove(function(e) {
if(drawing) {
var rect = canvas.getBoundingClientRect();
var mouseX = e.clientX- rect.left;;
var mouseY = e.clientY- rect.top;
addUserGesture(mouseX, mouseY, true);
drawOnCanvas();
}
});
//---------------------
// TOUCH MOVE function
//---------------------
canvas.addEventListener("touchmove", function (e) {
if (e.target == canvas) {
e.preventDefault();
}
if(drawing) {
var rect = canvas.getBoundingClientRect();
var touch = e.touches[0];
var mouseX = touch.clientX - rect.left;
var mouseY = touch.clientY - rect.top;
addUserGesture(mouseX, mouseY, true);
drawOnCanvas();
}
}, false);
//-------------------
// MOUSE UP function
//-------------------
$("#canvas").mouseup(function(e) {
drawing = false;
});
//---------------------
// TOUCH END function
//---------------------
canvas.addEventListener("touchend", function (e) {
if (e.target == canvas) {
e.preventDefault();
}
drawing = false;
}, false);
//----------------------
// MOUSE LEAVE function
//----------------------
$("#canvas").mouseleave(function(e) {
drawing = false;
});
//-----------------------
// TOUCH LEAVE function
//-----------------------
canvas.addEventListener("touchleave", function (e) {
if (e.target == canvas) {
e.preventDefault();
}
drawing = false;
}, false);
//--------------------
// ADD CLICK function
//--------------------
function addUserGesture(x, y, dragging) {
clickX.push(x);
clickY.push(y);
clickD.push(dragging);
}
//-------------------
// RE DRAW function
//-------------------
function drawOnCanvas() {
ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height);
ctx.strokeStyle = canvasStrokeStyle;
ctx.lineJoin = canvasLineJoin;
ctx.lineWidth = canvasLineWidth;
for (var i = 0; i < clickX.length; i++) {
ctx.beginPath();
if(clickD[i] && i) {
ctx.moveTo(clickX[i-1], clickY[i-1]);
} else {
ctx.moveTo(clickX[i]-1, clickY[i]);
}
ctx.lineTo(clickX[i], clickY[i]);
ctx.closePath();
ctx.stroke();
}
}
//------------------------
// CLEAR CANVAS function
//------------------------
$("#clear-button").click(async function () {
ctx.clearRect(0, 0, canvasWidth, canvasHeight);
clickX = new Array();
clickY = new Array();
clickD = new Array();
$(".prediction-text").empty();
$("#result_box").addClass('d-none');
});
# Step 4 : Tensorflow.js load model and predict
Now we need to use TensorFlow.js to load the model that we trained ealier in Python, and use it to predict the digit that we drawn on the cavas.
Load Model
function loadModel() to call the tensorflow.js API tf.loadLayersModel
//-------------------------------------
// loader for cnn model
//-------------------------------------
async function loadModel() {
// clear the model variable
model = undefined;
// load the model using a HTTPS request (where you have stored your model files)
model = await tf.loadLayersModel("models/model.json");
}
loadModel();
Pre-process canvas
function preprocessCanvas to pre-process the canvas drawn by the user before feed it to the CNN model
//-----------------------------------------------
// preprocess the canvas
//-----------------------------------------------
function preprocessCanvas(image) {
// resize the input image to target size of (1, 28, 28)
let tensor = tf.browser.fromPixels(image)
.resizeNearestNeighbor([28, 28])
.mean(2)
.expandDims(2)
.expandDims()
.toFloat();
return tensor.div(255.0);
}
Prediction
When the “Predict” button is click, we get the image data from the canvas, pre-process it as a tensor, then feed it into the API model.predict
to get the result of the prediction.
//--------------------------------------------
// predict function
//--------------------------------------------
$("#predict-button").click(async function () {
// get image data from canvas
var imageData = canvas.toDataURL();
// preprocess canvas
let tensor = preprocessCanvas(canvas);
// make predictions on the preprocessed image tensor
let predictions = await model.predict(tensor).data();
// get the model's prediction results
let results = Array.from(predictions);
// display the predictions in chart
$("#result_box").removeClass('d-none');
displayChart(results);
displayLabel(results);
});
Display result
function loadChart to utilize the Chart.js library to display prediction result as a visual bar chart.
//------------------------------
// Chart to display predictions
//------------------------------
var chart = "";
var firstTime = 0;
function loadChart(label, data, modelSelected) {
var ctx = document.getElementById('chart_box').getContext('2d');
chart = new Chart(ctx, {
// The type of chart we want to create
type: 'bar',
// The data for our dataset
data: {
labels: label,
datasets: [{
label: modelSelected + " prediction",
backgroundColor: '#f50057',
borderColor: 'rgb(255, 99, 132)',
data: data,
}]
},
// Configuration options go here
options: {}
});
}
//----------------------------
// display chart with updated
// drawing from canvas
//----------------------------
function displayChart(data) {
var select_option = "CNN";
label = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"];
if (firstTime == 0) {
loadChart(label, data, select_option);
firstTime = 1;
} else {
chart.destroy();
loadChart(label, data, select_option);
}
document.getElementById('chart_box').style.display = "block";
}
function displayLabel(data) {
var max = data[0];
var maxIndex = 0;
for (var i = 1; i < data.length; i++) {
if (data[i] > max) {
maxIndex = i;
max = data[i];
}
}
$(".prediction-text").html("Predicting you draw <b>"+maxIndex+"</b> with <b>"+Math.trunc( max*100 )+"%</b> confidence")
}
Finally, testing
That’s it! As you can see we get a pretty good result, it predict I have drawn digit number 5 with 99% confidence
Without doubt, there are so many researches had been conducted in the filed of handwritten recognition. And tensorflow.js make these pre-trained deep models accessible in the browser. From this article I hope you have fun and I encourage you to discover more about this library. I can’t wait to see more creative ideas come up to put this cutting edge technology into practical usage.
Thank you for reading. If you like this article, please share on Facebook or Twitter. Let me know in the comment if you have any questions. Follow me on Medium, GitHub and Linkedin. Support me on Ko-fi.
15 Comments
Very helpful, well explained, easy to follow and understood.
Owner is very pro active and reply to any questions immediately.
SIr, project is very simply understandable. Thanks for this project.
hey bro my project is not running what should i do
congratulations on the work, amazing, I tested your project, what I would like to know is how to train it could not be replaced to save in json mode or better in a database to compare later, my idea is to do a signature authentication, a screen of register where the user registers his signature as if it were a username and password, and on the login page he only signs to enter
Hi Rodrigo, from what you describe, you are looking for image similarity comparison. Compare how similar between 2 signatures. I found this Github repo might be useful for you : https://github.com/ftlabs/Hancock
will you please tell me
which tensorflowjs api should i use as when i was using tensorflowjs latest api i getting so many exection and when i used tensorflowjs api(1.2.6) i am getting ConnectionResetError.
Hi, I am using the latest version https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest
Thank you for your response.
but i want to know..
As we know before writing the code we require to install some library.
which version of tensorflowjs and chart.min library i require to install as when i install their latest version i am getting so many exceptions. It would be very helpful if you will tell me the version number of libraries.
Below are the libraries I used for the demo:
jquery – https://code.jquery.com/jquery-3.3.1.min.js
tensorflowjs – https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest
Chart.js – https://cdn.jsdelivr.net/npm/chart.js@2.8.0
Hi Sir, thank you for sharing such a great tutorial for us! I keen to try out this application and run on local server, but I got this issue and I have google the solution tried several ways, the problem still unable to solve. Therefore, I hope to get advise from you on how to solve.
The problem that I encountered as follows:
Fetch API cannot load file:///D:/Hand-Written-Digit-Recognition-master/models/model.json. URL scheme must be “http” or “https” for CORS request.
Hi, after check out code from github repo, you can’t just open the html file, you need to set up your localhost to point to the root directory of the code, and access through http://localhost/index.html
How to set up your localhost to point to the root directory of the code, and access through http://localhost/index.html?
1. Install a web server such as Apache or Nginx on your local machine.
2. Configure your web server to point to the root directory of your code.
3. Start the web server.
4. Open a web browser and go to http://localhost/index.html.
what library do we use if we wanna detect alphabets along with numbers
Hi Antara,
you can find some machine learning models that trained to recognize alphabets on https://github.com/topics/alphabet-recognition