Commit c922308d authored by Tom's avatar Tom
Browse files

Added dynamic programming line segmenter written in Python (obsoleting ocrolseg).

Added text line cleanup prior to text line recognition.
parent 5a4dbdbd
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
#!/usr/bin/python
# -*- encoding: utf-8 -*-
from optparse import OptionParser
import sys,os,re,unicodedata
from math import log
import openfst
parser = OptionParser("""
usage: %prog [options] ...
""")
parser.add_option("-u","--utf8",help="words only",action="store_true")
parser.add_option("-w","--words",help="words only",action="store_true")
parser.add_option("-W","--wordfst",help="put words on arcs (for debugging)",action="store_true")
parser.add_option("-v","--verbose",help="verbose",action="store_true")
parser.add_option("-n","--n",help="n-gram",type=int,default=3)
parser.add_option("-o","--output",help="output file",default="ngram.fst")
parser.add_option("-m","--minimize",help="perform minimization",action="store_true")
(options,args) = parser.parse_args()
if len(args)<1:
parser.print_usage()
sys.exit(0)
n = options.n
class Counter(dict):
def __getitem__(self,index):
return self.setdefault(index,0)
words = Counter()
ngrams = Counter()
chars = Counter()
files = []
for file in args:
s = open(file).read().decode("utf-8")
print "# file:",file,"chars:",len(s)
s.strip()
files.append(s)
line = u" ".join(files)
line += u" "
line = re.sub(ur"\s*\n\*"," ",line)
line = re.sub(ur'\s+',' ',line)
line = re.sub(ur'[_#]','',line)
sep_re = re.compile(ur'([^A-Za-z0-9äöüÄÖÜß]+)')
ws = sep_re.split(line)
ngram = [u""]*n
for w in ws:
for c in w: chars[c] += 1
assert len(ngram)==n
del ngram[0]
ngram.append(w)
t = tuple(ngram)
if options.verbose:
print ("["+("|".join(t))+"]").encode("utf-8")
words[t[-1]] += 1
ngrams[t] += 1
print "words",len(words.keys())
print "ngrams",len(ngrams.keys())
print "# creating symbol table"
EPS = 0
symtab = openfst.SymbolTable("chars")
symtab.AddSymbol("EPS",EPS)
for c in chars.keys():
c = unichr(ord(c))
if c=='"':
symtab.AddSymbol("''",ord(c))
elif options.utf8:
symtab.AddSymbol(c.encode("utf-8"),ord(c))
else:
desc = str(c) if ord(c)<128 else "U+%04x"%ord(c)
symtab.AddSymbol(desc,ord(c))
fst = openfst.StdVectorFst()
nstates = 0
states = {}
initial = fst.AddState()
fst.SetStart(initial)
final = fst.AddState()
fst.SetFinal(final,0)
for t in ngrams.keys():
for s in [t[:-1],t[1:]]:
if not s in states:
state = fst.AddState()
states[s] = state
fst.AddArc(initial,0,0,0,state)
fst.AddArc(state,0,0,0,final)
print "# adding transitions"
def add_sep(frm,to,s):
state = frm
for i in range(len(s)):
next = to if i==len(s)-1 else fst.AddState()
c = ord(s[i])
fst.AddArc(state,c,c,0,next)
# can start and stop anywhere inside a separator
fst.AddArc(initial,0,0,0,state)
fst.AddArc(state,0,0,0,final)
state = next
def add_word(frm,to,s):
state = frm
for i in range(len(s)):
next = to if i==len(s)-1 else fst.AddState()
c = ord(s[i])
cost = 0.0
fst.AddArc(state,c,c,cost,next)
state = next
for ngram in ngrams:
assert len(ngram)==options.n
w = ngram[-1]
s0 = tuple(ngram[:-1])
s1 = tuple(ngram[1:])
frm = states[s0]
to = states[s1]
if options.verbose: print s0,repr(w),s1
if sep_re.search(w):
add_sep(frm,to,w)
else:
add_word(frm,to,w)
if options.minimize:
print "# minimizing"
det = openfst.StdVectorFst()
openfst.Determinize(fst,det)
openfst.Minimize(det)
fst = det
fst.SetInputSymbols(symtab)
fst.SetOutputSymbols(symtab)
fst.Write(options.output)
......@@ -49,24 +49,14 @@ def pyargsort(seq,cmp=cmp,key=lambda x:x):
function. Takes an optional cmp."""
return sorted(range(len(seq)),key=lambda x:key(seq.__getitem__(x)),cmp=cmp)
def renumber_labels_by_boxes(a,cmp=cmp,key=lambda x:x,correspondence=0):
"""Renumber the labels of the input array according to some
order on their bounding boxes. If you provide a cmp function,
it is passed the outputs of find_objects for sorting.
The default is lexicographic."""
if cmp=='rlex':
import __builtin__
cmp = lambda x,y: __builtin__.cmp(x[::-1],y[::-1])
assert a.dtype==dtype('B') or a.dtype==dtype('i')
labels = renumber_labels_ordered(a)
objects = flexible_find_objects(labels)
order = array(pyargsort(objects,cmp=cmp,key=key),'i')
assert len(objects)==len(order)
order = concatenate(([0],order+1))
if correspondence:
return order[labels],argsort(order)
else:
return order[labels]
def renumber_by_xcenter(seg):
objects = [(slice(0,0),slice(0,0))]+measurements.find_objects(seg)
def xc(o): return mean((o[1].start,o[1].stop))
xs = array([xc(o) for o in objects])
order = argsort(xs)
segmap = zeros(amax(seg)+1,'i')
for i,j in enumerate(order): segmap[j] = i
return segmap[seg]
def flexible_find_objects(image):
"""Like measurements.find_objects, but tries to
......@@ -776,3 +766,84 @@ def simple_classify(model,inputs):
result.append(model.coutputs(inputs[i]))
return result
################################################################
# text line related utilities
################################################################
from scipy.ndimage import filters,morphology
from pylab import imshow
import psegutils
def estimate_baseline(line):
"""Compute the baseline by fitting a polynomial to the gradient.
TODO: use robust fitting, special case very short line, limit parameter ranges"""
line = line*1.0/amax(line)
vgrad = morphology.grey_closing(line,(1,40))
vgrad = filters.gaussian_filter(vgrad,(2,60),(1,0))
if amin(vgrad)>0 or amax(vgrad)<0: raise BadLine()
h,w = vgrad.shape
baseline = fitext(vgrad)
return baseline
def dewarp_line(line,show=0):
line = line*1.0/amax(line)
line = r_[zeros(line.shape),line]
h,w = line.shape
baseline = estimate_baseline(line)
ys = polyval(baseline,arange(w))
base = 2*h/3
temp = zeros(line.shape)
for x in range(w):
temp[:,x] = interpolation.shift(line[:,x],(base-ys[x]),order=1)
return temp
line = line*1.0/amax(line)
def estimate_xheight(line,scale=1.0,debug=0):
vgrad = morphology.grey_closing(line,(1,int(scale*40)))
vgrad = filters.gaussian_filter(vgrad,(2,int(scale*60)),(1,0))
if amin(vgrad)>0 or amax(vgrad)<0: raise Exception("bad line")
if debug: imshow(vgrad)
proj = sum(vgrad,1)
proj = filters.gaussian_filter(proj,0.5)
top = argmax(proj)
bottom = argmin(proj)
return bottom-top,bottom
def keep_marked(image,markers):
labels,_ = measurements.label(image)
imshow(sin(17.1*labels),cmap=cm.jet)
marked = unique(labels*(markers!=0))
print marked
kept = in1d(labels.ravel(),marked)
return (image!=0)*kept.reshape(*labels.shape)
def latin_kernel(line,scale=1.0,r=1.2,debug=0):
vgrad = morphology.grey_closing(1.0*line,(1,int(scale*40)))
vgrad = filters.gaussian_filter(vgrad,(2,int(scale*60)),(1,0))
tops = argmax(vgrad,0)
bottoms = argmin(vgrad,0)
mask = zeros(line.shape)
xheight = mean(bottoms-tops)
for i in range(len(bottoms)):
d = bottoms[i]-tops[i]
y0 = int(maximum(0,bottoms[i]-r*d))
mask[y0:bottoms[i],i] = 1
return mask
def latin_filter(line,scale=1.0,r=1.2,debug=0):
bin = (line>0.5*amax(line))
mask = latin_kernel(bin,scale=scale,r=r,debug=debug)
mask = psegutils.keep_marked(bin,mask)
mask = filters.maximum_filter(mask,3)
return line*mask
def remove_noise(line,minsize=8):
bin = (line>0.5*amax(line))
labels,n = measurements.label(bin)
sums = measurements.sum(bin,labels,range(n+1))
sums = sums[labels]
good = minimum(bin,1-(sums>0)*(sums<minsize))
return good
......@@ -180,6 +180,60 @@ if __name__=="__main__":
draw()
raw_input()
def cairo_render_at(s,loc=None,shape=None,
fontname=None,fontfile=None,size=None,
slant=cairo.FONT_SLANT_NORMAL,
weight=cairo.FONT_WEIGHT_NORMAL,
bg=(0.0,0.0,0.0),fg=(0.9,0.9,0.9)):
"""Render a string using Cairo and the Cairo text rendering interface. Fonts can either be given
as a fontfile or as a fontname. Size should be in pixels (?). You can specify a background and
foreground color as RGB floating point triples. Images are padded by pad pixels on all sides."""
assert loc is not None
assert shape is not None
assert size is not None
w,h = shape
x,y = loc
face = None
if fontfile is not None:
# "/usr/share/fonts/truetype/msttcorefonts/comic.ttf"
if fontfile in facecache:
face = facecache[fontfile]
else:
face = create_cairo_font_face_for_file(fontfile,0)
facecache[fontfile] = face
surface = cairo.ImageSurface(cairo.FORMAT_ARGB32,w,h)
cr = cairo.Context(surface)
if face is not None:
cr.set_font_face(face)
else:
if fontname is None: fontname = "Helvetica"
if type(slant)==str:
if slant[0]=="i": slant = cairo.FONT_SLANT_ITALIC
elif slant[0]=="o": slant = cairo.FONT_SLANT_OBLIQUE
elif slant[0]=="n": slant = cairo.FONT_SLANT_NORMAL
else: raise Exception("bad font slant specification (use n/i/o)")
if type(weight)==str:
if weight[0]=="b": weight = cairo.FONT_WEIGHT_BOLD
elif weight[0]=="n": weight = cairo.FONT_WEIGHT_NORMAL
else: raise Exception("bad font weight specification (use b/n)")
cr.select_font_face(fontname,slant,weight)
if size is not None:
cr.set_font_size(size)
cr.set_source_rgb(*bg)
cr.rectangle(0,0,w,h)
cr.fill()
cr.move_to(x,y)
cr.set_source_rgb(*fg)
cr.show_text(s)
data = surface.get_data()
data = bytearray(data)
a = array(data,'B')
a.shape = (h,w,4)
a = a[:,:,:3]
a = a[:,:,::-1]
return a
if __name__=="x__main__":
s = u"hello, world: \u00E4\u0182\u03c0\u4eb0"
subplot(311)
......
......@@ -60,9 +60,9 @@ class Grouper(PyComponent):
def setSegmentation(self,segmentation,cseg=0,preferred=None):
"""Set the line segmentation."""
# reorder the labels by the x center of bounding box
segmentation = common.renumber_labels_by_boxes(segmentation,key=lambda x:mean((x[1].start,x[1].stop)))
segmentation = common.renumber_by_xcenter(segmentation)
if preferred is not None:
preferred = common.renumber_labels_by_boxes(preferred,key=lambda x:mean((x[1].start,x[1].stop)))
preferred = common.renumber_by_xcenter(preferred)
assert amax(segmentation)<32000 and amax(preferred)<32000
combined = ((preferred<<16)|segmentation)
correspondences = [(k>>16,k&0xffff) for k,v in Counter(combined.ravel()).most_common() if v>5]
......
from pylab import *
from scipy.ndimage import filters,morphology,measurements
import psegutils
def dpcuts(image,alpha=0.5,r=2):
costs = 9999*ones(image.shape)
costs[0,:] = 0
sources = zeros(image.shape,'i')
for i in range(1,len(costs)):
for k in range(-r,r+1):
ncosts = roll(costs[i-1,:],k)+image[i,:]+alpha*abs(k)
sources[i,:] = where(ncosts<costs[i,:],-k,sources[i,:])
costs[i,:] = where(ncosts<costs[i,:],ncosts,costs[i,:])
return costs,sources
def between(u,v):
u,v = min(u,v),max(u,v)
for i in range(u,v+1):
yield i
def dptrack(l,s):
result = zeros(s.shape)
for i in l:
x0 = i
x = i
y = len(s)-1
while y>-1:
x = clip(x,0,result.shape[1]-1)
for j in between(x0,x):
result[y,j] = 1
y -= 1
x0 = x
x += s[y,x]
return result
def dplineseg1(image,imweight=4,bweight=-1,diagweight=1):
cimage = imweight*image - bweight*maximum(0,roll(image,-1,1)-image)
c,s = dpcuts(cimage,alpha=diagweight)
costs = c[-1]
costs = filters.gaussian_filter(costs,1)
mins = find(filters.minimum_filter(costs,8)==costs)
tracks = dptrack(mins,s)
# combo = 3*tracks+cimage
return tracks
def centroid(image):
ys,xs = mgrid[:image.shape[0],:image.shape[1]]
yc = sum(image*ys)/sum(image)
xc = sum(image*xs)/sum(image)
return yc,xc
def dplineseg2(image,imweight=4,bweight=-1,diagweight=1,r=2,debug=0):
yc,xc = centroid(image)
half = int(yc)
cimage = imweight*image-bweight*maximum(0,roll(image,-1,1)-image)
tc,ts = dpcuts(cimage[:half],alpha=diagweight,r=r)
bc,bs = dpcuts(cimage[half:][::-1],alpha=diagweight,r=r)
costs = bc[-1]+tc[-1]
if debug:
clf()
subplot(311); imshow(tc)
subplot(312); imshow(bc)
costs = tc[-1]+bc[-1]
costs = -costs
costs -= amin(costs)
costs = filters.gaussian_filter(costs,1)
costs += 0.01*filters.gaussian_filter(costs,3.0)
mins = (filters.maximum_filter(costs,8)==costs)*(costs>0.3*amax(costs))
l = find(mins)
tt = dptrack(l,ts)
bt = dptrack(l,bs)
tracks = r_[tt,bt[::-1]]
if debug:
subplot(313)
imshow(tracks+0.5*image,interpolation='nearest')
return tracks
def ccslineseg(image):
image = 1.0*(image>0.3*amax(image))
sigma = 10.0
smooth = filters.gaussian_filter(image,(sigma,3.0*sigma))
center = (smooth==amax(smooth,axis=0)[newaxis,:])
center = filters.maximum_filter(center,(3,3))
center = psegutils.keep_marked(image>0.5,center)
center = filters.maximum_filter(center,(2,2))
center,_ = measurements.label(center)
center = psegutils.spread_labels(center)
center *= image
return center
import psegutils
class SimpleParams:
def info(self,depth=0):
"""Print information about this object."""
pass
def pexists(self,name):
"""Check whether parameter NAME exists."""
return name in dir(self)
def pset(self,name,value):
"""Set parameter NAME to VALUE."""
assert name in dir(self)
self.__dict__[name] = value
def pget(self,name):
"""Get the value of string parameter NAME."""
return self.__dict__.get(name)
def pgetf(self,name):
"""Get the value of floating point parameter NAME."""
return float(self.__dict__.get(name))
import common as ocrolib
class CCSSegmentLine(SimpleParams):
def charseg(self,line):
"""Segment a text line into potential character parts."""
line = (line<0.5*(amax(line)+amin(line)))
seg = ccslineseg(line)
seg = ocrolib.renumber_by_xcenter(seg)
return seg
class DPSegmentLine(SimpleParams):
def __init__(self,ledge=-0.1,imweight=4,bweight=-1,diagweight=0.3,r=1,debug=0):
self.r = r
self.imweight = imweight
self.bweight = bweight
self.diagweight = diagweight
self.debug = debug
self.ledge = ledge
def charseg(self,line):
"""Segment a text line into potential character parts."""
assert mean(line)>0.5*amax(line)
line = amax(line)-line
line = line+self.ledge*maximum(0,roll(line,-1,1)-line)
tracks = dplineseg2(line,imweight=self.imweight,bweight=self.bweight,
diagweight=self.diagweight,debug=self.debug,r=self.r)
tracks = array(tracks<0.5*amax(tracks),'i')
tracks,_ = measurements.label(tracks)
self.tracks = tracks
rsegs = psegutils.spread_labels(tracks)
rsegs = rsegs*(line>0.5*amax(line))
return ocrolib.renumber_by_xcenter(rsegs)
......@@ -83,7 +83,7 @@ def keep_marked(image,markers):
labels,_ = measurements.label(image)
marked = unique(labels*(markers!=0))
kept = in1d(labels.ravel(),marked)
return kept.reshape(*labels.shape)
return (image!=0)*kept.reshape(*labels.shape)
def remove_marked(image,markers):
marked = keep_marked(image,markers)
......
......@@ -17,6 +17,7 @@ import ocrolseg
import ocropreproc
import common
import grouper
import lineseg
from pycomp import PyComponent
from ocroio import renumber_labels
from pylab import *
......@@ -217,8 +218,10 @@ class Classifier(PyComponent):
class SegWithCost:
def __init__(self):
self.segmenter0 = ocrolseg.SegmentLineByGCCS()
self.segmenter1 = ocrolseg.DpSegmenter()
# self.segmenter0 = ocrolseg.SegmentLineByGCCS()
# self.segmenter1 = ocrolseg.DpSegmenter()
self.segmenter0 = lineseg.CCSSegmentLine
self.segmenter1 = lineseg.DPSegmentLine
def segment(self,image):
seg0 = self.segmenter0.charseg(image)
assert amax(seg0)<32000
......@@ -263,8 +266,15 @@ class CmodelLineRecognizer:
self.combined_cost = 0.0 # extra cost for combining connected components
self.split_cost = 0.0 # extra cost for splitting connected components
self.maxrange = 4
self.segmenter = ocrolseg.DpSegmenter()
self.segmenter0 = ocrolseg.SegmentLineByGCCS()
self.latin_cleaner = 1
self.min_xheight = 10
self.max_xheight = 40
self.check_white_on_black = 1
self.noise_threshold = 8
#self.segmenter = ocrolseg.DpSegmenter()
#self.segmenter0 = ocrolseg.SegmentLineByGCCS()
self.segmenter = lineseg.DPSegmentLine()
self.segmenter0 = lineseg.CCSSegmentLine()
common.set_params(self,kw)
if type(self.whitespace)==str:
self.whitespace = common.load_component(common.ocropus_find_file(self.whitespace))
......@@ -286,16 +296,37 @@ class CmodelLineRecognizer:
if image.shape[1]>10000:
raise common.RecognitionError("line image too wide???",image=image)
# FIXME for some reason, something down below
# depends on this being a bytearray image, so
# we're normalizing it here to that type
image = array(image*255.0/amax(image),'B')
# convert to floating point image
image = image*1.0/amax(image)
if self.check_white_on_black:
if mean(image)<0.5*amax(image):
raise common.RecognitionError("image may not be white on black text (maybe invert?)",image=image)
# make sure the xheight is reasonable
xheight,_ = common.estimate_xheight(1-image)
self.xheight = xheight
if xheight<self.min_xheight:
raise common.RecognitionError("xheight %f too small (maybe rescale?)"%xheight,image=image)
if xheight>self.max_xheight:
raise common.RecognitionError("xheight %f too large (maybe rescale?)"%xheight,image=image)
# clean up connected components around the edges
if self.latin_cleaner:
image = 1-image
image = common.latin_filter(image)
image = common.remove_noise(image,self.noise_threshold)
image = 1-image
# keep a copy of the cleaned up image
self.image = image.copy()
# compute the raw segmentation
rseg = self.segmenter.charseg(image)
# if self.display:
# show_segmentation(rseg) # FIXME
rseg = renumber_labels(rseg,1) # FIXME
self.rseg = rseg
if amax(rseg)<self.minsegs:
raise common.RecognitionError("not enough segments in raw segmentation",rseg=rseg)
# self.grouper = grouper.Grouper()
......@@ -460,7 +491,6 @@ class CmodelLineRecognizer:
if self.display:
title("waiting")
ginput(1,10000)
self.rseg = rseg
def bestpath(self):
"""Return the bestpath through the recognition lattice, as a string.
......
......@@ -13,6 +13,7 @@ from multiprocessing import Pool
import ocrolib
from ocrolib import number_of_processors,die
from ocrolib.ligatures import lig
from ocrolib import lineseg