714 lines
		
	
	
		
			25 KiB
		
	
	
	
		
			JavaScript
		
	
	
	
	
	
			
		
		
	
	
			714 lines
		
	
	
		
			25 KiB
		
	
	
	
		
			JavaScript
		
	
	
	
	
	
| /**
 | |
|  * @license
 | |
|  * Copyright 2019 Google LLC. All Rights Reserved.
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  * https://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  * =============================================================================
 | |
|  */
 | |
| 
 | |
| import * as tf from '@tensorflow/tfjs';
 | |
| import Plotly from 'plotly.js-dist';
 | |
| 
 | |
| import * as SpeechCommands from '@tensorflow-models/speech-commands';
 | |
| 
 | |
| import {DatasetViz, removeNonFixedChildrenFromWordDiv} from './dataset-vis';
 | |
| import {hideCandidateWords, logToStatusDisplay, plotPredictions, plotSpectrogram, populateCandidateWords, showCandidateWords} from './ui';
 | |
| 
 | |
| const startButton = document.getElementById('start');
 | |
| const stopButton = document.getElementById('stop');
 | |
| const predictionCanvas = document.getElementById('prediction-canvas');
 | |
| 
 | |
| const probaThresholdInput = document.getElementById('proba-threshold');
 | |
| const epochsInput = document.getElementById('epochs');
 | |
| const fineTuningEpochsInput = document.getElementById('fine-tuning-epochs');
 | |
| 
 | |
| const datasetIOButton = document.getElementById('dataset-io');
 | |
| const datasetIOInnerDiv = document.getElementById('dataset-io-inner');
 | |
| const downloadAsFileButton = document.getElementById('download-dataset');
 | |
| const datasetFileInput = document.getElementById('dataset-file-input');
 | |
| const uploadFilesButton = document.getElementById('upload-dataset');
 | |
| 
 | |
| const evalModelOnDatasetButton =
 | |
|     document.getElementById('eval-model-on-dataset');
 | |
| const evalResultsSpan = document.getElementById('eval-results');
 | |
| 
 | |
| const modelIOButton = document.getElementById('model-io');
 | |
| const transferModelSaveLoadInnerDiv =
 | |
|     document.getElementById('transfer-model-save-load-inner');
 | |
| const loadTransferModelButton = document.getElementById('load-transfer-model');
 | |
| const saveTransferModelButton = document.getElementById('save-transfer-model');
 | |
| const savedTransferModelsSelect =
 | |
|     document.getElementById('saved-transfer-models');
 | |
| const deleteTransferModelButton =
 | |
|     document.getElementById('delete-transfer-model');
 | |
| 
 | |
| const BACKGROUND_NOISE_TAG = SpeechCommands.BACKGROUND_NOISE_TAG;
 | |
| 
 | |
| /**
 | |
|  * Transfer learning-related UI componenets.
 | |
|  */
 | |
| const transferModelNameInput = document.getElementById('transfer-model-name');
 | |
| const learnWordsInput = document.getElementById('learn-words');
 | |
| const durationMultiplierSelect = document.getElementById('duration-multiplier');
 | |
| const enterLearnWordsButton = document.getElementById('enter-learn-words');
 | |
| const includeTimeDomainWaveformCheckbox =
 | |
|     document.getElementById('include-audio-waveform');
 | |
| const collectButtonsDiv = document.getElementById('collect-words');
 | |
| const startTransferLearnButton =
 | |
|     document.getElementById('start-transfer-learn');
 | |
| 
 | |
| const XFER_MODEL_NAME = 'xfer-model';
 | |
| 
 | |
| // Minimum required number of examples per class for transfer learning.
 | |
| const MIN_EXAMPLES_PER_CLASS = 8;
 | |
| 
 | |
| let recognizer;
 | |
| let transferWords;
 | |
| let transferRecognizer;
 | |
| let transferDurationMultiplier;
 | |
| 
 | |
