checkpoints

package
v0.19.4 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: May 24, 2025 License: Apache-2.0 Imports: 26 Imported by: 1

Documentation

Overview

Package checkpoints implements checkpoint management: saving and loading of checkpoints to file, or loading a checkpoint from an embedded checkpoint.

The main object is the Handler, that should be created by calling Build, followed by the various options setting and finally calling Config.Done. Once create, if a previous saved checkpoint exists, it will automatically load variables and parameters for your model into Context. And as the model trains, one can call Handler.Save() at any time to save a new checkpoint -- typically one will do that inside train.EveryNSteps().

Example: After creating the Context, it checks if a checkpoint directory was set (`*flagCheckpoint`) and if yes, creates a checkpoints.Handler to save checkpoints every 100 steps, keeping the last `*flagCheckpointKeep` steps.

…
ctx := context.New()
ctx.SetParam(optimizers.ParamLearningRate, *flagLearningRate)

var checkpoint *checkpoints.Handler
if *flagCheckpoint != "" {
	var err error
	checkpoint, err = checkpoints.Build(ctx).Dir(*flagCheckpoint).Keep(*flagCheckpointKeep).Done()
	Must(err)  // Panics if err != nil.
}
…
// Build training loop.
loop := train.NewLoop(trainer)
commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop.
if checkpoint != nil {
	const priority = 100  // Large number here, means it runs last.
	train.EveryNSteps(loop, 100, "checkpointing", priority, checkpoint.OnStepFn)
}
…

Example 2: To load a checkpoint from an embedded checkpoint, something usually used to distribute a model for inference:

//go:embed "my_model/checkpoint.json"
var myModelJson string

//go:embed "my_model/checkpoint.bin"
var myModelBin []byte

...
_ = checkpoints.Build(ctx).
	FromEmbed(myModelJson, myModelBin).
	Immediate().
	Done()

TODO:

  1. Compress checkpoints.
  2. Allow to specify parts of the model to load / scope where they should be loaded to, for transfer learning.

Index

Constants

View Source
const (

	// JsonNameSuffix for the JSON files returned by Handler.ListCheckpoints.
	JsonNameSuffix = ".json"

	// BinDataSuffix for the data files (holding the tensor values) returned by Handler.ListCheckpoints.
	BinDataSuffix = ".bin"

	// BackupDir is the name of the (sub-)directory under the model checkpoints directory that holds
	// the backups. See Handler.Backup.
	BackupDir = "backup"
)

Variables

View Source
var (
	// DirPermMode is the default directory creation permission (before umask) used.
	DirPermMode = os.FileMode(0770)
)

Functions

This section is empty.

Types

type Config

type Config struct {
	// contains filtered or unexported fields
}

Config for the checkpoints Handler to be created. This is created with Build() and configured with the various methods. Once finished, call Done() and it will output a checkpoints.Handler that loads (if there are any previously saved checkpoints) and saves checkpoints.

func Build

func Build(ctx *context.Context) *Config

Build a configuration for building a checkpoints.Handler. After configuring the Config object returned, call `Done` to get the configured checkpoints.Handler.

The new checkpoints.Handler will load ("lazy" by default) a checkpoint to the context (see Config.Dir, Config.DirFromBase or Config.FromEmbed to specify where to load/save) if it exists, otherwise it creates a new directory and can simply be used to save checkpoints.

When a checkpoint is "lazy loaded", its variables are not listed by default (if one uses Context.EnumerateVariables or Context.IterVariables). But if they are directly accessed, they are on-the-fly loaded. This is convenient when loading only part of the variables for inference (if one doesn't care about the training/optimizer variables), or for transfer learning part of a model. It also works to continue training a model loaded from a checkpoint. But if you need to variables to be loaded immediately, use Config.Immediate() -- an inspecting tool, like gomlx_checkpoints, will want to do that.

See Config.Dir, Config.DirFromBase or Config.FromEmbed to specify where to load/save.

func Load added in v0.15.2

func Load(ctx *context.Context) *Config

Load creates configuration to load a checkpoint. It's identical to Build, except it will fail if the checkpoint does not already exist.

Use Dir or DirWithBase to configure location of checkpoint. Once configured, call Config.Done to actually load it.

func (*Config) Dir

func (c *Config) Dir(dir string) *Config

Dir sets the directory where to save / load the checkpoints.

One must be set either Dir, DirFromBase or TempDir before building the checkpoints.Handler.

func (*Config) DirFromBase added in v0.5.0

func (c *Config) DirFromBase(dir, baseDir string) *Config

DirFromBase sets the directory where to save / load the checkpoints. If `dir` is not an absolute path, assumes it is a subdirectory of baseDir.

One must be set either Dir, DirFromBase or TempDir before building the checkpoints.Handler.

func (*Config) Done

func (c *Config) Done() (*Handler, error)

Done creates a Handler with the current configuration. It returns an error if the configuration is invalid, or if it's missing information.

func (*Config) ExcludeAllParams added in v0.11.0

func (c *Config) ExcludeAllParams() *Config

ExcludeAllParams configures Handler to exclude Context parameters (values usually read/written by Context.GetParam and context.SetParam) from being read.

By default, Params are loaded and set into Context the moment Handler is created (when Done() is called), overriding values already present in the Context.

See also ExcludeParams to exclude specific params from being read.

func (*Config) ExcludeParams

func (c *Config) ExcludeParams(paramsToExclude ...string) *Config

ExcludeParams configures Handler to exclude certain Context parameters (values usually read/written by Context.GetParam and context.SetParam) from being read. It can be called multiple times, each call adds new parameters to be excluded.

For values in paramsToExclude that don't include a preceding scope (separated by "/"), the exclusion applies to all scopes. Otherwise, it applies only to the specific scope. See context.JoinScope to merge scope and name.

By default, no parameters are excluded.

See also ExcludeAllParams to exclude all params from being read.

func (*Config) ExcludeVars added in v0.11.0

func (c *Config) ExcludeVars(vars ...*context.Variable) *Config

ExcludeVars enumerate variables to be excluded from saving. The function can be called multiple times, adding variables to be excluded from saving.

It can also be called after the Handler object is built as new variables are created.

func (*Config) FromEmbed added in v0.19.0

func (c *Config) FromEmbed(json string, binary []byte) *Config

FromEmbed allows one to load a checkpoint from an embedded checkpoint (using the go:embed tag).

You must set only one of Dir(or DirFromBase) or FromEmbed, but not both.

Notice that after Done() is called, it releases the references to the passed json and binary blobs, potentially freeing the resources.

Example:

//go:embed "my_model/checkpoint.json"
var myModelJson []byte

//go:embed "my_model/checkpoint.bin"
var myModelBin []byte

...
_ = checkpoints.Build(ctx).FromEmbed(myModelJson, myModelBin).Done()

func (*Config) Immediate added in v0.9.0

func (c *Config) Immediate() *Config

Immediate forces immediate load of all variables, as opposed to dynamically load variables from checkpoint as they are being used when building the model.

Not normally needed, but may be handy for testing. See also context.Context.InspectVariableIfLoaded.

It may trigger use more memory if not all variables are not used by the model -- not all training data (e.g.: optimizer variables) is used for inference.

func (*Config) Keep

func (c *Config) Keep(n int) *Config

Keep configures the number of checkpoint files to keep. If set to -1, it will never erase older checkpoints. The default is 1.

func (*Config) MustDone

func (c *Config) MustDone() *Handler

MustDone constructs the checkpoints.Handler. It panics if there was an error.

func (*Config) TakeMean added in v0.4.1

func (c *Config) TakeMean(n int, backend backends.Backend) *Config

TakeMean loads the mean of the last `n` checkpoints. If `n <= 0`, take the mean of all available checkpoints. Notice that only trainable variables are averaged. Variables that have integer values or are not marked as trainable (e.g. the global step), are taken from the most recent checkpoint instead.

The default is 1, so only load the most recent checkpoint.

If n != 1, it requires a backend that will be used to calculate the mean. If n == 1, the backend argument is ignored and can be nil.

Notice the mean is taken one tensor at a time, so at any time there is only one copy of the model weights in memory, plus the tensor being merged.

If a mean is calculated, the values of the variables will be stored on-device. This can be good -- if their values are going to be used on-device anyway -- or bad -- if they are not needed on-device, and it's using the limited on-device space. Consider *tensors.Tensor.MaterializeLocal and *tensors.Tensor.InvalidateOnDevice to have them moved locally if so desired.

func (*Config) TempDir

func (c *Config) TempDir(dir, pattern string) *Config

TempDir creates a temporary directory under dir, with the pattern name, and uses this directory to load / save checkpoints. It's a convenience wrapper to os.MkdirTemp.

If dir is the empty string, MkdirTemp uses the default directory for temporary files, as returned by os.TempDir.

The new directory's name is generated by adding a random string to the end of pattern. If `pattern` includes a "*", the random string replaces the last "*" instead (see os.MkdirTemp).

Any errors are reported on the return to the call to the method Done.

One must be set either Dir, DirFromBase or TempDir before building the checkpoints.Handler.

type Handler

type Handler struct {
	// contains filtered or unexported fields
}

Handler handles saving and loading of checkpoints for a context.Context. See example in package documentation.

It is created and configured using Build(), followed by options setting and then calling Config.Done().

Loading data into Handler happens at its creation time: it loads from the latest checkpoint. (Hyper-)Parameters are immediately loaded into the context then (if not Config.ExcludeAllParams) but the loaded variable values are only "consumed" (used) one at a time, as the variables are created during the graph building (e.g: when building the model).

Saving of checkpoints is explicit, by calling Handler.Save(). Usually this is done by configuring train.Loop to call it using train.EveryNSteps or train.NTimesDuringLoop. When saving all variables in Context are saved, along with any previous variables loaded by the Handler that were not used by Context and with the `Params` for all scopes (including changed values).

There can be more than one Handler attached to a Context -- they are used for loading in order they are created (so the first one created takes priority). Multiple Handler set up can be used for instance for transfer learning, where parts of the model are loaded from somewhere else.

A Handler can only be "attached" to one context.Context. If one wants to load the same checkpoint to two different contexts, another Handler object needs to be created. This is because once a variable is loaded, it is transferred to Context, and handler does not keep it.

func (*Handler) Backup added in v0.11.0

func (h *Handler) Backup() error

Backup links (or copies) the latest checkpoint to a separate sub-directory under the model directory called "backup" (constant in checkpoints.BackupDir).

This way the backed up checkpoint doesn't get automatically deleted as the model training progresses.

Useful, for instance, to back up the checkpoints used to collect the plot points. But can be used for any other reason.

func (*Handler) DeleteVariable added in v0.11.0

func (h *Handler) DeleteVariable(ctx *context.Context, scope, name string)

DeleteVariable implements context.Loader. It is called whenever Context.DeleteVariable is called. The deletion should cascade to the loader, otherwise the variable will reappear after deletion.

func (*Handler) Dir

func (h *Handler) Dir() string

Dir returns the directory the Handler is configured to. It cannot be changed once the Handler was created.

It returns "" (empty) if the Handler is `nil`.

func (*Handler) ExcludeVarsFromSaving added in v0.10.0

func (h *Handler) ExcludeVarsFromSaving(vars ...*context.Variable)

ExcludeVarsFromSaving enumerate variables to be excluded from saving. The function can be called multiple times, adding variables to be excluded from saving.

func (*Handler) HasCheckpoints added in v0.4.0

func (h *Handler) HasCheckpoints() (bool, error)

HasCheckpoints returns whether there are any checkpoints saved.

func (*Handler) ListCheckpoints

func (h *Handler) ListCheckpoints() (checkpoints []string, err error)

ListCheckpoints returns the base file paths of the checkpoints in the directory in time order (older first).

The actual paths are these base file paths suffixed with JsonNameSuffix and BinDataSuffix.

func (*Handler) LoadVariable

func (h *Handler) LoadVariable(ctx *context.Context, scope, name string) (value *tensors.Tensor, found bool)

LoadVariable implements context.Loader. This is called by context.Context when the variable is used for the first time. The user may want to use this function to inspect loaded values for testing.

func (*Handler) LoadedVariables added in v0.4.1

func (h *Handler) LoadedVariables() map[string]*tensors.Tensor

LoadedVariables for inspection. These are the values loaded -- but not necessarily immediately available in context, since they are actually used only when a model asks for the variable.

The Handler owns the returned map, don't change it -- the behavior is undefined if you do.

func (*Handler) OnStepFn added in v0.4.0

func (h *Handler) OnStepFn(_ *train.Loop, _ []*tensors.Tensor) error

OnStepFn implements `train.OnStepFn`, and make it convenient to attach to a training loop. It simply calls save.

func (*Handler) Save

func (h *Handler) Save() error

Save creates a new checkpoint and save the context variables and (optionally) Params.

All variables in the context are saved, as well as those previously loaded -- this allows one to load the variables only for a part of the model, update that part and save again with everything.

Params is (de-) serialized with package json.

If the handler is nil, this is a no-op: so it's safe to simply be called, even if the user hasn't configured a checkpoint.

func (*Handler) String

func (h *Handler) String() string

String implements Stringer.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL