{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Converting a TensorFlow.js Speech-Commands Model to Python and TFLite formats\n",
        "\n",
        "This notebook showcases how to convert a [TensorFlow.js (TF.js) Speech Commands model](https://www.npmjs.com/package/@tensorflow-models/speech-commands) to the Python (`tensorflow.keras`) and  [TFLite](https://www.tensorflow.org/lite) formats. The TFLite format enables the model to be deployed to mobile enviroments such as Android phones.\n",
        "\n",
        "The technique outlined in this notebook are applicable to:\n",
        "- the original Speech Commands models (including the 18w and directional4w) variants,\n",
        "- transfer-learned models based on the original models, which can be trained and exported from [Teachable Machine's Audio Project](https://teachablemachine.withgoogle.com/train/audio)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First, install the required `tensorflow` and `tensorflowjs` Python packages."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "scrolled": false
      },
      "outputs": [],
      "source": [
        "# We need scipy for .wav file IO.\n",
        "!pip install tensorflowjs==2.1.0 scipy==1.4.1\n",
        "# TensorFlow 2.3.0 is required due to https://github.com/tensorflow/tensorflow/issues/38135\n",
        "# TODO: Switch to 2.3.0 final release when it comes out.\n",
        "!pip install tensorflow-cpu==2.3.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Below we download the files of the original or transfer-learned TF.js Speech Commands model. \n",
        "The code example here downloads the original model. But the approach is the same for a transfer-learned model downloaded from Teachable Machine, except that the files may come in as a ZIP archive in the case of Teachable Machine and hence requires unzippping."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {},
      "outputs": [],
      "source": [
        "!mkdir -p /tmp/tfjs-sc-model\n",
        "!curl -o /tmp/tfjs-sc-model/metadata.json -fsSL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/metadata.json\n",
        "!curl -o /tmp/tfjs-sc-model/model.json -fsSL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/model.json\n",
        "!curl -o /tmp/tfjs-sc-model/group1-shard1of2 -fSsL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/group1-shard1of2\n",
        "!curl -o /tmp/tfjs-sc-model/group1-shard2of2 -fsSL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/group1-shard2of2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {},
      "outputs": [],
      "source": [
        "import json\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflowjs as tfjs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Specify the path to the TensorFlow.js Speech Commands model,\n",
        "# either original or transfer-learned on https://teachablemachine.withgoogle.com/)\n",
        "tfjs_model_json_path = '/tmp/tfjs-sc-model/model.json'\n",
        "\n",
        "# This is the main classifier model.\n",
        "model = tfjs.converters.load_keras_model(tfjs_model_json_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "As a required step, we download the audio preprocessing layer that replicates\n",
        "[WebAudio](https://developer.mozilla.org/en-US/docs/Web/API/Web_Audio_API)'s\n",
        "[Fourier transform](https://en.wikipedia.org/wiki/Fast_Fourier_transform) for\n",
        "non-browser environments such as Android phones."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "./sc_preproc_model/\r\n",
            "./sc_preproc_model/assets/\r\n",
            "./sc_preproc_model/variables/\r\n",
            "./sc_preproc_model/variables/variables.data-00000-of-00001\r\n",
            "./sc_preproc_model/variables/variables.index\r\n",
            "./sc_preproc_model/saved_model.pb\r\n"
          ]
        }
      ],
      "source": [
        "!curl -o /tmp/tfjs-sc-model/sc_preproc_model.tar.gz -fSsL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/conversion/sc_preproc_model.tar.gz\n",
        "!cd /tmp/tfjs-sc-model && tar xzvf ./sc_preproc_model.tar.gz"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n",
            "Model: \"audio_preproc\"\n",
            "_________________________________________________________________\n",
            "Layer (type)                 Output Shape              Param #   \n",
            "=================================================================\n",
            "audio_preprocessing_layer (A (None, None, None, 1)     2048      \n",
            "=================================================================\n",
            "Total params: 2,048\n",
            "Trainable params: 0\n",
            "Non-trainable params: 2,048\n",
            "_________________________________________________________________\n",
            "Input audio length = 44032\n"
          ]
        }
      ],
      "source": [
        "# Load the preprocessing layer (wrapped in a tf.keras Model).\n",
        "preproc_model_path = '/tmp/tfjs-sc-model/sc_preproc_model'\n",
        "preproc_model = tf.keras.models.load_model(preproc_model_path)\n",
        "preproc_model.summary()\n",
        "\n",
        "# From the input_shape of the preproc_model, we can determine the\n",
        "# required length of the input audio snippet.\n",
        "input_length = preproc_model.input_shape[-1]\n",
        "print(\"Input audio length = %d\" % input_length)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "scrolled": true
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Model: \"combined_model\"\n",
            "_________________________________________________________________\n",
            "Layer (type)                 Output Shape              Param #   \n",
            "=================================================================\n",
            "audio_preproc (Sequential)   (None, None, None, 1)     2048      \n",
            "_________________________________________________________________\n",
            "sequential (Sequential)      (None, 20)                1468684   \n",
            "=================================================================\n",
            "Total params: 1,470,732\n",
            "Trainable params: 1,468,684\n",
            "Non-trainable params: 2,048\n",
            "_________________________________________________________________\n"
          ]
        }
      ],
      "source": [
        "# Construct the new non-browser model by combining the preprocessing\n",
        "# layer with the main classifier model.\n",
        "\n",
        "combined_model = tf.keras.Sequential(name='combined_model')\n",
        "combined_model.add(preproc_model)\n",
        "combined_model.add(model)\n",
        "combined_model.build([None, input_length])\n",
        "combined_model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In order to quickly test that the converted model works, let's download a sample .wav file."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {},
      "outputs": [],
      "source": [
        "!curl -o /tmp/tfjs-sc-model/audio_sample_one_male_adult.wav -fSsL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/conversion/audio_sample_one_male_adult.wav"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/html": [
              "\n",
              "                \n",
              "              "
            ],
            "text/plain": [
              ""
            ]
          },
          "execution_count": 8,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Listen to the audio sample.\n",
        "wav_file_path = '/tmp/tfjs-sc-model/audio_sample_one_male_adult.wav'\n",
        "import IPython.display as ipd\n",
        "ipd.Audio(wav_file_path)  # Play the .wav file."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Read the wav file and truncate it to the an input length\n",
        "# suitable for the model.\n",
        "from scipy.io import wavfile\n",
        "\n",
        "# fs: sample rate in Hz; xs: the audio PCM samples.\n",
        "fs, xs = wavfile.read(wav_file_path)\n",
        "\n",
        "if len(xs) >= input_length:\n",
        "    xs = xs[:input_length]\n",
        "else:\n",
        "    raise ValueError(\"Audio from .wav file is too short\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "top-5 class probabilities:\n",
            "                 one: 1.0000e+00\n",
            "                nine: 5.0455e-19\n",
            "           _unknown_: 1.0553e-20\n",
            "                down: 4.0031e-26\n",
            "                  no: 3.8358e-26\n"
          ]
        }
      ],
      "source": [
        "# Try running some examples through the combined model.\n",
        "input_tensor = tf.constant(xs, shape=(1, input_length), dtype=tf.float32) / 32768.0\n",
        "# The model outputs the probabilties for the classes (`probs`).\n",
        "probs = combined_model.predict(input_tensor)\n",
        "\n",
        "# Read class labels of the model.\n",
        "metadata_json_path = '/tmp/tfjs-sc-model/metadata.json'\n",
        "\n",
        "with open(metadata_json_path, 'r') as f:\n",
        "    metadata = json.load(f)\n",
        "    class_labels = metadata[\"words\"]\n",
        "\n",
        "# Get sorted probabilities and their corresponding class labels.\n",
        "probs_and_labels = list(zip(probs[0].tolist(), class_labels))\n",
        "# Sort the probabilities in descending order.\n",
        "probs_and_labels = sorted(probs_and_labels, key=lambda x: -x[0])\n",
        "probs_and_labels\n",
        "# len(probs_and_labels)\n",
        "\n",
        "# Print the top-5 labels:\n",
        "print('top-5 class probabilities:')\n",
        "for i in range(5):\n",
        "    prob, label = probs_and_labels[i]\n",
        "    print('%20s: %.4e' % (label, prob))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:From /usr/local/google/home/cais/venv_tfjs/lib/python3.7/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
            "WARNING:tensorflow:From /usr/local/google/home/cais/venv_tfjs/lib/python3.7/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
            "INFO:tensorflow:Assets written to: /tmp/tmplb12fskv/assets\n",
            "Saved tflite file at: /tmp/tfjs-sc-model/combined_model.tflite\n"
          ]
        }
      ],
      "source": [
        "# Save the model as a tflite file.\n",
        "tflite_output_path = '/tmp/tfjs-sc-model/combined_model.tflite'\n",
        "converter = tf.lite.TFLiteConverter.from_keras_model(combined_model)\n",
        "converter.target_spec.supported_ops = [\n",
        "    tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS\n",
        "]\n",
        "with open(tflite_output_path, 'wb') as f:\n",
        "    f.write(converter.convert())\n",
        "print(\"Saved tflite file at: %s\" % tflite_output_path)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "tflite_conversion.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}