| (async function() {
 | |
|   logToStatusDisplay('Creating recognizer...');
 | |
|   recognizer = SpeechCommands.create('BROWSER_FFT');
 | |
| 
 | |
|   await populateSavedTransferModelsSelect();
 | |
| 
 | |
|   // Make sure the tf.Model is loaded through HTTP. If this is not
 | |
|   // called here, the tf.Model will be loaded the first time
 | |
|   // `listen()` is called.
 | |
|   recognizer.ensureModelLoaded()
 | |
|       .then(() => {
 | |
|         startButton.disabled = false;
 | |
|         enterLearnWordsButton.disabled = false;
 | |
|         loadTransferModelButton.disabled = false;
 | |
|         deleteTransferModelButton.disabled = false;
 | |
| 
 | |
|         transferModelNameInput.value = `model-${getDateString()}`;
 | |
| 
 | |
|         logToStatusDisplay('Model loaded.');
 | |
| 
 | |
|         const params = recognizer.params();
 | |
|         logToStatusDisplay(`sampleRateHz: ${params.sampleRateHz}`);
 | |
|         logToStatusDisplay(`fftSize: ${params.fftSize}`);
 | |
|         logToStatusDisplay(
 | |
|             `spectrogramDurationMillis: ` +
 | |
|             `${params.spectrogramDurationMillis.toFixed(2)}`);
 | |
|         logToStatusDisplay(
 | |
|             `tf.Model input shape: ` +
 | |
|             `${JSON.stringify(recognizer.modelInputShape())}`);
 | |
|       })
 | |
|       .catch(err => {
 | |
|         logToStatusDisplay(
 | |
|             'Failed to load model for recognizer: ' + err.message);
 | |
|       });
 | |
| })();
 | |
| 
 | |
| startButton.addEventListener('click', () => {
 | |
|   const activeRecognizer =
 | |
|       transferRecognizer == null ? recognizer : transferRecognizer;
 | |
|   populateCandidateWords(activeRecognizer.wordLabels());
 | |
| 
 | |
|   const suppressionTimeMillis = 1000;
 | |
|   activeRecognizer
 | |
|       .listen(
 | |
|           result => {
 | |
|             plotPredictions(
 | |
|                 predictionCanvas, activeRecognizer.wordLabels(), result.scores,
 | |
|                 3, suppressionTimeMillis);
 | |
|           },
 | |
|           {
 | |
|             includeSpectrogram: true,
 | |
|             suppressionTimeMillis,
 | |
|             probabilityThreshold: Number.parseFloat(probaThresholdInput.value)
 | |
|           })
 | |
|       .then(() => {
 | |
|         startButton.disabled = true;
 | |
|         stopButton.disabled = false;
 | |
|         showCandidateWords();
 | |
|         logToStatusDisplay('Streaming recognition started.');
 | |
|       })
 | |
|       .catch(err => {
 | |
|         logToStatusDisplay(
 | |
|             'ERROR: Failed to start streaming display: ' + err.message);
 | |
|       });
 | |
| });
 | |
| 
 | |
| stopButton.addEventListener('click', () => {
 | |
|   const activeRecognizer =
 | |
|       transferRecognizer == null ? recognizer : transferRecognizer;
 | |
|   activeRecognizer.stopListening()
 | |
|       .then(() => {
 | |
|         startButton.disabled = false;
 | |
|         stopButton.disabled = true;
 | |
|         hideCandidateWords();
 | |
|         logToStatusDisplay('Streaming recognition stopped.');
 | |
|       })
 | |
|       .catch(err => {
 | |
|         logToStatusDisplay(
 | |
|             'ERROR: Failed to stop streaming display: ' + err.message);
 | |
|       });
 | |
| });
 | |
| 
 | |
| /**
 | |
|  * Transfer learning logic.
 | |
|  */
 | |
| 
 | |
| /** Scroll to the bottom of the page */
 | |
| function scrollToPageBottom() {
 | |
|   const scrollingElement = (document.scrollingElement || document.body);
 | |
|   scrollingElement.scrollTop = scrollingElement.scrollHeight;
 | |
| }
 | |
| 
 | |
| let collectWordButtons = {};
 | |
| let datasetViz;
 | |
| 
 | |
| function createProgressBarAndIntervalJob(parentElement, durationSec) {
 | |
|   const progressBar = document.createElement('progress');
 | |
|   progressBar.value = 0;
 | |
|   progressBar.style['width'] = `${Math.round(window.innerWidth * 0.25)}px`;
 | |
|   // Update progress bar in increments.
 | |
|   const intervalJob = setInterval(() => {
 | |
|     progressBar.value += 0.05;
 | |
|   }, durationSec * 1e3 / 20);
 | |
|   parentElement.appendChild(progressBar);
 | |
|   return {progressBar, intervalJob};
 | |
| }
 | |
| 
 | |
| /**
 | |
|  * Create div elements for transfer words.
 | |
|  *
 | |
|  * @param {string[]} transferWords The array of transfer words.
 | |
|  * @returns {Object} An object mapping word to th div element created for it.
 | |
|  */
 | |
| function createWordDivs(transferWords) {
 | |
|   // Clear collectButtonsDiv first.
 | |
|   while (collectButtonsDiv.firstChild) {
 | |
|     collectButtonsDiv.removeChild(collectButtonsDiv.firstChild);
 | |
|   }
 | |
|   datasetViz = new DatasetViz(
 | |
|       transferRecognizer, collectButtonsDiv, MIN_EXAMPLES_PER_CLASS,
 | |
|       startTransferLearnButton, downloadAsFileButton,
 | |
|       transferDurationMultiplier);
 | |
| 
 | |
|   const wordDivs = {};
 | |
|   for (const word of transferWords) {
 | |
|     const wordDiv = document.createElement('div');
 | |
|     wordDiv.classList.add('word-div');
 | |
|     wordDivs[word] = wordDiv;
 | |
|     wordDiv.setAttribute('word', word);
 | |
|     const button = document.createElement('button');
 | |
|     button.setAttribute('isFixed', 'true');
 | |
|     button.style['display'] = 'inline-block';
 | |
|     button.style['vertical-align'] = 'middle';
 | |
| 
 | |
|     const displayWord = word === BACKGROUND_NOISE_TAG ? 'noise' : word;
 | |
| 
 | |
|     button.textContent = `${displayWord} (0)`;
 | |
|     wordDiv.appendChild(button);
 | |
|     wordDiv.className = 'transfer-word';
 | |
|     collectButtonsDiv.appendChild(wordDiv);
 | |
|     collectWordButtons[word] = button;
 | |
| 
 | |
|     let durationInput;
 | |
|     if (word === BACKGROUND_NOISE_TAG) {
 | |
|       // Create noise duration input.
 | |
|       durationInput = document.createElement('input');
 | |
|       durationInput.setAttribute('isFixed', 'true');
 | |
|       durationInput.value = '10';
 | |
|       durationInput.style['width'] = '100px';
 | |
|       wordDiv.appendChild(durationInput);
 | |
|       // Create time-unit span for noise duration.
 | |
|       const timeUnitSpan = document.createElement('span');
 | |
|       timeUnitSpan.setAttribute('isFixed', 'true');
 | |
|       timeUnitSpan.classList.add('settings');
 | |
|       timeUnitSpan.style['vertical-align'] = 'middle';
 | |
|       timeUnitSpan.textContent = 'seconds';
 | |
|       wordDiv.appendChild(timeUnitSpan);
 | |
|     }
 | |
| 
 | |
|     button.addEventListener('click', async () => {
 | |
|       disableAllCollectWordButtons();
 | |
|       removeNonFixedChildrenFromWordDiv(wordDiv);
 | |
| 
 | |
|       const collectExampleOptions = {};
 | |
|       let durationSec;
 | |
|       let intervalJob;
 | |
|       let progressBar;
 | |
| 
 | |
|       if (word === BACKGROUND_NOISE_TAG) {
 | |
|         // If the word type is background noise, display a progress bar during
 | |
|         // sound collection and do not show an incrementally updating
 | |
|         // spectrogram.
 | |
|         // _background_noise_ examples are special, in that user can specify
 | |
|         // the length of the recording (in seconds).
 | |
|         collectExampleOptions.durationSec =
 | |
|             Number.parseFloat(durationInput.value);
 | |
|         durationSec = collectExampleOptions.durationSec;
 | |
| 
 | |
|         const barAndJob = createProgressBarAndIntervalJob(wordDiv, durationSec);
 | |
|         progressBar = barAndJob.progressBar;
 | |
|         intervalJob = barAndJob.intervalJob;
 | |
|       } else {
 | |
|         // If this is not a background-noise word type and if the duration
 | |
|         // multiplier is >1 (> ~1 s recoding), show an incrementally
 | |
|         // updating spectrogram in real time.
 | |
|         collectExampleOptions.durationMultiplier = transferDurationMultiplier;
 | |
|         let tempSpectrogramData;
 | |
|         const tempCanvas = document.createElement('canvas');
 | |
|         tempCanvas.style['margin-left'] = '132px';
 | |
|         tempCanvas.height = 50;
 | |
|         wordDiv.appendChild(tempCanvas);
 | |
| 
 | |
|         collectExampleOptions.snippetDurationSec = 0.1;
 | |
|         collectExampleOptions.onSnippet = async (spectrogram) => {
 | |
|           if (tempSpectrogramData == null) {
 | |
|             tempSpectrogramData = spectrogram.data;
 | |
|           } else {
 | |
|             tempSpectrogramData = SpeechCommands.utils.concatenateFloat32Arrays(
 | |
|                 [tempSpectrogramData, spectrogram.data]);
 | |
|           }
 | |
|           plotSpectrogram(
 | |
|               tempCanvas, tempSpectrogramData, spectrogram.frameSize,
 | |
|               spectrogram.frameSize, {pixelsPerFrame: 2});
 | |
|         }
 | |
|       }
 | |
| 
 | |
|       collectExampleOptions.includeRawAudio =
 | |
|           includeTimeDomainWaveformCheckbox.checked;
 | |
|       const spectrogram =
 | |
|           await transferRecognizer.collectExample(word, collectExampleOptions);
 | |
| 
 | |
| 
 | |
|       if (intervalJob != null) {
 | |
|         clearInterval(intervalJob);
 | |
|       }
 | |
|       if (progressBar != null) {
 | |
|         wordDiv.removeChild(progressBar);
 | |
|       }
 | |
|       const examples = transferRecognizer.getExamples(word)
 | |
|       const example = examples[examples.length - 1];
 | |
|       await datasetViz.drawExample(
 | |
|           wordDiv, word, spectrogram, example.example.rawAudio, example.uid);
 | |
|       enableAllCollectWordButtons();
 | |
|     });
 | |
|   }
 | |
|   return wordDivs;
 | |
| }
 | |
| 
 | |
| enterLearnWordsButton.addEventListener('click', () => {
 | |
|   const modelName = transferModelNameInput.value;
 | |
|   if (modelName == null || modelName.length === 0) {
 | |
|     enterLearnWordsButton.textContent = 'Need model name!';
 | |
|     setTimeout(() => {
 | |
|       enterLearnWordsButton.textContent = 'Enter transfer words';
 | |
|     }, 2000);
 | |
|     return;
 | |
|   }
 | |
| 
 | |
|   // We disable the option to upload an existing dataset from files
 | |
|   // once the "Enter transfer words" button has been clicked.
 | |
|   // However, the user can still load an existing dataset from
 | |
|   // files first and keep appending examples to it.
 | |
|   disableFileUploadControls();
 | |
|   enterLearnWordsButton.disabled = true;
 | |
| 
 | |
|   transferDurationMultiplier = durationMultiplierSelect.value;
 | |
| 
 | |
|   learnWordsInput.disabled = true;
 | |
|   enterLearnWordsButton.disabled = true;
 | |
|   transferWords = learnWordsInput.value.trim().split(',').map(w => w.trim());
 | |
|   transferWords.sort();
 | |
|   if (transferWords == null || transferWords.length <= 1) {
 | |
|     logToStatusDisplay('ERROR: Invalid list of transfer words.');
 | |
|     return;
 | |
|   }
 | |
| 
 | |
|   transferRecognizer = recognizer.createTransfer(modelName);
 | |
|   createWordDivs(transferWords);
 | |
| 
 | |
|   scrollToPageBottom();
 | |
| });
 | |
| 
 | |
| function disableAllCollectWordButtons() {
 | |
|   for (const word in collectWordButtons) {
 | |
|     collectWordButtons[word].disabled = true;
 | |
|   }
 | |
| }
 | |
| 
 | |
| function enableAllCollectWordButtons() {
 | |
|   for (const word in collectWordButtons) {
 | |
|     collectWordButtons[word].disabled = false;
 | |
|   }
 | |
| }
 | |
| 
 | |
| function disableFileUploadControls() {
 | |
|   datasetFileInput.disabled = true;
 | |
|   uploadFilesButton.disabled = true;
 | |
| }
 | |
| 
 | |
| startTransferLearnButton.addEventListener('click', async () => {
 | |
|   startTransferLearnButton.disabled = true;
 | |
|   startButton.disabled = true;
 | |
|   startTransferLearnButton.textContent = 'Transfer learning starting...';
 | |
|   await tf.nextFrame();
 | |
| 
 | |
|   const INITIAL_PHASE = 'initial';
 | |
|   const FINE_TUNING_PHASE = 'fineTuningPhase';
 | |
| 
 | |
|   const epochs = parseInt(epochsInput.value);
 | |
|   const fineTuningEpochs = parseInt(fineTuningEpochsInput.value);
 | |
|   const trainLossValues = {};
 | |
|   const valLossValues = {};
 | |
|   const trainAccValues = {};
 | |
|   const valAccValues = {};
 | |
| 
 | |
|   for (const phase of [INITIAL_PHASE, FINE_TUNING_PHASE]) {
 | |
|     const phaseSuffix = phase === FINE_TUNING_PHASE ? ' (FT)' : '';
 | |
|     const lineWidth = phase === FINE_TUNING_PHASE ? 2 : 1;
 | |
|     trainLossValues[phase] = {
 | |
|       x: [],
 | |
|       y: [],
 | |
|       name: 'train' + phaseSuffix,
 | |
|       mode: 'lines',
 | |
|       line: {width: lineWidth}
 | |
|     };
 | |
|     valLossValues[phase] = {
 | |
|       x: [],
 | |
|       y: [],
 | |
|       name: 'val' + phaseSuffix,
 | |
|       mode: 'lines',
 | |
|       line: {width: lineWidth}
 | |
|     };
 | |
|     trainAccValues[phase] = {
 | |
|       x: [],
 | |
|       y: [],
 | |
|       name: 'train' + phaseSuffix,
 | |
|       mode: 'lines',
 | |
|       line: {width: lineWidth}
 | |
|     };
 | |
|     valAccValues[phase] = {
 | |
|       x: [],
 | |
|       y: [],
 | |
|       name: 'val' + phaseSuffix,
 | |
|       mode: 'lines',
 | |
|       line: {width: lineWidth}
 | |
|     };
 | |
|   }
 | |
| 
 | |
|   function plotLossAndAccuracy(epoch, loss, acc, val_loss, val_acc, phase) {
 | |
|     const displayEpoch = phase === FINE_TUNING_PHASE ? (epoch + epochs) : epoch;
 | |
|     trainLossValues[phase].x.push(displayEpoch);
 | |
|     trainLossValues[phase].y.push(loss);
 | |
|     trainAccValues[phase].x.push(displayEpoch);
 | |
|     trainAccValues[phase].y.push(acc);
 | |
|     valLossValues[phase].x.push(displayEpoch);
 | |
|     valLossValues[phase].y.push(val_loss);
 | |
|     valAccValues[phase].x.push(displayEpoch);
 | |
|     valAccValues[phase].y.push(val_acc);
 | |
| 
 | |
|     Plotly.newPlot(
 | |
|         'loss-plot',
 | |
|         [
 | |
|           trainLossValues[INITIAL_PHASE], valLossValues[INITIAL_PHASE],
 | |
|           trainLossValues[FINE_TUNING_PHASE], valLossValues[FINE_TUNING_PHASE]
 | |
|         ],
 | |
|         {
 | |
|           width: 480,
 | |
|           height: 360,
 | |
|           xaxis: {title: 'Epoch #'},
 | |
|           yaxis: {title: 'Loss'},
 | |
|           font: {size: 18}
 | |
|         });
 | |
|     Plotly.newPlot(
 | |
|         'accuracy-plot',
 | |
|         [
 | |
|           trainAccValues[INITIAL_PHASE], valAccValues[INITIAL_PHASE],
 | |
|           trainAccValues[FINE_TUNING_PHASE], valAccValues[FINE_TUNING_PHASE]
 | |
|         ],
 | |
|         {
 | |
|           width: 480,
 | |
|           height: 360,
 | |
|           xaxis: {title: 'Epoch #'},
 | |
|           yaxis: {title: 'Accuracy'},
 | |
|           font: {size: 18}
 | |
|         });
 | |
|     startTransferLearnButton.textContent = phase === INITIAL_PHASE ?
 | |
|         `Transfer-learning... (${(epoch / epochs * 1e2).toFixed(0)}%)` :
 | |
|         `Transfer-learning (fine-tuning)... (${
 | |
|             (epoch / fineTuningEpochs * 1e2).toFixed(0)}%)`
 | |
| 
 | |
|     scrollToPageBottom();
 | |
|   }
 | |
| 
 | |
|   disableAllCollectWordButtons();
 | |
|   const augmentByMixingNoiseRatio =
 | |
|       document.getElementById('augment-by-mixing-noise').checked ? 0.5 : null;
 | |
|   console.log(`augmentByMixingNoiseRatio = ${augmentByMixingNoiseRatio}`);
 | |
|   await transferRecognizer.train({
 | |
|     epochs,
 | |
|     validationSplit: 0.25,
 | |
|     augmentByMixingNoiseRatio,
 | |
|     callback: {
 | |
|       onEpochEnd: async (epoch, logs) => {
 | |
|         plotLossAndAccuracy(
 | |
|             epoch, logs.loss, logs.acc, logs.val_loss, logs.val_acc,
 | |
|             INITIAL_PHASE);
 | |
|       }
 | |
|     },
 | |
|     fineTuningEpochs,
 | |
|     fineTuningCallback: {
 | |
|       onEpochEnd: async (epoch, logs) => {
 | |
|         plotLossAndAccuracy(
 | |
|             epoch, logs.loss, logs.acc, logs.val_loss, logs.val_acc,
 | |
|             FINE_TUNING_PHASE);
 | |
|       }
 | |
|     }
 | |
|   });
 | |
|   saveTransferModelButton.disabled = false;
 | |
|   transferModelNameInput.value = transferRecognizer.name;
 | |
|   transferModelNameInput.disabled = true;
 | |
|   startTransferLearnButton.textContent = 'Transfer learning complete.';
 | |
|   transferModelNameInput.disabled = false;
 | |
|   startButton.disabled = false;
 | |
|   evalModelOnDatasetButton.disabled = false;
 | |
| });
 | |
| 
 | |
| downloadAsFileButton.addEventListener('click', () => {
 | |
|   const basename = getDateString();
 | |
|   const artifacts = transferRecognizer.serializeExamples();
 | |
| 
 | |
|   // Trigger downloading of the data .bin file.
 | |
|   const anchor = document.createElement('a');
 | |
|   anchor.download = `${basename}.bin`;
 | |
|   anchor.href = window.URL.createObjectURL(
 | |
|       new Blob([artifacts], {type: 'application/octet-stream'}));
 | |
|   anchor.click();
 | |
| });
 | |
| 
 | |
| /** Get the base name of the downloaded files based on current dataset. */
 | |
| function getDateString() {
 | |
|   const d = new Date();
 | |
|   const year = `${d.getFullYear()}`;
 | |
|   let month = `${d.getMonth() + 1}`;
 | |
|   let day = `${d.getDate()}`;
 | |
|   if (month.length < 2) {
 | |
|     month = `0${month}`;
 | |
|   }
 | |
|   if (day.length < 2) {
 | |
|     day = `0${day}`;
 | |
|   }
 | |
|   let hour = `${d.getHours()}`;
 | |
|   if (hour.length < 2) {
 | |
|     hour = `0${hour}`;
 | |
|   }
 | |
|   let minute = `${d.getMinutes()}`;
 | |
|   if (minute.length < 2) {
 | |
|     minute = `0${minute}`;
 | |
|   }
 | |
|   let second = `${d.getSeconds()}`;
 | |
|   if (second.length < 2) {
 | |
|     second = `0${second}`;
 | |
|   }
 | |
|   return `${year}-${month}-${day}T${hour}.${minute}.${second}`;
 | |
| }
 | |
| 
 | |
| uploadFilesButton.addEventListener('click', async () => {
 | |
|   const files = datasetFileInput.files;
 | |
|   if (files == null || files.length !== 1) {
 | |
|     throw new Error('Must select exactly one file.');
 | |
|   }
 | |
|   const datasetFileReader = new FileReader();
 | |
|   datasetFileReader.onload = async event => {
 | |
|     try {
 | |
|       await loadDatasetInTransferRecognizer(event.target.result);
 | |
|     } catch (err) {
 | |
|       const originalTextContent = uploadFilesButton.textContent;
 | |
|       uploadFilesButton.textContent = err.message;
 | |
|       setTimeout(() => {
 | |
|         uploadFilesButton.textContent = originalTextContent;
 | |
|       }, 2000);
 | |
|     }
 | |
|     durationMultiplierSelect.value = `${transferDurationMultiplier}`;
 | |
|     durationMultiplierSelect.disabled = true;
 | |
|     enterLearnWordsButton.disabled = true;
 | |
|   };
 | |
|   datasetFileReader.onerror = () =>
 | |
|       console.error(`Failed to binary data from file '${dataFile.name}'.`);
 | |
|   datasetFileReader.readAsArrayBuffer(files[0]);
 | |
| });
 | |
| 
 | |
| async function loadDatasetInTransferRecognizer(serialized) {
 | |
|   const modelName = transferModelNameInput.value;
 | |
|   if (modelName == null || modelName.length === 0) {
 | |
|     throw new Error('Need model name!');
 | |
|   }
 | |
| 
 | |
|   if (transferRecognizer == null) {
 | |
|     transferRecognizer = recognizer.createTransfer(modelName);
 | |
|   }
 | |
|   transferRecognizer.loadExamples(serialized);
 | |
|   const exampleCounts = transferRecognizer.countExamples();
 | |
|   transferWords = [];
 | |
|   const modelNumFrames = transferRecognizer.modelInputShape()[1];
 | |
|   const durationMultipliers = [];
 | |
|   for (const word in exampleCounts) {
 | |
|     transferWords.push(word);
 | |
|     const examples = transferRecognizer.getExamples(word);
 | |
|     for (const example of examples) {
 | |
|       const spectrogram = example.example.spectrogram;
 | |
|       // Ignore _background_noise_ examples when determining the duration
 | |
|       // multiplier of the dataset.
 | |
|       if (word !== BACKGROUND_NOISE_TAG) {
 | |
|         durationMultipliers.push(Math.round(
 | |
|             spectrogram.data.length / spectrogram.frameSize / modelNumFrames));
 | |
|       }
 | |
|     }
 | |
|   }
 | |
|   transferWords.sort();
 | |
|   learnWordsInput.value = transferWords.join(',');
 | |
| 
 | |
|   // Determine the transferDurationMultiplier value from the dataset.
 | |
|   transferDurationMultiplier =
 | |
|       durationMultipliers.length > 0 ? Math.max(...durationMultipliers) : 1;
 | |
|   console.log(
 | |
|       `Deteremined transferDurationMultiplier from uploaded ` +
 | |
|       `dataset: ${transferDurationMultiplier}`);
 | |
| 
 | |
|   createWordDivs(transferWords);
 | |
|   datasetViz.redrawAll();
 | |
| }
 | |
| 
 | |
| evalModelOnDatasetButton.addEventListener('click', async () => {
 | |
|   const files = datasetFileInput.files;
 | |
|   if (files == null || files.length !== 1) {
 | |
|     throw new Error('Must select exactly one file.');
 | |
|   }
 | |
|   evalModelOnDatasetButton.disabled = true;
 | |
|   const datasetFileReader = new FileReader();
 | |
|   datasetFileReader.onload = async event => {
 | |
|     try {
 | |
|       if (transferRecognizer == null) {
 | |
|         throw new Error('There is no model!');
 | |
|       }
 | |
| 
 | |
|       // Load the dataset and perform evaluation of the transfer
 | |
|       // model using the dataset.
 | |
|       transferRecognizer.loadExamples(event.target.result);
 | |
|       const evalResult = await transferRecognizer.evaluate({
 | |
|         windowHopRatio: 0.25,
 | |
|         wordProbThresholds: [
 | |
|           0,    0.05, 0.1,  0.15, 0.2,  0.25, 0.3,  0.35, 0.4,  0.5,
 | |
|           0.55, 0.6,  0.65, 0.7,  0.75, 0.8,  0.85, 0.9,  0.95, 1.0
 | |
|         ]
 | |
|       });
 | |
|       // Plot the ROC curve.
 | |
|       const rocDataForPlot = {x: [], y: []};
 | |
|       evalResult.rocCurve.forEach(item => {
 | |
|         rocDataForPlot.x.push(item.fpr);
 | |
|         rocDataForPlot.y.push(item.tpr);
 | |
|       });
 | |
| 
 | |
|       Plotly.newPlot('roc-plot', [rocDataForPlot], {
 | |
|         width: 360,
 | |
|         height: 360,
 | |
|         mode: 'markers',
 | |
|         marker: {size: 7},
 | |
|         xaxis: {title: 'False positive rate (FPR)', range: [0, 1]},
 | |
|         yaxis: {title: 'True positive rate (TPR)', range: [0, 1]},
 | |
|         font: {size: 18}
 | |
|       });
 | |
|       evalResultsSpan.textContent = `AUC = ${evalResult.auc}`;
 | |
|     } catch (err) {
 | |
|       const originalTextContent = evalModelOnDatasetButton.textContent;
 | |
|       evalModelOnDatasetButton.textContent = err.message;
 | |
|       setTimeout(() => {
 | |
|         evalModelOnDatasetButton.textContent = originalTextContent;
 | |
|       }, 2000);
 | |
|     }
 | |
|     evalModelOnDatasetButton.disabled = false;
 | |
|   };
 | |
|   datasetFileReader.onerror = () =>
 | |
|       console.error(`Failed to binary data from file '${dataFile.name}'.`);
 | |
|   datasetFileReader.readAsArrayBuffer(files[0]);
 | |
| });
 | |
| 
 | |
| async function populateSavedTransferModelsSelect() {
 | |
|   const savedModelKeys = await SpeechCommands.listSavedTransferModels();
 | |
|   while (savedTransferModelsSelect.firstChild) {
 | |
|     savedTransferModelsSelect.removeChild(savedTransferModelsSelect.firstChild);
 | |
|   }
 | |
|   if (savedModelKeys.length > 0) {
 | |
|     for (const key of savedModelKeys) {
 | |
|       const option = document.createElement('option');
 | |
|       option.textContent = key;
 | |
|       option.id = key;
 | |
|       savedTransferModelsSelect.appendChild(option);
 | |
|     }
 | |
|     loadTransferModelButton.disabled = false;
 | |
|   }
 | |
| }
 | |
| 
 | |
| saveTransferModelButton.addEventListener('click', async () => {
 | |
|   await transferRecognizer.save();
 | |
|   await populateSavedTransferModelsSelect();
 | |
|   saveTransferModelButton.textContent = 'Model saved!';
 | |
|   saveTransferModelButton.disabled = true;
 | |
| });
 | |
| 
 | |
| loadTransferModelButton.addEventListener('click', async () => {
 | |
|   const transferModelName = savedTransferModelsSelect.value;
 | |
|   await recognizer.ensureModelLoaded();
 | |
|   transferRecognizer = recognizer.createTransfer(transferModelName);
 | |
|   await transferRecognizer.load();
 | |
|   transferModelNameInput.value = transferModelName;
 | |
|   transferModelNameInput.disabled = true;
 | |
|   learnWordsInput.value = transferRecognizer.wordLabels().join(',');
 | |
|   learnWordsInput.disabled = true;
 | |
|   durationMultiplierSelect.disabled = true;
 | |
|   enterLearnWordsButton.disabled = true;
 | |
|   saveTransferModelButton.disabled = true;
 | |
|   loadTransferModelButton.disabled = true;
 | |
|   loadTransferModelButton.textContent = 'Model loaded!';
 | |
| });
 | |
| 
 | |
| modelIOButton.addEventListener('click', () => {
 | |
|   if (modelIOButton.textContent.endsWith(' >>')) {
 | |
|     transferModelSaveLoadInnerDiv.style.display = 'inline-block';
 | |
|     modelIOButton.textContent = modelIOButton.textContent.replace(' >>', ' <<');
 | |
|   } else {
 | |
|     transferModelSaveLoadInnerDiv.style.display = 'none';
 | |
|     modelIOButton.textContent = modelIOButton.textContent.replace(' <<', ' >>');
 | |
|   }
 | |
| });
 | |
| 
 | |
| deleteTransferModelButton.addEventListener('click', async () => {
 | |
|   const transferModelName = savedTransferModelsSelect.value;
 | |
|   await recognizer.ensureModelLoaded();
 | |
|   transferRecognizer = recognizer.createTransfer(transferModelName);
 | |
|   await SpeechCommands.deleteSavedTransferModel(transferModelName);
 | |
|   deleteTransferModelButton.disabled = true;
 | |
|   deleteTransferModelButton.textContent = `Deleted "${transferModelName}"`;
 | |
|   await populateSavedTransferModelsSelect();
 | |
| });
 | |
| 
 | |
| datasetIOButton.addEventListener('click', () => {
 | |
|   if (datasetIOButton.textContent.endsWith(' >>')) {
 | |
|     datasetIOInnerDiv.style.display = 'inline-block';
 | |
|     datasetIOButton.textContent =
 | |
|         datasetIOButton.textContent.replace(' >>', ' <<');
 | |
|   } else {
 | |
|     datasetIOInnerDiv.style.display = 'none';
 | |
|     datasetIOButton.textContent =
 | |
|         datasetIOButton.textContent.replace(' <<', ' >>');
 | |
|   }
 | |
| });
 |