70 lines
2.4 KiB
TypeScript
70 lines
2.4 KiB
TypeScript
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
|
|
// tslint:disable-next-line:no-require-imports
|
|
const packageJSON = require('../package.json');
|
|
import * as tf from '@tensorflow/tfjs-core';
|
|
import * as tfl from '@tensorflow/tfjs-layers';
|
|
import * as speechCommands from './index';
|
|
|
|
describe('Public API', () => {
|
|
it('version matches package.json', () => {
|
|
expect(typeof speechCommands.version).toEqual('string');
|
|
expect(speechCommands.version).toEqual(packageJSON.version);
|
|
});
|
|
});
|
|
|
|
describe('Creating recognizer', () => {
|
|
async function makeModelArtifacts(): Promise<tf.io.ModelArtifacts> {
|
|
const model = tfl.sequential();
|
|
model.add(tfl.layers.conv2d({
|
|
filters: 8,
|
|
kernelSize: 3,
|
|
activation: 'relu',
|
|
inputShape: [86, 500, 1]
|
|
}));
|
|
model.add(tfl.layers.flatten());
|
|
model.add(tfl.layers.dense({units: 3, activation: 'softmax'}));
|
|
let modelArtifacts: tf.io.ModelArtifacts;
|
|
await model.save(tf.io.withSaveHandler(artifacts => {
|
|
modelArtifacts = artifacts;
|
|
return null;
|
|
}));
|
|
return modelArtifacts;
|
|
}
|
|
|
|
function makeMetadata(): speechCommands.SpeechCommandRecognizerMetadata {
|
|
return {
|
|
wordLabels: [speechCommands.BACKGROUND_NOISE_TAG, 'foo', 'bar'],
|
|
tfjsSpeechCommandsVersion: speechCommands.version
|
|
};
|
|
}
|
|
|
|
it('Create recognizer from aritfacts and metadata objects', async () => {
|
|
const modelArtifacts = await makeModelArtifacts();
|
|
const metadata = makeMetadata();
|
|
const recognizer =
|
|
speechCommands.create('BROWSER_FFT', null, modelArtifacts, metadata);
|
|
await recognizer.ensureModelLoaded();
|
|
|
|
expect(recognizer.wordLabels()).toEqual([
|
|
speechCommands.BACKGROUND_NOISE_TAG, 'foo', 'bar'
|
|
]);
|
|
expect(recognizer.modelInputShape()).toEqual([null, 86, 500, 1]);
|
|
});
|
|
});
|