Commit 4d20ae65 authored by amitdo's avatar amitdo
Browse files

Replace native code with regular functions

Use numpy.einsum() for sumouter and sumprod functions.
parent 61562ce9
......@@ -35,7 +35,7 @@ from pylab import (clf, cm, figure, ginput, imshow, newaxis, rand, subplot,
from collections import defaultdict
from ocrolib.exceptions import RecognitionError
from ocrolib.edist import levenshtein
import nutils
import utils
import unicodedata
from scipy.ndimage import measurements,filters
......@@ -471,13 +471,13 @@ def backward_py(n,N,ni,ns,na,deltas,
sourceerr[t] += dot(gferr[t],WGF)
sourceerr[t] += dot(goerr[t],WGO)
sourceerr[t] += dot(cierr[t],WCI)
DWIP = nutils.sumprod(gierr[1:n],state[:n-1],out=DWIP)
DWFP = nutils.sumprod(gferr[1:n],state[:n-1],out=DWFP)
DWOP = nutils.sumprod(goerr[:n],state[:n],out=DWOP)
DWGI = nutils.sumouter(gierr[:n],source[:n],out=DWGI)
DWGF = nutils.sumouter(gferr[1:n],source[1:n],out=DWGF)
DWGO = nutils.sumouter(goerr[:n],source[:n],out=DWGO)
DWCI = nutils.sumouter(cierr[:n],source[:n],out=DWCI)
DWIP = utils.sumprod(gierr[1:n],state[:n-1],out=DWIP)
DWFP = utils.sumprod(gferr[1:n],state[:n-1],out=DWFP)
DWOP = utils.sumprod(goerr[:n],state[:n],out=DWOP)
DWGI = utils.sumouter(gierr[:n],source[:n],out=DWGI)
DWGF = utils.sumouter(gferr[1:n],source[1:n],out=DWGF)
DWGO = utils.sumouter(goerr[:n],source[:n],out=DWGO)
DWCI = utils.sumouter(cierr[:n],source[:n],out=DWCI)
class LSTM(Network):
"""A standard LSTM network. This is a direct implementation of all the forward
import numpy as np
def sumouter(u,v,out=None):
if out is None:
m = u.shape[1]
n = v.shape[1]
out = np.zeros((m,n))
return np.einsum('ki,kj->ij',u,v,out=out)
def sumprod(u,v,out=None):
if out is None:
n = u.shape[1]
out = np.zeros(n)
return np.einsum('ki,ki->i',u,v,out=out)
def test():
from pylab import randn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment