Go Report Card Build Status Coverage Status GoDoc

Feed forward/backpropagation neural network implementation. Currently supports:

  • Activation functions: sigmoid, hyperbolic, ReLU
  • Solvers: SGD, SGD with momentum/nesterov, Adam
  • Classification modes: regression, multi-class, multi-label, binary
  • Supports batch training in parallel
  • Bias nodes

Networks are modeled as a set of neurons connected through synapses. No GPU computations - don't use this for any large scale applications.


  • Dropout
  • Batch normalization


go get -u


Import the go-deep package

import (
	deep ""

Define some data...

var data = training.Examples{
	{[]float64{2.7810836, 2.550537003}, []float64{0}},
	{[]float64{1.465489372, 2.362125076}, []float64{0}},
	{[]float64{3.396561688, 4.400293529}, []float64{0}},
	{[]float64{1.38807019, 1.850220317}, []float64{0}},
	{[]float64{7.627531214, 2.759262235}, []float64{1}},
	{[]float64{5.332441248, 2.088626775}, []float64{1}},
	{[]float64{6.922596716, 1.77106367}, []float64{1}},
	{[]float64{8.675418651, -0.242068655}, []float64{1}},

Create a network with two hidden layers of size 2 and 2 respectively:

n := deep.NewNeural(&deep.Config{
	/* Input dimensionality */
	Inputs: 2,
	/* Two hidden layers consisting of two neurons each, and a single output */
	Layout: []int{2, 2, 1},
	/* Activation functions: Sigmoid, Tanh, ReLU, Linear */
	Activation: deep.ActivationSigmoid,
	/* Determines output layer activation & loss function: 
	ModeRegression: linear outputs with MSE loss
	ModeMultiClass: softmax output with Cross Entropy loss
	ModeMultiLabel: sigmoid output with Cross Entropy loss
	ModeBinary: sigmoid output with binary CE loss */
	Mode: deep.ModeBinary,
	/* Weight initializers: {deep.NewNormal(μ, σ), deep.NewUniform(μ, σ)} */
	Weight: deep.NewNormal(1.0, 0.0),
	/* Apply bias */
	Bias: true,


// params: learning rate, momentum, alpha decay, nesterov
optimizer := training.NewSGD(0.05, 0.1, 1e-6, true)
// params: optimizer, verbosity (print stats at every 50th iteration)
trainer := training.NewTrainer(optimizer, 50)

training, heldout := data.Split(0.5)
trainer.Train(n, training, heldout, 1000) // training, validation, iterations

resulting in:

Epochs        Elapsed       Error         
---           ---           ---           
5             12.938µs      0.36438       
10            125.691µs     0.02261       
15            177.194µs     0.00404       
1000          10.703839ms   0.00000       

Finally, make some predictions:

fmt.Println(data[0].Input, "=>", n.Predict(data[0].Input))
fmt.Println(data[5].Input, "=>", n.Predict(data[5].Input))

Alternatively, batch training can be performed in parallell:

optimizer := NewAdam(0.001, 0.9, 0.999, 1e-8)
// params: optimizer, verbosity (print info at every n:th iteration), batch-size, number of workers
trainer := training.NewBatchTrainer(optimizer, 1, 200, 4)

training, heldout := data.Split(0.75)
trainer.Train(n, training, heldout, 1000) // training, validation, iterations


See training/trainer_test.go for a variety of toy examples of regression, multi-class classification, binary classification, etc.

See examples/ for more realistic examples:

Dataset Topology Epochs Accuracy
wines [5 5] 10000 ~98%
mnist [50] 25 ~97%
Expand ▾ Collapse ▴




This section is empty.


This section is empty.


func ArgMax

func ArgMax(xx []float64) int

    ArgMax is the index of the largest element

    func Dot

    func Dot(xx, yy []float64) float64

      Dot product

      func Logistic

      func Logistic(x, a float64) float64

        Logistic is the logistic function

        func Max

        func Max(xx []float64) float64

          Max is the largest element

          func Mean

          func Mean(xx []float64) float64

            Mean of xx

            func Min

            func Min(xx []float64) float64

              Min is the smallest element

              func Normal

              func Normal(stdDev, mean float64) float64

                Normal samples a value from N(μ, σ)

                func Normalize

                func Normalize(xx []float64)

                  Normalize scales to (0,1)

                  func Round

                  func Round(x float64) float64

                    Round to nearest integer

                    func Sgn

                    func Sgn(x float64) float64

                      Sgn is signum

                      func Softmax

                      func Softmax(xx []float64) []float64

                        Softmax is the softmax function

                        func StandardDeviation

                        func StandardDeviation(xx []float64) float64

                          StandardDeviation of xx

                          func Standardize

                          func Standardize(xx []float64)

                            Standardize (z-score) shifts distribution to μ=0 σ=1

                            func Sum

                            func Sum(xx []float64) (sum float64)

                              Sum is sum

                              func Uniform

                              func Uniform(stdDev, mean float64) float64

                                Uniform samples a value from u(mean-stdDev/2,mean+stdDev/2)

                                func Variance

                                func Variance(xx []float64) float64

                                  Variance of xx


                                  type ActivationType

                                  type ActivationType int

                                    ActivationType is represents a neuron activation function

                                    const (
                                    	// ActivationNone is no activation
                                    	ActivationNone ActivationType = 0
                                    	// ActivationSigmoid is a sigmoid activation
                                    	ActivationSigmoid ActivationType = 1
                                    	// ActivationTanh is hyperbolic activation
                                    	ActivationTanh ActivationType = 2
                                    	// ActivationReLU is rectified linear unit activation
                                    	ActivationReLU ActivationType = 3
                                    	// ActivationLinear is linear activation
                                    	ActivationLinear ActivationType = 4
                                    	// ActivationSoftmax is a softmax activation (per layer)
                                    	ActivationSoftmax ActivationType = 5

                                    func OutputActivation

                                    func OutputActivation(c Mode) ActivationType

                                      OutputActivation returns activation corresponding to prediction mode

                                      type BinaryCrossEntropy

                                      type BinaryCrossEntropy struct{}

                                        BinaryCrossEntropy is binary CE loss

                                        func (BinaryCrossEntropy) Df

                                        func (l BinaryCrossEntropy) Df(estimate, ideal, activation float64) float64

                                          Df is CE'(...)

                                          func (BinaryCrossEntropy) F

                                          func (l BinaryCrossEntropy) F(estimate, ideal [][]float64) float64

                                            F is CE(...)

                                            type Config

                                            type Config struct {
                                            	// Number of inputs
                                            	Inputs int
                                            	// Defines topology:
                                            	// For instance, [5 3 3] signifies a network with two hidden layers
                                            	// containing 5 and 3 nodes respectively, followed an output layer
                                            	// containing 3 nodes.
                                            	Layout []int
                                            	// Activation functions: {ActivationTanh, ActivationReLU, ActivationSigmoid}
                                            	Activation ActivationType
                                            	// Solver modes: {ModeRegression, ModeBinary, ModeMultiClass, ModeMultiLabel}
                                            	Mode Mode
                                            	// Initializer for weights: {NewNormal(σ, μ), NewUniform(σ, μ)}
                                            	Weight WeightInitializer `json:"-"`
                                            	// Loss functions: {LossCrossEntropy, LossBinaryCrossEntropy, LossMeanSquared}
                                            	Loss LossType
                                            	// Apply bias nodes
                                            	Bias bool

                                              Config defines the network topology, activations, losses etc

                                              type CrossEntropy

                                              type CrossEntropy struct{}

                                                CrossEntropy is CE loss

                                                func (CrossEntropy) Df

                                                func (l CrossEntropy) Df(estimate, ideal, activation float64) float64

                                                  Df is CE'(...)

                                                  func (CrossEntropy) F

                                                  func (l CrossEntropy) F(estimate, ideal [][]float64) float64

                                                    F is CE(...)

                                                    type Differentiable

                                                    type Differentiable interface {
                                                    	F(float64) float64
                                                    	Df(float64) float64

                                                      Differentiable is an activation function and its first order derivative, where the latter is expressed as a function of the former for efficiency

                                                      func GetActivation

                                                      func GetActivation(act ActivationType) Differentiable

                                                        GetActivation returns the concrete activation given an ActivationType

                                                        type Dump

                                                        type Dump struct {
                                                        	Config  *Config
                                                        	Weights [][][]float64

                                                          Dump is a neural network dump

                                                          type Layer

                                                          type Layer struct {
                                                          	Neurons []*Neuron
                                                          	A       ActivationType

                                                            Layer is a set of neurons and corresponding activation

                                                            func NewLayer

                                                            func NewLayer(n int, activation ActivationType) *Layer

                                                              NewLayer creates a new layer with n nodes

                                                              func (*Layer) ApplyBias

                                                              func (l *Layer) ApplyBias(weight WeightInitializer) []*Synapse

                                                                ApplyBias creates and returns a bias synapse for each neuron in l

                                                                func (*Layer) Connect

                                                                func (l *Layer) Connect(next *Layer, weight WeightInitializer)

                                                                  Connect fully connects layer l to next, and initializes each synapse with the given weight function

                                                                  func (Layer) String

                                                                  func (l Layer) String() string

                                                                  type Linear

                                                                  type Linear struct{}

                                                                    Linear is a linear activator

                                                                    func (Linear) Df

                                                                    func (a Linear) Df(x float64) float64

                                                                      Df is constant

                                                                      func (Linear) F

                                                                      func (a Linear) F(x float64) float64

                                                                        F is the identity function

                                                                        type Loss

                                                                        type Loss interface {
                                                                        	F(estimate, ideal [][]float64) float64
                                                                        	Df(estimate, ideal, activation float64) float64

                                                                          Loss is satisfied by loss functions

                                                                          func GetLoss

                                                                          func GetLoss(loss LossType) Loss

                                                                            GetLoss returns a loss function given a LossType

                                                                            type LossType

                                                                            type LossType int

                                                                              LossType represents a loss function

                                                                              const (
                                                                              	// LossNone signifies unspecified loss
                                                                              	LossNone LossType = 0
                                                                              	// LossCrossEntropy is cross entropy loss
                                                                              	LossCrossEntropy LossType = 1
                                                                              	// LossBinaryCrossEntropy is the special case of binary cross entropy loss
                                                                              	LossBinaryCrossEntropy LossType = 2
                                                                              	// LossMeanSquared is MSE
                                                                              	LossMeanSquared LossType = 3

                                                                              func (LossType) String

                                                                              func (l LossType) String() string

                                                                              type MeanSquared

                                                                              type MeanSquared struct{}

                                                                                MeanSquared in MSE loss

                                                                                func (MeanSquared) Df

                                                                                func (l MeanSquared) Df(estimate, ideal, activation float64) float64

                                                                                  Df is MSE'(...)

                                                                                  func (MeanSquared) F

                                                                                  func (l MeanSquared) F(estimate, ideal [][]float64) float64

                                                                                    F is MSE(...)

                                                                                    type Mode

                                                                                    type Mode int

                                                                                      Mode denotes inference mode

                                                                                      const (
                                                                                      	// ModeDefault is unspecified mode
                                                                                      	ModeDefault Mode = 0
                                                                                      	// ModeMultiClass is for one-hot encoded classification, applies softmax output layer
                                                                                      	ModeMultiClass Mode = 1
                                                                                      	// ModeRegression is regression, applies linear output layer
                                                                                      	ModeRegression Mode = 2
                                                                                      	// ModeBinary is binary classification, applies sigmoid output layer
                                                                                      	ModeBinary Mode = 3
                                                                                      	// ModeMultiLabel is for multilabel classification, applies sigmoid output layer
                                                                                      	ModeMultiLabel Mode = 4

                                                                                      type Neural

                                                                                      type Neural struct {
                                                                                      	Layers []*Layer
                                                                                      	Biases [][]*Synapse
                                                                                      	Config *Config

                                                                                        Neural is a neural network

                                                                                        func FromDump

                                                                                        func FromDump(dump *Dump) *Neural

                                                                                          FromDump restores a Neural from a dump

                                                                                          func NewNeural

                                                                                          func NewNeural(c *Config) *Neural

                                                                                            NewNeural returns a new neural network

                                                                                            func Unmarshal

                                                                                            func Unmarshal(bytes []byte) (*Neural, error)

                                                                                              Unmarshal restores network from a JSON blob

                                                                                              func (*Neural) ApplyWeights

                                                                                              func (n *Neural) ApplyWeights(weights [][][]float64)

                                                                                                ApplyWeights sets the weights from a three-dimensional slice

                                                                                                func (Neural) Dump

                                                                                                func (n Neural) Dump() *Dump

                                                                                                  Dump generates a network dump

                                                                                                  func (*Neural) Forward

                                                                                                  func (n *Neural) Forward(input []float64) error

                                                                                                    Forward computes a forward pass

                                                                                                    func (Neural) Marshal

                                                                                                    func (n Neural) Marshal() ([]byte, error)

                                                                                                      Marshal marshals to JSON from network

                                                                                                      func (*Neural) NumWeights

                                                                                                      func (n *Neural) NumWeights() (num int)

                                                                                                        NumWeights returns the number of weights in the network

                                                                                                        func (*Neural) Predict

                                                                                                        func (n *Neural) Predict(input []float64) []float64

                                                                                                          Predict computes a forward pass and returns a prediction

                                                                                                          func (*Neural) String

                                                                                                          func (n *Neural) String() string

                                                                                                          func (Neural) Weights

                                                                                                          func (n Neural) Weights() [][][]float64

                                                                                                            Weights returns all weights in sequence

                                                                                                            type Neuron

                                                                                                            type Neuron struct {
                                                                                                            	A     ActivationType `json:"-"`
                                                                                                            	In    []*Synapse
                                                                                                            	Out   []*Synapse
                                                                                                            	Value float64 `json:"-"`

                                                                                                              Neuron is a neural network node

                                                                                                              func NewNeuron

                                                                                                              func NewNeuron(activation ActivationType) *Neuron

                                                                                                                NewNeuron returns a neuron with the given activation

                                                                                                                func (*Neuron) Activate

                                                                                                                func (n *Neuron) Activate(x float64) float64

                                                                                                                  Activate applies the neurons activation

                                                                                                                  func (*Neuron) DActivate

                                                                                                                  func (n *Neuron) DActivate(x float64) float64

                                                                                                                    DActivate applies the derivative of the neurons activation

                                                                                                                    type ReLU

                                                                                                                    type ReLU struct{}

                                                                                                                      ReLU is a rectified linear unit activator

                                                                                                                      func (ReLU) Df

                                                                                                                      func (a ReLU) Df(y float64) float64

                                                                                                                        Df is ReLU'(y), where y = ReLU(x)

                                                                                                                        func (ReLU) F

                                                                                                                        func (a ReLU) F(x float64) float64

                                                                                                                          F is ReLU(x)

                                                                                                                          type Sigmoid

                                                                                                                          type Sigmoid struct{}

                                                                                                                            Sigmoid is a logistic activator in the special case of a = 1

                                                                                                                            func (Sigmoid) Df

                                                                                                                            func (a Sigmoid) Df(y float64) float64

                                                                                                                              Df is Sigmoid'(y), where y = Sigmoid(x)

                                                                                                                              func (Sigmoid) F

                                                                                                                              func (a Sigmoid) F(x float64) float64

                                                                                                                                F is Sigmoid(x)

                                                                                                                                type Synapse

                                                                                                                                type Synapse struct {
                                                                                                                                	Weight  float64
                                                                                                                                	In, Out float64 `json:"-"`
                                                                                                                                	IsBias  bool

                                                                                                                                  Synapse is an edge between neurons

                                                                                                                                  func NewSynapse

                                                                                                                                  func NewSynapse(weight float64) *Synapse

                                                                                                                                    NewSynapse returns a synapse with the specified initialized weight

                                                                                                                                    type Tanh

                                                                                                                                    type Tanh struct{}

                                                                                                                                      Tanh is a hyperbolic activator

                                                                                                                                      func (Tanh) Df

                                                                                                                                      func (a Tanh) Df(y float64) float64

                                                                                                                                        Df is Tanh'(y), where y = Tanh(x)

                                                                                                                                        func (Tanh) F

                                                                                                                                        func (a Tanh) F(x float64) float64

                                                                                                                                          F is Tanh(x)

                                                                                                                                          type WeightInitializer

                                                                                                                                          type WeightInitializer func() float64

                                                                                                                                            A WeightInitializer returns a (random) weight

                                                                                                                                            func NewNormal

                                                                                                                                            func NewNormal(stdDev, mean float64) WeightInitializer

                                                                                                                                              NewNormal returns a normal weight generator

                                                                                                                                              func NewUniform

                                                                                                                                              func NewUniform(stdDev, mean float64) WeightInitializer

                                                                                                                                                NewUniform returns a uniform weight generator


                                                                                                                                                Path Synopsis