Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Norbert Fischer
pagexml
Commits
e26adf52
Commit
e26adf52
authored
Aug 07, 2020
by
Norbert Fischer
Browse files
initial commit
parents
Changes
1
Hide whitespace changes
Inline
Side-by-side
pagexml.py
0 → 100644
View file @
e26adf52
import
numpy
as
np
import
cv2
import
xml.etree.ElementTree
as
ET
import
re
import
os
import
enum
from
dataclasses
import
dataclass
import
random
from
math
import
floor
import
sys
from
collections
import
namedtuple
def
get_all_files_in_directory
(
dir
):
return
sorted
([
f
for
f
in
os
.
listdir
(
dir
)
if
os
.
path
.
isfile
(
os
.
path
.
join
(
dir
,
f
))])
def
get_all_files_in_directory_with_path
(
dir
):
files
=
get_all_files_in_directory
(
dir
)
return
[
os
.
path
.
join
(
dir
,
f
)
for
f
in
files
]
@
dataclass
class
PageXMLBinaryPair
:
binary_file
:
str
page_xml_file
:
str
class
PageXMLDataset
:
def
__init__
(
self
,
binary_dir
,
page_xml_dir
,
pairs
=
None
):
if
pairs
:
self
.
pairs
=
pairs
self
.
binary_dir
=
None
self
.
page_xml_dir
=
None
else
:
self
.
binary_dir
=
binary_dir
self
.
page_xml_dir
=
page_xml_dir
bins
=
get_all_files_in_directory
(
binary_dir
)
pagexmls
=
filter
(
lambda
x
:
x
.
endswith
(
".xml"
),
get_all_files_in_directory
(
page_xml_dir
))
self
.
pairs
=
[]
for
pagexml
in
pagexmls
:
# get the basename
pagexml_basename
=
os
.
path
.
basename
(
pagexml
)
# remove the extension
pagexml_without_ext
=
re
.
sub
(
".xml$"
,
""
,
pagexml_basename
)
# find a matching binary image
pagexml_without_ext_with_dot
=
pagexml_without_ext
+
"."
for
bin
in
bins
:
if
os
.
path
.
basename
(
bin
).
startswith
(
pagexml_without_ext_with_dot
):
self
.
pairs
.
append
(
PageXMLBinaryPair
(
os
.
path
.
join
(
binary_dir
,
bin
),
os
.
path
.
join
(
page_xml_dir
,
pagexml
)))
break
else
:
print
(
"Found no match for {}"
.
format
(
pagexml
),
file
=
sys
.
stderr
)
def
__len__
(
self
):
return
len
(
self
.
pairs
)
def
__iter__
(
self
):
return
iter
(
self
.
pairs
)
def
shuffle
(
self
,
rng
=
random
.
Random
()):
rng
.
shuffle
(
self
.
pairs
)
return
self
def
split
(
self
,
perc
):
assert
perc
>=
0.0
and
perc
<=
1.0
training_amount
=
floor
(
perc
*
len
(
self
.
pairs
))
tr_ds
=
PageXMLDataset
(
None
,
None
,
self
.
pairs
[:
training_amount
])
te_ds
=
PageXMLDataset
(
None
,
None
,
self
.
pairs
[
training_amount
:])
print
(
"Splitting Dataset: {} training files, {} test files"
.
format
(
len
(
tr_ds
),
len
(
te_ds
)))
return
tr_ds
,
te_ds
def
indiv_files
(
self
):
for
x
in
self
.
pairs
:
yield
PageXMLDataset
(
None
,
None
,
[
x
])
@
dataclass
class
PageXMLPredictionFilePair
:
prediction_file
:
str
page_xml_file
:
str
@
staticmethod
def
from_dirs
(
prediction_dir
:
str
,
pagexml_dir
:
str
):
bins
=
get_all_files_in_directory
(
prediction_dir
)
pagexmls
=
filter
(
lambda
x
:
x
.
endswith
(
".xml"
),
get_all_files_in_directory
(
pagexml_dir
))
pairs
=
[]
for
pagexml
in
pagexmls
:
# get the basename
pagexml_basename
=
os
.
path
.
basename
(
pagexml
)
# remove the extension
pagexml_without_ext
=
re
.
sub
(
".xml$"
,
""
,
pagexml_basename
)
# find a matching binary image
pagexml_without_ext_with_dot
=
pagexml_without_ext
+
"."
for
bin
in
bins
:
if
os
.
path
.
basename
(
bin
).
startswith
(
pagexml_without_ext_with_dot
):
pairs
.
append
(
PageXMLPredictionFilePair
(
os
.
path
.
join
(
prediction_dir
,
bin
),
os
.
path
.
join
(
pagexml_dir
,
pagexml
)))
break
else
:
print
(
"Found no match for {}"
.
format
(
pagexml
),
file
=
sys
.
stderr
)
return
pairs
class
PageXMLTypes
(
enum
.
Enum
):
PARAGRAPH
=
'paragraph'
IMAGE
=
'ImageRegion'
HEADING
=
'heading'
HEADER
=
'header'
CATCH_WORD
=
'catch-word'
PAGE_NUMBER
=
'page-number'
SIGNATURE_MARK
=
'signature-mark'
MARGINALIA
=
'marginalia'
OTHER
=
'other'
DROP_CAPITAL
=
'drop-capital'
FLOATING
=
'floating'
CAPTION
=
'caption'
ENDNOTE
=
'endnote'
IGNORE
=
'ignore'
TOCENTRY
=
'toc-entry'
FOOTNOTE
=
'footnote'
FOOTNOTE_CONTINUED
=
'footnote-continued'
FOOTER
=
'footer'
EMPTY
=
''
def
color
(
self
):
return
{
PageXMLTypes
.
PARAGRAPH
:
(
255
,
0
,
0
),
PageXMLTypes
.
IMAGE
:
(
0
,
255
,
0
),
PageXMLTypes
.
HEADING
:
(
0
,
0
,
255
),
PageXMLTypes
.
HEADER
:
(
0
,
255
,
255
),
PageXMLTypes
.
CATCH_WORD
:
(
255
,
255
,
0
),
PageXMLTypes
.
PAGE_NUMBER
:
(
255
,
0
,
255
),
PageXMLTypes
.
SIGNATURE_MARK
:
(
128
,
0
,
128
),
PageXMLTypes
.
MARGINALIA
:
(
128
,
128
,
0
),
PageXMLTypes
.
OTHER
:
(
0
,
128
,
128
),
PageXMLTypes
.
DROP_CAPITAL
:
(
255
,
128
,
0
),
PageXMLTypes
.
FLOATING
:
(
255
,
0
,
128
),
PageXMLTypes
.
CAPTION
:
(
128
,
255
,
0
),
PageXMLTypes
.
ENDNOTE
:
(
0
,
255
,
128
),
PageXMLTypes
.
IGNORE
:
(
0
,
128
,
0
),
PageXMLTypes
.
TOCENTRY
:
(
0
,
127
,
0
),
PageXMLTypes
.
FOOTNOTE
:
(
0
,
126
,
0
),
PageXMLTypes
.
FOOTNOTE_CONTINUED
:
(
0
,
125
,
0
),
PageXMLTypes
.
FOOTER
:
(
0
,
123
,
0
),
PageXMLTypes
.
EMPTY
:
(
0
,
124
,
0
),
}[
self
]
def
is_text
(
self
):
return
self
is
not
PageXMLTypes
.
IMAGE
and
\
self
is
not
PageXMLTypes
.
DROP_CAPITAL
and
\
self
is
not
PageXMLTypes
.
IGNORE
and
\
self
is
not
PageXMLTypes
.
CAPTION
PageRegionBB
=
namedtuple
(
'PageRegionBB'
,
'x1 y1 x2 y2 type'
)
@
dataclass
class
PageRegion
:
polygon
:
np
.
array
type
:
PageXMLTypes
is_text
:
bool
def
bounding_box
(
self
):
xmin
=
int
(
np
.
min
(
self
.
polygon
[:,
0
]))
xmax
=
int
(
np
.
max
(
self
.
polygon
[:,
0
]))
ymin
=
int
(
np
.
min
(
self
.
polygon
[:,
1
]))
ymax
=
int
(
np
.
max
(
self
.
polygon
[:,
1
]))
return
PageRegionBB
(
xmin
,
ymin
,
xmax
,
ymax
,
type
)
def
shifted
(
self
,
x
,
y
):
# shift the region points by x and y and return as new region
new_polygon
=
self
.
polygon
+
np
.
array
([
x
,
y
])
return
PageRegion
(
new_polygon
,
self
.
type
,
self
.
is_text
)
def
coords_str
(
self
):
# return the coords string as it is written in the pagexml
parts
=
[]
for
row
in
self
.
polygon
:
parts
.
append
(
"{},{}"
.
format
(
row
[
0
],
row
[
1
]))
return
" "
.
join
(
parts
)
def
has_negative_coords
(
self
):
return
np
.
min
(
self
.
polygon
)
<
0
"""
def plot_region(img, coords, color):
points = []
for match in _coords_regex.finditer(coords):
x, y = int(match.group(1)), int(match.group(2))
points.append((x, y))
# fill the poly
return cv2.fillPoly(img, np.array([points]), color)
"""
class
PageXMLParser
:
_coords_regex
=
re
.
compile
(
r
"([0-9]+),([0-9]+)"
)
def
_polygon_from_coords_str
(
self
,
coords_str
:
str
)
->
np
.
array
:
points
=
[]
for
match
in
PageXMLParser
.
_coords_regex
.
finditer
(
coords_str
):
x
,
y
=
int
(
match
.
group
(
1
)),
int
(
match
.
group
(
2
))
if
self
.
do_rescale
:
x
=
int
(
x
*
self
.
rescale_factor_x
)
# TODO: maybe use round() here !?!
y
=
int
(
y
*
self
.
rescale_factor_y
)
points
.
append
((
x
,
y
))
return
np
.
array
(
points
)
def
_parse_region
(
self
,
element
,
namespace
):
coords
=
element
.
find
(
namespace
+
"Coords"
)
polygon
=
None
if
"points"
in
coords
.
attrib
:
polygon
=
self
.
_polygon_from_coords_str
(
str
(
coords
.
attrib
[
"points"
]))
else
:
# hopefully there is a
point_elems
=
coords
.
findall
(
namespace
+
"Point"
)
points
=
[]
if
self
.
do_rescale
:
for
elem
in
point_elems
:
points
.
append
((
int
(
float
(
elem
.
attrib
[
"x"
])
*
self
.
rescale_factor_x
),
int
(
float
(
elem
.
attrib
[
"y"
])
*
self
.
rescale_factor_y
)))
# TODO: maybe round here !?!
else
:
for
elem
in
point_elems
:
points
.
append
((
int
(
elem
.
attrib
[
"x"
]),
int
(
elem
.
attrib
[
"y"
])))
polygon
=
np
.
array
(
points
)
polygon
.
shape
[
0
]
!=
0
and
polygon
.
shape
[
1
]
==
2
and
\
len
(
polygon
.
shape
)
==
2
,
\
"Invalid polygon. Maybe wrong xml format?"
is_text
=
False
if
element
.
tag
.
endswith
(
"TextRegion"
):
if
str
(
element
.
get
(
"type"
))
not
in
set
([
"drop-capital"
]):
is_text
=
True
# deduct the type from the element
if
element
.
get
(
"type"
)
is
None
:
if
not
is_text
:
type
=
PageXMLTypes
.
IMAGE
else
:
print
(
"Region without type attribute.. Assuming base type"
)
type
=
PageXMLTypes
.
PARAGRAPH
else
:
type
=
PageXMLTypes
(
element
.
get
(
'type'
))
if
is_text
and
not
type
.
is_text
():
is_text
=
False
print
(
is_text
,
type
,
element
.
get
(
"type"
))
return
PageRegion
(
polygon
=
polygon
,
type
=
type
,
is_text
=
is_text
)
@
staticmethod
def
from_file
(
filename
,
rescale
=
None
):
with
open
(
filename
)
as
f
:
return
PageXMLParser
(
f
.
read
(),
filename
,
rescale
=
rescale
)
def
__init__
(
self
,
pagexml
,
filename
=
None
,
rescale
=
None
):
tree
=
ET
.
ElementTree
(
ET
.
fromstring
(
pagexml
))
root
=
tree
.
getroot
()
namespace
=
root
.
tag
.
split
(
'}'
)[
0
]
+
"}"
page
=
root
.
find
(
namespace
+
"Page"
)
# gather some things from the PageXML header
self
.
image_filename
,
self
.
image_height
,
self
.
image_width
=
\
str
(
page
.
attrib
[
"imageFilename"
]),
\
int
(
page
.
attrib
[
"imageHeight"
]),
\
int
(
page
.
attrib
[
"imageWidth"
])
# really stupid hack, but source data is wrong :(
if
self
.
image_width
>
self
.
image_height
:
print
(
"Image width greater than height.. Skipping rotate.. but output data might be wrong"
)
#self.image_width, self.image_height = self.image_height, self.image_width
if
rescale
is
None
or
(
rescale
[
0
]
==
self
.
image_width
and
rescale
[
1
]
==
self
.
image_height
):
self
.
do_rescale
=
False
self
.
rescale_factor_x
=
1
self
.
rescale_factor_y
=
1
else
:
self
.
do_rescale
=
True
self
.
rescale_factor_x
=
rescale
[
0
]
/
self
.
image_width
self
.
rescale_factor_y
=
rescale
[
1
]
/
self
.
image_height
self
.
image_width
=
rescale
[
0
]
self
.
image_height
=
rescale
[
1
]
self
.
regions
=
[]
for
text_region
in
page
.
findall
(
namespace
+
"TextRegion"
):
try
:
self
.
regions
.
append
(
self
.
_parse_region
(
text_region
,
namespace
))
except
AssertionError
as
e
:
print
(
"Error: {}"
.
format
(
e
))
print
(
"Cannot parse TextRegion in file. Continuing. Filename: {}"
.
format
(
filename
))
for
image_region
in
page
.
findall
(
namespace
+
"ImageRegion"
):
try
:
self
.
regions
.
append
(
self
.
_parse_region
(
image_region
,
namespace
))
except
Exception
as
e
:
print
(
"Error: {}"
.
format
(
e
))
print
(
"Cannot parse ImageRegion in file. Continuing. Filename: {}"
.
format
(
filename
if
filename
else
"?"
))
def
regions_with_type
(
self
,
type
:
PageXMLTypes
):
# filter by type
return
list
(
filter
(
lambda
x
:
x
.
type
is
type
,
self
.
regions
))
def
get_mask_image
(
self
,
binary
=
False
,
glyph_color
=
(
0
,
255
,
0
),
other_color
=
(
255
,
0
,
255
)):
target_image
=
np
.
full
((
self
.
image_height
,
self
.
image_width
,
3
),
255
,
dtype
=
np
.
uint8
)
for
region
in
self
.
regions
:
color
=
glyph_color
if
region
.
is_text
else
other_color
cv2
.
fillPoly
(
target_image
,
np
.
array
([
region
.
polygon
]),
color
)
return
target_image
def
get_labeled_image
(
self
):
# label the image for each region
target_image
=
np
.
zeros
((
self
.
image_height
,
self
.
image_width
),
dtype
=
np
.
uint16
)
for
i
,
region
in
enumerate
(
self
.
regions
):
cv2
.
fillPoly
(
target_image
,
np
.
array
([
region
.
polygon
]),
[
i
+
1
])
return
target_image
def
contains_images
(
self
):
for
x
in
self
.
regions
:
if
x
.
is_text
is
False
:
return
True
return
False
def
__len__
(
self
):
return
len
(
self
.
regions
)
def
__iter__
(
self
):
return
iter
(
self
.
regions
)
def
build_pagexml
(
binary_filename
,
image_shape
,
regions
):
header
=
"""<?xml version="1.0" encoding="UTF-8"?>
<PcGts xmlns="http://schema.primaresearch.org/PAGE/gts/pagecontent/2017-07-15" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://schema.primaresearch.org/PAGE/gts/pagecontent/2017-07-15 http://schema.primaresearch.org/PAGE/gts/pagecontent/2017-07-15/pagecontent.xsd">
<Metadata>
<Creator />
<Created>2019-12-08T01:36:01</Created>
<LastChange>1970-01-01T01:00:00</LastChange>
<Comments />
</Metadata>
"""
page_section
=
'<Page imageFilename="{}" imageHeight="{}" imageWidth="{}">'
.
format
(
binary_filename
,
image_shape
[
0
],
image_shape
[
1
])
text_region_template
=
'<TextRegion id="{}" type="{}"><Coords points="{}" /><TextEquiv><Unicode /></TextEquiv></TextRegion>'
image_region_template
=
'<ImageRegion id="{}"><Coords points="{}" /></ImageRegion>'
region_strs
=
[]
for
rid
,
region
in
enumerate
(
regions
,
start
=
1
):
if
region
.
is_text
:
region_strs
.
append
(
text_region_template
.
format
(
"r{}"
.
format
(
rid
),
str
(
region
.
type
.
value
),
region
.
coords_str
()))
else
:
region_strs
.
append
(
image_region_template
.
format
(
"r{}"
.
format
(
rid
),
region
.
coords_str
()))
footer
=
'</Page></PcGts>'
return
header
+
page_section
+
" "
.
join
(
region_strs
)
+
footer
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment