Documentation
¶
Overview ¶
Package kan implements a generic Kolmogorov–Arnold Networks, as described in https://arxiv.org/pdf/2404.19756
Start with New(ctx, x, numOutputNodes). Configure further as desired. When finished, call Done, and it will return KAN(x), per configuration.
It is highly customizable, but the default ties to follow the description on section "2.2 KAN architecture" of the paper.
TODO: Implement "the good" variants from https://github.com/mintisan/awesome-kan
Index ¶
- Constants
- Variables
- func PiecewiseConstantFunction(input, controlPoints, splitPoints *Node) *Node
- func PiecewiseConstantFunctionWithInputPerturbation(input, controlPoints, splitPoints, softness *Node) *Node
- type Config
- func (c *Config) Activation(activation activations.Type) *Config
- func (c *Config) BSpline(numControlPoints int) *Config
- func (c *Config) Discrete(numControlPoints int) *Config
- func (c *Config) DiscreteSoftness(softness float64) *Config
- func (c *Config) Done() *Node
- func (c *Config) NumHiddenLayers(numLayers, numHiddenNodes int) *Config
- func (c *Config) Regularizer(regularizer regularizers.Regularizer) *Config
- func (c *Config) WithBSplineMagnitudeRegularizer(regularizer regularizers.Regularizer) *Config
Constants ¶
const ( // ParamNumHiddenLayers is the hyperparameter that defines the default number of hidden layers. // The default is 0, so no hidden layers. ParamNumHiddenLayers = "kan_num_hidden_layers" // ParamNumHiddenNodes is the hyperparameter that defines the default number of hidden nodes for KAN hidden layers. // Default is 10. ParamNumHiddenNodes = "kan_num_hidden_nodes" // ParamNumControlPoints is the hyperparameter that defines the default number of control points // for the bsplines used in the univariate KAN functions. // // If used Discrete-KAN it also defines the number of control points. // // Default is 20. ParamNumControlPoints = "kan_num_points" // ParamBSplineDegree is the hyperparameter that defines the default value for the bspline degree used in // the univariate KAN functions. // Default is 2. ParamBSplineDegree = "kan_bspline_degree" // ParamBSplineMagnitudeL2 is the hyperparameter that defines the default L2 regularization amount for the bspline // learned magnitude parameters. // Default is 0. ParamBSplineMagnitudeL2 = "kan_bspline_magnitude_l2" // ParamBSplineMagnitudeL1 is the hyperparameter that defines the default L1 regularization amount for the bspline // learned magnitude parameters. // Default is 0. ParamBSplineMagnitudeL1 = "kan_bspline_magnitude_l1" )
Variables ¶
var ( // ParamDiscrete indicates whether to use Discrete-KAN as the univariate function to learn. ParamDiscrete = "kan_discrete" // ParamDiscreteSoftness indicates whether to soften the PCF (piecewise constant functions) during training, // and by how much. ParamDiscreteSoftness = "kan_discrete_softness" // ParamDiscreteSplitPointsTrainable indicates whether the split points are trainable and can move around. // Default is true. ParamDiscreteSplitPointsTrainable = "kan_discrete_splits_trainable" )
Functions ¶
func PiecewiseConstantFunction ¶
func PiecewiseConstantFunction(input, controlPoints, splitPoints *Node) *Node
PiecewiseConstantFunction (PCF) generates a PCF output for a cross of numInputNodes x numOutputNodes, defined as the dimensions of its inputs, as follows:
- input to be transformed, shaped [batchSize, numInputNodes]
- controlPoints are the output values of the PCFs, and should be shaped [numOutputNodes, numInputNodes, NumControlPoints]
- splitPoints are the splitting values of the inputs, shaped [numOutputNodes or 1, numInputNodes or 1, NumControlPoints-1], but if the first or the second axes are set to 1, they are broadcast accordingly.
The output will be shaped [batchSize, numOutputPoints, numInputNodes]. Presumably, the caller will graph.ReduceSum on the last axis (after residual value is added) for a shape [batchSize, numOutputPoints].
func PiecewiseConstantFunctionWithInputPerturbation ¶
func PiecewiseConstantFunctionWithInputPerturbation(input, controlPoints, splitPoints, softness *Node) *Node
PiecewiseConstantFunctionWithInputPerturbation works similarly to PiecewiseConstantFunction, but adds a "perturbation" of the inputs by a triangular distribution of the value, controlled by smoothness.
The shapes and inputs are the same as PiecewiseConstantFunction, with the added smoothness parameter that should be a scalar with suggested values from 0 to 1.0.
The smoothness softens the function by perturbing the input using a triangular distribution, whose base is given by softness * 2 * (splitPoint[-1] - splitPoint[0]). If softness is 0, the function is back to being piece-wise constant.
The softening makes it differentiable with respect to the splitPoints, and hence can be used for training. One can control the softness as a form of annealing on the split points. As it reaches 0, the split points are no longer changed (only the control points).
The output will be shaped [batchSize, numOutputPoints, numInputNodes]. Presumably, the caller will graph.ReduceSum on the last axis (after residual value is added) for a shape [batchSize, numOutputPoints].
Types ¶
type Config ¶
type Config struct {
// contains filtered or unexported fields
}
Config is created with New and can be configured with its methods, or simply setting the corresponding hyperparameters in the context.
func New ¶
New returns the configuration for a KAN bsplineLayer(s) to be applied to the input x. See methods for optional configurations. When finished configuring call Done, and it will return "KAN(x)".
The input is expected to have shape `[<batch dimensions...>, <featureDimension>]`, the output will have shape `[<batch dimensions...>, <numOutputNodes>]`.
It will apply KAN-like transformations to the last axis (the "feature axis") of x, while preserving all leading axes (we'll call them "batch axes").
func (*Config) Activation ¶
func (c *Config) Activation(activation activations.Type) *Config
Activation sets the activation for the KAN, which is applied after the sum.
Following the paper, it defaults to "silu" (== "swish"), but it will be overridden if the hyperparameter layers.ParamActivation (="activation") is set in the context.
func (*Config) BSpline ¶
BSpline configures the KAN to use b-splines to model \phi(x), the univariate function described in the KAN the paper. It also sets the number of control points to use to model the function.
The numControlPoints must be greater or equal to 3, and it defaults to 20 and can also be set by using the hyperparameter ParamNumControlPoints ("kan_num_points").
func (*Config) Discrete ¶
Discrete configures the KAN to use a "piecewise-constant" functions (as opposed to splines) to model \phi(x), the univariate function used in the pape, and set the number of control points to use for the function.
The numControlPoints must be greater or equal to 2, and it defaults to 20 and can also be set by using the hyperparameter ParamNumControlPoints ("kan_num_points").
func (*Config) DiscreteSoftness ¶
DiscreteSoftness sets how much softness to use during training. If set to 0 softness is disabled.
The default is 0.1, and it can be set with the hyperparameter ParamDiscreteSoftness ("kan_discrete_softness").
func (*Config) Done ¶
func (c *Config) Done() *Node
Done takes the configuration and apply the KAN bsplineLayer(s) configured.
func (*Config) NumHiddenLayers ¶
NumHiddenLayers configure the number of hidden layers between the input and the output. Each bsplineLayer will have numHiddenNodes nodes.
The default is 0 (no hidden layers), but it will be overridden if the hyperparameter ParamNumHiddenLayers is set in the context (ctx). The value for numHiddenNodes can also be configured with the hyperparameter ParamNumHiddenNodes.
func (*Config) Regularizer ¶
func (c *Config) Regularizer(regularizer regularizers.Regularizer) *Config
Regularizer to be applied to the learned weights. Default is none.
To use more than one type of Regularizer, use regularizers.Combine, and set the returned combined regularizer here.
For BSpline models it applies the regularizer to the control-points.
The default is no regularizer, but it can be configured by regularizers.ParamL1 and regularizers.ParamL2.
func (*Config) WithBSplineMagnitudeRegularizer ¶
func (c *Config) WithBSplineMagnitudeRegularizer(regularizer regularizers.Regularizer) *Config
WithBSplineMagnitudeRegularizer to be applied to the magnitude weights for BSpline. Default is none, but can be changed with hyperparameters ParamBSplineMagnitudeL2 and ParamBSplineMagnitudeL1.
For BSpline models it applies the regularizer to the control-points.