Commit f60fb8df authored by Markus Krug's avatar Markus Krug
Browse files

fixed multiclass classification

parent 17bb7a2b
......@@ -7,6 +7,7 @@ import java.util.Set;
import de.uniwue.ls6.util.MatrixPoint;
import de.uniwue.ls6.util.MatrixUtil;
import no.uib.cipr.matrix.Matrices;
import no.uib.cipr.matrix.MatrixEntry;
import no.uib.cipr.matrix.sparse.FlexCompColMatrix;
......@@ -34,7 +35,6 @@ public class MatrixMcMatrixFace {
// l2r t2B
for (int col = 0; col < i.getNrCols(); col++) {
for (int row = 0; row < i.getNrRows(); row++) {
if (goldLabel == i.getLabel()) {
// add to TP matrix
addToMatrix(i.getValueAt(col, row), col, tpMatrix);
......@@ -66,12 +66,15 @@ public class MatrixMcMatrixFace {
}
private void removeFromMatrix(int feature, int windowcolumn, FlexCompColMatrix matrix) {
//nothing stored in this field
if (feature == 0)
return;
matrix.add(feature, windowcolumn, -1);
}
private void addToMatrix(int feature, int windowcolumn, FlexCompColMatrix matrix) {
// do not add the default label!!
// nothing stored in this field
if (feature == 0)
return;
matrix.add(feature, windowcolumn, 1);
......@@ -85,7 +88,7 @@ public class MatrixMcMatrixFace {
*/
public int getMaximumScore() {
double maxScore = Double.MIN_VALUE;
double maxScore = Double.NEGATIVE_INFINITY;
for (MatrixEntry entry : tpMatrix) {
double scoreCurrent = entry.get() - fpMatrix.get(entry.row(), entry.column());
......@@ -112,31 +115,41 @@ public class MatrixMcMatrixFace {
// TODO this needs a speedup because it calculates all features backwards
// which is unnecessary since it is stored in the mapping
public MatrixPoint getLocationOfMaximum(List<MatrixMapping> mappings) {
double maxScore = Double.MIN_VALUE;
double maxScore = Double.NEGATIVE_INFINITY;
MatrixPoint bestEntry = null;
Set<Point> mostSimpleRule = null;
//System.out.println("Card: " + Matrices.cardinality(tpMatrix));
for (MatrixEntry entry : tpMatrix) {
double fps = fpMatrix.get(entry.row(), entry.column());
double scoreCurrent = entry.get() - fps;
if (scoreCurrent > maxScore) {
maxScore = scoreCurrent;
bestEntry = new MatrixPoint(entry.column(), entry.row(), scoreCurrent, entry.get(), fps);
mostSimpleRule = MatrixUtil.determineFeaturesForIndex(bestEntry.getLocation(), mappings,
mappings.size() > 0 ? true : false);
} else if (scoreCurrent == maxScore && scoreCurrent > 0) {
// keep the simpler rule TODO
MatrixPoint loc = new MatrixPoint(entry.column(), entry.row(), scoreCurrent, entry.get(), fps);
Set<Point> featuresForIndex = MatrixUtil.determineFeaturesForIndex(loc.getLocation(), mappings,
mappings.size() > 0 ? true : false);
if (mostSimpleRule == null) {
bestEntry = loc;
mostSimpleRule = featuresForIndex;
}
// mostSimpleRule =
// MatrixUtil.determineFeaturesForIndex(bestEntry.getLocation(),
// mappings,
// mappings.size() > 0 ? true : false);
}
// else if (scoreCurrent == maxScore && scoreCurrent > 0) {
// // keep the simpler rule TODO
// MatrixPoint loc = new MatrixPoint(entry.column(), entry.row(),
// scoreCurrent, entry.get(), fps);
// Set<Point> featuresForIndex =
// MatrixUtil.determineFeaturesForIndex(loc.getLocation(), mappings,
// mappings.size() > 0 ? true : false);
// if (mostSimpleRule == null) {
// bestEntry = loc;
// mostSimpleRule = featuresForIndex;
// }
//
// }
}
if (bestEntry == null) {
// System.out.println(MatrixUtil.prettyMatrixFormat(tpMatrix));
// System.out.println();
// System.out.println(MatrixUtil.prettyMatrixFormat(fpMatrix));
}
return bestEntry == null ? null : bestEntry;
return bestEntry;
}
public String asString(List<MatrixMapping> mappings) {
......@@ -170,8 +183,6 @@ public class MatrixMcMatrixFace {
}
return sb.toString();
}
public int getGoldLabel() {
return goldLabel;
......
package de.uniwue.ls6.datastructure;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class MultiClassMapping {
// saves the correspondences between matrices and a list of mappings
private Map<MatrixMcMatrixFace, List<MatrixMapping>> map;
private Map<Integer, List<MatrixMapping>> map;
public MultiClassMapping() {
map = new HashMap<>();
}
public void addMapping(MatrixMcMatrixFace matrix, MatrixMapping mapping) {
List<MatrixMapping> matrixMappings = map.get(matrix);
public void addMapping(int label, MatrixMapping mapping) {
List<MatrixMapping> matrixMappings = map.get(label);
if (matrixMappings == null) {
matrixMappings = new ArrayList<>();
}
matrixMappings.add(mapping);
map.put(matrix, matrixMappings);
map.put(label, matrixMappings);
}
public Collection<List<MatrixMapping>> getMatrixMappings() {
public List<MatrixMapping> getMatrixMapping(MatrixMcMatrixFace matrix) {
return new ArrayList<>(map.values());
}
public List<MatrixMapping> getMatrixMappings(MatrixMcMatrixFace matrix) {
List<MatrixMapping> list = map.get(matrix);
List<MatrixMapping> list = map.get(matrix.getGoldLabel());
if (list == null) {
return new ArrayList<>();
}
......@@ -36,11 +43,11 @@ public class MultiClassMapping {
}
public void add(MultiClassMapping mappingForMaximum) {
for (MatrixMcMatrixFace matrix : mappingForMaximum.map.keySet()) {
for (Integer label : mappingForMaximum.map.keySet()) {
List<MatrixMapping> list = mappingForMaximum.map.get(matrix);
List<MatrixMapping> list = mappingForMaximum.map.get(label);
this.map.get(matrix).addAll(list);
map.get(label).addAll(list);
}
}
......@@ -48,11 +55,11 @@ public class MultiClassMapping {
// remove all that are contained in mapping for max
public void remove(MultiClassMapping mappingForMaximum) {
for (MatrixMcMatrixFace matrix : mappingForMaximum.map.keySet()) {
for (Integer label : mappingForMaximum.map.keySet()) {
List<MatrixMapping> list = mappingForMaximum.map.get(matrix);
List<MatrixMapping> list = mappingForMaximum.map.get(label);
this.map.get(matrix).removeAll(list);
this.map.get(label).removeAll(list);
}
}
......@@ -66,4 +73,21 @@ public class MultiClassMapping {
return maxSize;
}
public void removeLastMapping() {
for(List<MatrixMapping> mappings : map.values()){
mappings.remove(mappings.size()-1);
}
}
public int minSize() {
int minSize = Integer.MAX_VALUE;
for (List<MatrixMapping> mappings : map.values()) {
if (mappings.size() < minSize)
minSize = mappings.size();
}
return minSize;
}
}
......@@ -6,6 +6,7 @@ import java.util.List;
import de.uniwue.ls6.util.MatrixPoint;
import no.uib.cipr.matrix.MatrixEntry;
import no.uib.cipr.matrix.sparse.SSOR;
public class MultiClassMatrixMcMatrixFace {
......@@ -21,6 +22,12 @@ public class MultiClassMatrixMcMatrixFace {
}
}
public MultiClassMatrixMcMatrixFace() {
matrices = new ArrayList<>();
}
public void addInstances(Instance... instances) {
for (MatrixMcMatrixFace matrix : matrices) {
......@@ -35,16 +42,12 @@ public class MultiClassMatrixMcMatrixFace {
}
}
public int getMaximumScore() {
double maxScore = Double.MIN_VALUE;
public List<MatrixMcMatrixFace> getMatrices() {
return matrices;
}
for (MatrixMcMatrixFace matrix : matrices) {
double scoreMa = matrix.getMaximumScore();
if (scoreMa > maxScore)
maxScore = scoreMa;
}
return (int) maxScore;
public void setMatrices(List<MatrixMcMatrixFace> matrices) {
this.matrices = matrices;
}
public int getWindowSize() {
......@@ -56,20 +59,25 @@ public class MultiClassMatrixMcMatrixFace {
}
public MatrixPoint getLocationOfMaximum(MultiClassMapping mapping) {
double maxScore = Integer.MIN_VALUE;
MatrixMcMatrixFace matrixFace = null;
double maxScore = Double.NEGATIVE_INFINITY;
MatrixPoint bestPoint = null;
for(MatrixMcMatrixFace matrix : matrices){
MatrixPoint locationOfMaximum = matrix.getLocationOfMaximum(mapping.getMatrixMapping(matrix));
if(locationOfMaximum.getScore()>maxScore){
maxScore= locationOfMaximum.getScore();
matrixFace = matrix;
bestPoint=locationOfMaximum;
for (MatrixMcMatrixFace matrix : matrices) {
MatrixPoint locationOfMaximum = matrix.getLocationOfMaximum(mapping.getMatrixMappings(matrix));
if (locationOfMaximum == null)
continue;
if (locationOfMaximum.getScore() > maxScore) {
maxScore = locationOfMaximum.getScore();
bestPoint = locationOfMaximum;
bestPoint.setAccordingMatrix(matrix);
}
}
bestPoint.setAccordingMatrix(matrixFace);
return bestPoint;
}
public void addMatrix(MatrixMcMatrixFace expandedMatrix) {
this.matrices.add(expandedMatrix);
}
}
......@@ -15,11 +15,11 @@ import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.StreamSupport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import de.uniwue.ls6.algorithm.datastructure.RepresentationRule;
import de.uniwue.ls6.datastructure.Instance;
import de.uniwue.ls6.datastructure.LabelAlphabet;
import de.uniwue.ls6.datastructure.MatrixMapping;
......@@ -71,7 +71,8 @@ public class MatrixUtil {
if (amountTP == maximum && matrixface.getFpMatrix().get(entry.row(), entry.column()) == 0) {
if (!(new Point(entry.column(), entry.row()).equals(maxEntryLocation))) {
// no potential to improve left!
continue;
//continue;
//TODO
}
}
......@@ -136,17 +137,14 @@ public class MatrixUtil {
// System.out.println(entries.size()+"==");
// debug
// for (MatrixPoint entry : entries) {
// Set<Point> featuresForDenseIndex = determineFeaturesForIndex(new
// Point(entry.getX(), entry.getY()),
// mappings, mappings.size() > 0 ? true : false);
// RepresentationRule representationRule = new RepresentationRule(6,
// featuresForDenseIndex, 1, 10,0);
// System.out.println(entry.getScore() + "\t" + maximum + "\tTP " +
// entry.getTp() + "\tFP" + entry.getFp()
// + "\t" + representationRule.toString());
// }
// logger.info("Best possible coverage: " + maxCoverage);
// for (MatrixPoint entry : entries) {
// Set<Point> featuresForDenseIndex = determineFeaturesForIndex(new Point(entry.getX(), entry.getY()),
// mappings, mappings.size() > 0 ? true : false);
// RepresentationRule representationRule = new RepresentationRule(6, featuresForDenseIndex, 1, 10, 0);
// System.out.println(entry.getScore() + "\t" + maximum + "\tTP " + entry.getTp() + "\tFP" + entry.getFp()
// + "\t" + representationRule.toString());
// }
// logger.info("Best possible coverage: " + maxCoverage);
return matrixMapping;
}
......
......@@ -20,12 +20,16 @@ import java.util.stream.StreamSupport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.sun.org.apache.xerces.internal.util.SynchronizedSymbolTable;
import de.uniwue.ls6.algorithm.datastructure.RepresentationRule;
import de.uniwue.ls6.datastructure.Instance;
import de.uniwue.ls6.datastructure.LabelAlphabet;
import de.uniwue.ls6.datastructure.MatrixMapping;
import de.uniwue.ls6.datastructure.MatrixMcMatrixFace;
import de.uniwue.ls6.datastructure.MultiClassMapping;
import de.uniwue.ls6.datastructure.MultiClassMatrixMcMatrixFace;
import no.uib.cipr.matrix.Matrices;
import no.uib.cipr.matrix.MatrixEntry;
import no.uib.cipr.matrix.sparse.FlexCompColMatrix;
......@@ -33,20 +37,54 @@ public class MulticlassMatrixUtil {
static final Logger logger = LoggerFactory.getLogger(MulticlassMatrixUtil.class);
public static MultiClassMapping getMappingForMaximum(MultiClassMatrixMcMatrixFace iterationMatrix, int maximum, Point maxEntryLocation,
MultiClassMapping multiClassMapping, Set<Instance> instances, Map<Point, Set<Instance>> indexMap, int beamSize) {
public static void getMappingForMaximum(MultiClassMatrixMcMatrixFace iterationMatrix, int maximum,
Point maxEntryLocation, MultiClassMapping multiClassMapping, Set<Instance> instances,
Map<Point, Set<Instance>> indexMap, int beamSize) {
// this methods creates a mapping for every matrix
for (MatrixMcMatrixFace matrix : iterationMatrix.getMatrices()) {
MatrixMapping mappingForMaximum = MatrixUtil.getMappingForMaximum(matrix, maximum, null,
multiClassMapping.getMatrixMappings(matrix), instances, indexMap, beamSize);
multiClassMapping.addMapping(matrix.getGoldLabel(), mappingForMaximum);
}
// TODO
return null;
}
public static MultiClassMatrixMcMatrixFace performKroneckerExpansion(MultiClassMapping multiClassMapping,
Collection<Instance> instances) {
Collection<Instance> instances, MultiClassMatrixMcMatrixFace multiClass) {
// TODO
return null;
}
// expand all matrices (if necessary!)
MultiClassMatrixMcMatrixFace multiClassMatrixMcMatrixFace = new MultiClassMatrixMcMatrixFace();
multiClassMatrixMcMatrixFace.setWindowSize(multiClass.getWindowSize());
for (MatrixMcMatrixFace matrix : multiClass.getMatrices()) {
List<MatrixMapping> matrixMappings = multiClassMapping.getMatrixMappings(matrix);
//System.out.println("Before kron max score: " + matrix.getMaximumScore() + "Label: " + LabelAlphabet.getFeatureToId(matrix.getGoldLabel()));
//System.out.println(matrixMappings.get(matrixMappings.size()-1).getMappingMap().size());
// if(matrixMappings.get(matrixMappings.size()-1).getMappingMap().size()>0){
// HashMap<Point, Set<Point>> denseIndexToFeaturesMapping =
// matrixMappings.get(matrixMappings.size()-1).getDenseIndexToFeaturesMapping();
//
// //create a rule
// for(Set<Point> feats : denseIndexToFeaturesMapping.values()){
// RepresentationRule representationRule = new RepresentationRule(1,
// feats, matrix.getGoldLabel(), 1, 1);
// System.out.println(representationRule);
// }
// }
MatrixMcMatrixFace expandedMatrix = MatrixUtil.performKroneckerExpansion(matrixMappings, instances,
matrix.getGoldLabel());
//System.out.println("After kron max score: " + expandedMatrix.getMaximumScore());
//System.out.println();
multiClassMatrixMcMatrixFace.addMatrix(expandedMatrix);
}
return multiClassMatrixMcMatrixFace;
}
}
package de.uniwue.ls6.rulelearning.instanceloading.featuregenerator;
import org.apache.uima.cas.FSIterator;
import org.apache.uima.cas.Type;
import org.apache.uima.cas.text.AnnotationFS;
public class CONLL_LabelFeatureGen extends AFeatureGenerator {
int prefixLen;
String typeS;
String featureId;
public CONLL_LabelFeatureGen(String typeS, String labelId, String featureId) {
super(labelId);
this.typeS = typeS;
this.featureId = featureId;
}
@Override
public String[] generateFeatures(AnnotationFS token) {
Type type = token.getCAS().getTypeSystem().getType(typeS);
// get the end of the previous token
AnnotationFS lastTok = null;
FSIterator<AnnotationFS> iterator = token.getCAS().getAnnotationIndex(token.getType()).iterator();
iterator.moveTo(token);
iterator.moveToPrevious();
if (iterator.hasNext()) {
lastTok = iterator.get();
}
AnnotationFS lastGoldAnno = null;
for (AnnotationFS anno : token.getCAS().getAnnotationIndex(type)) {
String label = anno.getFeatureValueAsString(anno.getType().getFeatureByBaseName(featureId));
if (anno.getBegin() == token.getBegin() && anno.getEnd() == token.getEnd()) {
// here it is either B or I
if (lastGoldAnno == null || lastTok == null) {
// I
return new String[] { "LABEL=I-" + label };
} else {
if (lastGoldAnno.getEnd() == lastTok.getEnd()) {
// B
return new String[] { "LABEL=B-" + label };
}
}
return new String[] { "LABEL=I-" + label };
} else if (token.getBegin() > anno.getBegin() && token.getEnd() <= anno.getEnd()) {
return new String[] { "LABEL=I-" + label };
}
lastGoldAnno = anno;
}
return new String[] { "LABEL=O-" };
}
}
......@@ -12,44 +12,99 @@ import java.util.Set;
import org.apache.uima.cas.text.AnnotationFS;
public class InListFeatureGenerator extends AFeatureGenerator {
/*
* only returns true if the token is contained exactly
*/
Set<String> listEntries;
Set<String> longEntries;
public InListFeatureGenerator(String featureIdentidier,InputStream inStream) {
public InListFeatureGenerator(String featureIdentidier, InputStream inStream) {
super(featureIdentidier);
listEntries = new HashSet<String>();
BufferedReader reader = new BufferedReader(new InputStreamReader(inStream));
String line;
try {
while ((line = reader.readLine()) != null) {
listEntries.add(line.trim());
longEntries = new HashSet<>();
BufferedReader reader = new BufferedReader(new InputStreamReader(inStream));
String line;
try {
while ((line = reader.readLine()) != null) {
listEntries.add(line.trim());
String[] split = line.trim().split(" ");
if (split.length > 1) {
longEntries.add(line.trim());
}
reader.close();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
reader.close();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
@Override
public String[] generateFeatures(AnnotationFS token) {
if(listEntries.contains(token.getCoveredText())){
List<String> list = new ArrayList<>();
String id2 = token.getCoveredText().substring(0,Math.min(2, token.getCoveredText().length()));
list.add(super.featureIdentifier+id2+"IN_LIST");
String id1= token.getCoveredText().substring(0,Math.min(2, token.getCoveredText().length()));
list.add(super.featureIdentifier+id1+"IN_LIST");
String idFull = super.featureIdentifier+"=IN_LIST";
List<String> list = new ArrayList<>();
if (listEntries.contains(token.getCoveredText())) {
String id2 = token.getCoveredText().substring(0, Math.min(2, token.getCoveredText().length()));
list.add(super.featureIdentifier + id2 + "IN_LIST");
String id1 = token.getCoveredText().substring(0, Math.min(1, token.getCoveredText().length()));
list.add(super.featureIdentifier + id1 + "IN_LIST");
String idFull = super.featureIdentifier + "=IN_LIST";
list.add(idFull);
return list.toArray(new String[0]);
// perform a no case match
if(list.contains(token.getCoveredText().toLowerCase())|| list.contains(token.getCoveredText().toUpperCase())){
String idCase = super.featureIdentifier + "IGNORE_CASE=IN_LIST";
list.add(idCase);
}
// also add some more entries based on the border of longer entries
}
return new String[]{super.featureIdentifier+"=NOT_IN_LIST" };
// get substring around token
String documentText = token.getCAS().getDocumentText();
String surround = documentText.substring(Math.max(0, token.getBegin() - 50),
Math.min(documentText.length(), token.getEnd() + 50));
for (String longEntry : longEntries) {