Documentation
¶
Overview ¶
Package lstm provides a minimal "Long Short-Term Memory RNN" (LSTM) [1] implementation.
An LSTM is a type of recurrent neural network that addresses the vanishing gradient problem in vanilla RNNs through additional cells, input and output gates. Intuitively, vanishing gradients are solved through additional additive components, and forget gate activations, that allow the gradients to flow through the network without vanishing as quickly.
Since GoMLX doesn't implement loops, the size of the graph will be O(N) on the size of the sequence -- each step of the LSTM is instantiated as its own graph nodes.
In any case, if not for educational or historical reasons, consider using transformer or (dilated) convolution layers instead.
It was created to allow conversion of ONNX model, but it's fully differentiable and can be used to train models.
See discussions in [2], and specification of ONNX LSTM which this was created to support in [3].
[1] https://www.bioinf.jku.at/publications/older/2604.pdf, Hochreiter & Schmidhuber, 1997 [2] https://colah.github.io/posts/2015-08-Understanding-LSTMs/ [3] https://onnx.ai/onnx/operators/onnx__LSTM.html
Index ¶
- func DirectionTypeStrings() []string
- type ActivationFn
- type DirectionType
- func (i DirectionType) IsADirectionType() bool
- func (i DirectionType) MarshalJSON() ([]byte, error)
- func (i DirectionType) MarshalText() ([]byte, error)
- func (i DirectionType) MarshalYAML() (interface{}, error)
- func (i DirectionType) String() string
- func (i *DirectionType) UnmarshalJSON(data []byte) error
- func (i *DirectionType) UnmarshalText(text []byte) error
- func (i *DirectionType) UnmarshalYAML(unmarshal func(interface{}) error) error
- func (DirectionType) Values() []string
- type LSTM
- func (l *LSTM) Direction(dir DirectionType) *LSTM
- func (l *LSTM) Done() (allHiddenStates, lastHiddenState, lastCellState *Node)
- func (l *LSTM) InitialStates(initialHiddenState, initialCellState *Node) *LSTM
- func (l *LSTM) NumDirections() int
- func (l *LSTM) Ragged(sequencesLengths *Node) *LSTM
- func (l *LSTM) UsePeephole(usePeephole bool) *LSTM
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func DirectionTypeStrings ¶
func DirectionTypeStrings() []string
DirectionTypeStrings returns a slice of all String values of the enum
Types ¶
type ActivationFn ¶
type ActivationFn func(x *Node) *Node
ActivationFn defines an activation function used by the LSTM.
type DirectionType ¶
type DirectionType int
DirectionType defines the direction to run the LSTM.
const ( DirForward DirectionType = iota DirReverse DirBidirectional )
func DirectionTypeString ¶
func DirectionTypeString(s string) (DirectionType, error)
DirectionTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.
func DirectionTypeValues ¶
func DirectionTypeValues() []DirectionType
DirectionTypeValues returns all values of the enum
func (DirectionType) IsADirectionType ¶
func (i DirectionType) IsADirectionType() bool
IsADirectionType returns "true" if the value is listed in the enum definition. "false" otherwise
func (DirectionType) MarshalJSON ¶
func (i DirectionType) MarshalJSON() ([]byte, error)
MarshalJSON implements the json.Marshaler interface for DirectionType
func (DirectionType) MarshalText ¶
func (i DirectionType) MarshalText() ([]byte, error)
MarshalText implements the encoding.TextMarshaler interface for DirectionType
func (DirectionType) MarshalYAML ¶
func (i DirectionType) MarshalYAML() (interface{}, error)
MarshalYAML implements a YAML Marshaler for DirectionType
func (DirectionType) String ¶
func (i DirectionType) String() string
func (*DirectionType) UnmarshalJSON ¶
func (i *DirectionType) UnmarshalJSON(data []byte) error
UnmarshalJSON implements the json.Unmarshaler interface for DirectionType
func (*DirectionType) UnmarshalText ¶
func (i *DirectionType) UnmarshalText(text []byte) error
UnmarshalText implements the encoding.TextUnmarshaler interface for DirectionType
func (*DirectionType) UnmarshalYAML ¶
func (i *DirectionType) UnmarshalYAML(unmarshal func(interface{}) error) error
UnmarshalYAML implements a YAML Unmarshaler for DirectionType
func (DirectionType) Values ¶
func (DirectionType) Values() []string
type LSTM ¶
type LSTM struct {
// contains filtered or unexported fields
}
LSTM holds an LSTM configuration. It can be created with New (or NewWithWeights), and once finished to be configured, can be applied to x with Done.
func New ¶
New creates a new LSTM layer to be configured and then applied to x. x should be shaped [batchSize, sequenceSize, featuresSize].
See LSTM.Ragged if x is not densely used: a more compact version to padding or masking.
Once finished configuring, call LSTM.Done and it will return the final state of the LSTM.
func NewWithWeights ¶
func NewWithWeights(x *Node, inputsW, recurrentW, biases, peepholeW *Node) *LSTM
NewWithWeights creates a new LSTM layer using the given weights -- as opposed to creating them on-the-fly.
Args:
- x: shaped [batchSize, sequenceSize, featuresSize]
- inputsW: shaped [numDirections, 4, hiddenSize, featuresSize]
- recurrentW: shaped [numDirections, 4, hiddenSize, hiddenSize]
- biases: for both gates and cell updates, shaped [numDirections, 8, hiddenSize].
- peepholeW: optional (can be nil), shaped [numDirections, 3, hiddenSize].
See details in [3]
func (*LSTM) Direction ¶
func (l *LSTM) Direction(dir DirectionType) *LSTM
Direction configures in which direction to run the LSTM: DirForward, DirReverse or both.
func (*LSTM) Done ¶
func (l *LSTM) Done() (allHiddenStates, lastHiddenState, lastCellState *Node)
Done should be called once the LSTM is configured. It will apply the LSTM layer to the sequence in X. - allHiddenStates: [sequenceSize, numDirections, batchSize, hiddenSize] - lastHiddenState and lastCellState: [numDirections, batchSize, hiddenSize]
func (*LSTM) InitialStates ¶
InitialStates configures the LSTM initial hidden state and cell state (h_0 and c_0 in the literature). If not set it defaults to 0.
Both must be shaped [numDirections, batchSize, hiddenSize].
This is useful if concatenating the output of the LSTM to another instance of the (same?) LSTM. That is, you can feed here the output values from LSTM.Done of a previous call.
func (*LSTM) NumDirections ¶
NumDirections based on the direction information selected. See LSTM.Direction to configure the direction.
func (*LSTM) Ragged ¶
Ragged indicates that x is "ragged" (the sequences are not used to the end), and its lengths are given by sequenceLengths, which must be shaped [batchSize]. It is a more compact version of padding.
The default is to assume all sequences are dense -- used to the end.
func (*LSTM) UsePeephole ¶
UsePeephole configures whether to use a "peephole" to the "cell state" (c_i) when calculating values that usually only depend on the hidden state (h_i).
Default to false.