graph

package
v0.19.2 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: May 1, 2025 License: Apache-2.0 Imports: 24 Imported by: 4

Documentation

Overview

Package graph is the core package for GoMLX. It is used to create and run computation graphs on XLA/PJRT (using github.com/gomlx/gopjrt) -- a just-in-time compiler that allows for very efficient numerical computations.

It requires PJRT plugins corresponding to accelerators ("cpu, "cuda", "tpu", etc.) to be available (see github.com/gomlx/gopjrt) to compile and execute programs. Gopjrt is distributed with a "cpu" plugin, and it describes how to install a "cuda" plugin, for Nvidia graphics cards.

It also includes an autograd system and many useful higher level machine learning tools.

The main elements in the package (or related) are:

  • Backend: manages an XLA/PJRT (through gopjrt) connection: a PJRT plugin and a client. The whole computation building, compilation and execution runs within the scope of a Backend.

  • Graph: created by the Backend, this is used to construct a computation graph that can then be "just-in-time" compiled and executed efficiently. To construct a `Graph` one puts together nodes or "ops" defining the desired sequence of operations.

  • Node: represents the result of an operation ("outputOps" for short). E.g: Add, Sub, Mul, Sigmoid, Reshape, etc. Each node has a fixed shape that is known in "graph building time" (see discussion below).

  • context.Context: created by the Backend, a higher level abstraction convenient when building gradient descent based machine learning (ML) models (like Neural Networks). It organizes Variable objects into "scope", which usually holds the learnable weights for ML. It also allows for loading/saving of these values.

## Error Handling

Graph (and its Nodes) and context.Context methods "throw" errors with `panic`. This prevents having to manage error returning for every function call. It always throws meaningful error messages, with the full stack, to ease tracking bugs and solve issues.

Notice that unfortunately, there is no way to statically, in compile time, check for many of the errors that for a human would be relatively easy to spot without running the program. There is no way in Go to run arbitrary logic in compile time.

## Delayed Execution

When using ML frameworks, it usually requires the user to think about different "times" that things are happening. The same is true for GoMLX, and it is helpful to keep those in mind upfront, to have the right mental model:

  • **Compile time**: this is during Go compilation. Some amount of type checking is done here, but most of the tensor shape compatibility cannot be done statically here, unfortunately. Even if for a human it would be obvious without compiling and running a program that some operation among different shaped tensors shouldn't be allowed, there is no way in Go to run arbitrary logic in compile time to validate tensor shape compatibility. So most of the checking is left to "graph building time". Maybe one day one can write a gomlx_linter the runs before the compiler that could catch some of these.

  • **Graph building time**: this is when one is building a computation Graph, using the various ops (Add, Sub, Mul, ReduceSum, etc.). No actual computation happens here, just the building of the Graph (it's a kind of program) that will be executed later. This happens in "runtime", meaning after the Go program is compiled. And only in graph building time that proper shapes are checked, and good error (with stack traces) are reported back. This means that development often involves coding, and then running the Graph building to see if shapes are correct and what one wants -- Graph building is very fast, since not data is actually manipulated. Creating tests that just build the graph is the recommended way to develop. Quick sketching of code can be done on a Jupyter Notebook — see github.com/janpfeifer/gonb(https://github.com/janpfeifer/gonb) for Jupyter notebook support for the Go language. Once the model is built, it is usually "just in time" (JIT) compiled, and can be run.

  • **Computation/Training/Evaluation time**: this happens after the Graph is built and compiled, and all one does is feed values in, and get the computation out -- using a very fast just-in-time compiled code. Error reports here are terser and harder to debug (they come from the underlying C++ library), but usually most of the issues are caught in Graph building time. In particular, there is a `nanlogger` library that helps identify where `NaN` or `Inf` first appears in the middle of computation — handy to debug the math.

Index

Constants

View Source
const AliasScopeSeparator = "/"

AliasScopeSeparator is the string used to join the individual alias scope parts as well as the alias itself. So if the scope is currently ["a", "b"] and an alias "output" is created, it will be renamed "/a/b/output".

View Source
const InvalidNodeId = NodeId(-1)

InvalidNodeId indicates a node that failed to be created.

View Source
const InvalidParameterHandle = ParameterHandle(-1)

InvalidParameterHandle represents an invalid (or non-existent) parameter.

View Source
const (
	MaxSizeToPrint = 5
)
View Source
const NoInterpolation = int(-1)

NoInterpolation can be used for the outputSizes of the Interpolation call.

Variables

View Source
var DefaultExecMaxCacheSize = 32

DefaultExecMaxCacheSize is the value used to initialize new Exec objects.

View Source
var MinConstValueSizeToKeep = 32

MinConstValueSizeToKeep defines a size below which constant values (see Const, ConstTensor) are kept in the Node/Graph for printing/debugging purposes

If set to 0, no value is kept.

View Source
var ReduceAndKeepMasked = MaskedReduceAndKeep

ReduceAndKeepMasked is an alias for MaskedReduceAndKeep.

Deprecated: all functions that take mask are prefixed with `Masked...`

View Source
var (
	// RngStateShape is the shapes of the random number generator state, used
	// in all Random* functions.
	// This is dependent on the algorithm, that for now is fixed.
	RngStateShape = shapes.Make(dtypes.Uint64, 3)
)
View Source
var VJPRegistration = map[NodeType]VJP{
	NodeTypeInvalid:              vjpForSingleOutput(noOpVJP),
	NodeTypeConstant:             vjpForSingleOutput(nilVJP),
	NodeTypeParameter:            vjpForSingleOutput(nilVJP),
	NodeTypeConvertDType:         vjpForSingleOutput(convertDTypeVJP),
	NodeTypeWhere:                vjpForSingleOutput(whereVJP),
	NodeTypeNeg:                  vjpForSingleOutput(negVJP),
	NodeTypeAbs:                  vjpForSingleOutput(absVJP),
	NodeTypeExp:                  vjpForSingleOutput(expVJP),
	NodeTypeLog:                  vjpForSingleOutput(logVJP),
	NodeTypeLog1p:                vjpForSingleOutput(log1pVJP),
	NodeTypeTanh:                 vjpForSingleOutput(tanhVJP),
	NodeTypeAdd:                  vjpForSingleOutput(addVJP),
	NodeTypeSub:                  vjpForSingleOutput(subVJP),
	NodeTypeMul:                  vjpForSingleOutput(mulVJP),
	NodeTypeDiv:                  vjpForSingleOutput(divVJP),
	NodeTypePow:                  vjpForSingleOutput(powVJP),
	NodeTypeSqrt:                 vjpForSingleOutput(sqrtVJP),
	NodeTypeErf:                  vjpForSingleOutput(erfVJP),
	NodeTypeBatchNormForTraining: batchNormForTrainingVJP,

	NodeTypeLogicalAnd: vjpForSingleOutput(zeroVJP),
	NodeTypeLogicalOr:  vjpForSingleOutput(zeroVJP),
	NodeTypeLogicalXor: vjpForSingleOutput(zeroVJP),
	NodeTypeLogicalNot: vjpForSingleOutput(zeroVJP),
	NodeTypeBitwiseAnd: vjpForSingleOutput(zeroVJP),
	NodeTypeBitwiseOr:  vjpForSingleOutput(zeroVJP),
	NodeTypeBitwiseXor: vjpForSingleOutput(zeroVJP),
	NodeTypeBitwiseNot: vjpForSingleOutput(zeroVJP),

	NodeTypeReal:    vjpForSingleOutput(realVJP),
	NodeTypeImag:    vjpForSingleOutput(imagVJP),
	NodeTypeConj:    vjpForSingleOutput(conjVJP),
	NodeTypeComplex: vjpForSingleOutput(complexVJP),

	NodeTypeMax:                vjpForSingleOutput(minMaxVJP),
	NodeTypeMin:                vjpForSingleOutput(minMaxVJP),
	NodeTypeReshape:            vjpForSingleOutput(reshapeVJP),
	NodeTypeReduceSum:          vjpForSingleOutput(reduceSumVJP),
	NodeTypeReduceMax:          vjpForSingleOutput(reduceMaxVJP),
	NodeTypeReduceMin:          vjpForSingleOutput(reduceMinVJP),
	NodeTypeLogistic:           vjpForSingleOutput(logisticVJP),
	NodeTypeDot:                vjpForSingleOutput(dotVJP),
	NodeTypeDotGeneral:         vjpForSingleOutput(dotGeneralVJP),
	NodeTypeSlice:              vjpForSingleOutput(sliceVJP),
	NodeTypeGather:             vjpForSingleOutput(gatherVJP),
	NodeTypeScatterSum:         vjpForSingleOutput(scatterSumVJP),
	NodeTypeScatterMax:         vjpForSingleOutput(scatterMaxOrMinVJP),
	NodeTypeScatterMin:         vjpForSingleOutput(scatterMaxOrMinVJP),
	NodeTypeConcatenate:        vjpForSingleOutput(concatenateVJP),
	NodeTypeConvGeneralDilated: vjpForSingleOutput(convGeneralDilatedVJP),
	NodeTypeReduceWindow:       vjpForSingleOutput(reduceWindowVJP),
	NodeTypeTranspose:          vjpForSingleOutput(transposeVJP),
	NodeTypeBroadcastInDim:     vjpForSingleOutput(broadcastInDimVJP),
	NodeTypeFFT:                vjpForSingleOutput(fftVJP),
	NodeTypeDynamicSlice:       vjpForSingleOutput(dynamicSliceVJP),
	NodeTypeDynamicUpdateSlice: vjpForSingleOutput(dynamicUpdateSliceVJP),
}

VJPRegistration maps each node type to its implementation of VJP. If implementing a new outputOps, or for experimentation, one can dynamically change this.

Notice xla.GetTupleElementNode is specialized inside the main reverse autodiff code, and is not in the table here.

Functions

func AdjustAxisToOperandRank added in v0.11.0

func AdjustAxisToOperandRank(operand *Node, axis int) int

AdjustAxisToOperandRank returns the positive axis to the operand shapes, adjusting in case the axis given is negative.

It panics if axis given is not in the operand's rank range.

func DefaultNodeLogger added in v0.4.1

func DefaultNodeLogger(g *Graph, messages []string, values []*tensors.Tensor, nodes []NodeId)

DefaultNodeLogger for nodes marked to be logged. It prints the message and the node value for each logged node.

It accepts special prefixes on message name that affects the printing:

  • #full : prints full tensor value (as opposed to abbreviated).

func DonateTensorBuffer added in v0.11.0

func DonateTensorBuffer(t *tensors.Tensor, backend backends.Backend, deviceNum ...backends.DeviceNum) any

DonateTensorBuffer can be used by Graph.Run, Graph.RunWithMap or as input to Exec.Call, and it marks the Tensor to donate its on-device buffer to the execution.

This allows the accelerator (GPU) to reuse the space of the donated buffer, which saves space if the original value is no longer used. Useful in particular is updating some state in a loop.

This doesn't work if the tensor shares the buffer with the device (usually CPU plugins). You can check that with IsShared().

Example:

myState := myExec.Call(DonateTensorBuffer(myState, backend))[0]

It requires the backend and the deviceNum (defaults to 0) of the device buffer to donate.

Notice that after this, t's value in the device becomes invalid.

func ExecOnce added in v0.11.1

func ExecOnce[F ExecGraphFnOneOutput](backend backends.Backend, graphFn F, args ...any) *tensors.Tensor

ExecOnce builds the graph and executes it with the given arguments, and returns the one output.

It's short for a call to NewExec, Exec.Call and Exec.Finalize for functions that return only one output.

See ExecOnceN if you have multiple outputs.

func ExecOnceN added in v0.11.1

func ExecOnceN[F ExecGraphFn](backend backends.Backend, graphFn F, args ...any) []*tensors.Tensor

ExecOnceN builds the graph and executes it with the given arguments, and returns various output.

It's short for a call to NewExec, Exec.Call and Exec.Finalize.

See ExecOnce for a more convenient version if you have only one output.

func NodeTypeStrings added in v0.11.0

func NodeTypeStrings() []string

NodeTypeStrings returns a slice of all String values of the enum

func RngState added in v0.4.0

func RngState() *tensors.Tensor

RngState creates a random number generator (RNG) state initialized from the nanosecond clock at the time of the graph creation.

Notice it returns a concrete tensor value that can be used to set a variable or constant to be used in a graph.

Typical use case would be to use like:

rngState := Const(g, RngState())

func RngStateFromSeed added in v0.4.0

func RngStateFromSeed(seed int64) *tensors.Tensor

RngStateFromSeed creates a random number generator (RNG) state based on the static seed.

Notice it returns a concrete tensor value that can be used to set a variable or constant to be used in a graph.

Typical use case would be to use like:

rngState := Const(g, RngStateFromSeed(42))

Types

type ConvolutionBuilder

type ConvolutionBuilder struct {
	// contains filtered or unexported fields
}

ConvolutionBuilder is a helper to build a convolution computation. Create it with Convolve, set the desired parameters and when set, call `IsNil()`.

func Convolve

func Convolve(x, kernel *Node) *ConvolutionBuilder

Convolve prepares a convolution on x with the given kernel for arbitrary number of spatial dimensions (1D, 2D, 3D, etc.).

It returns a ConvolutionBuilder object that can be further configured. Once the configuration is finished, call ConvolutionBuilder.Done and it will return the convolved x. Browse through ConvolutionBuilder to see its capabilities and defaults.

The shape of x should be [batch, <spatial_dimensions...>, input_channels] if configured with ConvolutionBuilder.ChannelsAxis(timage.ChannelsLast), the default. If one sets ConvolutionBuilder.ChannelsAxis(timage.ChannelsFirst), then the shape should be [batch, input_channels, <spatial_dimensions...>] instead.

Note: package timage refers to package github.com/gomlx/gomlx/types/tensor/image.

The shape of kernel should be [<spatial_dimensions...>, input_channels, output_channels] if configured with ConvolutionBuilder.ChannelsAxis(timage.ChannelsLast), the default. If one sets ConvolutionBuilder.ChannelsAxis(timage.ChannelsFirst), the shape should be [input_channels, <spatial_dimensions...>, output_channels] instead.

Notice x and kernel must have the same rank.

We follow the Keras convention of calling the "depth" or "feature" or "channels" dimension "channels". Likewise, we use "kernel" instead of "filters" -- but they mean the same.

Additional features:

  • Group operations: Use ConvolutionBuilder.FeatureGroupCount to split channels or BatchGroupCount to split batches into independent processing groups. When using either feature, the kernel shape changes and back-propagation is not yet supported.

func (*ConvolutionBuilder) AxesConfig

AxesConfig specify the exact configuration of the axes on the input (x/input and kernel) and output of the Convolve operation. This is advanced (and may not be supported in every backend), but it's powerful. Consider using `ConvolutionBuilder.ChannelsAxis` instead.

The default is `ChannelsAxis(timage.ChannelsLast)`.

func (*ConvolutionBuilder) BatchGroupCount added in v0.18.0

func (conv *ConvolutionBuilder) BatchGroupCount(groupCount int) *ConvolutionBuilder

BatchGroupCount splits batches into independent processing groups. Used for cross-batch interactions like ShuffleNet's channel shuffle.

When BatchGroupCount != 1, the kernel shape changes: the batch dimension of the input is divided by the group count, creating separate convolution groups where each group processes a subset of the batch.

The output shape will have the same spatial dimensions as a regular convolution but with batch dimension affected by the grouping.

Note: Back-propagation is not yet implemented for this feature.

Reference: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch

func (*ConvolutionBuilder) ChannelsAxis added in v0.3.0

func (conv *ConvolutionBuilder) ChannelsAxis(channelsAxisConfig timage.ChannelsAxisConfig) *ConvolutionBuilder

ChannelsAxis configures the axis for the channels (aka. "depth" or "features") dimension. The default is `timage.ChannelsLast`, meaning the "channels" dimension comes last.

Note: `timage` refers to package `github.com/gomlx/gomlx/types/tensor/image`.

For more fine-control, see AxesConfig.

It returns the modified Config object, so calls can be cascaded.

func (*ConvolutionBuilder) DilationPerDim

func (conv *ConvolutionBuilder) DilationPerDim(dilations ...int) *ConvolutionBuilder

DilationPerDim sets the kernel dilations for each spatial dimension of the convolution. The default is 1 for every dimension.

It specifies the kernel up-sampling rate. In the literature, the same parameter is sometimes called input stride or dilation. The effective kernel size used for the convolution will be `kernel_shape + (kernel_shape - 1) * (dilation - 1)`, obtained by inserting (dilation-1) zeros between consecutive elements of the original filter in the spatial dimension.

One cannot use strides and dilation at the same time.

func (*ConvolutionBuilder) Dilations

func (conv *ConvolutionBuilder) Dilations(dilation int) *ConvolutionBuilder

Dilations sets the dilations of the convolution: the same value is used for every dimension.

The default is 1.

It specifies the kernel up-sampling rate. In the literature, the same parameter is sometimes called input stride or dilation. The effective kernel size used for the convolution will be `kernel_shape + (kernel_shape - 1) * (dilation - 1)`, obtained by inserting (dilation-1) zeros between consecutive elements of the original filter in the spatial dimension.

One cannot use strides and dilation at the same time.

func (*ConvolutionBuilder) Done

func (conv *ConvolutionBuilder) Done() *Node

Done indicates that the convolve operation is finished being configured, and it updates the computation graph with convolution, and returns the resulting Node.

func (*ConvolutionBuilder) FeatureGroupCount added in v0.18.0

func (conv *ConvolutionBuilder) FeatureGroupCount(groupCount int) *ConvolutionBuilder

FeatureGroupCount splits input/output channels into independent groups. Equivalent to TensorFlow's "groups" parameter in tf.nn.convNd operations.

When FeatureGroupCount != 1, the kernel shape changes: the input channels dimension of the kernel must equal (input_channels / group_count). This effectively creates separate convolution groups where each group processes a subset of input channels and produces a subset of output channels.

For depthwise convolution, set groups = input_channels (see tf.nn.depthwise_conv2d). The output shape will have the same spatial dimensions as a regular convolution but with channel dimensions affected by the grouping.

Side effects:

  • Kernel shape: The kernel's input channel dimension becomes (input_channels / group_count)
  • Output shape: The output maintains the same spatial dimensions as regular convolution but each group independently maps its input channels to output channels
  • Performance: Can reduce computation cost by limiting connections between channels
  • Memory usage: Reduces the number of parameters in the kernel

Note: Back-propagation is not yet implemented for this feature.

Reference: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#group_by_window

func (*ConvolutionBuilder) InputDilationPerDim

func (conv *ConvolutionBuilder) InputDilationPerDim(dilations ...int) *ConvolutionBuilder

InputDilationPerDim is used when generating the gradient of a convolution with strides. It effectively inserts zeros in the input, making it effectively larger than it actually is. The gradient of Convolve with input dilation is not implemented yet, careful.

func (*ConvolutionBuilder) NoPadding

func (conv *ConvolutionBuilder) NoPadding() *ConvolutionBuilder

NoPadding removes any paddings, so if the kernel spatial dimensions > 1, the output shapes will be reduced on the edges. This is the default.

See also PadSame and PaddingPerDim.

func (*ConvolutionBuilder) PadSame

func (conv *ConvolutionBuilder) PadSame() *ConvolutionBuilder

PadSame adds paddings on the edges of x such that in the end the output of the convolution has the same shapes as the input (assuming strides=1).

The default is no padding. See also NoPadding and PaddingPerDim.

func (*ConvolutionBuilder) PaddingPerDim

func (conv *ConvolutionBuilder) PaddingPerDim(paddings [][2]int) *ConvolutionBuilder

PaddingPerDim specifies the paddings at the start and at the end to use per spatial dimension, that means one pair ([2]int) per spatial dimension.

If a nil value for paddings is given, this have no effect.

The default is no padding. See also NoPadding and PadSame.

func (*ConvolutionBuilder) StridePerDim

func (conv *ConvolutionBuilder) StridePerDim(strides ...int) *ConvolutionBuilder

StridePerDim sets the strides for each spatial dimension of the convolution. The default is 1 for every dimension.

The stride is how many steps to move after a convolution. A value of 2 will halve the input size, since a convolution will be done at every other position, and so on. It can be defined separately per dimension.

One cannot use strides and dilation at the same time.

func (*ConvolutionBuilder) Strides

func (conv *ConvolutionBuilder) Strides(strides int) *ConvolutionBuilder

Strides sets the strides of the convolution. It sets the same value for every dimension. The default is 1.

The stride is how many steps to move after a convolution. A value of 2 will halve the input size, since a convolution will be done at every other position, and so on. It can be defined separately per dimension.

One cannot use strides and dilation at the same time.

type ConvolveAxesConfig

type ConvolveAxesConfig = backends.ConvolveAxesConfig

ConvolveAxesConfig defines the interpretation of the input/kernel/output tensor axes. There must be the same number of spatial dimensions (axes) for each of the 3 tensors. Input and output has batch and channel axes. Kernel has inputChannel and outputChannel axes.

type Exec

type Exec struct {
	// contains filtered or unexported fields
}

Exec creates and executes computation graphs as needed based on the inputs shapes.

It simplifies the process of executing a graph building function with real values. For example, assume you wrote:

func L2Norm(x *Node) *Node {
	return Sqrt(ReduceAllSum(Mul(x, x)))
}

To use it with actual values (tensors.Tensor's), one needs to build the computation graph for the specific shape of x, and then execute it. While this is straight forward, it's lots of boilerplate code -- JIT compilation makes things faster, but it imposes some bureaucracy.

With Exec one can do:

var l2NormExec = NewExec(backends.New(), L2Norm)
x0 := []float32{2}
fmt.Printf("L2Norm(%v) = %v\n", x0, l2NormExec.Call(x0)[0].Value())
x1 := []float64{4, 3}
fmt.Printf("L2Norm(%v) = %v\n", x1, l2NormExec.Call(x1)[0].Value())

Notice that both calls to Length.Call will need to create different graphs (for different shapes of the input), but they will be cached, and if the same shapes are used in Call again, the cached compiled graph is reused.

Also, Call outputs a slice with all the outputs, even when there is only one output.

If there are no inputs (for instance for some initialization function), then one needs to take a *Graph as the first parameter of the graph function (graphFn). Example:

iotaMatrixExec := NewExec(backend, func (g *Graph) *Node {
	return IotaFull(g, shapes.Make(dtype.Float32, 3, 3))
})
fmt.Printf("IotaFull(3x3 matrix, float32)=%v\n", iotaMatrixExec.Call()[0].Value())

It also provides a short-form version, that will execute and free the compiled program:

iotaMatrix := ExecOnce(backend, func (g *Graph) *Node { return IotaFull(g, shapes.Make(dtype.Float32, 3, 3)) })

The need to build different graphs for different shapes can be expensive when the shapes of the inputs varies a lot. The usual solution is to use shapes with dimensions in a power scale (for instance powers of 2) and masking of tensors for unused slices. For safety concerns there are a maximum number of different instantiations of the graph. It can be set or disabled with SetMaxCache.

Errors are returned as panic. See Panicf.

func NewExec

func NewExec[F ExecGraphFn](backend backends.Backend, graphFn F) *Exec

NewExec constructs an Exec object that uses the given graphFn to build computation graphs.

graphFn should take *Node as input and return a *Node -- except if there are no (Node) inputs, in which case it should take a single *Graph input.

It's a wrapper for NewExecAny, but uses generics to type check that graphFn is valid.

func NewExecAny

func NewExecAny(backend backends.Backend, graphFn any) *Exec

NewExecAny constructs an Exec object that uses the given graphFn to build computation graphs.

`graphFn` can take only *Node parameters as input and returns one or more *Node. Except if there are no inputs, in which case graphFn needs to take a *Graph as the first parameter.

It will panic if the inputs are invalid.

See also the generics NewExec, which checks for valid graphFn in compile time.

func (*Exec) Call

func (e *Exec) Call(args ...any) []*tensors.Tensor

Call parses the arguments into tensors (if they are not yet) and executes the graph corresponding to the shapes of the arguments. If a graph does not yet exist, one is created, compiled and cached for the shapes.

It returns the outputs in a slice, even if there is only one output.

Errors (with full stack-traces) are raised with `panic`.

func (*Exec) CallWithGraph

func (e *Exec) CallWithGraph(args ...any) (results []*tensors.Tensor, g *Graph)

CallWithGraph is similar to Call, but it also returns the computation graph used in the call. Since Exec creates different computation graphs for different set of parameters, this can help disambiguate in case the user needs to use the Graph for something else.

It returns the outputs in a slice, even if there is only one output, and the graph used to execute the computation.

Errors (with full stack-traces) are raised with `panic`.

func (*Exec) DeviceNum added in v0.11.0

func (e *Exec) DeviceNum() backends.DeviceNum

DeviceNum returns the device being used by this Exec. It defaults to 0 and can be changed with Exec.InDevice.

func (*Exec) Finalize

func (e *Exec) Finalize()

Finalize clears the cache, finalizing the compiled graphs. The Exec object shouldn't be used after that.

func (*Exec) GetNodeLogger

func (e *Exec) GetNodeLogger() LoggerFn

GetNodeLogger returns the currently registered LoggerFn.

func (*Exec) InDevice

func (e *Exec) InDevice(deviceNum backends.DeviceNum) *Exec

InDevice sets the device num to be used by graphs constructed by Exec. This should be called before any invocations of Call(). It returns a reference to itself so calls can be cascaded.

func (*Exec) Name

func (e *Exec) Name() string

Name returns the Exec name, a string used as prefix for Graph construction.

func (*Exec) PreCompile added in v0.10.0

func (e *Exec) PreCompile(args ...any)

PreCompile will build the computation graph and compile it, but not yet execute. Useful when one wants to measure the time separately, from graph compilation and its execution.

Notice, this will include the time to convert args to tensors. If you want to isolate that time, pre-convert args to tensors first.

func (*Exec) SetMaxCache

func (e *Exec) SetMaxCache(maxCacheSize int) *Exec

SetMaxCache sets the maximum size of the cache. Set it to -1 to have unlimited cache size. It returns a reference to itself so calls can be cascaded.

func (*Exec) SetName

func (e *Exec) SetName(name string) *Exec

SetName sets the name of Exec, used to provide the name to graphs created. This should be called before any invocations of Call(). It returns a reference to itself so calls can be cascaded.

func (*Exec) SetNodeLogger

func (e *Exec) SetNodeLogger(loggerFn LoggerFn)

SetNodeLogger with the function to be called for the nodes marked for logging during execution. If set to nil nothing will be logged.

func (*Exec) SetSideParamsHook

func (e *Exec) SetSideParamsHook(fn SideParamsFn) *Exec

SetSideParamsHook configures a function to be called just before executing a graph, so it can set extra parameters.

Mostly, this is for internal use and end-users will not likely need this. The context.Exec object uses this to pass the variable values as side inputs to the graph.

Exec takes care of creating parameters (with graph.Parameter) for every value passed to Call before calling the graph building function (the graph building function is executed only the first time, after the graph is compiled it is re-used for future executions).

But a graph building functions may want to create extra parameters itself (with graph.Parameter), which we call "side parameters".

The values to feed these "side parameters" are not passed to Exec.Call, but instead set with a SideParamsFn, which is configured here.

SideParamsFn is called after the graph is already built, just before the execution. It is passed with a slice of the backend.Buffer to be fed to the graph execution. The side parameters in this slice will be left nil, and it's expected that SideParamsFn will set them to the appropriate input.

It also includes the boolean map of the inputs to donate, which SideParamsFn can set accordingly (for the side parameters).

type ExecGraphFn

type ExecGraphFn interface {
	func(*Graph) *Node |
		func([]*Node) *Node |
		func(*Node) *Node |
		func(*Node, *Node) *Node |
		func(*Node, *Node, *Node) *Node |
		func(*Node, *Node, *Node, *Node) *Node |
		func(*Node, *Node, *Node, *Node, *Node) *Node |
		func(*Node, *Node, *Node, *Node, *Node, *Node) *Node |
		func(*Graph) (*Node, *Node) |
		func([]*Node) (*Node, *Node) |
		func(*Node) (*Node, *Node) |
		func(*Node, *Node) (*Node, *Node) |
		func(*Node, *Node, *Node) (*Node, *Node) |
		func(*Node, *Node, *Node, *Node) (*Node, *Node) |
		func(*Node, *Node, *Node, *Node, *Node) (*Node, *Node) |
		func(*Node, *Node, *Node, *Node, *Node, *Node) (*Node, *Node) |
		func(*Graph) (*Node, *Node, *Node) |
		func([]*Node) (*Node, *Node, *Node) |
		func(*Node) (*Node, *Node, *Node) |
		func(*Node, *Node) (*Node, *Node, *Node) |
		func(*Node, *Node, *Node) (*Node, *Node, *Node) |
		func(*Node, *Node, *Node, *Node) (*Node, *Node, *Node) |
		func(*Node, *Node, *Node, *Node, *Node) (*Node, *Node, *Node) |
		func(*Node, *Node, *Node, *Node, *Node, *Node) (*Node, *Node, *Node) |
		func(*Graph) []*Node |
		func([]*Node) []*Node |
		func(*Node) []*Node |
		func(*Node, *Node) []*Node |
		func(*Node, *Node, *Node) []*Node |
		func(*Node, *Node, *Node, *Node) []*Node |
		func(*Node, *Node, *Node, *Node, *Node) []*Node |
		func(*Node, *Node, *Node, *Node, *Node, *Node) []*Node
}

ExecGraphFn is a type parameter for accepted function types for NewExec constructor.

type ExecGraphFnOneOutput added in v0.11.1

type ExecGraphFnOneOutput interface {
	func(*Graph) *Node |
		func([]*Node) *Node |
		func(*Node) *Node |
		func(*Node, *Node) *Node |
		func(*Node, *Node, *Node) *Node |
		func(*Node, *Node, *Node, *Node) *Node |
		func(*Node, *Node, *Node, *Node, *Node) *Node |
		func(*Node, *Node, *Node, *Node, *Node, *Node) *Node
}

ExecGraphFnOneOutput are ExecGraphFn functions that return only one result. See ExecOnce.

type Graph

type Graph struct {
	// contains filtered or unexported fields
}

Graph with the operations and dependencies needed to run a computation.

func NewGraph added in v0.11.0

func NewGraph(backend backends.Backend, name string) *Graph

NewGraph constructs an empty Graph.

Empty Graph's can still be further configured (e.g. Graph.WithName) until one starts building a computation with them.

After building a computation they can be compiled (see Graph.Compile) at which point they can only be executed.

If they are finalized (see Graph.Finalize) resources are released immediately (instead of waiting for the GC) and the Graph can no longer be used.

func (*Graph) AssertBuilding added in v0.11.0

func (g *Graph) AssertBuilding()

AssertBuilding panics if graph is nil, has been finalized, or has already been compiled and therefore immutable. If Graph was in a configuring state (just after creation) this triggers it to enter into a "building" state.

func (*Graph) AssertCompiled added in v0.11.0

func (g *Graph) AssertCompiled()

AssertCompiled panics if graph is nil, if it has already been finalized or if it is not yet compiled.

func (*Graph) AssertConfiguring added in v0.11.0

func (g *Graph) AssertConfiguring()

AssertConfiguring panics if graph is not in "configuring" phase: that is, if one already started building a computation with it, or if it has already been compiled. It also panics if it is not valid (e.g.: if it has been finalized).

func (*Graph) AssertValid added in v0.5.0

func (g *Graph) AssertValid()

AssertValid panics if graph is nil, or if it has already been finalized.

func (*Graph) Backend added in v0.11.0

func (g *Graph) Backend() backends.Backend

Backend this Graph is using.

func (*Graph) Compile

func (g *Graph) Compile(outputs ...*Node)

Compile just-in-time (JIT) compiles the Graph into a Computation that can be executed.

At least one output must be given.

func (*Graph) Finalize

func (g *Graph) Finalize()

Finalize frees the associated data with the compiled graph (if it is compiled) and all the nodes. The graph is left in an unusable state. It is safe to call it more than once — subsequent calls are no-ops.

func (*Graph) GetNodeByAlias added in v0.17.0

func (g *Graph) GetNodeByAlias(alias string) *Node

GetNodeByAlias returns a node with the given alias or nil if it didn't find it.

If the search alias has an absolute scope (path), meaning if it starts with AliasScopeSeparator, it is searched as is. If not, it is prefixed with the current scope before searching.

See Node.WithAlias to create node aliases, and Graph.PushAliasScope and Graph.PopAliasScope to manipulate the scope of the aliases created.

func (*Graph) GetParameterByHandle added in v0.11.0

func (g *Graph) GetParameterByHandle(handle ParameterHandle) *Node

GetParameterByHandle returns the ii-th parameter, in order of creation, registered for this graph.

func (*Graph) GetParameterByName added in v0.11.0

func (g *Graph) GetParameterByName(name string) (node *Node)

GetParameterByName returns the parameter registered with the given name. Returns nil if the parameter with the given name hasn't been registered (see Parameter method).

func (*Graph) GraphId

func (g *Graph) GraphId() GraphId

GraphId is a globally unique id (even across Backend's) of the graph. It's a counter that starts with 0.

func (*Graph) IsValid added in v0.11.0

func (g *Graph) IsValid() bool

IsValid returns whether the Graph is in a valid state: it is valid if it is in a configuring, building or compiled state.

func (*Graph) IterAliasedNodes added in v0.17.0

func (g *Graph) IterAliasedNodes() iter.Seq2[string, *Node]

IterAliasedNodes provides an iterator over all aliased nodes. It yields pairs (alias, node). The aliases are sorted before iteration.

func (*Graph) LastNode added in v0.11.0

func (g *Graph) LastNode() *Node

LastNode returns the last node created. It returns nil if no node has been created for this graph yet.

func (*Graph) LoggedNodes

func (g *Graph) LoggedNodes() (nodes []*Node)

LoggedNodes returns all nodes from the graph marked to be logged. Exec object makes use of this information and logs those values when executing the graph.

func (*Graph) Name

func (g *Graph) Name() string

Name of the computation this Graph defines, set during its construction.

func (*Graph) NodeById

func (g *Graph) NodeById(id NodeId) *Node

NodeById returns the node for the given id.

func (*Graph) Nodes added in v0.11.0

func (g *Graph) Nodes() []*Node

Nodes return a slice of all nodes. The slice is owned by Graph and shouldn't be changed.

func (*Graph) NumParameters

func (g *Graph) NumParameters() int

NumParameters returns the number of parameters created for this graph.

func (*Graph) PopAliasScope added in v0.17.0

func (g *Graph) PopAliasScope()

PopAliasScope removes the scope previously pushed with PushAliasScope.

It panics if there are no scopes pushed.

func (*Graph) PushAliasScope added in v0.17.0

func (g *Graph) PushAliasScope(scope string)

PushAliasScope pushes another scope to the current alias scope for new aliases.

For instance, for an image model, one may want to push a scope per layer, and create an alias "output" to the node with the output of the layer. The different scope helps differentiate the different "output" aliases nodes.

Notice this is orthogonal to the context.Context scope used for variables. That's because for instance, one may reuse a model multiple times for different inputs (e.g.: triplet loss will use the same model for the "anchor", "positive" and "negative" examples, or a style transfer model will use the same embedding model for the "source", "style" and "target" images), so the variables context scope is the same, but we want a different alias scope, so we can access the outputs per layer of each type of the example separately.

Each call to Graph.PushAliasScope should be matched by a call to Graph.PopAliasScope, usually using defer.

func (*Graph) Run

func (g *Graph) Run(inputs ...any) (outputs []*tensors.Tensor)

Run the compiled Graph with the inputs given in order -- same order as the parameters were created.

The values for inputs can be:

1. A tensors.Tensor. 2. Any multi-dimensional slice (e.g.: [][]float32 for a 2D float32 value) that is dynamically converted to a temporary tensor. 3. The output of DonateTensorBuffer, which then donates the device buffer being used by a tensor -- if there are any.

This is a very "bare bones" way to running the Graph. Typically, one would use the Exec object instead (which dynamically generates a new Graph for inputs of different shapes when needed).

To donate the inputs buffers (if they are no longer used, e.g. when updating a state), consider using DonateTensorBuffer.

func (*Graph) RunWithBuffers added in v0.11.0

func (g *Graph) RunWithBuffers(inputs []backends.Buffer, donate []bool) (outputs []*tensors.Tensor)

RunWithBuffers executes the graph using as inputs the on-device buffers.

For the normal user, consider using the Exec wrapper, or Graph.Run.

The donate slice indicates which buffers can be donated to the execution -- they are immediately finalized after the execution is finished.

Notice that for repeated output nodes in the graph (the same output node returned in more than one position), the returned tensors are shared.

func (*Graph) RunWithMap added in v0.11.0

func (g *Graph) RunWithMap(inputs ParamsMap) (outputs []*tensors.Tensor)

RunWithMap runs the compiled graph with the inputs given as a map of the corresponding parameter node to tensor value to use.

The params can use Go values, Local tensors or Device tensors. Go values and Local tensors will be transferred to Device tensors (located in the Backend's accelerator memory) before the graph is executed.

This is a very "bare bones" way to running the Graph. Typically, one would use the Exec object instead (which dynamically generates a new Graph for inputs of different shapes when needed).

To donate the inputs buffers (if they are no longer used, e.g. when updating a state), consider using DonateTensorBuffer.

func (*Graph) SetTraced

func (g *Graph) SetTraced(traced bool)

SetTraced defines whether each node creation is traced. If true, every node will save a stack-trace of where it was created, which is helpful for debugging. See Node.Track().

This is expensive, but can be handy for debugging.

func (*Graph) String

func (g *Graph) String() string

String converts the Graph to a multiline string with a description of the full graph.

func (*Graph) WithName added in v0.11.0

func (g *Graph) WithName(name string) *Graph

WithName sets the name of the Graph.

It can only be called just after creation of the graph with NewGraph. Once any operation is created with the graph, the graph configuration (e.g.: Graph.WithName) become immutable, and changing them will panic.

It returns the graph passed, so configuring methods can be cascaded.

type GraphId

type GraphId int

GraphId is a unique Graph id within a manager.

type InterpolationConfig added in v0.2.0

type InterpolationConfig struct {
	// contains filtered or unexported fields
}

InterpolationConfig is created with Interpolate and then actually executed with a call to Done.

Between the its construction and execution one can set the various parameters for the interpolation.

func Interpolate added in v0.2.0

func Interpolate(input *Node, outputSizes ...int) *InterpolationConfig

Interpolate will interpolate the tensor to the given output sizes. The outputSizes should have the same rank as the image, and can have values set to NoInterpolation (`== -1`) for dimensions that shouldn't be changed.

Example:

image.AssertDims(100, 72, 72, 3)  // Shape `[batch_size, height=72, width=72, depth]`
image = Interpolate(image, -1, 64, 64, -1).Bilinear().Done()
image.AssertDims(100, 64, 64, 3)

Interpolate will return an InterpolationConfig that can be configured. When Done() is called it builds the graph for the interpolation and returns the interpolated tensor. The default set up is using Bilinear interpolation and HalfPixelCenters set to true.

This can be used for images (2D) but also for anything volumetric (3D or higher) or also for time-series (1D).

The implementation is based on the Tensorflow `tf2xla` one.

func (*InterpolationConfig) AlignCorner added in v0.2.0

func (c *InterpolationConfig) AlignCorner(alignCorner bool) *InterpolationConfig

AlignCorner configures the interpolation to be value of "align corner": if set to true, the input and output tensors corner pixels are aligned at the center points of their corner pixels, preserving the values at the corner pixels. If set to false, the input and output tensors are aligned by the corner points of their corner pixels, and the interpolation uses edge value padding for out-of-boundary values. Default is false.

One cannot select both, HalfPixelCenters(true) and AlignCorner(true).

Default is true.

It returns the InterpolationConfig passed, to allow cascaded method calls.

func (*InterpolationConfig) Bilinear added in v0.2.0

Bilinear configures the interpolation to be bilinear (as opposed to nearest). Default is Bilinear. See also Nearest.

It returns the InterpolationConfig passed, to allow cascaded method calls.

Note: there is a bug in that makes the Bilinear gradient fail if the input dimensions to interpolate <= 3. Use Nearest instead for now.

func (*InterpolationConfig) Done added in v0.2.0

func (c *InterpolationConfig) Done() (output *Node)

Done finishes the configuration of the interpolation and creates the computation graph that resizes the input to the given output sizes. It returns the resized input.

Any errors are returned in the graph.

func (*InterpolationConfig) HalfPixelCenters added in v0.2.0

func (c *InterpolationConfig) HalfPixelCenters(halfPixelCenters bool) *InterpolationConfig

HalfPixelCenters is used, if set. Defaults to true.

One cannot select both, HalfPixelCenters(true) and AlignCorner(true).

Default is false.

It returns the InterpolationConfig passed, to allow cascaded method calls.

func (*InterpolationConfig) Nearest added in v0.2.0

Nearest configures the interpolation to be bilinear (as opposed to nearest). Default is Bilinear. See also Bilinear.

It returns the InterpolationConfig passed, to allow cascaded method calls.

type LoggerFn

type LoggerFn func(graph *Graph, messages []string, values []*tensors.Tensor, nodes []NodeId)

LoggerFn is the function used to log nodes marked for logging. It is called after the Call method, with the list of messages and corresponding values of the evaluated nodes.

type Node

type Node struct {
	// contains filtered or unexported fields
}

Node represents the result of an operation in the computation graph, and can be used as input to further operations.

Internally, it keeps tracks of all parameters used for the computation: this is later used for auto-differentiation (see Gradient).

It also stores meta-information: see Node.SetLogged, Node.StopGradient.

Notice some complex methods offered in this package may be implemented with several instances of simpler operations and yield several nodes in the graph, that's normal.

Node.String allows for a pretty-printing of node. To see the full graph with all nodes, use Graph.String.

func Abs

func Abs(x *Node) (node *Node)

Abs returns the Op that represents the output of the corresponding operation.

func Add

func Add(x0 *Node, x1 *Node) (node *Node)

Add returns the element-wise sum of the two values. Standard broadcasting rules apply (see documentation). The op is created on the same XlaBuilder as used for x0 and x1.

func AddScalar

func AddScalar[N dtypes.NumberNotComplex](x *Node, scalar N) *Node

AddScalar converts scalar to a constant with x's DType and returns `x + scalar` with proper broadcasting.

func And

func And(lhs, rhs *Node) *Node

And is an alias for LogicalAnd.

func ArgMax added in v0.4.0

func ArgMax(x *Node, axis int, outputDType ...dtypes.DType) (output *Node)

ArgMax returns the index of the largest element across the given axis.

The selected axis is reduced, and the output has one fewer axes (rank `x.Rank() - 1`). The output `DType`, if not given, is `dtypes.Int32`.

Ties are resolved by returning the smallest index.

func ArgMin added in v0.4.0

func ArgMin(x *Node, axis int, outputDType ...dtypes.DType) (output *Node)

ArgMin returns the index of the smallest element across the given axis.

The selected axis is reduced, and the output has one fewer axes (rank `x.Rank() - 1`). The output `DType`, if not given, is `dtypes.Int32`.

Ties are resolved by returning the smallest index.

func BackendGather added in v0.19.0

func BackendGather(operand *Node, startIndices *Node, indexVectorAxis int, offsetAxes []int, collapsedSliceAxes []int, startIndexMap []int, sliceSizes []int, indicesAreSorted bool) (node *Node)

BackendGather exposes the raw backend Gather operator.

This should be internal and it is exposed only for debugging purposes, please don't rely on it. If it turns out you need some functionality here that is not provided in Gather or GatherSlices, open an issue in GoMLX and we'll figure a betterAPI.

See convoluted and circular description in https://openxla.org/xla/operation_semantics#gather

func BackendScatterMax added in v0.19.0

func BackendScatterMax(operand, indices, updates *Node, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) *Node

BackendScatterMax exposes the raw backend ScatterMax operator.

This should be internal, and it is exposed only for testing and debugging purposes, please don't rely on it. If it turns out you need some functionality here that is not provided in ScatterMax, open an issue in GoMLX and we'll figure a betterAPI.

Description in https://openxla.org/xla/operation_semantics#scatter

func BackendScatterMin added in v0.19.0

func BackendScatterMin(operand, indices, updates *Node, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) *Node

BackendScatterMin exposes the raw backend ScatterMin operator.

This should be internal, and it is exposed only for testing and debugging purposes, please don't rely on it. If it turns out you need some functionality here that is not provided in ScatterMin, open an issue in GoMLX and we'll figure a betterAPI.

Description in https://openxla.org/xla/operation_semantics#scatter

func BackendScatterSum added in v0.19.0

func BackendScatterSum(operand, indices, updates *Node, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) *Node

BackendScatterSum exposes the raw backend ScatterSum operator.

This should be internal, and it is exposed only for testing and debugging purposes, please don't rely on it. If it turns out you need some functionality here that is not provided in ScatterSum, open an issue in GoMLX and we'll figure a betterAPI.

Description in https://openxla.org/xla/operation_semantics#scatter

func BitCount added in v0.13.0

func BitCount(operand *Node) (node *Node)

BitCount returns the number of bits that are set to one.

func Bitcast added in v0.17.1

func Bitcast(x *Node, targetDType dtypes.DType) (node *Node)

Bitcast performs an elementwise bit-cast operation from a dtype to another dtype. The bitcast doesn't "convert" anything, it just reinterprets the bits from x.DType() to the targetDType. If x.DType() and targetDType use the same number of bytes (targetDType.Size() = x.DType().Size()), the dimensions are not changed, simply the dtype is changed. If targetDType.Size() > x.DType().Size(), it requires that x last axis to have a dimension of targetDType.Size() / x.DType().Size(), and the returned shape will trim the last axis. If targetDType.Size() < x.DType().Size(), the returned shape will have an extra axis in the end, with dimension of x.DType().Size() / targetDType.Size(). E.g: Bitcast([1]uint32{0xdeadbeef}, dtypes.UInt16) -> [1][2]uint16{{0xdead, 0xbeef}}

func BitwiseAnd added in v0.17.0

func BitwiseAnd(x0 *Node, x1 *Node) (node *Node)

BitwiseAnd returns the element-wise bitwise AND operation. The op is created on the same XlaBuilder as used for x0 and x1.

func BitwiseNot added in v0.17.0

func BitwiseNot(x *Node) (node *Node)

BitwiseNot returns the element-wise bitwise AND operation.

func BitwiseOr added in v0.17.0

func BitwiseOr(x0 *Node, x1 *Node) (node *Node)

BitwiseOr returns the element-wise bitwise OR operation. The op is created on the same XlaBuilder as used for x0 and x1.

func BitwiseShiftLeft added in v0.17.0

func BitwiseShiftLeft(x, n *Node) *Node

BitwiseShiftLeft n bits of integer values. It implicitly preserves the sign bit, if there is no overflow. So BitwiseShiftLeft(-1, 1) = -2.

func BitwiseShiftLeftScalar added in v0.17.0

func BitwiseShiftLeftScalar[T dtypes.NumberNotComplex](x *Node, n T) *Node

BitwiseShiftLeftScalar is an alias to BitwiseShiftLeft, but takes n as a scalar.

func BitwiseShiftRightArithmetic added in v0.17.0

func BitwiseShiftRightArithmetic(x, n *Node) *Node

BitwiseShiftRightArithmetic n bits of integer values, preserving the sign bit. So ShiftRight(-2, 1) = -1. See also BitwiseShiftRightLogical for a version the ignores the sign bit.

func BitwiseShiftRightArithmeticScalar added in v0.17.0

func BitwiseShiftRightArithmeticScalar[T dtypes.NumberNotComplex](x *Node, n T) *Node

BitwiseShiftRightArithmeticScalar is an alias to BitwiseShiftRightArithmetic, but takes n as a scalar. It shifts n bits of integer values, preserving the sign bit. So ShiftRight(-2, 1) = -1.

func BitwiseShiftRightLogical added in v0.17.0

func BitwiseShiftRightLogical(x, n *Node) *Node

BitwiseShiftRightLogical n bits of integer values, ignoring the sign bit. See also BitwiseShiftRightArithmetic for a version that preserves the sign bit.

func BitwiseShiftRightLogicalScalar added in v0.17.0

func BitwiseShiftRightLogicalScalar[T dtypes.NumberNotComplex](x *Node, n T) *Node

BitwiseShiftRightLogicalScalar is an alias to BitwiseShiftRightLogical, but takes n as a scalar. It shifts right n bits of integer values, ignoring the sign bit.

func BitwiseXor added in v0.17.0

func BitwiseXor(x0 *Node, x1 *Node) (node *Node)

BitwiseXor returns the element-wise bitwise XOR operator. The op is created on the same XlaBuilder as used for x0 and x1.

func BroadcastPrefix

func BroadcastPrefix(x *Node, dims ...int) *Node

BroadcastPrefix adds dimensions to an array by duplicating the data in the array.

The new dimensions dims are inserted on the left, i.e., if broadcast_sizes has values `{a0, ..., aN}` and the operand shape has dimensions {b0, ..., bM} then the shape of the output has dimensions {a0, ..., aN, b0, ..., bM}.

The new dimensions id into copies of the operand, i.e.

output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]

func BroadcastToDims

func BroadcastToDims(x *Node, dimensions ...int) *Node

BroadcastToDims broadcasts x to the given dimensions. x must have an equal or lower rank than the given dimensions, and if there are more dimensions than x rank, x will be expanded at the end (so new axes will be appended to x). Dimensions of x must either match the corresponding value in dimensions, or they must be 1, in which case they are broadcast.

It works as expected if x is a scalar.

See also the equivalent BroadcastToShape.

func BroadcastToShape

func BroadcastToShape(x *Node, shape shapes.Shape) *Node

BroadcastToShape broadcasts x to the given shape. x must have an equal or lower rank than shape, and if shape has higher rank, x will be expanded at the end (so new axes will be appended to x). Dimensions of x must either match the corresponding dimension in shape, or they must be 1, in which case they are broadcast.

It works as expected if x is a scalar.

Notice that the dtype of shape is ignored, the returned value preserves the dtype of x.

This is equivalent to BroadcastToDims(x, shape.Dimensions...).

func Ceil

func Ceil(x *Node) (node *Node)

Ceil returns the Op that represents the output of the corresponding operation.

func Clip

func Clip(x, min, max *Node) *Node

Clip is a shortcut to `Min(max, Max(x, min))`, which returns the values of x clipped between min and max.

func ClipScalar added in v0.4.0

func ClipScalar(x *Node, min, max float64) *Node

ClipScalar is a shortcut to `Min(max, Max(x, min))`, which returns the values of x clipped between min and max. The values min and max are given as scalar values -- the float64 is converted to the `DType` of x.

func Clz

func Clz(x *Node) (node *Node)

Clz returns element-wise the "count leading zeros" bits of input node x -- for integer values.

func Complex added in v0.6.0

func Complex(x0 *Node, x1 *Node) (node *Node)

Complex returns the complex number taking x0 as the real part and x1 as the imaginary part. The real (x0) and imaginary (x1) must have the same dtype, and they must be either `dtypes.Float32` or `dtypes.Float64`. The output will be either `dtypes.Complex64` or `dtypes.Complex128`, depending on x0 and x1 dtypes. The shapes of `real` or `imaginary` must be the same, or one must be a scalar, in which case the value is broadcast to every other value. The op is created on the same XlaBuilder as used for x0 and x1.

func Concatenate

func Concatenate(operands []*Node, axis int) *Node

Concatenate results on the given axis. A negative axis will be counted from the end -- so `axis==-1` means the last axis.

If operands are scalars, they will be concatenated to a vector (just use `axis=0`).

func Conj added in v0.6.0

func Conj(x *Node) (node *Node)

Conj returns the conjugate of a complex number. E.g: Conj(1+3i) = 1-3i

func ConsecutiveDifference added in v0.13.0

func ConsecutiveDifference(x *Node, axis int, preserveShape bool) *Node

ConsecutiveDifference is the inverse of CumSum: it outputs the difference from each number to be previous on the selected axis.

If preserveShape is true, the first element is preserved, and the shape is preserved, in which case we have ConsecutiveDifference(CumSum(x)) == x.

If preserveShape is false, just the differences are returned, and the resulting shape has the selected axis shrunk by 1.

Examples:

ConsecutiveDifference([2, 4, 8], 0, true) = [2, 2, 4]
ConsecutiveDifference([2, 4, 8], 0, false) = [2, 4]
ConsecutiveDifference([[1, 3, 6], [4, 9, 15]], -1, true) = [[1, 2, 3], [4, 5, 6]]
ConsecutiveDifference([[1, 2, 3], [5, 7, 9]], 0, true) = [[1, 2, 3], [4, 5, 6]]

func Const

func Const(g *Graph, x any) *Node

Const creates constant nodes in the Graph. It can take a tensor as well as multidimensional slices (or scalars).

It uses tensor.FromAnyValue to figure out the shape given a Go scalar/slice/array. If the value is unsupported, it panics.

A tensor.Device (e.g., generated by another computation) will be converted to local first. If you are creating very large constants that don't need to be materialized locally, consider instead storing them as variables in the context, or as a side parameter.

func ConstAs

func ConstAs(base *Node, x any) *Node

ConstAs creates a constant (slice or scalar) of the same DType and on the same Graph as the given base.

func ConstAsDType

func ConstAsDType(g *Graph, dtype dtypes.DType, x any) *Node

ConstAsDType creates a constant of the given DType. It adds the convenience of converting x (slice or scalar) to the appropriate type. E.g.:

Pi := ConstAsDType(g, myDType, math.Pi)
PiAndE := ConstAsDType(g, myDType, []float64{math.Pi, math.E})

func ConstCachedTensor added in v0.13.0

func ConstCachedTensor(g *Graph, t *tensors.Tensor) *Node

ConstCachedTensor returns a constant node for the tensor t. If it's the first time the tensor is used in this graph, a new node is created. Otherwise, a previously created node is reused.

The caching of the tensor has the side effect of keeping the tensor alive (and its memory resources) util the Graph itself is garbage collected. If this is a concern, use ConstTensor instead.

TODO:this can be made default (ConstTensor) once weak references land into Go and the issue of keeping the tensor alive is resolved. See discussion in https://github.com/golang/go/issues/67552 and cache with weak references example in https://github.com/golang/go/issues/67552#issuecomment-2200755798

func ConstTensor added in v0.11.0

func ConstTensor(g *Graph, t *tensors.Tensor) (node *Node)

ConstTensor returns a newly created constant node for the tensor t.

The value of t is copied into the graph. It's recommended that for very large tensors, even if constants, that they are passed as side inputNodes (or variables, see context package) instead.

See also ConstCachedTensor if you think you'll use the same tensor multiple times in the same graph.

func ConvGeneralDilated added in v0.13.0

func ConvGeneralDilated(input, kernel *Node, axes ConvolveAxesConfig,
	strides []int, paddings [][2]int, inputDilation, filterDilation []int,
	filterGroupCount, batchGroupCount int) *Node

ConvGeneralDilated is a generic Convolution operation. See Convolve for the simpler version. featureAxisAfter defines whether the features (aka. channels or depth) axis comes after the spatial dimension. Example: a 2D input can be one of the two:

  • featureAxisAfter=false: input=[batch_size, features, height, width], filter=[output_features, input_features, height, width]
  • featureAxisAfter=true: input=[batch_size, height, width, features], filter=[output_features, height, width, input_features]

Some details in https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution. (XLA documentation is really poor here, much is guess-work). Also useful is https://arxiv.org/pdf/1603.07285v1.pdf. Not exported for now, hopefully Convolve will suffice.

filterGroupCount and batchGroupCount are not supported yet for backpropagation. Please create an issue if you come to need that.

func ConvertDType added in v0.11.0

func ConvertDType(x *Node, dtype dtypes.DType) (node *Node)

ConvertDType of x to dtype. If x is already of the given dtype, it's a no-op.

func ConvertType

func ConvertType(x *Node, dtype dtypes.DType) *Node

ConvertType is an alias to ConvertDType. Deprecated: use ConvertDType instead.

func Cos

func Cos(x *Node) (node *Node)

Cos returns the Op that represents the output of the corresponding operation.

func CosineSimilarity added in v0.17.1

func CosineSimilarity(lhs *Node, rhs *Node, axis int) *Node

CosineSimilarity calculates the cosine similarity between the lhs and rhs nodes along the given axis. A typical value for axis is -1, it calculates the cosine similarity for the last dimension.

The output will have the same rank, but the axis is contracted to 1, and will hold the similarity.

func CumSum added in v0.11.1

func CumSum(x *Node, axis int) *Node

CumSum returns the cumulative sum along the given axis.

Example:

CumSum([[1, 2, 3], [4, 5, 6]], -1) = [[1, 3, 6], [4, 9, 15]]
CumSum([[1, 2, 3], [4, 5, 6]], 0) = [[1, 2, 3], [5, 7, 9]]

func Diagonal

func Diagonal(g *Graph, dim int) *Node

Diagonal returns a diagonal boolean square matrix of shape `[dim, dim]`.

This can be combined with `Where` to select values of any arbitrary other matrix.

func DiagonalWithValue

func DiagonalWithValue(scalar *Node, dim int) *Node

DiagonalWithValue returns a diagonal matrix of shape `[dim, dim]` with scalar in the diagonal and zero elsewhere. Set scalar to `ScalarOne()` and you get an identity matrix.

func Div

func Div(x0 *Node, x1 *Node) (node *Node)

Div returns the element-wise division of the two values. Standard broadcasting rules apply (see documentation). The op is created on the same XlaBuilder as used for x0 and x1.

func DivScalar added in v0.3.0

func DivScalar[N dtypes.NumberNotComplex](x *Node, scalar N) *Node

DivScalar converts scalar to a constant with x's DType and returns `x / scalar` with proper broadcasting.

For float DType's, DivScalar instead uses MulScalar(x, 1/scalar).

func Dot

func Dot(x0 *Node, x1 *Node) (node *Node)

Dot returns the "dot product" operation. The exact semantics of this operation depend on the ranks of the operands: | Input | Output | Semantics | | vector [n] dot vector [n] | scalar | vector dot product | | matrix [m x k] dot vector [k] | vector [m] matrix-vector multiplication | | matrix [m x k] dot matrix [k x n] | matrix [m x n] | matrix-matrix multiplication | The operation performs sum of products over the second dimension of x0 (or the first if it has rank 1) and the first dimension of x1. These are the "contracted" dimensions. The contracted dimensions of x0 and x1 must be of the same size. In practice, it can be used to perform dot products between vectors, vector/matrix multiplications or matrix/matrix multiplications. The op is created on the same XlaBuilder as used for x0 and x1.

func DotGeneral added in v0.11.0

func DotGeneral(lhs *Node, lhsContractingAxes, lhsBatchAxes []int, rhs *Node, rhsContractingAxes, rhsBatchAxes []int) *Node

DotGeneral takes as input lhs (left-hand-side) and rhs (right-hand-side) specifications for a general vector product -- a generalized "Einsum". Each axis can be:

  • Just aligned (batch axes), so the output has the same axes as the inputs. The dimensions must match in lhs and rhs.
  • Crossed (default), in which case the output is the combination (concatenation) of the dimensions.
  • Contracted (contracting axes), where the output does multiply the values and reduce sum those dimensions.

It follows that the resulting dimension number starts with the batch dimension, then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs' non-contracting/non-batch dimension. It provides the basic means of implementing Einsum.

func DynamicSlice added in v0.11.1

func DynamicSlice(operand *Node, startIndices []*Node, sliceDims []int) (node *Node)

DynamicSlice extracts a sub-array from the input array at dynamic start_indices. The size of the slice in each axis is passed in sliceDims, which specify the slice intervals for each axis: [start, start + size). The shape of startIndices must be rank == 1, with dimension size equal to the rank of operand. See description in https://openxla.org/xla/operation_semantics#dynamicslice

func DynamicUpdateSlice added in v0.11.1

func DynamicUpdateSlice(operand *Node, update *Node, startIndices []*Node) (node *Node)

DynamicUpdateSlice generates a result which is the value of the input array operand, with a slice update overwritten at startIndices. The shape of update determines the shape of the sub-array of the result which is updated. The shape of startIndices must be rank == 1, with dimension size equal to the rank of operand. See description in https://openxla.org/xla/operation_semantics#dynamicupdateslice

func Einsum

func Einsum(equation string, lhs, rhs *Node) *Node

Einsum evaluates the "Einstein summation" various types of products (inner/outer/batched) between 2 tensors, on arbitrary dimensions. This version uses a textual description on how to manipulate the axes. See EinsumAxes for a version where the axes are given numerically.

This is inspired on numpy Einsum, a description of which can be seen in https://stackoverflow.com/questions/26089893/understanding-numpys-einsum/33641428#33641428.

The equation string describes what to do with each dimension, for each operand, separated by ",", and the format of the result after the "->" describes what is to be made for each dimension.

Examples:

* `Einsum("ij,jk->ik", matrixA, matrixB)` performs the usual matrix multiplication. * `Einsum("bij,bjk->bik", batchedMatrixA, batchedMatrixB)` performs a batched matrix multiplication. * `Einsum("i,i->", vectorA, vectorB)` performs a dot product. * `Einsum("i,j->ij", vectorA, vectorB)` performs an outer (cross) product between two vectors.

It also works for higher dimension tensors. Dimensions missing on the output (after "->") are reduce-summed.

More examples in TensorFlow documentation: https://www.tensorflow.org/api_docs/python/tf/einsum

Notice though that this Einsum is only defined for operations between 2 operands:

- `lhs`: left-hand-side operand. - `rhs`: right-hand-side operand.

Important note: the order of the operands can have a dramatic impact on the speed of the multiplications. consider trying both sides.

func EinsumAxes added in v0.2.0

func EinsumAxes(lhs, rhs *Node, contractingAxes, batchAxes [][2]int) (output *Node)

EinsumAxes evaluates the "Einstein summation" various types of products (inner/outer/batched) between 2 tensors, on arbitrary dimensions. Similar to Einsum, but it uses the explicit numeric axis, as opposed to a textual description.

There are two operands: `lhs` (left-hand-side) and `rhs` (right-hand-side). The default for every axis is to do a cross-product, and the resulting tensor will have the concatenated shape (`lhs` dimensions first then `rhs` dimensions).

One can specify contractionAxes, pairs of axes (each pair with one index in the lhs and rhs operands) to be contracted: these dimensions will multiplied and summed one at a time. That's what happens in the usual "dot product."

One can also specify batchAxes, pairs of axes (each pair with one index in the lhs and rhs operands) to be considered as independently, as a batch dimension. These dimensions will show up in the same position as the `lhs`.

Examples:

  • `EinsumAxes(matrixA, matrixB, [][2]int{{1, 0}}, nil)` performs the usual matrix multiplication, where we contract axis 1 of `matrixA` with axis 0 of `matrixB`.
  • `EinsumAxes(batchedMatrixA, batchedMatrixB, [][2]int{{2, 1}}, [][2]int{{0, 0}})` is similar, but we use axis 0 of both inputNodes as a batch, and following 2 axes as a matrix multiplication.
  • `EinsumAxes(vectorA, vectorB, nil, nil)` performs an outer (cross) product -- no contractions, no batch.
  • `EinsumAxes(vectorA, vectorB, [][2]int{{0, 0}}, nil)` performs a dot product and returns a scalar.

Important note: the order of the operands can have a dramatic impact on the speed of the multiplications. Consider trying both sides.

func Equal

func Equal(x0 *Node, x1 *Node) (node *Node)

Equal performs element-wise equality check, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func EqualTotalOrder

func EqualTotalOrder(x0 *Node, x1 *Node) (node *Node)

EqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func Erf added in v0.12.0

func Erf(x *Node) (node *Node)

Erf returns the "error function", defined as erf(x) = 2/Pi * \int_{0}^{x}{e^{-t^2}dt}.

func Exp

func Exp(x *Node) (node *Node)

Exp returns the Op that represents the output of the corresponding operation.

func ExpandAndBroadcast added in v0.2.0

func ExpandAndBroadcast(x *Node, newDimensions []int, expandedAxes []int) (output *Node)

ExpandAndBroadcast combines ExpandAxes and broadcast of axes of `x`, the returned shape will be newDimensions. Only newly expanded axes can be broadcast.

  • newDimensions should have a rank larger than the rank of x, and the new axes in newDimensions should be listed in expandedAxes. In other words: `x.Rank() + len(expandedAxes) == len(newDimensions)`.

  • expandedAxes refer to the axes in newDimensions that are expanded and going to be broadcast. The reminder dimensions in newDimensions much match the corresponding in x.

Example:

x = Const(g, []int32{10, 20})
ExpandAndBroadcast(x, []int{2, 2}, []int{0})  // -> [][]int32{{10, 20}, {10, 20}}
ExpandAndBroadcast(x, []int{2, 2}, []int{1})  // -> [][]int32{{10, 10}, {20, 20}}

func ExpandAxes added in v0.15.0

func ExpandAxes(x *Node, newAxes ...int) *Node

ExpandAxes expands x creating new axes at the positions given by newAxes -- the positions are given at the target shape.

The list newAxes represent the positions in the returned shape. If newAxes[ii] < 0, then they are counted from the end of the new shape — -1 represents the last axis in the new shape.

There should be no repeated values in newAxes -- since they represent the positions in the returned shape, it wouldn't make sense.

See also InsertAxes, where the new axes are given as positions in the target shape.

func ExpandDims deprecated

func ExpandDims(x *Node, beforeAxes ...int) *Node

ExpandDims is an alias to InsertAxes.

Deprecated: this will be removed at the next release! Notice this has a different semantics than the more common numpy.expand_dims (which is matched by ExpandAxes). Please use InsertAxes or ExpandAxes instead.

func ExpandLeftToRank added in v0.4.0

func ExpandLeftToRank(x *Node, newRank int) (output *Node)

ExpandLeftToRank prepend axes of dimension 1 to x, until it reaches rank `newRank`.

func Expm1

func Expm1(x *Node) (node *Node)

Expm1 returns the Op that represents the output of the corresponding operation.

func FFT added in v0.6.0

func FFT(operand *Node) *Node

FFT computes a forward 1D fast-fourier transformation of the operand, which is expected to be complex. The FFT is computed on the last dimension, in case `operand.Rank() > 1`.

The resulting tensor (Node) has the same shapes as the input, and has the values on the frequency domain. Use InverseFFT to reverse the result.

func FillScalar added in v0.2.0

func FillScalar(g *Graph, shape shapes.Shape, value float64) *Node

FillScalar creates a Node with a value with the given shape, filled with the given value. It's implemented indirectly using other nodes.

func Floor

func Floor(x *Node) (node *Node)

Floor returns the Op that represents the output of the corresponding operation.

func Gather

func Gather(params, indices *Node, indicesAreSorted ...bool) *Node

Gather values in params from the pointers in indices. The outputs are slices of `params` selected by `indices`, stitched together.

Let's assume params has shapes `[i_1, ..., i_N, s_1, ..., s_S]`, where:

  • `i_1, ..., i_N` are the N "indexed axes", that is, the axes that are indexed by `indices`.
  • `s_1, ..., s_S` are the S dimensions of the slices that are going to be "gathered" (copied over).

And let's assume indices has shapes `[o_1,...,o_O, N]`, where:

  • `o_1, ..., o_O` are "batch dimensions" of the slices from `params` to gather, that will be included in the output. E.g.: let's say O=1, and o_1=3, that means there will be 3 slices to gather.
  • Last dimension `N`: this is the number of indices in `params` to point to. `N` is the number of dimensions indexed `i_1, ..., i_N` in `params` above.

The output will have shapes `[o_1,...,o_O, s_1, ... s_S]`, where:

  • `o_1, ..., o_O` come from indices, and are enumerations of the slices from params to gather.
  • `s_1, ..., s_S` are the slice sizes copied from params.

indicesAreSorted can be set if you know the indices in start are sorted, in some backends this allows for optimizations. If not set it default to false.

For example:

params := [][]float32{{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}
indices := [][]int{{1}, {0}}
Gather(params, indices) would return {{3, 4, 5}, {0, 1, 2}}

In the case above params shapes is interpreted as `[i_1=3, s_1=3]`, and indices' shapes is `[o_1=2, N=1]`. The output shapes is `[o_1=2, s_1=3]`.

func GatherSlices added in v0.2.0

func GatherSlices(input *Node, slicedAxes []int, start *Node, sizes []int, indicesAreSorted bool) (gathered *Node)

GatherSlices from inputNodes. Each axis listed in slicedAxes have corresponding start position and size for each slice indexed by `start` (a graph Node, can be dynamically generated in the graph) and `sizes`, which will define the output final shapes, and must be statically given.

Axes in slicedAxes can be given as negative numbers, which are taken from the the end of the input rank -- that is axis -1 means the last axis in the input. Axes not given in slicedAxes (and in `start` and `sizes`) are taken in full length.

Axes in slicedAxes must be given sorted in increasing order.

The output has a rank equal to the prefixing rank of `start` (== `start.Rank()-1`) plus the rank of `input`. And the shapes will depend on the sizes of the slices.

  • TODO: Add an option to support batch axes, present in both the input and in the start indices. This will need to automatically concatenate the batch index in the start Node as a iota of each batch example, and add the size 1 slice. This can be done manually today.

indicesAreSorted can be set if you know the indices in start are sorted, in some backends this allows for optimizations.

Example:

	x := IotaFull(g, shapes.Make(dtypes.Float64, 3, 10, 10))  // 300 in total.
	start := Const(g, [][]int32{{0, 3}, {1, 2}})  // 2 slices
	sizes := []int{1, 3}
	slices := GatherSlices(x, []int{1,2}, start, sizes, true)  // Axis=0 is taken in full.
    slices.AssertDims(2, 3, 1, 2)  // 2 slices, Axis=0 taken in full (3), and each slice of dimensions (1, 2).
	// Result would be [][][][]int32{{{0, 1, 2, 3, 4}}, {{30, 31, 32, 33, 34}}, {{40, 41, 42, 43, 44}}}

func Gradient

func Gradient(output *Node, gradientNodes ...*Node) []*Node

Gradient creates new nodes for the gradients of the output with respect to each node in gradientNodes. The output must be a scalar -- otherwise this would be called Jacobian. TODO: Define a Jacobian.

func GreaterOrEqual

func GreaterOrEqual(x0 *Node, x1 *Node) (node *Node)

GreaterOrEqual performs element-wise comparison, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func GreaterOrEqualTotalOrder

func GreaterOrEqualTotalOrder(x0 *Node, x1 *Node) (node *Node)

GreaterOrEqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func GreaterThan

func GreaterThan(x0 *Node, x1 *Node) (node *Node)

GreaterThan performs element-wise comparison, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func GreaterThanTotalOrder

func GreaterThanTotalOrder(x0 *Node, x1 *Node) (node *Node)

GreaterThanTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func GrowLeft added in v0.11.0

func GrowLeft(x *Node, axis int, n int, fillValue float64) *Node

GrowLeft will grow the dimension of the given axis by concatenating n elements to the left (start). Those elements are filled with value (converted to the corresponding dtype).

func GrowRight added in v0.11.0

func GrowRight(x *Node, axis int, n int, fillValue float64) *Node

GrowRight will grow the dimension of the given axis by concatenating n elements to the left (start). Those elements are filled with fillValue (converted to the corresponding dtype).

func Identity added in v0.11.0

func Identity(x *Node) (node *Node)

Identity returns an Op whose output is the same as its input. It's a no-op that can serve as a place-holder.

func IdentityWithCustomGradient added in v0.9.0

func IdentityWithCustomGradient(x *Node, gradientFn func(x, v *Node) *Node) *Node

IdentityWithCustomGradient returns x unchanged, but sets a custom gradient function to be applied when doing the reverse autograd (gradient) calculation.

The `gradientFn` will be called during auto-grad and will be passed `x` and `v`, the "adjoint", which represents the gradient of the loss (typically, but of whatever we are calculating the gradient of) with respect to `x`, and we should return the updated `v`, that is, the customized gradient with respect to `x`.

func Imag added in v0.6.0

func Imag(x *Node) (node *Node)

Imag returns the imaginary part of a complex number. It returns 0 if the x is a float number.

func IndicesForShape

func IndicesForShape(g *Graph, shape shapes.Shape) *Node

IndicesForShape enumerates a list of indices for all elements of the given shapes. It will always return a node with shapes [shapes.Size(), shapes.Rank()]. E.g: if shapes=[3, 2], it returns `[[0 0] [0 1] [1 0] [1 1] [2 0] [2 1]]`.

func Infinity added in v0.11.0

func Infinity(g *Graph, dtype dtypes.DType, sign int) *Node

Infinity returns the positive/negative (depending on the value of sign, which must be 1 or -1) for the given dtype. For integer dtypes, it returns the highest/lowest values.

func InsertAxes added in v0.15.0

func InsertAxes(x *Node, beforeAxes ...int) *Node

InsertAxes expands x creating new axes just before the axes given -- beforeAxes points to positions on the original tensor x, and they can be repeated, in case one wants to insert more than one new axis in the given position.

If beforeAxes[ii] < 0, then they are counted from the end — -1 represents a new axis after the end of the original shape.

The new axes will be of dimension 1 (so the total size of and contents of the tensor remains the same), and the rank is increased by `len(axes)`.

See also ExpandAxes, where the new axes are given as positions in the target shape.

func InternalBatchNormForInference added in v0.11.0

func InternalBatchNormForInference(operand *Node, scale *Node, offset *Node, mean *Node, variance *Node, epsilon float32, axis int) (node *Node)

InternalBatchNormForInference is a wrapper to the backend function. Don't use this directly, instead use layers.BatchNormalization.

func InternalBatchNormForTraining added in v0.11.0

func InternalBatchNormForTraining(operand *Node, scale *Node, offset *Node, epsilon float32, axis int) (normalized, batchMean, batchVariance *Node)

InternalBatchNormForTraining is a wrapper to the backend function. Don't use this directly, instead use layers.BatchNormalization.

func InternalBatchNormGradient added in v0.11.0

func InternalBatchNormGradient(operand *Node, scale *Node, mean *Node, variance *Node, gradOutput *Node, epsilon float32, axis int) (gradOperand, gradScale, gradOffset *Node)

InternalBatchNormGradient is a wrapper to the backend function. Don't use this directly, instead use layers.BatchNormalization.

func Inverse

func Inverse(x *Node) *Node

Inverse returns (1/x), the multiplicative inverse. Also known as the reciprocal.

func InverseFFT added in v0.6.0

func InverseFFT(operand *Node) *Node

InverseFFT computes an inverse fast-fourier transformation of the operand, which is expected to be complex. The InverseFFT is computed on the last dimension, in case `operand.Rank() > 1`.

The resulting tensor (Node) has the same shapes as the input, and has the values on the frequency domain.

func InverseRealFFT added in v0.6.0

func InverseRealFFT(operand *Node) *Node

InverseRealFFT computes the inverse of a forward 1D fast-fourier transformation. The inverse FFT is computed on the last dimension, in case `operand.Rank() > 1`.

The resulting tensor (Node) has the shapes equal to the input, except the last dimension (where the FFT is computed) which is reversed back to the original, `(dim-1)*2`, where `dim` is the last dimensions of `operand`.

Note that because of the last dimension change in `RealFFT`, this cannot be perfectly reversed if `operand.Shape().Dimensions[-1]` is odd. Preferably use with even numbers.

func Iota

func Iota(g *Graph, shape shapes.Shape, iotaAxis int) *Node

Iota creates a constant of the given shape with increasing numbers (starting from 0) on the given axis. So Iota([2,2], 1) returns [[0 1][0 1]], while Iota([2,2], 0) returns [[0 0][1 1]].

See also IotaFull.

func IotaFull

func IotaFull(g *Graph, shape shapes.Shape) *Node

IotaFull creates a constant of the given shape with increasing numbers for all values. So `IotaFull([2,2])` returns `[[0 1][2 3]]`.

func IsFinite added in v0.13.0

func IsFinite(x *Node) (node *Node)

IsFinite tests whether each element of operand is finite, i.e., is not positive or negative infinity, and is not NaN. It returns an array of boolean values with the same shape as the input, where each element is true if and only if the corresponding input element is finite.

func IsZero added in v0.17.1

func IsZero(x *Node) *Node

IsZero returns a Bool tensor that is true where x is zero, and false otherwise. A shortcut to Equal(x, ScalarZero(x.Graph(), x.DType())).

func L1Norm

func L1Norm(x *Node, reduceAxes ...int) *Node

L1Norm returns the L1 norm (same as Manhattan length) of the last axis of x. The returned value has the same rank, but the last axes will have dimension 1.

If no axes are given, it returns a scalar. Otherwise, the returned value has the same rank as `x`, but the reduce axes will have dimension 1.

func L2Norm

func L2Norm(x *Node, reduceAxes ...int) *Node

L2Norm returns the L2 norm (same as Euclidean length) over the given axes of x (defaults to all), given by Sqrt(\Sum{x_i^2}).

If no axes are given, it returns a scalar. Otherwise, the returned value has the same rank as `x`, but the reduce axes will have dimension 1.

func L2NormSquare added in v0.8.0

func L2NormSquare(x *Node, reduceAxes ...int) *Node

L2NormSquare returns the L2 norm square (same as square of the Euclidean length) over the given axes of x (defaults to all). Same as `\Sum_{reduceAxes}{x_i^2}`.

If no axes are given, it returns a scalar. Otherwise, the returned value has the same rank as `x`, but the reduce axes will have dimension 1.

func L2Normalize added in v0.8.0

func L2Normalize(x *Node, reduceAxis int, moreReduceAxes ...int) *Node

L2Normalize returns `x/L2Norm(x)` on the given reduce axes, making the last axis a unit-length vector.

It will return `inf` for values of x that are near zero-length.

For elements that have L2Norm zero, it returns 0 and 1s for the gradients, so no NaNs are generated.

See L2NormalizeWithEpsilon for a version that adds an epsilon to the denominator to avoid that.

func L2NormalizeWithEpsilon added in v0.8.0

func L2NormalizeWithEpsilon(x *Node, epsilon float64, reduceAxis int, moreReduceAxes ...int) *Node

L2NormalizeWithEpsilon returns `x/(L2Norm(x)+epsilon)` on the last axis, making the last axis a unit-length vector.

func LessOrEqual

func LessOrEqual(x0 *Node, x1 *Node) (node *Node)

LessOrEqual performs element-wise comparison, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func LessOrEqualTotalOrder

func LessOrEqualTotalOrder(x0 *Node, x1 *Node) (node *Node)

LessOrEqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func LessThan

func LessThan(x0 *Node, x1 *Node) (node *Node)

LessThan performs element-wise comparison, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func LessThanTotalOrder

func LessThanTotalOrder(x0 *Node, x1 *Node) (node *Node)

LessThanTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func Log

func Log(x *Node) (node *Node)

Log returns the Op that represents the output of the corresponding operation.

func Log1P added in v0.9.0

func Log1P(x *Node) *Node

Log1P is an alias to Log1p. It returns log(1+x).

func Log1p

func Log1p(x *Node) (node *Node)

Log1p returns the expression log(x+1).

func LogAddExp added in v0.17.1

func LogAddExp(x, y *Node) *Node

LogAddExp Logarithm of the sum of exponentiations of the inputs. Calculates log(exp(x1) + exp(x2)). This function is useful in statistics where the calculated probabilities of events may be so small as to exceed the range of normal floating point numbers. In such cases the logarithm of the calculated probability is stored. This function allows adding probabilities stored in such a fashion.

func LogSoftmax added in v0.11.0

func LogSoftmax(logits *Node, axes ...int) *Node

LogSoftmax computes the logarithm of the Softmax function, which rescales elements to the range $[-\infty, 0)$.

$$
\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
\right)
$$

The axes define over which axes the LogSoftmax should be computed. If missing it is assumed to be -1.

If any input values are "+inf", the result will be all "NaN": this reflects the fact that "inf / inf" is not well-defined in the context of floating-point math.

func LogicalAll added in v0.11.1

func LogicalAll(x *Node, reduceAxes ...int) *Node

LogicalAll returns true if all values of x (converted to boolean) evaluate to true. It's a "ReduceLogicalAnd" equivalent.

If reduceAxes is empty, it will reduce over all dimensions.

func LogicalAnd added in v0.17.0

func LogicalAnd(x0 *Node, x1 *Node) (node *Node)

LogicalAnd returns the element-wise logical AND operation. The op is created on the same XlaBuilder as used for x0 and x1.

func LogicalAny added in v0.11.1

func LogicalAny(x *Node, reduceAxes ...int) *Node

LogicalAny returns true if any values of x (converted to boolean) evaluate to true. It's a "ReduceLogicalOr" equivalent.

If reduceAxes is empty, it will reduce over all dimensions.

func LogicalNot added in v0.11.0

func LogicalNot(x *Node) (node *Node)

LogicalNot returns the Op that represents the output of the corresponding operation.

func LogicalOr added in v0.17.0

func LogicalOr(x0 *Node, x1 *Node) (node *Node)

LogicalOr returns the element-wise logical OR operation. The op is created on the same XlaBuilder as used for x0 and x1.

func LogicalXor added in v0.17.0

func LogicalXor(x0 *Node, x1 *Node) (node *Node)

LogicalXor returns the element-wise logical XOR operator. The op is created on the same XlaBuilder as used for x0 and x1.

func Logistic

func Logistic(x *Node) (node *Node)

Logistic returns the element-wise expression 1/(1+exp(-x)). Also known as the Sigmoid function.

func LowerTriangular

func LowerTriangular(g *Graph, dim int) *Node

LowerTriangular returns a lower-triangular boolean square matrix of shape `[dim, dim]`.

This can be combined with `Where` to select values of any arbitrary other matrix.

func MaskedLogSoftmax added in v0.11.0

func MaskedLogSoftmax(logits, mask *Node, axes ...int) *Node

MaskedLogSoftmax computes the logarithm of the MaskedSoftmax function, which rescales elements to the range $[-\infty, 0)$.

It takes a mask that is true on the values to be considered, and false for the values not to be considered. If mask is nil, it behaves like LogSoftmask.

See LogSoftmax for details.

func MaskedReduceAllMax added in v0.9.0

func MaskedReduceAllMax(x, mask *Node) *Node

MaskedReduceAllMax reduces all dimensions to a scalar by taking the max.

It ignores values for which the corresponding mask is false. The shapes of `mask and x must be the same.

func MaskedReduceAllMean added in v0.9.0

func MaskedReduceAllMean(x, mask *Node) *Node

MaskedReduceAllMean reduces all dimensions to a scalar by taking the mean. It ignores entries where mask is false.

func MaskedReduceAllMin added in v0.11.0

func MaskedReduceAllMin(x, mask *Node) *Node

MaskedReduceAllMin reduces all dimensions to a scalar by taking the min.

It ignores values for which the corresponding mask is false. The shapes of `mask and x must be the same. If mask is nil, it behaves like ReduceAllMin.

func MaskedReduceAllSum added in v0.9.0

func MaskedReduceAllSum(x, mask *Node) *Node

MaskedReduceAllSum reduces all dimensions to a scalar by summing.

It ignores values for which the corresponding mask is false. The `mask` and `x` values must have the same shape.

func MaskedReduceAndKeep added in v0.9.0

func MaskedReduceAndKeep(x, mask *Node, reduceFn func(x, mask *Node, reduceAxes ...int) *Node, reduceAxes ...int) *Node

MaskedReduceAndKeep applies the given masked reduction function but regenerates the reduced dimensions with size 1.

func MaskedReduceMax added in v0.9.0

func MaskedReduceMax(x, mask *Node, reduceAxes ...int) *Node

MaskedReduceMax reduces by taking the max of `x` elements over the selected axes. If reduceAxes is nil, reduce over all dimensions to a scalar.

It ignores values for which the corresponding mask is false. The shapes of `mask and x must be the same. If mask is nil, it behaves like ReduceMax.

func MaskedReduceMean added in v0.9.0

func MaskedReduceMean(x, mask *Node, reduceAxes ...int) *Node

MaskedReduceMean reduces by taking the mean over the elements of the selected axes.

The reduced axes of `x` are removed in the output -- so the rank is reduced.

It first applies a mask to x, converting masked values to the neutral value of the operation (0). For reduction dimensions that are completely masked, it returns 0. If mask is nil, it behaves like ReduceMean.

func MaskedReduceMin added in v0.11.0

func MaskedReduceMin(x, mask *Node, reduceAxes ...int) *Node

MaskedReduceMin reduces by taking the min of `x` elements over the selected axes. If reduceAxes is nil, reduce over all dimensions to a scalar.

It ignores values for which the corresponding mask is false. The shapes of `mask and x must be the same. If mask is nil, it behaves like ReduceMin.

func MaskedReduceSum added in v0.9.0

func MaskedReduceSum(x, mask *Node, reduceAxes ...int) *Node

MaskedReduceSum reduces by summing the `x` elements over the selected axes. If `reduceAxes` is nil, reduce over all dimensions to a scalar.

The reduced axes of `x` are removed in the output -- so the rank is reduced.

It ignores values for which the corresponding mask is false. The `mask` and `x` values must have the same shape. If mask is nil, it behaves like ReduceSum.

func MaskedSoftmax

func MaskedSoftmax(logits, mask *Node, axes ...int) *Node

MaskedSoftmax computes softmax activations. It's the equivalent to ```

Exp(logits) / InsertAxes(ReduceSum(Exp(logits), -1), -1)

```

But implemented in a numerical stable way.

It takes a mask that is true on the values to be considered, and false for the values not to be considered.

The list axes defines which axes is it supposed to run the softmax over (the axes that will be summed over). If no axes are given, it is assumed to be [-1], meaning, the last axes.

It ignores values for which the corresponding mask is false, and will return 0 for those fields. mask and logits must have the same shape.

func MatMul added in v0.15.0

func MatMul(lhs, rhs *Node) *Node

MatMul is the `numpy.matmul` equivalent, for those used to that.

It is similar to Dot but extends to allow for more batch dimensions in lhs or rhs operand, and does broadcasting (of all but the last 2 axes) according to the numpy broadcasting rules.

It's popular hence it is here, but full of edge cases, consider using DotGeneral instead.

func Max

func Max(x0 *Node, x1 *Node) (node *Node)

Max returns the element-wise highest value among the two. The op is created on the same XlaBuilder as used for x0 and x1.

func MaxScalar added in v0.2.0

func MaxScalar[N dtypes.NumberNotComplex](x *Node, scalar N) *Node

MaxScalar converts scalar to a constant with x's DType and returns element-wise `Max(x, scalar)`.

func Min

func Min(x0 *Node, x1 *Node) (node *Node)

Min returns the element-wise smallest value among the two. The op is created on the same XlaBuilder as used for x0 and x1.

func MinScalar added in v0.2.0

func MinScalar[N dtypes.NumberNotComplex](x *Node, scalar N) *Node

MinScalar converts scalar to a constant with x's DType and returns element-wise `Min(x, scalar)`.

func MinusOne

func MinusOne(x *Node) *Node

MinusOne returns (x-1).

func MirroredLog1p added in v0.9.0

func MirroredLog1p(x *Node) *Node

MirroredLog1p is similar to Log1p, but it is mirrored to negative numbers. It return Log(Abs(x)+1)*Sign(x).

func Mod

func Mod(x, y *Node) *Node

Mod adds to the graph the module (remainder) operation on the two input nodes x and y. It's an alias to Rem. Standard broadcasting rules apply (see documentation).

func ModScalar added in v0.6.0

func ModScalar[N dtypes.NumberNotComplex](x *Node, scalar N) *Node

ModScalar converts scalar to a constant with x's DType and returns `x % scalar` with proper broadcasting.

func Mul

func Mul(x0 *Node, x1 *Node) (node *Node)

Mul returns the element-wise multiplication of the two values. Standard broadcasting rules apply (see documentation). The op is created on the same XlaBuilder as used for x0 and x1.

func MulScalar

func MulScalar[N dtypes.NumberNotComplex](x *Node, scalar N) *Node

MulScalar converts scalar to a constant with x's DType and returns `x * scalar` with proper broadcasting.

func Neg

func Neg(x *Node) (node *Node)

Neg returns the Op that represents the output of the corresponding operation.

func NegativeIndicator added in v0.17.1

func NegativeIndicator(x *Node) *Node

NegativeIndicator returns 1 where x < 0, 0 otherwise. E.g: NegativeIndicator({1.0, 0.0001, 0, -0.2, -3.0}) -> [0, 0, 0, 1, 1], with the same shape/dtype as x.

func NonNegativeIndicator added in v0.17.1

func NonNegativeIndicator(x *Node) *Node

NonNegativeIndicator returns 1 where x >= 0, 0 otherwise. See also PositiveIndicator. E.g: NonNegativeIndicator ({1.0, 0.0001, 0, -0.2, -3.0}) -> [1, 1, 1, 0, 0], with the same shape/dtype as x.

func NonPositiveIndicator added in v0.17.1

func NonPositiveIndicator(x *Node) *Node

NonPositiveIndicator returns 1 where x <= 0, 0 otherwise. See also NegativeIndicator. E.g: NonPositiveIndicator ({1.0, 0.0001, 0, -0.2, -3.0}) -> [0, 0, 1, 1, 1], with the same shape/dtype as x.

func Not

func Not(x *Node) *Node

Not is an alias for LogicalNot.

func NotEqual

func NotEqual(x0 *Node, x1 *Node) (node *Node)

NotEqual performs element-wise inequality check, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func NotEqualTotalOrder

func NotEqualTotalOrder(x0 *Node, x1 *Node) (node *Node)

NotEqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func OneHot

func OneHot(indices *Node, depth int, dtype dtypes.DType) *Node

OneHot converts an integer numbers representing indices to it's "one-hot" representation, that is an expanded tensor with the indices position set to 1, and the other positions set to 0. The returned tensor has one extra dimension at the end. For example `OneHot([][]INT64{1, 0, 3}, 4, types.Float32)` returns `[][]F32{{0, 1, 0, 0}, {1, 0, 0, 0}, {0, 0, 0, 1}}`

func OneMinus

func OneMinus(x *Node) *Node

OneMinus returns (1-x).

func OnePlus

func OnePlus(x *Node) *Node

OnePlus returns (1+x).

func Ones

func Ones(g *Graph, shape shapes.Shape) *Node

Ones creates a computation with the same shape as the input, but with the value 1. It's implemented indirectly using other nodes.

func OnesLike

func OnesLike(x *Node) *Node

OnesLike returns a tensor with the same shape of x, filled with 1's.

func Or

func Or(lhs, rhs *Node) *Node

Or is an alias for LogicalOr.

func Pad

func Pad(x *Node, fillValue *Node, axesConfig ...backends.PadAxis) (node *Node)

Pad injects padding on the start, end or interior (in between each element) of the given operand. There must be at most `operand.Rank()` axesConfig values. Missing PadAxis are assumed to be zeros, that is, no padding for those axes.

func Parameter added in v0.11.0

func Parameter(g *Graph, name string, shape shapes.Shape) (node *Node)

Parameter registers an input parameter for a computation Graph (e.g: a feature used as input).

When created they get a handle (a plain index) but they can also be accessed It can be used in two different ways: as a Node when building the Graph, so when defining a function that uses the parameter, or as the key in the map of the inputNodes when executing the computation Graph (see Backend.RunWithMap).

func PositiveIndicator

func PositiveIndicator(x *Node) *Node

PositiveIndicator returns 1 where x > 0, 0 otherwise. E.g: PositiveIndicator({1.0, 0.0001, 0, -0.2, -3.0}) -> [1, 1, 0, 0, 0], with the same shape/dtype as x.

func Pow

func Pow(x0 *Node, x1 *Node) (node *Node)

Pow returns the Op that represents the output of the corresponding operation. The op is created on the same XlaBuilder as used for x0 and x1.

func PowScalar added in v0.4.0

func PowScalar[N dtypes.NumberNotComplex](x *Node, scalar N) *Node

PowScalar converts scalar to a constant with x's DType and returns `Pow(x, scalar)` (or `x ** scalar`) with proper broadcasting.

func RandomIntN added in v0.15.1

func RandomIntN[IntT interface{ *Node | constraints.Integer }](
	rngState *Node, N IntT, shape shapes.Shape) (newRngState, values *Node)

RandomIntN generates random numbers uniformly from 0 to N-1. It only works for integer types, see RandomUniform for float or complex data types. N can be given as a Node, or a static scalar integer value.

Example:

rngState := Const(g, RngStateFromSeed(42))
rngState, D10 := RandomIntN(rngState, 10, shapes.Make(dtypes.Int32))

It uses and updates the random number generator (RNG) state in `rngState`. See RngStateFromSeed or RngState to generate a random state tensor (that can be fed to the computation graph).

Alternatively, if you don't want to worry about carrying around the rngState, use the context.Context.RandomIntN version, which stores the rngState as a variable.

func RandomNormal added in v0.4.0

func RandomNormal(rngState *Node, shape shapes.Shape) (newRngState, values *Node)

RandomNormal generates random numbers from a normal distribution, with mean 0.0 and standard deviation 1.0. It generates values with the given shapes, each value pseudo-randomly generated.

If you need a different mean and standard deviation, just do something like the example below, where `mean` and `stddev` are the desired mean and standard deviation:

rngState := Const(g, RngStateFromSeed(42))
rngState, values := RandomNormal(rngState, shapes.Make(dtypes.Float32, 3, 2))
numbers = AddScalar(MulScalar(values, stddev), mean)

It uses the Box-Muller algorithm (see https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform). It has some numeric limitations, but works well for most purposes.

It will signal an error if the dtype is not float. See also RandomIntN for random integers.

It uses and updates the random number generator (RNG) state in `rngState`. See RngStateFromSeed or RngState to generate a random state tensor (that can be fed to the computation graph).

Alternatively, if you don't want to worry about carrying around the rngState, use the context.Context.RandomNormal version, which stores the rngState as a variable.

func RandomUniform added in v0.4.0

func RandomUniform(rngState *Node, shape shapes.Shape) (newRngState, values *Node)

RandomUniform generates random uniform values from 0.0 to 1.0 (half-open `[0.0, 1.0)`, so 1.0 is never returned) for float numbers in the given shapes.

It will signal an error if the dtype is not float -- see RandomIntN for random integers.

For complex numbers, both the real and the imaginary part are independently sampled from `[0.0, 1.0)`.

It uses and updates the random number generator (RNG) state in `rngState`. See RngStateFromSeed or RngState to generate a random state tensor (that can be fed to the computation graph).

Alternatively, if you don't want to worry about carrying around the rngState, use the context.Context.RandomUniform version, which stores the rngState as a variable.

Example:

rngState := Const(g, RngStateFromSeed(42))
rngState, values := RandomUniform(rngState, shapes.Make(dtypes.Float32, 3, 2))

func Real added in v0.6.0

func Real(x *Node) (node *Node)

Real return the real part of a complex number. It returns x if the x is a float number.

func RealFFT added in v0.6.0

func RealFFT(operand *Node) *Node

RealFFT computes a forward 1D fast-fourier transformation on a real (float) input. The FFT is computed on the last dimension, in case `operand.Rank() > 1`.

The resulting tensor (Node) has the shapes equal to the input, except the last dimension (where the FFT is computed) which has dimension `dim/2 + 1`, where `dim` is the last dimensions of `operand`.

Note that because of the last dimension change in `RealFFT`, this cannot be perfectly reversed if `operand.Shape().Dimensions[-1]` is odd. Preferably use with even numbers.

func ReduceAllMax

func ReduceAllMax(x *Node) *Node

ReduceAllMax reduces all dimensions to a scalar by taking the max.

func ReduceAllMean

func ReduceAllMean(x *Node) *Node

ReduceAllMean reduces all dimensions to a scalar by taking the mean.

func ReduceAllMin added in v0.11.0

func ReduceAllMin(x *Node) *Node

ReduceAllMin reduces all dimensions to a scalar by taking the min.

func ReduceAllMultiply

func ReduceAllMultiply(x *Node) *Node

ReduceAllMultiply reduces all dimensions to a scalar by multiplying.

func ReduceAllSum

func ReduceAllSum(x *Node) *Node

ReduceAllSum reduces all dimensions to a scalar by summing.

func ReduceAndKeep

func ReduceAndKeep(x *Node, reduceFn func(x *Node, reduceAxes ...int) *Node, reduceAxes ...int) *Node

ReduceAndKeep applies the given reduction function but regenerate the reduced dimensions with size 1. If len(reduceAxes) is 0 (no axes given) it's assumed it is being reduced on all axes.

func ReduceBitwiseAnd added in v0.17.0

func ReduceBitwiseAnd(x *Node, reduceAxes ...int) *Node

ReduceBitwiseAnd returns the bitwise AND of the values across the given axes. Only defined for integer values. No gradients are defined.

If reduceAxes is empty, it will reduce over all dimensions.

func ReduceBitwiseOr added in v0.17.0

func ReduceBitwiseOr(x *Node, reduceAxes ...int) *Node

ReduceBitwiseOr returns the bitwise OR of the values across the given axes. Only defined for integer values. No gradients are defined.

If reduceAxes is empty, it will reduce over all dimensions.

func ReduceBitwiseXor added in v0.17.0

func ReduceBitwiseXor(x *Node, reduceAxes ...int) *Node

ReduceBitwiseXor returns the bitwise XOR of the values across the given axes. Only defined for integer values. No gradients are defined.

If reduceAxes is empty, it will reduce over all dimensions.

func ReduceLogicalAnd added in v0.17.0

func ReduceLogicalAnd(x *Node, reduceAxes ...int) *Node

ReduceLogicalAnd returns true if all values of x evaluate to true across the given axes. No gradients are defined.

If reduceAxes is empty, it will reduce over all dimensions.

func ReduceLogicalOr added in v0.17.0

func ReduceLogicalOr(x *Node, reduceAxes ...int) *Node

ReduceLogicalOr returns true if any values of x evaluate to true across the given axes. No gradients are defined.

If reduceAxes is empty, it will reduce over all dimensions.

func ReduceLogicalXor added in v0.17.0

func ReduceLogicalXor(x *Node, reduceAxes ...int) *Node

ReduceLogicalXor returns the xor of the values across the given axes. No gradients are defined.

If reduceAxes is empty, it will reduce over all dimensions.

func ReduceMax

func ReduceMax(x *Node, reduceAxes ...int) *Node

ReduceMax reduces by taking the max over the elements of the selected axes. If reduceAxes is nil, reduce over all dimensions to a scalar.

The reduced axes of `x` are removed in the output -- so the rank is reduced. See ReduceAndKeep for a version to preserve the reduced axes.

func ReduceMean

func ReduceMean(x *Node, reduceAxes ...int) *Node

ReduceMean reduces by taking the mean over the elements of the selected axes.

The reduced axes of `x` are removed in the output -- so the rank is reduced. See ReduceAndKeep for a version to preserve the reduced axes.

func ReduceMin added in v0.11.0

func ReduceMin(x *Node, reduceAxes ...int) *Node

ReduceMin reduces by taking the min over the elements of the selected axes. If reduceAxes is nil, reduce over all dimensions to a scalar.

The reduced axes of `x` are removed in the output -- so the rank is reduced. See ReduceAndKeep for a version to preserve the reduced axes.

func ReduceMultiply

func ReduceMultiply(x *Node, reduceAxes ...int) *Node

ReduceMultiply reduces by summing over the elements of the selected axes. If reduceAxes is nil, reduce over all dimensions to a scalar.

The reduced axes of `x` are removed in the output -- so the rank is reduced. See ReduceAndKeep for a version to preserve the reduced axes.

func ReduceSkewness added in v0.17.0

func ReduceSkewness(x *Node, axes ...int) *Node

ReduceSkewness calculates the skewness (the 3rd standardized moment of a distribution) across the given axes.

If no axes is given, it assumes it should reduce all axes and returns a scalar.

func ReduceSum

func ReduceSum(x *Node, reduceAxes ...int) *Node

ReduceSum reduces by summing over X elements over the selected axes. If reduceAxes is nil, reduce over all dimensions to a scalar.

The reduced axes of `x` are removed in the output -- so the rank is reduced. See ReduceAndKeep for a version to preserve the reduced axes.

func ReduceVariance added in v0.15.2

func ReduceVariance(x *Node, axes ...int) *Node

ReduceVariance calculates the variance across the given axes.

If no axes is given, it assumes it should reduce all axes and returns a scalar.

func Rem added in v0.11.0

func Rem(x0 *Node, x1 *Node) (node *Node)

Rem returns the remainder operation, also known as modulo (or Mod for short). Notice despite the name XLA implements Mod not IEEE754 Remainder operation. The op is created on the same XlaBuilder as used for x0 and x1.

func Reshape

func Reshape(x *Node, dimensions ...int) *Node

Reshape x to the given dimensions. Total size cannot change. One dimension can be left as -1, in which case it will be set to match the size, if possible.

func ReshapeWithShape

func ReshapeWithShape(x *Node, shape shapes.Shape) *Node

ReshapeWithShape reshapes x to the dimensions given by shape. Total size cannot change, neither the DType is allowed to change. Conceptually, this is a limited form of "shape casting."

func Reverse

func Reverse(x *Node, axes ...int) *Node

Reverse returns x with the values for the given dimensions reversed, that is, the value indexed at `i` will be swapped with the value at indexed `(dimension_size - 1 - i)`. The shape remains the same.

func RngStateSplit added in v0.4.0

func RngStateSplit(rngState *Node) (newRngState1, newRngState2 *Node)

RngStateSplit splits the current state into 2 different states that can be used separately and will lead to different random numbers.

func Round

func Round(x *Node) (node *Node)

Round returns the Op that represents the output of the corresponding operation.

func Rsqrt added in v0.11.0

func Rsqrt(x *Node) (node *Node)

Rsqrt returns the element-wise reciprocal of square root operation 1/sqrt(x).

func Scalar

func Scalar[N dtypes.NumberNotComplex](g *Graph, dtype dtypes.DType, value N) *Node

Scalar returns a constant scalar with the given value.

The value is first converted to float64 to serve as index to a cache and later converted to the requested dtype. This may lose bits of precision to very large integers. If you are worried with any of these conversions, use Const instead.

func ScalarOne

func ScalarOne(g *Graph, dtype dtypes.DType) *Node

ScalarOne returns a scalar constant 1 for the given DType.

func ScalarZero

func ScalarZero(g *Graph, dtype dtypes.DType) *Node

ScalarZero returns a scalar constant 0 for the given DType.

func Scatter

func Scatter(indices, updates *Node, shape shapes.Shape) *Node

Scatter sums up the slices in updates into a new tensor of the given shapes, at the locations pointed by indices. It does the opposite of Gather.

In the simplest form, [indices] is shaped `[num_updates, 1]`, [updates] is shaped `[num_updates, update_size]` and shapes is of the form `[output_size, update_size]`. The indices values should be in between 0 and `output_size-1`.

func ScatterAdd deprecated

func ScatterAdd(operand, indices, updates *Node, sorted, unique bool) *Node

ScatterAdd is a deprecated alias to ScatterSum.

Deprecated: Please use ScatterSum instead.

func ScatterMax added in v0.11.0

func ScatterMax(operand, indices, updates *Node, sorted, unique bool) *Node

ScatterMax updates the max value of operand, from the values in updates pointed by indices.

The operand provides the initial values for the operation, and typically will be initialized with -inf. See Infinity and BroadcastToDims to create an arbitrarily shaped node filled with infinity.

Args: - [sorted]: the indices must be in order. In some cases it is faster, but if indices are not in order results may be unstable. - unique: the indices must be unique. In some cases it is faster, but if indices are not unique results may be unstable.

func ScatterMin added in v0.11.0

func ScatterMin(operand, indices, updates *Node, sorted, unique bool) *Node

ScatterMin updates the min value of operand, from the values in updates pointed by indices.

The operand provides the initial values for the operation, and typically will be initialized with +inf. See Infinity and BroadcastToDims to create an arbitrarily shaped node filled with infinity.

Args: - [sorted]: the indices must be in order. In some cases it is faster, but if indices are not in order results may be unstable. - unique: the indices must be unique. In some cases it is faster, but if indices are not unique results may be unstable.

func ScatterSum added in v0.18.0

func ScatterSum(operand, indices, updates *Node, sorted, unique bool) *Node

ScatterSum adds up the slices in updates into the given operand tensor, at the locations pointed by indices. It does the opposite of Gather.

Args: - [sorted]: the indices must be in order. In some cases it is faster, but if indices are not in order results may be unstable. - unique: the indices must be unique. In some cases it is faster, but if indices are not unique results may be unstable.

func ShapedLowerTriangular added in v0.12.0

func ShapedLowerTriangular(g *Graph, rows, column, k int) *Node

ShapedLowerTriangular returns a triangular boolean matrix (rows x column) (not necessarily rows == columns), where the lower triangular are set to true (including diagonal), and the upper triangular is set to zero.

The k value shifts the triangular up or down: k < 0 sets true values below the diagonal. Conversely, k > 0 extends the true values above the diagonal.

Examples:

ShapedLowerTriangular(g, 3, 3, k=0) => [][]bool{{true, false, false}, {true, true, false}, {true, true, true}}
ShapedLowerTriangular(g, 3, 3, k=-1) => [][]bool{{false, false, false}, {true, false, false}, {true, true, false}}
ShapedLowerTriangular(g, 2, 3, k=1) => [][]bool{{true, true, false}, {true, true, true}}

func Shift added in v0.10.0

func Shift(x *Node, axis int, shiftDir ShiftDirection, n int) *Node

Shift a given [axis] of [x] by [n] positions ([n] is a static value). The [shiftDir] defines the direction: left towards lower values or right towards higher values. The spaces left open keep the edge value. Example:

Shift([0, 1, 2, 3], axis=-1, ShiftDirLeft, n=2)

Will return `[2, 3, 3, 3]`.

func ShiftLeft added in v0.10.0

func ShiftLeft(x *Node, n int, fill float64) *Node

ShiftLeft the last axis of [x] by [n] positions ([n] is a static value) and fill the new value with [fill]. The value of [fill] is converted to [x]'s dtypes.DType. For boolean dtype, use 1.0 or 0.0.

See ShiftWithScalar and ShiftWithValue for a more generic shift function.

func ShiftRight added in v0.10.0

func ShiftRight(x *Node, n int, fill float64) *Node

ShiftRight the last axis of [x] by [n] positions ([n] is a static value) and fill the new value with [fill]. The value of [fill] is converted to [x]'s dtypes.DType. For boolean dtype, use 1.0 or 0.0.

See ShiftWithScalar and ShiftWithValue for a more generic shift function.

func ShiftWithScalar added in v0.10.0

func ShiftWithScalar(x *Node, axis int, shiftDir ShiftDirection, n int, fill float64) *Node

ShiftWithScalar a given [axis] of [x] by [n] positions ([n] is a static value) and fill the new value with [fill], a **static** scalar value. The [shiftDir] defines the direction: left towards lower values or right towards higher values. The value of [fill] is converted to [x]'s dtypes.DType. For boolean dtype, use 1.0 or 0.0.

func ShiftWithValue added in v0.10.0

func ShiftWithValue(x *Node, axis int, shiftDir ShiftDirection, n int, value *Node) *Node

ShiftWithValue a given [axis] of [x] by [n] positions ([n] is a static value) and fill the new value with a dynamic (graph) [value]. The [shiftDir] defines the direction: left towards lower values or right towards higher values. The filling [value] must be "broadcast-able" (see [BroadcastToDim]) to the space it's going to fill with the shift -- a scalar can always be broadcast.

func Sigmoid

func Sigmoid(x *Node) *Node

Sigmoid returns the expression $1/(1+exp(-x)). It is an alias to the Logistic function.

func Sign

func Sign(x *Node) *Node

Sign returns element-wise +1, +/-0 or -1 depending on the sign of x. It returns NaN if the input is NaN. The gradient of Sign is assumed to be zero everywhere.

func SignPlusOrMinus

func SignPlusOrMinus(x *Node) *Node

SignPlusOrMinus return +1 or -1 whether x >= 0 or x < 0. It's similar to Sign, but where 0s are considered positive.

func Sin

func Sin(x *Node) (node *Node)

Sin returns the Op that represents the output of the corresponding operation.

func Skewness added in v0.17.0

func Skewness(x *Node, axes ...int) *Node

Skewness calculates the skewness (the 3rd standardized moment of a distribution) across the given axes. It's just an alias to ReduceSkewness.

It's a form of reduction function, and the returned rank will be x.Rank() - len(axes).

If no axes is given, it assumes it should reduce all axes and returns a scalar.

func Slice

func Slice(x *Node, axesSpec ...SliceAxisSpec) *Node

Slice take slices of the operand.

Each axis can have a range defined as (start, end) pairs. Any axis for which a range is not specified is assumed to be taken in full. Consider using the shortcut AxisRange to define the ranges.

Examples:

- For `x = {10, 20, 30, 40}`:

  • `Slice(x) = {10, 20, 30, 40}` // SliceAxisSpec not given is taken in full.
  • `Slice(x, AxisRange()) = {10, 20, 30, 40}` // Default for AxisRange is the full range.
  • `Slice(x, AxisRange(1,-1)) = {20, 30}` // Negative values are taken from the end of the axis dimension.
  • `Slice(x, AxisRangeFromStart(-2)) = {10, 20}` // Negative values are taken from the end of the axis dimension.
  • `Slice(x, AxisRangeToEnd(2)) = {30, 40}` // Negative values are taken from the end of the axis dimension.
  • `Slice(x, AxisElem(2)) = {3}` // Take only one element of an axis.

- For `x = {{1, 2, 3}, {4, 5, 6}}`:

  • `Slice(x, AxisRange(), AxisElem(0)) = {{1}, {4}}` // First axis taken in full, second axis only the first element.
  • `Slice(x, AxisElem(1)) = {{4, 5, 6}}` // Missing second SliceAxisSpec, assumed to be taken in full.

If Slice is called with `x.shape = [5, 5, 5, 5]` and `axesRanges=AxisElem(1), AxisRange(), AxisRange(2), AxisRange(0,2)` would return a node shaped `[1, 5, 3, 2]`.

It also supports "spacers" (like "*" in paths), that fill the unknown axes. Example: let's say we want to get just the last example of a batch, and just the first element of the embedding. Assume x is shaped `[batch_size, ..., embedding_size]` and we want something like `x[-1, ..., 0:1]`.

sample := Slice(x, AxisElem(-1), AxisRange().Spacer(), AxisElem(0))

It also works with strides, use the SliceAxisSpec.Stride() method to conveniently set it.

Example:

- For `x = {1, 2, 3, 4}`:

  • `Slice(x, AxisRange().Stride(2)) = {1, 3}` // The whole range, but with a stride of 2.

- For `x = {{1, 2, 3}, {4, 5, 6}}`:

  • `Slice(x, AxisRange().Stride(2), AxisRange(-1)) = {{3}}` // Take every 2nd row (so only the 1st here), the last column.

func SliceAxis added in v0.13.0

func SliceAxis(x *Node, axis int, axisSpec SliceAxisSpec) *Node

SliceAxis is similar to Slice, but take a slice of one axis only, and preserve all others.

Example:

x.Shape() == [5, 4, 3]
SliceAxis(x, 1, AxisElem(1)) -> shape [5, 1 (sliced axis), 3]

func Softmax

func Softmax(logits *Node, axes ...int) *Node

Softmax computes softmax activations. It's the equivalent to

Exp(logits) / ReduceAndKeep(Exp(logits), ReduceSum, axes...)

But implemented in a numerical stable way.

The list axes defines which axes is it supposed to run the softmax over (the axes that will be summed over).

If no axes are given, it is assumed to be [-1], meaning, the last axes.

func Softplus added in v0.17.1

func Softplus(x *Node) *Node

Softplus activation function $[\log\(1+\exp(x))$ Equivalent of Log1P(Exp(x)) But implemented in a numerical stable way.

func Sqrt

func Sqrt(x *Node) (node *Node)

Sqrt returns the Op that represents the output of the corresponding operation.

func Square

func Square(x *Node) *Node

Square returns x^2 point-wise. Same as `Mul(x, x)`.

func Squeeze

func Squeeze(x *Node, axes ...int) *Node

Squeeze removes `axes` of dimension 1. If `axes` is not set, all axes of dimension 1 are removed. Otherwise, only the provided `axes` are removed. If any of the given `axes` is not of dimension 1, an error is raised in the Graph and an invalid node is returned.

If all dimensions are reduced, it returns a scalar.

func Stack added in v0.15.2

func Stack(operands []*Node, axis int) *Node

Stack puts together many values (*Node) with the exact same shape by creating a new axis and concatenating them.

Axis is relative to returning shape.

The returned value increased the rank by 1: output.Rank() = 1+operands[i].Rank()

func StopGradient

func StopGradient(x *Node) *Node

StopGradient creates an identity node (see Identity), through which gradients don't back-propagate.

No new XLA outputOps is created, so there are no costs to the computation execution speed.

func Sub

func Sub(x0 *Node, x1 *Node) (node *Node)

Sub returns the element-wise subtraction of the two values. Standard broadcasting rules apply (see documentation). The op is created on the same XlaBuilder as used for x0 and x1.

func TakeLowerTriangular added in v0.12.0

func TakeLowerTriangular(x *Node, k int) *Node

TakeLowerTriangular takes the lower triangular of the last 2 dimensions of x (x.Rank() must be >= 2), and set the other values to 0. The returned shape is the same as x.

The k value shifts the triangular up or down: k < 0 takes values further below the diagonal. Conversely, k > 0 extends the true values above the diagonal.

It uses ShapedLowerTriangular to calculate the mask.

Examples:

input = AddScalar(IotaFull(g, shapes.Make(dtypes.Float64, 2, 2)), 1)
TakeLowerTriangular(input, 0) => [][]float64{{1, 0}, {3, 4}}

input = AddScalar(IotaFull(g, shapes.Make(dtypes.Float32, 1, 2, 3, 4)), 1)
TakeLowerTriangular(input, 0)
// -> [][][][]float32{{{{1, 0, 0, 0}, {5, 6, 0, 0}, {9, 10, 11, 0}}, {{13, 0, 0, 0}, {17, 18, 0, 0}, {21, 22, 23, 0}}}}

TakeLowerTriangular(input, -1)
// -> [][][][]float32{{{{0, 0, 0, 0}, {5, 0, 0, 0}, {9, 10, 0, 0}}, {{0, 0, 0, 0}, {17, 0, 0, 0}, {21, 22, 0, 0}}}}

TakeLowerTriangular(input, 1)
// -> [][][][]float32{{{{1, 2, 0, 0}, {5, 6, 7, 0}, {9, 10, 11, 12}}, {{13, 14, 0, 0}, {17, 18, 19, 0}, {21, 22, 23, 24}}}}

func TakeUpperTriangular added in v0.12.0

func TakeUpperTriangular(x *Node, k int) *Node

TakeUpperTriangular takes the upper triangular of the last 2 dimensions of x (x.Rank() must be >= 2), and set the other values to 0. The returned shape is the same as x.

The k value shifts the triangular up or down: k < 0 takes values further below the diagonal. Conversely, k > 0 extends the true values above the diagonal.

It uses ShapedLowerTriangular to calculate the mask.

Examples:

input = AddScalar(IotaFull(g, shapes.Make(dtypes.Float64, 2, 2)), 1)
TakeUpperTriangular(input, 0) => [][]float64{{1, 2}, {0, 4}}

input = AddScalar(IotaFull(g, shapes.Make(dtypes.Float32, 1, 2, 3, 4)), 1)
TakeUpperTriangular(input, 0)
// -> [][][][]float32{{{{1, 2, 3, 4}, {0, 6, 7, 8}, {0, 0, 11, 12}}, {{13, 14, 15, 16}, {0, 18, 19, 20}, {0, 0, 23, 24}}}}

TakeUpperTriangular(input, -1)
// -> [][][][]float32{{{{1, 2, 3, 4}, {5, 6, 7, 8}, {0, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {0, 22, 23, 24}}}}

TakeUpperTriangular(input, 1)
// -> [][][][]float32{{{{0, 2, 3, 4}, {0, 0, 7, 8}, {0, 0, 0, 12}}, {{0, 14, 15, 16}, {0, 0, 19, 20}, {0, 0, 0, 24}}}}

func Tanh

func Tanh(x *Node) (node *Node)

Tanh returns the Op that represents the output of the corresponding operation.

func Transpose

func Transpose(x *Node, axisA, axisB int) *Node

Transpose returns x with the axes axisA and axisB transposed.

func TransposeAllDims

func TransposeAllDims(x *Node, permutations ...int) *Node

TransposeAllDims allows one to transpose any or all dimensions. It permutes the operand axes with the given permutation, so ∀ i, 0 ≤ i < rank ⇒ input_dimensions[permutations[i]] = output_dimensions[i].

func UpperTriangular

func UpperTriangular(g *Graph, dim int) *Node

UpperTriangular returns an upper-triangular boolean square matrix of shape `[dim, dim]`.

This can be combined with `Where` to select values of any arbitrary other matrix.

func Variance added in v0.15.2

func Variance(x *Node, axes ...int) *Node

Variance calculates the variance across the given axes. It's just an alias to ReduceVariance.

It's a form of reduction function, and the returned rank will be x.Rank() - len(axes).

If no axes is given, it assumes it should reduce all axes and returns a scalar.

func Where

func Where(condition, onTrue, onFalse *Node) *Node

Where takes element-wise values from onTrue or onFalse depending on the value of condition (expected to be boolean).

Usual implicit broadcasting rules don't apply. But it will broadcast in the following cases:

  1. If either onTrue or onFalse are a scalar, they are broadcast to the other (onFalse or onTrue respectively). If both are scalars, they will be broadcast to the shape of condition.
  2. If condition is a prefix to the shapes of onTrue/onFalse then condition is expanded to match. This is useful for masking of embeddings for instance.

func Zeros

func Zeros(g *Graph, shape shapes.Shape) *Node

Zeros creates a computation with the same shape as the input, but with the value 0. It's implemented indirectly using other nodes.

func ZerosLike

func ZerosLike(x *Node) *Node

ZerosLike returns a tensor with the same shape of x, filled with 0's.

func (*Node) AssertDims

func (n *Node) AssertDims(dimensions ...int)

AssertDims checks whether the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

If the shape is not what was expected, it panics with an error message.

This often serves as documentation for the code when implementing some complex computational graphs. This allows the reader of the code to corroborate what is the expected shape of a node.

Example:

batch_size := inputNodes[0].Shape().Dimensions[0]
…
layer := Concatenate(allEmbeddings, -1)
layer.AssertDims(batchSize, -1) // 2D tensor, with batch size as the leading dimension.

func (*Node) AssertRank

func (n *Node) AssertRank(rank int)

AssertRank checks whether the shape has the given rank.

If the rank is not what was expected, it panics with an error message.

This often serves as documentation for the code when implementing some complex computational graphs. This allows the reader of the code to corroborate what is the expected shape of a node.

It can be used in a similar fashion as AssertDims.

func (*Node) AssertScalar

func (n *Node) AssertScalar()

AssertScalar checks whether the shape is a scalar.

If the rank is not what was expected, it panics with an error message.

It can be used in a similar fashion as AssertDims.

func (*Node) AssertValid added in v0.5.0

func (n *Node) AssertValid()

AssertValid panics if `n` is nil, or if its graph is invalid.

func (*Node) ConstantValue added in v0.15.0

func (n *Node) ConstantValue() *tensors.Tensor

ConstantValue returns the value assigned to a constant node. It's an "introspection" method. It returns nil if n.Type() != NodeTypeConstant.

func (*Node) CustomGradient added in v0.11.0

func (n *Node) CustomGradient() VJP

CustomGradient returns a registered custom gradient for the Node. See IdentityWithCustomGradient.

func (*Node) DType

func (n *Node) DType() dtypes.DType

DType returns the DType of the node's shapes.

func (*Node) GetAlias added in v0.17.0

func (n *Node) GetAlias() string

GetAlias returns the alias (with the absolute path) of the current node, if one was registered with Node.WithAlias, otherwise returns "".

func (*Node) GetParameterHandle added in v0.11.0

func (n *Node) GetParameterHandle() ParameterHandle

GetParameterHandle returns the parameter id in the graph. It panics if node is not a parameter.

func (*Node) GetParameterName added in v0.11.0

func (n *Node) GetParameterName() string

GetParameterName returns the parameter name. If node is not a parameter, it panics.

func (*Node) Graph

func (n *Node) Graph() *Graph

Graph that holds this Node.

func (*Node) Id

func (n *Node) Id() NodeId

Id is the unique id of this node within the Graph.

func (*Node) Inputs

func (n *Node) Inputs() []*Node

Inputs are the other nodes that are direct inputNodes to the node. This doesn't include static inputNodes for some operations that are not given by other Graph nodes.

func (*Node) IsConstantExpression added in v0.15.0

func (n *Node) IsConstantExpression() bool

IsConstantExpression returns whether the Node is a Constant or an expression that depends only on constant values. It traverses all the node dependencies and checks that all leaf nodes are constants.

func (*Node) IsLogged

func (n *Node) IsLogged() bool

IsLogged returns whether node is marked to be logged.

func (*Node) IsScalar added in v0.11.0

func (n *Node) IsScalar() bool

IsScalar returns whether the node's shape is a scalar.

func (*Node) LogMessage

func (n *Node) LogMessage() string

LogMessage associated with node, if any.

func (*Node) NumOutputs added in v0.11.0

func (n *Node) NumOutputs() int

NumOutputs returns the number of outputs for a node.

Almost every node will have one output only. But a few (like "RngBitGenerator") will output various outputs that are split before usage. These nodes are marked with an invalid dtype.

Used internally only, all Graph public operations will return nodes with one output only.

func (*Node) Rank

func (n *Node) Rank() int

Rank returns the rank of the node's shape.

func (*Node) SetLogged

func (n *Node) SetLogged(message string)

SetLogged indicates that a node should be logged by executors, with the given message.

func (*Node) SetLoggedf added in v0.11.0

func (n *Node) SetLoggedf(format string, args ...any)

SetLoggedf indicates that a node should be logged by executors, with the given formatted message. See SetLogged.

func (*Node) Shape

func (n *Node) Shape() shapes.Shape

Shape of the Node's output. It can be `nil`, for nodes that simply have a side effect, like a "Print" Node.

func (*Node) StopGradient added in v0.11.0

func (n *Node) StopGradient() bool

StopGradient returns weather node is a StopGradient.

func (*Node) String

func (n *Node) String() (str string)

String implements the `fmt.Stringer` interface. Logged nodes are marked with (*).

func (*Node) Trace

func (n *Node) Trace() error

Trace returns stack-trace in form of an error, of when the node was created. Only available if enabled by `Graph.SetTraced(true)`.

func (*Node) Type added in v0.11.0

func (n *Node) Type() NodeType

Type identify the operation performed by the node. It's an "introspection" method.

func (*Node) WithAlias added in v0.17.0

func (n *Node) WithAlias(alias string) *Node

WithAlias sets an alias in the Graph for the node. It allows it to be retrieved with Graph.GetNodeByAlias.

The alias is automatically prefixed with the Graph current "alias scope", see Graph.PushAliasScope and Graph.PopAliasScope. Except if the alias starts with AliasScopeSeparator ("/"), in which case it is assumed to be given with an "absolute scope path".

It returns the Node itself, to allow cascading method calling.

It panics if the exact same alias already exists.

type NodeId

type NodeId int

NodeId is a unique NodeId within a Graph

type NodeInputs added in v0.11.0

type NodeInputs interface {
	Type() NodeType

	// String prints a descriptive representation of the node, using its parameters.
	String() string
}

NodeInputs represents the inputs to node. The common interface is to return the type of the node. For the input parameters themselves, the pointer needs to be cast to the corresponding type, usually named inputNodes<backend_operation_name>, see generated gen_backend_ops.go

type NodeType added in v0.11.0

type NodeType int
const (
	NodeTypeInvalid NodeType = iota
	NodeTypeSplitNode
	NodeTypeAbs
	NodeTypeAdd
	NodeTypeArgMinMax
	NodeTypeBatchNormForInference
	NodeTypeBatchNormForTraining
	NodeTypeBatchNormGradient
	NodeTypeBitCount
	NodeTypeBitcast
	NodeTypeBitwiseAnd
	NodeTypeBitwiseNot
	NodeTypeBitwiseOr
	NodeTypeBitwiseXor
	NodeTypeBroadcast
	NodeTypeBroadcastInDim
	NodeTypeCeil
	NodeTypeClz
	NodeTypeComplex
	NodeTypeConcatenate
	NodeTypeConj
	NodeTypeConstant
	NodeTypeConvGeneralDilated
	NodeTypeConvertDType
	NodeTypeCos
	NodeTypeDiv
	NodeTypeDot
	NodeTypeDotGeneral
	NodeTypeDynamicSlice
	NodeTypeDynamicUpdateSlice
	NodeTypeEqual
	NodeTypeEqualTotalOrder
	NodeTypeErf
	NodeTypeExp
	NodeTypeExpm1
	NodeTypeFFT
	NodeTypeFloor
	NodeTypeGather
	NodeTypeGreaterOrEqual
	NodeTypeGreaterOrEqualTotalOrder
	NodeTypeGreaterThan
	NodeTypeGreaterThanTotalOrder
	NodeTypeIdentity
	NodeTypeImag
	NodeTypeIota
	NodeTypeIsFinite
	NodeTypeLessOrEqual
	NodeTypeLessOrEqualTotalOrder
	NodeTypeLessThan
	NodeTypeLessThanTotalOrder
	NodeTypeLog
	NodeTypeLog1p
	NodeTypeLogicalAnd
	NodeTypeLogicalNot
	NodeTypeLogicalOr
	NodeTypeLogicalXor
	NodeTypeLogistic
	NodeTypeMax
	NodeTypeMin
	NodeTypeMul
	NodeTypeNeg
	NodeTypeNotEqual
	NodeTypeNotEqualTotalOrder
	NodeTypePad
	NodeTypeParameter
	NodeTypePow
	NodeTypeReal
	NodeTypeReduceBitwiseAnd
	NodeTypeReduceBitwiseOr
	NodeTypeReduceBitwiseXor
	NodeTypeReduceLogicalAnd
	NodeTypeReduceLogicalOr
	NodeTypeReduceLogicalXor
	NodeTypeReduceMax
	NodeTypeReduceMin
	NodeTypeReduceProduct
	NodeTypeReduceSum
	NodeTypeReduceWindow
	NodeTypeRem
	NodeTypeReshape
	NodeTypeReverse
	NodeTypeRngBitGenerator
	NodeTypeRound
	NodeTypeRsqrt
	NodeTypeScatterMax
	NodeTypeScatterMin
	NodeTypeScatterSum
	NodeTypeSelectAndScatterMax
	NodeTypeSelectAndScatterMin
	NodeTypeSelectAndScatterSum
	NodeTypeShiftLeft
	NodeTypeShiftRightArithmetic
	NodeTypeShiftRightLogical
	NodeTypeSign
	NodeTypeSin
	NodeTypeSlice
	NodeTypeSqrt
	NodeTypeSub
	NodeTypeTanh
	NodeTypeTranspose
	NodeTypeWhere
)

func NodeTypeString added in v0.11.0

func NodeTypeString(s string) (NodeType, error)

NodeTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.

func NodeTypeValues added in v0.11.0

func NodeTypeValues() []NodeType

NodeTypeValues returns all values of the enum

func (NodeType) IsANodeType added in v0.11.0

func (i NodeType) IsANodeType() bool

IsANodeType returns "true" if the value is listed in the enum definition. "false" otherwise

func (NodeType) MarshalJSON added in v0.11.0

func (i NodeType) MarshalJSON() ([]byte, error)

MarshalJSON implements the json.Marshaler interface for NodeType

func (NodeType) MarshalText added in v0.11.0

func (i NodeType) MarshalText() ([]byte, error)

MarshalText implements the encoding.TextMarshaler interface for NodeType

func (NodeType) MarshalYAML added in v0.11.0

func (i NodeType) MarshalYAML() (interface{}, error)

MarshalYAML implements a YAML Marshaler for NodeType

func (NodeType) String added in v0.11.0

func (i NodeType) String() string

func (*NodeType) UnmarshalJSON added in v0.11.0

func (i *NodeType) UnmarshalJSON(data []byte) error

UnmarshalJSON implements the json.Unmarshaler interface for NodeType

func (*NodeType) UnmarshalText added in v0.11.0

func (i *NodeType) UnmarshalText(text []byte) error

UnmarshalText implements the encoding.TextUnmarshaler interface for NodeType

func (*NodeType) UnmarshalYAML added in v0.11.0

func (i *NodeType) UnmarshalYAML(unmarshal func(interface{}) error) error

UnmarshalYAML implements a YAML Unmarshaler for NodeType

func (NodeType) Values added in v0.11.0

func (NodeType) Values() []string

type NodeXlaHandle

type NodeXlaHandle int

NodeXlaHandle is used by the underlying XLA implementation.

type PadAxis

type PadAxis = backends.PadAxis

PadAxis defines the amount of padding preceding one axis (Start), at the end of axis (End) or in between the inputNodes (Interior). This is used as a parameter for the Pad function. This is an alias to backends.PadAxis

type ParameterHandle

type ParameterHandle int

ParameterHandle is a key to be used by Graph implementations to refer to its internal parameters.

type ParamsMap

type ParamsMap map[*Node]any

ParamsMap is a shortcut for the map of parameters and their values passed to a graph execution. The values are anything that is accepted by tensor.FromAnyValue().

type PoolBuilder

type PoolBuilder struct {
	// contains filtered or unexported fields
}

PoolBuilder is a helper to build a pool computation. Create it with {Max|Sum|Mean|Prod}Pool, set the desired parameters and when set, call `IsNil()`.

func ConcatPool added in v0.11.0

func ConcatPool(x *Node) *PoolBuilder

ConcatPool pool on the spatial dimensions by increasing the channels dimensions, across the window.

Example: x.shape=[batch_size, height, width, 3] and Window(3): the output depth will be 9x3=27, with the concatenation of the channels of all the pixels around.

The implementation actually uses a convolution with a fixed kernel, but it can be seen as a concatenating pool operation.

func MaxPool

func MaxPool(x *Node) *PoolBuilder

MaxPool prepares a max pooling on x with the given kernel for arbitrary number of spatial dimensions (1D, 2D, 3D, etc.). It returns the max value for the selected window, on given strides.

It is very flexible and to ease configuring of its parameters it returns a PoolBuilder for configuration. Once it is set up call `PoolBuilder.Done` and it will return the pooled x. Browse through PoolBuilder to see the capabilities, and the defaults.

The window sizes must be set with PoolBuilder.Window or PoolBuilder.WindowPerAxis.

The shapes of x should be `[batch, <spatial_dimensions...>, input_channels]` if configured with `PoolBuilder.ChannelsAxis(images.ChannelsLast)`, the default. If one sets `PoolBuilder.ChannelsAxis(images.ChannelsFirst)`, the shapes should be `[batch, input_channels, <spatial_dimensions...>]` instead.

The "channels" axis is also known as depth or feature axis.

Note: `images` refers to package `github.com/gomlx/gomlx/types/tensors/images`.

The shapes of kernel should be `[<spatial_dimensions...>, input_channels, output_channels]` if configured with `PoolBuilder.ChannelsAxis(images.ChannelsLast)`, the default. If one sets `PoolBuilder.ChannelsAxis(images.Channels)`, the shapes should be `[input_channels, <spatial_dimensions...>, output_channels]` instead.

func MeanPool added in v0.3.0

func MeanPool(x *Node) *PoolBuilder

MeanPool prepares a mean pooling on x with the given kernel for arbitrary number of spatial dimensions (1D, 2D, 3D, etc.). It returns the mean value for the selected window, on given strides.

It is very flexible and to ease configuring of its parameters it returns a PoolBuilder for configuration. Once it is set up call `PoolBuilder.Done` and it will return the pooled x. Browse through PoolBuilder to see the capabilities, and the defaults.

The window sizes must be set with PoolBuilder.Window or PoolBuilder.WindowPerAxis.

The shapes of x should be `[batch, <spatial_dimensions...>, input_channels]` if configured with `PoolBuilder.ChannelsAxis(images.ChannelsLast)`, the default. If one sets `PoolBuilder.ChannelsAxis(images.ChannelsFirst)`, the shapes should be `[batch, input_channels, <spatial_dimensions...>]` instead.

The "channels" axis is also known as depth or feature axis.

Note: `images` refers to package `github.com/gomlx/gomlx/types/tensor/image`.

The shapes of kernel should be `[<spatial_dimensions...>, input_channels, output_channels]` if configured with `PoolBuilder.ChannelsAxis(images.ChannelsLast)`, the default. If one sets `PoolBuilder.ChannelsAxis(images.Channels)`, the shapes should be `[input_channels, <spatial_dimensions...>, output_channels]` instead.

func MinPool added in v0.11.0

func MinPool(x *Node) *PoolBuilder

MinPool prepares a min pooling on x with the given kernel for arbitrary number of spatial dimensions (1D, 2D, 3D, etc.). It returns the min value for the selected window, on given strides.

It is very flexible and to ease configuring of its parameters it returns a PoolBuilder for configuration. Once it is set up call `PoolBuilder.Done` and it will return the pooled x. Browse through PoolBuilder to see the capabilities, and the defaults.

The window sizes must be set with PoolBuilder.Window or PoolBuilder.WindowPerAxis.

The shapes of x should be `[batch, <spatial_dimensions...>, input_channels]` if configured with `PoolBuilder.ChannelsAxis(images.ChannelsLast)`, the default. If one sets `PoolBuilder.ChannelsAxis(images.ChannelsFirst)`, the shapes should be `[batch, input_channels, <spatial_dimensions...>]` instead.

The "channels" axis is also known as depth or feature axis.

Note: `images` refers to package `github.com/gomlx/gomlx/types/tensors/images`.

The shapes of kernel should be `[<spatial_dimensions...>, input_channels, output_channels]` if configured with `PoolBuilder.ChannelsAxis(images.ChannelsLast)`, the default. If one sets `PoolBuilder.ChannelsAxis(images.Channels)`, the shapes should be `[input_channels, <spatial_dimensions...>, output_channels]` instead.

func SumPool added in v0.3.0

func SumPool(x *Node) *PoolBuilder

SumPool prepares a sum pooling on x with the given kernel for arbitrary number of spatial dimensions (1D, 2D, 3D, etc.). It returns the sum value for the selected window, on given strides.

It is very flexible and to ease configuring of its parameters it returns a PoolBuilder for configuration. Once it is set up call `PoolBuilder.Done` and it will return the pooled x. Browse through PoolBuilder to see the capabilities, and the defaults.

The window sizes must be set with PoolBuilder.Window or PoolBuilder.WindowPerAxis.

The shapes of x should be `[batch, <spatial_dimensions...>, input_channels]` if configured with `PoolBuilder.ChannelsAxis(images.ChannelsLast)`, the default. If one sets `PoolBuilder.ChannelsAxis(images.ChannelsFirst)`, the shapes should be `[batch, input_channels, <spatial_dimensions...>]` instead.

The "channels" axis is also known as depth or feature axis.

Note: `images` refers to package `github.com/gomlx/gomlx/types/tensor/image`.

The shapes of kernel should be `[<spatial_dimensions...>, input_channels, output_channels]` if configured with `PoolBuilder.ChannelsAxis(images.ChannelsLast)`, the default. If one sets `PoolBuilder.ChannelsAxis(images.Channels)`, the shapes should be `[input_channels, <spatial_dimensions...>, output_channels]` instead.

func (*PoolBuilder) ChannelsAxis added in v0.3.0

func (pool *PoolBuilder) ChannelsAxis(channelsAxisConfig images.ChannelsAxisConfig) *PoolBuilder

ChannelsAxis configures the axis for the channels (aka. "depth" or "features") dimension. The default is `images.ChannelsLast`, meaning the "channels" dimension comes last.

Note: `images` refers to package `github.com/gomlx/gomlx/types/tensor/image`.

If you don't want to exclude the batch size and channels from the pooling, use FullShape instead.

It returns the modified Config object, so calls can be cascaded.

func (*PoolBuilder) Done

func (pool *PoolBuilder) Done() *Node

Done indicates that the convolve operation is finished being configured and it updates the computation graph with convolution, and returns the resulting Node.

func (*PoolBuilder) FullShape added in v0.11.1

func (pool *PoolBuilder) FullShape() *PoolBuilder

FullShape configures the pooling operation to consider all its axes as part of the pooling, with no special considerations for the batch or channel axes.

See ChannelsAxis to handle batch and channels specially.

The default is configured with ChannelsAxis(images.ChannelsLast).

func (*PoolBuilder) NoPadding

func (pool *PoolBuilder) NoPadding() *PoolBuilder

NoPadding removes any paddings, so if the kernel spatial dimensions > 1, the output shapes will be reduced on the edges. This is the default.

func (*PoolBuilder) PadSame

func (pool *PoolBuilder) PadSame() *PoolBuilder

PadSame adds paddings on the edges of x such that in the end the output of the convolution has the same shapes as the input.

This changes the default value of Strides to 1, if it is not set -- if set to something else, PadSame won't really output the same shapes as the input.

The default is NoPadding.

func (*PoolBuilder) PaddingPerDim

func (pool *PoolBuilder) PaddingPerDim(paddings [][2]int) *PoolBuilder

PaddingPerDim specifies the paddings at the start and at the end to use per spatial dimension, that means one pair ([2]int) per spatial dimension. The default is PadSame.

func (*PoolBuilder) StridePerAxis added in v0.11.0

func (pool *PoolBuilder) StridePerAxis(strides ...int) *PoolBuilder

StridePerAxis sets the strides for each spatial dimension of the pooling.

The default is the same value as the window size (set with Window or WindowPerAxis) -- except if one uses PadSame, then the default changes to 1.

The stride is how many steps to move after a pooling. A value of 2 will half the input size, since a pooling will be done at every other position, and so on. It can be defined separately per dimension.

One cannot use strides and dilation at the same time.

func (*PoolBuilder) Strides

func (pool *PoolBuilder) Strides(strides int) *PoolBuilder

Strides sets the strides of the pooling. It sets the same value for every spatial dimension.

The default is the same value as the window size (set with Window or WindowPerAxis) -- except if one uses PadSame, then the default changes to 1.

The stride is how many steps to move after the pooling of a window. A value of 2 will halve the input size, since the pooling will be done at every other position, and so on. It can be defined separately per dimension with StridePerAxis.

One cannot use strides and dilation at the same time.

func (*PoolBuilder) Window

func (pool *PoolBuilder) Window(windowSize int) *PoolBuilder

Window sets the pooling window size for all spatial dimensions to the same windowSize.

There is no default, and this must be set either with Window or WindowPerAxis.

func (*PoolBuilder) WindowPerAxis added in v0.11.0

func (pool *PoolBuilder) WindowPerAxis(sizes ...int) *PoolBuilder

WindowPerAxis sets the pooling window size for each spatial dimension.

There is no default, and this must be set either with Window or WindowPerAxis.

type Ragged2D added in v0.18.0

type Ragged2D struct {
	Dim0         int
	Flat, RowIDs *Node
}

Ragged2D is a 2D ragged representation with the first dimension dense and the second axis being ragged. It can be interpreted as an array (fixed size) of variable length lists.

A "ragged" representation is a special type of sparse representation where values are only defined on the start of the axis, and the tail of the axis is assumed to be irrelevant, or in some cases zero.

For now, if the user has ragged tensors with larger rank, its up to them to reshape and transpose around to get to a 2D representation. A generic ragged representation using fixed shaped tensors is a TODO.

To store a "ragged" representation we use a flat compact representation of the data (without the irrelevant parts) and the RowIDs for each elements: which row they are part of. Because GoMLX doesn't (yet) support dynamic shapes, it also takes the static value of the first axis dimension Dim0, it must be known in graph compile time.

See TensorFlow's more generic RaggedTensor in https://www.tensorflow.org/guide/ragged_tensor, which was used as a source of inspiration -- this is a much simpler implementation, but covers many of the use cases.

A note on padding: because of the static shape requirements, it's practice to use the last row as a "padding" row, and assign the RowIDs of the padding values to that extra row.

func MakeRagged2D added in v0.18.0

func MakeRagged2D(dim0 int, flat, rowIDs *Node) Ragged2D

MakeRagged2D creates a new Ragged2D using rowIDs.

The rowIDs _must be sorted_, meaning the flat values must come in row,col order, otherwise many of the operations will display undefined behaviour.

Example:

MakeRagged2D(dim0=4, flat=[1, 2, 3, 4, 5], rowsIds=[0, 0, 0, 1, 3]) represents the following 4x3 ragged 2D tensor :

{ {1, 2, 3},
  {4},
  {},
  {5} }

func (Ragged2D) DType added in v0.18.0

func (r Ragged2D) DType() dtypes.DType

DType returns the dtype of the flat values.

func (Ragged2D) Graph added in v0.18.0

func (r Ragged2D) Graph() *Graph

Graph associated with the Flat and RowIDs nodes.

func (Ragged2D) ReduceMaxCols added in v0.18.0

func (r Ragged2D) ReduceMaxCols() *Node

ReduceMaxCols returns the max over the ragged axis (columns). It returns a 1D tensor of shape [Ragged2D.Dim0] with the values for each rows.

Rows with no ragged values, will have -Inf values.

func (Ragged2D) ReduceMinCols added in v0.18.0

func (r Ragged2D) ReduceMinCols() *Node

ReduceMinCols returns the min over the ragged axis (columns). It returns a 1D tensor of shape [Ragged2D.Dim0] with the values for each rows.

Rows with no ragged values, will have Inf values.

func (Ragged2D) ReduceSumCols added in v0.18.0

func (r Ragged2D) ReduceSumCols() *Node

ReduceSumCols returns the sum over the ragged axis (columns). It returns a 1D tensor of shape [Ragged2D.Dim0].

func (Ragged2D) Softmax added in v0.18.0

func (r Ragged2D) Softmax() Ragged2D

Softmax of the Ragged2D matrix, returns a Ragged2D with the values converted to probabilities.

Notice that values not represented (the tail of each row) does not participate in the Softmax. It is as if they were -inf.

type ReduceOpType added in v0.11.0

type ReduceOpType = backends.ReduceOpType

type ShiftDirection added in v0.10.0

type ShiftDirection bool

ShiftDirection used by ShiftWithScalar and ShiftWithValue. See ShiftDirLeft and ShiftDirRight.

const (
	ShiftDirLeft  ShiftDirection = false
	ShiftDirRight                = true
)

func (ShiftDirection) String added in v0.10.0

func (s ShiftDirection) String() string

String implements the stringer interface.

type SideParamsFn

type SideParamsFn func(graph *Graph, inputBuffers []backends.Buffer, donate []bool)

SideParamsFn is the functions that sets side parameters during execution for Graphs that defines those. Typically, this is used to set the variables of a model.

type SingleOutputVJP added in v0.11.0

type SingleOutputVJP func(node, v *Node, outputShape shapes.Shape) []*Node

SingleOutputVJP for VJP of ops that have a single output (most of them).

type SliceAxisSpec added in v0.6.0

type SliceAxisSpec struct {
	Start, End, StrideValue int
	Full, NoEnd             bool
	IsSpacer                bool
}

SliceAxisSpec specifies the range and stride of an axis to include in a Slice.

The recommendation is to use AxisRange or AxisElem (defined below) to create it.

Full means to include the whole range (and ignore Start/End), and NoEnd means from Start to the full dimension of the axis.

Optional (if Stride != 0) it can set the stride for the axis as well.

Spacer means this AxisRange should be the generic definition for all undefined axes -- useful when the rank of the node is not known.

Consider using function AxisRange below to construct SliceAxisSpec values.

TODO: Add strides.

func AxisElem added in v0.6.0

func AxisElem(index int) SliceAxisSpec

AxisElem defines a range of one element to take for an axis in Slice. It returns an `SliceAxisSpec` object.

func AxisRange

func AxisRange(indices ...int) SliceAxisSpec

AxisRange defines a range to take for an axis in Slice. It returns an `SliceAxisSpec` object.

The indices can have 0, 1 or 2 elements: - If `len(indices) == 0`, it's assumed to be the full range of the axis. - If `len(indices) == 1`, it's assumed to be the start, and the range should be taken to the end. - If `len(indices) == 2`, they should be the start and end indices for the axis. - If `len(indices) > 2`, an error is raised with panic.

See also AxisElem if you want to define only one element of the range.

func AxisRangeFromStart added in v0.11.0

func AxisRangeFromStart(to int) SliceAxisSpec

AxisRangeFromStart defines a range from the start (0) to the given value for the axis. It's return value is to be used by Slice to specify one axis.

func AxisRangeToEnd added in v0.11.0

func AxisRangeToEnd(from int) SliceAxisSpec

AxisRangeToEnd defines a range from the given value to the end of the axis. It's return value is to be used by Slice to specify one axis.

func (SliceAxisSpec) Spacer added in v0.6.0

func (ar SliceAxisSpec) Spacer() SliceAxisSpec

Spacer marks this SliceAxisSpec to be a generic filler range to use on the undefined axes in Slice -- similar to a "*" in a path definition.

It works with any SliceAxisSpec, so it can be used with the return of any call to AxisRange or AxisElem.

Example: let's say we want to get just the last example of a batch, and just the first element of the embedding. Assume x is shaped `[batch_size, ..., embedding_size]` and we want something like `x[-1, ..., 0:1]`

sample := Slice(x, AxisElem(-1), AxisRange().Spacer(), AxisElem(0))

Notice that "spacer" ranges also matches zero dimensions. So if x is shaped `[5, 5]`, calling `Slice(x, AxisElem(0), AxisRange().Spacer(), AxisElem(0))` would return a node of shape `[1, 1]` and the spacer would be ignored.

func (SliceAxisSpec) Stride added in v0.6.0

func (ar SliceAxisSpec) Stride(stride int) SliceAxisSpec

Stride returns a copy of the SliceAxisSpec with Stride set to the given stride.

type VJP

type VJP func(node *Node, vjpOutputs []*Node, outputShape shapes.Shape) []*Node

VJP returns the $v \dot Jacobian$ of the given `node`, with respect to each of its inputNodes (given by `node.Inputs()`).

outputShape is the shape of the value for which we are calculating the gradient for. For now this is only used for Gradient, so one can expect outputShape to be scalar, and `v.Shape()` to be the same as `output.Shape()`. But this won't be true once Jacobian functionality (like a Gradient where output is a non-scalar tensor), is defined.

Args:

node: node for which we are calculating the backward gradient. An important part of `node` are its inputNodes
   given by `node.Inputs()`. The VJP function must one gradient per input.
v: gradient (or jacobian) of what we care about with the respect to the output of `node`, also known as the
   adjoint. VJP needs to calculate the next adjoint (`v`) but with respect to the inputNodes of `node` instead.
outputShape: for now fixed as scalar, and can be ignored. When we add support for Jacobian this will hold
   the outputShape of the thing we are calculating the Jacobian for.

Returns:

The adjoint (the updated `vjps`) to each of `node` inputNodes. That is, the gradient of the loss (typically, but of
whatever we are calculating the gradient of) with respect to each of the `node` inputNodes.

Directories

Path Synopsis
Package graphtest holds test utilities for packages that depend on the graph package.
Package graphtest holds test utilities for packages that depend on the graph package.
Package nanlogger collects `graph.Node` objects to monitor for `NaN` ("not-a-number") or `Inf` (infinity) values.
Package nanlogger collects `graph.Node` objects to monitor for `NaN` ("not-a-number") or `Inf` (infinity) values.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL