mx

package
v0.0.0-...-814de94 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Apr 26, 2020 License: Apache-2.0 Imports: 13 Imported by: 0

Documentation

Index

Constants

View Source
const (
	DimRow    = 0
	DimColumn = 1
	DimDepth  = 2
	DimDepth3 = 3
)
View Source
const (
	VersionMajor = 1
	VersionMinor = 5
	VersionPatch = 0
)
View Source
const (
	OpVar_    capi.MxnetOp = -1
	OpInput_  capi.MxnetOp = -2
	OpScalar_ capi.MxnetOp = -4
	OpNogVar_ capi.MxnetOp = -5
	OpGroup_  capi.MxnetOp = -7
)
View Source
const MaxDimensionCount = 4

do not change this constant code can assume exactly this value

Variables

This section is empty.

Functions

func AdamUpdate

func AdamUpdate(params, grads, mean, variance *NDArray, lr, beta1, beta2, epsilon, wd float32) error

func GpuCount

func GpuCount() int

func NextSymbolId

func NextSymbolId() int

func Nograd

func Nograd(_hidden_nograd_)

func RandomSeed

func RandomSeed(seed int)

func ResetSymbolId

func ResetSymbolId(first int)

func SgdMomUpdate

func SgdMomUpdate(params, grads, mom *NDArray, lr, momentum, wd float32) error

func SgdUpdate

func SgdUpdate(params, grads *NDArray, lr, wd float32) error

Types

type ActivationType

type ActivationType int
const (
	ReLU ActivationType = iota
	SoftReLU
	SoftSign
	ActivSigmoid
	ActivTanh
)

type Context

type Context int
const (
	NullContext Context = 0
	CPU         Context = 1
	GPU0        Context = 2
	GPU1        Context = 1002
)

func Gpu

func Gpu(no int) Context

func (Context) Array

func (c Context) Array(tp Dtype, d Dimension, vals ...interface{}) *NDArray

func (Context) CopyAs

func (c Context) CopyAs(a *NDArray, dtype Dtype) *NDArray

func (Context) DevNo

func (c Context) DevNo() int

func (Context) DevType

func (c Context) DevType() int

func (Context) IsGPU

func (c Context) IsGPU() bool

func (Context) RandomSeed

func (c Context) RandomSeed(seed int)

func (Context) String

func (c Context) String() string

type Dimension

type Dimension struct {
	Shape [MaxDimensionCount]int
	Len   int
}

Array Dimension

func Dim

func Dim(a ...int) Dimension

creates new dimension object

func (Dimension) Good

func (dim Dimension) Good() bool

check array dimension

func (Dimension) Push

func (dim Dimension) Push(i int) Dimension

func (Dimension) SizeOf

func (dim Dimension) SizeOf(dt Dtype) int

sizeof whole array data

func (Dimension) Skip

func (dim Dimension) Skip(n int) Dimension

func (Dimension) String

func (dim Dimension) String() string

represent array dimension as string

func (Dimension) Total

func (dim Dimension) Total() int

total elements in the whole array

type Dtype

type Dtype int
const (
	Float32 Dtype = 0
	Float64 Dtype = 1
	Float16 Dtype = 2
	Uint8   Dtype = 3
	Int32   Dtype = 4
	Int8    Dtype = 5
	Int64   Dtype = 6
)

func (Dtype) Size

func (tp Dtype) Size() int

func (Dtype) String

func (tp Dtype) String() string

type Graph

type Graph struct {
	Ctx   Context
	Dtype Dtype

	Input   *NDArray   // network input
	Outputs []*NDArray // referencing to executor outputs except loss
	Loss    *NDArray   // referencing to last executor output
	Label   *NDArray   // loss function label
	Params  map[string]Param

	Exec         capi.ExecutorHandle
	Initializers map[string]Inite
	Initialized  bool
	// contains filtered or unexported fields
}

func Compose

func Compose(
	ctx Context,
	sym *Symbol,
	loss Loss,
	input Dimension,
	dtype Dtype) (*Graph, error)

func (*Graph) Backward

func (g *Graph) Backward() error

func (*Graph) Forward

func (g *Graph) Forward(train bool) error

func (*Graph) GetShapes

func (g *Graph) GetShapes(withLoss bool) (map[string][]int, error)

func (*Graph) Identity

func (g *Graph) Identity() GraphIdentity

func (*Graph) Initialize

func (g *Graph) Initialize(inite func(*NDArray, string) error) error

func (*Graph) LoadParams

func (g *Graph) LoadParams(reader io.Reader) error

func (*Graph) LogSummary

func (g *Graph) LogSummary(withLoss bool) error

func (*Graph) PrintSummary

func (g *Graph) PrintSummary(withLoss bool) error

func (*Graph) Release

func (g *Graph) Release()

func (*Graph) SaveParams

func (g *Graph) SaveParams(writer io.Writer) error

func (*Graph) Summary

func (g *Graph) Summary(withLoss bool) (Summary, error)

func (*Graph) SummaryOut

func (g *Graph) SummaryOut(withLoss bool, out func(string)) error

func (*Graph) ToJson

func (g *Graph) ToJson(withLoss bool) ([]byte, error)

type GraphIdentity

type GraphIdentity [20]byte // SHA1

func (GraphIdentity) String

func (identity GraphIdentity) String() string

type GraphJs

type GraphJs struct {
	Nodes []struct {
		Op     string
		Name   string
		Attrs  map[string]string
		Inputs []interface{}
	}
}

type Inite

type Inite interface {
	Inite(*NDArray) error
}

type Loss

type Loss interface {
	// out, label => loss, sparse
	// sparse means label dimensions reduced by last one
	Loss(*Symbol, *Symbol) (*Symbol, bool)
}

type NDArray

type NDArray struct {
	// contains filtered or unexported fields
}

func Array

func Array(tp Dtype, d Dimension) *NDArray

func Errayf

func Errayf(s string, a ...interface{}) *NDArray

func (*NDArray) Cast

func (a *NDArray) Cast(dt Dtype) *NDArray

func (*NDArray) Context

func (a *NDArray) Context() Context

func (*NDArray) CopyValuesTo

func (a *NDArray) CopyValuesTo(dst interface{}) error

func (*NDArray) Depth

func (a *NDArray) Depth() int

func (*NDArray) Dim

func (a *NDArray) Dim() Dimension

func (*NDArray) Dtype

func (a *NDArray) Dtype() Dtype

func (NDArray) Err

func (a NDArray) Err() error

func (*NDArray) Fill

func (a *NDArray) Fill(value float32) *NDArray

func (*NDArray) Len

func (a *NDArray) Len(d int) int

func (*NDArray) NewLikeThis

func (a *NDArray) NewLikeThis() *NDArray

func (*NDArray) Normal

func (a *NDArray) Normal(mean float32, scale float32) *NDArray

func (*NDArray) Ones

func (a *NDArray) Ones() *NDArray

func (*NDArray) Raw

func (a *NDArray) Raw() []byte

func (*NDArray) Release

func (a *NDArray) Release()

func (*NDArray) Reshape

func (a *NDArray) Reshape(dim Dimension) *NDArray

func (*NDArray) SetValues

func (a *NDArray) SetValues(vals ...interface{}) error

func (*NDArray) Size

func (a *NDArray) Size() int

func (*NDArray) String

func (a *NDArray) String() string

func (*NDArray) Uniform

func (a *NDArray) Uniform(low float32, high float32) *NDArray

func (*NDArray) Values

func (a *NDArray) Values(dtype Dtype) (interface{}, error)

func (*NDArray) ValuesF32

func (a *NDArray) ValuesF32() []float32

func (*NDArray) Xavier

func (a *NDArray) Xavier(gaussian bool, factor int, magnitude float32) *NDArray

func (*NDArray) Zeros

func (a *NDArray) Zeros() *NDArray

type Param

type Param struct {
	Data     *NDArray
	Grad     *NDArray
	Autograd bool
}

type Summary

type Summary []*SummaryRow

type SummaryRow

type SummaryRow struct {
	No        int
	Name      string
	Operation string
	Params    int
	Dim       Dimension
	Args      []SummryArg
}

type SummryArg

type SummryArg struct {
	No   int
	Name string
}

type Symbol

type Symbol struct {
	// contains filtered or unexported fields
}

func Abs

func Abs(a *Symbol) *Symbol

func Activation

func Activation(a *Symbol, actType ActivationType) *Symbol

func Add

func Add(lv interface{}, rv interface{}) *Symbol

func BatchNorm

func BatchNorm(a, gamma, beta, rmean, rvar *Symbol, mom, eps float32, axis ...int) *Symbol

func BlockGrad

func BlockGrad(s *Symbol) *Symbol

func Concat

func Concat(a ...*Symbol) *Symbol

func Conv

func Conv(a, weight, bias *Symbol, channels int, kernel, stride, padding Dimension, groups bool) *Symbol

func Cosh

func Cosh(a *Symbol) *Symbol

func Div

func Div(lv interface{}, rv interface{}) *Symbol

func Dot

func Dot(lv interface{}, rv interface{}) *Symbol

func Flatten

func Flatten(a *Symbol) *Symbol

func FullyConnected

func FullyConnected(a, weight, bias *Symbol, size int, flatten bool) *Symbol

func GenericOp2

func GenericOp2(op, opScalar, opScalarR capi.MxnetOp, lv interface{}, rv interface{}) *Symbol

func Group

func Group(a ...*Symbol) *Symbol

func Input

func Input(..._hidden_input_) *Symbol

func Log

func Log(a *Symbol) *Symbol

func LogCosh

func LogCosh(a *Symbol) *Symbol

func LogSoftmax

func LogSoftmax(a *Symbol, axis ...int) *Symbol

func MakeLoss

func MakeLoss(s *Symbol) *Symbol

func Mean

func Mean(a *Symbol, axis ...int) *Symbol

func MeanXl

func MeanXl(a *Symbol, axis ...int) *Symbol

func Minus

func Minus(a *Symbol) *Symbol

func Mul

func Mul(lv interface{}, rv interface{}) *Symbol

func Not

func Not(a *Symbol) *Symbol

func Pick

func Pick(a *Symbol, label *Symbol) *Symbol

func Pool

func Pool(a *Symbol, kernel, stride, padding Dimension, ceil bool, maxpool bool) *Symbol

func Pow

func Pow(lv interface{}, rv interface{}) *Symbol

func Sigmoid

func Sigmoid(a *Symbol) *Symbol

func Sin

func Sin(a *Symbol) *Symbol

func Softmax

func Softmax(a *Symbol, axis ...int) *Symbol

func SoftmaxActivation

func SoftmaxActivation(a *Symbol, channel bool) *Symbol

func SoftmaxCrossEntropy

func SoftmaxCrossEntropy(a, b *Symbol, axis ...int) *Symbol

func SoftmaxOutput

func SoftmaxOutput(a *Symbol, l *Symbol, multiOutput bool) *Symbol

func Square

func Square(a *Symbol) *Symbol

func Stack

func Stack(a ...*Symbol) *Symbol

func Stack1

func Stack1(a ...*Symbol) *Symbol

func Sub

func Sub(lv interface{}, rv interface{}) *Symbol

func Sum

func Sum(a *Symbol, axis ...int) *Symbol

func SumXl

func SumXl(a *Symbol, axis ...int) *Symbol

func SymbolCast

func SymbolCast(i interface{}) (*Symbol, error)

func Tanh

func Tanh(a *Symbol) *Symbol

func Var

func Var(name string, opt ...interface{}) *Symbol

func (*Symbol) SetName

func (s *Symbol) SetName(name string) *Symbol

type VersionType

type VersionType int
const Version VersionType = VersionMajor*10000 + VersionMinor*100 + VersionPatch

func LibVersion

func LibVersion() VersionType

func MakeVersion

func MakeVersion(major, minor, patch int) VersionType

func (VersionType) Major

func (v VersionType) Major() int

func (VersionType) Minor

func (v VersionType) Minor() int

func (VersionType) Patch

func (v VersionType) Patch() int

func (VersionType) String

func (v VersionType) String() string

Directories

Path Synopsis

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL