Commit ed67780f authored by Sven's avatar Sven
Browse files

Multithreading for brute-force wrapper

parent 6c966c92
Pipeline #11810 failed with stages
package tools.descartes.dml.empirical.extract.dependencies.identification;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
......@@ -8,6 +9,7 @@ import java.util.Collections;
import java.util.List;
import tools.descartes.dml.empirical.extract.dependencies.helper.MethodSignatureFormatHelper;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Evaluation;
import weka.classifiers.trees.M5P;
import weka.core.Attribute;
......@@ -17,6 +19,7 @@ public class BruteForceWrapperM5Extractor implements ParametricDependencyExtract
private boolean logging = false;
private String path;
private static int fileCounter = 0;
private static final int NUM_SUBSETS_SIMULTANIOUSLY = 15;
public BruteForceWrapperM5Extractor(Boolean logging, String path) {
this.logging = logging;
......@@ -32,9 +35,38 @@ public class BruteForceWrapperM5Extractor implements ParametricDependencyExtract
throw new IllegalArgumentException("Number of features too high for evaluating all subsets!");
int max = 1 << attributes.size();
// calculate subset ranges for each thread
List<int[]> subsetRanges = new ArrayList<>();
int increment = max / NUM_SUBSETS_SIMULTANIOUSLY;
for (int i = 0; i < max; i += increment) {
int[] range = new int[2];
range[0] = i;
int limit = i + increment - 1;
if (limit > max)
range[1] = max;
else
range[1] = limit;
subsetRanges.add(range);
}
// multi-thread the subset computation
subsetRanges.parallelStream().forEach(range -> evaluateSubsets(range[0], range[1], attributes, data));
if (logging)
System.out.println("Evaluated subproblem with " + max + " subsets.");
fileCounter += 1;
// dummy
int[] ret = new int[1];
ret[0] = 1;
return ret;
}
private void evaluateSubsets(int from, int to, List<Attribute> attributes, Instances data) {
// iterate over all subsets of non-target features and execute M5 regression for
// each subset
for (int i = 0; i < max; i++) {
for (int i = from; i < to; i++) {
ArrayList<Attribute> curSubset = new ArrayList<>();
for (int j = 0; j < attributes.size(); j++) {
if (((i >> j) & 1) == 1) {
......@@ -50,10 +82,21 @@ public class BruteForceWrapperM5Extractor implements ParametricDependencyExtract
}
}
M5P m5 = new M5P();
Evaluation ev = new Evaluation(data);
ev.evaluateModel(m5, subsetInstance);
double error = ev.rootMeanSquaredError();
// exchange for respective classifier, e.g, LinearRegression
AbstractClassifier classifier = new M5P();
Double error = null;
try {
classifier.buildClassifier(subsetInstance);
Evaluation ev = new Evaluation(data);
ev.evaluateModel(classifier, subsetInstance);
error = ev.meanAbsoluteError();
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
if (error == null)
continue;
// generate subset identifier
StringBuilder sb = new StringBuilder();
......@@ -62,22 +105,12 @@ public class BruteForceWrapperM5Extractor implements ParametricDependencyExtract
sb.append(",");
}
sb.deleteCharAt(sb.length() - 1);
// filename is the target variable, each subset with corresponding accuracy is
// written as a line in this file
writeToFile(
MethodSignatureFormatHelper.formatParameterIdentifier(data.classAttribute().name()) + fileCounter,
sb.toString() + ";" + String.valueOf(error) + ";");
writeToFile(MethodSignatureFormatHelper.formatParameterIdentifier(data.classAttribute().name()).replace(':',
' ') + fileCounter, sb.toString() + ";" + String.valueOf(error) + ";");
}
if (logging)
System.out.println("Evaluated subproblem with " + max + " subsets.");
fileCounter += 1;
// dummy
int[] ret = new int[1];
ret[1] = 1;
return ret;
}
@Override
......@@ -99,10 +132,10 @@ public class BruteForceWrapperM5Extractor implements ParametricDependencyExtract
}
private synchronized void writeToFile(String name, String line) {
try (BufferedWriter fw = new BufferedWriter(new FileWriter(path + name + ".csv", true))) {
try (BufferedWriter fw = new BufferedWriter(new FileWriter(new File(path, name + ".csv"), true))) {
fw.write(line + "\n");// appends the string to the file
} catch (IOException ioe) {
System.err.println("IOException: " + ioe.getMessage());
System.out.println("IOException: " + ioe.getMessage());
}
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment