714 lines
25 KiB
JavaScript
714 lines
25 KiB
JavaScript
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* https://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
|
|
import * as tf from '@tensorflow/tfjs';
|
|
import Plotly from 'plotly.js-dist';
|
|
|
|
import * as SpeechCommands from '@tensorflow-models/speech-commands';
|
|
|
|
import {DatasetViz, removeNonFixedChildrenFromWordDiv} from './dataset-vis';
|
|
import {hideCandidateWords, logToStatusDisplay, plotPredictions, plotSpectrogram, populateCandidateWords, showCandidateWords} from './ui';
|
|
|
|
const startButton = document.getElementById('start');
|
|
const stopButton = document.getElementById('stop');
|
|
const predictionCanvas = document.getElementById('prediction-canvas');
|
|
|
|
const probaThresholdInput = document.getElementById('proba-threshold');
|
|
const epochsInput = document.getElementById('epochs');
|
|
const fineTuningEpochsInput = document.getElementById('fine-tuning-epochs');
|
|
|
|
const datasetIOButton = document.getElementById('dataset-io');
|
|
const datasetIOInnerDiv = document.getElementById('dataset-io-inner');
|
|
const downloadAsFileButton = document.getElementById('download-dataset');
|
|
const datasetFileInput = document.getElementById('dataset-file-input');
|
|
const uploadFilesButton = document.getElementById('upload-dataset');
|
|
|
|
const evalModelOnDatasetButton =
|
|
document.getElementById('eval-model-on-dataset');
|
|
const evalResultsSpan = document.getElementById('eval-results');
|
|
|
|
const modelIOButton = document.getElementById('model-io');
|
|
const transferModelSaveLoadInnerDiv =
|
|
document.getElementById('transfer-model-save-load-inner');
|
|
const loadTransferModelButton = document.getElementById('load-transfer-model');
|
|
const saveTransferModelButton = document.getElementById('save-transfer-model');
|
|
const savedTransferModelsSelect =
|
|
document.getElementById('saved-transfer-models');
|
|
const deleteTransferModelButton =
|
|
document.getElementById('delete-transfer-model');
|
|
|
|
const BACKGROUND_NOISE_TAG = SpeechCommands.BACKGROUND_NOISE_TAG;
|
|
|
|
/**
|
|
* Transfer learning-related UI componenets.
|
|
*/
|
|
const transferModelNameInput = document.getElementById('transfer-model-name');
|
|
const learnWordsInput = document.getElementById('learn-words');
|
|
const durationMultiplierSelect = document.getElementById('duration-multiplier');
|
|
const enterLearnWordsButton = document.getElementById('enter-learn-words');
|
|
const includeTimeDomainWaveformCheckbox =
|
|
document.getElementById('include-audio-waveform');
|
|
const collectButtonsDiv = document.getElementById('collect-words');
|
|
const startTransferLearnButton =
|
|
document.getElementById('start-transfer-learn');
|
|
|
|
const XFER_MODEL_NAME = 'xfer-model';
|
|
|
|
// Minimum required number of examples per class for transfer learning.
|
|
const MIN_EXAMPLES_PER_CLASS = 8;
|
|
|
|
let recognizer;
|
|
let transferWords;
|
|
let transferRecognizer;
|
|
let transferDurationMultiplier;
|
|
|
|
(async function() {
|
|
logToStatusDisplay('Creating recognizer...');
|
|
recognizer = SpeechCommands.create('BROWSER_FFT');
|
|
|
|
await populateSavedTransferModelsSelect();
|
|
|
|
// Make sure the tf.Model is loaded through HTTP. If this is not
|
|
// called here, the tf.Model will be loaded the first time
|
|
// `listen()` is called.
|
|
recognizer.ensureModelLoaded()
|
|
.then(() => {
|
|
startButton.disabled = false;
|
|
enterLearnWordsButton.disabled = false;
|
|
loadTransferModelButton.disabled = false;
|
|
deleteTransferModelButton.disabled = false;
|
|
|
|
transferModelNameInput.value = `model-${getDateString()}`;
|
|
|
|
logToStatusDisplay('Model loaded.');
|
|
|
|
const params = recognizer.params();
|
|
logToStatusDisplay(`sampleRateHz: ${params.sampleRateHz}`);
|
|
logToStatusDisplay(`fftSize: ${params.fftSize}`);
|
|
logToStatusDisplay(
|
|
`spectrogramDurationMillis: ` +
|
|
`${params.spectrogramDurationMillis.toFixed(2)}`);
|
|
logToStatusDisplay(
|
|
`tf.Model input shape: ` +
|
|
`${JSON.stringify(recognizer.modelInputShape())}`);
|
|
})
|
|
.catch(err => {
|
|
logToStatusDisplay(
|
|
'Failed to load model for recognizer: ' + err.message);
|
|
});
|
|
})();
|
|
|
|
startButton.addEventListener('click', () => {
|
|
const activeRecognizer =
|
|
transferRecognizer == null ? recognizer : transferRecognizer;
|
|
populateCandidateWords(activeRecognizer.wordLabels());
|
|
|
|
const suppressionTimeMillis = 1000;
|
|
activeRecognizer
|
|
.listen(
|
|
result => {
|
|
plotPredictions(
|
|
predictionCanvas, activeRecognizer.wordLabels(), result.scores,
|
|
3, suppressionTimeMillis);
|
|
},
|
|
{
|
|
includeSpectrogram: true,
|
|
suppressionTimeMillis,
|
|
probabilityThreshold: Number.parseFloat(probaThresholdInput.value)
|
|
})
|
|
.then(() => {
|
|
startButton.disabled = true;
|
|
stopButton.disabled = false;
|
|
showCandidateWords();
|
|
logToStatusDisplay('Streaming recognition started.');
|
|
})
|
|
.catch(err => {
|
|
logToStatusDisplay(
|
|
'ERROR: Failed to start streaming display: ' + err.message);
|
|
});
|
|
});
|
|
|
|
stopButton.addEventListener('click', () => {
|
|
const activeRecognizer =
|
|
transferRecognizer == null ? recognizer : transferRecognizer;
|
|
activeRecognizer.stopListening()
|
|
.then(() => {
|
|
startButton.disabled = false;
|
|
stopButton.disabled = true;
|
|
hideCandidateWords();
|
|
logToStatusDisplay('Streaming recognition stopped.');
|
|
})
|
|
.catch(err => {
|
|
logToStatusDisplay(
|
|
'ERROR: Failed to stop streaming display: ' + err.message);
|
|
});
|
|
});
|
|
|
|
/**
|
|
* Transfer learning logic.
|
|
*/
|
|
|
|
/** Scroll to the bottom of the page */
|
|
function scrollToPageBottom() {
|
|
const scrollingElement = (document.scrollingElement || document.body);
|
|
scrollingElement.scrollTop = scrollingElement.scrollHeight;
|
|
}
|
|
|
|
let collectWordButtons = {};
|
|
let datasetViz;
|
|
|
|
function createProgressBarAndIntervalJob(parentElement, durationSec) {
|
|
const progressBar = document.createElement('progress');
|
|
progressBar.value = 0;
|
|
progressBar.style['width'] = `${Math.round(window.innerWidth * 0.25)}px`;
|
|
// Update progress bar in increments.
|
|
const intervalJob = setInterval(() => {
|
|
progressBar.value += 0.05;
|
|
}, durationSec * 1e3 / 20);
|
|
parentElement.appendChild(progressBar);
|
|
return {progressBar, intervalJob};
|
|
}
|
|
|
|
/**
|
|
* Create div elements for transfer words.
|
|
*
|
|
* @param {string[]} transferWords The array of transfer words.
|
|
* @returns {Object} An object mapping word to th div element created for it.
|
|
*/
|
|
function createWordDivs(transferWords) {
|
|
// Clear collectButtonsDiv first.
|
|
while (collectButtonsDiv.firstChild) {
|
|
collectButtonsDiv.removeChild(collectButtonsDiv.firstChild);
|
|
}
|
|
datasetViz = new DatasetViz(
|
|
transferRecognizer, collectButtonsDiv, MIN_EXAMPLES_PER_CLASS,
|
|
startTransferLearnButton, downloadAsFileButton,
|
|
transferDurationMultiplier);
|
|
|
|
const wordDivs = {};
|
|
for (const word of transferWords) {
|
|
const wordDiv = document.createElement('div');
|
|
wordDiv.classList.add('word-div');
|
|
wordDivs[word] = wordDiv;
|
|
wordDiv.setAttribute('word', word);
|
|
const button = document.createElement('button');
|
|
button.setAttribute('isFixed', 'true');
|
|
button.style['display'] = 'inline-block';
|
|
button.style['vertical-align'] = 'middle';
|
|
|
|
const displayWord = word === BACKGROUND_NOISE_TAG ? 'noise' : word;
|
|
|
|
button.textContent = `${displayWord} (0)`;
|
|
wordDiv.appendChild(button);
|
|
wordDiv.className = 'transfer-word';
|
|
collectButtonsDiv.appendChild(wordDiv);
|
|
collectWordButtons[word] = button;
|
|
|
|
let durationInput;
|
|
if (word === BACKGROUND_NOISE_TAG) {
|
|
// Create noise duration input.
|
|
durationInput = document.createElement('input');
|
|
durationInput.setAttribute('isFixed', 'true');
|
|
durationInput.value = '10';
|
|
durationInput.style['width'] = '100px';
|
|
wordDiv.appendChild(durationInput);
|
|
// Create time-unit span for noise duration.
|
|
const timeUnitSpan = document.createElement('span');
|
|
timeUnitSpan.setAttribute('isFixed', 'true');
|
|
timeUnitSpan.classList.add('settings');
|
|
timeUnitSpan.style['vertical-align'] = 'middle';
|
|
timeUnitSpan.textContent = 'seconds';
|
|
wordDiv.appendChild(timeUnitSpan);
|
|
}
|
|
|
|
button.addEventListener('click', async () => {
|
|
disableAllCollectWordButtons();
|
|
removeNonFixedChildrenFromWordDiv(wordDiv);
|
|
|
|
const collectExampleOptions = {};
|
|
let durationSec;
|
|
let intervalJob;
|
|
let progressBar;
|
|
|
|
if (word === BACKGROUND_NOISE_TAG) {
|
|
// If the word type is background noise, display a progress bar during
|
|
// sound collection and do not show an incrementally updating
|
|
// spectrogram.
|
|
// _background_noise_ examples are special, in that user can specify
|
|
// the length of the recording (in seconds).
|
|
collectExampleOptions.durationSec =
|
|
Number.parseFloat(durationInput.value);
|
|
durationSec = collectExampleOptions.durationSec;
|
|
|
|
const barAndJob = createProgressBarAndIntervalJob(wordDiv, durationSec);
|
|
progressBar = barAndJob.progressBar;
|
|
intervalJob = barAndJob.intervalJob;
|
|
} else {
|
|
// If this is not a background-noise word type and if the duration
|
|
// multiplier is >1 (> ~1 s recoding), show an incrementally
|
|
// updating spectrogram in real time.
|
|
collectExampleOptions.durationMultiplier = transferDurationMultiplier;
|
|
let tempSpectrogramData;
|
|
const tempCanvas = document.createElement('canvas');
|
|
tempCanvas.style['margin-left'] = '132px';
|
|
tempCanvas.height = 50;
|
|
wordDiv.appendChild(tempCanvas);
|
|
|
|
collectExampleOptions.snippetDurationSec = 0.1;
|
|
collectExampleOptions.onSnippet = async (spectrogram) => {
|
|
if (tempSpectrogramData == null) {
|
|
tempSpectrogramData = spectrogram.data;
|
|
} else {
|
|
tempSpectrogramData = SpeechCommands.utils.concatenateFloat32Arrays(
|
|
[tempSpectrogramData, spectrogram.data]);
|
|
}
|
|
plotSpectrogram(
|
|
tempCanvas, tempSpectrogramData, spectrogram.frameSize,
|
|
spectrogram.frameSize, {pixelsPerFrame: 2});
|
|
}
|
|
}
|
|
|
|
collectExampleOptions.includeRawAudio =
|
|
includeTimeDomainWaveformCheckbox.checked;
|
|
const spectrogram =
|
|
await transferRecognizer.collectExample(word, collectExampleOptions);
|
|
|
|
|
|
if (intervalJob != null) {
|
|
clearInterval(intervalJob);
|
|
}
|
|
if (progressBar != null) {
|
|
wordDiv.removeChild(progressBar);
|
|
}
|
|
const examples = transferRecognizer.getExamples(word)
|
|
const example = examples[examples.length - 1];
|
|
await datasetViz.drawExample(
|
|
wordDiv, word, spectrogram, example.example.rawAudio, example.uid);
|
|
enableAllCollectWordButtons();
|
|
});
|
|
}
|
|
return wordDivs;
|
|
}
|
|
|
|
enterLearnWordsButton.addEventListener('click', () => {
|
|
const modelName = transferModelNameInput.value;
|
|
if (modelName == null || modelName.length === 0) {
|
|
enterLearnWordsButton.textContent = 'Need model name!';
|
|
setTimeout(() => {
|
|
enterLearnWordsButton.textContent = 'Enter transfer words';
|
|
}, 2000);
|
|
return;
|
|
}
|
|
|
|
// We disable the option to upload an existing dataset from files
|
|
// once the "Enter transfer words" button has been clicked.
|
|
// However, the user can still load an existing dataset from
|
|
// files first and keep appending examples to it.
|
|
disableFileUploadControls();
|
|
enterLearnWordsButton.disabled = true;
|
|
|
|
transferDurationMultiplier = durationMultiplierSelect.value;
|
|
|
|
learnWordsInput.disabled = true;
|
|
enterLearnWordsButton.disabled = true;
|
|
transferWords = learnWordsInput.value.trim().split(',').map(w => w.trim());
|
|
transferWords.sort();
|
|
if (transferWords == null || transferWords.length <= 1) {
|
|
logToStatusDisplay('ERROR: Invalid list of transfer words.');
|
|
return;
|
|
}
|
|
|
|
transferRecognizer = recognizer.createTransfer(modelName);
|
|
createWordDivs(transferWords);
|
|
|
|
scrollToPageBottom();
|
|
});
|
|
|
|
function disableAllCollectWordButtons() {
|
|
for (const word in collectWordButtons) {
|
|
collectWordButtons[word].disabled = true;
|
|
}
|
|
}
|
|
|
|
function enableAllCollectWordButtons() {
|
|
for (const word in collectWordButtons) {
|
|
collectWordButtons[word].disabled = false;
|
|
}
|
|
}
|
|
|
|
function disableFileUploadControls() {
|
|
datasetFileInput.disabled = true;
|
|
uploadFilesButton.disabled = true;
|
|
}
|
|
|
|
startTransferLearnButton.addEventListener('click', async () => {
|
|
startTransferLearnButton.disabled = true;
|
|
startButton.disabled = true;
|
|
startTransferLearnButton.textContent = 'Transfer learning starting...';
|
|
await tf.nextFrame();
|
|
|
|
const INITIAL_PHASE = 'initial';
|
|
const FINE_TUNING_PHASE = 'fineTuningPhase';
|
|
|
|
const epochs = parseInt(epochsInput.value);
|
|
const fineTuningEpochs = parseInt(fineTuningEpochsInput.value);
|
|
const trainLossValues = {};
|
|
const valLossValues = {};
|
|
const trainAccValues = {};
|
|
const valAccValues = {};
|
|
|
|
for (const phase of [INITIAL_PHASE, FINE_TUNING_PHASE]) {
|
|
const phaseSuffix = phase === FINE_TUNING_PHASE ? ' (FT)' : '';
|
|
const lineWidth = phase === FINE_TUNING_PHASE ? 2 : 1;
|
|
trainLossValues[phase] = {
|
|
x: [],
|
|
y: [],
|
|
name: 'train' + phaseSuffix,
|
|
mode: 'lines',
|
|
line: {width: lineWidth}
|
|
};
|
|
valLossValues[phase] = {
|
|
x: [],
|
|
y: [],
|
|
name: 'val' + phaseSuffix,
|
|
mode: 'lines',
|
|
line: {width: lineWidth}
|
|
};
|
|
trainAccValues[phase] = {
|
|
x: [],
|
|
y: [],
|
|
name: 'train' + phaseSuffix,
|
|
mode: 'lines',
|
|
line: {width: lineWidth}
|
|
};
|
|
valAccValues[phase] = {
|
|
x: [],
|
|
y: [],
|
|
name: 'val' + phaseSuffix,
|
|
mode: 'lines',
|
|
line: {width: lineWidth}
|
|
};
|
|
}
|
|
|
|
function plotLossAndAccuracy(epoch, loss, acc, val_loss, val_acc, phase) {
|
|
const displayEpoch = phase === FINE_TUNING_PHASE ? (epoch + epochs) : epoch;
|
|
trainLossValues[phase].x.push(displayEpoch);
|
|
trainLossValues[phase].y.push(loss);
|
|
trainAccValues[phase].x.push(displayEpoch);
|
|
trainAccValues[phase].y.push(acc);
|
|
valLossValues[phase].x.push(displayEpoch);
|
|
valLossValues[phase].y.push(val_loss);
|
|
valAccValues[phase].x.push(displayEpoch);
|
|
valAccValues[phase].y.push(val_acc);
|
|
|
|
Plotly.newPlot(
|
|
'loss-plot',
|
|
[
|
|
trainLossValues[INITIAL_PHASE], valLossValues[INITIAL_PHASE],
|
|
trainLossValues[FINE_TUNING_PHASE], valLossValues[FINE_TUNING_PHASE]
|
|
],
|
|
{
|
|
width: 480,
|
|
height: 360,
|
|
xaxis: {title: 'Epoch #'},
|
|
yaxis: {title: 'Loss'},
|
|
font: {size: 18}
|
|
});
|
|
Plotly.newPlot(
|
|
'accuracy-plot',
|
|
[
|
|
trainAccValues[INITIAL_PHASE], valAccValues[INITIAL_PHASE],
|
|
trainAccValues[FINE_TUNING_PHASE], valAccValues[FINE_TUNING_PHASE]
|
|
],
|
|
{
|
|
width: 480,
|
|
height: 360,
|
|
xaxis: {title: 'Epoch #'},
|
|
yaxis: {title: 'Accuracy'},
|
|
font: {size: 18}
|
|
});
|
|
startTransferLearnButton.textContent = phase === INITIAL_PHASE ?
|
|
`Transfer-learning... (${(epoch / epochs * 1e2).toFixed(0)}%)` :
|
|
`Transfer-learning (fine-tuning)... (${
|
|
(epoch / fineTuningEpochs * 1e2).toFixed(0)}%)`
|
|
|
|
scrollToPageBottom();
|
|
}
|
|
|
|
disableAllCollectWordButtons();
|
|
const augmentByMixingNoiseRatio =
|
|
document.getElementById('augment-by-mixing-noise').checked ? 0.5 : null;
|
|
console.log(`augmentByMixingNoiseRatio = ${augmentByMixingNoiseRatio}`);
|
|
await transferRecognizer.train({
|
|
epochs,
|
|
validationSplit: 0.25,
|
|
augmentByMixingNoiseRatio,
|
|
callback: {
|
|
onEpochEnd: async (epoch, logs) => {
|
|
plotLossAndAccuracy(
|
|
epoch, logs.loss, logs.acc, logs.val_loss, logs.val_acc,
|
|
INITIAL_PHASE);
|
|
}
|
|
},
|
|
fineTuningEpochs,
|
|
fineTuningCallback: {
|
|
onEpochEnd: async (epoch, logs) => {
|
|
plotLossAndAccuracy(
|
|
epoch, logs.loss, logs.acc, logs.val_loss, logs.val_acc,
|
|
FINE_TUNING_PHASE);
|
|
}
|
|
}
|
|
});
|
|
saveTransferModelButton.disabled = false;
|
|
transferModelNameInput.value = transferRecognizer.name;
|
|
transferModelNameInput.disabled = true;
|
|
startTransferLearnButton.textContent = 'Transfer learning complete.';
|
|
transferModelNameInput.disabled = false;
|
|
startButton.disabled = false;
|
|
evalModelOnDatasetButton.disabled = false;
|
|
});
|
|
|
|
downloadAsFileButton.addEventListener('click', () => {
|
|
const basename = getDateString();
|
|
const artifacts = transferRecognizer.serializeExamples();
|
|
|
|
// Trigger downloading of the data .bin file.
|
|
const anchor = document.createElement('a');
|
|
anchor.download = `${basename}.bin`;
|
|
anchor.href = window.URL.createObjectURL(
|
|
new Blob([artifacts], {type: 'application/octet-stream'}));
|
|
anchor.click();
|
|
});
|
|
|
|
/** Get the base name of the downloaded files based on current dataset. */
|
|
function getDateString() {
|
|
const d = new Date();
|
|
const year = `${d.getFullYear()}`;
|
|
let month = `${d.getMonth() + 1}`;
|
|
let day = `${d.getDate()}`;
|
|
if (month.length < 2) {
|
|
month = `0${month}`;
|
|
}
|
|
if (day.length < 2) {
|
|
day = `0${day}`;
|
|
}
|
|
let hour = `${d.getHours()}`;
|
|
if (hour.length < 2) {
|
|
hour = `0${hour}`;
|
|
}
|
|
let minute = `${d.getMinutes()}`;
|
|
if (minute.length < 2) {
|
|
minute = `0${minute}`;
|
|
}
|
|
let second = `${d.getSeconds()}`;
|
|
if (second.length < 2) {
|
|
second = `0${second}`;
|
|
}
|
|
return `${year}-${month}-${day}T${hour}.${minute}.${second}`;
|
|
}
|
|
|
|
uploadFilesButton.addEventListener('click', async () => {
|
|
const files = datasetFileInput.files;
|
|
if (files == null || files.length !== 1) {
|
|
throw new Error('Must select exactly one file.');
|
|
}
|
|
const datasetFileReader = new FileReader();
|
|
datasetFileReader.onload = async event => {
|
|
try {
|
|
await loadDatasetInTransferRecognizer(event.target.result);
|
|
} catch (err) {
|
|
const originalTextContent = uploadFilesButton.textContent;
|
|
uploadFilesButton.textContent = err.message;
|
|
setTimeout(() => {
|
|
uploadFilesButton.textContent = originalTextContent;
|
|
}, 2000);
|
|
}
|
|
durationMultiplierSelect.value = `${transferDurationMultiplier}`;
|
|
durationMultiplierSelect.disabled = true;
|
|
enterLearnWordsButton.disabled = true;
|
|
};
|
|
datasetFileReader.onerror = () =>
|
|
console.error(`Failed to binary data from file '${dataFile.name}'.`);
|
|
datasetFileReader.readAsArrayBuffer(files[0]);
|
|
});
|
|
|
|
async function loadDatasetInTransferRecognizer(serialized) {
|
|
const modelName = transferModelNameInput.value;
|
|
if (modelName == null || modelName.length === 0) {
|
|
throw new Error('Need model name!');
|
|
}
|
|
|
|
if (transferRecognizer == null) {
|
|
transferRecognizer = recognizer.createTransfer(modelName);
|
|
}
|
|
transferRecognizer.loadExamples(serialized);
|
|
const exampleCounts = transferRecognizer.countExamples();
|
|
transferWords = [];
|
|
const modelNumFrames = transferRecognizer.modelInputShape()[1];
|
|
const durationMultipliers = [];
|
|
for (const word in exampleCounts) {
|
|
transferWords.push(word);
|
|
const examples = transferRecognizer.getExamples(word);
|
|
for (const example of examples) {
|
|
const spectrogram = example.example.spectrogram;
|
|
// Ignore _background_noise_ examples when determining the duration
|
|
// multiplier of the dataset.
|
|
if (word !== BACKGROUND_NOISE_TAG) {
|
|
durationMultipliers.push(Math.round(
|
|
spectrogram.data.length / spectrogram.frameSize / modelNumFrames));
|
|
}
|
|
}
|
|
}
|
|
transferWords.sort();
|
|
learnWordsInput.value = transferWords.join(',');
|
|
|
|
// Determine the transferDurationMultiplier value from the dataset.
|
|
transferDurationMultiplier =
|
|
durationMultipliers.length > 0 ? Math.max(...durationMultipliers) : 1;
|
|
console.log(
|
|
`Deteremined transferDurationMultiplier from uploaded ` +
|
|
`dataset: ${transferDurationMultiplier}`);
|
|
|
|
createWordDivs(transferWords);
|
|
datasetViz.redrawAll();
|
|
}
|
|
|
|
evalModelOnDatasetButton.addEventListener('click', async () => {
|
|
const files = datasetFileInput.files;
|
|
if (files == null || files.length !== 1) {
|
|
throw new Error('Must select exactly one file.');
|
|
}
|
|
evalModelOnDatasetButton.disabled = true;
|
|
const datasetFileReader = new FileReader();
|
|
datasetFileReader.onload = async event => {
|
|
try {
|
|
if (transferRecognizer == null) {
|
|
throw new Error('There is no model!');
|
|
}
|
|
|
|
// Load the dataset and perform evaluation of the transfer
|
|
// model using the dataset.
|
|
transferRecognizer.loadExamples(event.target.result);
|
|
const evalResult = await transferRecognizer.evaluate({
|
|
windowHopRatio: 0.25,
|
|
wordProbThresholds: [
|
|
0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.5,
|
|
0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0
|
|
]
|
|
});
|
|
// Plot the ROC curve.
|
|
const rocDataForPlot = {x: [], y: []};
|
|
evalResult.rocCurve.forEach(item => {
|
|
rocDataForPlot.x.push(item.fpr);
|
|
rocDataForPlot.y.push(item.tpr);
|
|
});
|
|
|
|
Plotly.newPlot('roc-plot', [rocDataForPlot], {
|
|
width: 360,
|
|
height: 360,
|
|
mode: 'markers',
|
|
marker: {size: 7},
|
|
xaxis: {title: 'False positive rate (FPR)', range: [0, 1]},
|
|
yaxis: {title: 'True positive rate (TPR)', range: [0, 1]},
|
|
font: {size: 18}
|
|
});
|
|
evalResultsSpan.textContent = `AUC = ${evalResult.auc}`;
|
|
} catch (err) {
|
|
const originalTextContent = evalModelOnDatasetButton.textContent;
|
|
evalModelOnDatasetButton.textContent = err.message;
|
|
setTimeout(() => {
|
|
evalModelOnDatasetButton.textContent = originalTextContent;
|
|
}, 2000);
|
|
}
|
|
evalModelOnDatasetButton.disabled = false;
|
|
};
|
|
datasetFileReader.onerror = () =>
|
|
console.error(`Failed to binary data from file '${dataFile.name}'.`);
|
|
datasetFileReader.readAsArrayBuffer(files[0]);
|
|
});
|
|
|
|
async function populateSavedTransferModelsSelect() {
|
|
const savedModelKeys = await SpeechCommands.listSavedTransferModels();
|
|
while (savedTransferModelsSelect.firstChild) {
|
|
savedTransferModelsSelect.removeChild(savedTransferModelsSelect.firstChild);
|
|
}
|
|
if (savedModelKeys.length > 0) {
|
|
for (const key of savedModelKeys) {
|
|
const option = document.createElement('option');
|
|
option.textContent = key;
|
|
option.id = key;
|
|
savedTransferModelsSelect.appendChild(option);
|
|
}
|
|
loadTransferModelButton.disabled = false;
|
|
}
|
|
}
|
|
|
|
saveTransferModelButton.addEventListener('click', async () => {
|
|
await transferRecognizer.save();
|
|
await populateSavedTransferModelsSelect();
|
|
saveTransferModelButton.textContent = 'Model saved!';
|
|
saveTransferModelButton.disabled = true;
|
|
});
|
|
|
|
loadTransferModelButton.addEventListener('click', async () => {
|
|
const transferModelName = savedTransferModelsSelect.value;
|
|
await recognizer.ensureModelLoaded();
|
|
transferRecognizer = recognizer.createTransfer(transferModelName);
|
|
await transferRecognizer.load();
|
|
transferModelNameInput.value = transferModelName;
|
|
transferModelNameInput.disabled = true;
|
|
learnWordsInput.value = transferRecognizer.wordLabels().join(',');
|
|
learnWordsInput.disabled = true;
|
|
durationMultiplierSelect.disabled = true;
|
|
enterLearnWordsButton.disabled = true;
|
|
saveTransferModelButton.disabled = true;
|
|
loadTransferModelButton.disabled = true;
|
|
loadTransferModelButton.textContent = 'Model loaded!';
|
|
});
|
|
|
|
modelIOButton.addEventListener('click', () => {
|
|
if (modelIOButton.textContent.endsWith(' >>')) {
|
|
transferModelSaveLoadInnerDiv.style.display = 'inline-block';
|
|
modelIOButton.textContent = modelIOButton.textContent.replace(' >>', ' <<');
|
|
} else {
|
|
transferModelSaveLoadInnerDiv.style.display = 'none';
|
|
modelIOButton.textContent = modelIOButton.textContent.replace(' <<', ' >>');
|
|
}
|
|
});
|
|
|
|
deleteTransferModelButton.addEventListener('click', async () => {
|
|
const transferModelName = savedTransferModelsSelect.value;
|
|
await recognizer.ensureModelLoaded();
|
|
transferRecognizer = recognizer.createTransfer(transferModelName);
|
|
await SpeechCommands.deleteSavedTransferModel(transferModelName);
|
|
deleteTransferModelButton.disabled = true;
|
|
deleteTransferModelButton.textContent = `Deleted "${transferModelName}"`;
|
|
await populateSavedTransferModelsSelect();
|
|
});
|
|
|
|
datasetIOButton.addEventListener('click', () => {
|
|
if (datasetIOButton.textContent.endsWith(' >>')) {
|
|
datasetIOInnerDiv.style.display = 'inline-block';
|
|
datasetIOButton.textContent =
|
|
datasetIOButton.textContent.replace(' >>', ' <<');
|
|
} else {
|
|
datasetIOInnerDiv.style.display = 'none';
|
|
datasetIOButton.textContent =
|
|
datasetIOButton.textContent.replace(' <<', ' >>');
|
|
}
|
|
});
|