Commit 7ddd3dc5 authored by Alexander Gehrke's avatar Alexander Gehrke

Refactoring and cleanup in line with backend and pylib

parent 96f74c21
......@@ -22,7 +22,7 @@ for book in book*; do
# - text_nontext: red for text, green for non_text
# - baseline: draw baselines of text lines
# - textline: draw polygons of text lines
# we set --image_map_dir to the current directory, which will overwrite the
# we set --color-map-dir to the current directory, which will overwrite the
# output in each loop, but the image map generated is constant for each
# setting and only one file is needed for training
ocr4all-pixel-classifier gen-masks \
......@@ -30,7 +30,7 @@ for book in book*; do
--output-dir $book/masks \
--threads $(nproc) \
--setting text_nontext \
--image-map_dir ./
--color-map_dir ./
# Estimate the xheight for all pages based on connected components in binary
# image.
......
......@@ -21,7 +21,7 @@ ocr4all-pixel-classifier train \
--n-epoch 100 \
--early-stopping-max-performance-drops 30 \
--output my-model \
--color_map image_map.json
--color-map color_map.json
# if using a split file:
ocr4all-pixel-classifier train \
......@@ -29,7 +29,7 @@ ocr4all-pixel-classifier train \
-E 100 \
-S 30 \
--output my-model \
--color_map image_map.json
--color-map color_map.json
# you can also use --load to specify an existing model on which to continue
......
......@@ -3,7 +3,7 @@ import functools
import multiprocessing
import os
from ocr4all.image_map import ImageMap
from ocr4all.colors import ColorMap
from tqdm import tqdm
......@@ -27,8 +27,8 @@ def main():
args = parser.parse_args()
in_map = ImageMap.load(args.input_color_map)
out_map = ImageMap.load(args.output_color_map)
in_map = ColorMap.load(args.input_color_map)
out_map = ColorMap.load(args.output_color_map)
os.makedirs(args.output_dir, exist_ok=True)
......@@ -42,7 +42,7 @@ def main():
pass
def convert(img_path, outdir, in_map: ImageMap, out_map: ImageMap):
def convert(img_path, outdir, in_map: ColorMap, out_map: ColorMap):
labels = in_map.imread_labels(img_path)
res = out_map.to_image(labels)
res.save(os.path.join(outdir, os.path.basename(img_path)))
......
......@@ -6,7 +6,7 @@ from typing import Tuple
import numpy as np
from ocr4all.files import imread_bin, match_filenames
from ocr4all.image_map import ImageMap
from ocr4all.colors import ColorMap
from ocr4all_pixel_classifier.lib.evaluation import count_matches, total_accuracy, f1_measures, ConnectedComponentEval, \
cc_matching
from tqdm import tqdm
......@@ -57,10 +57,10 @@ def main():
# image_tpfpfn_cc = np.zeros([3])
if args.color_map_model and args.color_map_eval:
model_map = ImageMap.load(args.color_map_model)
eval_map = ImageMap.load(args.color_map_eval)
model_map = ColorMap.load(args.color_map_model)
eval_map = ColorMap.load(args.color_map_eval)
else:
model_map = eval_map = ImageMap({(255, 255, 255): (0, 'background'),
model_map = eval_map = ColorMap({(255, 255, 255): (0, 'background'),
(0, 255, 0): (1, 'text'),
(255, 0, 255): (2, 'image')})
......@@ -126,7 +126,7 @@ def csv_total(category: str, counts: np.ndarray):
.format(category, ttp, tfp, tfn, *f1_measures(ttp, tfp, tfn))
def eval_page(page: Tuple[str, str, str], eval_map: ImageMap, model_map: ImageMap, verbose, csv, singleclass, args):
def eval_page(page: Tuple[str, str, str], eval_map: ColorMap, model_map: ColorMap, verbose, csv, singleclass, args):
mask_p, pred_p, bin_p = page
mask = eval_map.imread_labels(mask_p)
pred = model_map.imread_labels(pred_p)
......
......@@ -2,24 +2,22 @@ import argparse
import json
import os
import os.path
from dataclasses import dataclass
from typing import Tuple, Optional, List, Callable, Generator, Type, Any
import numpy as np
from dataclasses import dataclass
from pypagexml.ds import TextRegionTypeSub, CoordsTypeSub, ImageRegionTypeSub
from tqdm import tqdm
from ocr4all_pixel_classifier.lib.dataset import SingleData, color_to_label, label_to_colors, DatasetLoader
from ocr4all_pixel_classifier.lib.image_map import load_image_map_from_file, DEFAULT_IMAGE_MAP
from ocr4all.colors import ColorMap, DEFAULT_COLOR_MAPPING, DEFAULT_LABELS_BY_NAME
from ocr4all.files import glob_all, imread, imread_bin
from ocr4all_pixel_classifier.lib.dataset import SingleData, DatasetLoader
from ocr4all_pixel_classifier.lib.image_ops import compute_char_height
from ocr4all_pixel_classifier.lib.output import Masks
from ocr4all_pixel_classifier.lib.pc_segmentation import find_segments, get_text_contours
from ocr4all_pixel_classifier.lib.predictor import PredictSettings, Predictor
from ocr4all_pixel_classifier.lib.output import Masks
from ocr4all_pixel_classifier.lib.util import glob_all, imread, imread_bin
from ocr4all_pixel_classifier.lib.xycut import render_regions, \
render_morphological, render_xycut, AnyRegion
from ocr4all_pixel_classifier.lib.image_ops import compute_char_height
from ocr4all_pixel_classifier.lib.render import render_regions, \
render_morphological, render_xycut
from ocr4all_pixel_classifier.lib.xycut import AnyRegion
from pypagexml.ds import TextRegionTypeSub, CoordsTypeSub, ImageRegionTypeSub
from tqdm import tqdm
@dataclass
......@@ -72,12 +70,13 @@ def main():
help="load an existing model")
args = parser.parse_args()
image_map, rev_image_map = post_process_args(args, parser)
process_args(args, parser)
color_map, labels_by_name = process_color_map_args(args)
if not args.existing_preds_inverted:
results = predict_and_segment(args, image_map, rev_image_map)
results = predict_and_segment(args, color_map)
else:
results = segment_existing(args, image_map, rev_image_map)
results = segment_existing(args, color_map)
for result in results:
create_pagexml(result, args.xml_output_dir, args.strip_extension)
......@@ -90,31 +89,30 @@ def main():
render_regions(args.render_output_dir, args.render,
result.original_shape,
result.path,
rev_image_map,
render_method,
label_colors=color_map,
method=render_method,
segments_text=result.text_segments,
segments_image=result.image_segments)
def predict_and_segment(args, image_map, rev_image_map) -> Generator[SegmentationResult, None, None]:
def segment_new_predictions(binary_path, char_height, image_map, image_path, rev_image_map):
def predict_and_segment(args, color_map: ColorMap) -> Generator[SegmentationResult, None, None]:
def segment_new_predictions(binary_path, char_height, color_map: ColorMap, image_path):
masks = create_predictions(args.model, image_path, binary_path, char_height, args.target_line_height,
args.gpu_allow_growth, image_map)
args.gpu_allow_growth, color_map)
overlay = masks.inverted_overlay
if args.method == 'xycut':
text, image = find_segments(overlay.shape[0], overlay, char_height, args.resize_height, rev_image_map)
text, image = find_segments(overlay.shape[0], overlay, char_height, args.resize_height)
elif args.method == 'morph':
text = get_text_contours(masks.fg_color_mask, char_height, rev_image_map)
text = get_text_contours(masks.fg_color_mask, char_height, color_map)
_, image = find_segments(masks.inverted_overlay.shape[0], masks.inverted_overlay, char_height,
args.resize_height,
rev_image_map, only_images=True)
args.resize_height, color_map, only_images=True)
else:
raise Exception("unknown method")
return SegmentationResult(text, image, overlay.shape[0:2], image_path)
results = (
segment_new_predictions(binary_path, char_height, image_map, image_path, rev_image_map)
segment_new_predictions(binary_path, char_height, color_map, image_path)
for image_path, binary_path, char_height in
tqdm(
zip(args.image_paths, args.binary_paths, args.all_char_heights), unit='pages',
......@@ -122,32 +120,31 @@ def predict_and_segment(args, image_map, rev_image_map) -> Generator[Segmentatio
return results
def segment_existing(args, image_map, rev_image_map) -> Generator[SegmentationResult, None, None]:
def segment_existing_pred(binary_path, char_height, color_path, image_map, inverted_path, rev_image_map):
def segment_existing(args, color_map: ColorMap) -> Generator[SegmentationResult, None, None]:
def segment_existing_pred(binary_path, char_height, color_path, inverted_path):
overlay = imread(inverted_path)
if args.method == 'xycut':
text, image = find_segments(overlay.shape[0], overlay, char_height, args.resize_height, rev_image_map)
text, image = find_segments(overlay.shape[0], overlay, char_height, args.resize_height, color_map)
elif args.method == 'morph':
image, text = segment_existing_morph(binary_path, char_height, color_path, image_map, overlay,
rev_image_map)
image, text = segment_existing_morph(binary_path, char_height, color_path, overlay)
else:
raise Exception("unknown method")
return SegmentationResult(text, image, overlay.shape[0:2], inverted_path)
def segment_existing_morph(binary_path, char_height, color_path, image_map, overlay, rev_image_map):
def segment_existing_morph(binary_path: str, char_height: int, color_path: str, overlay):
binary = imread_bin(binary_path)
color_mask = imread(color_path)
label_mask = color_to_label(color_mask, image_map)
label_mask = color_map.to_labels(color_mask)
label_mask[binary == 0] = 0
fg_color_mask = label_to_colors(label_mask, image_map)
text = get_text_contours(fg_color_mask, char_height, rev_image_map)
_, image = find_segments(overlay.shape[0], overlay, char_height, args.resize_height,
rev_image_map, only_images=True)
fg_color_mask = color_map.to_rgb_array(label_mask)
text = get_text_contours(fg_color_mask, char_height, color_map)
_, image = find_segments(overlay.shape[0], overlay, char_height, args.resize_height, color_map,
only_images=True)
return image, text
results = (
segment_existing_pred(binary_path, char_height, color_path, image_map, inverted_path, rev_image_map)
segment_existing_pred(binary_path, char_height, color_path, inverted_path)
for binary_path, inverted_path, color_path, char_height in
tqdm(
zip(args.binary_paths, args.existing_inverted_path, args.existing_color_path, args.all_char_heights),
......@@ -156,33 +153,27 @@ def segment_existing(args, image_map, rev_image_map) -> Generator[SegmentationRe
return results
def post_process_args(args, parser):
if args.existing_preds_inverted and ((args.existing_preds_color and args.binary) or args.method == "xycut"):
args.existing_inverted_path = sorted(glob_all(args.existing_preds_inverted))
num_files = len(args.existing_inverted_path)
if args.method == "morph":
args.existing_color_path = sorted(glob_all(args.existing_preds_color))
args.binary_paths = sorted(glob_all(args.binary))
else:
args.existing_color_path = [None] * num_files
args.binary_paths = [None] * num_files
elif args.method == "morph" \
and (args.existing_preds_color or args.existing_preds_inverted) \
and not (args.existing_preds_color and args.existing_preds_inverted and args.binary):
return parser.error("Morphology method requires binaries and both existing predictions.\n"
"If you want to create new predictions, do not pass -e or -c.")
elif args.binary:
args.binary_paths = sorted(glob_all(args.binary))
args.image_paths = sorted(glob_all(args.images)) if args.images else args.binary_paths
num_files = len(args.binary_paths)
elif args.method == "morph":
return parser.error("Morphology method requires binary images.")
else:
return parser.error("Prediction requires binary images. Either supply binaries or existing preds")
def process_args(args, parser):
num_files = process_image_args(args, parser)
if not args.existing_preds_inverted and args.model is None:
return parser.error("Prediction requires a model")
process_normalization_args(args, num_files, parser)
def process_color_map_args(args):
if args.color_map:
color_map = ColorMap.load(args.color_map)
label_names = color_map.label_by_name
else:
color_map = ColorMap(DEFAULT_COLOR_MAPPING)
label_names = DEFAULT_LABELS_BY_NAME
return color_map, label_names
def process_normalization_args(args, num_files, parser):
if args.char_height:
args.all_char_heights = [args.char_height] * num_files
elif args.norm:
......@@ -199,20 +190,36 @@ def post_process_args(args, parser):
args.all_char_heights = [compute_char_height(image, True)
for image in tqdm(args.binary, desc="Auto-detecting char height", unit="pages")]
if args.color_map:
image_map = load_image_map_from_file(args.color_map)
else:
image_map = DEFAULT_IMAGE_MAP
rev_image_map = {v[1]: np.array(k) for k, v in image_map.items()}
return image_map, rev_image_map
def process_image_args(args, parser):
if args.existing_preds_inverted and ((args.existing_preds_color and args.binary) or args.method == "xycut"):
args.existing_inverted_path = sorted(glob_all(args.existing_preds_inverted))
num_files = len(args.existing_inverted_path)
if args.method == "morph":
args.existing_color_path = sorted(glob_all(args.existing_preds_color))
args.binary_paths = sorted(glob_all(args.binary))
else:
args.existing_color_path = [None] * num_files
args.binary_paths = [None] * num_files
elif args.method == "morph" \
and (args.existing_preds_color or args.existing_preds_inverted) \
and not (args.existing_preds_color and args.existing_preds_inverted and args.binary):
parser.error("Morphology method requires binaries and both existing predictions.\n"
"If you want to create new predictions, do not pass -e or -c.")
elif args.binary:
args.binary_paths = sorted(glob_all(args.binary))
args.image_paths = sorted(glob_all(args.images)) if args.images else args.binary_paths
num_files = len(args.binary_paths)
elif args.method == "morph":
parser.error("Morphology method requires binary images.")
else:
parser.error("Prediction requires binary images. Either supply binaries or existing preds")
return num_files
def predict_masks(output: Optional[str],
data: SingleData,
color_map: dict,
line_height: int,
color_map: ColorMap,
model: str,
post_processors: Optional[List[Callable[[np.ndarray, SingleData], np.ndarray]]] = None,
gpu_allow_growth: bool = False,
......@@ -231,24 +238,20 @@ def predict_masks(output: Optional[str],
return predictor.predict_masks(data)
def create_predictions(model, image_path, binary_path, char_height, target_line_height, gpu_allow_growth,
image_map=None):
if image_map is None:
image_map = DEFAULT_IMAGE_MAP
image = imread(image_path)
binary = imread_bin(binary_path)
def create_predictions(model: str, image_path: str, binary_path: str, char_height: int, target_line_height: int,
gpu_allow_growth: bool, color_map: ColorMap = None):
if color_map is None:
color_map = ColorMap(DEFAULT_COLOR_MAPPING)
dataset_loader = DatasetLoader(target_line_height, prediction=True, color_map=image_map)
dataset_loader = DatasetLoader(target_line_height, prediction=True, color_map=color_map)
data = dataset_loader.load_data(
[SingleData(binary_path=binary_path, image_path=image_path, line_height_px=target_line_height)]
[SingleData(binary_path=binary_path, image_path=image_path, line_height_px=char_height)]
).data[0]
return predict_masks(None,
data,
image_map,
char_height,
color_map,
model=model,
post_processors=None,
gpu_allow_growth=gpu_allow_growth,
......
......@@ -22,8 +22,8 @@ def main():
conf_args = parser.add_argument_group("optional arguments")
conf_args.add_argument("-h", "--help", action="help", help="show this help message and exit")
conf_args.add_argument("-M", "--image-map-dir", type=str, default=None,
help="location for writing the image map")
conf_args.add_argument("-M", "--color-map-dir", type=str, default=None,
help="location for writing the color map")
conf_args.add_argument("-s", '--setting',
default='all_types',
choices=[t.value for t in MaskType],
......@@ -63,10 +63,10 @@ def main():
for file in files:
mask_gen.save(file, args.output_dir)
if args.image_map_dir:
with open(os.path.join(args.image_map_dir, 'image_map.json'), 'w') as fp:
if args.color_map_dir:
with open(os.path.join(args.color_map_dir, 'color_map.json'), 'w') as fp:
import json
json.dump(PageXMLTypes.image_map(MaskType(args.setting)), fp)
json.dump(PageXMLTypes.color_map(MaskType(args.setting)), fp)
if __name__ == '__main__':
......
import argparse
import multiprocessing
from ocr4all.image_map import compute_from_images
from ocr4all.colors import compute_from_images
def main():
......
import argparse
from ocr4all.image_map import ImageMap
from ocr4all.colors import ColorMap
def main():
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('vars', metavar='IMAGEMAP', type=str, nargs='+',
help='image map json files')
parser.add_argument('vars', metavar='ColorMap', type=str, nargs='+',
help='color map json files')
opt_args = parser.add_argument_group("optional arguments")
opt_args.add_argument("-h", "--help", action="help", help="show this help message and exit")
args = parser.parse_args()
for imap in args.vars:
map = ImageMap.load(imap)
map = ColorMap.load(imap)
for k, v in map.mapping.items():
print(f"\x1b[48;2;{k[0]};{k[1]};{k[2]}m \x1b[0m \x1b[38;2;{k[0]};{k[1]};{k[2]}m{v}\x1b[0m\n")
......
......@@ -29,15 +29,15 @@ commands = {
'main': 'main',
'help': 'Compute image normalizations'
},
'compute-image-map': {
'script': 'ocr4all_pixel_classifier_frontend.generate_image_map',
'compute-color-map': {
'script': 'ocr4all_pixel_classifier_frontend.generate_color_map',
'main': 'main',
'help': 'Generates color map'
},
'inspect-image-map': {
'script': 'ocr4all_pixel_classifier_frontend.inspect_image_map',
'inspect-color-map': {
'script': 'ocr4all_pixel_classifier_frontend.inspect_color_map',
'main': 'main',
'help': 'Displays image map on a color-enabled terminal'
'help': 'Displays color map on a color-enabled terminal'
},
'convert-colors': {
'script': 'ocr4all_pixel_classifier_frontend.convert_colors',
......
......@@ -26,7 +26,6 @@ def migrate_model(path_to_meta, n_classes, l_rate, output_path):
from ocr4all_pixel_classifier.lib.metrics import accuracy, loss
input_image = tf.keras.layers.Input((None, None, 1))
input_binary = tf.keras.layers.Input((None, None, 1))
from ocr4all_pixel_classifier.lib.model import model_fcn_skip
model = model_fcn_skip([input_image], n_classes)
......@@ -34,9 +33,9 @@ def migrate_model(path_to_meta, n_classes, l_rate, output_path):
model.compile(optimizer=optimizer, loss=loss, metrics=[accuracy])
keys = list(model_vars.keys())
counter = 0
for l in model.layers:
if len(l.get_weights()) > 0:
l.set_weights([model_vars[keys[counter]], model_vars[keys[counter + 1]]])
for layer in model.layers:
if len(layer.get_weights()) > 0:
layer.set_weights([model_vars[keys[counter]], model_vars[keys[counter + 1]]])
counter += 2
pass
model.save(output_path)
......
......@@ -5,12 +5,12 @@ from typing import Generator, List, Callable, Optional, Union
import numpy as np
from ocr4all_pixel_classifier.lib.dataset import DatasetLoader, SingleData
from ocr4all_pixel_classifier.lib.image_map import load_image_map_from_file, DEFAULT_IMAGE_MAP
from ocr4all_pixel_classifier.lib.image_ops import compute_char_height
from ocr4all_pixel_classifier.lib.output import output_data, scale_to_original_shape
from ocr4all_pixel_classifier.lib.postprocess import find_postprocessor, postprocess_help
from ocr4all_pixel_classifier.lib.predictor import Predictor, PredictSettings, Prediction
from ocr4all_pixel_classifier.lib.util import glob_all
from ocr4all.files import glob_all
from ocr4all.colors import ColorMap, DEFAULT_COLOR_MAPPING
from tqdm import tqdm
......@@ -48,8 +48,6 @@ def main():
image_file_paths = sorted(glob_all(args.images))
binary_file_paths = sorted(glob_all(args.binary))
norm_file_paths = sorted(glob_all(args.norm)) if args.norm else []
if len(image_file_paths) != len(binary_file_paths):
parser.error("Got {} images but {} binary images".format(len(image_file_paths), len(binary_file_paths)))
......@@ -57,23 +55,7 @@ def main():
num_files = len(image_file_paths)
if args.char_height:
line_heights = [args.char_height] * num_files
elif args.norm:
print(f'norm: {args.norm}')
norm_file_paths = sorted(glob_all(args.norm)) if args.norm else []
if len(norm_file_paths) == 1:
line_heights = [json.load(open(norm_file_paths[0]))["char_height"]] * num_files
else:
if len(norm_file_paths) != num_files:
parser.error("Number of norm files must be one or equals the number of image files")
line_heights = [json.load(open(n))["char_height"] for n in norm_file_paths]
else:
if not args.binary:
parser.error("No binary files given, cannot auto-detect char height")
line_heights = [compute_char_height(image, True)
for image in tqdm(args.binary, desc="Auto-detecting char height", unit="pages")]
line_heights = parse_line_heights(args, num_files, parser)
post_processors = [find_postprocessor(p) for p in args.postprocess]
if args.cc_majority:
......@@ -83,14 +65,14 @@ def main():
os.makedirs(args.output, exist_ok=True)
if args.color_map:
image_map = load_image_map_from_file(args.color_map)
color_map = ColorMap.load(args.color_map)
else:
image_map = DEFAULT_IMAGE_MAP
color_map = ColorMap(DEFAULT_COLOR_MAPPING)
predictions = predict(args.output,
binary_file_paths,
image_file_paths,
image_map,
color_map,
line_heights,
target_line_height=args.target_line_height,
models=args.load,
......@@ -100,13 +82,32 @@ def main():
)
for _, prediction in tqdm(enumerate(predictions)):
output_data(args.output, prediction.labels, prediction.data, image_map)
output_data(args.output, prediction.labels, prediction.data, color_map)
def parse_line_heights(args, num_files, parser):
if args.char_height:
line_heights = [args.char_height] * num_files
elif args.norm:
norm_file_paths = sorted(glob_all(args.norm)) if args.norm else []
if len(norm_file_paths) == 1:
line_heights = [json.load(open(norm_file_paths[0]))["char_height"]] * num_files
else:
if len(norm_file_paths) != num_files:
parser.error("Number of norm files must be one or equals the number of image files")
line_heights = [json.load(open(n))["char_height"] for n in norm_file_paths]
else:
if not args.binary:
parser.error("No binary files given, cannot auto-detect char height")
line_heights = [compute_char_height(image, True)
for image in tqdm(args.binary, desc="Auto-detecting char height", unit="pages")]
return line_heights
def predict(output,
binary_file_paths: List[str],
image_file_paths: List[str],
color_map: dict,
color_map: ColorMap,
line_heights: Union[List[int], int],
target_line_height: int,
models: List[str],
......
......@@ -38,7 +38,7 @@ def main():
help="Generate tensorboard logs")
parser.add_argument("--reduce-lr-on-plateau", action="store_true",
help="Reduce learn rate when on plateau")
parser.add_argument("--color-map", type=str, default="image_map.json",
parser.add_argument("--color-map", type=str, default=None,
help="color map to load")
parser.add_argument('--architecture',
default=Architecture.FCN_SKIP,
......@@ -60,7 +60,7 @@ def main():
parser.add_argument("--split_file", type=str, help=argparse.SUPPRESS)
parser.add_argument("--foreground_masks", action="store_true", help=argparse.SUPPRESS)
parser.add_argument("--reduce_lr_on_plateau", action="store_true", help=argparse.SUPPRESS)
parser.add_argument("--color_map", type=str, default="image_map.json", help=argparse.SUPPRESS)
parser.add_argument("--color_map", type=str, default=None, help=argparse.SUPPRESS)
parser.add_argument("--gpu_allow_growth", action="store_true", help=argparse.SUPPRESS)
args = parser.parse_args()
......@@ -86,11 +86,17 @@ def main():
args.eval += relpaths(reldir, d["eval"])
from ocr4all_pixel_classifier.lib.dataset import DatasetLoader
from ocr4all_pixel_classifier.lib.image_map import load_image_map_from_file
from ocr4all.colors import ColorMap
from ocr4all_pixel_classifier.lib.metrics import Loss
image_map = load_image_map_from_file(args.color_map)
dataset_loader = DatasetLoader(args.target_line_height, image_map)
if args.color_map is None:
import os