{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Converting a TensorFlow.js Speech-Commands Model to Python and TFLite formats\n",
"\n",
"This notebook showcases how to convert a [TensorFlow.js (TF.js) Speech Commands model](https://www.npmjs.com/package/@tensorflow-models/speech-commands) to the Python (`tensorflow.keras`) and [TFLite](https://www.tensorflow.org/lite) formats. The TFLite format enables the model to be deployed to mobile enviroments such as Android phones.\n",
"\n",
"The technique outlined in this notebook are applicable to:\n",
"- the original Speech Commands models (including the 18w and directional4w) variants,\n",
"- transfer-learned models based on the original models, which can be trained and exported from [Teachable Machine's Audio Project](https://teachablemachine.withgoogle.com/train/audio)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, install the required `tensorflow` and `tensorflowjs` Python packages."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"# We need scipy for .wav file IO.\n",
"!pip install tensorflowjs==2.1.0 scipy==1.4.1\n",
"# TensorFlow 2.3.0 is required due to https://github.com/tensorflow/tensorflow/issues/38135\n",
"# TODO: Switch to 2.3.0 final release when it comes out.\n",
"!pip install tensorflow-cpu==2.3.0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Below we download the files of the original or transfer-learned TF.js Speech Commands model. \n",
"The code example here downloads the original model. But the approach is the same for a transfer-learned model downloaded from Teachable Machine, except that the files may come in as a ZIP archive in the case of Teachable Machine and hence requires unzippping."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"!mkdir -p /tmp/tfjs-sc-model\n",
"!curl -o /tmp/tfjs-sc-model/metadata.json -fsSL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/metadata.json\n",
"!curl -o /tmp/tfjs-sc-model/model.json -fsSL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/model.json\n",
"!curl -o /tmp/tfjs-sc-model/group1-shard1of2 -fSsL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/group1-shard1of2\n",
"!curl -o /tmp/tfjs-sc-model/group1-shard2of2 -fsSL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/group1-shard2of2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"import tensorflow as tf\n",
"import tensorflowjs as tfjs"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Specify the path to the TensorFlow.js Speech Commands model,\n",
"# either original or transfer-learned on https://teachablemachine.withgoogle.com/)\n",
"tfjs_model_json_path = '/tmp/tfjs-sc-model/model.json'\n",
"\n",
"# This is the main classifier model.\n",
"model = tfjs.converters.load_keras_model(tfjs_model_json_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As a required step, we download the audio preprocessing layer that replicates\n",
"[WebAudio](https://developer.mozilla.org/en-US/docs/Web/API/Web_Audio_API)'s\n",
"[Fourier transform](https://en.wikipedia.org/wiki/Fast_Fourier_transform) for\n",
"non-browser environments such as Android phones."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"./sc_preproc_model/\r\n",
"./sc_preproc_model/assets/\r\n",
"./sc_preproc_model/variables/\r\n",
"./sc_preproc_model/variables/variables.data-00000-of-00001\r\n",
"./sc_preproc_model/variables/variables.index\r\n",
"./sc_preproc_model/saved_model.pb\r\n"
]
}
],
"source": [
"!curl -o /tmp/tfjs-sc-model/sc_preproc_model.tar.gz -fSsL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/conversion/sc_preproc_model.tar.gz\n",
"!cd /tmp/tfjs-sc-model && tar xzvf ./sc_preproc_model.tar.gz"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n",
"Model: \"audio_preproc\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"audio_preprocessing_layer (A (None, None, None, 1) 2048 \n",
"=================================================================\n",
"Total params: 2,048\n",
"Trainable params: 0\n",
"Non-trainable params: 2,048\n",
"_________________________________________________________________\n",
"Input audio length = 44032\n"
]
}
],
"source": [
"# Load the preprocessing layer (wrapped in a tf.keras Model).\n",
"preproc_model_path = '/tmp/tfjs-sc-model/sc_preproc_model'\n",
"preproc_model = tf.keras.models.load_model(preproc_model_path)\n",
"preproc_model.summary()\n",
"\n",
"# From the input_shape of the preproc_model, we can determine the\n",
"# required length of the input audio snippet.\n",
"input_length = preproc_model.input_shape[-1]\n",
"print(\"Input audio length = %d\" % input_length)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"combined_model\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"audio_preproc (Sequential) (None, None, None, 1) 2048 \n",
"_________________________________________________________________\n",
"sequential (Sequential) (None, 20) 1468684 \n",
"=================================================================\n",
"Total params: 1,470,732\n",
"Trainable params: 1,468,684\n",
"Non-trainable params: 2,048\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"# Construct the new non-browser model by combining the preprocessing\n",
"# layer with the main classifier model.\n",
"\n",
"combined_model = tf.keras.Sequential(name='combined_model')\n",
"combined_model.add(preproc_model)\n",
"combined_model.add(model)\n",
"combined_model.build([None, input_length])\n",
"combined_model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to quickly test that the converted model works, let's download a sample .wav file."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"!curl -o /tmp/tfjs-sc-model/audio_sample_one_male_adult.wav -fSsL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/conversion/audio_sample_one_male_adult.wav"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Listen to the audio sample.\n",
"wav_file_path = '/tmp/tfjs-sc-model/audio_sample_one_male_adult.wav'\n",
"import IPython.display as ipd\n",
"ipd.Audio(wav_file_path) # Play the .wav file."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# Read the wav file and truncate it to the an input length\n",
"# suitable for the model.\n",
"from scipy.io import wavfile\n",
"\n",
"# fs: sample rate in Hz; xs: the audio PCM samples.\n",
"fs, xs = wavfile.read(wav_file_path)\n",
"\n",
"if len(xs) >= input_length:\n",
" xs = xs[:input_length]\n",
"else:\n",
" raise ValueError(\"Audio from .wav file is too short\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"top-5 class probabilities:\n",
" one: 1.0000e+00\n",
" nine: 5.0455e-19\n",
" _unknown_: 1.0553e-20\n",
" down: 4.0031e-26\n",
" no: 3.8358e-26\n"
]
}
],
"source": [
"# Try running some examples through the combined model.\n",
"input_tensor = tf.constant(xs, shape=(1, input_length), dtype=tf.float32) / 32768.0\n",
"# The model outputs the probabilties for the classes (`probs`).\n",
"probs = combined_model.predict(input_tensor)\n",
"\n",
"# Read class labels of the model.\n",
"metadata_json_path = '/tmp/tfjs-sc-model/metadata.json'\n",
"\n",
"with open(metadata_json_path, 'r') as f:\n",
" metadata = json.load(f)\n",
" class_labels = metadata[\"words\"]\n",
"\n",
"# Get sorted probabilities and their corresponding class labels.\n",
"probs_and_labels = list(zip(probs[0].tolist(), class_labels))\n",
"# Sort the probabilities in descending order.\n",
"probs_and_labels = sorted(probs_and_labels, key=lambda x: -x[0])\n",
"probs_and_labels\n",
"# len(probs_and_labels)\n",
"\n",
"# Print the top-5 labels:\n",
"print('top-5 class probabilities:')\n",
"for i in range(5):\n",
" prob, label = probs_and_labels[i]\n",
" print('%20s: %.4e' % (label, prob))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /usr/local/google/home/cais/venv_tfjs/lib/python3.7/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
"WARNING:tensorflow:From /usr/local/google/home/cais/venv_tfjs/lib/python3.7/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
"INFO:tensorflow:Assets written to: /tmp/tmplb12fskv/assets\n",
"Saved tflite file at: /tmp/tfjs-sc-model/combined_model.tflite\n"
]
}
],
"source": [
"# Save the model as a tflite file.\n",
"tflite_output_path = '/tmp/tfjs-sc-model/combined_model.tflite'\n",
"converter = tf.lite.TFLiteConverter.from_keras_model(combined_model)\n",
"converter.target_spec.supported_ops = [\n",
" tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS\n",
"]\n",
"with open(tflite_output_path, 'wb') as f:\n",
" f.write(converter.convert())\n",
"print(\"Saved tflite file at: %s\" % tflite_output_path)"
]
}
],
"metadata": {
"colab": {
"name": "tflite_conversion.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}