Documentation
¶
Overview ¶
Package checkpoints implements checkpoint management: saving and loading of checkpoints.
The main object is the Handler, that should be created by calling Build, followed by the various options setting and finally calling Config.Done. Once create, if a previous saved checkpoint exists, it will automatically load variables and parameters for your model into Context. And as the model trains, one can call Handler.Save() at any time to save a new checkpoint -- typically one will do that inside train.EveryNSteps().
Example: After creating the Context, it checks if a checkpoint directory was set (`*flagCheckpoint`) and if yes, creates a checkpoints.Handler to save checkpoints every 100 steps, keeping the last `*flagCheckpointKeep` steps.
```
…
ctx := context.New()
ctx.SetParam(optimizers.ParamLearningRate, *flagLearningRate)
var checkpoint *checkpoints.Handler
if *flagCheckpoint != "" {
var err error
checkpoint, err = checkpoints.Build(ctx).Dir(*flagCheckpoint).Keep(*flagCheckpointKeep).Done()
Must(err) // Panics if err != nil.
}
…
// Build training loop.
loop := train.NewLoop(trainer)
commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop.
if checkpoint != nil {
const priority = 100 // Large number here, means it runs last.
train.EveryNSteps(loop, 100, "checkpointing", priority, checkpoint.OnStepFn)
}
…
```
TODO:
- Compress checkpoints.
- Allow to specify parts of the model to load / scope where they should be loaded to, for transfer learning.
Index ¶
- Constants
- Variables
- type Config
- func (c *Config) Dir(dir string) *Config
- func (c *Config) DirFromBase(dir, baseDir string) *Config
- func (c *Config) Done() (*Handler, error)
- func (c *Config) ExcludeAllParams() *Config
- func (c *Config) ExcludeParams(paramsToExclude ...string) *Config
- func (c *Config) ExcludeVars(vars ...*context.Variable) *Config
- func (c *Config) Immediate() *Config
- func (c *Config) Keep(n int) *Config
- func (c *Config) MustDone() *Handler
- func (c *Config) TakeMean(n int, backend backends.Backend) *Config
- func (c *Config) TempDir(dir, pattern string) *Config
- type Handler
- func (h *Handler) Backup() error
- func (h *Handler) DeleteVariable(ctx *context.Context, scope, name string)
- func (h *Handler) Dir() string
- func (h *Handler) ExcludeVarsFromSaving(vars ...*context.Variable)
- func (h *Handler) HasCheckpoints() (bool, error)
- func (h *Handler) ListCheckpoints() (checkpoints []string, err error)
- func (h *Handler) LoadVariable(ctx *context.Context, scope, name string) (value *tensors.Tensor, found bool)
- func (h *Handler) LoadedVariables() map[string]*tensors.Tensor
- func (h *Handler) OnStepFn(_ *train.Loop, _ []*tensors.Tensor) error
- func (h *Handler) Save() error
- func (h *Handler) String() string
Constants ¶
const ( // JsonNameSuffix for the JSON files returned by Handler.ListCheckpoints. JsonNameSuffix = ".json" // VarDataSuffix for the data files (holding the tensor values) returned by Handler.ListCheckpoints. VarDataSuffix = ".bin" // BackupDir is the name of the (sub-)directory under the model checkpoints directory that holds // the backups. See Handler.Backup. BackupDir = "backup" )
Variables ¶
var ( // DirPermMode is the default directory creation permission (before umask) used. DirPermMode = os.FileMode(0770) )
Functions ¶
This section is empty.
Types ¶
type Config ¶
type Config struct {
// contains filtered or unexported fields
}
Config for the checkpoints Handler to be created. This is created with Build() and configured with the various methods. Once finished, call Done() and it will output a checkpoints.Handler that loads (if there are any previously saved checkpoints) and saves checkpoints.
func Build ¶
Build a configuration for building a checkpoints.Handler. After configuring the Config object returned, call `Done` to get the configured checkpoints.Handler.
The new checkpoints.Handler will load ("lazy" by default) a checkpoint to the context (see Dir or DirFromBase) if it exists, otherwise it creates a new directory and can simply be used to save checkpoints.
When a checkpoint is "lazy loaded", its variables are not listed by default (if one uses Context.EnumerateVariables or Context.IterVariables). But if they are directly accessed, they are on-the-fly loaded. This is convenient when loading only part of the variables for inference (if one doesn't care about the training/optimizer variables), or for transfer learning part of a model. It also works to continue training a model loaded from a checkpoint. But if you need to variables to be loaded immediately, use Config.Immediate() -- an inspecting tool, like gomlx_checkpoints, will want to do that.
func Load ¶ added in v0.15.2
Load creates configuration to load a checkpoint. It's identical to Build, except it will fail if the checkpoint does not already exist.
Use Dir or DirWithBase to configure location of checkpoint. Once configured, call Config.Done to actually load it.
func (*Config) Dir ¶
Dir sets the directory where to save / load the checkpoints.
One must be set either Dir, DirFromBase or TempDir before building the checkpoints.Handler.
func (*Config) DirFromBase ¶ added in v0.5.0
DirFromBase sets the directory where to save / load the checkpoints. If `dir` is not an absolute path, assumes it is a subdirectory of baseDir.
One must be set either Dir, DirFromBase or TempDir before building the checkpoints.Handler.
func (*Config) Done ¶
Done creates a Handler with the current configuration. It returns an error if the configuration is invalid, or if it's missing information.
func (*Config) ExcludeAllParams ¶ added in v0.11.0
ExcludeAllParams configures Handler to exclude Context parameters (values usually read/written by Context.GetParam and context.SetParam) from being read.
By default, Params are loaded and set into Context the moment Handler is created (when Done() is called), overriding values already present in the Context.
See also ExcludeParams to exclude specific params from being read.
func (*Config) ExcludeParams ¶
ExcludeParams configures Handler to exclude certain Context parameters (values usually read/written by Context.GetParam and context.SetParam) from being read. It can be called multiple times, each call adds new parameters to be excluded.
For values in paramsToExclude that don't include a preceding scope (separated by "/"), the exclusion applies to all scopes. Otherwise, it applies only to the specific scope. See context.JoinScope to merge scope and name.
By default, no parameters are excluded.
See also ExcludeAllParams to exclude all params from being read.
func (*Config) ExcludeVars ¶ added in v0.11.0
ExcludeVars enumerate variables to be excluded from saving. The function can be called multiple times, adding variables to be excluded from saving.
It can also be called after the Handler object is built as new variables are created.
func (*Config) Immediate ¶ added in v0.9.0
Immediate forces immediate load of all variables, as opposed to dynamically load variables from checkpoint as they are being used when building the model.
Not normally needed, but may be handy for testing. See also context.Context.InspectVariableIfLoaded.
It may trigger use more memory if not all variables are not used by the model -- not all training data (e.g.: optimizer variables) is used for inference.
func (*Config) Keep ¶
Keep configures the number of checkpoint files to keep. If set to -1, it will never erase older checkpoints. The default is 1.
func (*Config) MustDone ¶
MustDone constructs the checkpoints.Handler. It panics if there was an error.
func (*Config) TakeMean ¶ added in v0.4.1
TakeMean loads the mean of the last `n` checkpoints. If `n <= 0`, take the mean of all available checkpoints. Notice that only trainable variables are averaged. Variables that have integer values or are not marked as trainable (e.g. the global step), are taken from the most recent checkpoint instead.
The default is 1, so only load the most recent checkpoint.
If n != 1, it requires a backend that will be used to calculate the mean. If n == 1, the backend argument is ignored and can be nil.
Notice the mean is taken one tensor at a time, so at any time there is only one copy of the model weights in memory, plus the tensor being merged.
If a mean is calculated, the values of the variables will be stored on-device. This can be good -- if their values are going to be used on-device anyway -- or bad -- if they are not needed on-device, and it's using the limited on-device space. Consider *tensors.Tensor.MaterializeLocal and *tensors.Tensor.InvalidateOnDevice to have them moved locally if so desired.
func (*Config) TempDir ¶
TempDir creates a temporary directory under dir, with the pattern name, and uses this directory to load / save checkpoints. It's a convenience wrapper to os.MkdirTemp.
If dir is the empty string, MkdirTemp uses the default directory for temporary files, as returned by os.TempDir.
The new directory's name is generated by adding a random string to the end of pattern. If `pattern` includes a "*", the random string replaces the last "*" instead (see os.MkdirTemp).
Any errors are reported on the return to the call to the method Done.
One must be set either Dir, DirFromBase or TempDir before building the checkpoints.Handler.
type Handler ¶
type Handler struct {
// contains filtered or unexported fields
}
Handler handles saving and loading of checkpoints for a context.Context. See example in package documentation.
It is created and configured using Build(), followed by options setting and then calling Config.Done().
Loading data into Handler happens at its creation time: it loads from the latest checkpoint. (Hyper-)Parameters are immediately loaded into the context then (if not Config.ExcludeAllParams) but the loaded variable values are only "consumed" (used) one at a time, as the variables are created during the graph building (e.g: when building the model).
Saving of checkpoints is explicit, by calling Handler.Save(). Usually this is done by configuring train.Loop to call it using train.EveryNSteps or train.NTimesDuringLoop. When saving all variables in Context are saved, along with any previous variables loaded by the Handler that were not used by Context and with the `Params` for all scopes (including changed values).
There can be more than one Handler attached to a Context -- they are used for loading in order they are created (so the first one created takes priority). Multiple Handler set up can be used for instance for transfer learning, where parts of the model are loaded from somewhere else.
A Handler can only be "attached" to one context.Context. If one wants to load the same checkpoint to two different contexts, another Handler object needs to be created. This is because once a variable is loaded, it is transferred to Context, and handler does not keep it.
func (*Handler) Backup ¶ added in v0.11.0
Backup links (or copies) the latest checkpoint to a separate sub-directory under the model directory called "backup" (constant in checkpoints.BackupDir).
This way the backed up checkpoint doesn't get automatically deleted as the model training progresses.
Useful, for instance, to back up the checkpoints used to collect the plot points. But can be used for any other reason.
func (*Handler) DeleteVariable ¶ added in v0.11.0
DeleteVariable implements context.Loader. It is called whenever Context.DeleteVariable is called. The deletion should cascade to the loader, otherwise the variable will reappear after deletion.
func (*Handler) Dir ¶
Dir returns the directory the Handler is configured to. It cannot be changed once the Handler was created.
It returns "" (empty) if the Handler is `nil`.
func (*Handler) ExcludeVarsFromSaving ¶ added in v0.10.0
ExcludeVarsFromSaving enumerate variables to be excluded from saving. The function can be called multiple times, adding variables to be excluded from saving.
func (*Handler) HasCheckpoints ¶ added in v0.4.0
HasCheckpoints returns whether there are any checkpoints saved.
func (*Handler) ListCheckpoints ¶
ListCheckpoints returns the base file paths of the checkpoints in the directory in time order (older first).
The actual paths are these base file paths suffixed with JsonNameSuffix and VarDataSuffix.
func (*Handler) LoadVariable ¶
func (h *Handler) LoadVariable(ctx *context.Context, scope, name string) (value *tensors.Tensor, found bool)
LoadVariable implements context.Loader. This is called by context.Context when the variable is used for the first time. The user may want to use this function to inspect loaded values for testing.
func (*Handler) LoadedVariables ¶ added in v0.4.1
LoadedVariables for inspection. These are the values loaded -- but not necessarily immediately available in context, since they are actually used only when a model asks for the variable.
The Handler owns the returned map, don't change it -- the behavior is undefined if you do.
func (*Handler) OnStepFn ¶ added in v0.4.0
OnStepFn implements `train.OnStepFn`, and make it convenient to attach to a training loop. It simply calls save.
func (*Handler) Save ¶
Save creates a new checkpoint and save the context variables and (optionally) Params.
All variables in the context are saved, as well as those previously loaded -- this allows one to load the variables only for a part of the model, update that part and save again with everything.
Params is (de-) serialized with package json.
If the handler is nil, this is a no-op: so it's safe to simply be called, even if the user hasn't configured a checkpoint.