Commit 8e4542fe authored by Markus Krug's avatar Markus Krug
Browse files

Added Binary Algorithm for the rule learning

parent 14747a34
......@@ -12,15 +12,11 @@
<attribute name="maven.pomderived" value="true"/>
</attributes>
</classpathentry>
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER/org.eclipse.jdt.internal.debug.ui.launcher.StandardVMType/J2SE-1.5">
<attributes>
<attribute name="maven.pomderived" value="true"/>
</attributes>
</classpathentry>
<classpathentry kind="con" path="org.eclipse.m2e.MAVEN2_CLASSPATH_CONTAINER">
<attributes>
<attribute name="maven.pomderived" value="true"/>
</attributes>
</classpathentry>
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER"/>
<classpathentry kind="output" path="target/classes"/>
</classpath>
eclipse.preferences.version=1
org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.5
org.eclipse.jdt.core.compiler.compliance=1.5
org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled
org.eclipse.jdt.core.compiler.codegen.methodParameters=do not generate
org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8
org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve
org.eclipse.jdt.core.compiler.compliance=1.8
org.eclipse.jdt.core.compiler.debug.lineNumber=generate
org.eclipse.jdt.core.compiler.debug.localVariable=generate
org.eclipse.jdt.core.compiler.debug.sourceFile=generate
org.eclipse.jdt.core.compiler.problem.assertIdentifier=error
org.eclipse.jdt.core.compiler.problem.enumIdentifier=error
org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning
org.eclipse.jdt.core.compiler.source=1.5
org.eclipse.jdt.core.compiler.source=1.8
<?xml version="1.0" encoding="UTF-8"?>
<classpath>
<classpathentry kind="src" path="src"/>
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER/org.eclipse.jdt.internal.debug.ui.launcher.StandardVMType/J2SE-1.5">
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER/org.eclipse.jdt.internal.debug.ui.launcher.StandardVMType/JavaSE-1.8">
<attributes>
<attribute name="maven.pomderived" value="true"/>
</attributes>
......
eclipse.preferences.version=1
org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.5
org.eclipse.jdt.core.compiler.compliance=1.5
org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled
org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8
org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve
org.eclipse.jdt.core.compiler.compliance=1.8
org.eclipse.jdt.core.compiler.debug.lineNumber=generate
org.eclipse.jdt.core.compiler.debug.localVariable=generate
org.eclipse.jdt.core.compiler.debug.sourceFile=generate
org.eclipse.jdt.core.compiler.problem.assertIdentifier=error
org.eclipse.jdt.core.compiler.problem.enumIdentifier=error
org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning
org.eclipse.jdt.core.compiler.source=1.5
org.eclipse.jdt.core.compiler.source=1.8
package de.uniwue.ls6.algorithm.datastructure;
import java.awt.Point;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import de.uniwue.ls6.datastructure.Instance;
import de.uniwue.ls6.datastructure.LabelAlphabet;
public class RepresentationRule {
// coondition
private List<Set<Integer>> conditionSet;
// label of the rule
private Integer label;
public RepresentationRule(int windowSize, int[][] instanceArray, int label) {
conditionSet = new ArrayList<Set<Integer>>(windowSize);
this.label = label;
for (int col = 0; col < instanceArray[0].length; col++) {
Set<Integer> set = conditionSet.get(col);
if (set == null) {
set = new HashSet<Integer>();
}
for (int row = 0; row < instanceArray.length; row++) {
set.add(instanceArray[row][col]);
}
conditionSet.set(col, set);
}
}
public RepresentationRule(int windowSize, List<Point> features, int label) {
this.label = label;
conditionSet = new ArrayList<Set<Integer>>(windowSize);
for (Point p : features) {
Set<Integer> set = conditionSet.get(p.x);
if (set == null) {
set = new HashSet<>();
}
set.add(p.y);
conditionSet.set(p.x, set);
}
}
public String verbalizeRule() {
return toString();
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
// TODO
for (Set<Integer> set : conditionSet) {
}
return "";
}
// TODO this could also be made faster
public boolean isApplicable(Instance ins) {
return ins.containsFeature(asPointList());
}
private List<Point> asPointList() {
List<Point> pointList = new ArrayList<>();
for (int i = 0; i < conditionSet.size(); i++) {
for (Integer feature : conditionSet.get(i)) {
pointList.add(new Point(i, feature));
}
}
return pointList;
}
}
package de.uniwue.ls6.algorithm.datastructure;
import java.util.Collections;
import java.util.List;
import de.uniwue.ls6.datastructure.Instance;
public class RulePass {
/**
* contains a set of rules that are applied within a pass over the instances
*/
private List<RepresentationRule> ruleSet;
public RulePass() {
super();
}
public RulePass(List<RepresentationRule> ruleSet) {
super();
this.ruleSet = ruleSet;
}
public void addRule(RepresentationRule rule) {
this.ruleSet.add(rule);
}
public void removeRule(RepresentationRule rule) {
this.ruleSet.remove(rule);
}
public List<RepresentationRule> getRuleset() {
return Collections.unmodifiableList(ruleSet);
}
public boolean isApplicable(Instance ins) {
for (RepresentationRule rule : ruleSet) {
if (rule.isApplicable(ins))
return true;
}
return false;
}
}
package de.uniwue.ls6.datastructure;
public abstract class ALabelling {
}
package de.uniwue.ls6.datastructure;
import java.awt.Point;
import java.util.Arrays;
import java.util.List;
import de.uniwue.ls6.util.MatrixUtil;
......@@ -108,4 +109,31 @@ public class Instance {
return expanedMatrix;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + Arrays.deepHashCode(featureArray);
result = prime * result + label;
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Instance other = (Instance) obj;
if (!Arrays.deepEquals(featureArray, other.featureArray))
return false;
if (label != other.label)
return false;
return true;
}
}
......@@ -45,4 +45,8 @@ public class LabelAlphabet {
idToFeatureMap.put(value, feature);
}
public static int getSize() {
return featureToIdMap.size();
}
}
package de.uniwue.ls6.datastructure;
import java.awt.Point;
import no.uib.cipr.matrix.MatrixEntry;
import no.uib.cipr.matrix.sparse.FlexCompColMatrix;
......@@ -77,6 +79,20 @@ public class MatrixMcMatrixFace {
}
public Point getLocationOfMaximum() {
double maxScore = 0;
MatrixEntry bestEntry = null;
for (MatrixEntry entry : tpMatrix) {
double scoreCurrent = entry.get() - fpMatrix.get(entry.row(), entry.column());
if (scoreCurrent > maxScore) {
maxScore = scoreCurrent;
bestEntry = entry;
}
}
return bestEntry == null ? null : new Point(bestEntry.column(), bestEntry.row());
}
/*
* 1. Differenzmatrix ausrechnen => Best Score ist Max(Matrix) 2. Setze
* check! Werte auf Sparse wenn Summe von TP un FP <=? MaxScore 3. Expansion
......
......@@ -35,7 +35,7 @@ public class MatrixUtil {
return matrixMapping;
}
public static MatrixMcMatrixFace performKroneckerExpansion(List<MatrixMapping> mappings, List<Instance> instances,
public static MatrixMcMatrixFace performKroneckerExpansion(List<MatrixMapping> mappings, Collection<Instance> instances,
int label) {
MatrixMapping lastMapping = mappings.get(mappings.size() - 1);
......
package de.uniwue.ls6.rulelearning.algorithm;
import de.uniwue.ls6.datastructure.ALabelling;
import de.uniwue.ls6.datastructure.Instance;
public interface IRepresentationRuleLearningAlgorithm {
public void learn(Instance...instances);
public ALabelling apply(Instance instanceToClassify);
}
package de.uniwue.ls6.rulelearning.algorithm.impl;
import java.awt.Point;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import de.uniwue.ls6.algorithm.datastructure.RepresentationRule;
import de.uniwue.ls6.algorithm.datastructure.RulePass;
import de.uniwue.ls6.datastructure.ALabelling;
import de.uniwue.ls6.datastructure.Instance;
import de.uniwue.ls6.datastructure.LabelAlphabet;
import de.uniwue.ls6.datastructure.MatrixMapping;
import de.uniwue.ls6.datastructure.MatrixMcMatrixFace;
import de.uniwue.ls6.rulelearning.algorithm.IRepresentationRuleLearningAlgorithm;
import de.uniwue.ls6.util.MatrixUtil;
public class BinaryRepresentationRuleLearningAlgorithm implements IRepresentationRuleLearningAlgorithm {
List<RulePass> passes;
private int goldLabel;
private int otherLabel;
public BinaryRepresentationRuleLearningAlgorithm(int goldLabel, int otherLabel) {
passes = new LinkedList<RulePass>();
this.goldLabel = goldLabel;
this.otherLabel = otherLabel;
}
public void learn(Instance... instances) {
if (instances.length == 0)
throw new IllegalArgumentException("Plz give data to train!");
Set<Instance> instancesForPass = new HashSet<Instance>(Arrays.asList(instances));
int passIndex = 0;
int currentGoldIndex = goldLabel;
while (morePasses(instancesForPass, currentGoldIndex)) {
// update the learning objective
if (passIndex % 2 == 0) {
currentGoldIndex = goldLabel;
} else {
currentGoldIndex = otherLabel;
}
// create a new pass
learnRulePass(currentGoldIndex, instancesForPass);
// adapt the instances
instancesForPass = keepClassifiableInstances(passes, instancesForPass);
passIndex++;
}
}
private Set<Instance> keepClassifiableInstances(List<RulePass> passes, Set<Instance> instancesForPass) {
Set<Instance> remainingInstances = new HashSet<Instance>();
for (Instance ins : instancesForPass) {
if (passes.get(passes.size() - 1).isApplicable(ins)) {
remainingInstances.add(ins);
}
}
return remainingInstances;
}
private void learnRulePass(int goldIndex, Set<Instance> instancesForPass) {
Set<Instance> temporaryCopy = new HashSet<Instance>(instancesForPass);
RulePass pass = new RulePass();
RepresentationRule learnedRule = learnRule(goldIndex, temporaryCopy);
if (learnedRule != null) {
pass.addRule(learnedRule);
// modify instances so that already classified instances are removed
// from the trainingdata
temporaryCopy = removeAlreadyClassifiable(temporaryCopy, pass);
} else {
if (!pass.getRuleset().isEmpty()) {
passes.add(pass);
}
return;
}
}
private Set<Instance> removeAlreadyClassifiable(Set<Instance> temporaryCopy, RulePass pass) {
Set<Instance> remainingInstances = new HashSet<Instance>();
for (Instance ins : temporaryCopy) {
if (!pass.isApplicable(ins)) {
remainingInstances.add(ins);
}
}
return remainingInstances;
}
private boolean morePasses(Set<Instance> instancesForPass, int currentGoldIndex) {
for (Instance in : instancesForPass) {
if (in.getLabel() == currentGoldIndex)
return true;
}
return false;
}
private RepresentationRule learnRule(int goldLabel, Collection<Instance> instances) {
List<MatrixMapping> mappings = new ArrayList<MatrixMapping>();
// perform the training
int windowSize = instances.iterator().next().getFeatureArray()[0].length;
MatrixMcMatrixFace matrixInFocus = new MatrixMcMatrixFace(LabelAlphabet.getSize(),
windowSize, goldLabel);
int maximumScore = 0;
while (betterRuleCanBeLearned(maximumScore, matrixInFocus)) {
maximumScore = matrixInFocus.getMaximumScore();
MatrixMapping mappingForMaximum = MatrixUtil.getMappingForMaximum(matrixInFocus, maximumScore);
mappings.add(mappingForMaximum);
matrixInFocus = MatrixUtil.performKroneckerExpansion(mappings, instances, goldLabel);
}
// determine the best rule that has not been used before! TODO is this even necessary???
Point maxEntryLocation = matrixInFocus.getLocationOfMaximum();
List<Point> featuresAtMax = MatrixUtil.determineFeaturesForIndex(maxEntryLocation, mappings);
return new RepresentationRule(windowSize, featuresAtMax, goldLabel);
}
private boolean betterRuleCanBeLearned(int maximumScore, MatrixMcMatrixFace matrixInFocus) {
return matrixInFocus.getMaximumScore() > maximumScore;
}
public ALabelling apply(Instance instanceToClassify) {
// TODO Auto-generated method stub
return null;
}
}
package de.uniwue.ls6.rulelearning.algorithm.impl;
import de.uniwue.ls6.datastructure.ALabelling;
import de.uniwue.ls6.datastructure.Instance;
import de.uniwue.ls6.rulelearning.algorithm.IRepresentationRuleLearningAlgorithm;
public class MultiClassRepresentationRuleAlgorithm implements IRepresentationRuleLearningAlgorithm {
public void learn(Instance... instances) {
// TODO Auto-generated method stub
}
public ALabelling apply(Instance instanceToClassify) {
// TODO Auto-generated method stub
return null;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment