candy

package module
v0.3.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Oct 29, 2025 License: BSD-3-Clause Imports: 4 Imported by: 0

README

Candy Logo

Candy: ML framework for Go

Install

go get github.com/gocnn/candy

License

Candy has BSD 3-Clause License, see LICENSE.Logo icon: Fire Bold by Phosphor (MIT License).

Documentation

Index

Constants

View Source

Logo is the ASCII art logo for the candy framework

Variables

This section is empty.

Functions

func ResolveAxes

func ResolveAxes(axes []int, s *Shape) ([]int, error)

ResolveAxes resolves a list of axis indices, supporting negative indices. It checks for duplicates and out-of-range values.

func ResolveAxis

func ResolveAxis(axis, rank int) (int, error)

ResolveAxis resolves a single axis index, supporting negative values.

Types

type BackendDevice

type BackendDevice[T D] interface {
	// Location returns the device location (e.g., CPU, GPU, device ID).
	Location() DeviceLocation

	// IsSame checks if two devices are the same.
	IsSame(BackendDevice[T]) bool

	// StorageFromSlice creates storage from a slice of values.
	StorageFromSlice([]T) (BackendStorage[T], error)

	// SetSeed sets the random seed for the device.
	SetSeed(uint64) error

	// RandUniform creates storage with random values from a uniform distribution.
	RandUniform(*Shape, DType, float64, float64) (BackendStorage[T], error)

	// RandNormal creates storage with random values from a normal distribution.
	RandNormal(*Shape, DType, float64, float64) (BackendStorage[T], error)

	// Alloc allocates a zero-initialized storage for the given shape.
	Alloc(*Shape, DType) (BackendStorage[T], error)

	// Zeros creates a storage filled with zeros.
	Zeros(*Shape, DType) (BackendStorage[T], error)

	// Ones creates a storage filled with ones.
	Ones(*Shape, DType) (BackendStorage[T], error)

	// Full creates a storage filled with a specific value.
	Full(*Shape, DType, float64) (BackendStorage[T], error)

	// Synchronize blocks until all operations on the device are complete.
	Synchronize() error
}

BackendDevice defines operations for device management.

type BackendStorage

type BackendStorage[T D] interface {
	// Clone creates a deep copy of the storage.
	Clone() (BackendStorage[T], error)

	// Data returns a copy of the underlying data.
	Data() []T

	// Device returns the associated device.
	Device() BackendDevice[T]

	// DType returns the data type of the storage.
	DType() DType

	// Affine applies an affine transformation (scale * x + bias) to the storage.
	Affine(layout *Layout, scale, bias T) (BackendStorage[T], error)

	// Add performs element-wise addition between this and another storage.
	Add(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[T], error)

	// Sub performs element-wise subtraction between this and another storage.
	Sub(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[T], error)

	// Mul performs element-wise multiplication between this and another storage.
	Mul(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[T], error)

	// Div performs element-wise division between this and another storage.
	Div(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[T], error)

	// Maximum performs element-wise maximum of two tensors.
	Maximum(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[T], error)

	// Minimum performs element-wise minimum of two tensors.
	Minimum(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[T], error)

	// Eq performs element-wise equality comparison of two tensors.
	Eq(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[T], error)

	// Ne performs element-wise not-equal comparison of two tensors.
	Ne(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[T], error)

	// Lt performs element-wise less-than comparison of two tensors.
	Lt(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[T], error)

	// Le performs element-wise less-than-or-equal comparison of two tensors.
	Le(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[T], error)

	// Gt performs element-wise greater-than comparison of two tensors.
	Gt(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[T], error)

	// Ge performs element-wise greater-than-or-equal comparison of two tensors.
	Ge(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[T], error)

	// EqU8 performs element-wise equality comparison of two tensors.
	EqU8(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[uint8], error)

	// NeU8 performs element-wise not-equal comparison of two tensors.
	NeU8(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[uint8], error)

	// LtU8 performs element-wise less-than comparison of two tensors.
	LtU8(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[uint8], error)

	// LeU8 performs element-wise less-than-or-equal comparison of two tensors.
	LeU8(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[uint8], error)

	// GtU8 performs element-wise greater-than comparison of two tensors.
	GtU8(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[uint8], error)

	// GeU8 performs element-wise greater-than-or-equal comparison of two tensors.
	GeU8(rhs BackendStorage[T], lhsLayout, rhsLayout, resLayout *Layout) (BackendStorage[uint8], error)

	// ToDtype performs type conversion to the specified target type.
	ToDtype(layout *Layout, dtype DType) (any, error)

	// MatMul performs matrix multiplication: C = A * B
	MatMul(lhsLayout *Layout, rhs BackendStorage[T], rhsLayout *Layout, b, m, n, k int) (BackendStorage[T], error)

	// Conv1d performs 1D convolution using im2col + BLAS for supported types.
	Conv1d(layout *Layout, kernel BackendStorage[T], kernelLayout *Layout, params *Conv1DParams) (BackendStorage[T], error)

	// ConvTranspose1d performs 1D transposed convolution (deconvolution) for supported types.
	ConvTranspose1d(layout *Layout, kernel BackendStorage[T], kernelLayout *Layout, params *ConvT1DParams) (BackendStorage[T], error)

	// Conv2d performs 2D convolution using im2col + BLAS for supported types.
	Conv2d(layout *Layout, kernel BackendStorage[T], kernelLayout *Layout, params *Conv2DParams) (BackendStorage[T], error)

	// ConvTranspose2d performs 2D transposed convolution (deconvolution) for supported types.
	ConvTranspose2d(layout *Layout, kernel BackendStorage[T], kernelLayout *Layout, params *ConvT2DParams) (BackendStorage[T], error)

	// AvgPool2d performs 2D average pooling for supported types.
	AvgPool2d(layout *Layout, kH, kW, sH, sW int) (BackendStorage[T], error)

	// MaxPool2d performs 2D max pooling for supported types.
	MaxPool2d(layout *Layout, kH, kW, sH, sW int) (BackendStorage[T], error)

	// UpsampleNearest2d performs 2D nearest neighbor upsampling for supported types.
	UpsampleNearest2d(layout *Layout, targetH, targetW int) (BackendStorage[T], error)

	// ConstSet sets all elements to a constant value for supported types.
	ConstSet(layout *Layout, val T) error

	// Gather performs gather operation along a specified dimension with same-type indices.
	Gather(layout *Layout, ids BackendStorage[T], idsLayout *Layout, dim int) (BackendStorage[T], error)

	// Scatter performs scatter operation along a specified dimension with same-type indices.
	Scatter(layout *Layout, ids BackendStorage[T], idsLayout *Layout, src BackendStorage[T], srcLayout *Layout, dim int) (BackendStorage[T], error)

	// ScatterAdd performs scatter-add operation along a specified dimension.
	ScatterAdd(layout *Layout, ids BackendStorage[T], idsLayout *Layout, src BackendStorage[T], srcLayout *Layout, dim int) (BackendStorage[T], error)

	// Copy2d copies a 2D region from source to destination for supported types.
	Copy2d(dst BackendStorage[T], d1, d2, srcStride1, dstStride1, srcOffset, dstOffset int) error

	// FastSum computes the sum over the last dimension.
	FastSum(layout *Layout) (BackendStorage[T], error)

	// FastMin computes the minimum over the last dimension.
	FastMin(layout *Layout) (BackendStorage[T], error)

	// FastMax computes the maximum over the last dimension.
	FastMax(layout *Layout) (BackendStorage[T], error)

	// FastArgmin computes the indices of minimum values over the last dimension.
	FastArgmin(layout *Layout) (BackendStorage[uint32], error)

	// FastArgmax computes the indices of maximum values over the last dimension.
	FastArgmax(layout *Layout) (BackendStorage[uint32], error)

	// Sum performs summation along specified dimensions.
	Sum(layout *Layout, dims []int) (BackendStorage[T], error)

	// Min computes the minimum over the specified dimension.
	Min(layout *Layout, dim int) (BackendStorage[T], error)

	// Max computes the maximum over the specified dimension.
	Max(layout *Layout, dim int) (BackendStorage[T], error)

	// Argmin computes the index of minimum over the specified dimension.
	Argmin(layout *Layout, dim int) (BackendStorage[uint32], error)

	// Argmax computes the index of maximum over the specified dimension.
	Argmax(layout *Layout, dim int) (BackendStorage[uint32], error)

	// FastFastSoftmax performs softmax along the last dimension.
	FastSoftmax(layout *Layout) (BackendStorage[T], error)

	// FastRmsNorm performs RMS normalization along the last dimension.
	FastRmsNorm(layout *Layout, alpha BackendStorage[T], alphaLayout *Layout, eps T) (BackendStorage[T], error)

	// FastLayerNorm performs Layer normalization along the last dimension.
	FastLayerNorm(layout *Layout, alpha BackendStorage[T], alphaLayout *Layout, beta BackendStorage[T], betaLayout *Layout, eps T) (BackendStorage[T], error)

	// RopeI performs rotary position embedding (rope_i variant).
	RopeI(layout *Layout, cos BackendStorage[T], cosLayout *Layout, sin BackendStorage[T], sinLayout *Layout) (BackendStorage[T], error)

	// Rope performs rotary position embedding (rope variant).
	Rope(layout *Layout, cos BackendStorage[T], cosLayout *Layout, sin BackendStorage[T], sinLayout *Layout) (BackendStorage[T], error)

	// RopeThd performs rotary position embedding (rope_thd variant).
	RopeThd(layout *Layout, cos BackendStorage[T], cosLayout *Layout, sin BackendStorage[T], sinLayout *Layout) (BackendStorage[T], error)

	// WhereCond performs element-wise selection based on condition.
	WhereCond(condLayout *Layout, t BackendStorage[T], tLayout *Layout, f BackendStorage[T], fLayout *Layout) (BackendStorage[T], error)

	// Copy performs element-wise copy operation.
	Copy(layout *Layout, src BackendStorage[T]) (BackendStorage[T], error)

	// Neg performs element-wise negation operation.
	Neg(layout *Layout) (BackendStorage[T], error)

	// Recip performs element-wise reciprocal operation.
	Recip(layout *Layout) (BackendStorage[T], error)

	// Exp performs element-wise exponential operation.
	Exp(layout *Layout) (BackendStorage[T], error)

	// Log performs element-wise logarithm operation.
	Log(layout *Layout) (BackendStorage[T], error)

	// Sin performs element-wise sine operation.
	Sin(layout *Layout) (BackendStorage[T], error)

	// Cos performs element-wise cosine operation.
	Cos(layout *Layout) (BackendStorage[T], error)

	// Tanh performs element-wise hyperbolic tangent operation.
	Tanh(layout *Layout) (BackendStorage[T], error)

	// Erf performs element-wise error function operation.
	Erf(layout *Layout) (BackendStorage[T], error)

	// Ceil performs element-wise ceiling operation.
	Ceil(layout *Layout) (BackendStorage[T], error)

	// Floor performs element-wise floor operation.
	Floor(layout *Layout) (BackendStorage[T], error)

	// Round performs element-wise round operation.
	Round(layout *Layout) (BackendStorage[T], error)

	// Normcdf performs element-wise normal CDF operation.
	Normcdf(layout *Layout) (BackendStorage[T], error)

	// Abs performs element-wise absolute value operation.
	Abs(layout *Layout) (BackendStorage[T], error)

	// Sqr performs element-wise square operation.
	Sqr(layout *Layout) (BackendStorage[T], error)

	// Sqrt performs element-wise square root operation.
	Sqrt(layout *Layout) (BackendStorage[T], error)

	// Gelu performs element-wise GELU activation operation.
	Gelu(layout *Layout) (BackendStorage[T], error)

	// GeluErf performs element-wise GELU (ERF-based) activation operation.
	GeluErf(layout *Layout) (BackendStorage[T], error)

	// Relu performs element-wise ReLU activation operation.
	Relu(layout *Layout) (BackendStorage[T], error)

	// Elu performs element-wise ELU activation operation with parameter alpha.
	Elu(layout *Layout, alpha T) (BackendStorage[T], error)

	// Silu performs element-wise SiLU (Swish) activation operation.
	Silu(layout *Layout) (BackendStorage[T], error)

	// Powf performs element-wise power operation with parameter param.
	Powf(layout *Layout, param T) (BackendStorage[T], error)

	// Sign performs element-wise sign operation.
	Sign(layout *Layout) (BackendStorage[T], error)

	// Sigmoid performs element-wise sigmoid activation operation.
	Sigmoid(layout *Layout) (BackendStorage[T], error)
}

BackendStorage defines operations for tensor storage management.

type ContiguousOffsetsWithBroadcast

type ContiguousOffsetsWithBroadcast struct {
	Start          int
	Len            int
	LeftBroadcast  int
	RightBroadcast int
}

ContiguousOffsetsWithBroadcast represents contiguous storage with broadcasted dimensions.

type Conv1DParams

type Conv1DParams struct {
	Batch  int
	InLen  int // Input length
	OutCh  int // Output channels
	InCh   int // Input channels
	KSize  int // Kernel size
	Pad    int
	Stride int
	Dilate int
	Algo   *FwdAlgo // Optional cuDNN forward algorithm
}

Conv1DParams holds parameters for 1D convolution.

func (Conv1DParams) OutDims

func (p Conv1DParams) OutDims() []int

OutDims returns the output dimensions [batch, out_channels, out_length].

func (Conv1DParams) OutLen

func (p Conv1DParams) OutLen() int

OutLen computes the output length for 1D convolution.

type Conv2DParams

type Conv2DParams struct {
	Batch  int // Batch size
	InH    int // Input height
	InW    int // Input width
	KH     int // Kernel height
	KW     int // Kernel width
	OutCh  int // Output channels
	InCh   int // Input channels
	Pad    int
	Stride int
	Dilate int
	Algo   *FwdAlgo // Optional cuDNN forward algorithm
}

Conv2DParams holds parameters for 2D convolution. Assumes uniform padding, stride, and dilation for height and width.

func (Conv2DParams) OutDims

func (p Conv2DParams) OutDims() []int

OutDims returns the output dimensions [batch, out_channels, out_height, out_width].

func (Conv2DParams) OutH

func (p Conv2DParams) OutH() int

OutH computes the output height for 2D convolution.

func (Conv2DParams) OutW

func (p Conv2DParams) OutW() int

OutW computes the output width for 2D convolution.

type ConvT1DParams

type ConvT1DParams struct {
	Batch  int
	InLen  int // Input length
	OutCh  int // Output channels
	InCh   int // Input channels
	KSize  int // Kernel size
	Pad    int
	OutPad int
	Stride int
	Dilate int
}

ConvT1DParams holds parameters for 1D transposed convolution.

func (ConvT1DParams) OutDims

func (p ConvT1DParams) OutDims() []int

OutDims returns the output dimensions [batch, out_channels, out_length].

func (ConvT1DParams) OutLen

func (p ConvT1DParams) OutLen() int

OutLen computes the output length for 1D transposed convolution.

type ConvT2DParams

type ConvT2DParams struct {
	Batch  int // Batch size
	InH    int // Input height
	InW    int // Input width
	KH     int // Kernel height
	KW     int // Kernel width
	OutCh  int // Output channels
	InCh   int // Input channels
	Pad    int
	OutPad int
	Stride int
	Dilate int
}

ConvT2DParams holds parameters for 2D transposed convolution. Assumes uniform padding, output_padding, stride, and dilation for height and width.

func (ConvT2DParams) OutDims

func (p ConvT2DParams) OutDims() []int

OutDims returns the output dimensions [batch, out_channels, out_height, out_width].

func (ConvT2DParams) OutH

func (p ConvT2DParams) OutH() int

OutH computes the output height for 2D transposed convolution.

func (ConvT2DParams) OutW

func (p ConvT2DParams) OutW() int

OutW computes the output width for 2D transposed convolution.

type D

type D interface {
	float32 | float64 | uint8 | uint32 | int64
}

D is the type constraint for matrices defined in this package.

type DType

type DType int
const (
	F32 DType = iota
	F64
	F16
	BF16
	U8
	U32
	I64
)

func DTypeOf

func DTypeOf[T D]() DType

DTypeOf returns the DType corresponding to the given type parameter

func (DType) IsFloat

func (d DType) IsFloat() bool

func (DType) IsInteger

func (d DType) IsInteger() bool

func (DType) String

func (d DType) String() string

type Device

type Device int
const (
	CPU Device = iota
	CUDA
	Metal
)

func (Device) String

func (d Device) String() string

String returns the string representation of the device.

type DeviceLocation

type DeviceLocation int

DeviceLocation represents the location of a device (e.g., CPU, GPU, device ID).

const (
	CpuLocation DeviceLocation = iota
	GpuLocation
	MetalLocation
)

func (DeviceLocation) String

func (dl DeviceLocation) String() string

String implements fmt.Stringer for DeviceLocation.

type FwdAlgo

type FwdAlgo int

FwdAlgo represents forward convolution algorithms supported by cuDNN.

const (
	FwdAlgoImplicitGEMM        FwdAlgo = iota // 0
	FwdAlgoImplicitPrecompGEMM                // 1
	FwdAlgoGEMM                               // 2
	FwdAlgoDirect                             // 3
	FwdAlgoFFT                                // 4
	FwdAlgoFFTTiling                          // 5
	FwdAlgoWinograd                           // 6
	FwdAlgoWinogradNonfused                   // 7
	FwdAlgoCount                              // 8
)

type Layout

type Layout struct {
	// contains filtered or unexported fields
}

Layout represents the layout of a tensor, including shape, strides, and starting offset. Strides are in number of elements, not bytes.

func Contiguous

func Contiguous(shape *Shape) *Layout

Contiguous creates a contiguous (row-major) layout starting at offset 0.

func ContiguousWithOffset

func ContiguousWithOffset(shape *Shape, startOffset int) *Layout

ContiguousWithOffset creates a contiguous (row-major) layout with the given start offset.

func NewLayout

func NewLayout(shape *Shape, stride []int, startOffset int) *Layout

NewLayout creates a new Layout with the given shape, stride, and start offset. It clones the inputs to ensure immutability.

func (*Layout) BroadcastAs

func (l *Layout) BroadcastAs(target *Shape) (*Layout, error)

BroadcastAs returns a new layout broadcasted to the target shape.

func (*Layout) Clone

func (l *Layout) Clone() *Layout

Clone returns a deep copy of the layout.

func (*Layout) ContiguousOffsets

func (l *Layout) ContiguousOffsets() (start, end int, ok bool)

ContiguousOffsets returns the start and end offsets if the layout is contiguous, along with a boolean indicating if it is contiguous.

func (*Layout) Dim

func (l *Layout) Dim(dim int) int

Dim returns the size of the specified dimension, supporting negative indices.

func (*Layout) Dims

func (l *Layout) Dims() []int

Dims returns the dimensions of the shape.

func (*Layout) Dims0

func (s *Layout) Dims0() error

Dims0 checks if the shape has 0 dimensions (scalar).

func (*Layout) Dims1

func (s *Layout) Dims1() (int, error)

Dims1 extracts the single dimension from a 1D shape.

func (*Layout) Dims2

func (s *Layout) Dims2() (int, int, error)

Dims2 extracts the two dimensions from a 2D shape.

func (*Layout) Dims3

func (s *Layout) Dims3() (int, int, int, error)

Dims3 extracts the three dimensions from a 3D shape.

func (*Layout) Dims4

func (s *Layout) Dims4() (int, int, int, int, error)

Dims4 extracts the four dimensions from a 4D shape.

func (*Layout) Dims5

func (s *Layout) Dims5() (int, int, int, int, int, error)

Dims5 extracts the five dimensions from a 5D shape.

func (*Layout) IsContiguous

func (l *Layout) IsContiguous() bool

IsContiguous returns true if the strides represent a C-contiguous (row-major) layout.

func (*Layout) IsFortranContiguous

func (l *Layout) IsFortranContiguous() bool

IsFortranContiguous returns true if the strides represent a Fortran-contiguous (column-major) layout.

func (*Layout) Narrow

func (l *Layout) Narrow(dim, start, len int) (*Layout, error)

Narrow returns a new layout narrowed along the specified dimension from start to start+len.

func (*Layout) Numel

func (l *Layout) Numel() int

Numel returns the total number of elements in the layout.

func (*Layout) OffsetsB

func (l *Layout) OffsetsB() (ContiguousOffsetsWithBroadcast, bool)

OffsetsB returns contiguous offsets with broadcast dimensions if applicable, along with a boolean indicating success.

func (*Layout) Permute

func (l *Layout) Permute(idxs []int) (*Layout, error)

Permute returns a new layout with dimensions reordered according to the permutation indices.

func (*Layout) Rank

func (l *Layout) Rank() int

Rank returns the number of dimensions (rank) of the layout.

func (*Layout) Shape

func (l *Layout) Shape() *Shape

Shape returns a copy of the shape.

func (*Layout) StartOffset

func (l *Layout) StartOffset() int

StartOffset returns the starting offset.

func (*Layout) Stride

func (l *Layout) Stride() []int

Stride returns a copy of the stride slice.

func (*Layout) String

func (l *Layout) String() string

String returns a string representation of the layout.

func (*Layout) Transpose

func (l *Layout) Transpose(dim1, dim2 int) (*Layout, error)

Transpose returns a new layout with the two specified dimensions swapped.

type Shape

type Shape struct {
	// contains filtered or unexported fields
}

Shape represents the dimensions of a tensor.

func NewShape

func NewShape(dims ...int) *Shape

NewShape creates a new Shape from the given dimensions.

func NewShapeFrom

func NewShapeFrom(dims []int) *Shape

NewShapeFrom creates a new Shape from a slice of dimensions.

func (*Shape) BroadcastShapeBinaryOp

func (s *Shape) BroadcastShapeBinaryOp(rhs *Shape) (*Shape, error)

BroadcastShapeBinaryOp computes the broadcasted shape for binary operations.

Broadcasting rules (NumPy-compatible): - Align shapes from the rightmost dimension - Dimensions are compatible if they are equal or one of them is 1 - Missing dimensions are treated as 1

Examples:

[3, 1, 4] + [2, 4] -> [3, 2, 4]  (missing dim treated as 1)
[5, 1, 3] * [1, 4, 1] -> [5, 4, 3]  (1s broadcast to larger dims)
[3, 4] + [2, 5] -> panic (incompatible: 4≠5 and neither is 1)

func (*Shape) BroadcastShapeMatmul

func (s *Shape) BroadcastShapeMatmul(rhs *Shape) (*Shape, *Shape, error)

BroadcastShapeMatmul returns the broadcasted shapes for matrix multiplication. It broadcasts the batch dimensions and checks the inner dimensions for compatibility.

func (*Shape) Clone

func (s *Shape) Clone() *Shape

Clone returns a deep copy of the Shape.

func (*Shape) Dim

func (s *Shape) Dim(dim int) int

Dim returns the size of the dimension at the given index. Negative indices count from the end (-1 is the last dimension).

func (*Shape) Dims

func (s *Shape) Dims() []int

Dims returns a copy of the dimensions slice.

func (*Shape) Dims0

func (s *Shape) Dims0() error

Dims0 checks if the shape has 0 dimensions (scalar).

func (*Shape) Dims1

func (s *Shape) Dims1() (int, error)

Dims1 extracts the single dimension from a 1D shape.

func (*Shape) Dims2

func (s *Shape) Dims2() (int, int, error)

Dims2 extracts the two dimensions from a 2D shape.

func (*Shape) Dims3

func (s *Shape) Dims3() (int, int, int, error)

Dims3 extracts the three dimensions from a 3D shape.

func (*Shape) Dims4

func (s *Shape) Dims4() (int, int, int, int, error)

Dims4 extracts the four dimensions from a 4D shape.

func (*Shape) Dims5

func (s *Shape) Dims5() (int, int, int, int, int, error)

Dims5 extracts the five dimensions from a 5D shape.

func (*Shape) Equal

func (s *Shape) Equal(other *Shape) bool

Equal checks if two shapes are equal.

func (*Shape) Extend

func (s *Shape) Extend(add ...int) *Shape

Extend returns a new Shape with additional dimensions appended.

func (*Shape) IsContiguous

func (s *Shape) IsContiguous(strides []int) bool

IsContiguous checks if the given strides are C-contiguous (row-major).

func (*Shape) IsFortranContiguous

func (s *Shape) IsFortranContiguous(strides []int) bool

IsFortranContiguous checks if the given strides are Fortran-contiguous (column-major).

func (*Shape) IsMatrix

func (s *Shape) IsMatrix() bool

IsMatrix returns true if the shape represents a matrix (2 dimensions).

func (*Shape) IsScalar

func (s *Shape) IsScalar() bool

IsScalar returns true if the shape represents a scalar (0 dimensions).

func (*Shape) IsVector

func (s *Shape) IsVector() bool

IsVector returns true if the shape represents a vector (1 dimension).

func (*Shape) Numel

func (s *Shape) Numel() int

Numel returns the total number of elements (product of all dimensions).

func (*Shape) Rank

func (s *Shape) Rank() int

Rank returns the number of dimensions (rank) of the shape.

func (*Shape) Reshape

func (s *Shape) Reshape(newDims ...int) (*Shape, error)

Reshape returns a new shape with the given dimensions, inferring one dimension if -1 is provided. The total element count must match.

func (*Shape) StrideContiguous

func (s *Shape) StrideContiguous() []int

StrideContiguous returns the strides for a contiguous (row-major) tensor with this shape.

func (*Shape) String

func (s *Shape) String() string

String returns a string representation of the shape.

Directories

Path Synopsis
alexnet command
autodiff command
lenet command
numpy/read command
numpy/write command
resnet command
internal
cpu
nn

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL