140 lines
4.5 KiB
TypeScript
140 lines
4.5 KiB
TypeScript
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
|
|
// Load the binding
|
|
import '@tensorflow/tfjs-node';
|
|
|
|
import chalk from 'chalk';
|
|
import * as ora from 'ora';
|
|
import * as Vorpal from 'vorpal';
|
|
|
|
import {AudioModel} from './audio_model';
|
|
import {Dataset} from './dataset';
|
|
import {WavFileFeatureExtractor} from './wav_file_feature_extractor';
|
|
|
|
// tslint:disable-next-line:no-any
|
|
(global as any).AudioContext = class AudioContext {};
|
|
|
|
export const MODEL_SHAPE = [98, 40, 1];
|
|
export const labelsMsg = [
|
|
{type: 'input', name: 'labels', message: 'Enter labels (seperate by comma)'}
|
|
];
|
|
export const trainingMsg = [
|
|
{type: 'input', name: 'dir', message: 'Enter file directory'},
|
|
{type: 'input', name: 'label', message: 'Enter label for the directory'}
|
|
];
|
|
export const filenameMsg = [{
|
|
type: 'input',
|
|
name: 'filename',
|
|
message: 'Enter target filename for the model'
|
|
}];
|
|
let model: AudioModel;
|
|
let labels: string[];
|
|
const vorpal = new Vorpal();
|
|
let spinner = ora();
|
|
vorpal.command('create_model [labels...]')
|
|
.alias('c')
|
|
.description('create the audio model')
|
|
.action((args, cb) => {
|
|
console.log(args.labels);
|
|
labels = args.labels as string[];
|
|
model = new AudioModel(
|
|
MODEL_SHAPE, labels, new Dataset(labels.length),
|
|
new WavFileFeatureExtractor());
|
|
cb();
|
|
});
|
|
|
|
vorpal
|
|
.command(
|
|
'load_dataset all <dir>',
|
|
'Load all the data from the root directory by the labels')
|
|
.alias('la')
|
|
.action((args) => {
|
|
spinner.start('load dataset ...');
|
|
return model
|
|
.loadAll(
|
|
args.dir as string,
|
|
(text: string, finished?: boolean) => {
|
|
if (finished) {
|
|
spinner.succeed(text);
|
|
} else {
|
|
spinner.start();
|
|
spinner.text = text;
|
|
spinner.render();
|
|
}
|
|
})
|
|
.then(() => spinner.stop());
|
|
});
|
|
vorpal
|
|
.command(
|
|
'load_dataset <dir> <label>',
|
|
'Load the dataset from the directory with the label')
|
|
.alias('l')
|
|
.action((args) => {
|
|
spinner = ora('creating tensors ...');
|
|
spinner.start();
|
|
return model
|
|
.loadData(
|
|
args.dir as string, args.label as string,
|
|
(text: string) => {
|
|
// console.log(text);
|
|
spinner.text = text;
|
|
spinner.render();
|
|
})
|
|
.then(() => spinner.stop(), (err) => {
|
|
spinner.fail(`failed to load: ${err}`);
|
|
});
|
|
});
|
|
vorpal.command('dataset size', 'Show the size of the dataset')
|
|
.alias('d')
|
|
.action((args, cb) => {
|
|
console.log(chalk.green(`dataset size = ${model.size()}`));
|
|
cb();
|
|
});
|
|
vorpal.command('train [epoch]')
|
|
.alias('t')
|
|
.description('train all audio dataset')
|
|
.action((args) => {
|
|
spinner = ora('training models ...').start();
|
|
return model
|
|
.train(parseInt(args.epoch as string, 10) || 20, {
|
|
onBatchEnd: async (batch, logs) => {
|
|
spinner.text = chalk.green(`loss: ${logs.loss.toFixed(5)}`);
|
|
spinner.render();
|
|
},
|
|
onEpochEnd: async (epoch, logs) => {
|
|
spinner.succeed(chalk.green(
|
|
`epoch: ${epoch}, loss: ${logs.loss.toFixed(5)}` +
|
|
`, accuracy: ${logs.acc.toFixed(5)}` +
|
|
`, validation accuracy: ${logs.val_acc.toFixed(5)}`));
|
|
spinner.start();
|
|
}
|
|
})
|
|
.then(() => spinner.stop());
|
|
});
|
|
vorpal.command('save_model <filename>')
|
|
.alias('s')
|
|
.description('save the audio model')
|
|
.action((args) => {
|
|
spinner.start(`saving to ${args.filename} ...`);
|
|
return model.save(args.filename as string).then(() => {
|
|
spinner.succeed(`${args.filename} saved.`);
|
|
}, () => spinner.fail(`failed to save ${args.filename}`));
|
|
});
|
|
|
|
vorpal.show();
|