gather

package
v1.15.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 25, 2026 License: Apache-2.0 Imports: 8 Imported by: 0

Documentation

Overview

Package gather provides the Gather layer for embedding-table lookup.

Stability: stable

Package gather provides the Gather layer for the Zerfoo ML framework.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func BuildGather

func BuildGather[T tensor.Numeric](
	engine compute.Engine[T],
	_ numeric.Arithmetic[T],
	name string,
	params map[string]*graph.Parameter[T],
	attrs map[string]interface{},
) (graph.Node[T], error)

BuildGather constructs a Gather layer. For embedding-style nodes whose name maps to a known weight parameter, weights are embedded in the layer. For "gather from shape" nodes where the indices are constant, the indices are embedded in the layer. All other Gather nodes operate as general ONNX Gather (axis-0 indexing).

Types

type Gather

type Gather[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

Gather is a layer that gathers slices from a tensor.

func New

func New[T tensor.Numeric](engine compute.Engine[T]) *Gather[T]

New creates a new Gather layer.

func NewWithIndices added in v0.2.1

func NewWithIndices[T tensor.Numeric](engine compute.Engine[T], indices *tensor.TensorNumeric[int]) *Gather[T]

NewWithIndices creates a new Gather layer with embedded constant indices. At forward time, input[0] is the data tensor; indices come from the layer.

func NewWithWeights

func NewWithWeights[T tensor.Numeric](engine compute.Engine[T], weights *tensor.TensorNumeric[T]) *Gather[T]

NewWithWeights creates a new Gather layer with embedded weights.

func (*Gather[T]) Attributes

func (g *Gather[T]) Attributes() map[string]interface{}

Attributes returns nil for the Gather layer.

func (*Gather[T]) Backward

func (g *Gather[T]) Backward(ctx context.Context, mode types.BackwardMode, outputGradient *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward computes the gradients for the Gather layer.

func (*Gather[T]) EmbeddedFrozen added in v0.2.1

func (g *Gather[T]) EmbeddedFrozen() []*tensor.TensorNumeric[T]

EmbeddedFrozen returns the embedded frozen tensors (weights) that should be registered as frozen slots during compilation. Returns nil if no weights are embedded. Implements graph.EmbeddedFrozenProvider.

func (*Gather[T]) Forward

func (g *Gather[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes the gather operation.

func (*Gather[T]) HasEmbeddedWeights

func (g *Gather[T]) HasEmbeddedWeights() bool

HasEmbeddedWeights returns true if this Gather layer has embedded weights.

func (*Gather[T]) OpType

func (g *Gather[T]) OpType() string

OpType returns the operation type of the Gather layer.

func (*Gather[T]) OutputShape

func (g *Gather[T]) OutputShape() []int

OutputShape returns the output shape of the Gather layer.

func (*Gather[T]) Parameters

func (g *Gather[T]) Parameters() []*graph.Parameter[T]

Parameters returns no trainable parameters for the Gather layer.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL