Documentation
¶
Overview ¶
Package hrm implements the Hierarchical Reasoning Model.
Stability: alpha
Package hrm implements the Hierarchical Reasoning Model.
Package hrm implements the Hierarchical Reasoning Model.
Index ¶
- type HModule
- func (m *HModule[T]) Attributes() map[string]any
- func (m *HModule[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], ...) ([]*tensor.TensorNumeric[T], error)
- func (m *HModule[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (m *HModule[T]) OpType() string
- func (m *HModule[T]) OutputShape() []int
- func (m *HModule[T]) Parameters() []*graph.Parameter[T]
- type LModule
- func (m *LModule[T]) Attributes() map[string]any
- func (m *LModule[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], ...) ([]*tensor.TensorNumeric[T], error)
- func (m *LModule[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (m *LModule[T]) OpType() string
- func (m *LModule[T]) OutputShape() []int
- func (m *LModule[T]) Parameters() []*graph.Parameter[T]
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
This section is empty.
Types ¶
type HModule ¶
type HModule[T tensor.Numeric] struct { Block *transformer.Block[T] HiddenState *tensor.TensorNumeric[T] // contains filtered or unexported fields }
HModule represents the high-level recurrent module of the HRM. It implements graph.Node so it can be used in a computation graph.
func NewHModule ¶
func NewHModule[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], modelDim, ffnDim int, attention graph.Node[T], opts ...transformer.BlockOption[T], ) (*HModule[T], error)
NewHModule creates a new HModule.
func (*HModule[T]) Attributes ¶
Attributes returns the attributes of the HModule.
func (*HModule[T]) Backward ¶
func (m *HModule[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], _ ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes the gradients of the HModule.
func (*HModule[T]) Forward ¶
func (m *HModule[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward performs a single step of the HModule's computation. inputs[0] is the L-module state (lState).
func (*HModule[T]) OutputShape ¶
OutputShape returns the output shape of the HModule.
func (*HModule[T]) Parameters ¶
Parameters returns the parameters of the HModule.
type LModule ¶
type LModule[T tensor.Numeric] struct { Block *transformer.Block[T] HiddenState *tensor.TensorNumeric[T] // contains filtered or unexported fields }
LModule represents the low-level recurrent module of the HRM. It implements graph.Node so it can be used in a computation graph.
func NewLModule ¶
func NewLModule[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], modelDim, ffnDim int, attention graph.Node[T], opts ...transformer.BlockOption[T], ) (*LModule[T], error)
NewLModule creates a new LModule.
func (*LModule[T]) Attributes ¶
Attributes returns the attributes of the LModule.
func (*LModule[T]) Backward ¶
func (m *LModule[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], _ ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes the gradients of the LModule. Returns gradients for both inputs: [dHState, dProjectedInput]. Since forward combines inputs via Add, both gradients equal the block's input gradient.
func (*LModule[T]) Forward ¶
func (m *LModule[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward performs a single step of the LModule's computation. inputs[0] is the H-module state (hState), inputs[1] is the projected input.
func (*LModule[T]) OutputShape ¶
OutputShape returns the output shape of the LModule.
func (*LModule[T]) Parameters ¶
Parameters returns the parameters of the LModule.