Fix xml output path and allow manual removal of file extensions

......@@ -58,6 +58,8 @@ def main():
help="target directory for PageXML output")
parser.add_argument("-O", "--render-output-dir", type=str, default=None,
help="target directory for rendered output")
parser.add_argument("--strip-extension", type=str, default=None,
help="Remove this extension from the file name to generate the xml output file name (default: everything from last dot)")
parser.add_argument("--load", "--model", type=str, default=None, dest="model",
help="load an existing model")
parser.add_argument("--image-map", "--color_map", type=str, default=None, help="color_map to load",
......@@ -78,7 +80,7 @@ def main():
results = segment_existing(args, image_map, rev_image_map)
for result in results:
create_pagexml(result, args.xml_output_dir)
create_pagexml(result, args.xml_output_dir, args.strip_extension)
if args.render:
render_method = {
'xycut': render_xycut,
......@@ -168,7 +170,7 @@ def post_process_args(args, parser):
and (args.existing_preds_color or args.existing_preds_inverted) \
and not (args.existing_preds_color and args.existing_preds_inverted and args.binary):
return parser.error("Morphology method requires binaries and both existing predictions.\n"
"If you want to create new predictions, do not pass -e or -c.")
"If you want to create new predictions, do not pass -e or -c.")
elif args.binary:
args.binary_paths = sorted(glob_all(args.binary))
args.image_paths = sorted(glob_all(args.images)) if args.images else args.binary_paths
......@@ -207,7 +209,6 @@ def post_process_args(args, parser):
return image_map, rev_image_map
def predict_masks(output: Optional[str],
data: SingleData,
color_map: dict,
......@@ -216,7 +217,6 @@ def predict_masks(output: Optional[str],
post_processors: Optional[List[Callable[[np.ndarray, SingleData], np.ndarray]]] = None,
gpu_allow_growth: bool = False,
) -> Masks:
settings = PredictSettings(
......@@ -255,7 +255,7 @@ def create_predictions(model, image_path, binary_path, char_height, target_line_
def create_pagexml(result: SegmentationResult, output_dir: Optional[str] = None):
def create_pagexml(result: SegmentationResult, output_dir: Optional[str] = None, strip_extension: Optional[str] = None):
import pypagexml as pxml
meta = pxml.ds.MetadataTypeSub(Creator="ocr4all_pixel_classifier_frontend", Created=pxml.ds.iso_now())
doc = pxml.new_document_from_image(result.path, meta)
......@@ -274,10 +274,14 @@ def create_pagexml(result: SegmentationResult, output_dir: Optional[str] = None)
add_segment(imageseg, i, "ir", ImageRegionTypeSub, doc.get_Page().add_ImageRegion)
if output_dir is None:
output_file = os.path.splitext(result.path)[0] + ".xml"
output_dir = os.path.dirname(result.path)
if strip_extension is None:
output_file = os.path.splitext(os.path.basename(result.path))[0] + ".xml"
output_file = os.path.join(output_dir, os.path.splitext(os.path.basename(output_dir))[0] + ".xml")
doc.saveAs(result.path + ".xml", level=0)
output_file = os.path.basename(result.path).replace(strip_extension, "") + ".xml"
output_path = os.path.join(output_dir, output_file)
doc.saveAs(output_path, level=0)
if __name__ == "__main__":
