Commit d1f9e116 authored by Alexander Gehrke's avatar Alexander Gehrke

Made loading of input in predict and find-segments identical

parent 574ac535
......@@ -11,7 +11,7 @@ from dataclasses import dataclass
from pypagexml.ds import TextRegionTypeSub, CoordsTypeSub, ImageRegionTypeSub
from tqdm import tqdm
from ocr4all_pixel_classifier.lib.dataset import SingleData, color_to_label, label_to_colors
from ocr4all_pixel_classifier.lib.dataset import SingleData, color_to_label, label_to_colors, DatasetLoader
from ocr4all_pixel_classifier.lib.image_map import load_image_map_from_file, DEFAULT_IMAGE_MAP
from ocr4all_pixel_classifier.lib.pc_segmentation import find_segments, get_text_contours
from ocr4all_pixel_classifier.lib.predictor import PredictSettings, Predictor
......@@ -207,16 +207,15 @@ def post_process_args(args, parser):
return image_map, rev_image_map
def predict_masks(output: Optional[str],
image: np.ndarray,
binary: np.ndarray,
data: SingleData,
color_map: dict,
line_height: int,
model: str,
post_processors: Optional[List[Callable[[np.ndarray, SingleData], np.ndarray]]] = None,
gpu_allow_growth: bool = False,
) -> Masks:
data = SingleData(binary=binary, image=image, original_shape=binary.shape, line_height_px=line_height)
settings = PredictSettings(
network=os.path.abspath(model),
......@@ -240,12 +239,14 @@ def create_predictions(model, image_path, binary_path, char_height, target_line_
image = imread(image_path)
binary = imread_bin(binary_path)
from ocr4all_pixel_classifier.lib.dataset import prepare_images
img, bin = prepare_images(image, binary, target_line_height, char_height)
dataset_loader = DatasetLoader(target_line_height, prediction=True, color_map=image_map)
data = dataset_loader.load_data(
[SingleData(binary_path=binary_path, image_path=image_path, line_height_px=target_line_height)]
).data[0]
return predict_masks(None,
img,
bin,
data,
image_map,
char_height,
model=model,
......
......@@ -4,13 +4,14 @@ import os
import numpy as np
from typing import Generator, List, Callable, Optional, Union
import tqdm
from tqdm import tqdm
from ocr4all_pixel_classifier.lib.dataset import DatasetLoader, SingleData
from ocr4all_pixel_classifier.lib.image_ops import compute_char_height
from ocr4all_pixel_classifier.lib.output import output_data, scale_to_original_shape
from ocr4all_pixel_classifier.lib.postprocess import find_postprocessor, postprocess_help
from ocr4all_pixel_classifier.lib.predictor import Predictor, PredictSettings, Prediction
from ocr4all_pixel_classifier.lib.image_map import load_image_map_from_file
from ocr4all_pixel_classifier.lib.image_map import load_image_map_from_file, DEFAULT_IMAGE_MAP
from ocr4all_pixel_classifier.lib.util import glob_all, preserving_resize
......@@ -28,8 +29,8 @@ def main():
help="directory name of the binary images")
parser.add_argument("--images", type=str, required=True, nargs="+",
help="directory name of the images on which to train")
parser.add_argument("--norm", type=str, required=False, nargs="+",
help="directory name of the norms on which to train")
parser.add_argument("-n", "--norm", type=str, required=False, nargs="+",
help="use normalization files for input char height")
parser.add_argument("--keep_low_res", action="store_true",
help="keep low resolution prediction instead of rescaling output to orignal image size")
parser.add_argument("--cc_majority", action="store_true",
......@@ -37,8 +38,7 @@ def main():
parser.add_argument("--postprocess", type=str, nargs="+", default=[],
choices=["cc_majority", "bounding_boxes"],
help="add postprocessor functions to run on the prediction. use 'list' or 'help' to show available postprocessors")
parser.add_argument("--color_map", type=str, required=True,
help="color_map to load")
parser.add_argument("--color_map", type=str, default=None, help="color_map to load")
parser.add_argument("--gpu_allow_growth", action="store_true")
args = parser.parse_args()
......@@ -56,17 +56,25 @@ def main():
print("Loading {} files with character height {}".format(len(image_file_paths), args.char_height))
if not args.char_height and len(norm_file_paths) == 0:
parser.error("either --norm or --char_height must be specified")
num_files = len(image_file_paths)
if args.char_height:
line_heights = [args.char_height] * len(image_file_paths)
elif len(norm_file_paths) == 1:
line_heights = [json.load(open(norm_file_paths[0]))["char_height"]] * len(image_file_paths)
line_heights = [args.char_height] * num_files
elif args.norm:
print(f'norm: {args.norm}')
norm_file_paths = sorted(glob_all(args.norm)) if args.norm else []
if len(norm_file_paths) == 1:
line_heights = [json.load(open(norm_file_paths[0]))["char_height"]] * num_files
else:
if len(norm_file_paths) != num_files:
parser.error("Number of norm files must be one or equals the number of image files")
line_heights = [json.load(open(n))["char_height"] for n in norm_file_paths]
else:
if len(norm_file_paths) != len(image_file_paths):
raise Exception("Number of norm files must be one or equals the number of image files")
line_heights = [json.load(open(n))["char_height"] for n in norm_file_paths]
if not args.binary:
parser.error("No binary files given, cannot auto-detect char height")
line_heights = [compute_char_height(image, True)
for image in tqdm(args.binary, desc="Auto-detecting char height", unit="pages")]
post_processors = [find_postprocessor(p) for p in args.postprocess]
if args.cc_majority:
......@@ -75,7 +83,10 @@ def main():
os.makedirs(args.output, exist_ok=True)
image_map = load_image_map_from_file(args.color_map)
if args.color_map:
image_map = load_image_map_from_file(args.color_map)
else:
image_map = DEFAULT_IMAGE_MAP
predictions = predict(args.output,
binary_file_paths,
......@@ -89,7 +100,7 @@ def main():
gpu_allow_growth=args.gpu_allow_growth,
)
for _, prediction in tqdm.tqdm(enumerate(predictions)):
for _, prediction in tqdm(enumerate(predictions)):
output_data(args.output, prediction.labels, prediction.data, image_map)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment