Commit 29395dd3 authored by Markus Krug's avatar Markus Krug
Browse files

small refactoring

parent 7120e22f
......@@ -13,6 +13,8 @@ import de.uniwue.ls6.datastructure.Instance;
import de.uniwue.ls6.datastructure.LabelAlphabet;
import de.uniwue.ls6.datastructure.SimpleLabelling;
import de.uniwue.ls6.rulelearning.algorithm.IRepresentationRuleLearningAlgorithm;
import de.uniwue.ls6.rulelearning.ensemble.EEnsembleClassificationMethod;
import de.uniwue.ls6.rulelearning.ensemble.EnsembleClassificationMethodfactory;
import de.uniwue.ls6.rulelearning.instanceloading.featuregen.collection.FeatureGeneratorCollection;
public class FeatureDifferenceEnsembleClassifier implements IRepresentationRuleLearningAlgorithm {
......@@ -47,8 +49,7 @@ public class FeatureDifferenceEnsembleClassifier implements IRepresentationRuleL
// learn maxClassifierSize different multiclass classifier
Set<FeatureGeneratorCollection> featureCollections = collection.getFeatureCombinations(maxClassifierSize);
for (FeatureGeneratorCollection coll : featureCollections) {
featureCollections.parallelStream().forEach((FeatureGeneratorCollection coll) -> {
// change the instances to only contain the features in coll
Set<Instance> adjustedFeatureInstances = adjustInstances(instances, coll);
......@@ -59,8 +60,11 @@ public class FeatureDifferenceEnsembleClassifier implements IRepresentationRuleL
}
algo.learn(adjustedFeatureInstances.toArray(new Instance[0]));
classifier.add(algo);
}
synchronized (classifier) {
classifier.add(algo);
}
});
}
......@@ -80,8 +84,8 @@ public class FeatureDifferenceEnsembleClassifier implements IRepresentationRuleL
if (list.size() > max)
max = list.size();
}
//this is a clone
Instance clone = new Instance(i.getNrCols(), max,i.getLabel(), i.getId());
// this is a clone
Instance clone = new Instance(i.getNrCols(), max, i.getLabel(), i.getId());
// set the values
for (int col = 0; col < retainedFeatures.size(); col++) {
......@@ -138,68 +142,9 @@ public class FeatureDifferenceEnsembleClassifier implements IRepresentationRuleL
labellings.add(apply);
}
if (classificationMethod == EEnsembleClassificationMethod.MajorityVote) {
return getMajorityLabel(labellings);
} else if (classificationMethod == EEnsembleClassificationMethod.Probability) {
return getMostProbableLabel(labellings);
}
// unsupported!!
return null;
}
private ALabelling getMostProbableLabel(List<ALabelling> labellings) {
Map<Integer, Double> labelCntMap = new HashMap<>();
for (ALabelling l : labellings) {
if (labelCntMap.containsKey(l.getLabel())) {
Double prob = labelCntMap.get(l.getLabel());
prob += l.getScore();
labelCntMap.put(l.getLabel(), prob);
} else {
labelCntMap.put(l.getLabel(), l.getScore());
}
}
double maxScore = 0.0d;
int bestLabel = -1;
for (Integer label : labelCntMap.keySet()) {
if (labelCntMap.get(label) > maxScore) {
maxScore = labelCntMap.get(label);
bestLabel = label;
}
}
return new SimpleLabelling(bestLabel, maxScore);
return EnsembleClassificationMethodfactory.createLabellingStrategy(classificationMethod).classify(labellings);
}
private ALabelling getMajorityLabel(List<ALabelling> labellings) {
Map<Integer, Integer> labelCntMap = new HashMap<>();
for (ALabelling l : labellings) {
if (labelCntMap.containsKey(l.getLabel())) {
Integer integer = labelCntMap.get(l.getLabel());
integer++;
labelCntMap.put(l.getLabel(), integer);
} else {
labelCntMap.put(l.getLabel(), 1);
}
}
int maxFreq = 0;
int bestLabel = -1;
for (Integer label : labelCntMap.keySet()) {
if (labelCntMap.get(label) > maxFreq) {
maxFreq = labelCntMap.get(label);
bestLabel = label;
}
}
return new SimpleLabelling(bestLabel, maxFreq);
}
public void setUseDropOut(Random random, double dropOutFactor) {
this.dropOut = dropOutFactor;
......
package de.uniwue.ls6.rulelearning.algorithm.impl;
package de.uniwue.ls6.rulelearning.ensemble;
public enum EEnsembleClassificationMethod {Probability,MajorityVote;
public enum EEnsembleClassificationMethod {Probability,MajorityVote,Max_Average_Probability,Max_Probability;
}
package de.uniwue.ls6.rulelearning.ensemble;
public class EnsembleClassificationMethodfactory {
public static IEnsembleLabellingStrategy createLabellingStrategy(EEnsembleClassificationMethod method) {
switch (method) {
case Probability:
return new MaxProbabilityLabellingStrategy();
case MajorityVote:
return new MajorityVoteLabellingStrategy();
case Max_Average_Probability:
return null;
default:
// Unsupported
return null;
}
}
}
package de.uniwue.ls6.rulelearning.ensemble;
import java.util.Arrays;
import java.util.List;
import de.uniwue.ls6.datastructure.ALabelling;
import de.uniwue.ls6.datastructure.Instance;
import de.uniwue.ls6.rulelearning.algorithm.IRepresentationRuleLearningAlgorithm;
public class EnsembleClassifierApplication {
private List<IRepresentationRuleLearningAlgorithm> algorithms;
// this class can not train but only classify
public EnsembleClassifierApplication(IRepresentationRuleLearningAlgorithm... algos) {
algorithms = Arrays.asList(algos);
}
public ALabelling classify(EEnsembleClassificationMethod classificationMethod, Instance i) {
return null;
}
}
package de.uniwue.ls6.rulelearning.ensemble;
import java.util.List;
import de.uniwue.ls6.datastructure.ALabelling;
public interface IEnsembleLabellingStrategy {
public ALabelling classify(List<ALabelling> labellings);
}
package de.uniwue.ls6.rulelearning.ensemble;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import de.uniwue.ls6.datastructure.ALabelling;
import de.uniwue.ls6.datastructure.SimpleLabelling;
public class MajorityVoteLabellingStrategy implements IEnsembleLabellingStrategy{
@Override
public ALabelling classify(List<ALabelling> labellings) {
Map<Integer, Integer> labelCntMap = new HashMap<>();
for (ALabelling l : labellings) {
if (labelCntMap.containsKey(l.getLabel())) {
Integer integer = labelCntMap.get(l.getLabel());
integer++;
labelCntMap.put(l.getLabel(), integer);
} else {
labelCntMap.put(l.getLabel(), 1);
}
}
int maxFreq = 0;
int bestLabel = -1;
for (Integer label : labelCntMap.keySet()) {
if (labelCntMap.get(label) > maxFreq) {
maxFreq = labelCntMap.get(label);
bestLabel = label;
}
}
return new SimpleLabelling(bestLabel, maxFreq);
}
}
package de.uniwue.ls6.rulelearning.ensemble;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import de.uniwue.ls6.datastructure.ALabelling;
import de.uniwue.ls6.datastructure.SimpleLabelling;
public class MaxProbabilityLabellingStrategy implements IEnsembleLabellingStrategy{
@Override
public ALabelling classify(List<ALabelling> labellings) {
Map<Integer, Double> labelCntMap = new HashMap<>();
for (ALabelling l : labellings) {
if (labelCntMap.containsKey(l.getLabel())) {
Double prob = labelCntMap.get(l.getLabel());
prob += l.getScore();
labelCntMap.put(l.getLabel(), prob);
} else {
labelCntMap.put(l.getLabel(), l.getScore());
}
}
double maxScore = 0.0d;
int bestLabel = -1;
for (Integer label : labelCntMap.keySet()) {
if (labelCntMap.get(label) > maxScore) {
maxScore = labelCntMap.get(label);
bestLabel = label;
}
}
return new SimpleLabelling(bestLabel, maxScore);
}
}
......@@ -32,15 +32,16 @@ public class FirstTest {
File doc2 = new File("resources\\Ahlefeld,-Charlotte-von_Erna1421[Lukas].xmi.xmi");
File typesystem = new File("resources\\MiKalliTypesystem.xml");
File bigDoc = new File("C:\\Users\\mkrug\\annoTest\\TestProject\\output\\hp5_utf8.txt.xmi");
File conllTrainXmi = new File("C:\\sandbox-Markus\\ner\\testConll\\conlldeu.train.xmi");
File korpusFOlder = new File("X:\\Neuer Ordner\\output+speech");
MultiClassRepresentationRuleAlgorithm algorithm = new MultiClassRepresentationRuleAlgorithm(1000);
algorithm.setUseDropOut(new Random(42111337), 0.9);
MultiClassRepresentationRuleAlgorithm algorithm = new MultiClassRepresentationRuleAlgorithm(100);
algorithm.setUseDropOut(new Random(42111337), 0.7);
TypeSystemDescription tsd = TypeSystemDescriptionFactory
.createTypeSystemDescriptionFromPath(typesystem.toURL().toString());
List<Instance> instances = InstanceCreationFactory.createWindowedInstancesFromUIMA(document, 0,0, 0,
List<Instance> instances = InstanceCreationFactory.createWindowedInstancesFromUIMA(conllTrainXmi, 0,1, 1,
"de.uniwue.kalimachos.coref.type.POS", tsd, new POSTagFeatureGenerator("POSTag"),
new WordFeaturegenerator(), new IsUppercaseFeatureGenerator(), new PrefixNGenerator(3),
new NGramGenerator(),new WordCategorization());
......
......@@ -11,9 +11,9 @@ import org.apache.uima.resource.metadata.TypeSystemDescription;
import de.uniwue.ls6.datastructure.ALabelling;
import de.uniwue.ls6.datastructure.Instance;
import de.uniwue.ls6.datastructure.SimpleLabelling;
import de.uniwue.ls6.rulelearning.algorithm.impl.EEnsembleClassificationMethod;
import de.uniwue.ls6.rulelearning.algorithm.impl.FeatureDifferenceEnsembleClassifier;
import de.uniwue.ls6.rulelearning.algorithm.impl.MultiClassRepresentationRuleAlgorithm;
import de.uniwue.ls6.rulelearning.ensemble.EEnsembleClassificationMethod;
import de.uniwue.ls6.rulelearning.evaluation.eval.LabelAccuracyEvaluation;
import de.uniwue.ls6.rulelearning.evaluation.fold.FoldUtil;
import de.uniwue.ls6.rulelearning.evaluation.fold.UnstructuredFold;
......@@ -35,17 +35,18 @@ public class FirstTestEnsemble {
File doc2 = new File("resources\\Ahlefeld,-Charlotte-von_Erna1421[Lukas].xmi.xmi");
File typesystem = new File("resources\\MiKalliTypesystem.xml");
File bigDoc = new File("C:\\Users\\mkrug\\annoTest\\TestProject\\output\\hp5_utf8.txt.xmi");
File conllTrainXmi = new File("C:\\sandbox-Markus\\ner\\testConll\\conlldeu.train.xmi");
File korpusFOlder = new File("X:\\Neuer Ordner\\output+speech");
FeatureGeneratorCollection collection = createCollection();
FeatureDifferenceEnsembleClassifier ensemble = new FeatureDifferenceEnsembleClassifier(collection, EEnsembleClassificationMethod.Probability, 100);
ensemble.setMaxClassifierSize(50);
ensemble.setUseDropOut(new Random(42111337), 0.9);
FeatureDifferenceEnsembleClassifier ensemble = new FeatureDifferenceEnsembleClassifier(collection,
EEnsembleClassificationMethod.Probability, 100);
ensemble.setMaxClassifierSize(40);
ensemble.setUseDropOut(new Random(42111337), 0.8);
TypeSystemDescription tsd = TypeSystemDescriptionFactory
.createTypeSystemDescriptionFromPath(typesystem.toURL().toString());
List<Instance> instances = InstanceCreationFactory.createWindowedInstancesFromUIMA(document, 0,2, 2,
List<Instance> instances = InstanceCreationFactory.createWindowedInstancesFromUIMA(conllTrainXmi, 0, 1, 1,
"de.uniwue.kalimachos.coref.type.POS", tsd, new POSTagFeatureGenerator("POSTag"),
collection.getGenerators());
......@@ -58,7 +59,6 @@ public class FirstTestEnsemble {
// evaluate
List<ALabelling> goldLabels = new ArrayList<>();
List<ALabelling> systemLabels = new ArrayList<>();
......@@ -69,9 +69,9 @@ public class FirstTestEnsemble {
// evaluate this fold
System.out.println(ensemble.getClassifier().size()+"==");
String evaluateToString = new LabelAccuracyEvaluation().evaluateToString(goldLabels.toArray(new ALabelling[0]),
systemLabels.toArray(new ALabelling[0]));
System.out.println(ensemble.getClassifier().size() + "==");
String evaluateToString = new LabelAccuracyEvaluation()
.evaluateToString(goldLabels.toArray(new ALabelling[0]), systemLabels.toArray(new ALabelling[0]));
System.out.println(evaluateToString);
break;
}
......@@ -80,14 +80,13 @@ public class FirstTestEnsemble {
private static FeatureGeneratorCollection createCollection() {
FeatureGeneratorCollection collection = new FeatureGeneratorCollection();
collection.addFeatureGenerator(new WordFeaturegenerator());
collection.addFeatureGenerator( new IsUppercaseFeatureGenerator());
collection.addFeatureGenerator( new PrefixNGenerator(3));
collection.addFeatureGenerator(new IsUppercaseFeatureGenerator());
collection.addFeatureGenerator(new PrefixNGenerator(3));
collection.addFeatureGenerator(new NGramGenerator());
collection.addFeatureGenerator(new WordCategorization());
return collection;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment