Commit f33831a1 authored by Christoph Wick's avatar Christoph Wick
Browse files

Implementation of resizing codec

parent 3015bf43
......@@ -27,7 +27,7 @@
from __future__ import print_function
import common as ocrolib
from numpy import (amax, amin, argmax, arange, array, clip, concatenate, dot,
from numpy import (amax, amin, argmax, arange, array, clip, concatenate, delete, dot,
exp, isnan, log, maximum, mean, nan, ones, outer, roll, sum,
tanh, tile, vstack, zeros)
from pylab import (clf, cm, figure, ginput, imshow, newaxis, rand, subplot,
......@@ -288,6 +288,19 @@ class Softmax(Network):
self.No = No
self.W2 = randu(No,Nh+1)*initial_range
self.DW2 = zeros((No,Nh+1))
self.initial_range = initial_range
def resizeOutput(self,No, deleted_positions):
"""resize all matrices to the new codec created by a given charset"""
# delete rows for chars that are not necessary
W2_temp = delete(self.W2, deleted_positions, axis=0)
# enlarge output and weights for extra chars
W2 = randu(No, self.Nh + 1) * initial_range
# use the trained weights (if --load was used)
W2[: len(W2_temp)] = W2_temp
self.W2 = W2
self.DW2 = zeros((No, self.Nh + 1))
self.No = No
self.deltas = None
def ninputs(self):
return self.Nh
def noutputs(self):
......@@ -578,6 +591,8 @@ class LSTM(Network):
self.DWGI,self.DWGF,self.DWGO,self.DWCI,
self.DWIP,self.DWFP,self.DWOP)
return [s[1:1+ni] for s in self.sourceerr[:n]]
def resizeOutput(self, No, deleted_positions):
pass
################################################################
# combination classifiers
......@@ -621,6 +636,9 @@ class Stacked(Network):
for i,net in enumerate(self.nets):
for w,dw,n in net.weights():
yield w,dw,"Stacked%d/%s"%(i,n)
def resizeOutput(self, nout, deleted_positions):
self.nets[-1].resizeOutput(nout, deleted_positions)
self.deltas = None
class Reversed(Network):
"""Run a network on the time-reversed input."""
......@@ -645,6 +663,8 @@ class Reversed(Network):
def weights(self):
for w,dw,n in self.net.weights():
yield w,dw,"Reversed/%s"%n
def resizeOutput(self, no, deleted_positions):
self.net.resizeOutput(no, deleted_positions)
class Parallel(Network):
"""Run multiple networks in parallel on the same input."""
......@@ -679,6 +699,9 @@ class Parallel(Network):
for i,net in enumerate(self.nets):
for w,dw,n in net.weights():
yield w,dw,"Parallel%d/%s"%(i,n)
def resizeOutput(self, no, deleted_positions):
for net in self.nets:
net.resizeOutput(no, deleted_positions)
def MLP1(Ni,Ns,No):
"""An MLP implementation by stacking two `Logreg` networks on top
......@@ -839,7 +862,7 @@ def ctc_align_targets(outputs,targets,threshold=100.0,verbose=0,debug=0,lo=1e-5)
return aligned
def normalize_nfkc(s):
return unicodedata.normalize('NFKC',s)
return unicodedata.normalize('NFC',s)
def add_training_info(network):
return network
......@@ -933,6 +956,17 @@ class SeqRecognizer:
"Predict output as a string. This uses codec and normalizer."
cs = self.predictSequence(xs)
return self.l2s(cs)
def resizeCodec(self, codec):
"""create a codec that exactly fits to ground truth/given codec as parameter"""
print("# creating a codec thas fits to the given charset")
# add all unknown and new chars to the codec
self.codec.extend(codec)
# search for chars that should not be in the codec anymore
deleted_positions = self.codec.shrink(codec)
# let the output fit to the new defined codec
self.lstm.resizeOutput(self.codec.size(), deleted_positions)
self.No = self.codec.size()
return self.codec
class Codec:
"""Translate between integer codes and characters."""
......@@ -957,6 +991,33 @@ class Codec:
"Decode a code sequence into a string."
s = [self.code2char.get(c,"~") for c in l]
return s
def extend(self, codec):
charset = self.code2char.values()
size = self.size()
counter = 0
for c in codec.code2char.values():
if not c in charset: # append chars that doesn't appear in the codec
self.code2char[size] = c
self.char2code[c] = size
size += 1
counter += 1
print("#", counter, " extra chars added")
def shrink(self, codec):
deleted_positions = []
positions = []
for number, char in self.code2char.iteritems():
if not char in codec.char2code and char != "~":
deleted_positions.append(number)
else:
positions.append(number)
charset = [self.code2char[c] for c in sorted(positions)]
self.code2char = {}
self.char2code = {}
for code, char in enumerate(charset):
self.code2char[code] = char
self.char2code[char] = code
print("#", len(deleted_positions), " unnecessary chars deleted")
return deleted_positions
ascii_labels = [""," ","~"] + [unichr(x) for x in range(33,126)]
......
......@@ -168,7 +168,7 @@ def save_lstm(fname,network):
for x in network.walk(): x.postLoad()
def load_lstm(fname):
def load_lstm(fname, codec):
if args.clstm:
network = lstm.SeqRecognizer(args.height,args.hiddensize,
codec=codec,
......@@ -178,17 +178,27 @@ def load_lstm(fname):
mylstm.init(network.No,args.hiddensize,network.Ni)
mylstm.load(fname)
network.lstm = clstm.CNetwork(mylstm)
return network
else:
network = ocrolib.load_object(last_save)
network.upgrade()
for x in network.walk(): x.postLoad()
return network
# if a model was loaded we must change the local codec in any case
# either resize the codec of the network if a codec is given
# or use the loaded codec directly
if args.codec != []:
# resize the network codec (including the network weights)
codec = network.resizeCodec(codec)
else:
# the local codec is simply the local codec
codec = network.codec
return network, codec
if args.load:
print("# loading", args.load)
last_save = args.load
network = load_lstm(args.load)
network, codec = load_lstm(args.load, codec)
else:
last_save = None
network = lstm.SeqRecognizer(args.height,args.hiddensize,
......@@ -296,8 +306,7 @@ for trial in range(start,args.ntrain):
except FloatingPointError as e:
print("# oops, got FloatingPointError", e)
traceback.print_exc()
network = load_lstm(last_save)
continue
network, codec = load_lstm(last_save, codec)
except lstm.RangeError as e:
continue
pred = "".join(codec.decode(pcs))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment