Commit d44049cf authored by Alexander Gehrke's avatar Alexander Gehrke

Update for ocr4all-pixel-classifier -> ocr4all-pylib split

parent 20802eca
import argparse
import multiprocessing
import os
from dataclasses import dataclass
from functools import partial
from typing import Tuple
import numpy as np
from tqdm import tqdm
from ocr4all.files import imread_bin, match_filenames
from ocr4all.image_map import ImageMap
from ocr4all_pixel_classifier.lib.evaluation import count_matches, total_accuracy, f1_measures, ConnectedComponentEval, \
cc_equal, cc_matching
from ocr4all_pixel_classifier.lib.image_map import rgb_to_label
from ocr4all_pixel_classifier.lib.util import imread, imread_bin, match_filenames
from ocr4all_pixel_classifier.lib.image_map import load_image_map_from_file
cc_matching
from tqdm import tqdm
def main():
......@@ -32,17 +29,17 @@ def main():
parser.add_argument("--csv", action="store_true", help="enable csv output")
cceval_args = parser.add_argument_group("Connected Component Evaluation")
cceval_args.add_argument("-T", "--cc-threshold-tp", type=float, default=1.0,
help="ratio of pixels required for a true positive")
help="ratio of pixels required for a true positive")
cceval_args.add_argument("-F", "--cc-threshold-fp", type=float, default=0.1,
help="ratio of pixels required for a false positive")
help="ratio of pixels required for a false positive")
cceval_args.add_argument("-M", "--cc-threshold-mask", type=float, default=1.0,
help="ratio of pixels required for mask component to be considered text")
help="ratio of pixels required for mask component to be considered text")
parser.add_argument("--verify-filenames", action="store_true")
parser.add_argument("--singleclass", action="store_true", help="evaluate as"
"binary classificator by treating background and image class as same")
"binary classificator by treating background and image class as same")
args = parser.parse_args()
#if args.csv and args.verbose:
# if args.csv and args.verbose:
# parser.error("--csv and --verbose are currently not compatible")
if bool(args.color_map_model) != bool(args.color_map_eval):
......@@ -60,12 +57,12 @@ def main():
# image_tpfpfn_cc = np.zeros([3])
if args.color_map_model and args.color_map_eval:
model_map = load_image_map_from_file(args.color_map_model)
eval_map = load_image_map_from_file(args.color_map_eval)
model_map = ImageMap.load(args.color_map_model)
eval_map = ImageMap.load(args.color_map_eval)
else:
model_map = eval_map = {(255, 255, 255): [0, 'background'],
(0, 255, 0): [1, 'text'],
(255, 0, 255): [2, 'image']}
model_map = eval_map = ImageMap({(255, 255, 255): (0, 'background'),
(0, 255, 0): (1, 'text'),
(255, 0, 255): (2, 'image')})
# for mask_p, pred_p, bin_p in tqdm(zip(args.masks, args.preds, args.binary)):
......@@ -75,15 +72,15 @@ def main():
text_tpfpfn_cc = np.zeros([3])
parfunc = partial(eval_page, eval_map=eval_map, model_map=model_map,
verbose=args.verbose, csv=args.csv, singleclass=args.singleclass,
args=args)
verbose=args.verbose, csv=args.csv, singleclass=args.singleclass,
args=args)
if args.csv and args.verbose:
print('Image,Category,TP,FP,FN,Precision,Recall,F1')
with multiprocessing.Pool(processes=args.threads) as p:
for match in tqdm(p.imap(parfunc, zip(args.masks, args.preds, args.binary)), total=len(args.masks)):
#for page in zip(args.masks, args.preds, args.binary):
# match = eval_page(page, eval_map=eval_map, model_map=model_map, verbose=args.verbose)
# for page in zip(args.masks, args.preds, args.binary):
# match = eval_page(page, eval_map=eval_map, model_map=model_map, verbose=args.verbose)
text_tpfpfn += match.text
image_tpfpfn += match.image
correct_total += match.accuracy
......@@ -129,22 +126,21 @@ def csv_total(category: str, counts: np.ndarray):
.format(category, ttp, tfp, tfn, *f1_measures(ttp, tfp, tfn))
def eval_page(page, eval_map, model_map, verbose, csv, singleclass, args):
def eval_page(page: Tuple[str, str, str], eval_map: ImageMap, model_map: ImageMap, verbose, csv, singleclass, args):
mask_p, pred_p, bin_p = page
mask = rgb_to_label(imread(mask_p), eval_map)
pred = rgb_to_label(imread(pred_p), model_map)
mask = eval_map.imread_labels(mask_p)
pred = model_map.imread_labels(pred_p)
if singleclass:
pred[pred == 0] = 2
fg = imread_bin(bin_p)
cceval = ConnectedComponentEval(mask, pred, fg).only_label(1,
args.cc_threshold_mask)
args.cc_threshold_mask)
text_cc_eval = list(cceval.run_per_component(cc_matching(1,
threshold_tp=args.cc_threshold_tp,
threshold_fp=args.cc_threshold_fp,
threshold_mask=args.cc_threshold_mask,
assume_filtered=True)))
threshold_tp=args.cc_threshold_tp,
threshold_fp=args.cc_threshold_fp,
threshold_mask=args.cc_threshold_mask)))
if len(text_cc_eval) == 0:
text_matches_cc = [0, 0, 0]
else:
......
import argparse
import multiprocessing
from ocr4all_pixel_classifier.lib.image_map import compute_image_map
from ocr4all.image_map import compute_from_images
def main():
......@@ -21,7 +21,7 @@ def main():
help="Number of threads to use")
args = parser.parse_args()
compute_image_map(args.input_dir, args.output_dir, args.max_image, args.threads)
compute_from_images(args.input_dir, args.output_dir, args.max_image, args.threads)
if __name__ == '__main__':
......
import tensorflow as tf
import logging
import tensorflow as tf
logger = logging.getLogger(__name__)
......@@ -22,7 +23,7 @@ def migrate_model(path_to_meta, n_classes, l_rate, output_path):
for var in vars_global:
model_vars[var.name] = var.eval()
from ocr4all_pixel_classifier.lib.metrics import fgpa, accuracy, loss
from ocr4all_pixel_classifier.lib.metrics import accuracy, loss
input_image = tf.keras.layers.Input((None, None, 1))
input_binary = tf.keras.layers.Input((None, None, 1))
......
import argparse
import json
import os
import numpy as np
from typing import Generator, List, Callable, Optional, Union
from tqdm import tqdm
import numpy as np
from ocr4all_pixel_classifier.lib.dataset import DatasetLoader, SingleData
from ocr4all_pixel_classifier.lib.image_map import load_image_map_from_file, DEFAULT_IMAGE_MAP
from ocr4all_pixel_classifier.lib.image_ops import compute_char_height
from ocr4all_pixel_classifier.lib.output import output_data, scale_to_original_shape
from ocr4all_pixel_classifier.lib.postprocess import find_postprocessor, postprocess_help
from ocr4all_pixel_classifier.lib.predictor import Predictor, PredictSettings, Prediction
from ocr4all_pixel_classifier.lib.image_map import load_image_map_from_file, DEFAULT_IMAGE_MAP
from ocr4all_pixel_classifier.lib.util import glob_all, preserving_resize
from ocr4all_pixel_classifier.lib.util import glob_all
from tqdm import tqdm
def main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment