Commit 7120e22f authored by Markus Krug's avatar Markus Krug
Browse files

added an ensemble training method

parent 547f3cbb
......@@ -169,6 +169,47 @@ public class RepresentationRule {
public double getPrecision() {
return precision;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((conditionSet == null) ? 0 : conditionSet.hashCode());
result = prime * result + ((label == null) ? 0 : label.hashCode());
result = prime * result + maximumScore;
long temp;
temp = Double.doubleToLongBits(precision);
result = prime * result + (int) (temp ^ (temp >>> 32));
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
RepresentationRule other = (RepresentationRule) obj;
if (conditionSet == null) {
if (other.conditionSet != null)
return false;
} else if (!conditionSet.equals(other.conditionSet))
return false;
if (label == null) {
if (other.label != null)
return false;
} else if (!label.equals(other.label))
return false;
if (maximumScore != other.maximumScore)
return false;
if (Double.doubleToLongBits(precision) != Double.doubleToLongBits(other.precision))
return false;
return true;
}
......
......@@ -39,7 +39,7 @@ public class RulePass {
* @param ruleSet
* List of Rules, ordered by score
*/
public RulePass(List<RepresentationRule> ruleSet,int label) {
public RulePass(List<RepresentationRule> ruleSet, int label) {
super();
this.ruleSet = ruleSet;
this.label = label;
......@@ -52,7 +52,9 @@ public class RulePass {
* rule to be added to the pass
*/
public void addRule(RepresentationRule rule) {
this.ruleSet.add(rule);
if (!this.ruleSet.contains(rule)) {
this.ruleSet.add(rule);
}
}
/**
......@@ -88,8 +90,8 @@ public class RulePass {
}
return false;
}
public RepresentationRule apply(Instance ins){
public RepresentationRule apply(Instance ins) {
for (RepresentationRule rule : ruleSet) {
if (rule.isApplicable(ins))
return rule;
......@@ -104,7 +106,5 @@ public class RulePass {
public void setLabel(int label) {
this.label = label;
}
}
......@@ -109,6 +109,8 @@ public class MatrixMcMatrixFace {
}
// TODO this needs a speedup because it calculates all features backwards
// which is unnecessary since it is stored in the mapping
public MatrixPoint getLocationOfMaximum(List<MatrixMapping> mappings) {
double maxScore = 0;
MatrixPoint bestEntry = null;
......@@ -120,13 +122,13 @@ public class MatrixMcMatrixFace {
if (scoreCurrent > maxScore) {
maxScore = scoreCurrent;
bestEntry = new MatrixPoint(entry.column(), entry.row(), scoreCurrent, entry.get(), fps);
mostSimpleRule = MatrixUtil.determineFeaturesForIndex(bestEntry.getLocation(),
mappings, mappings.size() > 0 ? true : false);
mostSimpleRule = MatrixUtil.determineFeaturesForIndex(bestEntry.getLocation(), mappings,
mappings.size() > 0 ? true : false);
} else if (scoreCurrent == maxScore && scoreCurrent > 0) {
// keep the simpler rule
MatrixPoint loc = new MatrixPoint(entry.column(), entry.row(), scoreCurrent, entry.get(), fps);
Set<Point> featuresForIndex = MatrixUtil.determineFeaturesForIndex(loc.getLocation(),
mappings, mappings.size() > 0 ? true : false);
Set<Point> featuresForIndex = MatrixUtil.determineFeaturesForIndex(loc.getLocation(), mappings,
mappings.size() > 0 ? true : false);
if (mostSimpleRule == null) {
bestEntry = loc;
mostSimpleRule = featuresForIndex;
......
......@@ -71,25 +71,28 @@ public class MatrixUtil {
}
}
// // get the features of this point
// Set<Point> featuresOfPoints = determineFeaturesForIndex(new Point(entry.column(), entry.row()),
// mappings, mappings.size() > 0 ? true : false);
//
// // only keep each feature combination exactly once!
// if (uniqueFeatureCombinations.contains(featuresOfPoints)) {
// continue;
// }
// uniqueFeatureCombinations.add(featuresOfPoints);
//
// // furthermore we can filter all those featurecombinations that
// // resemble exactly the same instance set
// // because the algrotihm can not differ between those
// Set<Instance> instancesForFeatureSet = determineInstancesForFeatures(featuresOfPoints, indexMap);
//
// if (uniqueInstance.contains(instancesForFeatureSet)) {
// continue;
// }
// uniqueInstance.add(instancesForFeatureSet);
// // get the features of this point
// Set<Point> featuresOfPoints = determineFeaturesForIndex(new
// Point(entry.column(), entry.row()),
// mappings, mappings.size() > 0 ? true : false);
//
// // only keep each feature combination exactly once!
// if (uniqueFeatureCombinations.contains(featuresOfPoints)) {
// continue;
// }
// uniqueFeatureCombinations.add(featuresOfPoints);
//
// // furthermore we can filter all those featurecombinations
// that
// // resemble exactly the same instance set
// // because the algrotihm can not differ between those
// Set<Instance> instancesForFeatureSet =
// determineInstancesForFeatures(featuresOfPoints, indexMap);
//
// if (uniqueInstance.contains(instancesForFeatureSet)) {
// continue;
// }
// uniqueInstance.add(instancesForFeatureSet);
// finally we can decide to keep our feature for the next
// iteration
......@@ -124,7 +127,6 @@ public class MatrixUtil {
ArrayList<MatrixMapping> arrayList = new ArrayList<MatrixMapping>(mappings);
arrayList.add(matrixMapping);
matrixMapping.inferDenseMapValues(arrayList);
// System.out.println(entries.size()+"==");
// debug
......@@ -200,8 +202,8 @@ public class MatrixUtil {
fullFeats.addAll(featE1);
fullFeats.addAll(featE2);
if (uniqueFeatures.contains(fullFeats))
continue;
if (uniqueFeatures.contains(fullFeats))
continue;
uniqueFeatures.add(fullFeats);
Set<Instance> fullInstances = new HashSet<>();
......@@ -211,8 +213,8 @@ public class MatrixUtil {
fullInstances.addAll(insE1);
fullInstances.retainAll(insE2);
if (uniqueInstances.contains(fullInstances))
continue;
if (uniqueInstances.contains(fullInstances))
continue;
uniqueInstances.add(fullInstances);
......@@ -268,6 +270,8 @@ public class MatrixUtil {
.collect(Collector.of(matrixConstructor, accumulator, join, Collector.Characteristics.UNORDERED));
}
// TODO speed this up by a simple lookup of the indices in the previous
// mapping
public static Set<Point> determineFeaturesForIndex(Point index, List<MatrixMapping> mappings,
boolean revertKroneckerFirst) {
......
package de.uniwue.ls6.rulelearning.instanceloading.featuregen.collection;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import de.uniwue.ls6.rulelearning.instanceloading.featuregenerator.AFeatureGenerator;
public class FeatureGeneratorCollection {
private List<AFeatureGenerator> featureGenerator;
private AFeatureGenerator goldGen;
public FeatureGeneratorCollection() {
featureGenerator = new ArrayList<AFeatureGenerator>();
}
public void setGoldGenerator(AFeatureGenerator goldGen) {
this.goldGen = goldGen;
}
public void addFeatureGenerator(AFeatureGenerator gen) {
this.featureGenerator.add(gen);
}
public void addFeatureGenerator(List<AFeatureGenerator> gens) {
this.featureGenerator.addAll(gens);
}
public void removeFeatureGenerator(AFeatureGenerator gen) {
this.featureGenerator.remove(gen);
}
public Set<FeatureGeneratorCollection> getFeatureCombinations(int maxClassifierSize) {
// TODO Auto-generated method stub
// perform a BFS through all subsets of feature collections starting at
// the full feature set
// init
FeatureGeneratorCollection initColl = new FeatureGeneratorCollection();
initColl.addFeatureGenerator(featureGenerator);
Set<FeatureGeneratorCollection> completeCollection = new HashSet<FeatureGeneratorCollection>();
completeCollection.add(initColl);
List<FeatureGeneratorCollection> bfsList = new ArrayList<FeatureGeneratorCollection>();
bfsList.add(initColl);
while (bfsList.size() != 0) {
// expand the last state
FeatureGeneratorCollection stateInFocus = bfsList.remove(0);
// expand that state
List<FeatureGeneratorCollection> expandedStates = expandState(stateInFocus);
// add those to bfslist
bfsList.addAll(expandedStates);
completeCollection.addAll(expandedStates);
// check if enough classifier have been created
if (completeCollection.size() > maxClassifierSize)
break;
}
return completeCollection;
}
private List<FeatureGeneratorCollection> expandState(FeatureGeneratorCollection stateInFocus) {
List<FeatureGeneratorCollection> expandedStates = new ArrayList<FeatureGeneratorCollection>();
for (int i = 0; i < stateInFocus.featureGenerator.size(); i++) {
// create a new state and add all but i
FeatureGeneratorCollection collNew = new FeatureGeneratorCollection();
for (int k = 0; k < stateInFocus.featureGenerator.size(); k++) {
if (k == i)
continue;
else {
collNew.addFeatureGenerator(stateInFocus.featureGenerator.get(k));
}
}
expandedStates.add(collNew);
}
return expandedStates;
}
public Set<String> getFeatureIds() {
Set<String> ids = new HashSet<String>();
for (AFeatureGenerator gen : featureGenerator) {
ids.add(gen.getFeatureId());
}
return ids;
}
public AFeatureGenerator[] getGenerators() {
return featureGenerator.toArray(new AFeatureGenerator[0]);
}
}
......@@ -12,4 +12,8 @@ public abstract class AFeatureGenerator {
public abstract String[] generateFeatures(AnnotationFS token);
public String getFeatureId() {
return this.featureIdentifier;
}
}
......@@ -9,6 +9,7 @@ import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.slf4j.Logger;
......@@ -25,7 +26,6 @@ import de.uniwue.ls6.datastructure.SimpleLabelling;
import de.uniwue.ls6.rulelearning.algorithm.IRepresentationRuleLearningAlgorithm;
import de.uniwue.ls6.util.MatrixPoint;
import de.uniwue.ls6.util.MatrixUtil;
import no.uib.cipr.matrix.Matrices;
public class BinaryRepresentationRuleLearningAlgorithm implements IRepresentationRuleLearningAlgorithm {
......@@ -33,11 +33,14 @@ public class BinaryRepresentationRuleLearningAlgorithm implements IRepresentatio
private int goldLabel;
private int otherLabel;
private int beamSize;
private int lookahead;
private int maxExpandSize=8;
// dropOut still keep dropOut many instances in % if the rule is applicable
private double dropOut = -1;
final Logger logger = LoggerFactory.getLogger(BinaryRepresentationRuleLearningAlgorithm.class);
private Map<Point, Set<Instance>> indexMap;
private Random randomSeed;
public BinaryRepresentationRuleLearningAlgorithm(int goldLabel, int otherLabel, int beamSize) {
passes = new LinkedList<RulePass>();
......@@ -114,6 +117,9 @@ public class BinaryRepresentationRuleLearningAlgorithm implements IRepresentatio
logger.info("Trainingset binary Labelaccuracy: " + labelAccuracy * 100 + "%");
logger.info("Trainingset binary Precision: " + prec * 100 + "%");
logger.info("Trainingset binary Recall: " + rec * 100 + "%");
//free the expensive object this save loads of ram!
indexMap=null;
}
......@@ -197,10 +203,29 @@ public class BinaryRepresentationRuleLearningAlgorithm implements IRepresentatio
Set<Instance> remainingInstances = new HashSet<Instance>();
for (Instance ins : temporaryCopy) {
// if dropout is used
boolean dropOutInstance = false;
if (dropOut != -1) {
double nextDouble = randomSeed.nextDouble();
if (nextDouble < dropOut) {
dropOutInstance = true;
}
}
if (!learnedRule.isApplicable(ins)) {
remainingInstances.add(ins);
} else {
initialMatrix.removeInstance(ins);
// rule is applicable
if (dropOutInstance) {
// if dropOut is true then we keep it still
remainingInstances.add(ins);
// otherwise we drop it!
} else {
initialMatrix.removeInstance(ins);
}
}
}
return remainingInstances;
......@@ -222,7 +247,6 @@ public class BinaryRepresentationRuleLearningAlgorithm implements IRepresentatio
// initial values
int maximumScore = 0;
MatrixPoint maxEntryLocation = iterationMatrix.getLocationOfMaximum(mappings);
lookahead = 0;
while (true) {
maximumScore = iterationMatrix.getMaximumScore();
......@@ -243,8 +267,8 @@ public class BinaryRepresentationRuleLearningAlgorithm implements IRepresentatio
// expand in kronecker fashion
long time = System.currentTimeMillis();
iterationMatrix = MatrixUtil.performKroneckerExpansionByIndex(mappings, instances,indexMap,goldLabel);
System.out.println(System.currentTimeMillis() - time);
iterationMatrix = MatrixUtil.performKroneckerExpansionByIndex(mappings, instances, indexMap, goldLabel);
// System.out.println(System.currentTimeMillis() - time);
// assert that the maximum is growing
assert (maximumScore <= iterationMatrix.getMaximumScore()) : "Maximum decreased within iteration!";
......@@ -257,7 +281,7 @@ public class BinaryRepresentationRuleLearningAlgorithm implements IRepresentatio
// determine the best rule that has not been used before! TODO is this
// even necessary???
if (maxEntryLocation == null || maximumScore<=1)
if (maxEntryLocation == null || maximumScore <= 1)
return null;
// rcalculate the rule
......@@ -275,6 +299,8 @@ public class BinaryRepresentationRuleLearningAlgorithm implements IRepresentatio
private boolean betterRuleCanBeLearned(List<MatrixMapping> mappings) {
if (mappings.size() > maxExpandSize)
return false;
if (mappings.size() < 2)
return true;
return betterRuleCanBeLearned(mappings.get(mappings.size() - 2), mappings.get(mappings.size() - 1));
......@@ -289,23 +315,23 @@ public class BinaryRepresentationRuleLearningAlgorithm implements IRepresentatio
return false;
}
// compare the feature sets "deep" that means we convert the features to
// instances and compare on instances!
Set<Set<Instance>> lastMappingInstanceSets = new HashSet<>();
for (Set<Point> lastFeats : lastMapping.getDenseIndexToFeaturesMapping().values()) {
lastMappingInstanceSets.add(MatrixUtil.determineInstancesForFeatures(lastFeats, indexMap));
}
//perform deep equals
for(Set<Point> newFeats : newMapping.getDenseIndexToFeaturesMapping().values()){
if(!lastMappingInstanceSets.contains(MatrixUtil.determineInstancesForFeatures(newFeats, indexMap))){
// perform deep equals
for (Set<Point> newFeats : newMapping.getDenseIndexToFeaturesMapping().values()) {
if (!lastMappingInstanceSets.contains(MatrixUtil.determineInstancesForFeatures(newFeats, indexMap))) {
return true;
}
}
//System.out.println("better");
//System.out.println(lastMapping.getInverseMappingMap().size() + "\t" + newMapping.getInverseMappingMap().size());
// System.out.println("better");
// System.out.println(lastMapping.getInverseMappingMap().size() + "\t" +
// newMapping.getInverseMappingMap().size());
return false;
}
......@@ -313,7 +339,7 @@ public class BinaryRepresentationRuleLearningAlgorithm implements IRepresentatio
private boolean betterRuleCanBeLearned(int maximumScore, MatrixMcMatrixFace matrixInFocus) {
if (matrixInFocus.getMaximumScore() > maximumScore) {
lookahead = 0;
maxExpandSize = 0;
return true;
} else {
if (matrixInFocus.getMaximumScore() < maximumScore) {
......@@ -323,8 +349,8 @@ public class BinaryRepresentationRuleLearningAlgorithm implements IRepresentatio
// System.out.println("WHY!");
return false;
}
if (lookahead < 7) {
lookahead++;
if (maxExpandSize < 7) {
maxExpandSize++;
return true;
}
}
......@@ -352,6 +378,11 @@ public class BinaryRepresentationRuleLearningAlgorithm implements IRepresentatio
return goldLabel;
}
public void setUseDropOut(Random random, double dropOutFactor) {
this.dropOut = dropOutFactor;
this.randomSeed = random;
}
public void setGoldLabel(int goldLabel) {
this.goldLabel = goldLabel;
}
......
package de.uniwue.ls6.rulelearning.algorithm.impl;
public enum EEnsembleClassificationMethod {Probability,MajorityVote;
}
package de.uniwue.ls6.rulelearning.algorithm.impl;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import de.uniwue.ls6.datastructure.ALabelling;
import de.uniwue.ls6.datastructure.Instance;
import de.uniwue.ls6.datastructure.LabelAlphabet;
import de.uniwue.ls6.datastructure.SimpleLabelling;
import de.uniwue.ls6.rulelearning.algorithm.IRepresentationRuleLearningAlgorithm;
import de.uniwue.ls6.rulelearning.instanceloading.featuregen.collection.FeatureGeneratorCollection;
public class FeatureDifferenceEnsembleClassifier implements IRepresentationRuleLearningAlgorithm {
List<MultiClassRepresentationRuleAlgorithm> classifier;
private FeatureGeneratorCollection collection;
// beam size of every classifier
private int beamSize;
private int maxClassifierSize = -1;
// dropOut still keep dropOut many instances in % if the rule is applicable
private double dropOut = -1;
private Random randomSeed;
private EEnsembleClassificationMethod classificationMethod;
public FeatureDifferenceEnsembleClassifier(FeatureGeneratorCollection collection,
EEnsembleClassificationMethod method, int beamSize) {
this.collection = collection;
this.beamSize = beamSize;
classificationMethod = method;
classifier = new ArrayList<>();
}
public void setMaxClassifierSize(int maxClassifierSize) {
this.maxClassifierSize = maxClassifierSize;
}
@Override
public void learn(Instance... instances) {
// learn maxClassifierSize different multiclass classifier
Set<FeatureGeneratorCollection> featureCollections = collection.getFeatureCombinations(maxClassifierSize);
for (FeatureGeneratorCollection coll : featureCollections) {
// change the instances to only contain the features in coll
Set<Instance> adjustedFeatureInstances = adjustInstances(instances, coll);
MultiClassRepresentationRuleAlgorithm algo = new MultiClassRepresentationRuleAlgorithm(beamSize);
if (dropOut != -1) {
algo.setUseDropOut(randomSeed, dropOut);
}
algo.learn(adjustedFeatureInstances.toArray(new Instance[0]));