[MF]修改姿态和音频实现[DOC]更新姿态文档[CF]创建音频文档

This commit is contained in:
51hhh 2025-08-14 11:02:56 +08:00
parent 6b9c260767
commit 9a8a767656
62 changed files with 23177 additions and 4 deletions

View File

@ -29,6 +29,15 @@
- 应用将开始实时分析你的姿态,并在下方显示预测结果。 - 应用将开始实时分析你的姿态,并在下方显示预测结果。
- 再次点击“停止预测”可暂停。 - 再次点击“停止预测”可暂停。
建议使用live server插件开启本地服务器并访问index.html文件。
## 存在问题
css样式中描绘关节骨架与实际对比偏小现在使用的方法是模型输出关节后与视频大小一起缩放。暂时不清楚是缩放导致还是模型输出。
## 📁 项目结构 ## 📁 项目结构
``` ```

View File

@ -318,7 +318,6 @@ function flattenPose(pose) {
} }
function drawPose(pose) { function drawPose(pose) {
// ... (此函数无需修改, 省略以保持简洁)
// 绘制关键点和骨骼... // 绘制关键点和骨骼...
if (pose.keypoints) { if (pose.keypoints) {
// 绘制关键点 // 绘制关键点

50
音频分类/README.md Normal file
View File

@ -0,0 +1,50 @@
# 浏览器音频分类器 (背景噪音分离增强版)
## 简介
这个项目是一个基于浏览器的音频分类器,它利用 TensorFlow.js 和 Speech Commands 模型,可以识别用户自定义的声音类别。**与传统音频分类器不同的是,此版本特别强调了背景噪音的分离和处理,从而提升分类准确率。**
此应用允许用户:
1. **录制背景噪音样本:** 用于训练模型,区分目标声音和环境噪音。
2. **添加自定义声音类别:** 例如 "拍手"、"响指"、"警告音" 等。
3. **录制自定义声音样本:** 用于训练模型,识别特定声音。
4. **训练模型:** 使用录制的背景噪音和自定义声音数据,训练分类模型。
5. **实时识别:** 使用训练好的模型,实时识别麦克风输入的声音类别。
## 特性
* **背景噪音分离:** 通过录制和学习背景噪音,提高分类器在嘈杂环境中的准确性。
* **自定义类别:** 用户可以根据自己的需求添加和训练任意声音类别。
* **实时识别:** 模型训练完成后,可以立即进行实时声音识别。
* **浏览器端运行:** 所有处理都在浏览器端完成,无需服务器支持。
* **用户友好的界面:** 简单直观的界面,易于操作和使用。
## 技术栈
* **TensorFlow.js:** 用于在浏览器端运行机器学习模型。
* **Speech Commands Model:** TensorFlow.js 提供的预训练语音命令模型,用于迁移学习。
## 快速上手
建议使用live server插件开启本地服务器并访问voice.html文件。
**注直接打开index.html文件会需要重复授权麦克风权限请使用live server插件开启本地服务器可以解决**
## 音频切片
TODO
需要查看speech-commands接口实现传入`collectExample`
.\speech-commands\src\browser_fft_recognizer.ts
667,9: async collectExample(word: string, options?: ExampleCollectionOptions):
现在实现的方法是调用`collectExample`方法,传入`word`参数,然后会自动录制音频文件,统一码率,生成频谱图,传入模型,并没有给出可以直接传入音频接口。
**如果需要实现一次性录制,需要实现手动将音频文件转换成频谱图,传入模型。**
目录下speech-commands文件夹是导入的`https://cdn.jsdelivr.net/npm/@tensorflow-models/speech-commands@latest/dist/speech-commands.min.js`js文件仓库

18
音频分类/speech-commands/.gitignore vendored Normal file
View File

@ -0,0 +1,18 @@
node_modules/
coverage/
package-lock.json
npm-debug.log
yarn-error.log
.DS_Store
dist/
.idea/
*.tgz
.cache
bazel-*
*.pyc
model.json
metadata.json
weights.bin

View File

@ -0,0 +1,16 @@
.yalc/
.vscode/
.rpt2_cache/
demo/
scripts/
src/
training/
coverage/
node_modules/
karma.conf.js
*.tgz
.travis.yml
.npmignore
tslint.json
yarn.lock
yalc.lock

View File

@ -0,0 +1,26 @@
{
"search.exclude": {
"**/node_modules": true,
"coverage/": true,
"**/dist/": true,
"**/yarn.lock": true,
"**/.rpt2_cache/": true,
"**/.yalc/": true
},
"tslint.enable": true,
"tslint.run": "onType",
"tslint.configFile": "tslint.json",
"files.trimTrailingWhitespace": true,
"editor.tabSize": 2,
"editor.insertSpaces": true,
"[typescript]": {
"editor.formatOnSave": true
},
"editor.rulers": [80],
"clang-format.style": "Google",
"files.insertFinalNewline": true,
"editor.detectIndentation": false,
"editor.wrappingIndent": "none",
"typescript.tsdk": "./node_modules/typescript/lib",
"clang-format.executable": "${workspaceRoot}/node_modules/.bin/clang-format"
}

View File

@ -0,0 +1,319 @@
# Speech Command Recognizer
The Speech Command Recognizer is a JavaScript module that enables
recognition of spoken commands comprised of simple isolated English
words from a small vocabulary. The default vocabulary includes the following
words: the ten digits from "zero" to "nine", "up", "down", "left", "right",
"go", "stop", "yes", "no", as well as the additional categories of
"unknown word" and "background noise".
It uses the web browser's
[WebAudio API](https://developer.mozilla.org/en-US/docs/Web/API/Web_Audio_API).
It is built on top of [TensorFlow.js](https://js.tensorflow.org) and can
perform inference and transfer learning entirely in the browser, using
WebGL GPU acceleration.
The underlying deep neural network has been trained using the
[TensorFlow Speech Commands Dataset](https://www.tensorflow.org/datasets/catalog/speech_commands).
For more details on the data set, see:
Warden, P. (2018) "Speech commands: A dataset for limited-vocabulary
speech recognition" https://arxiv.org/pdf/1804.03209.pdf
## API Usage
A speech command recognizer can be used in two ways:
1. **Online streaming recognition**, during which the library automatically
opens an audio input channel using the browser's
[`getUserMedia`](https://developer.mozilla.org/en-US/docs/Web/API/MediaDevices/getUserMedia)
and
[WebAudio](https://developer.mozilla.org/en-US/docs/Web/API/Web_Audio_API)
APIs (requesting permission from user) and performs real-time recognition on
the audio input.
2. **Offline recognition**, in which you provide a pre-constructed TensorFlow.js
[Tensor](https://js.tensorflow.org/api/latest/#tensor) object or a
`Float32Array` and the recognizer will return the recognition results.
### Online streaming recognition
To use the speech-command recognizer, first create a recognizer instance,
then start the streaming recognition by calling its `listen()` method.
```js
const tf = require('@tensorflow/tfjs');
const speechCommands = require('@tensorflow-models/speech-commands');
// When calling `create()`, you must provide the type of the audio input.
// The two available options are `BROWSER_FFT` and `SOFT_FFT`.
// - BROWSER_FFT uses the browser's native Fourier transform.
// - SOFT_FFT uses JavaScript implementations of Fourier transform
// (not implemented yet).
const recognizer = speechCommands.create('BROWSER_FFT');
// Make sure that the underlying model and metadata are loaded via HTTPS
// requests.
await recognizer.ensureModelLoaded();
// See the array of words that the recognizer is trained to recognize.
console.log(recognizer.wordLabels());
// `listen()` takes two arguments:
// 1. A callback function that is invoked anytime a word is recognized.
// 2. A configuration object with adjustable fields such a
// - includeSpectrogram
// - probabilityThreshold
// - includeEmbedding
recognizer.listen(result => {
// - result.scores contains the probability scores that correspond to
// recognizer.wordLabels().
// - result.spectrogram contains the spectrogram of the recognized word.
}, {
includeSpectrogram: true,
probabilityThreshold: 0.75
});
// Stop the recognition in 10 seconds.
setTimeout(() => recognizer.stopListening(), 10e3);
```
#### Vocabularies
When calling `speechCommands.create()`, you can specify the vocabulary
the loaded model will be able to recognize. This is specified as the second,
optional argument to `speechCommands.create()`. For example:
```js
const recognizer = speechCommands.create('BROWSER_FFT', 'directional4w');
```
Currently, the supported vocabularies are:
- '18w' (default): The 20 item vocaulbary, consisting of:
'zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven',
'eight', 'nine', 'up', 'down', 'left', 'right', 'go', 'stop',
'yes', and 'no', in addition to '_background_noise_' and '_unknown_'.
- 'directional4w': The four directional words: 'up', 'down', 'left', and
'right', in addition to '_background_noise_' and '_unknown_'.
'18w' is the default vocabulary.
#### Parameters for online streaming recognition
As the example above shows, you can specify optional parameters when calling
`listen()`. The supported parameters are:
* `overlapFactor`: Controls how often the recognizer performs prediction on
spectrograms. Must be >=0 and <1 (default: 0.5). For example,
if each spectrogram is 1000 ms long and `overlapFactor` is set to 0.25,
the prediction will happen every 250 ms.
* `includeSpectrogram`: Let the callback function be invoked with the
spectrogram data included in the argument. Default: `false`.
* `probabilityThreshold`: The callback function will be invoked if and only if
the maximum probability score of all the words is greater than this threshold.
Default: `0`.
* `invokeCallbackOnNoiseAndUnknown`: Whether the callback function will be
invoked if the "word" with the maximum probability score is the "unknown"
or "background noise" token. Default: `false`.
* `includeEmbedding`: Whether an internal activation from the underlying model
will be included in the callback argument, in addition to the probability
scores. Note: if this field is set as `true`, the value of
`invokeCallbackOnNoiseAndUnknown` will be overridden to `true` and the
value of `probabilityThreshold` will be overridden to `0`.
### Offline recognition
To perform offline recognition, you need to have obtained the spectrogram
of an audio snippet through a certain means, e.g., by loading the data
from a .wav file or synthesizing the spectrogram programmatically.
Assuming you have the spectrogram stored in an Array of numbers or
a Float32Array, you can create a `tf.Tensor` object. Note that the
shape of the Tensor must match the expectation of the recognizer instance.
E.g.,
```js
const tf = require('@tensorflow/tfjs');
const speechCommands = require('@tensorflow-models/speech-commands');
const recognizer = speechCommands.create('BROWSER_FFT');
// Inspect the input shape of the recognizer's underlying tf.Model.
console.log(recognizer.modelInputShape());
// You will get something like [null, 43, 232, 1].
// - The first dimension (null) is an undetermined batch dimension.
// - The second dimension (e.g., 43) is the number of audio frames.
// - The third dimension (e.g., 232) is the number of frequency data points in
// every frame (i.e., column) of the spectrogram
// - The last dimension (e.g., 1) is fixed at 1. This follows the convention of
// convolutional neural networks in TensorFlow.js and Keras.
// Inspect the sampling frequency and FFT size:
console.log(recognizer.params().sampleRateHz);
console.log(recognizer.params().fftSize);
const x = tf.tensor4d(
mySpectrogramData, [1].concat(recognizer.modelInputShape().slice(1)));
const output = await recognizer.recognize(x);
// output has the same format as `result` in the online streaming example
// above: the `scores` field contains the probabilities of the words.
tf.dispose([x, output]);
```
Note that you must provide a spectrogram value to the `recognize()` call
in order to perform the offline recognition. If `recognize()` is called
without a first argument, it will perform one-shot online recognition
by collecting a frame of audio via WebAudio.
### Preloading model
By default, a recognizer object will load the underlying
tf.Model via HTTP requests to a centralized location, when its
`listen()` or `recognize()` method is called the first time.
You can pre-load the model to reduce the latency of the first calls
to these methods. To do that, use the `ensureModelLoaded()` method of the
recognizer object. The `ensureModelLoaded()` method also "warms up" model after
the model is loaded. "Warm up" means running a few dummy examples through the
model for inference to make sure that the necessary states are set up, so that
subsequent inferences can be fast.
### Transfer learning
**Transfer learning** is the process of taking a model trained
previously on a dataset (say dataset A) and applying it on a
different dataset (say dataset B).
To achieve transfer learning, the model needs to be slightly modified and
re-trained on dataset B. However, thanks to the training on
the original dataset (A), the training on the new dataset (B) takes much less
time and computational resource, in addition to requiring a much smaller amount of
data than the original training data. The modification process involves removing the
top (output) dense layer of the original model and keeping the "base" of the
model. Due to its previous training, the base can be used as a good feature
extractor for any data similar to the original training data.
The removed dense layer is replaced with a new dense layer configured
specifically for the new dataset.
The speech-command model is a model suitable for transfer learning on
previously unseen spoken words. The original model has been trained on a relatively
large dataset (~50k examples from 20 classes). It can be used for transfer learning on
words different from the original vocabulary. We provide an API to perform
this type of transfer learning. The steps are listed in the example
code snippet below
```js
const baseRecognizer = speechCommands.create('BROWSER_FFT');
await baseRecognizer.ensureModelLoaded();
// Each instance of speech-command recognizer supports multiple
// transfer-learning models, each of which can be trained for a different
// new vocabulary.
// Therefore we give a name to the transfer-learning model we are about to
// train ('colors' in this case).
const transferRecognizer = baseRecognizer.createTransfer('colors');
// Call `collectExample()` to collect a number of audio examples
// via WebAudio.
await transferRecognizer.collectExample('red');
await transferRecognizer.collectExample('green');
await transferRecognizer.collectExample('blue');
await transferRecognizer.collectExample('red');
// Don't forget to collect some background-noise examples, so that the
// transfer-learned model will be able to detect moments of silence.
await transferRecognizer.collectExample('_background_noise_');
await transferRecognizer.collectExample('green');
await transferRecognizer.collectExample('blue');
await transferRecognizer.collectExample('_background_noise_');
// ... You would typically want to put `collectExample`
// in the callback of a UI button to allow the user to collect
// any desired number of examples in random order.
// You can check the counts of examples for different words that have been
// collect for this transfer-learning model.
console.log(transferRecognizer.countExamples());
// e.g., {'red': 2, 'green': 2', 'blue': 2, '_background_noise': 2};
// Start training of the transfer-learning model.
// You can specify `epochs` (number of training epochs) and `callback`
// (the Model.fit callback to use during training), among other configuration
// fields.
await transferRecognizer.train({
epochs: 25,
callback: {
onEpochEnd: async (epoch, logs) => {
console.log(`Epoch ${epoch}: loss=${logs.loss}, accuracy=${logs.acc}`);
}
}
});
// After the transfer learning completes, you can start online streaming
// recognition using the new model.
await transferRecognizer.listen(result => {
// - result.scores contains the scores for the new vocabulary, which
// can be checked with:
const words = transferRecognizer.wordLabels();
// `result.scores` contains the scores for the new words, not the original
// words.
for (let i = 0; i < words.length; ++i) {
console.log(`score for word '${words[i]}' = ${result.scores[i]}`);
}
}, {probabilityThreshold: 0.75});
// Stop the recognition in 10 seconds.
setTimeout(() => transferRecognizer.stopListening(), 10e3);
```
### Serialize examples from a transfer recognizer.
Once examples has been collected with a transfer recognizer,
you can export the examples in serialized form with the `serielizedExamples()`
method, e.g.,
```js
const serialized = transferRecognizer.serializeExamples();
```
`serialized` is a binary `ArrayBuffer` amenable to storage and transmission.
It contains the spectrogram data of the examples, as well as metadata such
as word labels.
You can also serialize the examples from a subset of the words in the
transfer recognizer's vocabulary, e.g.,
```js
const serializedWithOnlyFoo = transferRecognizer.serializeExamples('foo');
// Or
const serializedWithOnlyFooAndBar = transferRecognizer.serializeExamples(['foo', 'bar']);
```
The serialized examples can later be loaded into another instance of
transfer recognizer with the `loadExamples()` method, e.g.,
```js
const clearExisting = false;
newTransferRecognizer.loadExamples(serialized, clearExisting);
```
Theo `clearExisting` flag ensures that the examples that `newTransferRecognizer`
already holds are preserved. If `true`, the existing exampels will be cleared.
If `clearExisting` is not specified, it'll default to `false`.
## Live demo
A developer-oriented live demo is available at
[this address](https://storage.googleapis.com/tfjs-speech-model-test/2019-01-03a/dist/index.html).
## How to run the demo from source code
The demo/ folder contains a live demo of the speech-command recognizer.
To run it, do
```sh
cd speech-commands
yarn
yarn publish-local
cd demo
yarn
yarn link-local
yarn watch
```

View File

@ -0,0 +1,48 @@
steps:
# Install common dependencies.
- name: 'node:16'
id: 'yarn-common'
entrypoint: 'yarn'
args: ['install']
# Install tfjs dependencies.
- name: 'node:16'
dir: 'speech-commands'
entrypoint: 'yarn'
id: 'yarn'
args: ['install']
waitFor: ['yarn-common']
# Lint.
- name: 'node:16'
dir: 'speech-commands'
entrypoint: 'yarn'
id: 'lint'
args: ['lint']
waitFor: ['yarn']
# Build.
- name: 'node:16'
dir: 'speech-commands'
entrypoint: 'yarn'
id: 'build'
args: ['build']
waitFor: ['yarn']
# Run tests.
- name: 'node:16'
dir: 'speech-commands'
entrypoint: 'yarn'
id: 'test'
args: ['test']
waitFor: ['yarn']
# General configuration
timeout: 1800s
logsBucket: 'gs://tfjs-build-logs'
substitutions:
_NIGHTLY: ''
options:
logStreamingOption: 'STREAM_ON'
substitution_option: 'ALLOW_LOOSE'

View File

@ -0,0 +1,18 @@
{
"presets": [
[
"env",
{
"esmodules": false,
"targets": {
"browsers": [
"> 3%"
]
}
}
]
],
"plugins": [
"@babel/plugin-transform-runtime"
]
}

View File

@ -0,0 +1,321 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as speechCommands from '../src';
import {plotSpectrogram} from './ui';
/** Remove the children of a div that do not have the isFixed attribute. */
export function removeNonFixedChildrenFromWordDiv(wordDiv) {
for (let i = wordDiv.children.length - 1; i >= 0; --i) {
if (wordDiv.children[i].getAttribute('isFixed') == null) {
wordDiv.removeChild(wordDiv.children[i]);
} else {
break;
}
}
}
/**
* Get the relative x-coordinate of a click event in a canvas.
*
* @param {HTMLCanvasElement} canvasElement The canvas in which the click
* event happened.
* @param {Event} event The click event object.
* @return {number} The relative x-coordinate: a `number` between 0 and 1.
*/
function getCanvasClickRelativeXCoordinate(canvasElement, event) {
let x;
if (event.pageX) {
x = event.pageX;
} else {
x = event.clientX + document.body.scrollLeft +
document.documentElement.scrollLeft;
}
x -= canvasElement.offsetLeft;
return x / canvasElement.width;
}
/**
* Dataset visualizer that supports
*
* - Display of words and spectrograms
* - Navigation through examples
* - Deletion of examples
*/
export class DatasetViz {
/**
* Constructor of DatasetViz
*
* @param {Object} transferRecognizer An instance of
* `speechCommands.TransferSpeechCommandRecognizer`.
* @param {HTMLDivElement} topLevelContainer The div element that
* holds the div elements for the individual words. It is assumed
* that each element has its "word" attribute set to the word.
* @param {number} minExamplesPerClass Minimum number of examples
* per word class required for the start-transfer-learning button
* to be enabled.
* @param {HTMLButtonElement} startTransferLearnButton The button
* which starts the transfer learning when clicked.
* @param {HTMLBUttonElement} downloadAsFileButton The button
* that triggers downloading of the dataset as a file when clicked.
* @param {number} transferDurationMultiplier Optional duration
* multiplier (the ratio between the length of the example
* and the length expected by the model.) Defaults to 1.
*/
constructor(
transferRecognizer, topLevelContainer, minExamplesPerClass,
startTransferLearnButton, downloadAsFileButton,
transferDurationMultiplier = 1) {
this.transferRecognizer = transferRecognizer;
this.container = topLevelContainer;
this.minExamplesPerClass = minExamplesPerClass;
this.startTransferLearnButton = startTransferLearnButton;
this.downloadAsFileButton = downloadAsFileButton;
this.transferDurationMultiplier = transferDurationMultiplier;
// Navigation indices for the words.
this.navIndices = {};
}
/** Get the set of words in the dataset visualizer. */
words_() {
const words = [];
for (const element of this.container.children) {
words.push(element.getAttribute('word'));
}
return words;
}
/**
* Draw an example.
*
* @param {HTMLDivElement} wordDiv The div element for the word. It is assumed
* that it contains the word button as the first child and the canvas as the
* second.
* @param {string} word The word of the example being added.
* @param {SpectrogramData} spectrogram Optional spectrogram data.
* If provided, will use it as is. If not provided, will use WebAudio
* to collect an example.
* @param {RawAudio} rawAudio Raw audio waveform. Optional
* @param {string} uid UID of the example being drawn. Must match the UID
* of the example from `this.transferRecognizer`.
*/
async drawExample(wordDiv, word, spectrogram, rawAudio, uid) {
if (uid == null) {
throw new Error('Error: UID is not provided for pre-existing example.');
}
removeNonFixedChildrenFromWordDiv(wordDiv);
// Create the left and right nav buttons.
const leftButton = document.createElement('button');
leftButton.textContent = '←';
wordDiv.appendChild(leftButton);
const rightButton = document.createElement('button');
rightButton.textContent = '→';
wordDiv.appendChild(rightButton);
// Determine the position of the example in the word of the dataset.
const exampleUIDs =
this.transferRecognizer.getExamples(word).map(ex => ex.uid);
const position = exampleUIDs.indexOf(uid);
this.navIndices[word] = exampleUIDs.indexOf(uid);
if (position > 0) {
leftButton.addEventListener('click', () => {
this.redraw(word, exampleUIDs[position - 1]);
});
} else {
leftButton.disabled = true;
}
if (position < exampleUIDs.length - 1) {
rightButton.addEventListener('click', () => {
this.redraw(word, exampleUIDs[position + 1]);
});
} else {
rightButton.disabled = true;
}
// Spectrogram canvas.
const exampleCanvas = document.createElement('canvas');
exampleCanvas.style['display'] = 'inline-block';
exampleCanvas.style['vertical-align'] = 'middle';
exampleCanvas.height = 60;
exampleCanvas.width = 80;
exampleCanvas.style['padding'] = '3px';
// Set up the click callback for the spectrogram canvas. When clicked,
// the keyFrameIndex will be set.
if (word !== speechCommands.BACKGROUND_NOISE_TAG) {
exampleCanvas.addEventListener('click', event => {
const relativeX =
getCanvasClickRelativeXCoordinate(exampleCanvas, event);
const numFrames = spectrogram.data.length / spectrogram.frameSize;
const keyFrameIndex = Math.floor(numFrames * relativeX);
console.log(
`relativeX=${relativeX}; ` +
`changed keyFrameIndex to ${keyFrameIndex}`);
this.transferRecognizer.setExampleKeyFrameIndex(uid, keyFrameIndex);
this.redraw(word, uid);
});
}
wordDiv.appendChild(exampleCanvas);
const modelNumFrames = this.transferRecognizer.modelInputShape()[1];
await plotSpectrogram(
exampleCanvas, spectrogram.data, spectrogram.frameSize,
spectrogram.frameSize, {
pixelsPerFrame: exampleCanvas.width / modelNumFrames,
maxPixelWidth: Math.round(0.4 * window.innerWidth),
markKeyFrame: this.transferDurationMultiplier > 1 &&
word !== speechCommands.BACKGROUND_NOISE_TAG,
keyFrameIndex: spectrogram.keyFrameIndex
});
if (rawAudio != null) {
const playButton = document.createElement('button');
playButton.textContent = '▶️';
playButton.addEventListener('click', () => {
playButton.disabled = true;
speechCommands.utils.playRawAudio(
rawAudio, () => playButton.disabled = false);
});
wordDiv.appendChild(playButton);
}
// Create Delete button.
const deleteButton = document.createElement('button');
deleteButton.textContent = 'X';
wordDiv.appendChild(deleteButton);
// Callback for delete button.
deleteButton.addEventListener('click', () => {
this.transferRecognizer.removeExample(uid);
// TODO(cais): Smarter logic for which example to draw after deletion.
// Right now it always redraws the last available one.
this.redraw(word);
});
this.updateButtons_();
}
/**
* Redraw the spectrogram and buttons for a word.
*
* @param {string} word The word being redrawn. This must belong to the
* vocabulary currently held by the transferRecognizer.
* @param {string} uid Optional UID for the example to render. If not
* specified, the last available example of the dataset will be drawn.
*/
async redraw(word, uid) {
if (word == null) {
throw new Error('word is not specified');
}
let divIndex;
for (divIndex = 0; divIndex < this.container.children.length; ++divIndex) {
if (this.container.children[divIndex].getAttribute('word') === word) {
break;
}
}
if (divIndex === this.container.children.length) {
throw new Error(`Cannot find div corresponding to word ${word}`);
}
const wordDiv = this.container.children[divIndex];
const exampleCounts = this.transferRecognizer.isDatasetEmpty() ?
{} :
this.transferRecognizer.countExamples();
if (word in exampleCounts) {
const examples = this.transferRecognizer.getExamples(word);
let example;
if (uid == null) {
// Example UID is not specified. Draw the last one available.
example = examples[examples.length - 1];
} else {
// Example UID is specified. Find the example and update navigation
// indices.
for (let index = 0; index < examples.length; ++index) {
if (examples[index].uid === uid) {
example = examples[index];
}
}
}
const spectrogram = example.example.spectrogram;
await this.drawExample(
wordDiv, word, spectrogram, example.example.rawAudio, example.uid);
} else {
removeNonFixedChildrenFromWordDiv(wordDiv);
}
this.updateButtons_();
}
/**
* Redraw the spectrograms and buttons for all words.
*
* For each word, the last available example is rendered.
**/
redrawAll() {
for (const word of this.words_()) {
this.redraw(word);
}
}
/** Update the button states according to the state of transferRecognizer. */
updateButtons_() {
const exampleCounts = this.transferRecognizer.isDatasetEmpty() ?
{} :
this.transferRecognizer.countExamples();
const minCountByClass =
this.words_()
.map(word => exampleCounts[word] || 0)
.reduce((prev, current) => current < prev ? current : prev);
for (const element of this.container.children) {
const word = element.getAttribute('word');
const button = element.children[0];
const displayWord =
word === speechCommands.BACKGROUND_NOISE_TAG ? 'noise' : word;
const exampleCount = exampleCounts[word] || 0;
if (exampleCount === 0) {
button.textContent = `${displayWord} (${exampleCount})`;
} else {
const pos = this.navIndices[word] + 1;
button.textContent = `${displayWord} (${pos}/${exampleCount})`;
}
}
const requiredMinCountPerClass =
Math.ceil(this.minExamplesPerClass / this.transferDurationMultiplier);
if (minCountByClass >= requiredMinCountPerClass) {
this.startTransferLearnButton.textContent = 'Start transfer learning';
this.startTransferLearnButton.disabled = false;
} else {
this.startTransferLearnButton.textContent =
`Need at least ${requiredMinCountPerClass} examples per word`;
this.startTransferLearnButton.disabled = true;
}
this.downloadAsFileButton.disabled =
this.transferRecognizer.isDatasetEmpty();
}
}

View File

@ -0,0 +1,98 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>TensorFlow.js Speech Commands Model Demo</title>
<link href="style.css" rel="stylesheet" type="text/css">
<meta name="viewport" content="width=device-width, initial-scale=1">
</head>
<body>
<div class="start-stop">
<button id="start" disabled="true">Start</button>
<button id="stop" disabled="true">Stop</button>
</div>
<div class='main-model'>
<div class="settings">
<span style="display:none">Prob. threshold:</span>
<input style="display:none" class="settings" size="5" id="proba-threshold" value="0.75">
</div>
<div id="candidate-words" class="candidate-words-hidden"></div>
</div>
<div class="footer" style="display: none;">
<textarea id="status-display" style="display: none" cols="80" readonly="true"></textarea>
</div>
<div class="transfer-learn-section">
<input id="transfer-model-name" size="20" placeholder="model name">
<input id="learn-words" size="36" value="_background_noise_,red,green">
<select id="duration-multiplier">
<option value="1">Duration x1</option>
<option value="2" selected="true">Duration x2</option>
</select>
<input type="checkbox" id="include-audio-waveform">
<span id="include-audio-waveform-label">Include audio waveform</span>
<button id="enter-learn-words" disabled="true">Enter transfer words</button>
<div id="transfer-learn-history"></div>
<div id="collect-words"></div>
<div class="collapsible-region">
<button id="dataset-io">Dataset IO >></button>
<div class="collapsible-region-inner" id="dataset-io-inner">
<div>
<button id="download-dataset" disabled="true">↓ Download dataset as file</button>
<div>
<input type="file" id="dataset-file-input">
<button id="upload-dataset">↑ Upload dataset</button>
<button id="eval-model-on-dataset">Evaluate model on dataset</button>
</div>
</div>
</div>
</div>
<div class="settings">
<span>Epochs:</span>
<input class="settings" size="5" id="epochs" value="100">
<span>Fine-tuning (FT) epochs:</span>
<input class="settings" size="5" id="fine-tuning-epochs" value="0">
<span>Augment by mixing noise:</span>
<input type="checkbox" id="augment-by-mixing-noise">
<button id="start-transfer-learn" disabled="true">Start transfer learning</button>
</div>
<div id="plots">
<div id="loss-plot" class="plots"></div>
<div id="accuracy-plot" class="plots"></div>
<div>
<div>
<span id="eval-results" class="eval-results"></span>
</div>
<div id="roc-plot" class="plots"></div>
</div>
</div>
<div class="collapsible-region">
<button id="model-io">Model IO >></button>
<div class="collapsible-region-inner" id="transfer-model-save-load-inner">
<div>
<button id="load-transfer-model" disabled="true">Load:</button>
<select id="saved-transfer-models">
<option value="1"></option>
</select>
<button id="delete-transfer-model" disabled="true">Delete</button>
</div>
<div>
<button id="save-transfer-model" disabled="true">Save model</button>
</div>
</div>
</div>
</div>
<script src="index.js"></script>
</body>
</html>

View File

@ -0,0 +1,713 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs';
import Plotly from 'plotly.js-dist';
import * as SpeechCommands from '@tensorflow-models/speech-commands';
import {DatasetViz, removeNonFixedChildrenFromWordDiv} from './dataset-vis';
import {hideCandidateWords, logToStatusDisplay, plotPredictions, plotSpectrogram, populateCandidateWords, showCandidateWords} from './ui';
const startButton = document.getElementById('start');
const stopButton = document.getElementById('stop');
const predictionCanvas = document.getElementById('prediction-canvas');
const probaThresholdInput = document.getElementById('proba-threshold');
const epochsInput = document.getElementById('epochs');
const fineTuningEpochsInput = document.getElementById('fine-tuning-epochs');
const datasetIOButton = document.getElementById('dataset-io');
const datasetIOInnerDiv = document.getElementById('dataset-io-inner');
const downloadAsFileButton = document.getElementById('download-dataset');
const datasetFileInput = document.getElementById('dataset-file-input');
const uploadFilesButton = document.getElementById('upload-dataset');
const evalModelOnDatasetButton =
document.getElementById('eval-model-on-dataset');
const evalResultsSpan = document.getElementById('eval-results');
const modelIOButton = document.getElementById('model-io');
const transferModelSaveLoadInnerDiv =
document.getElementById('transfer-model-save-load-inner');
const loadTransferModelButton = document.getElementById('load-transfer-model');
const saveTransferModelButton = document.getElementById('save-transfer-model');
const savedTransferModelsSelect =
document.getElementById('saved-transfer-models');
const deleteTransferModelButton =
document.getElementById('delete-transfer-model');
const BACKGROUND_NOISE_TAG = SpeechCommands.BACKGROUND_NOISE_TAG;
/**
* Transfer learning-related UI componenets.
*/
const transferModelNameInput = document.getElementById('transfer-model-name');
const learnWordsInput = document.getElementById('learn-words');
const durationMultiplierSelect = document.getElementById('duration-multiplier');
const enterLearnWordsButton = document.getElementById('enter-learn-words');
const includeTimeDomainWaveformCheckbox =
document.getElementById('include-audio-waveform');
const collectButtonsDiv = document.getElementById('collect-words');
const startTransferLearnButton =
document.getElementById('start-transfer-learn');
const XFER_MODEL_NAME = 'xfer-model';
// Minimum required number of examples per class for transfer learning.
const MIN_EXAMPLES_PER_CLASS = 8;
let recognizer;
let transferWords;
let transferRecognizer;
let transferDurationMultiplier;
(async function() {
logToStatusDisplay('Creating recognizer...');
recognizer = SpeechCommands.create('BROWSER_FFT');
await populateSavedTransferModelsSelect();
// Make sure the tf.Model is loaded through HTTP. If this is not
// called here, the tf.Model will be loaded the first time
// `listen()` is called.
recognizer.ensureModelLoaded()
.then(() => {
startButton.disabled = false;
enterLearnWordsButton.disabled = false;
loadTransferModelButton.disabled = false;
deleteTransferModelButton.disabled = false;
transferModelNameInput.value = `model-${getDateString()}`;
logToStatusDisplay('Model loaded.');
const params = recognizer.params();
logToStatusDisplay(`sampleRateHz: ${params.sampleRateHz}`);
logToStatusDisplay(`fftSize: ${params.fftSize}`);
logToStatusDisplay(
`spectrogramDurationMillis: ` +
`${params.spectrogramDurationMillis.toFixed(2)}`);
logToStatusDisplay(
`tf.Model input shape: ` +
`${JSON.stringify(recognizer.modelInputShape())}`);
})
.catch(err => {
logToStatusDisplay(
'Failed to load model for recognizer: ' + err.message);
});
})();
startButton.addEventListener('click', () => {
const activeRecognizer =
transferRecognizer == null ? recognizer : transferRecognizer;
populateCandidateWords(activeRecognizer.wordLabels());
const suppressionTimeMillis = 1000;
activeRecognizer
.listen(
result => {
plotPredictions(
predictionCanvas, activeRecognizer.wordLabels(), result.scores,
3, suppressionTimeMillis);
},
{
includeSpectrogram: true,
suppressionTimeMillis,
probabilityThreshold: Number.parseFloat(probaThresholdInput.value)
})
.then(() => {
startButton.disabled = true;
stopButton.disabled = false;
showCandidateWords();
logToStatusDisplay('Streaming recognition started.');
})
.catch(err => {
logToStatusDisplay(
'ERROR: Failed to start streaming display: ' + err.message);
});
});
stopButton.addEventListener('click', () => {
const activeRecognizer =
transferRecognizer == null ? recognizer : transferRecognizer;
activeRecognizer.stopListening()
.then(() => {
startButton.disabled = false;
stopButton.disabled = true;
hideCandidateWords();
logToStatusDisplay('Streaming recognition stopped.');
})
.catch(err => {
logToStatusDisplay(
'ERROR: Failed to stop streaming display: ' + err.message);
});
});
/**
* Transfer learning logic.
*/
/** Scroll to the bottom of the page */
function scrollToPageBottom() {
const scrollingElement = (document.scrollingElement || document.body);
scrollingElement.scrollTop = scrollingElement.scrollHeight;
}
let collectWordButtons = {};
let datasetViz;
function createProgressBarAndIntervalJob(parentElement, durationSec) {
const progressBar = document.createElement('progress');
progressBar.value = 0;
progressBar.style['width'] = `${Math.round(window.innerWidth * 0.25)}px`;
// Update progress bar in increments.
const intervalJob = setInterval(() => {
progressBar.value += 0.05;
}, durationSec * 1e3 / 20);
parentElement.appendChild(progressBar);
return {progressBar, intervalJob};
}
/**
* Create div elements for transfer words.
*
* @param {string[]} transferWords The array of transfer words.
* @returns {Object} An object mapping word to th div element created for it.
*/
function createWordDivs(transferWords) {
// Clear collectButtonsDiv first.
while (collectButtonsDiv.firstChild) {
collectButtonsDiv.removeChild(collectButtonsDiv.firstChild);
}
datasetViz = new DatasetViz(
transferRecognizer, collectButtonsDiv, MIN_EXAMPLES_PER_CLASS,
startTransferLearnButton, downloadAsFileButton,
transferDurationMultiplier);
const wordDivs = {};
for (const word of transferWords) {
const wordDiv = document.createElement('div');
wordDiv.classList.add('word-div');
wordDivs[word] = wordDiv;
wordDiv.setAttribute('word', word);
const button = document.createElement('button');
button.setAttribute('isFixed', 'true');
button.style['display'] = 'inline-block';
button.style['vertical-align'] = 'middle';
const displayWord = word === BACKGROUND_NOISE_TAG ? 'noise' : word;
button.textContent = `${displayWord} (0)`;
wordDiv.appendChild(button);
wordDiv.className = 'transfer-word';
collectButtonsDiv.appendChild(wordDiv);
collectWordButtons[word] = button;
let durationInput;
if (word === BACKGROUND_NOISE_TAG) {
// Create noise duration input.
durationInput = document.createElement('input');
durationInput.setAttribute('isFixed', 'true');
durationInput.value = '10';
durationInput.style['width'] = '100px';
wordDiv.appendChild(durationInput);
// Create time-unit span for noise duration.
const timeUnitSpan = document.createElement('span');
timeUnitSpan.setAttribute('isFixed', 'true');
timeUnitSpan.classList.add('settings');
timeUnitSpan.style['vertical-align'] = 'middle';
timeUnitSpan.textContent = 'seconds';
wordDiv.appendChild(timeUnitSpan);
}
button.addEventListener('click', async () => {
disableAllCollectWordButtons();
removeNonFixedChildrenFromWordDiv(wordDiv);
const collectExampleOptions = {};
let durationSec;
let intervalJob;
let progressBar;
if (word === BACKGROUND_NOISE_TAG) {
// If the word type is background noise, display a progress bar during
// sound collection and do not show an incrementally updating
// spectrogram.
// _background_noise_ examples are special, in that user can specify
// the length of the recording (in seconds).
collectExampleOptions.durationSec =
Number.parseFloat(durationInput.value);
durationSec = collectExampleOptions.durationSec;
const barAndJob = createProgressBarAndIntervalJob(wordDiv, durationSec);
progressBar = barAndJob.progressBar;
intervalJob = barAndJob.intervalJob;
} else {
// If this is not a background-noise word type and if the duration
// multiplier is >1 (> ~1 s recoding), show an incrementally
// updating spectrogram in real time.
collectExampleOptions.durationMultiplier = transferDurationMultiplier;
let tempSpectrogramData;
const tempCanvas = document.createElement('canvas');
tempCanvas.style['margin-left'] = '132px';
tempCanvas.height = 50;
wordDiv.appendChild(tempCanvas);
collectExampleOptions.snippetDurationSec = 0.1;
collectExampleOptions.onSnippet = async (spectrogram) => {
if (tempSpectrogramData == null) {
tempSpectrogramData = spectrogram.data;
} else {
tempSpectrogramData = SpeechCommands.utils.concatenateFloat32Arrays(
[tempSpectrogramData, spectrogram.data]);
}
plotSpectrogram(
tempCanvas, tempSpectrogramData, spectrogram.frameSize,
spectrogram.frameSize, {pixelsPerFrame: 2});
}
}
collectExampleOptions.includeRawAudio =
includeTimeDomainWaveformCheckbox.checked;
const spectrogram =
await transferRecognizer.collectExample(word, collectExampleOptions);
if (intervalJob != null) {
clearInterval(intervalJob);
}
if (progressBar != null) {
wordDiv.removeChild(progressBar);
}
const examples = transferRecognizer.getExamples(word)
const example = examples[examples.length - 1];
await datasetViz.drawExample(
wordDiv, word, spectrogram, example.example.rawAudio, example.uid);
enableAllCollectWordButtons();
});
}
return wordDivs;
}
enterLearnWordsButton.addEventListener('click', () => {
const modelName = transferModelNameInput.value;
if (modelName == null || modelName.length === 0) {
enterLearnWordsButton.textContent = 'Need model name!';
setTimeout(() => {
enterLearnWordsButton.textContent = 'Enter transfer words';
}, 2000);
return;
}
// We disable the option to upload an existing dataset from files
// once the "Enter transfer words" button has been clicked.
// However, the user can still load an existing dataset from
// files first and keep appending examples to it.
disableFileUploadControls();
enterLearnWordsButton.disabled = true;
transferDurationMultiplier = durationMultiplierSelect.value;
learnWordsInput.disabled = true;
enterLearnWordsButton.disabled = true;
transferWords = learnWordsInput.value.trim().split(',').map(w => w.trim());
transferWords.sort();
if (transferWords == null || transferWords.length <= 1) {
logToStatusDisplay('ERROR: Invalid list of transfer words.');
return;
}
transferRecognizer = recognizer.createTransfer(modelName);
createWordDivs(transferWords);
scrollToPageBottom();
});
function disableAllCollectWordButtons() {
for (const word in collectWordButtons) {
collectWordButtons[word].disabled = true;
}
}
function enableAllCollectWordButtons() {
for (const word in collectWordButtons) {
collectWordButtons[word].disabled = false;
}
}
function disableFileUploadControls() {
datasetFileInput.disabled = true;
uploadFilesButton.disabled = true;
}
startTransferLearnButton.addEventListener('click', async () => {
startTransferLearnButton.disabled = true;
startButton.disabled = true;
startTransferLearnButton.textContent = 'Transfer learning starting...';
await tf.nextFrame();
const INITIAL_PHASE = 'initial';
const FINE_TUNING_PHASE = 'fineTuningPhase';
const epochs = parseInt(epochsInput.value);
const fineTuningEpochs = parseInt(fineTuningEpochsInput.value);
const trainLossValues = {};
const valLossValues = {};
const trainAccValues = {};
const valAccValues = {};
for (const phase of [INITIAL_PHASE, FINE_TUNING_PHASE]) {
const phaseSuffix = phase === FINE_TUNING_PHASE ? ' (FT)' : '';
const lineWidth = phase === FINE_TUNING_PHASE ? 2 : 1;
trainLossValues[phase] = {
x: [],
y: [],
name: 'train' + phaseSuffix,
mode: 'lines',
line: {width: lineWidth}
};
valLossValues[phase] = {
x: [],
y: [],
name: 'val' + phaseSuffix,
mode: 'lines',
line: {width: lineWidth}
};
trainAccValues[phase] = {
x: [],
y: [],
name: 'train' + phaseSuffix,
mode: 'lines',
line: {width: lineWidth}
};
valAccValues[phase] = {
x: [],
y: [],
name: 'val' + phaseSuffix,
mode: 'lines',
line: {width: lineWidth}
};
}
function plotLossAndAccuracy(epoch, loss, acc, val_loss, val_acc, phase) {
const displayEpoch = phase === FINE_TUNING_PHASE ? (epoch + epochs) : epoch;
trainLossValues[phase].x.push(displayEpoch);
trainLossValues[phase].y.push(loss);
trainAccValues[phase].x.push(displayEpoch);
trainAccValues[phase].y.push(acc);
valLossValues[phase].x.push(displayEpoch);
valLossValues[phase].y.push(val_loss);
valAccValues[phase].x.push(displayEpoch);
valAccValues[phase].y.push(val_acc);
Plotly.newPlot(
'loss-plot',
[
trainLossValues[INITIAL_PHASE], valLossValues[INITIAL_PHASE],
trainLossValues[FINE_TUNING_PHASE], valLossValues[FINE_TUNING_PHASE]
],
{
width: 480,
height: 360,
xaxis: {title: 'Epoch #'},
yaxis: {title: 'Loss'},
font: {size: 18}
});
Plotly.newPlot(
'accuracy-plot',
[
trainAccValues[INITIAL_PHASE], valAccValues[INITIAL_PHASE],
trainAccValues[FINE_TUNING_PHASE], valAccValues[FINE_TUNING_PHASE]
],
{
width: 480,
height: 360,
xaxis: {title: 'Epoch #'},
yaxis: {title: 'Accuracy'},
font: {size: 18}
});
startTransferLearnButton.textContent = phase === INITIAL_PHASE ?
`Transfer-learning... (${(epoch / epochs * 1e2).toFixed(0)}%)` :
`Transfer-learning (fine-tuning)... (${
(epoch / fineTuningEpochs * 1e2).toFixed(0)}%)`
scrollToPageBottom();
}
disableAllCollectWordButtons();
const augmentByMixingNoiseRatio =
document.getElementById('augment-by-mixing-noise').checked ? 0.5 : null;
console.log(`augmentByMixingNoiseRatio = ${augmentByMixingNoiseRatio}`);
await transferRecognizer.train({
epochs,
validationSplit: 0.25,
augmentByMixingNoiseRatio,
callback: {
onEpochEnd: async (epoch, logs) => {
plotLossAndAccuracy(
epoch, logs.loss, logs.acc, logs.val_loss, logs.val_acc,
INITIAL_PHASE);
}
},
fineTuningEpochs,
fineTuningCallback: {
onEpochEnd: async (epoch, logs) => {
plotLossAndAccuracy(
epoch, logs.loss, logs.acc, logs.val_loss, logs.val_acc,
FINE_TUNING_PHASE);
}
}
});
saveTransferModelButton.disabled = false;
transferModelNameInput.value = transferRecognizer.name;
transferModelNameInput.disabled = true;
startTransferLearnButton.textContent = 'Transfer learning complete.';
transferModelNameInput.disabled = false;
startButton.disabled = false;
evalModelOnDatasetButton.disabled = false;
});
downloadAsFileButton.addEventListener('click', () => {
const basename = getDateString();
const artifacts = transferRecognizer.serializeExamples();
// Trigger downloading of the data .bin file.
const anchor = document.createElement('a');
anchor.download = `${basename}.bin`;
anchor.href = window.URL.createObjectURL(
new Blob([artifacts], {type: 'application/octet-stream'}));
anchor.click();
});
/** Get the base name of the downloaded files based on current dataset. */
function getDateString() {
const d = new Date();
const year = `${d.getFullYear()}`;
let month = `${d.getMonth() + 1}`;
let day = `${d.getDate()}`;
if (month.length < 2) {
month = `0${month}`;
}
if (day.length < 2) {
day = `0${day}`;
}
let hour = `${d.getHours()}`;
if (hour.length < 2) {
hour = `0${hour}`;
}
let minute = `${d.getMinutes()}`;
if (minute.length < 2) {
minute = `0${minute}`;
}
let second = `${d.getSeconds()}`;
if (second.length < 2) {
second = `0${second}`;
}
return `${year}-${month}-${day}T${hour}.${minute}.${second}`;
}
uploadFilesButton.addEventListener('click', async () => {
const files = datasetFileInput.files;
if (files == null || files.length !== 1) {
throw new Error('Must select exactly one file.');
}
const datasetFileReader = new FileReader();
datasetFileReader.onload = async event => {
try {
await loadDatasetInTransferRecognizer(event.target.result);
} catch (err) {
const originalTextContent = uploadFilesButton.textContent;
uploadFilesButton.textContent = err.message;
setTimeout(() => {
uploadFilesButton.textContent = originalTextContent;
}, 2000);
}
durationMultiplierSelect.value = `${transferDurationMultiplier}`;
durationMultiplierSelect.disabled = true;
enterLearnWordsButton.disabled = true;
};
datasetFileReader.onerror = () =>
console.error(`Failed to binary data from file '${dataFile.name}'.`);
datasetFileReader.readAsArrayBuffer(files[0]);
});
async function loadDatasetInTransferRecognizer(serialized) {
const modelName = transferModelNameInput.value;
if (modelName == null || modelName.length === 0) {
throw new Error('Need model name!');
}
if (transferRecognizer == null) {
transferRecognizer = recognizer.createTransfer(modelName);
}
transferRecognizer.loadExamples(serialized);
const exampleCounts = transferRecognizer.countExamples();
transferWords = [];
const modelNumFrames = transferRecognizer.modelInputShape()[1];
const durationMultipliers = [];
for (const word in exampleCounts) {
transferWords.push(word);
const examples = transferRecognizer.getExamples(word);
for (const example of examples) {
const spectrogram = example.example.spectrogram;
// Ignore _background_noise_ examples when determining the duration
// multiplier of the dataset.
if (word !== BACKGROUND_NOISE_TAG) {
durationMultipliers.push(Math.round(
spectrogram.data.length / spectrogram.frameSize / modelNumFrames));
}
}
}
transferWords.sort();
learnWordsInput.value = transferWords.join(',');
// Determine the transferDurationMultiplier value from the dataset.
transferDurationMultiplier =
durationMultipliers.length > 0 ? Math.max(...durationMultipliers) : 1;
console.log(
`Deteremined transferDurationMultiplier from uploaded ` +
`dataset: ${transferDurationMultiplier}`);
createWordDivs(transferWords);
datasetViz.redrawAll();
}
evalModelOnDatasetButton.addEventListener('click', async () => {
const files = datasetFileInput.files;
if (files == null || files.length !== 1) {
throw new Error('Must select exactly one file.');
}
evalModelOnDatasetButton.disabled = true;
const datasetFileReader = new FileReader();
datasetFileReader.onload = async event => {
try {
if (transferRecognizer == null) {
throw new Error('There is no model!');
}
// Load the dataset and perform evaluation of the transfer
// model using the dataset.
transferRecognizer.loadExamples(event.target.result);
const evalResult = await transferRecognizer.evaluate({
windowHopRatio: 0.25,
wordProbThresholds: [
0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.5,
0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0
]
});
// Plot the ROC curve.
const rocDataForPlot = {x: [], y: []};
evalResult.rocCurve.forEach(item => {
rocDataForPlot.x.push(item.fpr);
rocDataForPlot.y.push(item.tpr);
});
Plotly.newPlot('roc-plot', [rocDataForPlot], {
width: 360,
height: 360,
mode: 'markers',
marker: {size: 7},
xaxis: {title: 'False positive rate (FPR)', range: [0, 1]},
yaxis: {title: 'True positive rate (TPR)', range: [0, 1]},
font: {size: 18}
});
evalResultsSpan.textContent = `AUC = ${evalResult.auc}`;
} catch (err) {
const originalTextContent = evalModelOnDatasetButton.textContent;
evalModelOnDatasetButton.textContent = err.message;
setTimeout(() => {
evalModelOnDatasetButton.textContent = originalTextContent;
}, 2000);
}
evalModelOnDatasetButton.disabled = false;
};
datasetFileReader.onerror = () =>
console.error(`Failed to binary data from file '${dataFile.name}'.`);
datasetFileReader.readAsArrayBuffer(files[0]);
});
async function populateSavedTransferModelsSelect() {
const savedModelKeys = await SpeechCommands.listSavedTransferModels();
while (savedTransferModelsSelect.firstChild) {
savedTransferModelsSelect.removeChild(savedTransferModelsSelect.firstChild);
}
if (savedModelKeys.length > 0) {
for (const key of savedModelKeys) {
const option = document.createElement('option');
option.textContent = key;
option.id = key;
savedTransferModelsSelect.appendChild(option);
}
loadTransferModelButton.disabled = false;
}
}
saveTransferModelButton.addEventListener('click', async () => {
await transferRecognizer.save();
await populateSavedTransferModelsSelect();
saveTransferModelButton.textContent = 'Model saved!';
saveTransferModelButton.disabled = true;
});
loadTransferModelButton.addEventListener('click', async () => {
const transferModelName = savedTransferModelsSelect.value;
await recognizer.ensureModelLoaded();
transferRecognizer = recognizer.createTransfer(transferModelName);
await transferRecognizer.load();
transferModelNameInput.value = transferModelName;
transferModelNameInput.disabled = true;
learnWordsInput.value = transferRecognizer.wordLabels().join(',');
learnWordsInput.disabled = true;
durationMultiplierSelect.disabled = true;
enterLearnWordsButton.disabled = true;
saveTransferModelButton.disabled = true;
loadTransferModelButton.disabled = true;
loadTransferModelButton.textContent = 'Model loaded!';
});
modelIOButton.addEventListener('click', () => {
if (modelIOButton.textContent.endsWith(' >>')) {
transferModelSaveLoadInnerDiv.style.display = 'inline-block';
modelIOButton.textContent = modelIOButton.textContent.replace(' >>', ' <<');
} else {
transferModelSaveLoadInnerDiv.style.display = 'none';
modelIOButton.textContent = modelIOButton.textContent.replace(' <<', ' >>');
}
});
deleteTransferModelButton.addEventListener('click', async () => {
const transferModelName = savedTransferModelsSelect.value;
await recognizer.ensureModelLoaded();
transferRecognizer = recognizer.createTransfer(transferModelName);
await SpeechCommands.deleteSavedTransferModel(transferModelName);
deleteTransferModelButton.disabled = true;
deleteTransferModelButton.textContent = `Deleted "${transferModelName}"`;
await populateSavedTransferModelsSelect();
});
datasetIOButton.addEventListener('click', () => {
if (datasetIOButton.textContent.endsWith(' >>')) {
datasetIOInnerDiv.style.display = 'inline-block';
datasetIOButton.textContent =
datasetIOButton.textContent.replace(' >>', ' <<');
} else {
datasetIOInnerDiv.style.display = 'none';
datasetIOButton.textContent =
datasetIOButton.textContent.replace(' <<', ' >>');
}
});

View File

@ -0,0 +1,60 @@
{
"name": "tfjs-models-speech-commands-demo",
"version": "0.0.1",
"description": "",
"main": "index.js",
"license": "Apache-2.0",
"private": true,
"engines": {
"node": ">=8.9.0"
},
"dependencies": {
"@tensorflow-models/speech-commands": "file:../dist",
"stats.js": "^0.17.0"
},
"scripts": {
"build-model": "cd .. && yarn && yarn build",
"watch": "yarn build-model && cross-env NODE_OPTIONS=--max_old_space_size=4096 NODE_ENV=development parcel index.html --no-hmr --open",
"build": "yarn build-model && cross-env NODE_OPTIONS=--max_old_space_size=4096 NODE_ENV=production parcel build index.html --public-url ./",
"lint": "eslint .",
"link-local": "yalc link @tensorflow-models/speech-commands"
},
"browser": {
"crypto": false
},
"devDependencies": {
"@babel/core": "^7.0.0-0",
"@babel/plugin-transform-runtime": "^7.1.0",
"babel-polyfill": "~6.26.0",
"babel-preset-env": "~1.6.1",
"babel-preset-es2017": "^6.24.1",
"clang-format": "~1.2.2",
"cross-env": "^5.2.0",
"dat.gui": "^0.7.1",
"eslint": "^4.19.1",
"eslint-config-google": "^0.9.1",
"parcel-bundler": "~1.12.5",
"plotly.js-dist": "^1.39.4",
"yalc": "~1.0.0-pre.50"
},
"resolutions": {
"is-svg": "4.3.1"
},
"eslintConfig": {
"extends": "google",
"rules": {
"require-jsdoc": 0,
"valid-jsdoc": 0
},
"env": {
"es6": true
},
"parserOptions": {
"ecmaVersion": 8,
"sourceType": "module"
}
},
"eslintIgnore": [
"dist/"
]
}

View File

@ -0,0 +1,183 @@
body {
margin: 30px 0 0 30px;
font: 400 11px system-ui;
}
button {
color: #ff8300;
background-color: #ffffff;
border-style: solid;
border-width: 2px;
border-color: #ff8300;
border-radius: 10px;
font-size: 20px;
margin: 5px;
padding: 15px;
}
button:disabled {
color: #a0a0a0;
border-color: #a0a0a0;
}
input:disabled {
color: #a0a0a0;
border-color: #a0a0a0;
}
select:disabled {
color: #a0a0a0;
border-color: #a0a0a0;
}
select {
color: #ff8300;
background-color: #ffffff;
border-style: solid;
border-width: 2px;
border-color: #ff8300;
border-radius: 10px;
font-size: 20px;
margin: 5px;
padding: 15px;
}
.transfer-learn-section input {
color: #0000ff;
background-color: #ffffff;
border-style: solid;
border-width: 2px;
border-color: #0000ff;
border-radius: 10px;
font-size: 20px;
margin: 5px;
padding: 15px;
position: relative;
}
textarea {
font-size: 20px;
width: 90%;
height: 80%;
border: 2px solid #888;
border-radius: 10px;
resize: none;
}
.footer {
height: 20%;
}
.transfer-learn {
position: absolute;
top: 40%;
}
.word-div {
border-radius: 10px;
margin: 3px;
}
.candidate-word {
border: 1px solid gray;
background-color: lightyellow;
margin: 5px;
border-radius: 3px;
width: 10vw;
padding: 15px;
text-align: center;
}
.candidate-word-active {
border: 2px solid gray;
background-color: lightgreen;
}
.candidate-word-label {
font-weight: bold;
background-color: orange;
width: 250px;
}
.candidate-words-hidden {
display: none !important;
}
#candidate-words {
display: flex;
flex-wrap: wrap;
font-size: 20px;
}
#collect-words {
display: flex;
flex-wrap: wrap;
flex-direction: column;
}
#plots {
display: flex;
}
.settings {
font-size: 17px;
}
.collapsible-region {
font-size: 17px;
border-style: solid;
border-width: 1px;
border-color: #808080;
border-radius: 10px;
}
.collapsible-region-inner {
display: none;
}
#model-io {
vertical-align: top;
}
#dataset-io {
vertical-align: top;
}
.start-stop {
font-size: 17px;
}
.settings input {
color: #0000ff;
background-color: #ffffff;
border-style: solid;
border-width: 2px;
border-color: #0000ff;
border-radius: 10px;
margin: 5px;
font-size: 17px;
}
.transfer-word {
display: flex;
width: 100%;
justify-content: left;
vertical-align: middle;
text-align: center;
}
.eval-results {
font-size: 17px;
}
input[type=checkbox] {
transform: scale(2);
}
#include-audio-waveform {
margin-left: 20px;
}
#include-audio-waveform-label {
font-size: 17px;
}

View File

@ -0,0 +1,216 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as SpeechCommands from '../src';
import {BACKGROUND_NOISE_TAG, UNKNOWN_TAG} from '../src';
const statusDisplay = document.getElementById('status-display');
const candidateWordsContainer = document.getElementById('candidate-words');
/**
* Log a message to a textarea.
*
* @param {string} message Message to be logged.
*/
export function logToStatusDisplay(message) {
const date = new Date();
statusDisplay.value += `[${date.toISOString()}] ` + message + '\n';
statusDisplay.scrollTop = statusDisplay.scrollHeight;
}
let candidateWordSpans;
/**
* Display candidate words in the UI.
*
* The background-noise "word" will be omitted.
*
* @param {*} words Candidate words.
*/
export function populateCandidateWords(words) {
candidateWordSpans = {};
while (candidateWordsContainer.firstChild) {
candidateWordsContainer.removeChild(candidateWordsContainer.firstChild);
}
for (const word of words) {
if (word === BACKGROUND_NOISE_TAG || word === UNKNOWN_TAG) {
continue;
}
const wordSpan = document.createElement('span');
wordSpan.textContent = word;
wordSpan.classList.add('candidate-word');
candidateWordsContainer.appendChild(wordSpan);
candidateWordSpans[word] = wordSpan;
}
}
export function showCandidateWords() {
candidateWordsContainer.classList.remove('candidate-words-hidden');
}
export function hideCandidateWords() {
candidateWordsContainer.classList.add('candidate-words-hidden');
}
/**
* Show an audio spectrogram in a canvas.
*
* @param {HTMLCanvasElement} canvas The canvas element to draw the
* spectrogram in.
* @param {Float32Array} frequencyData The flat array for the spectrogram
* data.
* @param {number} fftSize Number of frequency points per frame.
* @param {number} fftDisplaySize Number of frequency points to show. Must be
* @param {Object} config Optional configuration object, with the following
* supported fields:
* - pixelsPerFrame {number} Number of pixels along the width dimension of
* the canvas for each frame of spectrogram.
* - maxPixelWidth {number} Maximum width in pixels.
* - markKeyFrame {bool} Whether to mark the index of the frame
* with the maximum intensity or a predetermined key frame.
* - keyFrameIndex {index?} Predetermined key frame index.
*
* <= fftSize.
*/
export async function plotSpectrogram(
canvas, frequencyData, fftSize, fftDisplaySize, config) {
if (fftDisplaySize == null) {
fftDisplaySize = fftSize;
}
if (config == null) {
config = {};
}
// Get the maximum and minimum.
let min = Infinity;
let max = -Infinity;
for (let i = 0; i < frequencyData.length; ++i) {
const x = frequencyData[i];
if (x !== -Infinity) {
if (x < min) {
min = x;
}
if (x > max) {
max = x;
}
}
}
if (min >= max) {
return;
}
const context = canvas.getContext('2d');
context.clearRect(0, 0, canvas.width, canvas.height);
const numFrames = frequencyData.length / fftSize;
if (config.pixelsPerFrame != null) {
let realWidth = Math.round(config.pixelsPerFrame * numFrames);
if (config.maxPixelWidth != null && realWidth > config.maxPixelWidth) {
realWidth = config.maxPixelWidth;
}
canvas.width = realWidth;
}
const pixelWidth = canvas.width / numFrames;
const pixelHeight = canvas.height / fftDisplaySize;
for (let i = 0; i < numFrames; ++i) {
const x = pixelWidth * i;
const spectrum = frequencyData.subarray(i * fftSize, (i + 1) * fftSize);
if (spectrum[0] === -Infinity) {
break;
}
for (let j = 0; j < fftDisplaySize; ++j) {
const y = canvas.height - (j + 1) * pixelHeight;
let colorValue = (spectrum[j] - min) / (max - min);
colorValue = Math.pow(colorValue, 3);
colorValue = Math.round(255 * colorValue);
const fillStyle =
`rgb(${colorValue},${255 - colorValue},${255 - colorValue})`;
context.fillStyle = fillStyle;
context.fillRect(x, y, pixelWidth, pixelHeight);
}
}
if (config.markKeyFrame) {
const keyFrameIndex = config.keyFrameIndex == null ?
await SpeechCommands
.getMaxIntensityFrameIndex(
{data: frequencyData, frameSize: fftSize})
.data() :
config.keyFrameIndex;
// Draw lines to mark the maximum-intensity frame.
context.strokeStyle = 'black';
context.beginPath();
context.moveTo(pixelWidth * keyFrameIndex, 0);
context.lineTo(pixelWidth * keyFrameIndex, canvas.height * 0.1);
context.stroke();
context.beginPath();
context.moveTo(pixelWidth * keyFrameIndex, canvas.height * 0.9);
context.lineTo(pixelWidth * keyFrameIndex, canvas.height);
context.stroke();
}
}
/**
* Plot top-K predictions from a speech command recognizer.
*
* @param {HTMLCanvasElement} canvas The canvas to render the predictions in.
* @param {string[]} candidateWords Candidate word array.
* @param {Float32Array | number[]} probabilities Probability scores from the
* speech command recognizer. Must be of the same length as `candidateWords`.
* @param {number} timeToLiveMillis Optional time to live for the active label
* highlighting. If not provided, will the highlighting will live
* indefinitely till the next highlighting.
* @param {number} topK Top _ scores to render.
*/
export function plotPredictions(
canvas, candidateWords, probabilities, topK, timeToLiveMillis) {
if (topK != null) {
let wordsAndProbs = [];
for (let i = 0; i < candidateWords.length; ++i) {
wordsAndProbs.push([candidateWords[i], probabilities[i]]);
}
wordsAndProbs.sort((a, b) => (b[1] - a[1]));
wordsAndProbs = wordsAndProbs.slice(0, topK);
candidateWords = wordsAndProbs.map(item => item[0]);
probabilities = wordsAndProbs.map(item => item[1]);
// Highlight the top word.
const topWord = wordsAndProbs[0][0];
console.log(
`"${topWord}" (p=${wordsAndProbs[0][1].toFixed(6)}) @ ` +
new Date().toTimeString());
for (const word in candidateWordSpans) {
if (word === topWord) {
candidateWordSpans[word].classList.add('candidate-word-active');
if (timeToLiveMillis != null) {
setTimeout(() => {
if (candidateWordSpans[word]) {
candidateWordSpans[word].classList.remove(
'candidate-word-active');
}
}, timeToLiveMillis);
}
} else {
candidateWordSpans[word].classList.remove('candidate-word-active');
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,56 @@
{
"name": "@tensorflow-models/speech-commands",
"version": "0.5.4",
"description": "Speech-command recognizer in TensorFlow.js",
"main": "dist/index.js",
"unpkg": "dist/speech-commands.min.js",
"jsdelivr": "dist/speech-commands.min.js",
"jsnext:main": "dist/speech-commands.esm.js",
"module": "dist/speech-commands.esm.js",
"types": "dist/index.d.ts",
"repository": {
"type": "git",
"url": "https://github.com/tensorflow/tfjs-models.git"
},
"peerDependencies": {
"@tensorflow/tfjs-core": "^4.13.0",
"@tensorflow/tfjs-data": "^4.13.0",
"@tensorflow/tfjs-layers": "^4.13.0"
},
"devDependencies": {
"@tensorflow/tfjs-core": "^4.13.0",
"@tensorflow/tfjs-data": "^4.13.0",
"@tensorflow/tfjs-layers": "^4.13.0",
"@tensorflow/tfjs-node": "^4.13.0",
"@types/jasmine": "~2.8.8",
"@types/rimraf": "^2.0.2",
"@types/tempfile": "^2.0.0",
"babel-core": "~6.26.0",
"babel-plugin-transform-runtime": "~6.23.0",
"clang-format": "^1.2.4",
"dct": "^0.0.3",
"jasmine": "^3.2.0",
"jasmine-core": "^3.2.1",
"kissfft-js": "^0.1.8",
"rimraf": "2.6.2",
"rollup": "~0.58.2",
"rollup-plugin-node-resolve": "~3.3.0",
"rollup-plugin-typescript2": "~0.13.0",
"rollup-plugin-uglify": "~3.0.0",
"tempfile": "2.0.0",
"ts-node": "~5.0.0",
"tslib": "1.8.0",
"tslint": "~5.18.0",
"tslint-no-circular-imports": "^0.6.1",
"typescript": "~3.5.3",
"yalc": "~1.0.0-pre.21"
},
"scripts": {
"build": "tsc",
"lint": "tslint -p . -t verbose",
"publish-local": "yarn build && rollup -c && yalc push",
"build-npm": "yarn build && rollup -c",
"test": "ts-node --skip-ignore --project tsconfig.test.json run_tests.ts"
},
"license": "Apache-2.0"
}

View File

@ -0,0 +1,77 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import node from 'rollup-plugin-node-resolve';
import typescript from 'rollup-plugin-typescript2';
import uglify from 'rollup-plugin-uglify';
const PREAMBLE = `/**
* @license
* Copyright ${(new Date).getFullYear()} Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/`;
function minify() {
return uglify({ output: { preamble: PREAMBLE } });
}
function config({ plugins = [], output = {} }) {
return {
input: 'src/index.ts',
plugins: [
typescript({ tsconfigOverride: { compilerOptions: { module: 'ES2015' } } }),
node(), ...plugins
],
output: {
banner: PREAMBLE,
globals: {
'@tensorflow/tfjs-core': 'tf',
'@tensorflow/tfjs-layers': 'tf',
'@tensorflow/tfjs-data': 'tf.data',
},
...output
},
external: [
'@tensorflow/tfjs-core',
'@tensorflow/tfjs-layers',
'@tensorflow/tfjs-data',
]
};
}
const packageName = 'speechCommands';
export default [
config({output: {format: 'umd', name: packageName, file: 'dist/speech-commands.js'}}),
config({
plugins: [minify()],
output: {format: 'umd', name: packageName, file: 'dist/speech-commands.min.js'}
}),
config({
plugins: [minify()],
output: {format: 'es', file: 'dist/speech-commands.esm.js'}
})
];

View File

@ -0,0 +1,21 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as jasmine_util from '@tensorflow/tfjs-core/dist/jasmine_util';
import {runTests} from '../test_util';
runTests(jasmine_util);

View File

@ -0,0 +1,331 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Audio FFT Feature Extractor based on Browser-Native FFT.
*/
import * as tf from '@tensorflow/tfjs-core';
import {getAudioContextConstructor, getAudioMediaStream} from './browser_fft_utils';
import {FeatureExtractor, RecognizerParams} from './types';
export type SpectrogramCallback = (freqData: tf.Tensor, timeData?: tf.Tensor) =>
Promise<boolean>;
/**
* Configurations for constructing BrowserFftFeatureExtractor.
*/
export interface BrowserFftFeatureExtractorConfig extends RecognizerParams {
/**
* Number of audio frames (i.e., frequency columns) per spectrogram.
*/
numFramesPerSpectrogram: number;
/**
* Suppression period in milliseconds.
*
* How much time to rest (not call the spectrogramCallback) every time
* a word with probability score above threshold is recognized.
*/
suppressionTimeMillis: number;
/**
* A callback that is invoked every time a full spectrogram becomes
* available.
*
* `x` is a single-example tf.Tensor instance that includes the batch
* dimension.
* The return value is assumed to be whether a flag for whether the
* suppression period should initiate, e.g., when a word is recognized.
*/
spectrogramCallback: SpectrogramCallback;
/**
* Truncate each spectrogram column at how many frequency points.
*
* If `null` or `undefined`, will do no truncation.
*/
columnTruncateLength?: number;
/**
* Overlap factor. Must be >=0 and <1.
* For example, if the model takes a frame length of 1000 ms,
* and if overlap factor is 0.4, there will be a 400ms
* overlap between two successive frames, i.e., frames
* will be taken every 600 ms.
*/
overlapFactor: number;
/**
* Whether to collect the raw time-domain audio waveform in addition to the
* spectrogram.
*
* Default: `false`.
*/
includeRawAudio?: boolean;
}
/**
* Audio feature extractor based on Browser-native FFT.
*
* Uses AudioContext and analyser node.
*/
export class BrowserFftFeatureExtractor implements FeatureExtractor {
// Number of frames (i.e., columns) per spectrogram used for classification.
readonly numFrames: number;
// Audio sampling rate in Hz.
readonly sampleRateHz: number;
// The FFT length for each spectrogram column.
readonly fftSize: number;
// Truncation length for spectrogram columns.
readonly columnTruncateLength: number;
// Overlapping factor: the ratio between the temporal spacing between
// consecutive spectrograms and the length of each individual spectrogram.
readonly overlapFactor: number;
readonly includeRawAudio: boolean;
private readonly spectrogramCallback: SpectrogramCallback;
private stream: MediaStream;
// tslint:disable-next-line:no-any
private audioContextConstructor: any;
private audioContext: AudioContext;
private analyser: AnalyserNode;
private tracker: Tracker;
private freqData: Float32Array;
private timeData: Float32Array;
private freqDataQueue: Float32Array[];
private timeDataQueue: Float32Array[];
// tslint:disable-next-line:no-any
private frameIntervalTask: any;
private frameDurationMillis: number;
private suppressionTimeMillis: number;
/**
* Constructor of BrowserFftFeatureExtractor.
*
* @param config Required configuration object.
*/
constructor(config: BrowserFftFeatureExtractorConfig) {
if (config == null) {
throw new Error(
`Required configuration object is missing for ` +
`BrowserFftFeatureExtractor constructor`);
}
if (config.spectrogramCallback == null) {
throw new Error(`spectrogramCallback cannot be null or undefined`);
}
if (!(config.numFramesPerSpectrogram > 0)) {
throw new Error(
`Invalid value in numFramesPerSpectrogram: ` +
`${config.numFramesPerSpectrogram}`);
}
if (config.suppressionTimeMillis < 0) {
throw new Error(
`Expected suppressionTimeMillis to be >= 0, ` +
`but got ${config.suppressionTimeMillis}`);
}
this.suppressionTimeMillis = config.suppressionTimeMillis;
this.spectrogramCallback = config.spectrogramCallback;
this.numFrames = config.numFramesPerSpectrogram;
this.sampleRateHz = config.sampleRateHz || 44100;
this.fftSize = config.fftSize || 1024;
this.frameDurationMillis = this.fftSize / this.sampleRateHz * 1e3;
this.columnTruncateLength = config.columnTruncateLength || this.fftSize;
this.overlapFactor = config.overlapFactor;
this.includeRawAudio = config.includeRawAudio;
tf.util.assert(
this.overlapFactor >= 0 && this.overlapFactor < 1,
() => `Expected overlapFactor to be >= 0 and < 1, ` +
`but got ${this.overlapFactor}`);
if (this.columnTruncateLength > this.fftSize) {
throw new Error(
`columnTruncateLength ${this.columnTruncateLength} exceeds ` +
`fftSize (${this.fftSize}).`);
}
this.audioContextConstructor = getAudioContextConstructor();
}
async start(audioTrackConstraints?: MediaTrackConstraints):
Promise<Float32Array[]|void> {
if (this.frameIntervalTask != null) {
throw new Error(
'Cannot start already-started BrowserFftFeatureExtractor');
}
this.stream = await getAudioMediaStream(audioTrackConstraints);
this.audioContext = new this.audioContextConstructor(
{sampleRate: this.sampleRateHz}) as AudioContext;
const streamSource = this.audioContext.createMediaStreamSource(this.stream);
this.analyser = this.audioContext.createAnalyser();
this.analyser.fftSize = this.fftSize * 2;
this.analyser.smoothingTimeConstant = 0.0;
streamSource.connect(this.analyser);
// Reset the queue.
this.freqDataQueue = [];
this.freqData = new Float32Array(this.fftSize);
if (this.includeRawAudio) {
this.timeDataQueue = [];
this.timeData = new Float32Array(this.fftSize);
}
const period =
Math.max(1, Math.round(this.numFrames * (1 - this.overlapFactor)));
this.tracker = new Tracker(
period,
Math.round(this.suppressionTimeMillis / this.frameDurationMillis));
this.frameIntervalTask = setInterval(
this.onAudioFrame.bind(this), this.fftSize / this.sampleRateHz * 1e3);
}
private async onAudioFrame() {
this.analyser.getFloatFrequencyData(this.freqData);
if (this.freqData[0] === -Infinity) {
return;
}
this.freqDataQueue.push(this.freqData.slice(0, this.columnTruncateLength));
if (this.includeRawAudio) {
this.analyser.getFloatTimeDomainData(this.timeData);
this.timeDataQueue.push(this.timeData.slice());
}
if (this.freqDataQueue.length > this.numFrames) {
// Drop the oldest frame (least recent).
this.freqDataQueue.shift();
}
const shouldFire = this.tracker.tick();
if (shouldFire) {
const freqData = flattenQueue(this.freqDataQueue);
const freqDataTensor = getInputTensorFromFrequencyData(
freqData, [1, this.numFrames, this.columnTruncateLength, 1]);
let timeDataTensor: tf.Tensor;
if (this.includeRawAudio) {
const timeData = flattenQueue(this.timeDataQueue);
timeDataTensor = getInputTensorFromFrequencyData(
timeData, [1, this.numFrames * this.fftSize]);
}
const shouldRest =
await this.spectrogramCallback(freqDataTensor, timeDataTensor);
if (shouldRest) {
this.tracker.suppress();
}
tf.dispose([freqDataTensor, timeDataTensor]);
}
}
async stop(): Promise<void> {
if (this.frameIntervalTask == null) {
throw new Error(
'Cannot stop because there is no ongoing streaming activity.');
}
clearInterval(this.frameIntervalTask);
this.frameIntervalTask = null;
this.analyser.disconnect();
this.audioContext.close();
if (this.stream != null && this.stream.getTracks().length > 0) {
this.stream.getTracks()[0].stop();
}
}
setConfig(params: RecognizerParams) {
throw new Error(
'setConfig() is not implemented for BrowserFftFeatureExtractor.');
}
getFeatures(): Float32Array[] {
throw new Error(
'getFeatures() is not implemented for ' +
'BrowserFftFeatureExtractor. Use the spectrogramCallback ' +
'field of the constructor config instead.');
}
}
export function flattenQueue(queue: Float32Array[]): Float32Array {
const frameSize = queue[0].length;
const freqData = new Float32Array(queue.length * frameSize);
queue.forEach((data, i) => freqData.set(data, i * frameSize));
return freqData;
}
export function getInputTensorFromFrequencyData(
freqData: Float32Array, shape: number[]): tf.Tensor {
const vals = new Float32Array(tf.util.sizeFromShape(shape));
// If the data is less than the output shape, the rest is padded with zeros.
vals.set(freqData, vals.length - freqData.length);
return tf.tensor(vals, shape);
}
/**
* A class that manages the firing of events based on periods
* and suppression time.
*/
export class Tracker {
readonly period: number;
readonly suppressionTime: number;
private counter: number;
private suppressionOnset: number;
/**
* Constructor of Tracker.
*
* @param period The event-firing period, in number of frames.
* @param suppressionPeriod The suppression period, in number of frames.
*/
constructor(period: number, suppressionPeriod: number) {
this.period = period;
this.suppressionTime = suppressionPeriod == null ? 0 : suppressionPeriod;
this.counter = 0;
tf.util.assert(
this.period > 0,
() => `Expected period to be positive, but got ${this.period}`);
}
/**
* Mark a frame.
*
* @returns Whether the event should be fired at the current frame.
*/
tick(): boolean {
this.counter++;
const shouldFire = (this.counter % this.period === 0) &&
(this.suppressionOnset == null ||
this.counter - this.suppressionOnset > this.suppressionTime);
return shouldFire;
}
/**
* Order the beginning of a supression period.
*/
suppress() {
this.suppressionOnset = this.counter;
}
}

View File

@ -0,0 +1,327 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs-core';
// tslint:disable-next-line: no-imports-from-dist
import {describeWithFlags, NODE_ENVS} from '@tensorflow/tfjs-core/dist/jasmine_util';
import {BrowserFftFeatureExtractor, flattenQueue, getInputTensorFromFrequencyData} from './browser_fft_extractor';
import * as BrowserFftUtils from './browser_fft_utils';
import {FakeAudioContext, FakeAudioMediaStream} from './browser_test_utils';
import {expectTensorsClose} from './test_utils';
const testEnvs = NODE_ENVS;
describeWithFlags('flattenQueue', testEnvs, () => {
it('3 frames, 2 values each', () => {
const queue = [[1, 1], [2, 2], [3, 3]].map(x => new Float32Array(x));
expect(flattenQueue(queue)).toEqual(new Float32Array([1, 1, 2, 2, 3, 3]));
});
it('2 frames, 2 values each', () => {
const queue = [[1, 1], [2, 2]].map(x => new Float32Array(x));
expect(flattenQueue(queue)).toEqual(new Float32Array([1, 1, 2, 2]));
});
it('1 frame, 2 values each', () => {
const queue = [[1, 1]].map(x => new Float32Array(x));
expect(flattenQueue(queue)).toEqual(new Float32Array([1, 1]));
});
});
describeWithFlags('getInputTensorFromFrequencyData', testEnvs, () => {
it('6 frames, 2 vals each', () => {
const freqData = new Float32Array([1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]);
const numFrames = 6;
const fftSize = 2;
const tensor =
getInputTensorFromFrequencyData(freqData, [1, numFrames, fftSize, 1]);
expectTensorsClose(tensor, tf.tensor4d(freqData, [1, 6, 2, 1]));
});
});
describeWithFlags('BrowserFftFeatureExtractor', testEnvs, () => {
function setUpFakes() {
spyOn(BrowserFftUtils, 'getAudioContextConstructor')
.and.callFake(() => FakeAudioContext.createInstance);
spyOn(BrowserFftUtils, 'getAudioMediaStream')
.and.callFake(() => new FakeAudioMediaStream());
}
it('constructor', () => {
setUpFakes();
const extractor = new BrowserFftFeatureExtractor({
spectrogramCallback: async (x: tf.Tensor) => false,
numFramesPerSpectrogram: 43,
columnTruncateLength: 225,
suppressionTimeMillis: 1000,
overlapFactor: 0
});
expect(extractor.fftSize).toEqual(1024);
expect(extractor.numFrames).toEqual(43);
expect(extractor.columnTruncateLength).toEqual(225);
expect(extractor.overlapFactor).toBeCloseTo(0);
});
it('constructor errors due to null config', () => {
expect(() => new BrowserFftFeatureExtractor(null))
.toThrowError(/Required configuration object is missing/);
});
it('constructor errors due to missing spectrogramCallback', () => {
expect(() => new BrowserFftFeatureExtractor({
spectrogramCallback: null,
numFramesPerSpectrogram: 43,
columnTruncateLength: 225,
suppressionTimeMillis: 1000,
overlapFactor: 0
}))
.toThrowError(/spectrogramCallback cannot be null or undefined/);
});
it('constructor errors due to invalid numFramesPerSpectrogram', () => {
expect(() => new BrowserFftFeatureExtractor({
spectrogramCallback: async (x: tf.Tensor) => false,
numFramesPerSpectrogram: -2,
columnTruncateLength: 225,
overlapFactor: 0,
suppressionTimeMillis: 1000
}))
.toThrowError(/Invalid value in numFramesPerSpectrogram: -2/);
});
it('constructor errors due to negative overlapFactor', () => {
expect(() => new BrowserFftFeatureExtractor({
spectrogramCallback: async (x: tf.Tensor) => false,
numFramesPerSpectrogram: 43,
columnTruncateLength: 225,
overlapFactor: -0.1,
suppressionTimeMillis: 1000
}))
.toThrowError(/Expected overlapFactor/);
});
it('constructor errors due to columnTruncateLength too large', () => {
expect(() => new BrowserFftFeatureExtractor({
spectrogramCallback: async (x: tf.Tensor) => false,
numFramesPerSpectrogram: 43,
columnTruncateLength: 1600, // > 1024 and leads to Error.
overlapFactor: 0,
suppressionTimeMillis: 1000
}))
.toThrowError(/columnTruncateLength .* exceeds fftSize/);
});
it('constructor errors due to negative suppressionTimeMillis', () => {
expect(() => new BrowserFftFeatureExtractor({
spectrogramCallback: async (x: tf.Tensor) => false,
numFramesPerSpectrogram: 43,
columnTruncateLength: 1600,
overlapFactor: 0,
suppressionTimeMillis: -1000 // <0 and leads to Error.
}))
.toThrowError(/Expected suppressionTimeMillis to be >= 0/);
});
it('start and stop: overlapFactor = 0', done => {
setUpFakes();
const timeDelta = 50;
const spectrogramDurationMillis = 1024 / 44100 * 43 * 1e3 - timeDelta;
const numCallbacksToComplete = 3;
let numCallbacksCompleted = 0;
const tensorCounts: number[] = [];
const callbackTimestamps: number[] = [];
const extractor = new BrowserFftFeatureExtractor({
spectrogramCallback: async (x: tf.Tensor) => {
callbackTimestamps.push(tf.util.now());
if (callbackTimestamps.length > 1) {
expect(
callbackTimestamps[callbackTimestamps.length - 1] -
callbackTimestamps[callbackTimestamps.length - 2])
.toBeGreaterThanOrEqual(spectrogramDurationMillis);
}
expect(x.shape).toEqual([1, 43, 225, 1]);
tensorCounts.push(tf.memory().numTensors);
if (tensorCounts.length > 1) {
// Assert no memory leak.
expect(tensorCounts[tensorCounts.length - 1])
.toEqual(tensorCounts[tensorCounts.length - 2]);
}
if (++numCallbacksCompleted >= numCallbacksToComplete) {
await extractor.stop();
done();
}
return false;
},
numFramesPerSpectrogram: 43,
columnTruncateLength: 225,
overlapFactor: 0,
suppressionTimeMillis: 0
});
extractor.start();
});
it('start and stop: correct rotating buffer size', done => {
setUpFakes();
const numFramesPerSpectrogram = 43;
const columnTruncateLength = 225;
const numCallbacksToComplete = 3;
let numCallbacksCompleted = 0;
const extractor = new BrowserFftFeatureExtractor({
spectrogramCallback: async (x: tf.Tensor) => {
const xData = await x.data();
// Verify the correctness of the spectrogram data.
for (let i = 0; i < xData.length; ++i) {
const segment = Math.floor(i / columnTruncateLength);
const expected = segment * 1024 + (i % columnTruncateLength) +
1024 * numFramesPerSpectrogram * numCallbacksCompleted;
expect(xData[i]).toEqual(expected);
}
if (++numCallbacksCompleted >= numCallbacksToComplete) {
await extractor.stop();
done();
}
return false;
},
numFramesPerSpectrogram,
columnTruncateLength,
overlapFactor: 0,
suppressionTimeMillis: 0
});
extractor.start();
});
it('start and stop: overlapFactor = 0.5', done => {
setUpFakes();
const numCallbacksToComplete = 5;
let numCallbacksCompleted = 0;
const spectrogramTensors: tf.Tensor[] = [];
const callbackTimestamps: number[] = [];
const spectrogramDurationMillis = 1024 / 44100 * 43 * 1e3;
const numFramesPerSpectrogram = 43;
const columnTruncateLength = 225;
const extractor = new BrowserFftFeatureExtractor({
spectrogramCallback: async (x: tf.Tensor) => {
callbackTimestamps.push(tf.util.now());
if (callbackTimestamps.length > 1) {
expect(
callbackTimestamps[callbackTimestamps.length - 1] -
callbackTimestamps[callbackTimestamps.length - 2])
.toBeGreaterThanOrEqual(spectrogramDurationMillis * 0.5);
// Verify the content of the spectrogram data.
const xData = await x.data();
expect(xData[xData.length - 1])
.toEqual(callbackTimestamps.length * 22 * 1024 - 800);
}
expect(x.shape).toEqual([1, 43, 225, 1]);
spectrogramTensors.push(tf.clone(x));
if (++numCallbacksCompleted >= numCallbacksToComplete) {
await extractor.stop();
done();
}
return false;
},
numFramesPerSpectrogram,
columnTruncateLength,
overlapFactor: 0.5,
suppressionTimeMillis: 0
});
extractor.start();
});
it('start and stop: the first frame is captured', done => {
setUpFakes();
const extractor = new BrowserFftFeatureExtractor({
spectrogramCallback: async (x: tf.Tensor) => {
expect(x.shape).toEqual([1, 43, 225, 1]);
const xData = x.dataSync();
// Verify that the first frame is not all zero or any constant value
// We don't compare the values against zero directly, because the
// spectrogram data is normalized here. The assertions below are also
// based on the fact that the fake audio context outputs linearly
// increasing sample values.
expect(xData[1]).toBeGreaterThan(xData[0]);
expect(xData[2]).toBeGreaterThan(xData[1]);
await extractor.stop();
done();
return false;
},
numFramesPerSpectrogram: 43,
columnTruncateLength: 225,
overlapFactor: 0,
suppressionTimeMillis: 0
});
extractor.start();
});
it('start and stop: suppressionTimeMillis = 1000', done => {
setUpFakes();
const numCallbacksToComplete = 2;
const suppressionTimeMillis = 1500;
let numCallbacksCompleted = 0;
const extractor = new BrowserFftFeatureExtractor({
spectrogramCallback: async (x: tf.Tensor) => {
if (++numCallbacksCompleted >= numCallbacksToComplete) {
const tEnd = tf.util.now();
// Due to the suppression time, the time elapsed between the two
// consecutive callbacks should be longer than it.
expect(tEnd - tBegin).toBeGreaterThanOrEqual(suppressionTimeMillis);
await extractor.stop();
done();
}
return true; // Returning true causes suppression.
},
numFramesPerSpectrogram: 43,
columnTruncateLength: 225,
overlapFactor: 0.25,
suppressionTimeMillis
});
const tBegin = tf.util.now();
extractor.start();
});
it('stopping unstarted extractor leads to Error', async () => {
setUpFakes();
const extractor = new BrowserFftFeatureExtractor({
spectrogramCallback: async (x: tf.Tensor) => false,
numFramesPerSpectrogram: 43,
columnTruncateLength: 225,
overlapFactor: 0,
suppressionTimeMillis: 1000
});
let caughtError: Error;
try {
await extractor.stop();
} catch (err) {
caughtError = err;
}
expect(caughtError.message)
.toMatch(/Cannot stop because there is no ongoing streaming activity/);
});
});

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,131 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs-core';
import {promisify} from 'util';
import {RawAudioData} from './types';
export async function loadMetadataJson(url: string):
Promise<{wordLabels: string[]}> {
const HTTP_SCHEME = 'http://';
const HTTPS_SCHEME = 'https://';
const FILE_SCHEME = 'file://';
if (url.indexOf(HTTP_SCHEME) === 0 || url.indexOf(HTTPS_SCHEME) === 0) {
const response = await fetch(url);
const parsed = await response.json();
return parsed;
} else if (url.indexOf(FILE_SCHEME) === 0) {
// tslint:disable-next-line:no-require-imports
const fs = require('fs');
const readFile = promisify(fs.readFile);
return JSON.parse(
await readFile(url.slice(FILE_SCHEME.length), {encoding: 'utf-8'}));
} else {
throw new Error(
`Unsupported URL scheme in metadata URL: ${url}. ` +
`Supported schemes are: http://, https://, and ` +
`(node.js-only) file://`);
}
}
let EPSILON: number = null;
/**
* Normalize the input into zero mean and unit standard deviation.
*
* This function is safe against divison-by-zero: In case the standard
* deviation is zero, the output will be all-zero.
*
* @param x Input tensor.
* @param y Output normalized tensor.
*/
export function normalize(x: tf.Tensor): tf.Tensor {
if (EPSILON == null) {
EPSILON = tf.backend().epsilon();
}
return tf.tidy(() => {
const {mean, variance} = tf.moments(x);
// Add an EPSILON to the denominator to prevent division-by-zero.
return tf.div(tf.sub(x, mean), tf.add(tf.sqrt(variance), EPSILON));
});
}
/**
* Z-Normalize the elements of a Float32Array.
*
* Subtract the mean and divide the result by the standard deviation.
*
* @param x The Float32Array to normalize.
* @return Noramlzied Float32Array.
*/
export function normalizeFloat32Array(x: Float32Array): Float32Array {
if (x.length < 2) {
throw new Error(
'Cannot normalize a Float32Array with fewer than 2 elements.');
}
if (EPSILON == null) {
EPSILON = tf.backend().epsilon();
}
return tf.tidy(() => {
const {mean, variance} = tf.moments(tf.tensor1d(x));
const meanVal = mean.arraySync() as number;
const stdVal = Math.sqrt(variance.arraySync() as number);
const yArray = Array.from(x).map(y => (y - meanVal) / (stdVal + EPSILON));
return new Float32Array(yArray);
});
}
export function getAudioContextConstructor(): AudioContext {
// tslint:disable-next-line:no-any
return (window as any).AudioContext || (window as any).webkitAudioContext;
}
export async function getAudioMediaStream(
audioTrackConstraints?: MediaTrackConstraints): Promise<MediaStream> {
return navigator.mediaDevices.getUserMedia({
audio: audioTrackConstraints == null ? true : audioTrackConstraints,
video: false
});
}
/**
* Play raw audio waveform
* @param rawAudio Raw audio data, including the waveform and the sampling rate.
* @param onEnded Callback function to execute when the playing ends.
*/
export function playRawAudio(
rawAudio: RawAudioData, onEnded: () => void|Promise<void>): void {
const audioContextConstructor =
// tslint:disable-next-line:no-any
(window as any).AudioContext || (window as any).webkitAudioContext;
const audioContext: AudioContext = new audioContextConstructor();
const arrayBuffer =
audioContext.createBuffer(1, rawAudio.data.length, rawAudio.sampleRateHz);
const nowBuffering = arrayBuffer.getChannelData(0);
nowBuffering.set(rawAudio.data);
const source = audioContext.createBufferSource();
source.buffer = arrayBuffer;
source.connect(audioContext.destination);
source.start();
source.onended = () => {
if (onEnded != null) {
onEnded();
}
};
}

View File

@ -0,0 +1,56 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs-core';
import {normalize, normalizeFloat32Array} from './browser_fft_utils';
import {expectTensorsClose} from './test_utils';
describe('normalize', () => {
it('Non-constant value; no memory leak', () => {
const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]);
const numTensors0 = tf.memory().numTensors;
const y = normalize(x);
// Assert no memory leak.
expect(tf.memory().numTensors).toEqual(numTensors0 + 1);
expectTensorsClose(
y,
tf.tensor4d(
[-1.3416406, -0.4472135, 0.4472135, 1.3416406], [1, 2, 2, 1]));
const {mean, variance} = tf.moments(y);
expectTensorsClose(mean, tf.scalar(0));
expectTensorsClose(variance, tf.scalar(1));
});
it('Constant value', () => {
const x = tf.tensor4d([42, 42, 42, 42], [1, 2, 2, 1]);
const y = normalize(x);
expectTensorsClose(y, tf.tensor4d([0, 0, 0, 0], [1, 2, 2, 1]));
});
});
describe('normalizeFloat32Array', () => {
it('Length-4 input', () => {
const xs = new Float32Array([1, 2, 3, 4]);
const numTensors0 = tf.memory().numTensors;
const ys = tf.tensor1d(normalizeFloat32Array(xs));
// Assert no memory leak. (The extra comes from the tf.tensor1d() call
// in the testing code.)
expect(tf.memory().numTensors).toEqual(numTensors0 + 1);
expectTensorsClose(
ys, tf.tensor1d([-1.3416406, -0.4472135, 0.4472135, 1.3416406]));
});
});

View File

@ -0,0 +1,77 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Testing Utilities for Browser Audio Feature Extraction.
*/
export class FakeAudioContext {
readonly sampleRate = 44100;
static createInstance() {
return new FakeAudioContext();
}
createMediaStreamSource() {
return new FakeMediaStreamAudioSourceNode();
}
createAnalyser() {
return new FakeAnalyser();
}
close(): void {}
}
export class FakeAudioMediaStream {
constructor() {}
getTracks(): Array<{}> {
return [];
}
}
class FakeMediaStreamAudioSourceNode {
constructor() {}
connect(node: {}): void {}
}
class FakeAnalyser {
fftSize: number;
smoothingTimeConstant: number;
private x: number;
constructor() {
this.x = 0;
}
getFloatFrequencyData(data: Float32Array): void {
const xs: number[] = [];
for (let i = 0; i < this.fftSize / 2; ++i) {
xs.push(this.x++);
}
data.set(new Float32Array(xs));
}
getFloatTimeDomainData(data: Float32Array): void {
const xs: number[] = [];
for (let i = 0; i < this.fftSize / 2; ++i) {
xs.push(-(this.x++));
}
data.set(new Float32Array(xs));
}
disconnect(): void {}
}

View File

@ -0,0 +1,977 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs-core';
import * as tfd from '@tensorflow/tfjs-data';
import {normalize} from './browser_fft_utils';
import {arrayBuffer2String, concatenateArrayBuffers, getRandomInteger, getUID, string2ArrayBuffer} from './generic_utils';
import {balancedTrainValSplitNumArrays} from './training_utils';
import {AudioDataAugmentationOptions, Example, SpectrogramData} from './types';
// Descriptor for serialized dataset files: stands for:
// TensorFlow.js Speech-Commands Dataset.
// DO NOT EVER CHANGE THIS!
export const DATASET_SERIALIZATION_DESCRIPTOR = 'TFJSSCDS';
// A version number for the serialization. Since this needs
// to be encoded within a length-1 Uint8 array, it must be
// 1. an positive integer.
// 2. monotonically increasing over its change history.
// Item 1 is checked by unit tests.
export const DATASET_SERIALIZATION_VERSION = 1;
/**
* Specification for an `Example` (see above).
*
* Used for serialization of `Example`.
*/
export interface ExampleSpec {
/** A label for the example. */
label: string;
/** Number of frames in the spectrogram. */
spectrogramNumFrames: number;
/** The length of each frame in the spectrogram. */
spectrogramFrameSize: number;
/** The key frame index of the spectrogram. */
spectrogramKeyFrameIndex?: number;
/** Number of samples in the raw PCM-format audio (if any). */
rawAudioNumSamples?: number;
/** Sampling rate of the raw audio (if any). */
rawAudioSampleRateHz?: number;
}
/**
* Serialized Dataset, containing a number of `Example`s in their
* serialized format.
*
* This format consists of a plain-old JSON object as the manifest,
* along with a flattened binary `ArrayBuffer`. The format facilitates
* storage and transmission.
*/
export interface SerializedExamples {
/**
* Specifications of the serialized `Example`s, serialized as a string.
*/
manifest: ExampleSpec[];
/**
* Serialized binary data from the `Example`s.
*
* Including the spectrograms and the raw audio (if any).
*
* For example, assuming `manifest.length` is `N`, the format of the
* `ArrayBuffer` is as follows:
*
* [spectrogramData1, rawAudio1 (if any),
* spectrogramData2, rawAudio2 (if any),
* ...
* spectrogramDataN, rawAudioN (if any)]
*/
data: ArrayBuffer;
}
export const BACKGROUND_NOISE_TAG = '_background_noise_';
/**
* Configuration for getting spectrograms as tensors.
*/
export interface GetDataConfig extends AudioDataAugmentationOptions {
/**
* Number of frames.
*
* This must be smaller than or equal to the # of frames of each
* example held by the dataset.
*
* If the # of frames of an example is greater than this number,
* the following heuristics will be used to extra >= 1 examples
* of length numFrames from the original example:
*
* - If the label of the example is `BAKCGROUND_NOISE_TAG`,
* the example will be splitted into multiple examples using the
* `hopFrames` parameter (see below).
* - If the label of the example is not `BACKGROUND_NOISE_TAG`,
* the example will be splitted into multiple examples that
* all contain the maximum-intensity frame using the `hopFrames`
* parameter.
*/
numFrames?: number;
/**
* Hop length in number of frames.
*
* Used when splitting a long example into multiple shorter ones.
*
* Must be provided if any such long examples exist.
*/
hopFrames?: number;
/**
* Whether the spectrogram of each example will be normalized.
*
* Normalization means:
* - Subtracting the mean, and
* - Dividing the result by the standard deviation.
*
* Default: `true`.
*/
normalize?: boolean;
/**
* Whether the examples will be shuffled prior to merged into
* `tf.Tensor`s.
*
* Default: `true`.
*/
shuffle?: boolean;
/**
* Whether to obtain a `tf.data.Datasaet` object.
*
* Default: `false`.
*/
getDataset?: boolean;
/**
* Batch size for dataset.
*
* Applicable only if `getDataset === true`.
*/
datasetBatchSize?: number;
/**
* Validation split for the datasaet.
*
* Applicable only if `getDataset === true`.
*
* The data will be divided into two fractions of relative sizes
* `[1 - datasetValidationSplit, datasetValidationSplit]`, for the
* training and validation `tf.data.Dataset` objects, respectively.
*
* Must be a number between 0 and 1.
* Default: 0.15.
*/
datasetValidationSplit?: number;
}
// tslint:disable-next-line:no-any
export type SpectrogramAndTargetsTfDataset = tfd.Dataset<{}>;
/**
* A serializable, mutable set of speech/audio `Example`s;
*/
export class Dataset {
private examples: {[id: string]: Example};
private label2Ids: {[label: string]: string[]};
/**
* Constructor of `Dataset`.
*
* If called with no arguments (i.e., `artifacts` == null), an empty dataset
* will be constructed.
*
* Else, the dataset will be deserialized from `artifacts`.
*
* @param serialized Optional serialization artifacts to deserialize.
*/
constructor(serialized?: ArrayBuffer) {
this.examples = {};
this.label2Ids = {};
if (serialized != null) {
// Deserialize from the provided artifacts.
const artifacts = arrayBuffer2SerializedExamples(serialized);
let offset = 0;
for (let i = 0; i < artifacts.manifest.length; ++i) {
const spec = artifacts.manifest[i];
let byteLen = spec.spectrogramNumFrames * spec.spectrogramFrameSize;
if (spec.rawAudioNumSamples != null) {
byteLen += spec.rawAudioNumSamples;
}
byteLen *= 4;
this.addExample(deserializeExample(
{spec, data: artifacts.data.slice(offset, offset + byteLen)}));
offset += byteLen;
}
}
}
/**
* Add an `Example` to the `Dataset`
*
* @param example A `Example`, with a label. The label must be a non-empty
* string.
* @returns The UID for the added `Example`.
*/
addExample(example: Example): string {
tf.util.assert(example != null, () => 'Got null or undefined example');
tf.util.assert(
example.label != null && example.label.length > 0,
() => `Expected label to be a non-empty string, ` +
`but got ${JSON.stringify(example.label)}`);
const uid = getUID();
this.examples[uid] = example;
if (!(example.label in this.label2Ids)) {
this.label2Ids[example.label] = [];
}
this.label2Ids[example.label].push(uid);
return uid;
}
/**
* Merge the incoming dataset into this dataset
*
* @param dataset The incoming dataset to be merged into this dataset.
*/
merge(dataset: Dataset): void {
tf.util.assert(
dataset !== this, () => 'Cannot merge a dataset into itself');
const vocab = dataset.getVocabulary();
for (const word of vocab) {
const examples = dataset.getExamples(word);
for (const example of examples) {
this.addExample(example.example);
}
}
}
/**
* Get a map from `Example` label to number of `Example`s with the label.
*
* @returns A map from label to number of example counts under that label.
*/
getExampleCounts(): {[label: string]: number} {
const counts: {[label: string]: number} = {};
for (const uid in this.examples) {
const example = this.examples[uid];
if (!(example.label in counts)) {
counts[example.label] = 0;
}
counts[example.label]++;
}
return counts;
}
/**
* Get all examples of a given label, with their UIDs.
*
* @param label The requested label.
* @return All examples of the given `label`, along with their UIDs.
* The examples are sorted in the order in which they are added to the
* `Dataset`.
* @throws Error if label is `null` or `undefined`.
*/
getExamples(label: string): Array<{uid: string, example: Example}> {
tf.util.assert(
label != null,
() =>
`Expected label to be a string, but got ${JSON.stringify(label)}`);
tf.util.assert(
label in this.label2Ids,
() => `No example of label "${label}" exists in dataset`);
const output: Array<{uid: string, example: Example}> = [];
this.label2Ids[label].forEach(id => {
output.push({uid: id, example: this.examples[id]});
});
return output;
}
/**
* Get all examples and labels as tensors.
*
* - If `label` is provided and exists in the vocabulary of the `Dataset`,
* the spectrograms of all `Example`s under the `label` will be returned
* as a 4D `tf.Tensor` as `xs`. The shape of the `tf.Tensor` will be
* `[numExamples, numFrames, frameSize, 1]`
* where
* - `numExamples` is the number of `Example`s with the label
* - `numFrames` is the number of frames in each spectrogram
* - `frameSize` is the size of each spectrogram frame.
* No label Tensor will be returned.
* - If `label` is not provided, all `Example`s will be returned as `xs`.
* In addition, `ys` will contain a one-hot encoded list of labels.
* - The shape of `xs` will be: `[numExamples, numFrames, frameSize, 1]`
* - The shape of `ys` will be: `[numExamples, vocabularySize]`.
*
* @returns If `config.getDataset` is `true`, returns two `tf.data.Dataset`
* objects, one for training and one for validation.
* Else, xs` and `ys` tensors. See description above.
* @throws Error
* - if not all the involved spectrograms have matching `numFrames` and
* `frameSize`, or
* - if `label` is provided and is not present in the vocabulary of the
* `Dataset`, or
* - if the `Dataset` is currently empty.
*/
getData(label?: string, config?: GetDataConfig): {
xs: tf.Tensor4D,
ys?: tf.Tensor2D
}|[SpectrogramAndTargetsTfDataset, SpectrogramAndTargetsTfDataset] {
tf.util.assert(
this.size() > 0,
() =>
`Cannot get spectrograms as tensors because the dataset is empty`);
const vocab = this.getVocabulary();
if (label != null) {
tf.util.assert(
vocab.indexOf(label) !== -1,
() => `Label ${label} is not in the vocabulary ` +
`(${JSON.stringify(vocab)})`);
} else {
// If all words are requested, there must be at least two words in the
// vocabulary to make one-hot encoding possible.
tf.util.assert(
vocab.length > 1,
() => `One-hot encoding of labels requires the vocabulary to have ` +
`at least two words, but it has only ${vocab.length} word.`);
}
if (config == null) {
config = {};
}
// Get the numFrames lengths of all the examples currently held by the
// dataset.
const sortedUniqueNumFrames = this.getSortedUniqueNumFrames();
let numFrames: number;
let hopFrames: number;
if (sortedUniqueNumFrames.length === 1) {
numFrames = config.numFrames == null ? sortedUniqueNumFrames[0] :
config.numFrames;
hopFrames = config.hopFrames == null ? 1 : config.hopFrames;
} else {
numFrames = config.numFrames;
tf.util.assert(
numFrames != null && Number.isInteger(numFrames) && numFrames > 0,
() => `There are ${
sortedUniqueNumFrames.length} unique lengths among ` +
`the ${this.size()} examples of this Dataset, hence numFrames ` +
`is required. But it is not provided.`);
tf.util.assert(
numFrames <= sortedUniqueNumFrames[0],
() => `numFrames (${numFrames}) exceeds the minimum numFrames ` +
`(${sortedUniqueNumFrames[0]}) among the examples of ` +
`the Dataset.`);
hopFrames = config.hopFrames;
tf.util.assert(
hopFrames != null && Number.isInteger(hopFrames) && hopFrames > 0,
() => `There are ${
sortedUniqueNumFrames.length} unique lengths among ` +
`the ${this.size()} examples of this Dataset, hence hopFrames ` +
`is required. But it is not provided.`);
}
// Normalization is performed by default.
const toNormalize = config.normalize == null ? true : config.normalize;
return tf.tidy(() => {
let xTensors: tf.Tensor3D[] = [];
let xArrays: Float32Array[] = [];
let labelIndices: number[] = [];
let uniqueFrameSize: number;
for (let i = 0; i < vocab.length; ++i) {
const currentLabel = vocab[i];
if (label != null && currentLabel !== label) {
continue;
}
const ids = this.label2Ids[currentLabel];
for (const id of ids) {
const example = this.examples[id];
const spectrogram = example.spectrogram;
const frameSize = spectrogram.frameSize;
if (uniqueFrameSize == null) {
uniqueFrameSize = frameSize;
} else {
tf.util.assert(
frameSize === uniqueFrameSize,
() => `Mismatch in frameSize ` +
`(${frameSize} vs ${uniqueFrameSize})`);
}
const snippetLength = spectrogram.data.length / frameSize;
let focusIndex = null;
if (currentLabel !== BACKGROUND_NOISE_TAG) {
focusIndex = spectrogram.keyFrameIndex == null ?
getMaxIntensityFrameIndex(spectrogram).dataSync()[0] :
spectrogram.keyFrameIndex;
}
// TODO(cais): See if we can get rid of dataSync();
const snippet =
tf.tensor3d(spectrogram.data, [snippetLength, frameSize, 1]);
const windows =
getValidWindows(snippetLength, focusIndex, numFrames, hopFrames);
for (const window of windows) {
const windowedSnippet = tf.tidy(() => {
const output = tf.slice(snippet,
[window[0], 0, 0], [window[1] - window[0], -1, -1]);
return toNormalize ? normalize(output) : output;
});
if (config.getDataset) {
// TODO(cais): See if we can do away with dataSync();
// TODO(cais): Shuffling?
xArrays.push(windowedSnippet.dataSync() as Float32Array);
} else {
xTensors.push(windowedSnippet as tf.Tensor3D);
}
if (label == null) {
labelIndices.push(i);
}
}
tf.dispose(snippet); // For memory saving.
}
}
if (config.augmentByMixingNoiseRatio != null) {
this.augmentByMixingNoise(
config.getDataset ? xArrays :
xTensors as Array<Float32Array|tf.Tensor>,
labelIndices, config.augmentByMixingNoiseRatio);
}
const shuffle = config.shuffle == null ? true : config.shuffle;
if (config.getDataset) {
const batchSize =
config.datasetBatchSize == null ? 32 : config.datasetBatchSize;
// Split the data into two splits: training and validation.
const valSplit = config.datasetValidationSplit == null ?
0.15 :
config.datasetValidationSplit;
tf.util.assert(
valSplit > 0 && valSplit < 1,
() => `Invalid dataset validation split: ${valSplit}`);
const zippedXandYArrays =
xArrays.map((xArray, i) => [xArray, labelIndices[i]]);
tf.util.shuffle(
zippedXandYArrays); // Shuffle the data before splitting.
xArrays = zippedXandYArrays.map(item => item[0]) as Float32Array[];
const yArrays = zippedXandYArrays.map(item => item[1]) as number[];
const {trainXs, trainYs, valXs, valYs} =
balancedTrainValSplitNumArrays(xArrays, yArrays, valSplit);
// TODO(cais): The typing around Float32Array is not working properly
// for tf.data currently. Tighten the types when the tf.data bug is
// fixed.
// tslint:disable:no-any
const xTrain =
tfd.array(trainXs as any).map(x => tf.tensor3d(x as any, [
numFrames, uniqueFrameSize, 1
]));
const yTrain = tfd.array(trainYs).map(
y => tf.squeeze(tf.oneHot([y], vocab.length), [0]));
// TODO(cais): See if we can tighten the typing.
let trainDataset = tfd.zip({xs: xTrain, ys: yTrain});
if (shuffle) {
// Shuffle the dataset.
trainDataset = trainDataset.shuffle(xArrays.length);
}
trainDataset = trainDataset.batch(batchSize).prefetch(4);
const xVal =
tfd.array(valXs as any).map(x => tf.tensor3d(x as any, [
numFrames, uniqueFrameSize, 1
]));
const yVal = tfd.array(valYs).map(
y => tf.squeeze(tf.oneHot([y], vocab.length), [0]));
let valDataset = tfd.zip({xs: xVal, ys: yVal});
valDataset = valDataset.batch(batchSize).prefetch(4);
// tslint:enable:no-any
// tslint:disable-next-line:no-any
return [trainDataset, valDataset] as any;
} else {
if (shuffle) {
// Shuffle the data.
const zipped: Array<{x: tf.Tensor3D, y: number}> = [];
xTensors.forEach((xTensor, i) => {
zipped.push({x: xTensor, y: labelIndices[i]});
});
tf.util.shuffle(zipped);
xTensors = zipped.map(item => item.x);
labelIndices = zipped.map(item => item.y);
}
const targets = label == null ?
tf.cast(tf.oneHot(tf.tensor1d(labelIndices, 'int32'), vocab.length),
'float32') :
undefined;
return {
xs: tf.stack(xTensors) as tf.Tensor4D,
ys: targets as tf.Tensor2D
};
}
});
}
private augmentByMixingNoise<T extends tf.Tensor|Float32Array>(
xs: T[], labelIndices: number[], ratio: number): void {
if (xs == null || xs.length === 0) {
throw new Error(
`Cannot perform augmentation because data is null or empty`);
}
const isTypedArray = xs[0] instanceof Float32Array;
const vocab = this.getVocabulary();
const noiseExampleIndices: number[] = [];
const wordExampleIndices: number[] = [];
for (let i = 0; i < labelIndices.length; ++i) {
if (vocab[labelIndices[i]] === BACKGROUND_NOISE_TAG) {
noiseExampleIndices.push(i);
} else {
wordExampleIndices.push(i);
}
}
if (noiseExampleIndices.length === 0) {
throw new Error(
`Cannot perform augmentation by mixing with noise when ` +
`there is no example with label ${BACKGROUND_NOISE_TAG}`);
}
const mixedXTensors: Array<tf.Tensor|Float32Array> = [];
const mixedLabelIndices: number[] = [];
for (const index of wordExampleIndices) {
const noiseIndex = // Randomly sample from the noises, with replacement.
noiseExampleIndices[getRandomInteger(0, noiseExampleIndices.length)];
const signalTensor = isTypedArray ?
tf.tensor1d(xs[index] as Float32Array) :
xs[index] as tf.Tensor;
const noiseTensor = isTypedArray ?
tf.tensor1d(xs[noiseIndex] as Float32Array) :
xs[noiseIndex] as tf.Tensor;
const mixed: tf.Tensor =
tf.tidy(() => normalize(
tf.add(signalTensor, tf.mul(noiseTensor, ratio))));
if (isTypedArray) {
mixedXTensors.push(mixed.dataSync() as Float32Array);
} else {
mixedXTensors.push(mixed);
}
mixedLabelIndices.push(labelIndices[index]);
}
console.log(
`Data augmentation: mixing noise: added ${mixedXTensors.length} ` +
`examples`);
mixedXTensors.forEach(tensor => xs.push(tensor as T));
labelIndices.push(...mixedLabelIndices);
}
private getSortedUniqueNumFrames(): number[] {
const numFramesSet = new Set<number>();
const vocab = this.getVocabulary();
for (let i = 0; i < vocab.length; ++i) {
const label = vocab[i];
const ids = this.label2Ids[label];
for (const id of ids) {
const spectrogram = this.examples[id].spectrogram;
const numFrames = spectrogram.data.length / spectrogram.frameSize;
numFramesSet.add(numFrames);
}
}
const uniqueNumFrames = [...numFramesSet];
uniqueNumFrames.sort();
return uniqueNumFrames;
}
/**
* Remove an example from the `Dataset`.
*
* @param uid The UID of the example to remove.
* @throws Error if the UID doesn't exist in the `Dataset`.
*/
removeExample(uid: string): void {
if (!(uid in this.examples)) {
throw new Error(`Nonexistent example UID: ${uid}`);
}
const label = this.examples[uid].label;
delete this.examples[uid];
const index = this.label2Ids[label].indexOf(uid);
this.label2Ids[label].splice(index, 1);
if (this.label2Ids[label].length === 0) {
delete this.label2Ids[label];
}
}
/**
* Set the key frame index of a given example.
*
* @param uid The UID of the example of which the `keyFrameIndex` is to be
* set.
* @param keyFrameIndex The desired value of the `keyFrameIndex`. Must
* be >= 0, < the number of frames of the example, and an integer.
* @throws Error If the UID and/or the `keyFrameIndex` value is invalid.
*/
setExampleKeyFrameIndex(uid: string, keyFrameIndex: number) {
if (!(uid in this.examples)) {
throw new Error(`Nonexistent example UID: ${uid}`);
}
const spectrogram = this.examples[uid].spectrogram;
const numFrames = spectrogram.data.length / spectrogram.frameSize;
tf.util.assert(
keyFrameIndex >= 0 && keyFrameIndex < numFrames &&
Number.isInteger(keyFrameIndex),
() => `Invalid keyFrameIndex: ${keyFrameIndex}. ` +
`Must be >= 0, < ${numFrames}, and an integer.`);
spectrogram.keyFrameIndex = keyFrameIndex;
}
/**
* Get the total number of `Example` currently held by the `Dataset`.
*
* @returns Total `Example` count.
*/
size(): number {
return Object.keys(this.examples).length;
}
/**
* Get the total duration of the `Example` currently held by `Dataset`,
*
* in milliseconds.
*
* @return Total duration in milliseconds.
*/
durationMillis(): number {
let durMillis = 0;
const DEFAULT_FRAME_DUR_MILLIS = 23.22;
for (const key in this.examples) {
const spectrogram = this.examples[key].spectrogram;
const frameDurMillis =
spectrogram.frameDurationMillis | DEFAULT_FRAME_DUR_MILLIS;
durMillis +=
spectrogram.data.length / spectrogram.frameSize * frameDurMillis;
}
return durMillis;
}
/**
* Query whether the `Dataset` is currently empty.
*
* I.e., holds zero examples.
*
* @returns Whether the `Dataset` is currently empty.
*/
empty(): boolean {
return this.size() === 0;
}
/**
* Remove all `Example`s from the `Dataset`.
*/
clear(): void {
this.examples = {};
}
/**
* Get the list of labels among all `Example`s the `Dataset` currently holds.
*
* @returns A sorted Array of labels, for the unique labels that belong to all
* `Example`s currently held by the `Dataset`.
*/
getVocabulary(): string[] {
const vocab = new Set<string>();
for (const uid in this.examples) {
const example = this.examples[uid];
vocab.add(example.label);
}
const sortedVocab = [...vocab];
sortedVocab.sort();
return sortedVocab;
}
/**
* Serialize the `Dataset`.
*
* The `Examples` are sorted in the following order:
* - First, the labels in the vocabulary are sorted.
* - Second, the `Example`s for every label are sorted by the order in
* which they are added to this `Dataset`.
*
* @param wordLabels Optional word label(s) to serialize. If specified, only
* the examples with labels matching the argument will be serialized. If
* any specified word label does not exist in the vocabulary of this
* dataset, an Error will be thrown.
* @returns A `ArrayBuffer` object amenable to transmission and storage.
*/
serialize(wordLabels?: string|string[]): ArrayBuffer {
const vocab = this.getVocabulary();
tf.util.assert(!this.empty(), () => `Cannot serialize empty Dataset`);
if (wordLabels != null) {
if (!Array.isArray(wordLabels)) {
wordLabels = [wordLabels];
}
wordLabels.forEach(wordLabel => {
if (vocab.indexOf(wordLabel) === -1) {
throw new Error(
`Word label "${wordLabel}" does not exist in the ` +
`vocabulary of this dataset. The vocabulary is: ` +
`${JSON.stringify(vocab)}.`);
}
});
}
const manifest: ExampleSpec[] = [];
const buffers: ArrayBuffer[] = [];
for (const label of vocab) {
if (wordLabels != null && wordLabels.indexOf(label) === -1) {
continue;
}
const ids = this.label2Ids[label];
for (const id of ids) {
const artifact = serializeExample(this.examples[id]);
manifest.push(artifact.spec);
buffers.push(artifact.data);
}
}
return serializedExamples2ArrayBuffer(
{manifest, data: concatenateArrayBuffers(buffers)});
}
}
/** Serialize an `Example`. */
export function serializeExample(example: Example):
{spec: ExampleSpec, data: ArrayBuffer} {
const hasRawAudio = example.rawAudio != null;
const spec: ExampleSpec = {
label: example.label,
spectrogramNumFrames:
example.spectrogram.data.length / example.spectrogram.frameSize,
spectrogramFrameSize: example.spectrogram.frameSize,
};
if (example.spectrogram.keyFrameIndex != null) {
spec.spectrogramKeyFrameIndex = example.spectrogram.keyFrameIndex;
}
let data = example.spectrogram.data.buffer.slice(0);
if (hasRawAudio) {
spec.rawAudioNumSamples = example.rawAudio.data.length;
spec.rawAudioSampleRateHz = example.rawAudio.sampleRateHz;
// Account for the fact that the data are all float32.
data = concatenateArrayBuffers([data, example.rawAudio.data.buffer]);
}
return {spec, data};
}
/** Deserialize an `Example`. */
export function deserializeExample(
artifact: {spec: ExampleSpec, data: ArrayBuffer}): Example {
const spectrogram: SpectrogramData = {
frameSize: artifact.spec.spectrogramFrameSize,
data: new Float32Array(artifact.data.slice(
0,
4 * artifact.spec.spectrogramFrameSize *
artifact.spec.spectrogramNumFrames))
};
if (artifact.spec.spectrogramKeyFrameIndex != null) {
spectrogram.keyFrameIndex = artifact.spec.spectrogramKeyFrameIndex;
}
const ex: Example = {label: artifact.spec.label, spectrogram};
if (artifact.spec.rawAudioNumSamples != null) {
ex.rawAudio = {
sampleRateHz: artifact.spec.rawAudioSampleRateHz,
data: new Float32Array(artifact.data.slice(
4 * artifact.spec.spectrogramFrameSize *
artifact.spec.spectrogramNumFrames))
};
}
return ex;
}
/**
* Encode intermediate serialization format as an ArrayBuffer.
*
* Format of the binary ArrayBuffer:
* 1. An 8-byte descriptor (see above).
* 2. A 4-byte version number as Uint32.
* 3. A 4-byte number for the byte length of the JSON manifest.
* 4. The encoded JSON manifest
* 5. The binary data of the spectrograms, and raw audio (if any).
*
* @param serialized: Intermediate serialization format of a dataset.
* @returns The binary conversion result as an ArrayBuffer.
*/
function serializedExamples2ArrayBuffer(serialized: SerializedExamples):
ArrayBuffer {
const manifestBuffer =
string2ArrayBuffer(JSON.stringify(serialized.manifest));
const descriptorBuffer = string2ArrayBuffer(DATASET_SERIALIZATION_DESCRIPTOR);
const version = new Uint32Array([DATASET_SERIALIZATION_VERSION]);
const manifestLength = new Uint32Array([manifestBuffer.byteLength]);
const headerBuffer = concatenateArrayBuffers(
[descriptorBuffer, version.buffer, manifestLength.buffer]);
return concatenateArrayBuffers(
[headerBuffer, manifestBuffer, serialized.data]);
}
/** Decode an ArrayBuffer as intermediate serialization format. */
export function arrayBuffer2SerializedExamples(buffer: ArrayBuffer):
SerializedExamples {
tf.util.assert(buffer != null, () => 'Received null or undefined buffer');
// Check descriptor.
let offset = 0;
const descriptor = arrayBuffer2String(
buffer.slice(offset, DATASET_SERIALIZATION_DESCRIPTOR.length));
tf.util.assert(
descriptor === DATASET_SERIALIZATION_DESCRIPTOR,
() => `Deserialization error: Invalid descriptor`);
offset += DATASET_SERIALIZATION_DESCRIPTOR.length;
// Skip the version part for now. It may be used in the future.
offset += 4;
// Extract the length of the encoded manifest JSON as a Uint32.
const manifestLength = new Uint32Array(buffer, offset, 1);
offset += 4;
const manifestBeginByte = offset;
offset = manifestBeginByte + manifestLength[0];
const manifestBytes = buffer.slice(manifestBeginByte, offset);
const manifestString = arrayBuffer2String(manifestBytes);
const manifest = JSON.parse(manifestString);
const data = buffer.slice(offset);
return {manifest, data};
}
/**
* Get valid windows in a long snippet.
*
* Each window is represented by an inclusive left index and an exclusive
* right index.
*
* @param snippetLength Long of the entire snippet. Must be a positive
* integer.
* @param focusIndex Optional. If `null` or `undefined`, an array of
* evenly-spaced windows will be generated. The array of windows will
* start from the first possible location (i.e., [0, windowLength]).
* If not `null` or `undefined`, must be an integer >= 0 and < snippetLength.
* @param windowLength Length of each window. Must be a positive integer and
* <= snippetLength.
* @param windowHop Hops between successsive windows. Must be a positive
* integer.
* @returns An array of [beginIndex, endIndex] pairs.
*/
export function getValidWindows(
snippetLength: number, focusIndex: number, windowLength: number,
windowHop: number): Array<[number, number]> {
tf.util.assert(
Number.isInteger(snippetLength) && snippetLength > 0,
() =>
`snippetLength must be a positive integer, but got ${snippetLength}`);
if (focusIndex != null) {
tf.util.assert(
Number.isInteger(focusIndex) && focusIndex >= 0,
() =>
`focusIndex must be a non-negative integer, but got ${focusIndex}`);
}
tf.util.assert(
Number.isInteger(windowLength) && windowLength > 0,
() => `windowLength must be a positive integer, but got ${windowLength}`);
tf.util.assert(
Number.isInteger(windowHop) && windowHop > 0,
() => `windowHop must be a positive integer, but got ${windowHop}`);
tf.util.assert(
windowLength <= snippetLength,
() => `windowLength (${windowLength}) exceeds snippetLength ` +
`(${snippetLength})`);
tf.util.assert(
focusIndex < snippetLength,
() => `focusIndex (${focusIndex}) equals or exceeds snippetLength ` +
`(${snippetLength})`);
if (windowLength === snippetLength) {
return [[0, snippetLength]];
}
const windows: Array<[number, number]> = [];
if (focusIndex == null) {
// Deal with the special case of no focus frame:
// Output an array of evenly-spaced windows, starting from
// the first possible location.
let begin = 0;
while (begin + windowLength <= snippetLength) {
windows.push([begin, begin + windowLength]);
begin += windowHop;
}
return windows;
}
const leftHalf = Math.floor(windowLength / 2);
let left = focusIndex - leftHalf;
if (left < 0) {
left = 0;
} else if (left + windowLength > snippetLength) {
left = snippetLength - windowLength;
}
while (true) {
if (left - windowHop < 0 || focusIndex >= left - windowHop + windowLength) {
break;
}
left -= windowHop;
}
while (left + windowLength <= snippetLength) {
if (focusIndex < left) {
break;
}
windows.push([left, left + windowLength]);
left += windowHop;
}
return windows;
}
/**
* Calculate an intensity profile from a spectrogram.
*
* The intensity at each time frame is caclulated by simply averaging all the
* spectral values that belong to that time frame.
*
* @param spectrogram The input spectrogram.
* @returns The temporal profile of the intensity as a 1D tf.Tensor of shape
* `[numFrames]`.
*/
export function spectrogram2IntensityCurve(spectrogram: SpectrogramData):
tf.Tensor {
return tf.tidy(() => {
const numFrames = spectrogram.data.length / spectrogram.frameSize;
const x = tf.tensor2d(spectrogram.data, [numFrames, spectrogram.frameSize]);
return tf.mean(x, -1);
});
}
/**
* Get the index to the maximum intensity frame.
*
* The intensity of each time frame is calculated as the arithmetic mean of
* all the spectral values belonging to that time frame.
*
* @param spectrogram The input spectrogram.
* @returns The index to the time frame containing the maximum intensity.
*/
export function getMaxIntensityFrameIndex(spectrogram: SpectrogramData):
tf.Scalar {
return tf.tidy(() => tf.argMax(spectrogram2IntensityCurve(spectrogram)));
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,92 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Concatenate a number of ArrayBuffers into one.
*
* @param buffers A number of array buffers to concatenate.
* @returns Result of concatenating `buffers` in order.
*/
export function concatenateArrayBuffers(buffers: ArrayBuffer[]): ArrayBuffer {
let totalByteLength = 0;
buffers.forEach((buffer: ArrayBuffer) => {
totalByteLength += buffer.byteLength;
});
const temp = new Uint8Array(totalByteLength);
let offset = 0;
buffers.forEach((buffer: ArrayBuffer) => {
temp.set(new Uint8Array(buffer), offset);
offset += buffer.byteLength;
});
return temp.buffer;
}
/**
* Concatenate Float32Arrays.
*
* @param xs Float32Arrays to concatenate.
* @return The result of the concatenation.
*/
export function concatenateFloat32Arrays(xs: Float32Array[]): Float32Array {
let totalLength = 0;
xs.forEach(x => totalLength += x.length);
const concatenated = new Float32Array(totalLength);
let index = 0;
xs.forEach(x => {
concatenated.set(x, index);
index += x.length;
});
return concatenated;
}
/** Encode a string as an ArrayBuffer. */
export function string2ArrayBuffer(str: string): ArrayBuffer {
if (str == null) {
throw new Error('Received null or undefind string');
}
// NOTE(cais): This implementation is inefficient in terms of memory.
// But it works for UTF-8 strings. Just don't use on for very long strings.
const strUTF8 = unescape(encodeURIComponent(str));
const buf = new Uint8Array(strUTF8.length);
for (let i = 0; i < strUTF8.length; ++i) {
buf[i] = strUTF8.charCodeAt(i);
}
return buf.buffer;
}
/** Decode an ArrayBuffer as a string. */
export function arrayBuffer2String(buffer: ArrayBuffer): string {
if (buffer == null) {
throw new Error('Received null or undefind buffer');
}
const buf = new Uint8Array(buffer);
return decodeURIComponent(escape(String.fromCharCode(...buf)));
}
/** Generate a pseudo-random UID. */
export function getUID(): string {
function s4() {
return Math.floor((1 + Math.random()) * 0x10000).toString(16).substring(1);
}
return s4() + s4() + '-' + s4() + '-' + s4() + '-' + s4() + '-' + s4() +
s4() + s4();
}
export function getRandomInteger(min: number, max: number): number {
return Math.floor((max - min) * Math.random()) + min;
}

View File

@ -0,0 +1,81 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// tslint:disable-next-line: no-imports-from-dist
import {expectArraysEqual} from '@tensorflow/tfjs-core/dist/test_util';
import {arrayBuffer2String, concatenateFloat32Arrays, string2ArrayBuffer} from './generic_utils';
describe('string2ArrayBuffer and arrayBuffer2String', () => {
it('round trip: ASCII only', () => {
const str = 'Lorem_Ipsum_123 !@#$%^&*()';
expect(arrayBuffer2String(string2ArrayBuffer(str))).toEqual(str);
});
it('round trip: non-ASCII', () => {
const str = 'Welcome 欢迎 स्वागत हे ようこそ добро пожаловать 😀😀';
expect(arrayBuffer2String(string2ArrayBuffer(str))).toEqual(str);
});
it('round trip: empty string', () => {
const str = '';
expect(arrayBuffer2String(string2ArrayBuffer(str))).toEqual(str);
});
});
describe('concatenateFloat32Arrays', () => {
it('Two non-empty', () => {
const xs = new Float32Array([1, 3]);
const ys = new Float32Array([3, 7]);
expectArraysEqual(
concatenateFloat32Arrays([xs, ys]), new Float32Array([1, 3, 3, 7]));
expectArraysEqual(
concatenateFloat32Arrays([ys, xs]), new Float32Array([3, 7, 1, 3]));
// Assert that the original Float32Arrays are not altered.
expectArraysEqual(xs, new Float32Array([1, 3]));
expectArraysEqual(ys, new Float32Array([3, 7]));
});
it('Three unequal lengths non-empty', () => {
const array1 = new Float32Array([1]);
const array2 = new Float32Array([2, 3]);
const array3 = new Float32Array([4, 5, 6]);
expectArraysEqual(
concatenateFloat32Arrays([array1, array2, array3]),
new Float32Array([1, 2, 3, 4, 5, 6]));
});
it('One empty, one non-empty', () => {
const xs = new Float32Array([4, 2]);
const ys = new Float32Array(0);
expectArraysEqual(
concatenateFloat32Arrays([xs, ys]), new Float32Array([4, 2]));
expectArraysEqual(
concatenateFloat32Arrays([ys, xs]), new Float32Array([4, 2]));
// Assert that the original Float32Arrays are not altered.
expectArraysEqual(xs, new Float32Array([4, 2]));
expectArraysEqual(ys, new Float32Array(0));
});
it('Two empty', () => {
const xs = new Float32Array(0);
const ys = new Float32Array(0);
expectArraysEqual(concatenateFloat32Arrays([xs, ys]), new Float32Array(0));
expectArraysEqual(concatenateFloat32Arrays([ys, xs]), new Float32Array(0));
// Assert that the original Float32Arrays are not altered.
expectArraysEqual(xs, new Float32Array(0));
expectArraysEqual(ys, new Float32Array(0));
});
});

View File

@ -0,0 +1,91 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs-core';
import {BrowserFftSpeechCommandRecognizer} from './browser_fft_recognizer';
import {playRawAudio} from './browser_fft_utils';
import {concatenateFloat32Arrays} from './generic_utils';
import {FFT_TYPE, SpeechCommandRecognizer, SpeechCommandRecognizerMetadata} from './types';
import { normalizeFloat32Array, normalize } from './browser_fft_utils';
/**
* Create an instance of speech-command recognizer.
*
* @param fftType Type of FFT. The currently availble option(s):
* - BROWSER_FFT: Obtains audio spectrograms using browser's native Fourier
* transform.
* @param vocabulary The vocabulary of the model to load. Possible options:
* - '18w' (default): The 18-word vocaulbary, consisting of:
* 'zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven',
* 'eight', 'nine', 'up', 'down', 'left', 'right', 'go', 'stop',
* 'yes', and 'no', in addition to '_background_noise_' and '_unknown_'.
* - 'directional4w': The four directional words: 'up', 'down', 'left', and
* 'right', in addition to '_background_noise_' and '_unknown_'.
* Choosing a smaller vocabulary leads to better accuracy on the words of
* interest and a slightly smaller model size.
* @param customModelArtifactsOrURL A custom model URL pointing to a model.json
* file, or a set of modelArtifacts in `tf.io.ModelArtifacts` format.
* Supported schemes: http://, https://, and node.js-only: file://.
* Mutually exclusive with `vocabulary`. If provided, `customMetadatURL`
* most also be provided.
* @param customMetadataOrURL A custom metadata URL pointing to a metadata.json
* file. Must be provided together with `customModelURL`, or a metadata
* object.
* @returns An instance of SpeechCommandRecognizer.
* @throws Error on invalid value of `fftType`.
*/
export function create(
fftType: FFT_TYPE, vocabulary?: string,
customModelArtifactsOrURL?: tf.io.ModelArtifacts|string,
customMetadataOrURL?: SpeechCommandRecognizerMetadata|
string): SpeechCommandRecognizer {
tf.util.assert(
customModelArtifactsOrURL == null && customMetadataOrURL == null ||
customModelArtifactsOrURL != null && customMetadataOrURL != null,
() => `customModelURL and customMetadataURL must be both provided or ` +
`both not provided.`);
if (customModelArtifactsOrURL != null) {
tf.util.assert(
vocabulary == null,
() => `vocabulary name must be null or undefined when modelURL ` +
`is provided.`);
}
if (fftType === 'BROWSER_FFT') {
return new BrowserFftSpeechCommandRecognizer(
vocabulary, customModelArtifactsOrURL, customMetadataOrURL);
} else if (fftType === 'SOFT_FFT') {
throw new Error(
'SOFT_FFT SpeechCommandRecognizer has not been implemented yet.');
} else {
throw new Error(`Invalid fftType: '${fftType}'`);
}
}
const utils = {
concatenateFloat32Arrays,
normalizeFloat32Array,
normalize,
playRawAudio
};
export {BACKGROUND_NOISE_TAG, Dataset, GetDataConfig as GetSpectrogramsAsTensorsConfig, getMaxIntensityFrameIndex, spectrogram2IntensityCurve, SpectrogramAndTargetsTfDataset} from './dataset';
export {AudioDataAugmentationOptions, Example, FFT_TYPE, RawAudioData, RecognizerParams, SpectrogramData, SpeechCommandRecognizer, SpeechCommandRecognizerMetadata, SpeechCommandRecognizerResult, StreamingRecognitionConfig, TransferLearnConfig, TransferSpeechCommandRecognizer} from './types';
export {deleteSavedTransferModel, listSavedTransferModels, UNKNOWN_TAG} from './browser_fft_recognizer';
export {utils};
export {version} from './version';

View File

@ -0,0 +1,69 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// tslint:disable-next-line:no-require-imports
const packageJSON = require('../package.json');
import * as tf from '@tensorflow/tfjs-core';
import * as tfl from '@tensorflow/tfjs-layers';
import * as speechCommands from './index';
describe('Public API', () => {
it('version matches package.json', () => {
expect(typeof speechCommands.version).toEqual('string');
expect(speechCommands.version).toEqual(packageJSON.version);
});
});
describe('Creating recognizer', () => {
async function makeModelArtifacts(): Promise<tf.io.ModelArtifacts> {
const model = tfl.sequential();
model.add(tfl.layers.conv2d({
filters: 8,
kernelSize: 3,
activation: 'relu',
inputShape: [86, 500, 1]
}));
model.add(tfl.layers.flatten());
model.add(tfl.layers.dense({units: 3, activation: 'softmax'}));
let modelArtifacts: tf.io.ModelArtifacts;
await model.save(tf.io.withSaveHandler(artifacts => {
modelArtifacts = artifacts;
return null;
}));
return modelArtifacts;
}
function makeMetadata(): speechCommands.SpeechCommandRecognizerMetadata {
return {
wordLabels: [speechCommands.BACKGROUND_NOISE_TAG, 'foo', 'bar'],
tfjsSpeechCommandsVersion: speechCommands.version
};
}
it('Create recognizer from aritfacts and metadata objects', async () => {
const modelArtifacts = await makeModelArtifacts();
const metadata = makeMetadata();
const recognizer =
speechCommands.create('BROWSER_FFT', null, modelArtifacts, metadata);
await recognizer.ensureModelLoaded();
expect(recognizer.wordLabels()).toEqual([
speechCommands.BACKGROUND_NOISE_TAG, 'foo', 'bar'
]);
expect(recognizer.modelInputShape()).toEqual([null, 86, 500, 1]);
});
});

View File

@ -0,0 +1,46 @@
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {Tensor, test_util, util} from '@tensorflow/tfjs-core';
export function expectTensorsClose(
actual: Tensor|number[], expected: Tensor|number[], epsilon?: number) {
if (actual == null) {
throw new Error(
'First argument to expectTensorsClose() is not defined.');
}
if (expected == null) {
throw new Error(
'Second argument to expectTensorsClose() is not defined.');
}
if (actual instanceof Tensor && expected instanceof Tensor) {
if (actual.dtype !== expected.dtype) {
throw new Error(
`Data types do not match. Actual: '${actual.dtype}'. ` +
`Expected: '${expected.dtype}'`);
}
if (!util.arraysEqual(actual.shape, expected.shape)) {
throw new Error(
`Shapes do not match. Actual: [${actual.shape}]. ` +
`Expected: [${expected.shape}].`);
}
}
const actualData = actual instanceof Tensor ? actual.dataSync() : actual;
const expectedData =
expected instanceof Tensor ? expected.dataSync() : expected;
test_util.expectArraysClose(actualData, expectedData, epsilon);
}

View File

@ -0,0 +1,164 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Utility functions for training and transfer learning of the speech-commands
* model.
*/
import * as tf from '@tensorflow/tfjs-core';
/**
* Split feature and target tensors into train and validation (val) splits.
*
* Given sufficent number of examples, the train and val sets will be
* balanced with respect to the classes.
*
* @param xs Features tensor, of shape [numExamples, ...].
* @param ys Targets tensors, of shape [numExamples, numClasses]. Assumed to be
* one-hot categorical encoding.
* @param valSplit A number > 0 and < 1, fraction of examples to use
* as the validation set.
* @returns trainXs: training features tensor; trainYs: training targets
* tensor; valXs: validation features tensor; valYs: validation targets
* tensor.
*/
export function balancedTrainValSplit(
xs: tf.Tensor, ys: tf.Tensor, valSplit: number): {
trainXs: tf.Tensor,
trainYs: tf.Tensor,
valXs: tf.Tensor,
valYs: tf.Tensor
} {
tf.util.assert(
valSplit > 0 && valSplit < 1,
() => `validationSplit is expected to be >0 and <1, ` +
`but got ${valSplit}`);
return tf.tidy(() => {
const classIndices = tf.argMax(ys, -1).dataSync();
const indicesByClasses: number[][] = [];
for (let i = 0; i < classIndices.length; ++i) {
const classIndex = classIndices[i];
if (indicesByClasses[classIndex] == null) {
indicesByClasses[classIndex] = [];
}
indicesByClasses[classIndex].push(i);
}
const numClasses = indicesByClasses.length;
const trainIndices: number[] = [];
const valIndices: number[] = [];
// Randomly shuffle the list of indices in each array.
indicesByClasses.map(classIndices => tf.util.shuffle(classIndices));
for (let i = 0; i < numClasses; ++i) {
const classIndices = indicesByClasses[i];
const cutoff = Math.round(classIndices.length * (1 - valSplit));
for (let j = 0; j < classIndices.length; ++j) {
if (j < cutoff) {
trainIndices.push(classIndices[j]);
} else {
valIndices.push(classIndices[j]);
}
}
}
const trainXs = tf.gather(xs, trainIndices);
const trainYs = tf.gather(ys, trainIndices);
const valXs = tf.gather(xs, valIndices);
const valYs = tf.gather(ys, valIndices);
return {trainXs, trainYs, valXs, valYs};
});
}
/**
* Same as balancedTrainValSplit, but for number arrays or Float32Arrays.
*/
export function balancedTrainValSplitNumArrays(
xs: number[][]|Float32Array[], ys: number[], valSplit: number): {
trainXs: number[][]|Float32Array[],
trainYs: number[],
valXs: number[][]|Float32Array[],
valYs: number[]
} {
tf.util.assert(
valSplit > 0 && valSplit < 1,
() => `validationSplit is expected to be >0 and <1, ` +
`but got ${valSplit}`);
const isXsFloat32Array = !Array.isArray(xs[0]);
const classIndices = ys;
const indicesByClasses: number[][] = [];
for (let i = 0; i < classIndices.length; ++i) {
const classIndex = classIndices[i];
if (indicesByClasses[classIndex] == null) {
indicesByClasses[classIndex] = [];
}
indicesByClasses[classIndex].push(i);
}
const numClasses = indicesByClasses.length;
const trainIndices: number[] = [];
const valIndices: number[] = [];
// Randomly shuffle the list of indices in each array.
indicesByClasses.map(classIndices => tf.util.shuffle(classIndices));
for (let i = 0; i < numClasses; ++i) {
const classIndices = indicesByClasses[i];
const cutoff = Math.round(classIndices.length * (1 - valSplit));
for (let j = 0; j < classIndices.length; ++j) {
if (j < cutoff) {
trainIndices.push(classIndices[j]);
} else {
valIndices.push(classIndices[j]);
}
}
}
if (isXsFloat32Array) {
const trainXs: Float32Array[] = [];
const trainYs: number[] = [];
const valXs: Float32Array[] = [];
const valYs: number[] = [];
for (const index of trainIndices) {
trainXs.push(xs[index] as Float32Array);
trainYs.push(ys[index]);
}
for (const index of valIndices) {
valXs.push(xs[index] as Float32Array);
valYs.push(ys[index]);
}
return {trainXs, trainYs, valXs, valYs};
} else {
const trainXs: number[][] = [];
const trainYs: number[] = [];
const valXs: number[][] = [];
const valYs: number[] = [];
for (const index of trainIndices) {
trainXs.push(xs[index] as number[]);
trainYs.push(ys[index]);
}
for (const index of valIndices) {
valXs.push(xs[index] as number[]);
valYs.push(ys[index]);
}
return {trainXs, trainYs, valXs, valYs};
}
}

View File

@ -0,0 +1,60 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import '@tensorflow/tfjs-node';
import * as tf from '@tensorflow/tfjs-core';
// tslint:disable-next-line: no-imports-from-dist
import {describeWithFlags, NODE_ENVS} from '@tensorflow/tfjs-core/dist/jasmine_util';
import {expectTensorsClose} from './test_utils';
import {balancedTrainValSplit} from './training_utils';
describeWithFlags('balancedTrainValSplit', NODE_ENVS, () => {
it('Enough data for split', () => {
const xs = tf.randomNormal([8, 3]);
const ys = tf.oneHot(tf.tensor1d([0, 0, 0, 0, 1, 1, 1, 1], 'int32'), 2);
const {trainXs, trainYs, valXs, valYs} =
balancedTrainValSplit(xs, ys, 0.25);
expect(trainXs.shape).toEqual([6, 3]);
expect(trainYs.shape).toEqual([6, 2]);
expect(valXs.shape).toEqual([2, 3]);
expect(valYs.shape).toEqual([2, 2]);
expectTensorsClose(tf.sum(trainYs, 0), tf.tensor1d([3, 3], 'int32'));
expectTensorsClose(tf.sum(valYs, 0), tf.tensor1d([1, 1], 'int32'));
});
it('Not enough data for split', () => {
const xs = tf.randomNormal([8, 3]);
const ys = tf.oneHot(tf.tensor1d([0, 0, 0, 0, 1, 1, 1, 1], 'int32'), 2);
const {trainXs, trainYs, valXs, valYs} =
balancedTrainValSplit(xs, ys, 0.01);
expect(trainXs.shape).toEqual([8, 3]);
expect(trainYs.shape).toEqual([8, 2]);
expect(valXs.shape).toEqual([0, 3]);
expect(valYs.shape).toEqual([0, 2]);
});
it('Invalid valSplit leads to Error', () => {
const xs = tf.randomNormal([8, 3]);
const ys = tf.oneHot(tf.tensor1d([0, 0, 0, 0, 1, 1, 1, 1], 'int32'), 2);
expect(() => balancedTrainValSplit(xs, ys, -0.2)).toThrow();
expect(() => balancedTrainValSplit(xs, ys, 0)).toThrow();
expect(() => balancedTrainValSplit(xs, ys, 1)).toThrow();
expect(() => balancedTrainValSplit(xs, ys, 1.2)).toThrow();
});
});

View File

@ -0,0 +1,754 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs-core';
import * as tfl from '@tensorflow/tfjs-layers';
/**
* This file defines the interfaces related to SpeechCommandRecognizer.
*/
export type FFT_TYPE = 'BROWSER_FFT'|'SOFT_FFT';
export type RecognizerCallback = (result: SpeechCommandRecognizerResult) =>
Promise<void>;
/**
* Interface for a speech-command recognizer.
*/
export interface SpeechCommandRecognizer {
/**
* Load the underlying model instance and associated metadata.
*
* If the model and the metadata are already loaded, do nothing.
*/
ensureModelLoaded(): Promise<void>;
/**
* Start listening continuously to microphone input and perform recognition
* in a streaming fashion.
*
* @param callback the callback that will be invoked every time
* a recognition result is available.
* @param config optional configuration.
* @throws Error if there is already ongoing streaming recognition.
*/
listen(callback: RecognizerCallback, config?: StreamingRecognitionConfig):
Promise<void>;
/**
* Stop the ongoing streaming recognition (if any).
*
* @throws Error if no streaming recognition is ongoing.
*/
stopListening(): Promise<void>;
/**
* Check if this instance is currently performing
* streaming recognition.
*/
isListening(): boolean;
/**
* Recognize a single example of audio.
*
* If `input` is provided, will perform offline prediction.
* If `input` is not provided, a single frame of audio
* will be collected from the microhpone via WebAudio and predictions
* will be made on it.
*
* @param input (Optional) tf.Tensor of Float32Array.
* If provided and a tf.Tensor, must match the input shape of the
* underlying tf.Model. If a Float32Array, the length must be
* equal to (the models required FFT length) *
* (the models required frame count).
* @returns A Promise of recognition result, with the following fields:
* - scores: the probability scores.
* - embedding: the embedding for the input audio (i.e., an internal
* activation from the model). Provided if and only if `includeEmbedding`
* is `true` in `config`.
* @throws Error on incorrect shape or length.
*/
recognize(input?: tf.Tensor|Float32Array, config?: RecognizeConfig):
Promise<SpeechCommandRecognizerResult>;
/**
* Get the input shape of the tf.Model the underlies the recognizer.
*/
modelInputShape(): tfl.Shape;
/**
* Getter for word labels.
*
* The word labels are an alphabetically sorted Array of strings.
*/
wordLabels(): string[];
/**
* Get the parameters such as the required number of frames.
*/
params(): RecognizerParams;
/**
* Create a new recognizer based on this recognizer, for transfer learning.
*
* @param name Required name of the transfer learning recognizer. Must be a
* non-empty string.
* @returns An instance of TransferSpeechCommandRecognizer, which supports
* `collectExample()`, `train()`, as well as the same `listen()`
* `stopListening()` and `recognize()` as the base recognizer.
*/
createTransfer(name: string): TransferSpeechCommandRecognizer;
}
export interface ExampleCollectionOptions {
/**
* Multiplier for the duration.
*
* This is the ratio between the duration of the to-be-collected
* example and the duration of each input example accepted by the
* underlying convnet.
*
* If not provided, will default to 1.
*
* Must be a number >=1.
*/
durationMultiplier?: number;
/**
* Duration in seconds.
*
* Mutually exclusive with durationMultiplier.
* If specified, must be >0.
*/
durationSec?: number;
/**
* Optional constraints for the audio track.
*
* E.g., this can be used to select a microphone when multiple microphones
* are available on the system: `{deviceId: 'deadbeef'}`.
*/
audioTrackConstraints?: MediaTrackConstraints;
/**
* Optional snipppet duration in seconds.
*
* Must be supplied if `onSnippet` is specified.
*/
snippetDurationSec?: number;
/**
* Optional snippet callback.
*
* Must be provided if `snippetDurationSec` is specified.
*
* Gets called every snippetDurationSec with a latest slice of the
* spectrogram. It is the spectrogram accumulated since the last invocation of
* the callback (or for the first time, since when `collectExample()` is
* started).
*/
onSnippet?: (spectrogram: SpectrogramData) => Promise<void>;
/**
* Whether to collect the raw time-domain audio waveform in addition to the
* spectrogram.
*
* Default: `false`.
*/
includeRawAudio?: boolean;
}
/**
* Metadata for a speech-comamnds recognizer.
*/
export interface SpeechCommandRecognizerMetadata {
/** Version of the speech-commands library. */
tfjsSpeechCommandsVersion: string;
/** Name of the model. */
modelName?: string;
/** A time stamp for when this metadata is generatd. */
timeStamp?: string;
/**
* Word labels for the recognizer model's output probability scores.
*
* The length of this array should be equal to the size of the last dimension
* of the model's output.
*/
wordLabels: string[];
}
/**
* Interface for a transfer-learning speech command recognizer.
*
* This inherits the `SpeechCommandRecognizer`. It adds methods for
* collecting and clearing examples for transfer learning, methods for
* querying the status of example collection, and for performing the
* transfer-learning training.
*/
export interface TransferSpeechCommandRecognizer extends
SpeechCommandRecognizer {
/**
* Collect an example for transfer learning via WebAudio.
*
* @param {string} word Name of the word. Must not overlap with any of the
* words the base model is trained to recognize.
* @returns {SpectrogramData} The spectrogram of the acquired the example.
* @throws Error, if word belongs to the set of words the base model is
* trained to recognize.
*/
collectExample(word: string, options?: ExampleCollectionOptions):
Promise<SpectrogramData>;
/**
* Clear all transfer learning examples collected so far.
*/
clearExamples(): void;
/**
* Get counts of the word examples that have been collected for a
* transfer-learning model.
*
* @returns {{[word: string]: number}} A map from word name to number of
* examples collected for that word so far.
*/
countExamples(): {[word: string]: number};
/**
* Train a transfer-learning model.
*
* The last dense layer of the base model is replaced with new softmax dense
* layer.
*
* It is assume that at least one category of data has been collected (using
* multiple calls to the `collectTransferExample` method).
*
* @param config {TransferLearnConfig} Optional configurations fot the
* training of the transfer-learning model.
* @returns {tf.History} A history object with the loss and accuracy values
* from the training of the transfer-learning model.
* @throws Error, if `modelName` is invalid or if not sufficient training
* examples have been collected yet.
*/
train(config?: TransferLearnConfig):
Promise<tfl.History|[tfl.History, tfl.History]>;
/**
* Perform evaluation of the model using the examples that the model
* has loaded.
*
* The evaluation calcuates an ROC curve by lumping the non-background-noise
* classes into a positive category and treating the background-noise
* class as the negative category.
*
* @param config Configuration object for the evaluation.
* @returns A Promise of the result of evaluation.
*/
evaluate(config: EvaluateConfig): Promise<EvaluateResult>;
/**
* Get examples currently held by the transfer-learning recognizer.
*
* @param label Label requested.
* @returns An array of `Example`s, along with their UIDs.
*/
getExamples(label: string): Array<{uid: string, example: Example}>;
/** Set the key frame index of a given example. */
setExampleKeyFrameIndex(uid: string, keyFrameIndex: number): void;
/**
* Load an array of serialized examples.
*
* @param serialized The examples in their serialized format.
* @param clearExisting Whether to clear the existing examples while
* performing the loading (default: false).
*/
loadExamples(serialized: ArrayBuffer, clearExisting?: boolean): void;
/**
* Serialize the existing examples.
*
* @param wordLabels Optional word label(s) to serialize. If specified, only
* the examples with labels matching the argument will be serialized. If
* any specified word label does not exist in the vocabulary of this
* transfer recognizer, an Error will be thrown.
* @returns An `ArrayBuffer` object amenable to transmission and storage.
*/
serializeExamples(wordLabels?: string|string[]): ArrayBuffer;
/**
* Remove an example from the dataset of the transfer recognizer.
*
* @param uid The UID for the example to be removed.
*/
removeExample(uid: string): void;
/**
* Check whether the dataset underlying this transfer recognizer is empty.
*
* @returns A boolean indicating whether the underlying dataset is empty.
*/
isDatasetEmpty(): boolean;
/**
* Save the transfer-learned model.
*
* By default, the model's topology and weights are saved to browser
* IndexedDB, and the associated metadata are saved to browser LocalStorage.
*
* The saved metadata includes (among other things) the word list.
*
* To save the model to another destination, use the optional argument
* `handlerOrURL`. Note that if you use the custom route, you'll
* currently have to handle the metadata (e.g., word list) saving yourself.
*
* @param handlerOrURL Optional custom save URL or IOHandler object. E.g.,
* `'downloads://my-file-name'`.
* @returns A `Promise` of a `SaveResult` object that summarizes the
* saving result.
*/
save(handlerOrURL?: string|tf.io.IOHandler): Promise<tf.io.SaveResult>;
/**
* Load the transfer-learned model.
*
* By default, the model's topology and weights are loaded from browser
* IndexedDB and the associated metadata are loaded from browser
* LocalStorage.
*
* To load the model from another destination, use the optional
* argument. Note that if you load the model from a custom URL or
* IOHandler, you'll currently have to load the metadata (e.g., word
* list) yourself.
*
* @param handlerOrURL Optional custom source URL or IOHandler object
* to load the data from. E.g.,
* `tf.io.browserFiles([modelJSONFile, weightsFile])`
*/
load(handlerOrURL?: string|tf.io.IOHandler): Promise<void>;
/**
* Get metadata about the transfer recognizer.
*
* The metadata includes but is not limited to: speech-commands library
* version, word labels that correspond to the model's probability outputs.
*/
getMetadata(): SpeechCommandRecognizerMetadata;
}
/**
* Interface for a snippet of audio spectrogram.
*/
export interface SpectrogramData {
/**
* The float32 data for the spectrogram.
*
* Stored frame by frame. For example, the first N elements
* belong to the first time frame and the next N elements belong
* to the second time frame, and so forth.
*/
data: Float32Array;
/**
* Number of points per frame, i.e., FFT length per frame.
*/
frameSize: number;
/**
* Duration of each frame in milliseconds.
*/
frameDurationMillis?: number;
/**
* Index to the key frame (0-based).
*
* A key frame is a frame in the spectrogram that belongs to
* the utterance of interest. It is used to distinguish the
* utterance part from the background-noise part.
*
* A typical use of key frame index: when multiple training examples are
* extracted from a spectroram, every example is guaranteed to include
* the key frame.
*
* Key frame is not required. If it is missing, heuristics algorithms
* (e.g., finding the highest-intensity frame) can be used to calculate
* the key frame.
*/
keyFrameIndex?: number;
}
/**
* Interface for a result emitted by a speech-command recognizer.
*
* It is used in the callback of a recognizer's streaming or offline
* recognition method. It represents the result for a short snippet of
* audio.
*/
export interface SpeechCommandRecognizerResult {
/**
* Probability scores for the words.
*/
scores: Float32Array|Float32Array[];
/**
* Optional spectrogram data.
*/
spectrogram?: SpectrogramData;
/**
* Embedding (internal activation) for the input.
*
* This field is populated if and only if `includeEmbedding`
* is `true` in the configuration object used during the `recognize` call.
*/
embedding?: tf.Tensor;
}
export interface StreamingRecognitionConfig {
/**
* Overlap factor. Must be >=0 and <1.
* Defaults to 0.5.
* For example, if the model takes a frame length of 1000 ms,
* and if overlap factor is 0.4, there will be a 400ms
* overlap between two successive frames, i.e., frames
* will be taken every 600 ms.
*/
overlapFactor?: number;
/**
* Amount to time in ms to suppress recognizer after a word is recognized.
*
* Defaults to 1000 ms.
*/
suppressionTimeMillis?: number;
/**
* Threshold for the maximum probability value in a model prediction
* output to be greater than or equal to, below which the callback
* will not be called.
*
* Must be a number >=0 and <=1.
*
* The value will be overridden to `0` if `includeEmbedding` is `true`.
*
* If `null` or `undefined`, will default to `0`.
*/
probabilityThreshold?: number;
/**
* Invoke the callback for background noise and unknown.
*
* The value will be overridden to `true` if `includeEmbedding` is `true`.
*
* Default: `false`.
*/
invokeCallbackOnNoiseAndUnknown?: boolean;
/**
* Whether the spectrogram is to be provided in the each recognition
* callback call.
*
* Default: `false`.
*/
includeSpectrogram?: boolean;
/**
* Whether to include the embedding (internal activation).
*
* If set as `true`, the values of the following configuration fields
* in this object will be overridden:
*
* - `probabilityThreshold` will be overridden to 0.
* - `invokeCallbackOnNoiseAndUnknown` will be overridden to `true`.
*
* Default: `false`.
*/
includeEmbedding?: boolean;
/**
* Optional constraints for the audio track.
*
* E.g., this can be used to select a microphone when multiple microphones
* are available on the system: `{deviceId: 'deadbeef'}`.
*/
audioTrackConstraints?: MediaTrackConstraints;
}
export interface RecognizeConfig {
/**
* Whether the spectrogram is to be provided in the each recognition
* callback call.
*
* Default: `false`.
*/
includeSpectrogram?: boolean;
/**
* Whether to include the embedding (internal activation).
*
* Default: `false`.
*/
includeEmbedding?: boolean;
}
export interface AudioDataAugmentationOptions {
/**
* Additive ratio for augmenting the data by mixing the word spectrograms
* with background-noise ones.
*
* If not `null` or `undefined`, will cause extra word spectrograms to be
* created through the equation:
* (normalizedWordSpectrogram +
* augmentByMixingNoiseRatio * normalizedNoiseSpectrogram)
*
* The normalizedNoiseSpectrogram will be drawn randomly from all noise
* snippets available. If no noise snippet is available, an Error will
* be thrown.
*
* Default: `undefined`.
*/
augmentByMixingNoiseRatio?: number;
// TODO(cais): Add other augmentation options, including augmentByReverb,
// augmentByTempoShift and augmentByFrequencyShift.
}
/**
* Configurations for the training of a transfer-learning recognizer.
*
* It is used during calls to the `TransferSpeechCommandRecognizer.train()`
* method.
*/
export interface TransferLearnConfig extends AudioDataAugmentationOptions {
/**
* Number of training epochs (default: 20).
*/
epochs?: number;
/**
* Optimizer to be used for training (default: 'sgd').
*/
optimizer?: string|tf.Optimizer;
/**
* Batch size of training (default: 128).
*/
batchSize?: number;
/**
* Validation split to be used during training.
*
* Default: null (no validation split).
*
* Note that this is split is different from the basic validation-split
* paradigm in TensorFlow.js. It makes sure that the distribution of the
* classes in the training and validation sets are approximately balanced.
*
* If specified, must be a number > 0 and < 1.
*/
validationSplit?: number;
/**
* Number of fine-tuning epochs to run after the initial `epochs` epochs
* of transfer-learning training.
*
* During the fine-tuning, the last dense layer of the truncated base
* model (i.e., the second-last dense layer of the original model) is
* unfrozen and updated through backpropagation.
*
* If specified, must be an integer > 0.
*/
fineTuningEpochs?: number;
/**
* The optimizer for fine-tuning after the initial transfer-learning
* training.
*
* This parameter is used only if `fineTuningEpochs` is specified
* and is a positive integre.
*
* Default: 'sgd'.
*/
fineTuningOptimizer?: string|tf.Optimizer;
/**
* tf.Callback to be used during the initial training (i.e., not
* the fine-tuning phase).
*/
callback?: tfl.CustomCallbackArgs;
/**
* tf.Callback to be used durnig the fine-tuning phase.
*
* This parameter is used only if `fineTuningEpochs` is specified
* and is a positive integer.
*/
fineTuningCallback?: tfl.CustomCallbackArgs;
/**
* Ratio between the window hop and the window width.
*
* Used during extraction of multiple spectrograms matching the underlying
* model's input shape from a longer spectroram.
*
* Defaults to 0.25.
*
* For example, if the spectrogram window accepted by the underlying model
* is 43 frames long, then the default windowHopRatio 0.25 will lead to
* a hop of Math.round(43 * 0.25) = 11 frames.
*/
windowHopRatio?: number;
/**
* The threshold for the total duration of the dataset above which
* `fitDataset()` will be used in lieu of `fit()`.
*
* Default: 60e3 (1 minute).
*/
fitDatasetDurationMillisThreshold?: number;
}
/**
* Type for a Receiver Operating Characteristics (ROC) curve.
*/
export type ROCCurve =
Array < {probThreshold?: number, /** Probability threshold */
fpr: number, /** False positive rate (FP / N) */
tpr: number /** True positive rate (TP / P) */
falsePositivesPerHour?: number /** FPR converted to per hour rate */
}>;
/**
* Model evaluation result.
*/
export interface EvaluateResult {
/**
* ROC curve.
*/
rocCurve?: ROCCurve;
/**
* Area under the (ROC) curve.
*/
auc?: number;
}
/**
* Model evaluation configuration.
*/
export interface EvaluateConfig {
/**
* Ratio between the window hop and the window width.
*
* Used during extraction of multiple spectrograms matching the underlying
* model's input shape from a longer spectroram.
*
* For example, if the spectrogram window accepted by the underlying model
* is 43 frames long, then the default windowHopRatio 0.25 will lead to
* a hop of Math.round(43 * 0.25) = 11 frames.
*/
windowHopRatio: number;
/**
* Word probability score thresholds, used to calculate the ROC.
*
* E.g., [0, 0.2, 0.4, 0.6, 0.8, 1.0].
*/
wordProbThresholds: number[];
}
/**
* Parameters for a speech-command recognizer.
*/
export interface RecognizerParams {
/**
* Total duration per spectragram, in milliseconds.
*/
spectrogramDurationMillis?: number;
/**
* FFT encoding size per spectrogram column.
*/
fftSize?: number;
/**
* Sampling rate, in Hz.
*/
sampleRateHz?: number;
}
/**
* Interface of an audio feature extractor.
*/
export interface FeatureExtractor {
/**
* Config the feature extractor.
*/
setConfig(params: RecognizerParams): void;
/**
* Start the feature extraction from the audio samples.
*/
start(audioTrackConstraints?: MediaTrackConstraints):
Promise<Float32Array[]|void>;
/**
* Stop the feature extraction.
*/
stop(): Promise<void>;
/**
* Get the extractor features collected since last call.
*/
getFeatures(): Float32Array[];
}
/** Snippet of pulse-code modulation (PCM) audio data. */
export interface RawAudioData {
/** Samples of the snippet. */
data: Float32Array;
/** Sampling rate, in Hz. */
sampleRateHz: number;
}
/**
* A short, labeled snippet of speech or audio.
*
* This can be used for training a transfer model based on the base
* speech-commands model, among other things.
*
* A set of `Example`s can make up a dataset.
*/
export interface Example {
/** A label for the example. */
label: string;
/** Spectrogram data. */
spectrogram: SpectrogramData;
/**
* Raw audio in PCM (pulse-code modulation) format.
*
* Optional.
*/
rawAudio?: RawAudioData;
}

View File

@ -0,0 +1,5 @@
/** @license See the LICENSE file. */
// This code is auto-generated, do not modify this file!
const version = '0.5.4';
export {version};

View File

@ -0,0 +1,26 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {version} from './index';
describe('version', () => {
it('version matches package.json', () => {
// tslint:disable-next-line:no-require-imports
const expected = require('../package.json').version;
expect(version).toBe(expected);
});
});

View File

@ -0,0 +1,22 @@
# Training a TensorFlow.js model for Speech Commands Using Browser FFT
This directory contains two example notebooks. They demonstrate how to train
custom TensorFlow.js audio models and deploy them for inference. The models
trained this way expect inputs to be spectrograms in a format consistent with
[WebAudio's `getFloatFrequencyData`](https://developer.mozilla.org/en-US/docs/Web/API/AnalyserNode/getFloatFrequencyData).
Therefore they can be deployed to the browser using the speech-commands library
for inference.
Specifically,
- [training_custom_audio_model_in_python.ipynb](./training_custom_audio_model_in_python.ipynb)
contains steps to preprocess a directory with audio examples stored as .wav
files and the steps in which a tf.keras model can be trained on the
preprocessed data. It then demonstrates how the trained tf.keras model can be
converted to a TensorFlow.js `LayersModel` that can be loaded with the
speech-command library's `create()` API. In addition, the notebook also shows
the steps to convert the trained tf.keras model to a TFLite model for
inference on mobile devices.
- [tflite_conversion.ipynb](./tflite_conversion.ipynb) illustrates how
an audio model trained on [Teachable Machine](https://teachablemachine.withgoogle.com/train/audio)
can be converted to TFLite directly.

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,125 @@
# Training a TensorFlow.js model for Speech Commands Using node.js
## Preparing data for training
Before you can train your model that uses spectrogram from the browser's
WebAudio as input features, you need to download the speech-commands [data set v0.01](https://storage.cloud.google.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz) or [data set v0.02](https://storage.cloud.google.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz).
## Training the TensorFlow.js Model
The node.js training package comes with a command line tool that will assist your training. Here are the steps:
1. Prepare the node modules dependecies:
```bash
yarn
```
2. Start the CLI program:
```none
yarn start
```
Following are command supported by the CLI:
```none
Commands:
help [command...] Provides help for a given command.
exit Exits application.
create_model [labels...] create the audio model
load_dataset all <dir> Load all the data from the root directory by the labels
load_dataset <dir> <label> Load the dataset from the directory with the label
dataset size Show the size of the dataset
train [epoch] train all audio dataset
save_model <filename> save the audio model
```
3. You need to first create a model. For example create a model with four labels (up down left right):
```none
local@piyu~$ create up down left right
_________________________________________________________________
Layer (type) Output shape Param #
=================================================================
conv2d_Conv2D1 (Conv2D) [null,95,39,8] 72
_________________________________________________________________
max_pooling2d_MaxPooling2D1 [null,47,19,8] 0
_________________________________________________________________
conv2d_Conv2D2 (Conv2D) [null,44,18,32] 2080
_________________________________________________________________
max_pooling2d_MaxPooling2D2 [null,22,9,32] 0
_________________________________________________________________
conv2d_Conv2D3 (Conv2D) [null,19,8,32] 8224
_________________________________________________________________
max_pooling2d_MaxPooling2D3 [null,9,4,32] 0
_________________________________________________________________
conv2d_Conv2D4 (Conv2D) [null,6,3,32] 8224
_________________________________________________________________
max_pooling2d_MaxPooling2D4 [null,5,1,32] 0
_________________________________________________________________
flatten_Flatten1 (Flatten) [null,160] 0
_________________________________________________________________
dense_Dense1 (Dense) [null,2000] 322000
_________________________________________________________________
dropout_Dropout1 (Dropout) [null,2000] 0
_________________________________________________________________
dense_Dense2 (Dense) [null,4] 8004
=================================================================
Total params: 348604
Trainable params: 348604
Non-trainable params: 0
```
4. Load the dataset.
You can use 'load_dataset all' command to load data for all labels that is configure for the previously created model. The root directory is where you untar the dataset file to. Each label should have corresponding directory in that root directory.
```none
local@piyu~$ load_dataset all /tmp/audio/data
✔ finished loading label: up (0)
✔ finished loading label: left (2)
✔ finished loading label: down (1)
✔ finished loading label: right (3)
```
You can also load data per label using 'load' command. For example loading data for the 'up' label.
```none
local@piyu~$ load_dataset /tmp/audio/data/up up
```
5. Show the dataset stats. You can review the dataset size and shape by running 'dataset size' command.
```none
local@piyu~$ dataset size
dataset size = xs: 8534,98,40,1 ys: 8534,4
```
6. Training the model. You can also specify the epochs for the 'train' command.
```none
local@piyu~$ train 5
✔ epoch: 0, loss: 1.35054, accuracy: 0.34792, validation accuracy: 0.42740
✔ epoch: 1, loss: 1.23458, accuracy: 0.45339, validation accuracy: 0.50351
✔ epoch: 2, loss: 1.06478, accuracy: 0.55833, validation accuracy: 0.62529
✔ epoch: 3, loss: 0.88953, accuracy: 0.63073, validation accuracy: 0.68735
✔ epoch: 4, loss: 0.78241, accuracy: 0.67799, validation accuracy: 0.73770
```
7 Save the trained model.
```none
local@piyu~$ save_model /tmp/audio_model
✔ /tmp/audio_model saved.
```
## Development

View File

@ -0,0 +1,229 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs-core';
import * as tfl from '@tensorflow/tfjs-layers';
import * as fs from 'fs';
/// <reference path="./types/node-wav.d.ts" />
import * as wav from 'node-wav';
import * as path from 'path';
import {Dataset} from './dataset';
import {WavFileFeatureExtractor} from './wav_file_feature_extractor';
/**
* Audio Model that creates tf.Model for a fix amount of labels. It requires a
* feature extractor to convert the audio stream into input tensors for the
* internal tf.Model.
* It provide datasets loading, training, and model saving functions.
*/
export class AudioModel {
private model: tfl.LayersModel;
/**
*
* @param inputShape Input tensor shape.
* @param labels Audio command label list
* @param dataset Dataset class to store the loaded data.
* @param featureExtractor converter to extractor features from audio stream
* as input tensors
*/
constructor(
inputShape: number[], private labels: string[], private dataset: Dataset,
private featureExtractor: WavFileFeatureExtractor) {
this.featureExtractor.config({
melCount: 40,
bufferLength: 480,
hopLength: 160,
targetSr: 16000,
isMfccEnabled: true,
duration: 1.0
});
this.model = this.createModel(inputShape);
}
private createModel(inputShape: number[]): tfl.LayersModel {
const model = tfl.sequential();
model.add(tfl.layers.conv2d(
{filters: 8, kernelSize: [4, 2], activation: 'relu', inputShape}));
model.add(tfl.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
model.add(tfl.layers.conv2d(
{filters: 32, kernelSize: [4, 2], activation: 'relu'}));
model.add(tfl.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
model.add(tfl.layers.conv2d(
{filters: 32, kernelSize: [4, 2], activation: 'relu'}));
model.add(tfl.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
model.add(tfl.layers.conv2d(
{filters: 32, kernelSize: [4, 2], activation: 'relu'}));
model.add(tfl.layers.maxPooling2d({poolSize: [2, 2], strides: [1, 2]}));
model.add(tfl.layers.flatten({}));
model.add(tfl.layers.dropout({rate: 0.25}));
model.add(tfl.layers.dense({units: 2000, activation: 'relu'}));
model.add(tfl.layers.dropout({rate: 0.5}));
model.add(
tfl.layers.dense({units: this.labels.length, activation: 'softmax'}));
model.compile({
loss: 'categoricalCrossentropy',
optimizer: tf.train.sgd(0.01),
metrics: ['accuracy']
});
model.summary();
return model;
}
/**
* Load all dataset for the root directory, all the subdirectories that have
* matching name to the entries in model label list, contained audio files
* will be converted to input tensors and stored in the dataset for training.
* @param dir The root directory of the audio dataset
* @param callback Callback function for display training logs
*/
async loadAll(dir: string, callback: Function) {
const promises = [];
this.labels.forEach(async (label, index) => {
callback(`loading label: ${label} (${index})`);
promises.push(
this.loadDataArray(path.resolve(dir, label), callback).then(v => {
callback(`finished loading label: ${label} (${index})`, true);
return [v, index];
}));
});
let allSpecs = await Promise.all(promises);
allSpecs = allSpecs
.map((specs, i) => {
const index = specs[1];
return specs[0].map(spec => [spec, index]);
})
.reduce((acc, currentValue) => acc.concat(currentValue), []);
tf.util.shuffle(allSpecs);
const specs = allSpecs.map(spec => spec[0]);
const labels = allSpecs.map(spec => spec[1]);
this.dataset.addExamples(
this.melSpectrogramToInput(specs),
tf.oneHot(labels, this.labels.length));
}
/**
* Load one dataset from directory, all contained audio files
* will be converted to input tensors and stored in the dataset for training.
* @param dir The directory of the audio dataset
* @param label The label for the audio dataset
* @param callback Callback function for display training logs
*/
async loadData(dir: string, label: string, callback: Function) {
const index = this.labels.indexOf(label);
const specs = await this.loadDataArray(dir, callback);
this.dataset.addExamples(
this.melSpectrogramToInput(specs),
tf.oneHot(tf.fill([specs.length], index, 'int32'), this.labels.length));
}
private loadDataArray(dir: string, callback: Function) {
return new Promise<Float32Array[][]>((resolve, reject) => {
fs.readdir(dir, (err, filenames) => {
if (err) {
reject(err);
}
let specs: Float32Array[][] = [];
filenames.forEach((filename) => {
callback('decoding ' + dir + '/' + filename + '...');
const spec = this.splitSpecs(this.decode(dir + '/' + filename));
if (!!spec) {
specs = specs.concat(spec);
}
callback('decoding ' + dir + '/' + filename + '...done');
});
resolve(specs);
});
});
}
private decode(filename: string) {
const result = wav.decode(fs.readFileSync(filename));
return this.featureExtractor.start(result.channelData[0]);
}
/**
* Train the model for stored dataset. The method call be called multiple
* times.
* @param epochs iteration of the training
* @param trainCallback
*/
async train(epochs?: number, trainCallback?: tfl.CustomCallbackArgs) {
return this.model.fit(this.dataset.xs, this.dataset.ys, {
batchSize: 64,
epochs: epochs || 100,
shuffle: true,
validationSplit: 0.1,
callbacks: trainCallback
});
}
/**
* Save the model to the specified directory.
* @param dir Directory to store the model.
*/
save(dir: string): Promise<tf.io.SaveResult> {
return this.model.save('file://' + dir);
}
/**
* Return the size of the dataset in string.
*/
size(): string {
return this.dataset.xs ?
`xs: ${this.dataset.xs.shape} ys: ${this.dataset.ys.shape}` :
'0';
}
private splitSpecs(spec: Float32Array[]) {
if (spec.length >= 98) {
const output = [];
for (let i = 0; i <= (spec.length - 98); i += 32) {
output.push(spec.slice(i, i + 98));
}
return output;
}
return undefined;
}
private melSpectrogramToInput(specs: Float32Array[][]): tf.Tensor {
// Flatten this spectrogram into a 2D array.
const batch = specs.length;
const times = specs[0].length;
const freqs = specs[0][0].length;
const data = new Float32Array(batch * times * freqs);
console.log(data.length);
for (let j = 0; j < batch; j++) {
const spec = specs[j];
for (let i = 0; i < times; i++) {
const mel = spec[i];
const offset = j * freqs * times + i * freqs;
data.set(mel, offset);
}
}
// Normalize the whole input to be in [0, 1].
const shape: [number, number, number, number] = [batch, times, freqs, 1];
// this.normalizeInPlace(data, 0, 1);
return tf.tensor4d(data, shape);
}
}

View File

@ -0,0 +1,139 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// Load the binding
import '@tensorflow/tfjs-node';
import chalk from 'chalk';
import * as ora from 'ora';
import * as Vorpal from 'vorpal';
import {AudioModel} from './audio_model';
import {Dataset} from './dataset';
import {WavFileFeatureExtractor} from './wav_file_feature_extractor';
// tslint:disable-next-line:no-any
(global as any).AudioContext = class AudioContext {};
export const MODEL_SHAPE = [98, 40, 1];
export const labelsMsg = [
{type: 'input', name: 'labels', message: 'Enter labels (seperate by comma)'}
];
export const trainingMsg = [
{type: 'input', name: 'dir', message: 'Enter file directory'},
{type: 'input', name: 'label', message: 'Enter label for the directory'}
];
export const filenameMsg = [{
type: 'input',
name: 'filename',
message: 'Enter target filename for the model'
}];
let model: AudioModel;
let labels: string[];
const vorpal = new Vorpal();
let spinner = ora();
vorpal.command('create_model [labels...]')
.alias('c')
.description('create the audio model')
.action((args, cb) => {
console.log(args.labels);
labels = args.labels as string[];
model = new AudioModel(
MODEL_SHAPE, labels, new Dataset(labels.length),
new WavFileFeatureExtractor());
cb();
});
vorpal
.command(
'load_dataset all <dir>',
'Load all the data from the root directory by the labels')
.alias('la')
.action((args) => {
spinner.start('load dataset ...');
return model
.loadAll(
args.dir as string,
(text: string, finished?: boolean) => {
if (finished) {
spinner.succeed(text);
} else {
spinner.start();
spinner.text = text;
spinner.render();
}
})
.then(() => spinner.stop());
});
vorpal
.command(
'load_dataset <dir> <label>',
'Load the dataset from the directory with the label')
.alias('l')
.action((args) => {
spinner = ora('creating tensors ...');
spinner.start();
return model
.loadData(
args.dir as string, args.label as string,
(text: string) => {
// console.log(text);
spinner.text = text;
spinner.render();
})
.then(() => spinner.stop(), (err) => {
spinner.fail(`failed to load: ${err}`);
});
});
vorpal.command('dataset size', 'Show the size of the dataset')
.alias('d')
.action((args, cb) => {
console.log(chalk.green(`dataset size = ${model.size()}`));
cb();
});
vorpal.command('train [epoch]')
.alias('t')
.description('train all audio dataset')
.action((args) => {
spinner = ora('training models ...').start();
return model
.train(parseInt(args.epoch as string, 10) || 20, {
onBatchEnd: async (batch, logs) => {
spinner.text = chalk.green(`loss: ${logs.loss.toFixed(5)}`);
spinner.render();
},
onEpochEnd: async (epoch, logs) => {
spinner.succeed(chalk.green(
`epoch: ${epoch}, loss: ${logs.loss.toFixed(5)}` +
`, accuracy: ${logs.acc.toFixed(5)}` +
`, validation accuracy: ${logs.val_acc.toFixed(5)}`));
spinner.start();
}
})
.then(() => spinner.stop());
});
vorpal.command('save_model <filename>')
.alias('s')
.description('save the audio model')
.action((args) => {
spinner.start(`saving to ${args.filename} ...`);
return model.save(args.filename as string).then(() => {
spinner.succeed(`${args.filename} saved.`);
}, () => spinner.fail(`failed to save ${args.filename}`));
});
vorpal.show();

View File

@ -0,0 +1,54 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs-core';
/**
* A dataset for webcam controls which allows the user to add example Tensors
* for particular labels. This object will concat them into two large xs and ys.
*/
export class Dataset {
xs: tf.Tensor;
ys: tf.Tensor;
constructor(public numClasses: number) {}
/**
* Adding data pair to the dataset, examples and labels should have the
* matching shape. For example, if the input shape is [2, 20, 20], 2 is the
* batch size, the labels shape should be [2,10] (num of classes is 10).
*
* @param examples Batch of inputs
* @param labels Matching labels for inputs
*/
addExamples(examples: tf.Tensor, labels: tf.Tensor) {
if (this.xs == null) {
// For the first example that gets added, keep example and y so that the
// Dataset owns the memory of the inputs. This makes sure that
// if addExample() is called in a tf.tidy(), these Tensors will not get
// disposed.
this.xs = tf.keep(examples);
this.ys = tf.keep(labels);
} else {
const oldX = this.xs;
this.xs = tf.keep(this.xs.concat(examples, 0));
const oldY = this.ys;
this.ys = tf.keep(oldY.concat(labels, 0));
oldX.dispose();
oldY.dispose();
}
}
}

View File

@ -0,0 +1,38 @@
{
"name": "audio-command-model-node",
"version": "0.0.1",
"description": "tfjs audio command model training in node.js",
"main": "./cli",
"license": "Apache-2.0",
"private": true,
"bin": "./cli",
"dependencies": {
"@tensorflow/tfjs": "^3.3.0",
"@tensorflow/tfjs-node": "^3.3.0",
"chalk": "^2.4.1",
"vorpal": "1.12.0",
"node-wav": "^0.0.2",
"ora": "^2.1.0",
"ts-node": "7.0.0",
"dct": "^0.0.3",
"kissfft-js": "0.1.8"
},
"scripts": {
"build": "tsc",
"lint": "tslint -p . -t verbose",
"start": "nodemon --exec ts-node -- cli.ts",
"ts-node": "ts-node cli.ts"
},
"devDependencies": {
"@types/chalk": "^2.2.0",
"@types/node": "^10.5.2",
"@types/ora": "^1.3.4",
"@types/inquirer": "^0.0.42",
"@types/minimist": "^1.2.0",
"clang-format": "~1.2.2",
"typescript": "3.5.3",
"nodemon": "1.18.2",
"tslint": "~6.1.3",
"tslint-no-circular-imports": "~0.7.0"
}
}

View File

@ -0,0 +1,26 @@
{
"compilerOptions": {
"module": "commonjs",
//"noImplicitAny": true,
"sourceMap": true,
"removeComments": true,
"preserveConstEnums": true,
"declaration": true,
"target": "es5",
"lib": ["es2015", "dom"],
"outDir": "./dist",
"noUnusedLocals": true,
"noImplicitReturns": true,
"noImplicitThis": true,
"alwaysStrict": true,
"noUnusedParameters": false,
"pretty": true,
"noFallthroughCasesInSwitch": true,
"allowUnreachableCode": false,
"downlevelIteration": true,
"moduleResolution": "node"
},
"include": [
"*.ts", "utils/**/*"
]
}

View File

@ -0,0 +1,17 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
declare module 'node-wav';

View File

@ -0,0 +1,18 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
declare module 'vorpal';

View File

@ -0,0 +1,259 @@
/**
* Copyright 2019 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
import * as DCT from 'dct';
import * as KissFFT from 'kissfft-js';
const SR = 16000;
const hannWindowMap: {[key: number]: number[]} = {};
let context: AudioContext;
export class AudioUtils {
startIndex = 0;
endIndex = 0;
bandMapper: number[] = [];
context: AudioContext;
constructor() {}
/**
* Gets periodic hann window
* @param windowLength size of the hann window
* @returns periodic hann map
*/
GetPeriodicHann(windowLength: number): number[] {
if (!hannWindowMap[windowLength]) {
const window = [];
// Some platforms don't have M_PI, so define a local constant here.
for (let i = 0; i < windowLength; ++i) {
window[i] = 0.5 - 0.5 * Math.cos((2 * Math.PI * i) / windowLength);
}
hannWindowMap[windowLength] = window;
}
return hannWindowMap[windowLength];
}
/**
* Calculates the FFT for an array buffer. Output is an array.
*/
fft(y: Float32Array) {
const window = this.GetPeriodicHann(y.length);
y = y.map((v, index) => v * window[index]);
const fftSize = nextPowerOfTwo(y.length);
for (let i = y.length; i < fftSize; i++) {
y[i] = 0;
}
const fftr = new KissFFT.FFTR(fftSize);
const transform = fftr.forward(y);
fftr.dispose();
transform[fftSize] = transform[1];
transform[fftSize + 1] = 0;
transform[1] = 0;
return transform;
}
/**
* Calculate the DCT encoding for spectrogram.
* @param y spectrogram data
* @returns DCT encoded
*/
dct(y: Float32Array): Float32Array {
const scale = Math.sqrt(2.0 / y.length);
return DCT(y, scale);
}
/**
* Given an interlaced complex array (y_i is real, y_(i+1) is imaginary),
* calculates the energies. Output is half the size.
*/
fftEnergies(y: Float32Array): Float32Array {
const out = new Float32Array(y.length / 2);
for (let i = 0; i < y.length / 2; i++) {
out[i] = y[i * 2] * y[i * 2] + y[i * 2 + 1] * y[i * 2 + 1];
}
return out;
}
/**
* Creates mel filterbank map for the give melCount size
* @param fftSize FFT frequence count
* @param [melCount] Mel filterbank count
* @param [lowHz] low bank filter frequence
* @param [highHz] high bank filter frequence
* @param [sr] sampling rate
* @returns mel filterbank map
*/
createMelFilterbank(
fftSize: number, melCount = 40, lowHz = 20, highHz = 4000,
sr = SR): Float32Array {
const lowMel = this.hzToMel(lowHz);
const highMel = this.hzToMel(highHz);
// Construct linearly spaced array of melCount intervals, between lowMel and
// highMel.
const mels = [];
const melSpan = highMel - lowMel;
const melSpacing = melSpan / (melCount + 1);
for (let i = 0; i < melCount + 1; ++i) {
mels[i] = lowMel + (melSpacing * (i + 1));
}
// Always exclude DC; emulate HTK.
const hzPerSbin = 0.5 * sr / (fftSize - 1);
this.startIndex = Math.floor(1.5 + (lowHz / hzPerSbin));
this.endIndex = Math.ceil(highHz / hzPerSbin);
// Maps the input spectrum bin indices to filter bank channels/indices. For
// each FFT bin, band_mapper tells us which channel this bin contributes to
// on the right side of the triangle. Thus this bin also contributes to the
// left side of the next channel's triangle response.
this.bandMapper = [];
let channel = 0;
for (let i = 0; i < fftSize; ++i) {
const melf = this.hzToMel(i * hzPerSbin);
if ((i < this.startIndex) || (i > this.endIndex)) {
this.bandMapper[i] = -2; // Indicate an unused Fourier coefficient.
} else {
while ((mels[channel] < melf) && (channel < melCount)) {
++channel;
}
this.bandMapper[i] = channel - 1; // Can be == -1
}
}
// Create the weighting functions to taper the band edges. The contribution
// of any one FFT bin is based on its distance along the continuum between
// two mel-channel center frequencies. This bin contributes weights_[i] to
// the current channel and 1-weights_[i] to the next channel.
const weights = new Float32Array(fftSize);
for (let i = 0; i < fftSize; ++i) {
channel = this.bandMapper[i];
if ((i < this.startIndex) || (i > this.endIndex)) {
weights[i] = 0.0;
} else {
if (channel >= 0) {
weights[i] = (mels[channel + 1] - this.hzToMel(i * hzPerSbin)) /
(mels[channel + 1] - mels[channel]);
} else {
weights[i] =
(mels[0] - this.hzToMel(i * hzPerSbin)) / (mels[0] - lowMel);
}
}
}
return weights;
}
/**
* Given an array of FFT magnitudes, apply a filterbank. Output should be an
* array with size |filterbank|.
*/
applyFilterbank(
fftEnergies: Float32Array, filterbank: Float32Array,
melCount = 40): Float32Array {
const out = new Float32Array(melCount);
for (let i = this.startIndex; i <= this.endIndex;
i++) { // For each FFT bin
const specVal = Math.sqrt(fftEnergies[i]);
const weighted = specVal * filterbank[i];
let channel = this.bandMapper[i];
if (channel >= 0) {
out[channel] += weighted; // Right side of triangle, downward slope
}
channel++;
if (channel < melCount) {
out[channel] += (specVal - weighted); // Left side of triangle
}
}
for (let i = 0; i < out.length; ++i) {
let val = out[i];
if (val < 1e-12) {
val = 1e-12;
}
out[i] = Math.log(val);
}
return out;
}
private hzToMel(hz: number) {
return 1127.0 * Math.log(1.0 + hz / 700.0);
}
/**
* Cepstrums from the energy spectrumgram
* @param melEnergies array of melbank energies
* @returns
*/
cepstrumFromEnergySpectrum(melEnergies: Float32Array) {
return this.dct(melEnergies);
}
/**
* Playbacks audio data from array buffer using the given sample rate.
* @param buffer audio data
* @param [sampleRate] playback sample rate
*/
playbackArrayBuffer(buffer: Float32Array, sampleRate?: number) {
if (!context) {
context = new AudioContext();
}
if (!sampleRate) {
sampleRate = this.context.sampleRate;
}
const audioBuffer = context.createBuffer(1, buffer.length, sampleRate);
const audioBufferData = audioBuffer.getChannelData(0);
audioBufferData.set(buffer);
const source = context.createBufferSource();
source.buffer = audioBuffer;
source.connect(context.destination);
source.start();
}
/**
* Resamples web audio data by the target sample rate.
* @param audioBuffer Audio data
* @param targetSr Target sample rate
* @returns resampled web audio data
*/
resampleWebAudio(audioBuffer: AudioBuffer, targetSr: number):
Promise<AudioBuffer> {
const sourceSr = audioBuffer.sampleRate;
const lengthRes = audioBuffer.length * targetSr / sourceSr;
const offlineCtx = new OfflineAudioContext(1, lengthRes, targetSr);
return new Promise((resolve, reject) => {
const bufferSource = offlineCtx.createBufferSource();
bufferSource.buffer = audioBuffer;
offlineCtx.oncomplete = (event) => {
resolve(event.renderedBuffer);
};
bufferSource.connect(offlineCtx.destination);
bufferSource.start();
offlineCtx.startRendering();
});
}
}
/**
* Next power of two value for the given number.
* @param value
* @returns
*/
export function nextPowerOfTwo(value: number) {
const exponent = Math.ceil(Math.log2(value));
return 1 << exponent;
}

View File

@ -0,0 +1,108 @@
/**
* Copyright 2019 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
/**
* Save Float32Array in arbitrarily sized chunks.
* Load Float32Array in arbitrarily sized chunks.
* Determine if there's enough data to grab a certain amount.
*/
export class CircularAudioBuffer {
buffer: Float32Array;
// The index that we are currently full up to. New data is written from
// [currentIndex + 1, maxLength]. Data can be read from [0, currentIndex].
currentIndex: number;
constructor(maxLength: number) {
this.buffer = new Float32Array(maxLength);
this.currentIndex = 0;
}
/**
* Add a new buffer of data. Called when we get new audio input samples.
*/
addBuffer(newBuffer: Float32Array) {
// Do we have enough data in this buffer?
const remaining = this.buffer.length - this.currentIndex;
if (this.currentIndex + newBuffer.length > this.buffer.length) {
console.error(
`Not enough space to write ${newBuffer.length}` +
` to this circular buffer with ${remaining} left.`);
return;
}
this.buffer.set(newBuffer, this.currentIndex);
this.currentIndex += newBuffer.length;
}
/**
* How many samples are stored currently?
*/
getLength() {
return this.currentIndex;
}
/**
* How much space remains?
*/
getRemainingLength() {
return this.buffer.length - this.currentIndex;
}
/**
* Return the first N samples of the buffer, and remove them. Called when we
* want to get a buffer of audio data of a fixed size.
*/
popBuffer(length: number) {
// Do we have enough data to read back?
if (this.currentIndex < length) {
console.error(
`This circular buffer doesn't have ${length} entries in it.`);
return undefined;
}
if (length === 0) {
console.warn(`Calling popBuffer(0) does nothing.`);
return undefined;
}
const popped = this.buffer.slice(0, length);
const remaining = this.buffer.slice(length, this.buffer.length);
// Remove the popped entries from the buffer.
this.buffer.fill(0);
this.buffer.set(remaining, 0);
// Send the currentIndex back.
this.currentIndex -= length;
return popped;
}
/**
* Get the the first part of the buffer without mutating it.
*/
getBuffer(length?: number) {
if (!length) {
length = this.getLength();
}
// Do we have enough data to read back?
if (this.currentIndex < length) {
console.error(
`This circular buffer doesn't have ${length} entries in it.`);
return undefined;
}
return this.buffer.slice(0, length);
}
clear() {
this.currentIndex = 0;
this.buffer.fill(0);
}
}

View File

@ -0,0 +1,60 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs-core';
/**
* A dataset for webcam controls which allows the user to add example Tensors
* for particular labels. This object will concat them into two large xs and ys.
*/
export class Dataset {
xs: tf.Tensor[];
ys: tf.Tensor;
constructor(public numClasses: number) {}
/**
* Adds an example to the controller dataset.
* @param {Tensor} example A tensor representing the example.
* It can be an image, an activation, or any other type of Tensor.
* @param {number} label The label of the example. Should be an number.
*/
addExample(example: tf.Tensor|tf.Tensor[], label: number) {
example = Array.isArray(example) ? example : [example];
// One-hot encode the label.
const y =
tf.tidy(() => tf.oneHot(tf.tensor1d([label]).toInt(), this.numClasses));
if (this.xs == null) {
// For the first example that gets added, keep example and y so that the
// Dataset owns the memory of the inputs. This makes sure that
// if addExample() is called in a tf.tidy(), these Tensors will not get
// disposed.
this.xs = example.map(tensor => tf.keep(tensor));
this.ys = tf.keep(y);
} else {
const oldX = this.xs;
this.xs = example.map(
(tensor, index) => tf.keep(this.xs[index].concat(tensor, 0)));
const oldY = this.ys;
this.ys = tf.keep(oldY.concat(y, 0));
oldX.forEach(tensor => tensor.dispose());
oldY.dispose();
y.dispose();
}
}
}

View File

@ -0,0 +1,53 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {InferenceModel} from '@tensorflow/tfjs';
export interface Params {
inputBufferLength?: number;
bufferLength?: number;
hopLength?: number;
duration?: number;
fftSize?: number;
melCount?: number;
targetSr?: number;
isMfccEnabled?: boolean;
}
export interface FeatureExtractor {
config(params: Params): void;
start(samples?: Float32Array): Promise<Float32Array[]>|void;
stop(): void;
getFeatures(): Float32Array[];
getImages(): Float32Array[];
}
export enum ModelType {
FROZEN_MODEL = 0,
FROZEN_MODEL_NATIVE,
TF_MODEL
}
export const BUFFER_LENGTH = 1024;
export const HOP_LENGTH = 444;
export const MEL_COUNT = 40;
export const EXAMPLE_SR = 44100;
export const DURATION = 1.0;
export const IS_MFCC_ENABLED = true;
export const MIN_SAMPLE = 3;
export const DETECTION_THRESHOLD = 0.5;
export const SUPPRESSION_TIME = 500;
export const MODELS: {[key: number]: InferenceModel} = {};

View File

@ -0,0 +1,17 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
declare module 'dct';

View File

@ -0,0 +1,17 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
declare module 'kissfft-js';

View File

@ -0,0 +1,95 @@
import {AudioUtils} from './utils/audio_utils';
import {Params} from './utils/types';
import {nextPowerOfTwo} from './utils/audio_utils';
export class WavFileFeatureExtractor {
private features: Float32Array[];
// Target sample rate.
targetSr = 16000;
// How long the buffer is.
bufferLength = 480;
// How many mel bins to use.
melCount = 40;
// Number of samples to hop over for every new column.
hopLength = 160;
// How long the total duration is.
duration = 1.0;
// Whether to use MFCC or Mel features.
isMfccEnabled = true;
fftSize = 512;
// How many buffers to keep in the spectrogram.
bufferCount: number;
// The mel filterbank (calculate it only once).
melFilterbank: Float32Array;
audioUtils = new AudioUtils();
config(params: Params) {
Object.assign(this, params);
this.bufferCount = Math.floor(
(this.duration * this.targetSr - this.bufferLength) /
this.hopLength) +
1;
if (this.hopLength > this.bufferLength) {
console.error('Hop length must be smaller than buffer length.');
}
// The mel filterbank is actually half of the size of the number of samples,
// since the FFT array is complex valued.
this.fftSize = nextPowerOfTwo(this.bufferLength);
this.melFilterbank = this.audioUtils.createMelFilterbank(
this.fftSize / 2 + 1, this.melCount);
}
start(samples: Float32Array): Float32Array[] {
this.features = [];
// Get buffer(s) out of the circular buffer. Note that there may be
// multiple available, and if there are, we should get them all.
const buffers = this.getFullBuffers(samples);
for (const buffer of buffers) {
// console.log(`Got buffer of length ${buffer.length}.`);
// Extract the mel values for this new frame of audio data.
const fft = this.audioUtils.fft(buffer);
const fftEnergies = this.audioUtils.fftEnergies(fft);
const melEnergies =
this.audioUtils.applyFilterbank(fftEnergies, this.melFilterbank);
const mfccs = this.audioUtils.cepstrumFromEnergySpectrum(melEnergies);
if (this.isMfccEnabled) {
this.features.push(mfccs);
} else {
this.features.push(melEnergies);
}
}
return this.features;
}
stop() {}
transform(data: Float32Array) {
return data;
}
getFeatures(): Float32Array[] {
return this.features;
}
getImages(): Float32Array[] {
throw new Error('Method not implemented.');
}
/**
* Get as many full buffers as are available in the circular buffer.
*/
private getFullBuffers(sample: Float32Array) {
const out = [];
let index = 0;
// While we have enough data in the buffer.
while (index <= sample.length - this.bufferLength) {
// Get a buffer of desired size.
const buffer = sample.slice(index, index + this.bufferLength);
index += this.hopLength;
out.push(buffer);
}
return out;
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,13 @@
{
"extends": "../tsconfig",
"include": [
"src/"
],
"exclude": [
"node_modules/"
],
"compilerOptions": {
"outDir": "./dist",
"downlevelIteration": true
}
}

View File

@ -0,0 +1,14 @@
{
"extends": "../tsconfig.test",
"include": [
"src/"
],
"exclude": [
"node_modules/"
],
"compilerOptions": {
"importHelpers": true,
"outDir": "./dist",
"downlevelIteration": true
}
}

View File

@ -0,0 +1,3 @@
{
"extends": "../tslint.json"
}

File diff suppressed because it is too large Load Diff

View File

@ -168,10 +168,11 @@
等待模型训练完成并开始识别... 等待模型训练完成并开始识别...
</div> </div>
<!-- 引入 TensorFlow.js 库 -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest/dist/tf.min.js"></script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest/dist/tf.min.js"></script>
<!-- 引入 Speech Commands 模型库 --> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/speech-commands@latest/dist/speech-commands.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/speech-commands@0.5.4/dist/speech-commands.min.js"></script>
<!-- 你的 JavaScript 代码 --> <!-- 你的 JavaScript 代码 -->
<script src="script.js"></script> <script src="script.js"></script>