Commit 17bb7a2b authored by Markus Krug's avatar Markus Krug
Browse files

started the creation of the multiclass classifier

parent 87d4a58d
......@@ -169,6 +169,16 @@ public class RepresentationRule {
public double getPrecision() {
return precision;
}
public Integer getLabel() {
return label;
}
public void setLabel(Integer label) {
this.label = label;
}
@Override
public int hashCode() {
......
......@@ -24,13 +24,18 @@ public class RulePass {
*/
private int label;
public RulePass(int label) {
super();
ruleSet = new ArrayList<>();
this.label = label;
}
/**
* Default struct, only creates an empty rulelist
*/
public RulePass(int label) {
public RulePass() {
super();
ruleSet = new ArrayList<>();
this.label = label;
}
/**
......
package de.uniwue.ls6.datastructure;
import java.awt.Point;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
......@@ -169,8 +170,7 @@ public class Instance {
int kroneckerCol = colE1 + e2.column();
int kroneckerRow = rowE1 + e2.row();
expanedMatrix.add(kroneckerRow, kroneckerCol, 1);
expanedMatrix.add(kroneckerRow, kroneckerCol, 1);
if (e1.row() == e2.row() && e1.column() == e2.column())
continue outer;
......@@ -189,6 +189,80 @@ public class Instance {
return expanedMatrix;
}
// perform kronecker expansion should be executed in parallel for many
// instances
public FlexCompColMatrix expand2(List<MatrixMapping> mappings) {
MatrixMapping lastMapping = mappings.get(mappings.size() - 1);
// create dense matrix for instance
int denseDimension = lastMapping.getDenseMatrixDimension();
FlexCompColMatrix denseInstanceMatrix = new FlexCompColMatrix(denseDimension, denseDimension);
Map<Integer, List<Point>> cntMap = new HashMap<>();
int maxFeatureSubsetSize = 0;
for (Point denseIndices : lastMapping.getInverseMappingMap().keySet()) {
Set<Point> features = lastMapping.getFeaturesForDenseIndex(denseIndices);
// check if this instance contains the features
if (containsFeature(features)) {
// if so then create a sprase matrix and put a 1 into the
// according slot ;x is always the cols
denseInstanceMatrix.add(denseIndices.y, denseIndices.x, 1);
if (features.size() > maxFeatureSubsetSize)
maxFeatureSubsetSize = features.size();
if (cntMap.containsKey(features.size())) {
cntMap.get(features.size()).add(denseIndices);
} else {
List<Point> pointList = new ArrayList<>();
pointList.add(denseIndices);
cntMap.put(features.size(), pointList);
}
}
}
// perform the kronecker expansion from the non sparse dense elements
int kroneckerDimension = denseDimension * denseDimension;
FlexCompColMatrix expanedMatrix = new FlexCompColMatrix(kroneckerDimension, kroneckerDimension);
// keep all diagonal entries
for (MatrixEntry e : denseInstanceMatrix) {
int kroneckerCol = e.column() * denseDimension + e.column();
int kroneckerRow = e.row() * denseDimension + e.row();
expanedMatrix.set(kroneckerRow, kroneckerCol, 1);
}
// aside from those we expand only featuresets that (may)result in a
// featuresubset with > maxFeature.size features
for (int i = maxFeatureSubsetSize; i > 0; i--) {
List<Point> list = cntMap.get(i);
if (list == null)
continue;
// determine the max amount of iterations
for (int iter = i; iter > 0; iter--) {
if (iter + i <= maxFeatureSubsetSize)
break;
List<Point> listIter = cntMap.get(iter);
if (listIter == null)
continue;
// put the expansion
for (Point p : list) {
int colE1 = p.x * denseDimension;
int rowE1 = p.y * denseDimension;
for (Point pIter : listIter) {
// calculate the new indices
int kroneckerCol = colE1 + pIter.x;
int kroneckerRow = rowE1 + pIter.y;
expanedMatrix.set(kroneckerRow, kroneckerCol, 1);
}
}
}
}
return expanedMatrix;
}
@Override
public int hashCode() {
final int prime = 31;
......
......@@ -170,6 +170,16 @@ public class MatrixMcMatrixFace {
}
return sb.toString();
}
public int getGoldLabel() {
return goldLabel;
}
public void setGoldLabel(int goldLabel) {
this.goldLabel = goldLabel;
}
@Override
public int hashCode() {
......
package de.uniwue.ls6.datastructure;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class MultiClassMapping {
// saves the correspondences between matrices and a list of mappings
private Map<MatrixMcMatrixFace, List<MatrixMapping>> map;
public MultiClassMapping() {
map = new HashMap<>();
}
public void addMapping(MatrixMcMatrixFace matrix, MatrixMapping mapping) {
List<MatrixMapping> matrixMappings = map.get(matrix);
if (matrixMappings == null) {
matrixMappings = new ArrayList<>();
}
matrixMappings.add(mapping);
map.put(matrix, matrixMappings);
}
public List<MatrixMapping> getMatrixMapping(MatrixMcMatrixFace matrix) {
List<MatrixMapping> list = map.get(matrix);
if (list == null) {
return new ArrayList<>();
}
return list;
}
public void add(MultiClassMapping mappingForMaximum) {
for (MatrixMcMatrixFace matrix : mappingForMaximum.map.keySet()) {
List<MatrixMapping> list = mappingForMaximum.map.get(matrix);
this.map.get(matrix).addAll(list);
}
}
// remove all that are contained in mapping for max
public void remove(MultiClassMapping mappingForMaximum) {
for (MatrixMcMatrixFace matrix : mappingForMaximum.map.keySet()) {
List<MatrixMapping> list = mappingForMaximum.map.get(matrix);
this.map.get(matrix).removeAll(list);
}
}
public int maxSize() {
int maxSize = 0;
for (List<MatrixMapping> mappings : map.values()) {
if (mappings.size() > maxSize)
maxSize = mappings.size();
}
return maxSize;
}
}
package de.uniwue.ls6.datastructure;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import de.uniwue.ls6.util.MatrixPoint;
import no.uib.cipr.matrix.MatrixEntry;
public class MultiClassMatrixMcMatrixFace {
private List<MatrixMcMatrixFace> matrices;
private int windowSize;
public MultiClassMatrixMcMatrixFace(int amountFeatures, int windowSize, Collection<Integer> labels) {
this.windowSize = windowSize;
matrices = new ArrayList<>();
for (Integer label : labels) {
matrices.add(new MatrixMcMatrixFace(amountFeatures, windowSize, label));
}
}
public void addInstances(Instance... instances) {
public MultiClassMatrixMcMatrixFace() {
for (MatrixMcMatrixFace matrix : matrices) {
matrix.addInstance(instances);
}
}
public void removeInstance(Instance i) {
for (MatrixMcMatrixFace matrix : matrices) {
matrix.removeInstance(i);
}
}
public int getMaximumScore() {
double maxScore = Double.MIN_VALUE;
for (MatrixMcMatrixFace matrix : matrices) {
double scoreMa = matrix.getMaximumScore();
if (scoreMa > maxScore)
maxScore = scoreMa;
}
return (int) maxScore;
}
public int getWindowSize() {
return windowSize;
}
public void setWindowSize(int windowSize) {
this.windowSize = windowSize;
}
public MatrixPoint getLocationOfMaximum(MultiClassMapping mapping) {
double maxScore = Integer.MIN_VALUE;
MatrixMcMatrixFace matrixFace = null;
MatrixPoint bestPoint = null;
for(MatrixMcMatrixFace matrix : matrices){
MatrixPoint locationOfMaximum = matrix.getLocationOfMaximum(mapping.getMatrixMapping(matrix));
if(locationOfMaximum.getScore()>maxScore){
maxScore= locationOfMaximum.getScore();
matrixFace = matrix;
bestPoint=locationOfMaximum;
}
}
bestPoint.setAccordingMatrix(matrixFace);
return bestPoint;
}
}
......@@ -2,6 +2,8 @@ package de.uniwue.ls6.util;
import java.awt.Point;
import de.uniwue.ls6.datastructure.MatrixMcMatrixFace;
public class MatrixPoint {
private int x;
......@@ -9,6 +11,7 @@ public class MatrixPoint {
private double score;
private double fp;
private double tp;
private MatrixMcMatrixFace accordingMatrix;
public MatrixPoint(int x, int y, double score, double tp, double fp) {
super();
......@@ -34,6 +37,16 @@ public class MatrixPoint {
public void setY(int y) {
this.y = y;
}
public MatrixMcMatrixFace getAccordingMatrix() {
return accordingMatrix;
}
public void setAccordingMatrix(MatrixMcMatrixFace accordingMatrix) {
this.accordingMatrix = accordingMatrix;
}
public double getScore() {
return score;
......
package de.uniwue.ls6.util;
import java.awt.Point;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.StreamSupport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import de.uniwue.ls6.datastructure.Instance;
import de.uniwue.ls6.datastructure.LabelAlphabet;
import de.uniwue.ls6.datastructure.MatrixMapping;
import de.uniwue.ls6.datastructure.MatrixMcMatrixFace;
import de.uniwue.ls6.datastructure.MultiClassMapping;
import de.uniwue.ls6.datastructure.MultiClassMatrixMcMatrixFace;
import no.uib.cipr.matrix.MatrixEntry;
import no.uib.cipr.matrix.sparse.FlexCompColMatrix;
public class MulticlassMatrixUtil {
static final Logger logger = LoggerFactory.getLogger(MulticlassMatrixUtil.class);
public static MultiClassMapping getMappingForMaximum(MultiClassMatrixMcMatrixFace iterationMatrix, int maximum, Point maxEntryLocation,
MultiClassMapping multiClassMapping, Set<Instance> instances, Map<Point, Set<Instance>> indexMap, int beamSize) {
// TODO
return null;
}
public static MultiClassMatrixMcMatrixFace performKroneckerExpansion(MultiClassMapping multiClassMapping,
Collection<Instance> instances) {
// TODO
return null;
}
}
<?xml version="1.0" encoding="UTF-8"?>
<classpath>
<classpathentry kind="src" path="src"/>
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER/org.eclipse.jdt.internal.debug.ui.launcher.StandardVMType/J2SE-1.5">
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER/org.eclipse.jdt.internal.debug.ui.launcher.StandardVMType/JavaSE-1.7">
<attributes>
<attribute name="maven.pomderived" value="true"/>
</attributes>
......
eclipse.preferences.version=1
org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.5
org.eclipse.jdt.core.compiler.compliance=1.5
org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled
org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.7
org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve
org.eclipse.jdt.core.compiler.compliance=1.7
org.eclipse.jdt.core.compiler.debug.lineNumber=generate
org.eclipse.jdt.core.compiler.debug.localVariable=generate
org.eclipse.jdt.core.compiler.debug.sourceFile=generate
org.eclipse.jdt.core.compiler.problem.assertIdentifier=error
org.eclipse.jdt.core.compiler.problem.enumIdentifier=error
org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning
org.eclipse.jdt.core.compiler.source=1.5
org.eclipse.jdt.core.compiler.source=1.7
......@@ -4,7 +4,9 @@ import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.uima.cas.text.AnnotationFS;
......@@ -38,8 +40,14 @@ public class InListFeatureGenerator extends AFeatureGenerator {
public String[] generateFeatures(AnnotationFS token) {
if(listEntries.contains(token.getCoveredText())){
String id = token.getCoveredText().substring(0,Math.min(2, token.getCoveredText().length()));
return new String[]{super.featureIdentifier+id+"=IN_LIST" };
List<String> list = new ArrayList<>();
String id2 = token.getCoveredText().substring(0,Math.min(2, token.getCoveredText().length()));
list.add(super.featureIdentifier+id2+"IN_LIST");
String id1= token.getCoveredText().substring(0,Math.min(2, token.getCoveredText().length()));
list.add(super.featureIdentifier+id1+"IN_LIST");
String idFull = super.featureIdentifier+"=IN_LIST";
list.add(idFull);
return list.toArray(new String[0]);
}
return new String[]{super.featureIdentifier+"=NOT_IN_LIST" };
}
......
......@@ -19,7 +19,7 @@ import de.uniwue.ls6.rulelearning.instanceloading.featuregen.collection.FeatureG
public class FeatureDifferenceEnsembleClassifier implements IRepresentationRuleLearningAlgorithm {
List<MultiClassRepresentationRuleAlgorithm> classifier;
List<MultiClassOneVsAllRepresentationRuleAlgorithm> classifier;
private FeatureGeneratorCollection collection;
// beam size of every classifier
......@@ -54,7 +54,7 @@ public class FeatureDifferenceEnsembleClassifier implements IRepresentationRuleL
// change the instances to only contain the features in coll
Set<Instance> adjustedFeatureInstances = adjustInstances(instances, coll);
MultiClassRepresentationRuleAlgorithm algo = new MultiClassRepresentationRuleAlgorithm(beamSize);
MultiClassOneVsAllRepresentationRuleAlgorithm algo = new MultiClassOneVsAllRepresentationRuleAlgorithm(beamSize);
if (dropOut != -1) {
algo.setUseDropOut(randomSeed, dropOut);
}
......@@ -136,7 +136,7 @@ public class FeatureDifferenceEnsembleClassifier implements IRepresentationRuleL
public ALabelling apply(Instance instanceToClassify) {
List<ALabelling> labellings = new ArrayList<>();
for (MultiClassRepresentationRuleAlgorithm algo : classifier) {
for (MultiClassOneVsAllRepresentationRuleAlgorithm algo : classifier) {
ALabelling apply = algo.apply(instanceToClassify);
labellings.add(apply);
......@@ -151,7 +151,7 @@ public class FeatureDifferenceEnsembleClassifier implements IRepresentationRuleL
this.randomSeed = random;
}
public List<MultiClassRepresentationRuleAlgorithm> getClassifier() {
public List<MultiClassOneVsAllRepresentationRuleAlgorithm> getClassifier() {
return this.classifier;
}
......
......@@ -17,9 +17,9 @@ import de.uniwue.ls6.datastructure.LabelAlphabet;
import de.uniwue.ls6.datastructure.SimpleLabelling;
import de.uniwue.ls6.rulelearning.algorithm.IRepresentationRuleLearningAlgorithm;
public class MultiClassRepresentationRuleAlgorithm implements IRepresentationRuleLearningAlgorithm {
public class MultiClassOneVsAllRepresentationRuleAlgorithm implements IRepresentationRuleLearningAlgorithm {
final Logger logger = LoggerFactory.getLogger(MultiClassRepresentationRuleAlgorithm.class);
final Logger logger = LoggerFactory.getLogger(MultiClassOneVsAllRepresentationRuleAlgorithm.class);
List<BinaryRepresentationRuleLearningAlgorithm> binaryClassifiers;
private int beamSize;
......@@ -31,7 +31,7 @@ public class MultiClassRepresentationRuleAlgorithm implements IRepresentationRul
private Map<Instance, Integer> instanceToLabelMapping;
private Random randomSeed;
public MultiClassRepresentationRuleAlgorithm(int beamSize) {
public MultiClassOneVsAllRepresentationRuleAlgorithm(int beamSize) {
this.beamSize = beamSize;
this.skipClassifierTrainingForLabel = new ArrayList<>();
}
......
package de.uniwue.ls6.rulelearning.algorithm.impl;
import java.awt.Point;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import de.uniwue.ls6.algorithm.datastructure.RepresentationRule;
import de.uniwue.ls6.algorithm.datastructure.RulePass;
import de.uniwue.ls6.datastructure.ALabelling;
import de.uniwue.ls6.datastructure.Instance;
import de.uniwue.ls6.datastructure.LabelAlphabet;
import de.uniwue.ls6.datastructure.MatrixMapping;
import de.uniwue.ls6.datastructure.MatrixMcMatrixFace;
import de.uniwue.ls6.datastructure.MultiClassMapping;
import de.uniwue.ls6.datastructure.MultiClassMatrixMcMatrixFace;
import de.uniwue.ls6.datastructure.SimpleLabelling;
import de.uniwue.ls6.rulelearning.algorithm.IRepresentationRuleLearningAlgorithm;
import de.uniwue.ls6.util.MatrixPoint;
import de.uniwue.ls6.util.MatrixUtil;
import de.uniwue.ls6.util.MulticlassMatrixUtil;
public class MultiClassRuleLearningAlgorithm implements IRepresentationRuleLearningAlgorithm {
List<RulePass> passes;
private int beamSize;
private int maxExpandSize = 8;
// dropOut still keep dropOut many instances in % if the rule is applicable
private double dropOut = -1;
final Logger logger = LoggerFactory.getLogger(MultiClassRuleLearningAlgorithm.class);
private Map<Point, Set<Instance>> indexMap;
private Random randomSeed;
public MultiClassRuleLearningAlgorithm(int beamSize) {
passes = new LinkedList<RulePass>();
this.beamSize = beamSize;
}
public void learn(Instance... instances) {
if (instances.length == 0)
throw new IllegalArgumentException("Plz give data to train!");
logger.info("Start to index the instances");
indexInstances(instances);
Set<I