mobileNet/随机森林/decision-tree.js
2025-08-11 17:44:57 +08:00

394 lines
12 KiB
JavaScript

var dt = (function () {
/**
* Creates an instance of DecisionTree
*
* @constructor
* @param builder - contains training set and
* some configuration parameters
*/
function DecisionTree(builder) {
this.root = buildDecisionTree({
trainingSet: builder.trainingSet,
ignoredAttributes: arrayToHashSet(builder.ignoredAttributes),
categoryAttr: builder.categoryAttr || 'category',
minItemsCount: builder.minItemsCount || 1,
entropyThrehold: builder.entropyThrehold || 0.01,
maxTreeDepth: builder.maxTreeDepth || 70
});
}
DecisionTree.prototype.predict = function (item) {
return predict(this.root, item);
}
/**
* Creates an instance of RandomForest
* with specific number of trees
*
* @constructor
* @param builder - contains training set and some
* configuration parameters for
* building decision trees
*/
function RandomForest(builder, treesNumber) {
this.trees = buildRandomForest(builder, treesNumber);
}
RandomForest.prototype.predict = function (item) {
return predictRandomForest(this.trees, item);
}
/**
* Transforming array to object with such attributes
* as elements of array (afterwards it can be used as HashSet)
*/
function arrayToHashSet(array) {
var hashSet = {};
if (array) {
for(var i in array) {
var attr = array[i];
hashSet[attr] = true;
}
}
return hashSet;
}
/**
* Calculating how many objects have the same
* values of specific attribute.
*
* @param items - array of objects
*
* @param attr - variable with name of attribute,
* which embedded in each object
*/
function countUniqueValues(items, attr) {
var counter = {};
// detecting different values of attribute
for (var i = items.length - 1; i >= 0; i--) {
// items[i][attr] - value of attribute
counter[items[i][attr]] = 0;
}
// counting number of occurrences of each of values
// of attribute
for (var i = items.length - 1; i >= 0; i--) {
counter[items[i][attr]] += 1;
}
return counter;
}
/**
* Calculating entropy of array of objects
* by specific attribute.
*
* @param items - array of objects
*
* @param attr - variable with name of attribute,
* which embedded in each object
*/
function entropy(items, attr) {
// counting number of occurrences of each of values
// of attribute
var counter = countUniqueValues(items, attr);
var entropy = 0;
var p;
for (var i in counter) {
p = counter[i] / items.length;
entropy += -p * Math.log(p);
}
return entropy;
}
/**
* Splitting array of objects by value of specific attribute,
* using specific predicate and pivot.
*
* Items which matched by predicate will be copied to
* the new array called 'match', and the rest of the items
* will be copied to array with name 'notMatch'
*
* @param items - array of objects
*
* @param attr - variable with name of attribute,
* which embedded in each object
*
* @param predicate - function(x, y)
* which returns 'true' or 'false'
*
* @param pivot - used as the second argument when
* calling predicate function:
* e.g. predicate(item[attr], pivot)
*/
function split(items, attr, predicate, pivot) {
var match = [];
var notMatch = [];
var item,
attrValue;
for (var i = items.length - 1; i >= 0; i--) {
item = items[i];
attrValue = item[attr];
if (predicate(attrValue, pivot)) {
match.push(item);
} else {
notMatch.push(item);
}
};
return {
match: match,
notMatch: notMatch
};
}
/**
* Finding value of specific attribute which is most frequent
* in given array of objects.
*
* @param items - array of objects
*
* @param attr - variable with name of attribute,
* which embedded in each object
*/
function mostFrequentValue(items, attr) {
// counting number of occurrences of each of values
// of attribute
var counter = countUniqueValues(items, attr);
var mostFrequentCount = 0;
var mostFrequentValue;
for (var value in counter) {
if (counter[value] > mostFrequentCount) {
mostFrequentCount = counter[value];
mostFrequentValue = value;
}
};
return mostFrequentValue;
}
var predicates = {
'==': function (a, b) { return a == b },
'>=': function (a, b) { return a >= b }
};
/**
* Function for building decision tree
*/
function buildDecisionTree(builder) {
var trainingSet = builder.trainingSet;
var minItemsCount = builder.minItemsCount;
var categoryAttr = builder.categoryAttr;
var entropyThrehold = builder.entropyThrehold;
var maxTreeDepth = builder.maxTreeDepth;
var ignoredAttributes = builder.ignoredAttributes;
if ((maxTreeDepth == 0) || (trainingSet.length <= minItemsCount)) {
// restriction by maximal depth of tree
// or size of training set is to small
// so we have to terminate process of building tree
return {
category: mostFrequentValue(trainingSet, categoryAttr)
};
}
var initialEntropy = entropy(trainingSet, categoryAttr);
if (initialEntropy <= entropyThrehold) {
// entropy of training set too small
// (it means that training set is almost homogeneous),
// so we have to terminate process of building tree
return {
category: mostFrequentValue(trainingSet, categoryAttr)
};
}
// used as hash-set for avoiding the checking of split by rules
// with the same 'attribute-predicate-pivot' more than once
var alreadyChecked = {};
// this variable expected to contain rule, which splits training set
// into subsets with smaller values of entropy (produces informational gain)
var bestSplit = {gain: 0};
for (var i = trainingSet.length - 1; i >= 0; i--) {
var item = trainingSet[i];
// iterating over all attributes of item
for (var attr in item) {
if ((attr == categoryAttr) || ignoredAttributes[attr]) {
continue;
}
// let the value of current attribute be the pivot
var pivot = item[attr];
// pick the predicate
// depending on the type of the attribute value
var predicateName;
if (typeof pivot == 'number') {
predicateName = '>=';
} else {
// there is no sense to compare non-numeric attributes
// so we will check only equality of such attributes
predicateName = '==';
}
var attrPredPivot = attr + predicateName + pivot;
if (alreadyChecked[attrPredPivot]) {
// skip such pairs of 'attribute-predicate-pivot',
// which been already checked
continue;
}
alreadyChecked[attrPredPivot] = true;
var predicate = predicates[predicateName];
// splitting training set by given 'attribute-predicate-value'
var currSplit = split(trainingSet, attr, predicate, pivot);
// calculating entropy of subsets
var matchEntropy = entropy(currSplit.match, categoryAttr);
var notMatchEntropy = entropy(currSplit.notMatch, categoryAttr);
// calculating informational gain
var newEntropy = 0;
newEntropy += matchEntropy * currSplit.match.length;
newEntropy += notMatchEntropy * currSplit.notMatch.length;
newEntropy /= trainingSet.length;
var currGain = initialEntropy - newEntropy;
if (currGain > bestSplit.gain) {
// remember pairs 'attribute-predicate-value'
// which provides informational gain
bestSplit = currSplit;
bestSplit.predicateName = predicateName;
bestSplit.predicate = predicate;
bestSplit.attribute = attr;
bestSplit.pivot = pivot;
bestSplit.gain = currGain;
}
}
}
if (!bestSplit.gain) {
// can't find optimal split
return { category: mostFrequentValue(trainingSet, categoryAttr) };
}
// building subtrees
builder.maxTreeDepth = maxTreeDepth - 1;
builder.trainingSet = bestSplit.match;
var matchSubTree = buildDecisionTree(builder);
builder.trainingSet = bestSplit.notMatch;
var notMatchSubTree = buildDecisionTree(builder);
return {
attribute: bestSplit.attribute,
predicate: bestSplit.predicate,
predicateName: bestSplit.predicateName,
pivot: bestSplit.pivot,
match: matchSubTree,
notMatch: notMatchSubTree,
matchedCount: bestSplit.match.length,
notMatchedCount: bestSplit.notMatch.length
};
}
/**
* Classifying item, using decision tree
*/
function predict(tree, item) {
var attr,
value,
predicate,
pivot;
// Traversing tree from the root to leaf
while(true) {
if (tree.category) {
// only leafs contains predicted category
return tree.category;
}
attr = tree.attribute;
value = item[attr];
predicate = tree.predicate;
pivot = tree.pivot;
// move to one of subtrees
if (predicate(value, pivot)) {
tree = tree.match;
} else {
tree = tree.notMatch;
}
}
}
/**
* Building array of decision trees
*/
function buildRandomForest(builder, treesNumber) {
var items = builder.trainingSet;
// creating training sets for each tree
var trainingSets = [];
for (var t = 0; t < treesNumber; t++) {
trainingSets[t] = [];
}
for (var i = items.length - 1; i >= 0 ; i--) {
// assigning items to training sets of each tree
// using 'round-robin' strategy
var correspondingTree = i % treesNumber;
trainingSets[correspondingTree].push(items[i]);
}
// building decision trees
var forest = [];
for (var t = 0; t < treesNumber; t++) {
builder.trainingSet = trainingSets[t];
var tree = new DecisionTree(builder);
forest.push(tree);
}
return forest;
}
/**
* Each of decision tree classifying item
* ('voting' that item corresponds to some class).
*
* This function returns hash, which contains
* all classifying results, and number of votes
* which were given for each of classifying results
*/
function predictRandomForest(forest, item) {
var result = {};
for (var i in forest) {
var tree = forest[i];
var prediction = tree.predict(item);
result[prediction] = result[prediction] ? result[prediction] + 1 : 1;
}
return result;
}
var exports = {};
exports.DecisionTree = DecisionTree;
exports.RandomForest = RandomForest;
return exports;
})();