Commit e26adf52 authored by Norbert Fischer's avatar Norbert Fischer
Browse files

initial commit

parents
import numpy as np
import cv2
import xml.etree.ElementTree as ET
import re
import os
import enum
from dataclasses import dataclass
import random
from math import floor
import sys
from collections import namedtuple
def get_all_files_in_directory(dir):
return sorted([f for f in os.listdir(dir) if os.path.isfile(os.path.join(dir, f))])
def get_all_files_in_directory_with_path(dir):
files = get_all_files_in_directory(dir)
return [os.path.join(dir, f) for f in files]
@dataclass
class PageXMLBinaryPair:
binary_file: str
page_xml_file: str
class PageXMLDataset:
def __init__(self, binary_dir, page_xml_dir, pairs=None):
if pairs:
self.pairs = pairs
self.binary_dir = None
self.page_xml_dir = None
else:
self.binary_dir = binary_dir
self.page_xml_dir = page_xml_dir
bins = get_all_files_in_directory(binary_dir)
pagexmls = filter(lambda x: x.endswith(".xml"), get_all_files_in_directory(page_xml_dir))
self.pairs = []
for pagexml in pagexmls:
# get the basename
pagexml_basename = os.path.basename(pagexml)
# remove the extension
pagexml_without_ext = re.sub(".xml$", "", pagexml_basename)
# find a matching binary image
pagexml_without_ext_with_dot = pagexml_without_ext + "."
for bin in bins:
if os.path.basename(bin).startswith(pagexml_without_ext_with_dot):
self.pairs.append(PageXMLBinaryPair(os.path.join(binary_dir, bin), os.path.join(page_xml_dir, pagexml)))
break
else:
print("Found no match for {}".format(pagexml), file=sys.stderr)
def __len__(self):
return len(self.pairs)
def __iter__(self):
return iter(self.pairs)
def shuffle(self, rng=random.Random()):
rng.shuffle(self.pairs)
return self
def split(self, perc):
assert perc >= 0.0 and perc <= 1.0
training_amount = floor(perc * len(self.pairs))
tr_ds = PageXMLDataset(None, None, self.pairs[:training_amount])
te_ds = PageXMLDataset(None, None, self.pairs[training_amount:])
print("Splitting Dataset: {} training files, {} test files".format(len(tr_ds), len(te_ds)))
return tr_ds, te_ds
def indiv_files(self):
for x in self.pairs:
yield PageXMLDataset(None, None, [x])
@dataclass
class PageXMLPredictionFilePair:
prediction_file: str
page_xml_file: str
@staticmethod
def from_dirs(prediction_dir: str, pagexml_dir: str):
bins = get_all_files_in_directory(prediction_dir)
pagexmls = filter(lambda x: x.endswith(".xml"), get_all_files_in_directory(pagexml_dir))
pairs = []
for pagexml in pagexmls:
# get the basename
pagexml_basename = os.path.basename(pagexml)
# remove the extension
pagexml_without_ext = re.sub(".xml$", "", pagexml_basename)
# find a matching binary image
pagexml_without_ext_with_dot = pagexml_without_ext + "."
for bin in bins:
if os.path.basename(bin).startswith(pagexml_without_ext_with_dot):
pairs.append(PageXMLPredictionFilePair(os.path.join(prediction_dir, bin), os.path.join(pagexml_dir, pagexml)))
break
else:
print("Found no match for {}".format(pagexml), file=sys.stderr)
return pairs
class PageXMLTypes(enum.Enum):
PARAGRAPH ='paragraph'
IMAGE = 'ImageRegion'
HEADING = 'heading'
HEADER = 'header'
CATCH_WORD = 'catch-word'
PAGE_NUMBER = 'page-number'
SIGNATURE_MARK = 'signature-mark'
MARGINALIA = 'marginalia'
OTHER = 'other'
DROP_CAPITAL = 'drop-capital'
FLOATING = 'floating'
CAPTION = 'caption'
ENDNOTE = 'endnote'
IGNORE = 'ignore'
TOCENTRY = 'toc-entry'
FOOTNOTE = 'footnote'
FOOTNOTE_CONTINUED = 'footnote-continued'
FOOTER = 'footer'
EMPTY = ''
def color(self):
return {
PageXMLTypes.PARAGRAPH: (255, 0, 0),
PageXMLTypes.IMAGE: (0, 255, 0),
PageXMLTypes.HEADING: (0, 0, 255),
PageXMLTypes.HEADER: (0, 255, 255),
PageXMLTypes.CATCH_WORD: (255, 255, 0),
PageXMLTypes.PAGE_NUMBER: (255, 0, 255),
PageXMLTypes.SIGNATURE_MARK: (128, 0, 128),
PageXMLTypes.MARGINALIA: (128, 128, 0),
PageXMLTypes.OTHER: (0, 128, 128),
PageXMLTypes.DROP_CAPITAL: (255, 128, 0),
PageXMLTypes.FLOATING: (255, 0, 128),
PageXMLTypes.CAPTION: (128, 255, 0),
PageXMLTypes.ENDNOTE: (0, 255, 128),
PageXMLTypes.IGNORE: (0, 128, 0),
PageXMLTypes.TOCENTRY: (0, 127, 0),
PageXMLTypes.FOOTNOTE: (0, 126, 0),
PageXMLTypes.FOOTNOTE_CONTINUED: (0, 125, 0),
PageXMLTypes.FOOTER: (0, 123, 0),
PageXMLTypes.EMPTY: (0, 124, 0),
}[self]
def is_text(self):
return self is not PageXMLTypes.IMAGE and \
self is not PageXMLTypes.DROP_CAPITAL and \
self is not PageXMLTypes.IGNORE and \
self is not PageXMLTypes.CAPTION
PageRegionBB = namedtuple('PageRegionBB', 'x1 y1 x2 y2 type')
@dataclass
class PageRegion:
polygon: np.array
type: PageXMLTypes
is_text: bool
def bounding_box(self):
xmin = int(np.min(self.polygon[:,0]))
xmax = int(np.max(self.polygon[:,0]))
ymin = int(np.min(self.polygon[:,1]))
ymax = int(np.max(self.polygon[:,1]))
return PageRegionBB(xmin, ymin, xmax, ymax, type)
def shifted(self, x, y):
# shift the region points by x and y and return as new region
new_polygon = self.polygon + np.array([x, y])
return PageRegion(new_polygon, self.type, self.is_text)
def coords_str(self):
# return the coords string as it is written in the pagexml
parts = []
for row in self.polygon:
parts.append("{},{}".format(row[0], row[1]))
return " ".join(parts)
def has_negative_coords(self):
return np.min(self.polygon) < 0
"""
def plot_region(img, coords, color):
points = []
for match in _coords_regex.finditer(coords):
x, y = int(match.group(1)), int(match.group(2))
points.append((x, y))
# fill the poly
return cv2.fillPoly(img, np.array([points]), color)
"""
class PageXMLParser:
_coords_regex = re.compile(r"([0-9]+),([0-9]+)")
def _polygon_from_coords_str(self, coords_str: str) -> np.array:
points = []
for match in PageXMLParser._coords_regex.finditer(coords_str):
x, y = int(match.group(1)), int(match.group(2))
if self.do_rescale:
x = int(x * self.rescale_factor_x) # TODO: maybe use round() here !?!
y = int(y * self.rescale_factor_y)
points.append((x, y))
return np.array(points)
def _parse_region(self, element, namespace):
coords = element.find(namespace + "Coords")
polygon = None
if "points" in coords.attrib:
polygon = self._polygon_from_coords_str(str(coords.attrib["points"]))
else:
# hopefully there is a
point_elems = coords.findall(namespace + "Point")
points = []
if self.do_rescale:
for elem in point_elems:
points.append((int(float(elem.attrib["x"]) * self.rescale_factor_x), int(float(elem.attrib["y"]) * self.rescale_factor_y)))
# TODO: maybe round here !?!
else:
for elem in point_elems:
points.append((int(elem.attrib["x"]), int(elem.attrib["y"])))
polygon = np.array(points)
polygon.shape[0] != 0 and polygon.shape[1] == 2 and \
len(polygon.shape) == 2, \
"Invalid polygon. Maybe wrong xml format?"
is_text = False
if element.tag.endswith("TextRegion"):
if str(element.get("type")) not in set(["drop-capital"]):
is_text = True
# deduct the type from the element
if element.get("type") is None:
if not is_text:
type = PageXMLTypes.IMAGE
else:
print("Region without type attribute.. Assuming base type")
type = PageXMLTypes.PARAGRAPH
else:
type = PageXMLTypes(element.get('type'))
if is_text and not type.is_text():
is_text = False
print(is_text, type, element.get("type"))
return PageRegion(polygon=polygon, type=type, is_text=is_text)
@staticmethod
def from_file(filename, rescale=None):
with open(filename) as f:
return PageXMLParser(f.read(), filename, rescale=rescale)
def __init__(self, pagexml, filename=None, rescale=None):
tree = ET.ElementTree(ET.fromstring(pagexml))
root = tree.getroot()
namespace = root.tag.split('}')[0] + "}"
page = root.find(namespace + "Page")
# gather some things from the PageXML header
self.image_filename, self.image_height, self.image_width = \
str(page.attrib["imageFilename"]), \
int(page.attrib["imageHeight"]), \
int(page.attrib["imageWidth"])
# really stupid hack, but source data is wrong :(
if self.image_width > self.image_height:
print("Image width greater than height.. Skipping rotate.. but output data might be wrong")
#self.image_width, self.image_height = self.image_height, self.image_width
if rescale is None or (rescale[0] == self.image_width and rescale[1] == self.image_height):
self.do_rescale = False
self.rescale_factor_x = 1
self.rescale_factor_y = 1
else:
self.do_rescale = True
self.rescale_factor_x = rescale[0] / self.image_width
self.rescale_factor_y = rescale[1] / self.image_height
self.image_width = rescale[0]
self.image_height = rescale[1]
self.regions = []
for text_region in page.findall(namespace + "TextRegion"):
try:
self.regions.append(self._parse_region(text_region, namespace))
except AssertionError as e:
print("Error: {}".format(e))
print("Cannot parse TextRegion in file. Continuing. Filename: {}".format(filename))
for image_region in page.findall(namespace + "ImageRegion"):
try:
self.regions.append(self._parse_region(image_region, namespace))
except Exception as e:
print("Error: {}".format(e))
print("Cannot parse ImageRegion in file. Continuing. Filename: {}".format(filename if filename else "?"))
def regions_with_type(self, type: PageXMLTypes):
# filter by type
return list(filter(lambda x: x.type is type, self.regions))
def get_mask_image(self, binary=False, glyph_color=(0,255,0), other_color=(255,0,255)):
target_image = np.full((self.image_height, self.image_width, 3), 255, dtype=np.uint8)
for region in self.regions:
color = glyph_color if region.is_text else other_color
cv2.fillPoly(target_image, np.array([region.polygon]),color)
return target_image
def get_labeled_image(self):
# label the image for each region
target_image = np.zeros((self.image_height,self.image_width),dtype=np.uint16)
for i, region in enumerate(self.regions):
cv2.fillPoly(target_image, np.array([region.polygon]), [i+1])
return target_image
def contains_images(self):
for x in self.regions:
if x.is_text is False:
return True
return False
def __len__(self):
return len(self.regions)
def __iter__(self):
return iter(self.regions)
def build_pagexml(binary_filename, image_shape, regions):
header = """<?xml version="1.0" encoding="UTF-8"?>
<PcGts xmlns="http://schema.primaresearch.org/PAGE/gts/pagecontent/2017-07-15" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://schema.primaresearch.org/PAGE/gts/pagecontent/2017-07-15 http://schema.primaresearch.org/PAGE/gts/pagecontent/2017-07-15/pagecontent.xsd">
<Metadata>
<Creator />
<Created>2019-12-08T01:36:01</Created>
<LastChange>1970-01-01T01:00:00</LastChange>
<Comments />
</Metadata>
"""
page_section = '<Page imageFilename="{}" imageHeight="{}" imageWidth="{}">'.format(binary_filename, image_shape[0], image_shape[1])
text_region_template = '<TextRegion id="{}" type="{}"><Coords points="{}" /><TextEquiv><Unicode /></TextEquiv></TextRegion>'
image_region_template = '<ImageRegion id="{}"><Coords points="{}" /></ImageRegion>'
region_strs = []
for rid, region in enumerate(regions, start=1):
if region.is_text:
region_strs.append(text_region_template.format("r{}".format(rid), str(region.type.value), region.coords_str()))
else:
region_strs.append(image_region_template.format("r{}".format(rid), region.coords_str()))
footer = '</Page></PcGts>'
return header + page_section + " ".join(region_strs) + footer
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment