Documentation
¶
Index ¶
- Constants
- Variables
- func CnnEmbeddings(ctx *context.Context, images *Node) *Node
- func CnnModelGraph(ctx *context.Context, spec any, inputs []*Node) []*Node
- func CreateDatasets(backend backends.Backend, config *DatasetsConfiguration) (trainDS, trainEvalDS, validationEvalDS train.Dataset)
- func CreateDefaultContext() *context.Context
- func Download(baseDir string) error
- func LinearModelGraph(ctx *context.Context, spec any, inputs []*Node) []*Node
- func NewDataset(backend backends.Backend, name, baseDir, mode string, dtype dtypes.DType) (ds *data.InMemoryDataset, err error)
- func TrainModel(ctx *context.Context, dataDir, checkpointPath string, paramsSet []string) error
- type ContextFn
- type DatasetsConfiguration
- type Image
- type Label
Constants ¶
const ( DownloadURL = "https://storage.googleapis.com/cvdf-datasets/mnist" TrainImagesFilename = "train-images-idx3-ubyte.gz" TrainLabelsFilename = "train-labels-idx1-ubyte.gz" TestImagesFilename = "t10k-images-idx3-ubyte.gz" TestLabelsFilename = "t10k-labels-idx1-ubyte.gz" TrainImagesHash = "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609" TrainLabelsHash = "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c" TestImagesHash = "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6" TestLabelsHash = "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6" Width = 28 Height = 28 Depth = 3 NumClasses = 10 TrainExamples = 60000 TestExamples = 10000 ImageMagic = 0x00000803 LabelMagic = 0x00000801 )
const ( ImageFileType fileType = iota LabelFileType )
Variables ¶
var ModelList = []string{"linear", "cnn"}
Functions ¶
func CnnEmbeddings ¶
func CnnModelGraph ¶
CnnModelGraph builds the CNN model for our demo. It returns the logit, not the predictions, which works with most losses with shape `[batch_size, NumClasses]`. inputs: only one tensor, with shape `[batch_size, width, height, depth]`.
func CreateDatasets ¶
func CreateDatasets(backend backends.Backend, config *DatasetsConfiguration) (trainDS, trainEvalDS, validationEvalDS train.Dataset)
CreateDatasets used for training and evaluation.
func CreateDefaultContext ¶
func LinearModelGraph ¶
LinearModelGraph builds a simple model logistic model It returns the logit, not the predictions, which works with most losses with shape `[batch_size, NumClasses]`. inputs: only one tensor, with shape `[batch_size, width, height, depth]`.
func NewDataset ¶
func NewDataset(backend backends.Backend, name, baseDir, mode string, dtype dtypes.DType) (ds *data.InMemoryDataset, err error)
NewDataset creates a train.Dataset that yields images from MNIST Dataset.
It takes the following arguments:
- name:
- baseDir:
- mode: choose between 'train' and 'test'
Types ¶
type DatasetsConfiguration ¶
type DatasetsConfiguration struct {
// DataDir, where downloaded and generated data is stored.
DataDir string
// BatchSize for training and evaluation batches.
BatchSize, EvalBatchSize int
// UseParallelism when using Dataset.
UseParallelism bool
// BufferSize used for data.ParallelDataset, to cache intermediary batches. This value is used
// for each dataset.
BufferSize int
Dtype dtypes.DType
}
DatasetsConfiguration holds various parameters on how to transform the input images.
func NewDatasetsConfigurationFromContext ¶
func NewDatasetsConfigurationFromContext(ctx *context.Context, dataDir string) *DatasetsConfiguration
NewDatasetsConfigurationFromContext create a preprocessing configuration based on hyperparameters set in the context.
type Image ¶
Image represents a MNIST image. It is a array a bytes representing the color. 0 is black (the background) and 255 is white (the digit color).
var (
AssertImageIsImageImage *Image
)
func (Image) ColorModel ¶
ColorModel implements the image.Image interface.
Directories
¶
| Path | Synopsis |
|---|---|
|
Package classifier is a MNIST-based digit classifier.
|
Package classifier is a MNIST-based digit classifier. |
|
demo for mnist library
|
demo for mnist library |