nru

package
v0.7.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: May 24, 2021 License: BSD-2-Clause Imports: 6 Imported by: 0

Documentation

Overview

Package nru provides an implementation of the NRU (Non-Saturating Recurrent Units) recurrent network as described in "Towards Non-Saturating Recurrent Units for Modelling Long-Term Dependencies" by Chandar et al., 2019. (https://www.aaai.org/ojs/index.php/AAAI/article/view/4200/4078)

Unfortunately this implementation is extremely inefficient due to the lack of functionality in the auto-grad (ag) package at the moment.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type Config

type Config struct {
	InputSize    int
	HiddenSize   int
	MemorySize   int
	K            int
	UseReLU      bool
	UseLayerNorm bool
	States       []*State `spago:"scope:processor"`
}

Config provides configuration settings for a NRU Model.

type Model

type Model struct {
	nn.BaseModel
	Config
	SqrtMemK        int
	Wx              nn.Param `spago:"type:weights"`
	Wh              nn.Param `spago:"type:weights"`
	Wm              nn.Param `spago:"type:weights"`
	B               nn.Param `spago:"type:biases"`
	Whm2alpha       nn.Param `spago:"type:weights"`
	Bhm2alpha       nn.Param `spago:"type:biases"`
	Whm2alphaVec    nn.Param `spago:"type:weights"`
	Bhm2alphaVec    nn.Param `spago:"type:biases"`
	Whm2beta        nn.Param `spago:"type:weights"`
	Bhm2beta        nn.Param `spago:"type:biases"`
	Whm2betaVec     nn.Param `spago:"type:weights"`
	Bhm2betaVec     nn.Param `spago:"type:biases"`
	HiddenLayerNorm *layernorm.Model
}

Model contains the serializable parameters.

func New

func New(config Config) *Model

New returns a new model with parameters initialized to zeros.

func (*Model) Forward

func (m *Model) Forward(xs ...ag.Node) []ag.Node

Forward performs the forward step for each input node and returns the result.

func (*Model) LastState

func (m *Model) LastState() *State

LastState returns the last state of the recurrent network. It returns nil if there are no states.

func (*Model) SetInitialState

func (m *Model) SetInitialState(state *State)

SetInitialState sets the initial state of the recurrent network. It panics if one or more states are already present.

type State

type State struct {
	Y      ag.Node
	Memory ag.Node
}

State represent a state of the NRU recurrent network.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL