Commit f2dcc9c0 authored by Markus Krug's avatar Markus Krug
Browse files

*added entity based evaluation

*started the CONLL learning and made a main for it ExperimentCONLL.java
* made more feature generator and contributed resources 
* quality is still pretty bad
parent fd59e265
......@@ -7,9 +7,11 @@ public abstract class ALabelling {
protected int label;
protected double score;
protected String stringLabel;
public ALabelling(int label,double score){
this.label = label;
this.score = score;
this.stringLabel = LabelAlphabet.getFeatureToId(label);
}
public int getLabel() {
return label;
......@@ -23,6 +25,13 @@ public abstract class ALabelling {
public void setScore(double score) {
this.score = score;
}
public String getStringLabel() {
return stringLabel;
}
public void setStringLabel(String stringLabel) {
this.stringLabel = stringLabel;
}
......
package de.uniwue.ls6.datastructure;
import java.awt.Point;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentHashMap.KeySetView;
import no.uib.cipr.matrix.Matrices;
import no.uib.cipr.matrix.MatrixEntry;
......@@ -155,16 +160,17 @@ public class Instance {
FlexCompColMatrix expanedMatrix = new FlexCompColMatrix(kroneckerDimension, kroneckerDimension);
outer: for (MatrixEntry e1 : denseInstanceMatrix) {
int colE1 = e1.column() * denseDimension;
int rowE1 = e1.row() * denseDimension;
// iterate over all no sparse elements and expand
for (MatrixEntry e2 : denseInstanceMatrix) {
// // skip diagonal expansion TODO ??
// calculate the new indices
int kroneckerCol = e1.column() * denseDimension + e2.column();
int kroneckerRow = e1.row() * denseDimension + e2.row();
int kroneckerCol = colE1 + e2.column();
int kroneckerRow = rowE1 + e2.row();
expanedMatrix.add(kroneckerRow, kroneckerCol, 1);
expanedMatrix.add(kroneckerRow, kroneckerCol, 1);
if (e1.row() == e2.row() && e1.column() == e2.column())
continue outer;
......
......@@ -2,6 +2,7 @@ package de.uniwue.ls6.datastructure;
import java.awt.Point;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
......@@ -12,7 +13,8 @@ public class MatrixMapping {
// x is col and y is row
HashMap<Point, Point> mappingMap;
HashMap<Point, Point> inverseMappingMap;
HashMap<Point,Set<Point>> denseIndexToFeaturesMapping;
HashMap<Point, Set<Point>> denseIndexToFeaturesMapping;
HashSet<Point> forbiddenIndexes;
//
private int kroneckerDimension;
......@@ -30,6 +32,7 @@ public class MatrixMapping {
this.inverseMappingMap = inverseMap;
this.kroneckerDimension = kroneckerDimension;
this.denseIndexToFeaturesMapping = new HashMap<>();
this.forbiddenIndexes = new HashSet<>();
}
public MatrixMapping(int kroneckerDimension) {
......@@ -38,6 +41,7 @@ public class MatrixMapping {
this.inverseMappingMap = new HashMap<Point, Point>();
this.kroneckerDimension = kroneckerDimension;
this.denseIndexToFeaturesMapping = new HashMap<>();
this.forbiddenIndexes = new HashSet<>();
}
public HashMap<Point, Point> getMappingMap() {
......@@ -68,6 +72,10 @@ public class MatrixMapping {
this.mappingMap.put(key, null);
}
public boolean isForbiddenIndex(Point p) {
return this.forbiddenIndexes.contains(p);
}
// this method generates all values based on the keys
public void inferDenseMapValues(List<MatrixMapping> mappings) {
int numCols = (int) Math.ceil(Math.sqrt(mappingMap.keySet().size()));
......@@ -78,10 +86,10 @@ public class MatrixMapping {
inverseMappingMap.put(value, key);
index++;
}
//also infer the features
for(Point p : inverseMappingMap.keySet()){
denseIndexToFeaturesMapping.put(p,MatrixUtil.determineFeaturesForIndex(p, mappings,false));
// also infer the features
for (Point p : inverseMappingMap.keySet()) {
denseIndexToFeaturesMapping.put(p, MatrixUtil.determineFeaturesForIndex(p, mappings, false));
}
}
......@@ -98,8 +106,8 @@ public class MatrixMapping {
public int getKroneckerMatrixDimension() {
return kroneckerDimension;
}
public Set<Point> getFeaturesForDenseIndex(Point densePoint){
public Set<Point> getFeaturesForDenseIndex(Point densePoint) {
return denseIndexToFeaturesMapping.get(densePoint);
}
......@@ -121,4 +129,9 @@ public class MatrixMapping {
return sb.toString();
}
public void addForbiddenIndex(Point point) {
this.forbiddenIndexes.add(point);
}
}
......@@ -20,6 +20,7 @@ import java.util.stream.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import de.uniwue.ls6.algorithm.datastructure.RepresentationRule;
import de.uniwue.ls6.datastructure.Instance;
import de.uniwue.ls6.datastructure.LabelAlphabet;
......@@ -259,9 +260,14 @@ public class MatrixUtil {
MatrixMapping lastMapping = mappings.get(mappings.size() - 1);
int dimension = lastMapping.getDenseMatrixDimension();
int kroneckerDim = dimension * dimension;
// add the forbidden expansion indexes
//does not speed up
//addForbiddenIndexesToMapping(lastMapping, dimension);
Supplier<MatrixMcMatrixFace> matrixConstructor = () -> new MatrixMcMatrixFace(dimension * dimension,
dimension * dimension, label);
Supplier<MatrixMcMatrixFace> matrixConstructor = () -> new MatrixMcMatrixFace(kroneckerDim, kroneckerDim,
label);
BiConsumer<MatrixMcMatrixFace, Instance> accumulator = (MatrixMcMatrixFace expandedMatrix, Instance inst) -> {
FlexCompColMatrix expandedInstance = inst.expand(mappings);
if (inst.getLabel() == label) {
......@@ -280,6 +286,40 @@ public class MatrixUtil {
.collect(Collector.of(matrixConstructor, accumulator, join, Collector.Characteristics.UNORDERED));
}
private static void addForbiddenIndexesToMapping(MatrixMapping lastMapping, int dimension) {
FlexCompColMatrix denseInstanceMatrix = new FlexCompColMatrix(dimension, dimension);
for (Point denseIndices : lastMapping.getInverseMappingMap().keySet()) {
denseInstanceMatrix.add(denseIndices.y, denseIndices.x, 1);
}
Set<Set<Point>> uniqueFeatureCombinations = new HashSet<>();
outer: for (MatrixEntry e1 : denseInstanceMatrix) {
Set<Point> featuresE1 = lastMapping.getFeaturesForDenseIndex(new Point(e1.column(), e1.row()));
int colE1 = e1.column() * dimension;
int rowE1 = e1.row() * dimension;
// iterate over all no sparse elements and expand
for (MatrixEntry e2 : denseInstanceMatrix) {
// calculate the new indices
int kroneckerCol = colE1 + e2.column();
int kroneckerRow = rowE1 + e2.row();
Set<Point> featuresE2 = lastMapping.getFeaturesForDenseIndex(new Point(e2.column(), e2.row()));
Set<Point> combinedSet = new HashSet<>(featuresE1);
combinedSet.addAll(featuresE2);
if (uniqueFeatureCombinations.contains(combinedSet)) {
lastMapping.addForbiddenIndex(new Point(kroneckerCol, kroneckerRow));
}
uniqueFeatureCombinations.add(combinedSet);
if (e1.row() == e2.row() && e1.column() == e2.column())
continue outer;
}
}
}
// TODO speed this up by a simple lookup of the indices in the previous
// mapping
public static Set<Point> determineFeaturesForIndex(Point index, List<MatrixMapping> mappings,
......
package de.uniwue.ls6.rulelearning.evaluation.eval;
public enum EEntityEvaluationsScheme {
IOB, BIO, BE;
}
package de.uniwue.ls6.rulelearning.evaluation.eval;
import java.awt.Point;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import de.uniwue.ls6.datastructure.ALabelling;
public class EntityAccuracyEvaluation implements IEvaluation {
private EEntityEvaluationsScheme scheme;
private String typeSplitter;
public EntityAccuracyEvaluation(EEntityEvaluationsScheme scheme, String typeSplitter) {
this.scheme = scheme;
this.typeSplitter = typeSplitter;
}
// TODO does not respect any doubles
@Override
public String evaluateToString(ALabelling[] goldLabels, ALabelling[] systemLabels) {
double tp_unlabelled = 0;
double fp_unlabelled = 0;
double fn_unlabelled = 0;
double tp_labelled = 0;
double fp_labelled = 0;
double fn_labelled = 0;
Map<Point, String> labelMapGold = convertLabellingToEntityLabelling(goldLabels);
Map<Point, String> labelMapSystem = convertLabellingToEntityLabelling(systemLabels);
Set<Point> goldUniqueMapped = new HashSet<>();
outer: for (Point sysEnt : labelMapSystem.keySet()) {
// check if gold has it
for (Point goldEnt : labelMapGold.keySet()) {
if (goldEnt.x == sysEnt.x && goldEnt.y == sysEnt.y && !goldUniqueMapped.contains(goldEnt)) {
// compare label
tp_unlabelled++;
if (labelMapGold.get(goldEnt).equals(labelMapSystem.get(sysEnt))) {
tp_labelled++;
goldUniqueMapped.add(goldEnt);
}
continue outer;
}
}
}
fp_labelled = labelMapSystem.size() - tp_labelled;
fn_labelled = labelMapGold.size() - tp_labelled;
double recall = tp_labelled / (tp_labelled + fn_labelled);
double prec = tp_labelled / (tp_labelled + fp_labelled);
double f1 = 2 * prec * recall / (prec + recall);
String entityEval = "Entity Evaluation\n";
entityEval += "Labelled-Precision: " + prec + "\n";
entityEval += "Labelled-Recall: " + recall + "\n";
entityEval += "Labelled-F1: " + f1;
return entityEval;
}
private Map<Point, String> convertLabellingToEntityLabelling(ALabelling[] labels) {
Map<Point, String> entityMap = new HashMap<>();
if (scheme == EEntityEvaluationsScheme.IOB) {
createIOBEntities(labels, entityMap);
}
// TODO currently unsupported
return entityMap;
}
private void createIOBEntities(ALabelling[] labels, Map<Point, String> entityMap) {
int beg = -1;
boolean inEntity = false;
for (int i = 0; i < labels.length; i++) {
ALabelling labelling = labels[i];
String stringLabel = labelling.getStringLabel();
if(stringLabel.equals("DEFAULT")){
stringLabel="O-";
}
// split between BIO and type
String[] split = stringLabel.split(typeSplitter);
// it is assumed that the BIO is in split[0] and the type in
// split[1]
if (split[0].startsWith("B")) {
if (inEntity) {
// close the previous entity
entityMap.put(new Point(beg, i - 1), labels[i - 1].getStringLabel().split(typeSplitter)[1]);
}
inEntity = true;
beg = i;
} else if (split[0].startsWith("I")) {
if (!inEntity) {
beg = i;
}
inEntity = true;
} else if (split[0].startsWith("O")) {
if(inEntity){
//close the previous entity
entityMap.put(new Point(beg, i - 1), labels[i - 1].getStringLabel().split(typeSplitter)[1]);
}
inEntity=false;
}
}
// add a possible last anno
if(inEntity){
//add it
entityMap.put(new Point(beg, labels.length-1), labels[labels.length-1].getStringLabel().split(typeSplitter)[1]);
}
}
@Override
public Evaluation evaluate(ALabelling[] goldLabels, ALabelling[] systemLabels) {
// TODO Auto-generated method stub
return null;
}
}
package de.uniwue.ls6.rulelearning.instanceloading.featuregenerator;
import org.apache.uima.cas.FSIterator;
import org.apache.uima.cas.Type;
import org.apache.uima.cas.text.AnnotationFS;
public class IOB_CONNL_NE_UIMA_FeatureGen extends AFeatureGenerator {
int prefixLen;
String typeS;
String featureId;
public IOB_CONNL_NE_UIMA_FeatureGen(String typeS, String labelId, String featureId) {
super(labelId);
this.typeS = typeS;
this.featureId = featureId;
}
@Override
public String[] generateFeatures(AnnotationFS token) {
Type type = token.getCAS().getTypeSystem().getType(typeS);
// get the end of the previous token
AnnotationFS lastTok = null;
FSIterator<AnnotationFS> iterator = token.getCAS().getAnnotationIndex(token.getType()).iterator();
iterator.moveTo(token);
iterator.moveToPrevious();
if (iterator.hasNext()) {
lastTok = iterator.get();
}
AnnotationFS lastGoldAnno = null;
for (AnnotationFS anno : token.getCAS().getAnnotationIndex(type)) {
String label = anno.getFeatureValueAsString(anno.getType().getFeatureByBaseName(featureId));
if (anno.getBegin() == token.getBegin() && anno.getEnd() == token.getEnd()) {
// here it is either B or I
if (lastGoldAnno == null || lastTok == null) {
// I
return new String[] { "I-" + label };
} else {
if (lastGoldAnno.getEnd() == lastTok.getEnd()) {
// B
return new String[] { "B-" + label };
}
}
return new String[] { "I-" + label };
} else if (token.getBegin() > anno.getBegin() && token.getEnd() <= anno.getEnd()) {
return new String[] { "I-" + label };
}
lastGoldAnno = anno;
}
return new String[] { "O-" };
}
}
package de.uniwue.ls6.rulelearning.instanceloading.featuregenerator;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.HashSet;
import java.util.Set;
import org.apache.uima.cas.text.AnnotationFS;
public class InListFeatureGenerator extends AFeatureGenerator {
/*
* only returns true if the token is contained exactly
*/
Set<String> listEntries;
public InListFeatureGenerator(String featureIdentidier,InputStream inStream) {
super(featureIdentidier);
listEntries = new HashSet<String>();
BufferedReader reader = new BufferedReader(new InputStreamReader(inStream));
String line;
try {
while ((line = reader.readLine()) != null) {
listEntries.add(line.trim());
}
reader.close();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
@Override
public String[] generateFeatures(AnnotationFS token) {
if(listEntries.contains(token.getCoveredText())){
return new String[]{super.featureIdentifier+"=IN_LIST" };
}
return new String[]{super.featureIdentifier+"=NOT_IN_LIST" };
}
}
package de.uniwue.ls6.rulelearning.instanceloading.featuregenerator;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import org.apache.uima.cas.text.AnnotationFS;
public class WordClusteringFeatureGenerator extends AFeatureGenerator {
/*
* only returns true if the token is contained exactly
*/
Map<String, String> clusteringEntries;
public WordClusteringFeatureGenerator(String featureIdentidier, InputStream inStream, String colSegmenter) {
super(featureIdentidier);
clusteringEntries = new HashMap<String, String>();
BufferedReader reader = new BufferedReader(new InputStreamReader(inStream));
String line;
try {
while ((line = reader.readLine()) != null) {
clusteringEntries.put(line.split(colSegmenter)[0].trim(), line.split(colSegmenter)[1].trim());
}
reader.close();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
@Override
public String[] generateFeatures(AnnotationFS token) {
if (clusteringEntries.keySet().contains(token.getCoveredText())) {
return new String[] {
super.featureIdentifier + "=Cluster_id_" + clusteringEntries.get(token.getCoveredText()) };
}
return new String[] { super.featureIdentifier + "=NOT_IN_CLUSTER" };
}
}
......@@ -25,11 +25,12 @@ import de.uniwue.ls6.rulelearning.instanceloading.featuregenerator.AFeatureGener
public class InstanceCreationFactory {
public static List<Instance> createWindowedInstancesFromUIMA(File fileToDocument,int startIndex, int leftWindowsize,
int rightWindowSize, String tokentypeS, TypeSystemDescription typesystem, AFeatureGenerator goldGenerator,
AFeatureGenerator... generators) throws ResourceInitializationException, SAXException, IOException {
public static List<Instance> createWindowedInstancesFromUIMA(File fileToDocument, int startIndex,
int leftWindowsize, int rightWindowSize, String tokentypeS, TypeSystemDescription typesystem,
AFeatureGenerator goldGenerator, AFeatureGenerator... generators)
throws ResourceInitializationException, SAXException, IOException {
//get the windowsize of the experiment
// get the windowsize of the experiment
int windowSize = leftWindowsize + 1 + rightWindowSize;
// deserialize
CAS cas = deserializeCAS(fileToDocument, typesystem);
......@@ -52,7 +53,8 @@ public class InstanceCreationFactory {
// update the queue
if (windowQueue.size() >= windowSize) {
Instance inst = generateInstanceFromQueue(windowQueue, windowSize, labelList.get(tokenIndex),tokenIndex+startIndex+1);
Instance inst = generateInstanceFromQueue(windowQueue, windowSize, labelList.get(tokenIndex),
tokenIndex + startIndex + 1);
instances.add(inst);
windowQueue.poll();
tokenIndex++;
......@@ -63,11 +65,11 @@ public class InstanceCreationFactory {
for (int i = 0; i < rightWindowSize; i++) {
windowQueue.poll();
windowQueue.add(new LinkedList<String>());
Instance inst = generateInstanceFromQueue(windowQueue, windowSize, labelList.get(tokenIndex),tokenIndex);
Instance inst = generateInstanceFromQueue(windowQueue, windowSize, labelList.get(tokenIndex), tokenIndex);
instances.add(inst);
tokenIndex++;
}
System.out.println(cas.getAnnotationIndex(tokentype).size() + " Tokens and " + instances.size() + " instances");
return instances;
}
......@@ -90,14 +92,14 @@ public class InstanceCreationFactory {
}
int idToFeature = LabelAlphabet.getIdToFeature(goldfeature);