hrm

package
v1.25.2 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 26, 2026 License: Apache-2.0 Imports: 8 Imported by: 0

Documentation

Overview

Package hrm implements the Hierarchical Reasoning Model.

Stability: alpha

Package hrm implements the Hierarchical Reasoning Model.

Package hrm implements the Hierarchical Reasoning Model.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type HModule

type HModule[T tensor.Numeric] struct {
	Block       *transformer.Block[T]
	HiddenState *tensor.TensorNumeric[T]
	// contains filtered or unexported fields
}

HModule represents the high-level recurrent module of the HRM. It implements graph.Node so it can be used in a computation graph.

func NewHModule

func NewHModule[T tensor.Numeric](
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	modelDim, ffnDim int,
	attention graph.Node[T],
	opts ...transformer.BlockOption[T],
) (*HModule[T], error)

NewHModule creates a new HModule.

func (*HModule[T]) Attributes

func (m *HModule[T]) Attributes() map[string]any

Attributes returns the attributes of the HModule.

func (*HModule[T]) Backward

func (m *HModule[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], _ ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward computes the gradients of the HModule.

func (*HModule[T]) Forward

func (m *HModule[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward performs a single step of the HModule's computation. inputs[0] is the L-module state (lState).

func (*HModule[T]) OpType

func (m *HModule[T]) OpType() string

OpType returns the operation type of the HModule.

func (*HModule[T]) OutputShape

func (m *HModule[T]) OutputShape() []int

OutputShape returns the output shape of the HModule.

func (*HModule[T]) Parameters

func (m *HModule[T]) Parameters() []*graph.Parameter[T]

Parameters returns the parameters of the HModule.

type LModule

type LModule[T tensor.Numeric] struct {
	Block       *transformer.Block[T]
	HiddenState *tensor.TensorNumeric[T]
	// contains filtered or unexported fields
}

LModule represents the low-level recurrent module of the HRM. It implements graph.Node so it can be used in a computation graph.

func NewLModule

func NewLModule[T tensor.Numeric](
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	modelDim, ffnDim int,
	attention graph.Node[T],
	opts ...transformer.BlockOption[T],
) (*LModule[T], error)

NewLModule creates a new LModule.

func (*LModule[T]) Attributes

func (m *LModule[T]) Attributes() map[string]any

Attributes returns the attributes of the LModule.

func (*LModule[T]) Backward

func (m *LModule[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], _ ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward computes the gradients of the LModule. Returns gradients for both inputs: [dHState, dProjectedInput]. Since forward combines inputs via Add, both gradients equal the block's input gradient.

func (*LModule[T]) Forward

func (m *LModule[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward performs a single step of the LModule's computation. inputs[0] is the H-module state (hState), inputs[1] is the projected input.

func (*LModule[T]) OpType

func (m *LModule[T]) OpType() string

OpType returns the operation type of the LModule.

func (*LModule[T]) OutputShape

func (m *LModule[T]) OutputShape() []int

OutputShape returns the output shape of the LModule.

func (*LModule[T]) Parameters

func (m *LModule[T]) Parameters() []*graph.Parameter[T]

Parameters returns the parameters of the LModule.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL