Handwritten digit recognition with Tensorflow.js

Handwritten recognition enable us to convert the handwriting documents into digital form. This technology is now being use in numerous ways : reading postal addresses, bank check amounts, digitizing historical literature.

Thanks to tensorflow.js, it brings this powerful technology into the browser. In this article, we are going to build a web application that can predict the digit you draw on the canvas.


Handwritten digit recognition demo

Draw on the black canvas below with your mouse on desktop or your finger on your mobile, click “Predict” to get result of the hand written digit prediction, click “Clean” to start drawing again


GitHub repository

You can download the complete code of the above demo in the link below:


Implementation

This application first use Python script to train and save the model, then use javascript library tensorflow.js to load the model into the browser, and predict what number is the hand drawing digit. Please follow me below to explore further how it is build.

Folder Structure

Let’s start with setting up the project with proper folder structure

handwritten digit recognition folder structure
  • js – contain javascript files
    • chart.min.js – chart display javascipt library
    • digit-recognition.js – main application javascript
  • models – contain saved models and weights
  • style – contain css style file
  • index.html – main html file
  • MNIST.py – Python script to train and save the model

# Step 1 : Train and save model

To begin our journey, we will be writing Python script to train a CNN(Convolutional Neural Network) model on the famous MNIST dataset.

MNIST is a computer vision database consisting of handwritten digits, with labels identifying the digits. Every MNIST data point has two parts: an image of a handwritten digit and a corresponding label. 

Import libraries

import tensorflow as tf
import tensorflowjs as tfjs
from tensorflow import keras

Load MNIST dataset

The MNIST dataset consist of 60,000 examples, we are splitting them into training and testing datasets. After that, it required some pre-processing before it can feed into the CNN.

# split the mnist data into train and test
(train_img,train_label),(test_img,test_label) = keras.datasets.mnist.load_data()

# reshape and scale the data
train_img = train_img.reshape([-1, 28, 28, 1])
test_img = test_img.reshape([-1, 28, 28, 1])
train_img = train_img/255.0
test_img = test_img/255.0

# convert class vectors to binary class matrices --> one-hot encoding
train_label = keras.utils.to_categorical(train_label)
test_label = keras.utils.to_categorical(test_label)

Define the model architecture

model = keras.Sequential([
    keras.layers.Conv2D(32, (5, 5), padding="same", input_shape=[28, 28, 1]),
    keras.layers.MaxPool2D((2,2)),
    keras.layers.Conv2D(64, (5, 5), padding="same"),
    keras.layers.MaxPool2D((2,2)),
    keras.layers.Flatten(),
    keras.layers.Dense(1024, activation='relu'),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

Train the model

model.fit(train_img,train_label, validation_data=(test_img,test_label), epochs=10)
test_loss,test_acc = model.evaluate(test_img, test_label)
print('Test accuracy:', test_acc)

It training could take a couple of minutes, and we can get a pretty good result of 98.5% accuracy on the test set.

Save model as tfjs format

Now we have a model, we need to save it into some format that tensorflowjs can load into the browser.

tfjs.converters.save_keras_model(model, 'models')

The model would be saved into the ‘models’ folder, which contains a model.json file and some other weight files.

# Step 2 : Include tensorflow.js

Simply include the scripts for tfjs in the <head> section of the html file. I also include the jquery library and the chart library as well.

<html>
<head>
	<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" >    
	<link rel="stylesheet" type="text/css" href="style/digit.css">
	<script src="https://code.jquery.com/jquery-3.3.1.min.js"></script>
	<script src="js/chart.min.js"></script>
	<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
</head>

# Step 3 : Set up canvas

For the user to draw a digit using mouse on desktop or finger on mobile devices, we need to create a HTML5 element called canvas. Inside the canvas, the user will draw the digit. We will feed the user drawn digit into the deep neural network that we have created to make predictions.

HTML – index.html

Add a placeholder <div> to contain the canvas that you can draw digit on

<div id="canvas_box" class="canvas-box"></div>

Add “Predict” button to get result of the hand written digit prediction, “Clean” button to wipe the canvas and start drawing again

<button id="clear-button" class="btn btn-dark">Clear</button>
<button id="predict-button" class="btn btn-dark">Predict</button>

At the end of the <body>, include the main javascript file digit-recognition.js

		<script src="js/digit-recognition.js"></script>
	</body>
</html>

Javascript – digit-recognition.js

Initialize the variables

let model;
var canvasWidth           	= 150;
var canvasHeight 			= 150;
var canvasStrokeStyle		= "white";
var canvasLineJoin			= "round";
var canvasLineWidth       	= 10;
var canvasBackgroundColor 	= "black";
var canvasId              	= "canvas";
var clickX = new Array();
var clickY = new Array();
var clickD = new Array();
var drawing;

Create the canvas and append it to the placeholder to display

var canvasBox = document.getElementById('canvas_box');
var canvas    = document.createElement("canvas");

canvas.setAttribute("width", canvasWidth);
canvas.setAttribute("height", canvasHeight);
canvas.setAttribute("id", canvasId);
canvas.style.backgroundColor = canvasBackgroundColor;
canvasBox.appendChild(canvas);
if(typeof G_vmlCanvasManager != 'undefined') {
  canvas = G_vmlCanvasManager.initElement(canvas);
}

ctx = canvas.getContext("2d");

Drawing inside a canvas is a little bit tricky in mobile and desktop. We need to be aware of all the jQuery handlers that are available for mouse and touch. Below are the jQuery event handlers that we will be using.

For desktop & laptops
  • mousedown
  • mousemove
  • mouseup
  • mouseleave
For mobile devices
  • touchstart
  • touchmove
  • touchend
  • touchleave
//---------------------
// MOUSE DOWN function
//---------------------
$("#canvas").mousedown(function(e) {
	var rect = canvas.getBoundingClientRect();
	var mouseX = e.clientX- rect.left;;
	var mouseY = e.clientY- rect.top;
	drawing = true;
	addUserGesture(mouseX, mouseY);
	drawOnCanvas();
});

//-----------------------
// TOUCH START function
//-----------------------
canvas.addEventListener("touchstart", function (e) {
	if (e.target == canvas) {
    	e.preventDefault();
  	}

	var rect = canvas.getBoundingClientRect();
	var touch = e.touches[0];

	var mouseX = touch.clientX - rect.left;
	var mouseY = touch.clientY - rect.top;

	drawing = true;
	addUserGesture(mouseX, mouseY);
	drawOnCanvas();

}, false);

//---------------------
// MOUSE MOVE function
//---------------------
$("#canvas").mousemove(function(e) {
	if(drawing) {
		var rect = canvas.getBoundingClientRect();
		var mouseX = e.clientX- rect.left;;
		var mouseY = e.clientY- rect.top;
		addUserGesture(mouseX, mouseY, true);
		drawOnCanvas();
	}
});

//---------------------
// TOUCH MOVE function
//---------------------
canvas.addEventListener("touchmove", function (e) {
	if (e.target == canvas) {
    	e.preventDefault();
  	}
	if(drawing) {
		var rect = canvas.getBoundingClientRect();
		var touch = e.touches[0];

		var mouseX = touch.clientX - rect.left;
		var mouseY = touch.clientY - rect.top;

		addUserGesture(mouseX, mouseY, true);
		drawOnCanvas();
	}
}, false);

//-------------------
// MOUSE UP function
//-------------------
$("#canvas").mouseup(function(e) {
	drawing = false;
});

//---------------------
// TOUCH END function
//---------------------
canvas.addEventListener("touchend", function (e) {
	if (e.target == canvas) {
    	e.preventDefault();
  	}
	drawing = false;
}, false);

//----------------------
// MOUSE LEAVE function
//----------------------
$("#canvas").mouseleave(function(e) {
	drawing = false;
});

//-----------------------
// TOUCH LEAVE function
//-----------------------
canvas.addEventListener("touchleave", function (e) {
	if (e.target == canvas) {
    	e.preventDefault();
  	}
	drawing = false;
}, false);

//--------------------
// ADD CLICK function
//--------------------
function addUserGesture(x, y, dragging) {
	clickX.push(x);
	clickY.push(y);
	clickD.push(dragging);
}

//-------------------
// RE DRAW function
//-------------------
function drawOnCanvas() {
	ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height);

	ctx.strokeStyle = canvasStrokeStyle;
	ctx.lineJoin    = canvasLineJoin;
	ctx.lineWidth   = canvasLineWidth;

	for (var i = 0; i < clickX.length; i++) {
		ctx.beginPath();
		if(clickD[i] && i) {
			ctx.moveTo(clickX[i-1], clickY[i-1]);
		} else {
			ctx.moveTo(clickX[i]-1, clickY[i]);
		}
		ctx.lineTo(clickX[i], clickY[i]);
		ctx.closePath();
		ctx.stroke();
	}
}

//------------------------
// CLEAR CANVAS function
//------------------------
$("#clear-button").click(async function () {
    ctx.clearRect(0, 0, canvasWidth, canvasHeight);
	clickX = new Array();
	clickY = new Array();
	clickD = new Array();
	$(".prediction-text").empty();
	$("#result_box").addClass('d-none');
});

# Step 4 : Tensorflow.js load model and predict

Now we need to use TensorFlow.js to load the model that we trained ealier in Python, and use it to predict the digit that we drawn on the cavas.

Load Model

function loadModel() to call the tensorflow.js API tf.loadLayersModel

//-------------------------------------
// loader for cnn model
//-------------------------------------
async function loadModel() {
  // clear the model variable
  model = undefined; 
  // load the model using a HTTPS request (where you have stored your model files)
  model = await tf.loadLayersModel("models/model.json");
}

loadModel();

Pre-process canvas

function preprocessCanvas to pre-process the canvas drawn by the user before feed it to the CNN model

//-----------------------------------------------
// preprocess the canvas
//-----------------------------------------------
function preprocessCanvas(image) {
	// resize the input image to target size of (1, 28, 28)
	let tensor = tf.browser.fromPixels(image)
		.resizeNearestNeighbor([28, 28])
		.mean(2)
		.expandDims(2)
		.expandDims()
		.toFloat();
	return tensor.div(255.0);
}

Prediction

When the “Predict” button is click, we get the image data from the canvas, pre-process it as a tensor, then feed it into the API model.predict to get the result of the prediction.

//--------------------------------------------
// predict function 
//--------------------------------------------
$("#predict-button").click(async function () {
    // get image data from canvas
	var imageData = canvas.toDataURL();

	// preprocess canvas
	let tensor = preprocessCanvas(canvas);

	// make predictions on the preprocessed image tensor
	let predictions = await model.predict(tensor).data();

	// get the model's prediction results
	let results = Array.from(predictions);

	// display the predictions in chart
	$("#result_box").removeClass('d-none');
	displayChart(results);
	displayLabel(results);
});

Display result

function loadChart to utilize the Chart.js library to display prediction result as a visual bar chart.

//------------------------------
// Chart to display predictions
//------------------------------
var chart = "";
var firstTime = 0;
function loadChart(label, data, modelSelected) {
	var ctx = document.getElementById('chart_box').getContext('2d');
	chart = new Chart(ctx, {
	    // The type of chart we want to create
	    type: 'bar',

	    // The data for our dataset
	    data: {
	        labels: label,
	        datasets: [{
	            label: modelSelected + " prediction",
	            backgroundColor: '#f50057',
	            borderColor: 'rgb(255, 99, 132)',
	            data: data,
	        }]
	    },

	    // Configuration options go here
	    options: {}
	});
}

//----------------------------
// display chart with updated
// drawing from canvas
//----------------------------
function displayChart(data) {
  	var select_option = "CNN";

	label = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"];
	if (firstTime == 0) {
		loadChart(label, data, select_option);
		firstTime = 1;
	} else {
		chart.destroy();
		loadChart(label, data, select_option);
	}
	document.getElementById('chart_box').style.display = "block";
}

function displayLabel(data) {
	var max = data[0];
    var maxIndex = 0;

    for (var i = 1; i < data.length; i++) {
        if (data[i] > max) {
            maxIndex = i;
            max = data[i];
        }
    }
	$(".prediction-text").html("Predicting you draw <b>"+maxIndex+"</b> with <b>"+Math.trunc( max*100 )+"%</b> confidence")
}

Finally, testing

handwritten digit recognition demo

That’s it! As you can see we get a pretty good result, it predict I have drawn digit number 5 with 99% confidence


Without doubt, there are so many researches had been conducted in the filed of handwritten recognition. And tensorflow.js make these pre-trained deep models accessible in the browser. From this article I hope you have fun and I encourage you to discover more about this library. I can’t wait to see more creative ideas come up to put this cutting edge technology into practical usage.

Thank you for reading. If you like this article, please share on Facebook or Twitter. Let me know in the comment if you have any questions. Follow me on Medium, GitHub and Linkedin. Support me on Ko-fi.

15 Comments

  • Mark on Nov 4, 2019, Reply

    Very helpful, well explained, easy to follow and understood.
    Owner is very pro active and reply to any questions immediately.

  • RAJAT on Apr 18, 2020, Reply

    SIr, project is very simply understandable. Thanks for this project.

    • BHAVESH KUMAR on Apr 11, 2023, Reply

      hey bro my project is not running what should i do

  • Rodrigo Faustino (@RodrigoFautino) on May 7, 2020, Reply

    congratulations on the work, amazing, I tested your project, what I would like to know is how to train it could not be replaced to save in json mode or better in a database to compare later, my idea is to do a signature authentication, a screen of register where the user registers his signature as if it were a username and password, and on the login page he only signs to enter

    • benson_ruan on May 7, 2020, Reply

      Hi Rodrigo, from what you describe, you are looking for image similarity comparison. Compare how similar between 2 signatures. I found this Github repo might be useful for you : https://github.com/ftlabs/Hancock

  • Sarita on May 11, 2020, Reply

    will you please tell me
    which tensorflowjs api should i use as when i was using tensorflowjs latest api i getting so many exection and when i used tensorflowjs api(1.2.6) i am getting ConnectionResetError.

  • Yeo Keat on Nov 26, 2020, Reply

    Hi Sir, thank you for sharing such a great tutorial for us! I keen to try out this application and run on local server, but I got this issue and I have google the solution tried several ways, the problem still unable to solve. Therefore, I hope to get advise from you on how to solve.

    The problem that I encountered as follows:

    Fetch API cannot load file:///D:/Hand-Written-Digit-Recognition-master/models/model.json. URL scheme must be “http” or “https” for CORS request.

    • benson_ruan on Dec 6, 2020, Reply

      Hi, after check out code from github repo, you can’t just open the html file, you need to set up your localhost to point to the root directory of the code, and access through http://localhost/index.html

  • Nikhil on Feb 9, 2022, Reply

    How to set up your localhost to point to the root directory of the code, and access through http://localhost/index.html?

    • benson_ruan on Feb 28, 2023, Reply

      1. Install a web server such as Apache or Nginx on your local machine.
      2. Configure your web server to point to the root directory of your code.
      3. Start the web server.
      4. Open a web browser and go to http://localhost/index.html.

  • Antara on Jul 7, 2023, Reply

    what library do we use if we wanna detect alphabets along with numbers

Leave a Reply