cifar

package
v0.14.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Oct 24, 2024 License: Apache-2.0 Imports: 31 Imported by: 0

Documentation

Overview

Package cifar provides a library of tools to download and manipulate Cifar-10 dataset. Information about it in https://www.cs.toronto.edu/~kriz/cifar.html

Index

Constants

View Source
const (
	C10Url     = "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
	C10TarName = "cifar-10-binary.tar.gz"
	C10SubDir  = "cifar-10-batches-bin"

	C100Url     = "https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz"
	C100TarName = "cifar-100-binary.tar.gz"
	C100SubDir  = "cifar-100-binary"

	// NumExamples is the total number of examples, including training and testing.
	// The value is the same for both, Cifar-10 and Cifar-100.
	NumExamples = 60000

	// NumTrainExamples is the number of examples reserved for training, the starting ones.
	// The value is the same for both, Cifar-10 and Cifar-100.
	NumTrainExamples = 50000

	// NumTestExamples is the number of examples reserved for testing, the last ones.
	// The value is the same for both, Cifar-10 and Cifar-100.
	NumTestExamples = 10000
)
View Source
const (
	Width  int = 32
	Height int = 32
	Depth  int = 3
)

Width, Height and Depth are the dimensions of the images, the same for Cifar-10 and Cifar-100.

View Source
const C10ExamplesPerFile = 10000
View Source
const ParamCNNNormalization = "cnn_normalization"

Variables

View Source
var (
	C10Labels = []string{"airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"}

	C100CoarseLabels = []string{"aquatic_mammals", "fish", "flowers", "food_containers", "fruit_and_vegetables",
		"household_electrical_devices", "household_furniture", "insects", "large_carnivores",
		"large_man-made_outdoor_things", "large_natural_outdoor_scenes", "large_omnivores_and_herbivores",
		"medium_mammals", "non-insect_invertebrates", "people", "reptiles", "small_mammals", "trees", "vehicles_1",
		"vehicles_2"}
	C100FineLabels = []string{"apple", "aquarium_fish", "baby", "bear", "beaver", "bed", "bee", "beetle", "bicycle",
		"bottle", "bowl", "boy", "bridge", "bus", "butterfly", "camel", "can", "castle", "caterpillar", "cattle",
		"chair", "chimpanzee", "clock", "cloud", "cockroach", "couch", "crab", "crocodile", "cup", "dinosaur",
		"dolphin", "elephant", "flatfish", "forest", "fox", "girl", "hamster", "house", "kangaroo", "keyboard", "lamp",
		"lawn_mower", "leopard", "lion", "lizard", "lobster", "man", "maple_tree", "motorcycle", "mountain", "mouse",
		"mushroom", "oak_tree", "orange", "orchid", "otter", "palm_tree", "pear", "pickup_truck", "pine_tree", "plain",
		"plate", "poppy", "porcupine", "possum", "rabbit", "raccoon", "ray", "road", "rocket", "rose", "sea", "seal",
		"shark", "shrew", "skunk", "skyscraper", "snail", "snake", "spider", "squirrel", "streetcar", "sunflower",
		"sweet_pepper", "table", "tank", "telephone", "television", "tiger", "tractor", "train", "trout", "tulip",
		"turtle", "wardrobe", "whale", "willow_tree", "wolf", "woman", "worm"}
)
View Source
var (
	// DType used in the mode.
	DType = dtypes.Float32

	// C10ValidModels is the list of model types supported.
	C10ValidModels = []string{"fnn", "kan", "cnn"}

	// ParamsExcludedFromSaving is the list of parameters (see CreateDefaultContext) that shouldn't be saved
	// along on the models checkpoints, and may be overwritten in further training sessions.
	ParamsExcludedFromSaving = []string{
		"data_dir", "train_steps", "num_checkpoints", "plots",
	}
)

Backend is created once and reused if train is called multiple times.

Functions

func C10ConvolutionModelGraph added in v0.11.0

func C10ConvolutionModelGraph(ctx *context.Context, spec any, inputs []*graph.Node) []*graph.Node

C10ConvolutionModelGraph implements train.ModelFn and returns the logit Node, given the input image. It's a straight forward CNN (Convolution Neural Network) model.

This is modeled after the Keras example in Kaggle: https://www.kaggle.com/code/ektasharma/simple-cifar10-cnn-keras-code-with-88-accuracy (Thanks @ektasharma)

func C10PlainModelGraph added in v0.11.0

func C10PlainModelGraph(ctx *context.Context, spec any, inputs []*graph.Node) []*graph.Node

C10PlainModelGraph implements train.ModelFn, and returns the logit Node, given the input image. It's a basic FNN (Feedforward Neural Network), so no convolutions. It is meant only as an example.

func ConvertToGoImage

func ConvertToGoImage(images *tensors.Tensor, exampleNum int) *image.NRGBA

func CreateDatasets added in v0.11.0

func CreateDatasets(backend backends.Backend, dataDir string, batchSize, evalBatchSize int) (trainDS, trainEvalDS, validationEvalDS train.Dataset)

func DownloadCifar10

func DownloadCifar10(baseDir string) error

func DownloadCifar100

func DownloadCifar100(baseDir string) error

func NewDataset

func NewDataset(backend backends.Backend, name, baseDir string, source DataSource, dtype dtypes.DType, partition Partition) *data.InMemoryDataset

NewDataset returns a Dataset for the training data, which implements train.Dataset and hence can be used by train.Trainer methods.

It automatically downloads the data from the web, and then loads the data into memory if it hasn't been loaded yet. It caches the result, so multiple Datasets can be created without any extra costs in time/memory.

func ResetCache

func ResetCache()

func TrainCifar10Model added in v0.11.0

func TrainCifar10Model(ctx *context.Context, dataDir, checkpointPath string, evaluateOnEnd bool, verbosity int, paramsSet []string)

TrainCifar10Model with hyperparameters given in ctx.

Types

type DataSource

type DataSource int

DataSource refers to Cifar-10 (C10) or Cifar-100 (C100).

const (
	C10 DataSource = iota
	C100
)

type ImagesAndLabels

type ImagesAndLabels struct {
	// contains filtered or unexported fields
}

type Partition

type Partition int

Partition refers to the train or test partitions of the datasets.

const (
	Train Partition = iota
	Test
)

type PartitionedImagesAndLabels added in v0.5.0

type PartitionedImagesAndLabels [2]ImagesAndLabels

PartitionedImagesAndLabels holds for each partition (Train, Test), one set of Images and Labels.

func LoadCifar10

func LoadCifar10(backend backends.Backend, baseDir string, dtype dtypes.DType) (partitioned PartitionedImagesAndLabels)

LoadCifar10 into 2 tensors of the given DType: images with given dtype and shaped [NumExamples=60000, Height=32, Width=32, Depth=3], and labels shaped [NumExamples=60000, 1] of Int64. The first 50k examples are for training, and the last 10k for testing. Only Float32 and Float64 dtypes are supported for now.

func LoadCifar100

func LoadCifar100(backend backends.Backend, baseDir string, dtype dtypes.DType) (partitioned PartitionedImagesAndLabels)

LoadCifar100 into 2 tensors of the given DType: images with given dtype and shaped [NumExamples=60000, Height=32, Width=32, Depth=3], and labels shaped [NumExamples=60000, 1] of Int64. The first 50k examples are for training, and the last 10k for testing. Only Float32 and Float64 dtypes are supported for now.

Directories

Path Synopsis
CIFAR-10 demo trainer.
CIFAR-10 demo trainer.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL