Commit f286719e authored by Markus Krug's avatar Markus Krug
Browse files

Merge branch 'ruleLearningDropout' into 'master'

Rule learning dropout



See merge request !2
parents 547f3cbb f2dcc9c0
......@@ -169,6 +169,47 @@ public class RepresentationRule {
public double getPrecision() {
return precision;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((conditionSet == null) ? 0 : conditionSet.hashCode());
result = prime * result + ((label == null) ? 0 : label.hashCode());
result = prime * result + maximumScore;
long temp;
temp = Double.doubleToLongBits(precision);
result = prime * result + (int) (temp ^ (temp >>> 32));
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
RepresentationRule other = (RepresentationRule) obj;
if (conditionSet == null) {
if (other.conditionSet != null)
return false;
} else if (!conditionSet.equals(other.conditionSet))
return false;
if (label == null) {
if (other.label != null)
return false;
} else if (!label.equals(other.label))
return false;
if (maximumScore != other.maximumScore)
return false;
if (Double.doubleToLongBits(precision) != Double.doubleToLongBits(other.precision))
return false;
return true;
}
......
......@@ -39,7 +39,7 @@ public class RulePass {
* @param ruleSet
* List of Rules, ordered by score
*/
public RulePass(List<RepresentationRule> ruleSet,int label) {
public RulePass(List<RepresentationRule> ruleSet, int label) {
super();
this.ruleSet = ruleSet;
this.label = label;
......@@ -52,7 +52,9 @@ public class RulePass {
* rule to be added to the pass
*/
public void addRule(RepresentationRule rule) {
this.ruleSet.add(rule);
if (!this.ruleSet.contains(rule)) {
this.ruleSet.add(rule);
}
}
/**
......@@ -88,8 +90,8 @@ public class RulePass {
}
return false;
}
public RepresentationRule apply(Instance ins){
public RepresentationRule apply(Instance ins) {
for (RepresentationRule rule : ruleSet) {
if (rule.isApplicable(ins))
return rule;
......@@ -104,7 +106,5 @@ public class RulePass {
public void setLabel(int label) {
this.label = label;
}
}
......@@ -7,9 +7,11 @@ public abstract class ALabelling {
protected int label;
protected double score;
protected String stringLabel;
public ALabelling(int label,double score){
this.label = label;
this.score = score;
this.stringLabel = LabelAlphabet.getFeatureToId(label);
}
public int getLabel() {
return label;
......@@ -23,6 +25,13 @@ public abstract class ALabelling {
public void setScore(double score) {
this.score = score;
}
public String getStringLabel() {
return stringLabel;
}
public void setStringLabel(String stringLabel) {
this.stringLabel = stringLabel;
}
......
package de.uniwue.ls6.datastructure;
import java.awt.Point;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentHashMap.KeySetView;
import no.uib.cipr.matrix.Matrices;
import no.uib.cipr.matrix.MatrixEntry;
......@@ -155,16 +160,17 @@ public class Instance {
FlexCompColMatrix expanedMatrix = new FlexCompColMatrix(kroneckerDimension, kroneckerDimension);
outer: for (MatrixEntry e1 : denseInstanceMatrix) {
int colE1 = e1.column() * denseDimension;
int rowE1 = e1.row() * denseDimension;
// iterate over all no sparse elements and expand
for (MatrixEntry e2 : denseInstanceMatrix) {
// // skip diagonal expansion TODO ??
// calculate the new indices
int kroneckerCol = e1.column() * denseDimension + e2.column();
int kroneckerRow = e1.row() * denseDimension + e2.row();
int kroneckerCol = colE1 + e2.column();
int kroneckerRow = rowE1 + e2.row();
expanedMatrix.add(kroneckerRow, kroneckerCol, 1);
expanedMatrix.add(kroneckerRow, kroneckerCol, 1);
if (e1.row() == e2.row() && e1.column() == e2.column())
continue outer;
......
......@@ -2,6 +2,7 @@ package de.uniwue.ls6.datastructure;
import java.awt.Point;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
......@@ -12,7 +13,8 @@ public class MatrixMapping {
// x is col and y is row
HashMap<Point, Point> mappingMap;
HashMap<Point, Point> inverseMappingMap;
HashMap<Point,Set<Point>> denseIndexToFeaturesMapping;
HashMap<Point, Set<Point>> denseIndexToFeaturesMapping;
HashSet<Point> forbiddenIndexes;
//
private int kroneckerDimension;
......@@ -30,6 +32,7 @@ public class MatrixMapping {
this.inverseMappingMap = inverseMap;
this.kroneckerDimension = kroneckerDimension;
this.denseIndexToFeaturesMapping = new HashMap<>();
this.forbiddenIndexes = new HashSet<>();
}
public MatrixMapping(int kroneckerDimension) {
......@@ -38,6 +41,7 @@ public class MatrixMapping {
this.inverseMappingMap = new HashMap<Point, Point>();
this.kroneckerDimension = kroneckerDimension;
this.denseIndexToFeaturesMapping = new HashMap<>();
this.forbiddenIndexes = new HashSet<>();
}
public HashMap<Point, Point> getMappingMap() {
......@@ -68,6 +72,10 @@ public class MatrixMapping {
this.mappingMap.put(key, null);
}
public boolean isForbiddenIndex(Point p) {
return this.forbiddenIndexes.contains(p);
}
// this method generates all values based on the keys
public void inferDenseMapValues(List<MatrixMapping> mappings) {
int numCols = (int) Math.ceil(Math.sqrt(mappingMap.keySet().size()));
......@@ -78,10 +86,10 @@ public class MatrixMapping {
inverseMappingMap.put(value, key);
index++;
}
//also infer the features
for(Point p : inverseMappingMap.keySet()){
denseIndexToFeaturesMapping.put(p,MatrixUtil.determineFeaturesForIndex(p, mappings,false));
// also infer the features
for (Point p : inverseMappingMap.keySet()) {
denseIndexToFeaturesMapping.put(p, MatrixUtil.determineFeaturesForIndex(p, mappings, false));
}
}
......@@ -98,8 +106,8 @@ public class MatrixMapping {
public int getKroneckerMatrixDimension() {
return kroneckerDimension;
}
public Set<Point> getFeaturesForDenseIndex(Point densePoint){
public Set<Point> getFeaturesForDenseIndex(Point densePoint) {
return denseIndexToFeaturesMapping.get(densePoint);
}
......@@ -121,4 +129,9 @@ public class MatrixMapping {
return sb.toString();
}
public void addForbiddenIndex(Point point) {
this.forbiddenIndexes.add(point);
}
}
......@@ -85,7 +85,7 @@ public class MatrixMcMatrixFace {
*/
public int getMaximumScore() {
double maxScore = 0;
double maxScore = Double.MIN_VALUE;
for (MatrixEntry entry : tpMatrix) {
double scoreCurrent = entry.get() - fpMatrix.get(entry.row(), entry.column());
......@@ -109,8 +109,10 @@ public class MatrixMcMatrixFace {
}
// TODO this needs a speedup because it calculates all features backwards
// which is unnecessary since it is stored in the mapping
public MatrixPoint getLocationOfMaximum(List<MatrixMapping> mappings) {
double maxScore = 0;
double maxScore = Double.MIN_VALUE;
MatrixPoint bestEntry = null;
Set<Point> mostSimpleRule = null;
for (MatrixEntry entry : tpMatrix) {
......@@ -120,13 +122,13 @@ public class MatrixMcMatrixFace {
if (scoreCurrent > maxScore) {
maxScore = scoreCurrent;
bestEntry = new MatrixPoint(entry.column(), entry.row(), scoreCurrent, entry.get(), fps);
mostSimpleRule = MatrixUtil.determineFeaturesForIndex(bestEntry.getLocation(),
mappings, mappings.size() > 0 ? true : false);
mostSimpleRule = MatrixUtil.determineFeaturesForIndex(bestEntry.getLocation(), mappings,
mappings.size() > 0 ? true : false);
} else if (scoreCurrent == maxScore && scoreCurrent > 0) {
// keep the simpler rule
// keep the simpler rule TODO
MatrixPoint loc = new MatrixPoint(entry.column(), entry.row(), scoreCurrent, entry.get(), fps);
Set<Point> featuresForIndex = MatrixUtil.determineFeaturesForIndex(loc.getLocation(),
mappings, mappings.size() > 0 ? true : false);
Set<Point> featuresForIndex = MatrixUtil.determineFeaturesForIndex(loc.getLocation(), mappings,
mappings.size() > 0 ? true : false);
if (mostSimpleRule == null) {
bestEntry = loc;
mostSimpleRule = featuresForIndex;
......
......@@ -20,6 +20,7 @@ import java.util.stream.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import de.uniwue.ls6.algorithm.datastructure.RepresentationRule;
import de.uniwue.ls6.datastructure.Instance;
import de.uniwue.ls6.datastructure.LabelAlphabet;
......@@ -71,25 +72,26 @@ public class MatrixUtil {
}
}
// // get the features of this point
// Set<Point> featuresOfPoints = determineFeaturesForIndex(new Point(entry.column(), entry.row()),
// mappings, mappings.size() > 0 ? true : false);
//
// // only keep each feature combination exactly once!
// if (uniqueFeatureCombinations.contains(featuresOfPoints)) {
// continue;
// }
// uniqueFeatureCombinations.add(featuresOfPoints);
//
// // furthermore we can filter all those featurecombinations that
// // resemble exactly the same instance set
// // because the algrotihm can not differ between those
// Set<Instance> instancesForFeatureSet = determineInstancesForFeatures(featuresOfPoints, indexMap);
//
// if (uniqueInstance.contains(instancesForFeatureSet)) {
// continue;
// }
// uniqueInstance.add(instancesForFeatureSet);
// // get the features of this point
Set<Point> featuresOfPoints = determineFeaturesForIndex(new Point(entry.column(), entry.row()),
mappings, mappings.size() > 0 ? true : false);
// only keep each feature combination exactly once!
if (uniqueFeatureCombinations.contains(featuresOfPoints)) {
continue;
}
uniqueFeatureCombinations.add(featuresOfPoints);
// furthermore we can filter all those featurecombinations
// that
// resemble exactly the same instance set
// because the algrotihm can not differ between those
Set<Instance> instancesForFeatureSet = determineInstancesForFeatures(featuresOfPoints, indexMap);
if (uniqueInstance.contains(instancesForFeatureSet)) {
continue;
}
uniqueInstance.add(instancesForFeatureSet);
// finally we can decide to keep our feature for the next
// iteration
......@@ -115,16 +117,17 @@ public class MatrixUtil {
for (MatrixPoint entry : entries) {
if (matrixMapping.getMappingMap().size() > beamSize)
break;
if (entry.getScore() > 0) {
matrixMapping.addEntry(new Point(entry.getX(), entry.getY()));
}
// TODO this is extremely greedy!! most features have a negative
// score
// if (entry.getScore() > 0) {
matrixMapping.addEntry(new Point(entry.getX(), entry.getY()));
// }
}
// infer the -> righthandside
ArrayList<MatrixMapping> arrayList = new ArrayList<MatrixMapping>(mappings);
arrayList.add(matrixMapping);
matrixMapping.inferDenseMapValues(arrayList);
// System.out.println(entries.size()+"==");
// debug
......@@ -145,18 +148,28 @@ public class MatrixUtil {
public static Set<Instance> determineInstancesForFeatures(Set<Point> featuresOfPoints,
Map<Point, Set<Instance>> indexMap) {
// inverted list intersection!
List<Set<Instance>> instanceSetList = new ArrayList<>();
Set<Instance> intersectionSet = new HashSet<>();
Set<Instance> smallestSet = null;
for (Point feature : featuresOfPoints) {
Set<Instance> set = indexMap.get(feature);
if (intersectionSet.isEmpty()) {
intersectionSet.addAll(set);
} else {
// keep only the elements that are in both sets
intersectionSet.retainAll(set);
instanceSetList.add(set);
if (smallestSet == null)
smallestSet = set;
else if (set.size() < smallestSet.size())
smallestSet = set;
}
outer: for (Instance i : smallestSet) {
for (Set<Instance> set : instanceSetList) {
if (!set.contains(i)) {
continue outer;
}
}
intersectionSet.add(i);
}
return intersectionSet;
}
public static MatrixMcMatrixFace performKroneckerExpansionByIndex(List<MatrixMapping> mappings,
......@@ -200,8 +213,8 @@ public class MatrixUtil {
fullFeats.addAll(featE1);
fullFeats.addAll(featE2);
if (uniqueFeatures.contains(fullFeats))
continue;
if (uniqueFeatures.contains(fullFeats))
continue;
uniqueFeatures.add(fullFeats);
Set<Instance> fullInstances = new HashSet<>();
......@@ -211,8 +224,8 @@ public class MatrixUtil {
fullInstances.addAll(insE1);
fullInstances.retainAll(insE2);
if (uniqueInstances.contains(fullInstances))
continue;
if (uniqueInstances.contains(fullInstances))
continue;
uniqueInstances.add(fullInstances);
......@@ -247,9 +260,14 @@ public class MatrixUtil {
MatrixMapping lastMapping = mappings.get(mappings.size() - 1);
int dimension = lastMapping.getDenseMatrixDimension();
int kroneckerDim = dimension * dimension;
// add the forbidden expansion indexes
Supplier<MatrixMcMatrixFace> matrixConstructor = () -> new MatrixMcMatrixFace(dimension * dimension,
dimension * dimension, label);
//does not speed up
//addForbiddenIndexesToMapping(lastMapping, dimension);
Supplier<MatrixMcMatrixFace> matrixConstructor = () -> new MatrixMcMatrixFace(kroneckerDim, kroneckerDim,
label);
BiConsumer<MatrixMcMatrixFace, Instance> accumulator = (MatrixMcMatrixFace expandedMatrix, Instance inst) -> {
FlexCompColMatrix expandedInstance = inst.expand(mappings);
if (inst.getLabel() == label) {
......@@ -268,6 +286,42 @@ public class MatrixUtil {
.collect(Collector.of(matrixConstructor, accumulator, join, Collector.Characteristics.UNORDERED));
}
private static void addForbiddenIndexesToMapping(MatrixMapping lastMapping, int dimension) {
FlexCompColMatrix denseInstanceMatrix = new FlexCompColMatrix(dimension, dimension);
for (Point denseIndices : lastMapping.getInverseMappingMap().keySet()) {
denseInstanceMatrix.add(denseIndices.y, denseIndices.x, 1);
}
Set<Set<Point>> uniqueFeatureCombinations = new HashSet<>();
outer: for (MatrixEntry e1 : denseInstanceMatrix) {
Set<Point> featuresE1 = lastMapping.getFeaturesForDenseIndex(new Point(e1.column(), e1.row()));
int colE1 = e1.column() * dimension;
int rowE1 = e1.row() * dimension;
// iterate over all no sparse elements and expand
for (MatrixEntry e2 : denseInstanceMatrix) {
// calculate the new indices
int kroneckerCol = colE1 + e2.column();
int kroneckerRow = rowE1 + e2.row();
Set<Point> featuresE2 = lastMapping.getFeaturesForDenseIndex(new Point(e2.column(), e2.row()));
Set<Point> combinedSet = new HashSet<>(featuresE1);
combinedSet.addAll(featuresE2);
if (uniqueFeatureCombinations.contains(combinedSet)) {
lastMapping.addForbiddenIndex(new Point(kroneckerCol, kroneckerRow));
}
uniqueFeatureCombinations.add(combinedSet);
if (e1.row() == e2.row() && e1.column() == e2.column())
continue outer;
}
}
}
// TODO speed this up by a simple lookup of the indices in the previous
// mapping
public static Set<Point> determineFeaturesForIndex(Point index, List<MatrixMapping> mappings,
boolean revertKroneckerFirst) {
......
package de.uniwue.ls6.rulelearning.evaluation.eval;
public enum EEntityEvaluationsScheme {
IOB, BIO, BE;
}
package de.uniwue.ls6.rulelearning.evaluation.eval;
import java.awt.Point;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import de.uniwue.ls6.datastructure.ALabelling;
public class EntityAccuracyEvaluation implements IEvaluation {
private EEntityEvaluationsScheme scheme;
private String typeSplitter;
public EntityAccuracyEvaluation(EEntityEvaluationsScheme scheme, String typeSplitter) {
this.scheme = scheme;
this.typeSplitter = typeSplitter;
}
// TODO does not respect any doubles
@Override
public String evaluateToString(ALabelling[] goldLabels, ALabelling[] systemLabels) {
double tp_unlabelled = 0;
double fp_unlabelled = 0;
double fn_unlabelled = 0;
double tp_labelled = 0;
double fp_labelled = 0;
double fn_labelled = 0;
Map<Point, String> labelMapGold = convertLabellingToEntityLabelling(goldLabels);
Map<Point, String> labelMapSystem = convertLabellingToEntityLabelling(systemLabels);
Set<Point> goldUniqueMapped = new HashSet<>();
outer: for (Point sysEnt : labelMapSystem.keySet()) {
// check if gold has it
for (Point goldEnt : labelMapGold.keySet()) {
if (goldEnt.x == sysEnt.x && goldEnt.y == sysEnt.y && !goldUniqueMapped.contains(goldEnt)) {
// compare label
tp_unlabelled++;
if (labelMapGold.get(goldEnt).equals(labelMapSystem.get(sysEnt))) {
tp_labelled++;
goldUniqueMapped.add(goldEnt);
}
continue outer;
}
}
}
fp_labelled = labelMapSystem.size() - tp_labelled;
fn_labelled = labelMapGold.size() - tp_labelled;
double recall = tp_labelled / (tp_labelled + fn_labelled);
double prec = tp_labelled / (tp_labelled + fp_labelled);
double f1 = 2 * prec * recall / (prec + recall);
String entityEval = "Entity Evaluation\n";
entityEval += "Labelled-Precision: " + prec + "\n";
entityEval += "Labelled-Recall: " + recall + "\n";
entityEval += "Labelled-F1: " + f1;
return entityEval;
}
private Map<Point, String> convertLabellingToEntityLabelling(ALabelling[] labels) {
Map<Point, String> entityMap = new HashMap<>();
if (scheme == EEntityEvaluationsScheme.IOB) {
createIOBEntities(labels, entityMap);
}
// TODO currently unsupported
return entityMap;
}
private void createIOBEntities(ALabelling[] labels, Map<Point, String> entityMap) {
int beg = -1;
boolean inEntity = false;
for (int i = 0; i < labels.length; i++) {
ALabelling labelling = labels[i];
String stringLabel = labelling.getStringLabel();
if(stringLabel.equals("DEFAULT")){
stringLabel="O-";
}
// split between BIO and type
String[] split = stringLabel.split(typeSplitter);
// it is assumed that the BIO is in split[0] and the type in
// split[1]
if (split[0].startsWith("B")) {
if (inEntity) {