Unverified Commit ed5b5451 authored by Konstantin Baierer's avatar Konstantin Baierer Committed by GitHub
Browse files

Merge pull request #265 from amitdo/einsum

Replace native code with regular functions
parents d6a65700 4d20ae65
......@@ -35,7 +35,7 @@ from pylab import (clf, cm, figure, ginput, imshow, newaxis, rand, subplot,
from collections import defaultdict
from ocrolib.exceptions import RecognitionError
from ocrolib.edist import levenshtein
import nutils
import utils
import unicodedata
from scipy.ndimage import measurements,filters
......@@ -471,13 +471,13 @@ def backward_py(n,N,ni,ns,na,deltas,
sourceerr[t] += dot(gferr[t],WGF)
sourceerr[t] += dot(goerr[t],WGO)
sourceerr[t] += dot(cierr[t],WCI)
DWIP = nutils.sumprod(gierr[1:n],state[:n-1],out=DWIP)
DWFP = nutils.sumprod(gferr[1:n],state[:n-1],out=DWFP)
DWOP = nutils.sumprod(goerr[:n],state[:n],out=DWOP)
DWGI = nutils.sumouter(gierr[:n],source[:n],out=DWGI)
DWGF = nutils.sumouter(gferr[1:n],source[1:n],out=DWGF)
DWGO = nutils.sumouter(goerr[:n],source[:n],out=DWGO)
DWCI = nutils.sumouter(cierr[:n],source[:n],out=DWCI)
DWIP = utils.sumprod(gierr[1:n],state[:n-1],out=DWIP)
DWFP = utils.sumprod(gferr[1:n],state[:n-1],out=DWFP)
DWOP = utils.sumprod(goerr[:n],state[:n],out=DWOP)
DWGI = utils.sumouter(gierr[:n],source[:n],out=DWGI)
DWGF = utils.sumouter(gferr[1:n],source[1:n],out=DWGF)
DWGO = utils.sumouter(goerr[:n],source[:n],out=DWGO)
DWCI = utils.sumouter(cierr[:n],source[:n],out=DWCI)
class LSTM(Network):
"""A standard LSTM network. This is a direct implementation of all the forward
import numpy as np
def sumouter(u,v,out=None):
if out is None:
m = u.shape[1]
n = v.shape[1]
out = np.zeros((m,n))
return np.einsum('ki,kj->ij',u,v,out=out)
def sumprod(u,v,out=None):
if out is None:
n = u.shape[1]
out = np.zeros(n)
return np.einsum('ki,ki->i',u,v,out=out)
def test():
from pylab import randn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment