{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Converting a TensorFlow.js Speech-Commands Model to Python and TFLite formats\n", "\n", "This notebook showcases how to convert a [TensorFlow.js (TF.js) Speech Commands model](https://www.npmjs.com/package/@tensorflow-models/speech-commands) to the Python (`tensorflow.keras`) and [TFLite](https://www.tensorflow.org/lite) formats. The TFLite format enables the model to be deployed to mobile enviroments such as Android phones.\n", "\n", "The technique outlined in this notebook are applicable to:\n", "- the original Speech Commands models (including the 18w and directional4w) variants,\n", "- transfer-learned models based on the original models, which can be trained and exported from [Teachable Machine's Audio Project](https://teachablemachine.withgoogle.com/train/audio)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, install the required `tensorflow` and `tensorflowjs` Python packages." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "# We need scipy for .wav file IO.\n", "!pip install tensorflowjs==2.1.0 scipy==1.4.1\n", "# TensorFlow 2.3.0 is required due to https://github.com/tensorflow/tensorflow/issues/38135\n", "# TODO: Switch to 2.3.0 final release when it comes out.\n", "!pip install tensorflow-cpu==2.3.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below we download the files of the original or transfer-learned TF.js Speech Commands model. \n", "The code example here downloads the original model. But the approach is the same for a transfer-learned model downloaded from Teachable Machine, except that the files may come in as a ZIP archive in the case of Teachable Machine and hence requires unzippping." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "!mkdir -p /tmp/tfjs-sc-model\n", "!curl -o /tmp/tfjs-sc-model/metadata.json -fsSL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/metadata.json\n", "!curl -o /tmp/tfjs-sc-model/model.json -fsSL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/model.json\n", "!curl -o /tmp/tfjs-sc-model/group1-shard1of2 -fSsL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/group1-shard1of2\n", "!curl -o /tmp/tfjs-sc-model/group1-shard2of2 -fsSL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/group1-shard2of2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "import tensorflow as tf\n", "import tensorflowjs as tfjs" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Specify the path to the TensorFlow.js Speech Commands model,\n", "# either original or transfer-learned on https://teachablemachine.withgoogle.com/)\n", "tfjs_model_json_path = '/tmp/tfjs-sc-model/model.json'\n", "\n", "# This is the main classifier model.\n", "model = tfjs.converters.load_keras_model(tfjs_model_json_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As a required step, we download the audio preprocessing layer that replicates\n", "[WebAudio](https://developer.mozilla.org/en-US/docs/Web/API/Web_Audio_API)'s\n", "[Fourier transform](https://en.wikipedia.org/wiki/Fast_Fourier_transform) for\n", "non-browser environments such as Android phones." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "./sc_preproc_model/\r\n", "./sc_preproc_model/assets/\r\n", "./sc_preproc_model/variables/\r\n", "./sc_preproc_model/variables/variables.data-00000-of-00001\r\n", "./sc_preproc_model/variables/variables.index\r\n", "./sc_preproc_model/saved_model.pb\r\n" ] } ], "source": [ "!curl -o /tmp/tfjs-sc-model/sc_preproc_model.tar.gz -fSsL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/conversion/sc_preproc_model.tar.gz\n", "!cd /tmp/tfjs-sc-model && tar xzvf ./sc_preproc_model.tar.gz" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n", "Model: \"audio_preproc\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "audio_preprocessing_layer (A (None, None, None, 1) 2048 \n", "=================================================================\n", "Total params: 2,048\n", "Trainable params: 0\n", "Non-trainable params: 2,048\n", "_________________________________________________________________\n", "Input audio length = 44032\n" ] } ], "source": [ "# Load the preprocessing layer (wrapped in a tf.keras Model).\n", "preproc_model_path = '/tmp/tfjs-sc-model/sc_preproc_model'\n", "preproc_model = tf.keras.models.load_model(preproc_model_path)\n", "preproc_model.summary()\n", "\n", "# From the input_shape of the preproc_model, we can determine the\n", "# required length of the input audio snippet.\n", "input_length = preproc_model.input_shape[-1]\n", "print(\"Input audio length = %d\" % input_length)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"combined_model\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "audio_preproc (Sequential) (None, None, None, 1) 2048 \n", "_________________________________________________________________\n", "sequential (Sequential) (None, 20) 1468684 \n", "=================================================================\n", "Total params: 1,470,732\n", "Trainable params: 1,468,684\n", "Non-trainable params: 2,048\n", "_________________________________________________________________\n" ] } ], "source": [ "# Construct the new non-browser model by combining the preprocessing\n", "# layer with the main classifier model.\n", "\n", "combined_model = tf.keras.Sequential(name='combined_model')\n", "combined_model.add(preproc_model)\n", "combined_model.add(model)\n", "combined_model.build([None, input_length])\n", "combined_model.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to quickly test that the converted model works, let's download a sample .wav file." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "!curl -o /tmp/tfjs-sc-model/audio_sample_one_male_adult.wav -fSsL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/conversion/audio_sample_one_male_adult.wav" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Listen to the audio sample.\n", "wav_file_path = '/tmp/tfjs-sc-model/audio_sample_one_male_adult.wav'\n", "import IPython.display as ipd\n", "ipd.Audio(wav_file_path) # Play the .wav file." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Read the wav file and truncate it to the an input length\n", "# suitable for the model.\n", "from scipy.io import wavfile\n", "\n", "# fs: sample rate in Hz; xs: the audio PCM samples.\n", "fs, xs = wavfile.read(wav_file_path)\n", "\n", "if len(xs) >= input_length:\n", " xs = xs[:input_length]\n", "else:\n", " raise ValueError(\"Audio from .wav file is too short\")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "top-5 class probabilities:\n", " one: 1.0000e+00\n", " nine: 5.0455e-19\n", " _unknown_: 1.0553e-20\n", " down: 4.0031e-26\n", " no: 3.8358e-26\n" ] } ], "source": [ "# Try running some examples through the combined model.\n", "input_tensor = tf.constant(xs, shape=(1, input_length), dtype=tf.float32) / 32768.0\n", "# The model outputs the probabilties for the classes (`probs`).\n", "probs = combined_model.predict(input_tensor)\n", "\n", "# Read class labels of the model.\n", "metadata_json_path = '/tmp/tfjs-sc-model/metadata.json'\n", "\n", "with open(metadata_json_path, 'r') as f:\n", " metadata = json.load(f)\n", " class_labels = metadata[\"words\"]\n", "\n", "# Get sorted probabilities and their corresponding class labels.\n", "probs_and_labels = list(zip(probs[0].tolist(), class_labels))\n", "# Sort the probabilities in descending order.\n", "probs_and_labels = sorted(probs_and_labels, key=lambda x: -x[0])\n", "probs_and_labels\n", "# len(probs_and_labels)\n", "\n", "# Print the top-5 labels:\n", "print('top-5 class probabilities:')\n", "for i in range(5):\n", " prob, label = probs_and_labels[i]\n", " print('%20s: %.4e' % (label, prob))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /usr/local/google/home/cais/venv_tfjs/lib/python3.7/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n", "WARNING:tensorflow:From /usr/local/google/home/cais/venv_tfjs/lib/python3.7/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n", "INFO:tensorflow:Assets written to: /tmp/tmplb12fskv/assets\n", "Saved tflite file at: /tmp/tfjs-sc-model/combined_model.tflite\n" ] } ], "source": [ "# Save the model as a tflite file.\n", "tflite_output_path = '/tmp/tfjs-sc-model/combined_model.tflite'\n", "converter = tf.lite.TFLiteConverter.from_keras_model(combined_model)\n", "converter.target_spec.supported_ops = [\n", " tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS\n", "]\n", "with open(tflite_output_path, 'wb') as f:\n", " f.write(converter.convert())\n", "print(\"Saved tflite file at: %s\" % tflite_output_path)" ] } ], "metadata": { "colab": { "name": "tflite_conversion.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }