lstm

package
v0.21.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jul 1, 2025 License: Apache-2.0 Imports: 6 Imported by: 2

Documentation

Overview

Package lstm provides a minimal "Long Short-Term Memory RNN" (LSTM) [1] implementation.

An LSTM is a type of recurrent neural network that addresses the vanishing gradient problem in vanilla RNNs through additional cells, input and output gates. Intuitively, vanishing gradients are solved through additional additive components, and forget gate activations, that allow the gradients to flow through the network without vanishing as quickly.

Since GoMLX doesn't implement loops, the size of the graph will be O(N) on the size of the sequence -- each step of the LSTM is instantiated as its own graph nodes.

In any case, if not for educational or historical reasons, consider using transformer or (dilated) convolution layers instead.

It was created to allow conversion of ONNX model, but it's fully differentiable and can be used to train models.

See discussions in [2], and specification of ONNX LSTM which this was created to support in [3].

[1] https://www.bioinf.jku.at/publications/older/2604.pdf, Hochreiter & Schmidhuber, 1997 [2] https://colah.github.io/posts/2015-08-Understanding-LSTMs/ [3] https://onnx.ai/onnx/operators/onnx__LSTM.html

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func DirectionTypeStrings

func DirectionTypeStrings() []string

DirectionTypeStrings returns a slice of all String values of the enum

Types

type ActivationFn

type ActivationFn func(x *Node) *Node

ActivationFn defines an activation function used by the LSTM.

type DirectionType

type DirectionType int

DirectionType defines the direction to run the LSTM.

const (
	DirForward DirectionType = iota
	DirReverse
	DirBidirectional
)

func DirectionTypeString

func DirectionTypeString(s string) (DirectionType, error)

DirectionTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.

func DirectionTypeValues

func DirectionTypeValues() []DirectionType

DirectionTypeValues returns all values of the enum

func (DirectionType) IsADirectionType

func (i DirectionType) IsADirectionType() bool

IsADirectionType returns "true" if the value is listed in the enum definition. "false" otherwise

func (DirectionType) MarshalJSON

func (i DirectionType) MarshalJSON() ([]byte, error)

MarshalJSON implements the json.Marshaler interface for DirectionType

func (DirectionType) MarshalText

func (i DirectionType) MarshalText() ([]byte, error)

MarshalText implements the encoding.TextMarshaler interface for DirectionType

func (DirectionType) MarshalYAML

func (i DirectionType) MarshalYAML() (interface{}, error)

MarshalYAML implements a YAML Marshaler for DirectionType

func (DirectionType) String

func (i DirectionType) String() string

func (*DirectionType) UnmarshalJSON

func (i *DirectionType) UnmarshalJSON(data []byte) error

UnmarshalJSON implements the json.Unmarshaler interface for DirectionType

func (*DirectionType) UnmarshalText

func (i *DirectionType) UnmarshalText(text []byte) error

UnmarshalText implements the encoding.TextUnmarshaler interface for DirectionType

func (*DirectionType) UnmarshalYAML

func (i *DirectionType) UnmarshalYAML(unmarshal func(interface{}) error) error

UnmarshalYAML implements a YAML Unmarshaler for DirectionType

func (DirectionType) Values

func (DirectionType) Values() []string

type LSTM

type LSTM struct {
	// contains filtered or unexported fields
}

LSTM holds an LSTM configuration. It can be created with New (or NewWithWeights), and once finished to be configured, can be applied to x with Done.

func New

func New(ctx *context.Context, x *Node, hiddenSize int) *LSTM

New creates a new LSTM layer to be configured and then applied to x. x should be shaped [batchSize, sequenceSize, featuresSize].

See LSTM.Ragged if x is not densely used: a more compact version to padding or masking.

Once finished configuring, call LSTM.Done and it will return the final state of the LSTM.

func NewWithWeights

func NewWithWeights(x *Node, inputsW, recurrentW, biases, peepholeW *Node) *LSTM

NewWithWeights creates a new LSTM layer using the given weights -- as opposed to creating them on-the-fly.

Args:

  • x: shaped [batchSize, sequenceSize, featuresSize]
  • inputsW: shaped [numDirections, 4, hiddenSize, featuresSize]
  • recurrentW: shaped [numDirections, 4, hiddenSize, hiddenSize]
  • biases: for both gates and cell updates, shaped [numDirections, 8, hiddenSize].
  • peepholeW: optional (can be nil), shaped [numDirections, 3, hiddenSize].

See details in [3]

func (*LSTM) Direction

func (l *LSTM) Direction(dir DirectionType) *LSTM

Direction configures in which direction to run the LSTM: DirForward, DirReverse or both.

func (*LSTM) Done

func (l *LSTM) Done() (allHiddenStates, lastHiddenState, lastCellState *Node)

Done should be called once the LSTM is configured. It will apply the LSTM layer to the sequence in X. - allHiddenStates: [sequenceSize, numDirections, batchSize, hiddenSize] - lastHiddenState and lastCellState: [numDirections, batchSize, hiddenSize]

func (*LSTM) InitialStates

func (l *LSTM) InitialStates(initialHiddenState, initialCellState *Node) *LSTM

InitialStates configures the LSTM initial hidden state and cell state (h_0 and c_0 in the literature). If not set it defaults to 0.

Both must be shaped [numDirections, batchSize, hiddenSize].

This is useful if concatenating the output of the LSTM to another instance of the (same?) LSTM. That is, you can feed here the output values from LSTM.Done of a previous call.

func (*LSTM) NumDirections

func (l *LSTM) NumDirections() int

NumDirections based on the direction information selected. See LSTM.Direction to configure the direction.

func (*LSTM) Ragged

func (l *LSTM) Ragged(sequencesLengths *Node) *LSTM

Ragged indicates that x is "ragged" (the sequences are not used to the end), and its lengths are given by sequenceLengths, which must be shaped [batchSize]. It is a more compact version of padding.

The default is to assume all sequences are dense -- used to the end.

func (*LSTM) UsePeephole

func (l *LSTM) UsePeephole(usePeephole bool) *LSTM

UsePeephole configures whether to use a "peephole" to the "cell state" (c_i) when calculating values that usually only depend on the hidden state (h_i).

Default to false.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL