model

package
Version: v0.0.0-...-f83957e Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jan 30, 2021 License: Apache-2.0 Imports: 8 Imported by: 6

Documentation

Overview

Package model provides an interface for machine learning models.

Index

Constants

This section is empty.

Variables

AllMetrics are all metrics.

View Source
var CrossEntropy = &CrossEntropyLoss{}

CrossEntropy loss.

View Source
var MSE = &MSELoss{}

MSE is standard mean squared error loss.

View Source
var PseudoCrossEntropy = &PseudoCrossEntropyLoss{}

PseudoCrossEntropy loss.

View Source
var PseudoHuber = &PseudoHuberLoss{
	Delta: 1.0,
}

PseudoHuber is the Huber loss function.

Functions

func AsBatch

func AsBatch(size int) func(*Input)

AsBatch adds a batch size to a clone opt

func AsType

func AsType(dtype t.Dtype) func(*Input)

AsType explicitly sets the type of the input. Defaults to Float32.

func NameAsBatch

func NameAsBatch(name string) string

NameAsBatch takes an input name and converts it to its batch name.

func WithBatchSize

func WithBatchSize(size int) func(Model)

WithBatchSize sets the batch size for the model. Defaults to 32.

func WithGraphLogger

func WithGraphLogger(log *golog.Logger) func(Model)

WithGraphLogger adds a logger to the model which will print out the graph operations as they occur.

func WithLogger

func WithLogger(logger *log.Logger) func(Model)

WithLogger adds a logger to the model.

func WithLoss

func WithLoss(loss Loss) func(Model)

WithLoss uses a specific loss function with the model. Defaults to MSE.

func WithMetrics

func WithMetrics(metrics ...Metric) func(Model)

WithMetrics sets the metrics that the model should track. Defaults to AllMetrics.

func WithOptimizer

func WithOptimizer(optimizer g.Solver) func(Model)

WithOptimizer uses a specific optimizer function. Defaults to Adam.

func WithTracker

func WithTracker(tracker *track.Tracker) func(Model)

WithTracker adds a tracker to the model, if not provided one will be created.

func WithoutTracker

func WithoutTracker() func(Model)

WithoutTracker uses no tracking with the model.

Types

type CloneOpt

type CloneOpt func(*Input)

CloneOpt are clone options.

type CrossEntropyLoss

type CrossEntropyLoss struct{}

CrossEntropyLoss is standard cross entropy loss.

func (*CrossEntropyLoss) CloneTo

func (c *CrossEntropyLoss) CloneTo(graph *g.ExprGraph, opts ...CloneOpt) Loss

CloneTo another graph.

func (*CrossEntropyLoss) Compute

func (c *CrossEntropyLoss) Compute(yHat, y *g.Node) (loss *g.Node, err error)

Compute the loss.

func (*CrossEntropyLoss) Inputs

func (c *CrossEntropyLoss) Inputs() Inputs

Inputs returns any inputs the loss function utilizes.

type Input

type Input struct {
	// contains filtered or unexported fields
}

Input into the model.

func NewInput

func NewInput(name string, shape t.Shape, opts ...InputOpt) *Input

NewInput returns a new input.

func (*Input) AsBatch

func (i *Input) AsBatch(size int) *Input

AsBatch converts an input to a batched representation.

func (*Input) AsLayer

func (i *Input) AsLayer() l.Layer

AsLayer converts the input to a layer.

func (*Input) Check

func (i *Input) Check(value g.Value) error

Check that the dimensions and type of the given value are congruent with the expected input.

func (*Input) Clone

func (i *Input) Clone(opts ...CloneOpt) *Input

Clone an input.

func (*Input) CloneTo

func (i *Input) CloneTo(graph *g.ExprGraph, opts ...CloneOpt) *Input

CloneTo clones an input with the node value (if present) to another graph.

func (*Input) Compile

func (i *Input) Compile(graph *g.ExprGraph, opts ...InputOpt) *g.Node

Compile an input into a graph.

func (*Input) DType

func (i *Input) DType() t.Dtype

DType data type of the input.

func (*Input) EnsureBatch

func (i *Input) EnsureBatch() *Input

EnsureBatch checks that the first dimension is 1 or reshapes it to be so.

func (*Input) Input

func (i *Input) Input() *Input

Input implements an In.

func (*Input) Inputs

func (i *Input) Inputs() Inputs

Inputs implements an In.

func (*Input) Name

func (i *Input) Name() string

Name of the input.

func (*Input) Node

func (i *Input) Node() *g.Node

Node returns the graph node.

func (*Input) OneOfMany

func (i *Input) OneOfMany() (err error)

OneOfMany normalizes the input shape to be one of many. Any incoming singular input will also be normalized to this shape.

func (*Input) Set

func (i *Input) Set(value g.Value) error

Set the value of the input.

func (*Input) Shape

func (i *Input) Shape() t.Shape

Shape is the shape of the input.

func (*Input) Squeeze

func (i *Input) Squeeze() t.Shape

Squeeze returns the shape of the input with any leading dimensions of size 1 removed.

func (*Input) Validate

func (i *Input) Validate() error

Validate the input.

type InputLayer

type InputLayer struct {
	// contains filtered or unexported fields
}

InputLayer is an input layer to be used in a chain.

func (*InputLayer) Clone

func (i *InputLayer) Clone() l.Layer

Clone the layer.

func (*InputLayer) Compile

func (i *InputLayer) Compile(graph *g.ExprGraph, opts ...l.CompileOpt)

Compile the layer.

func (*InputLayer) Fwd

func (i *InputLayer) Fwd(x *g.Node) (*g.Node, error)

Fwd is a forward pass through the layer.

func (*InputLayer) Graph

func (i *InputLayer) Graph() *g.ExprGraph

Graph returns the graph for this layer.

func (*InputLayer) Learnables

func (i *InputLayer) Learnables() g.Nodes

Learnables returns all learnable nodes within this layer.

func (*InputLayer) Node

func (i *InputLayer) Node() *g.Node

Node of the input.

type InputOpt

type InputOpt func(*Input)

InputOpt is an input option.

type InputOr

type InputOr interface {
	// Input present.
	Input() *Input

	// Inputs present.
	Inputs() Inputs
}

InputOr is a sum type of input or inputs.

type Inputs

type Inputs []*Input

Inputs is a slice of input.

func (Inputs) Clone

func (i Inputs) Clone() Inputs

Clone the inputs.

func (Inputs) Compile

func (i Inputs) Compile(graph *g.ExprGraph, opts ...InputOpt) g.Nodes

Compile all inputs into the given graph.

func (Inputs) Contains

func (i Inputs) Contains(name string) bool

Contains tests whether the given input set contains an input.

func (Inputs) Get

func (i Inputs) Get(name string) (*Input, error)

Get an input by name.

func (Inputs) Input

func (i Inputs) Input() *Input

Input implements an In, returns the first element of the inputs.

func (Inputs) Inputs

func (i Inputs) Inputs() Inputs

Inputs implements an In.

func (Inputs) Set

func (i Inputs) Set(values Values) error

Set the values to the inputs.

type Loss

type Loss interface {
	// Comput the loss.
	Compute(yHat, y *g.Node) (loss *g.Node, err error)

	// Clone the loss to another graph.
	CloneTo(graph *g.ExprGraph, opts ...CloneOpt) Loss

	// Inputs return any inputs the loss function utilizes.
	Inputs() Inputs
}

Loss is the loss of a model.

type MSELoss

type MSELoss struct{}

MSELoss is mean squared error loss.

func (*MSELoss) CloneTo

func (m *MSELoss) CloneTo(graph *g.ExprGraph, opts ...CloneOpt) Loss

CloneTo another graph.

func (*MSELoss) Compute

func (m *MSELoss) Compute(yHat, y *g.Node) (loss *g.Node, err error)

Compute the loss

func (*MSELoss) Inputs

func (m *MSELoss) Inputs() Inputs

Inputs returns any inputs the loss function utilizes.

type Metric

type Metric string

Metric tracked by the model.

const (
	// TrainLossMetric is the metric for training loss.
	TrainLossMetric Metric = "train_loss"

	// TrainBatchLossMetric is the metric for batch training loss.
	TrainBatchLossMetric Metric = "train_batch_loss"
)

type Metrics

type Metrics []Metric

Metrics is a set of metric.

func (Metrics) Contains

func (m Metrics) Contains(metric Metric) bool

Contains tells whether the set contains the given metric.

type Model

type Model interface {
	// Compile the model.
	Compile(x InputOr, y *Input, opts ...Opt) error

	// Predict x.
	Predict(x g.Value) (prediction g.Value, err error)

	// Fit x to y.
	Fit(x ValueOr, y g.Value) error

	// FitBatch fits x to y as batches.
	FitBatch(x ValueOr, y g.Value) error

	// PredictBatch predicts x as a batch
	PredictBatch(x g.Value) (prediction g.Value, err error)

	// ResizeBatch resizes the batch graphs.
	ResizeBatch(n int) error

	// Visualize the model by graph name.
	Visualize(name string)

	// Graph returns the expression graph for the model.
	Graphs() map[string]*g.ExprGraph

	// X is the inputs to the model.
	X() InputOr

	// Y is the expected output of the model.
	Y() *Input

	// Learnables for the model.
	Learnables() g.Nodes
}

Model is a prediction model.

type Opt

type Opt func(Model)

Opt is a model option.

type Opts

type Opts struct {
	// contains filtered or unexported fields
}

Opts are optsion for a model

func NewOpts

func NewOpts() *Opts

NewOpts returns a new set of options for a model.

func (*Opts) Add

func (o *Opts) Add(opts ...Opt)

Add an option to the options.

func (*Opts) Values

func (o *Opts) Values() []Opt

Values are the options.

type PseudoCrossEntropyLoss

type PseudoCrossEntropyLoss struct{}

PseudoCrossEntropyLoss is standard cross entropy loss.

func (*PseudoCrossEntropyLoss) CloneTo

func (c *PseudoCrossEntropyLoss) CloneTo(graph *g.ExprGraph, opts ...CloneOpt) Loss

CloneTo another graph.

func (*PseudoCrossEntropyLoss) Compute

func (c *PseudoCrossEntropyLoss) Compute(yHat, y *g.Node) (loss *g.Node, err error)

Compute the loss.

func (*PseudoCrossEntropyLoss) Inputs

func (c *PseudoCrossEntropyLoss) Inputs() Inputs

Inputs returns any inputs the loss function utilizes.

type PseudoHuberLoss

type PseudoHuberLoss struct {
	// Delta determines where the function switches behavior.
	Delta float32
}

PseudoHuberLoss is a loss that is less sensetive to outliers. Can be thought of as absolute error when large, and quadratic when small. The larger the Delta param the steeper the loss.

!blocked on https://github.com/gorgonia/gorgonia/issues/373

func NewPseudoHuberLoss

func NewPseudoHuberLoss(delta float32) *PseudoHuberLoss

NewPseudoHuberLoss return a new huber loss.

func (*PseudoHuberLoss) CloneTo

func (h *PseudoHuberLoss) CloneTo(graph *g.ExprGraph, opts ...CloneOpt) Loss

CloneTo another graph.

func (*PseudoHuberLoss) Compute

func (h *PseudoHuberLoss) Compute(yHat, y *g.Node) (loss *g.Node, err error)

Compute the loss.

func (*PseudoHuberLoss) Inputs

func (h *PseudoHuberLoss) Inputs() Inputs

Inputs returns any inputs the loss function utilizes.

type Sequential

type Sequential struct {
	// Chain of layers in the model.
	Chain *layer.Chain

	// Tracker of values.
	Tracker *track.Tracker
	// contains filtered or unexported fields
}

Sequential model.

func NewSequential

func NewSequential(name string) (*Sequential, error)

NewSequential returns a new sequential model.

func (*Sequential) AddLayer

func (s *Sequential) AddLayer(layer layer.Config)

AddLayer adds a layer.

func (*Sequential) AddLayers

func (s *Sequential) AddLayers(layers ...layer.Config)

AddLayers adds a number of layer.

func (*Sequential) CloneLearnablesTo

func (s *Sequential) CloneLearnablesTo(to *Sequential) error

CloneLearnablesTo another model.

func (*Sequential) Compile

func (s *Sequential) Compile(x InputOr, y *Input, opts ...Opt) error

Compile the model.

func (*Sequential) Fit

func (s *Sequential) Fit(x ValueOr, y g.Value) error

Fit x to y.

func (*Sequential) FitBatch

func (s *Sequential) FitBatch(x ValueOr, y g.Value) error

FitBatch fits x to y as a batch.

func (*Sequential) Fwd

func (s *Sequential) Fwd(x *Input)

Fwd tells the model which input should be sent through the layer. If not provided, the first input will be used.

func (*Sequential) Graphs

func (s *Sequential) Graphs() map[string]*g.ExprGraph

Graphs returns the expression graphs for the model.

func (*Sequential) Learnables

func (s *Sequential) Learnables() g.Nodes

Learnables are the model learnables.

func (*Sequential) Predict

func (s *Sequential) Predict(x g.Value) (prediction g.Value, err error)

Predict x.

func (*Sequential) PredictBatch

func (s *Sequential) PredictBatch(x g.Value) (prediction g.Value, err error)

PredictBatch predicts x as a batch.

func (*Sequential) ResizeBatch

func (s *Sequential) ResizeBatch(n int) (err error)

ResizeBatch will resize the batch graph. Note: this is expensive as it recompiles the graph.

func (*Sequential) SetLearnables

func (s *Sequential) SetLearnables(desired g.Nodes) error

SetLearnables sets learnables to model

func (*Sequential) Visualize

func (s *Sequential) Visualize(name string)

Visualize the model by graph name.

func (*Sequential) X

func (s *Sequential) X() InputOr

X is is the input to the model.

func (*Sequential) Y

func (s *Sequential) Y() *Input

Y is is the output of the model.

type ValueOr

type ValueOr interface{}

ValueOr is a sum type that represents a gorgonia.Value or []gorgonia.Value.

type Values

type Values []g.Value

Values is a slice of value.

func ValuesFrom

func ValuesFrom(v ValueOr) Values

ValuesFrom returns the value as an array of gorgonia values.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL