kan

package
v0.11.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Aug 16, 2024 License: Apache-2.0 Imports: 15 Imported by: 0

Documentation

Overview

Package kan implements a generic Kolmogorov–Arnold Networks, as described in https://arxiv.org/pdf/2404.19756

Start with New(ctx, x, numOutputNodes). Configure further as desired. When finished, call Done, and it will return KAN(x), per configuration.

It is highly customizable, but the default ties to follow the description on section "2.2 KAN architecture" of the paper.

TODO: Implement "the good" variants from https://github.com/mintisan/awesome-kan

Index

Constants

View Source
const (
	// ParamNumHiddenLayers is the hyperparameter that defines the default number of hidden layers.
	// The default is 0, so no hidden layers.
	ParamNumHiddenLayers = "kan_num_hidden_layers"

	// ParamNumHiddenNodes is the hyperparameter that defines the default number of hidden nodes for KAN hidden layers.
	// Default is 10.
	ParamNumHiddenNodes = "kan_num_hidden_nodes"

	// ParamNumControlPoints is the hyperparameter that defines the default number of control points
	// for the bsplines used in the univariate KAN functions.
	//
	// If used Discrete-KAN it also defines the number of control points.
	//
	// Default is 20.
	ParamNumControlPoints = "kan_num_points"

	// ParamBSplineDegree is the hyperparameter that defines the default value for the bspline degree used in
	// the univariate KAN functions.
	// Default is 2.
	ParamBSplineDegree = "kan_bspline_degree"

	// ParamBSplineMagnitudeL2 is the hyperparameter that defines the default L2 regularization amount for the bspline
	// learned magnitude parameters.
	// Default is 0.
	ParamBSplineMagnitudeL2 = "kan_bspline_magnitude_l2"

	// ParamBSplineMagnitudeL1 is the hyperparameter that defines the default L1 regularization amount for the bspline
	// learned magnitude parameters.
	// Default is 0.
	ParamBSplineMagnitudeL1 = "kan_bspline_magnitude_l1"
)

Variables

View Source
var (
	// ParamDiscrete indicates whether to use Discrete-KAN as the univariate function to learn.
	ParamDiscrete = "kan_discrete"

	// ParamDiscreteSoftness indicates whether to soften the PCF (piecewise constant functions) during training,
	// and by how much.
	ParamDiscreteSoftness = "kan_discrete_softness"

	// ParamDiscreteSplitPointsTrainable indicates whether the split points are trainable and can move around.
	// Default is true.
	ParamDiscreteSplitPointsTrainable = "kan_discrete_splits_trainable"
)

Functions

func PiecewiseConstantFunction

func PiecewiseConstantFunction(input, controlPoints, splitPoints *Node) *Node

PiecewiseConstantFunction (PCF) generates a PCF output for a cross of numInputNodes x numOutputNodes, defined as the dimensions of its inputs, as follows:

  • input to be transformed, shaped [batchSize, numInputNodes]
  • controlPoints are the output values of the PCFs, and should be shaped [numOutputNodes, numInputNodes, NumControlPoints]
  • splitPoints are the splitting values of the inputs, shaped [numOutputNodes or 1, numInputNodes or 1, NumControlPoints-1], but if the first or the second axes are set to 1, they are broadcast accordingly.

The output will be shaped [batchSize, numOutputPoints, numInputNodes]. Presumably, the caller will graph.ReduceSum on the last axis (after residual value is added) for a shape [batchSize, numOutputPoints].

func PiecewiseConstantFunctionWithInputPerturbation

func PiecewiseConstantFunctionWithInputPerturbation(input, controlPoints, splitPoints, softness *Node) *Node

PiecewiseConstantFunctionWithInputPerturbation works similarly to PiecewiseConstantFunction, but adds a "perturbation" of the inputs by a triangular distribution of the value, controlled by smoothness.

The shapes and inputs are the same as PiecewiseConstantFunction, with the added smoothness parameter that should be a scalar with suggested values from 0 to 1.0.

The smoothness softens the function by perturbing the input using a triangular distribution, whose base is given by softness * 2 * (splitPoint[-1] - splitPoint[0]). If softness is 0, the function is back to being piece-wise constant.

The softening makes it differentiable with respect to the splitPoints, and hence can be used for training. One can control the softness as a form of annealing on the split points. As it reaches 0, the split points are no longer changed (only the control points).

The output will be shaped [batchSize, numOutputPoints, numInputNodes]. Presumably, the caller will graph.ReduceSum on the last axis (after residual value is added) for a shape [batchSize, numOutputPoints].

Types

type Config

type Config struct {
	// contains filtered or unexported fields
}

Config is created with New and can be configured with its methods, or simply setting the corresponding hyperparameters in the context.

func New

func New(ctx *context.Context, input *Node, numOutputNodes int) *Config

New returns the configuration for a KAN bsplineLayer(s) to be applied to the input x. See methods for optional configurations. When finished configuring call Done, and it will return "KAN(x)".

The input is expected to have shape `[<batch dimensions...>, <featureDimension>]`, the output will have shape `[<batch dimensions...>, <numOutputNodes>]`.

It will apply KAN-like transformations to the last axis (the "feature axis") of x, while preserving all leading axes (we'll call them "batch axes").

func (*Config) Activation

func (c *Config) Activation(activation activations.Type) *Config

Activation sets the activation for the KAN, which is applied after the sum.

Following the paper, it defaults to "silu" (== "swish"), but it will be overridden if the hyperparameter layers.ParamActivation (="activation") is set in the context.

func (*Config) BSpline

func (c *Config) BSpline(numControlPoints int) *Config

BSpline configures the KAN to use b-splines to model \phi(x), the univariate function described in the KAN the paper. It also sets the number of control points to use to model the function.

The numControlPoints must be greater or equal to 3, and it defaults to 20 and can also be set by using the hyperparameter ParamNumControlPoints ("kan_num_points").

func (*Config) Discrete

func (c *Config) Discrete(numControlPoints int) *Config

Discrete configures the KAN to use a "piecewise-constant" functions (as opposed to splines) to model \phi(x), the univariate function used in the pape, and set the number of control points to use for the function.

The numControlPoints must be greater or equal to 2, and it defaults to 20 and can also be set by using the hyperparameter ParamNumControlPoints ("kan_num_points").

func (*Config) DiscreteSoftness

func (c *Config) DiscreteSoftness(softness float64) *Config

DiscreteSoftness sets how much softness to use during training. If set to 0 softness is disabled.

The default is 0.1, and it can be set with the hyperparameter ParamDiscreteSoftness ("kan_discrete_softness").

func (*Config) Done

func (c *Config) Done() *Node

Done takes the configuration and apply the KAN bsplineLayer(s) configured.

func (*Config) NumHiddenLayers

func (c *Config) NumHiddenLayers(numLayers, numHiddenNodes int) *Config

NumHiddenLayers configure the number of hidden layers between the input and the output. Each bsplineLayer will have numHiddenNodes nodes.

The default is 0 (no hidden layers), but it will be overridden if the hyperparameter ParamNumHiddenLayers is set in the context (ctx). The value for numHiddenNodes can also be configured with the hyperparameter ParamNumHiddenNodes.

func (*Config) Regularizer

func (c *Config) Regularizer(regularizer regularizers.Regularizer) *Config

Regularizer to be applied to the learned weights. Default is none.

To use more than one type of Regularizer, use regularizers.Combine, and set the returned combined regularizer here.

For BSpline models it applies the regularizer to the control-points.

The default is no regularizer, but it can be configured by regularizers.ParamL1 and regularizers.ParamL2.

func (*Config) WithBSplineMagnitudeRegularizer

func (c *Config) WithBSplineMagnitudeRegularizer(regularizer regularizers.Regularizer) *Config

WithBSplineMagnitudeRegularizer to be applied to the magnitude weights for BSpline. Default is none, but can be changed with hyperparameters ParamBSplineMagnitudeL2 and ParamBSplineMagnitudeL1.

For BSpline models it applies the regularizer to the control-points.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL