fm

package
v0.17.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Feb 26, 2025 License: Apache-2.0 Imports: 47 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

View Source
var (
	DType = dtypes.Float32
)

Functions

func BuildTrainingModelGraph

func BuildTrainingModelGraph(config *diffusion.Config) train.ModelFn

BuildTrainingModelGraph builds the ModelFn for training and evaluation.

It generates the random noise as the "source distribution" for each example image, as well as random values of t -> [0,1), used to train.

func CreateDefaultContext

func CreateDefaultContext() *context.Context

CreateDefaultContext sets the context with default hyperparameters to use with TrainModel.

func DisplayImagesAcrossTime

func DisplayImagesAcrossTime(cfg *diffusion.Config, numImages int, numSteps int, displayEveryNSteps int)

DisplayImagesAcrossTime creates numImages series of images of the images starting from gaussian noise being transposed to generate images.

It transports the random noise to generated images in numSteps, of which displayEveryNSteps are actually displayed.

Plotting results only work if in a Jupyter (with GoNB kernel) notebook.

func DropdownFlowerTypes(cfg *diffusion.Config, cacheKey string, numImages, numDiffusionSteps int, htmlId string) *xsync.Latch

DropdownFlowerTypes creates a drop-down that shows images at different diffusion steps.

If `cacheKey` empty, cache is by-passed. Otherwise, try to load images from cache first if available, or save generated images in cache for future use.

func GenerateFlowerIds

func GenerateFlowerIds(cfg *diffusion.Config, numImages int) *tensors.Tensor

GenerateFlowerIds generates random flower ids: this is the type of flowers, one of the 102.

func GenerateImagesOfAllFlowerTypes

func GenerateImagesOfAllFlowerTypes(cfg *diffusion.Config, numDiffusionSteps int) (predictedImages *tensors.Tensor)

GenerateImagesOfAllFlowerTypes takes one random noise, and generate the flower for each of the 102 types.

paramsSet are hyperparameters overridden, that it should not load from the checkpoint (see commandline.ParseContextSettings).

func GenerateImagesOfFlowerType

func GenerateImagesOfFlowerType(cfg *diffusion.Config, numImages int, flowerType int32, numDiffusionSteps int) (predictedImages *tensors.Tensor)

GenerateImagesOfFlowerType is similar to DisplayImagesAcrossTime, but it limits itself to generating images of only one flower type.

paramsSet are hyperparameters overridden, that it should not load from the checkpoint (see commandline.ParseContextSettings).

func GenerateNoise

func GenerateNoise(cfg *diffusion.Config, numImages int) *tensors.Tensor

GenerateNoise generates random noise that can be used to generate images.

func ImagesToHtml

func ImagesToHtml(images []image.Image) string

ImagesToHtml converts slice of images to a list of images side-by-side in HTML format, that can be easily displayed.

func MakeMoons

func MakeMoons(ctx *context.Context, g *Graph, n int) *Node

MakeMoons returns a collection of n points sampled from two interleaved half circles. This is a toy dataset to visualize clustering and classification algorithms.

Modeled after scikit-learn make_moons function.

It returns a tensor of the given shaped [n, 2].

func MidPointODEStep

func MidPointODEStep(ctx *context.Context, noisyImages, flowerIds, startTime, endTime *Node) *Node

MidPointODEStep using "Midpoint Method" (https://en.wikipedia.org/wiki/Midpoint_method)

Parameters:

  • noisyImages: the X_t being integrated from t=0 to t=1. Shaped [numImages, width, height, channels].
  • flowerIds: shaped [numImages, 1].
  • startTime and endTime can either be scalars or shaped [numImages, 1]. They must be contrained to 0 <= startTime < 1 and startTime < endTime <= 1.

Returns the sample images moved ΔT (ΔT=endTime-startTime) towards the target distribution.

func PlotImages

func PlotImages(images []image.Image)

PlotImages all in one row. The image size in the HTML is set to the value given.

This only works in a Jupyter (GoNB kernel) notebook.

func PlotImagesTensor

func PlotImagesTensor(imagesT *tensors.Tensor)

PlotImagesTensor plots images in tensor format, all in one row. It assumes image's MaxValue of 255.

This only works in a Jupyter (GoNB kernel) notebook.

func PlotModelEvolution

func PlotModelEvolution(cfg *diffusion.Config, imagesPerSample int, animate bool, globalStepLimits ...int)

PlotModelEvolution plots the saved sampled generated images of a model in the current configured checkpoint.

If animate is true it will do an animation from first to last image, staying a few seconds on the last image.

If one globaStepLimits is given, it will take the latest image whose global step <= than the one given.

If two globalStepLimits are given, they are considered a range (start, end) of global step limits.

It outputs at most imagesPerSample per checkpoint sampled.

func SliderDiffusionSteps

func SliderDiffusionSteps(cfg *diffusion.Config, cacheKey string, numImages int, numDiffusionSteps int, htmlId string) *xsync.Latch

SliderDiffusionSteps creates and animates a slider that shows images at different diffusion steps. It handles the slider on a separate goroutine. Trigger the returned latch to stop it.

If `cacheKey` empty, cache is by-passed. Otherwise, try to load images from cache first if available, or save generated images in cache for future use.

func TrainModel

func TrainModel(config *diffusion.Config, checkpointPath string, evaluateOnEnd bool, verbosity int)

TrainModel with given config -- it includes the context with hyperparameters.

func TrainingMonitor

func TrainingMonitor(checkpoint *checkpoints.Handler, loop *train.Loop, metrics []*tensors.Tensor,
	plotter stdplots.Plotter, evalDatasets []train.Dataset, generator *ImagesGenerator, kid *KidGenerator) error

TrainingMonitor is periodically called during training, and is used to report metrics and generate sample images at the current training step.

Types

type ImagesGenerator

type ImagesGenerator struct {
	// contains filtered or unexported fields
}

ImagesGenerator given noise and the flowerIds. Use it with NewImagesGenerator.

func NewImagesGenerator

func NewImagesGenerator(cfg *diffusion.Config, noise, flowerIds *tensors.Tensor, numSteps int) *ImagesGenerator

NewImagesGenerator generates flowers given initial `noise` and `flowerIds`, in `numSteps`.

func (*ImagesGenerator) Generate

func (g *ImagesGenerator) Generate() (batchedImages *tensors.Tensor)

Generate images from the original noise.

It can be called multiple times if the context changed, if the model was further trained. Otherwise, it will always return the same images.

func (*ImagesGenerator) GenerateEveryN

func (g *ImagesGenerator) GenerateEveryN(n int) (predictedImages []*tensors.Tensor, times []float64)

GenerateEveryN images from the original noise. They are generating by transposing the random noise to the distribution of the flowers images in numSteps steps. It always returns the last generated image, plus every n intermediary image generated.

It can be called more than once if the context changed, if the model was further trained. Otherwise, it will always return the same images.

It returns a slice of batches of images, one batch per intermediary diffusion step and another slice with the "time" of each step, 0 <= time <= 1, time = 1 being the last.

type KidGenerator

type KidGenerator struct {
	// contains filtered or unexported fields
}

KidGenerator generates the [Kernel Inception Distance (KID)](https://arxiv.org/abs/1801.01401) metric.

func NewKidGenerator

func NewKidGenerator(cfg *diffusion.Config, evalDS train.Dataset, numDiffusionStep int) *KidGenerator

NewKidGenerator allows to generate the Kid metric. The Context passed is the context for the diffusion model. It uses a different context for the InceptionV3 KID metric, so that it's weights are not included in the generator model.

func (*KidGenerator) Eval

func (kg *KidGenerator) Eval() (metric *tensors.Tensor)

func (*KidGenerator) EvalStepGraph

func (kg *KidGenerator) EvalStepGraph(ctx *context.Context, allImages []*Node) (metric *Node)

Directories

Path Synopsis

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL