model

package
v1.38.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 30, 2026 License: Apache-2.0 Imports: 15 Imported by: 1

Documentation

Overview

Package model provides adapter implementations for bridging existing and new model interfaces.

Package model provides the abstraction layer for representing, loading, serializing, and managing neural network models in the Zerfoo framework. (Stability: stable)

Model Representation

The core Model struct pairs a token embedding layer with a computation graph. Call NewModel to create one, then invoke its Forward method to run the embedding lookup followed by a full graph evaluation.

The ModelInstance interface generalises this to any implementation that supports forward inference, backpropagation, parameter access, and training/inference mode toggling. StandardModelInstance adapts Model to this interface and is the default implementation used throughout Zerfoo.

Provider and Registry

ModelProvider is the factory interface for creating model instances. StandardModelProvider is the built-in implementation that creates models from an existing computation graph.

ModelRegistry is a thread-safe, generic registry that stores factory functions for every pluggable component kind: providers, serializers, loaders, exporters, validators, and optimizers. Pre-instantiated registries for common numeric types are available as Float32ModelRegistry and Float64ModelRegistry. Components are registered by name and retrieved via their corresponding Get* methods (e.g. ModelRegistry.GetModelProvider).

Layer Builder Registry

RegisterLayer and GetLayerBuilder manage a global map from op-type strings to LayerBuilder functions. During GGUF model loading (see package github.com/zerfoo/zerfoo/inference), the loader looks up each operation by its op_type and calls the corresponding builder to reconstruct the graph node with the correct parameters and attributes.

Parameter Resolution

ParamResolver maps architecture-specific weight names to canonical names so that the same graph-building code works across architectures. Call NewParamResolver with an architecture string (e.g. "phi") to obtain the appropriate resolver, then use ResolveAll to produce a parameter map that supports lookup by both original and canonical names.

Serialization and Export

ModelSerializer, ModelLoader, and ModelExporter define generic interfaces for model persistence. The Exporter interface provides a simpler single-method contract for writing a Model to a file path.

Validation and Optimization

ModelValidator checks model correctness including graph consistency, parameter integrity, and input shape compatibility. BasicModelValidator is the default implementation. ModelOptimizer applies performance or memory optimizations to a model instance.

Memory-Mapped File Access

MmapReader memory-maps a model file for zero-copy access to its contents, used during GGUF loading to avoid buffering large weight tensors into heap memory.

Integration

Models built by this package are consumed by the inference pipeline (package github.com/zerfoo/zerfoo/inference) which loads GGUF files and constructs architecture-specific graphs, and by the text generation layer (package github.com/zerfoo/zerfoo/generate) which drives autoregressive token generation over a model's forward pass. Stability: stable

Package model provides generic interfaces for model management and abstraction.

Package model provides a comprehensive model registry for managing pluggable model components.

Package model provides the core structures and loading mechanisms for Zerfoo models.

Index

Constants

This section is empty.

Variables

View Source
var (
	Float32ModelRegistry = NewModelRegistry[float32]()
	Float64ModelRegistry = NewModelRegistry[float64]()
)

Global registry instances for common numeric types

Functions

func RegisterLayer

func RegisterLayer[T tensor.Numeric](opType string, builder LayerBuilder[T])

RegisterLayer adds a new layer builder to the registry. It is intended to be called at initialization time (e.g., in an init() function).

func ResolveAll added in v0.2.1

func ResolveAll[V any](r ParamResolver, params map[string]V) map[string]V

ResolveAll takes a resolver and a map keyed by model-specific names, and returns a new map containing both the original names and any canonical aliases produced by the resolver. This allows parameter lookups by either the original or canonical name.

func SetLogger added in v0.2.1

func SetLogger(l log.Logger)

SetLogger sets the package-level logger for model operations.

func UnregisterLayer added in v0.2.1

func UnregisterLayer(opType string)

UnregisterLayer removes a layer builder from the registry.

Types

type BasicModelValidator added in v0.2.1

type BasicModelValidator[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

BasicModelValidator provides basic model validation functionality.

func NewBasicModelValidator added in v0.2.1

func NewBasicModelValidator[T tensor.Numeric]() *BasicModelValidator[T]

NewBasicModelValidator creates a new BasicModelValidator.

func (*BasicModelValidator[T]) GetValidatorInfo added in v0.2.1

func (v *BasicModelValidator[T]) GetValidatorInfo() ValidatorInfo

GetValidatorInfo implements ModelValidator.GetValidatorInfo

func (*BasicModelValidator[T]) ValidateArchitecture added in v0.2.1

func (v *BasicModelValidator[T]) ValidateArchitecture(ctx context.Context, model ModelInstance[T]) error

ValidateArchitecture implements ModelValidator.ValidateArchitecture

func (*BasicModelValidator[T]) ValidateInputs added in v0.2.1

func (v *BasicModelValidator[T]) ValidateInputs(ctx context.Context, model ModelInstance[T], inputs ...*tensor.TensorNumeric[T]) error

ValidateInputs implements ModelValidator.ValidateInputs

func (*BasicModelValidator[T]) ValidateModel added in v0.2.1

func (v *BasicModelValidator[T]) ValidateModel(ctx context.Context, model ModelInstance[T]) (*ValidationResult, error)

ValidateModel implements ModelValidator.ValidateModel

type Exporter

type Exporter[T tensor.Numeric] interface {
	// Export saves the given model to the specified path.
	Export(model *Model[T], path string) error
}

Exporter defines the interface for saving a zerfoo model to an external format.

type ExporterInfo added in v0.2.1

type ExporterInfo struct {
	Name             string   `json:"name"`
	Version          string   `json:"version"`
	Description      string   `json:"description"`
	SupportedFormats []string `json:"supported_formats"`
	Optimization     bool     `json:"supports_optimization"`
	Quantization     bool     `json:"supports_quantization"`
}

ExporterInfo contains metadata about a model exporter.

type LayerBuilder

type LayerBuilder[T tensor.Numeric] func(
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	name string,
	params map[string]*graph.Parameter[T],
	attributes map[string]interface{},
) (graph.Node[T], error)

LayerBuilder is a function that constructs a graph.Node (a layer) from serialized parameters.

func GetLayerBuilder

func GetLayerBuilder[T tensor.Numeric](opType string) (LayerBuilder[T], error)

GetLayerBuilder retrieves a layer builder from the registry for a given op_type.

type LoaderInfo added in v0.2.1

type LoaderInfo struct {
	Name             string   `json:"name"`
	Version          string   `json:"version"`
	Description      string   `json:"description"`
	SupportedFormats []string `json:"supported_formats"`
	StreamingLoad    bool     `json:"supports_streaming_load"`
	LazyLoad         bool     `json:"supports_lazy_load"`
}

LoaderInfo contains metadata about a model loader.

type MmapReader added in v0.2.1

type MmapReader struct {
	// contains filtered or unexported fields
}

MmapReader memory-maps a file and provides access to its contents as a byte slice backed by the OS page cache.

func NewMmapReader added in v0.2.1

func NewMmapReader(path string) (*MmapReader, error)

NewMmapReader memory-maps the file at path and returns an MmapReader. The caller must call Close when done to release the mapping.

func (*MmapReader) Bytes added in v0.2.1

func (r *MmapReader) Bytes() []byte

Bytes returns the memory-mapped file contents.

func (*MmapReader) Close added in v0.2.1

func (r *MmapReader) Close() error

Close releases the memory mapping. It is safe to call multiple times.

type Model

type Model[T tensor.Numeric] struct {
	Embedding *embeddings.TokenEmbedding[T]
	Graph     *graph.Graph[T]
}

Model represents a complete model, including a token embedding layer and a computation graph.

func NewModel

func NewModel[T tensor.Numeric](embedding *embeddings.TokenEmbedding[T], g *graph.Graph[T]) *Model[T]

NewModel creates a new model.

func (*Model[T]) Forward

func (m *Model[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward performs the forward pass of the model.

type ModelCapabilities added in v0.2.1

type ModelCapabilities struct {
	SupportedTypes      []string `json:"supported_types"`
	SupportedPrecisions []string `json:"supported_precisions"`
	SupportsTraining    bool     `json:"supports_training"`
	SupportsInference   bool     `json:"supports_inference"`
	SupportsBatching    bool     `json:"supports_batching"`
	SupportsStreaming   bool     `json:"supports_streaming"`
	MaxBatchSize        int      `json:"max_batch_size"`
	MaxSequenceLength   int      `json:"max_sequence_length"`
}

ModelCapabilities describes what a model provider can do.

type ModelConfig added in v0.2.1

type ModelConfig struct {
	// Core configuration
	Type         string                 `json:"type"`         // "standard", "hrm", "ensemble", etc.
	Architecture map[string]interface{} `json:"architecture"` // Architecture-specific parameters
	Parameters   map[string]interface{} `json:"parameters"`   // Model parameters

	// Behavior configuration
	TrainingMode bool `json:"training_mode"` // Whether to initialize in training mode
	BatchSize    int  `json:"batch_size"`    // Default batch size for inference

	// Format and compatibility
	InputFormat  string `json:"input_format"`  // Expected input format
	OutputFormat string `json:"output_format"` // Expected output format
	Version      string `json:"version"`       // Model format version

	// Extension point for domain-specific configuration
	Extensions map[string]interface{} `json:"extensions"` // Domain-specific extensions
}

ModelConfig configures model creation and behavior.

type ModelExporter added in v0.2.1

type ModelExporter[T tensor.Numeric] interface {
	// ExportToPath exports a model to a file path
	ExportToPath(ctx context.Context, model ModelInstance[T], path string) error

	// ExportToWriter exports a model to an io.Writer
	ExportToWriter(ctx context.Context, model ModelInstance[T], writer io.Writer) error

	// ExportToBytes exports a model to byte data
	ExportToBytes(ctx context.Context, model ModelInstance[T]) ([]byte, error)

	// SupportsFormat returns whether the exporter supports the given format
	SupportsFormat(format string) bool

	// GetExporterInfo returns metadata about this exporter
	GetExporterInfo() ExporterInfo
}

ModelExporter provides a generic interface for exporting models to various formats.

type ModelExporterFactory added in v0.2.1

type ModelExporterFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (ModelExporter[T], error)

ModelExporterFactory creates ModelExporter instances

type ModelInstance added in v0.2.1

type ModelInstance[T tensor.Numeric] interface {
	// Forward performs model inference
	Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Backward performs backpropagation (for training)
	Backward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) error

	// GetGraph returns the underlying computation graph
	GetGraph() *graph.Graph[T]

	// GetMetadata returns model metadata
	GetMetadata() ModelMetadata

	// Parameters returns model parameters for optimization
	Parameters() []*graph.Parameter[T]

	// SetTrainingMode sets the model to training or inference mode
	SetTrainingMode(training bool)

	// IsTraining returns whether the model is in training mode
	IsTraining() bool
}

ModelInstance represents a specific model instance with inference and training capabilities.

type ModelLoader added in v0.2.1

type ModelLoader[T tensor.Numeric] interface {
	// LoadFromPath loads a model from a file path
	LoadFromPath(ctx context.Context, path string) (ModelInstance[T], error)

	// LoadFromReader loads a model from an io.Reader
	LoadFromReader(ctx context.Context, reader io.Reader) (ModelInstance[T], error)

	// LoadFromBytes loads a model from byte data
	LoadFromBytes(ctx context.Context, data []byte) (ModelInstance[T], error)

	// SupportsFormat returns whether the loader supports the given format
	SupportsFormat(format string) bool

	// GetLoaderInfo returns metadata about this loader
	GetLoaderInfo() LoaderInfo
}

ModelLoader provides a generic interface for loading models from various sources.

type ModelLoaderFactory added in v0.2.1

type ModelLoaderFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (ModelLoader[T], error)

ModelLoaderFactory creates ModelLoader instances

type ModelMetadata added in v0.2.1

type ModelMetadata struct {
	Name         string                 `json:"name"`
	Version      string                 `json:"version"`
	Architecture string                 `json:"architecture"`
	Framework    string                 `json:"framework"`
	CreatedAt    string                 `json:"created_at"`
	ModifiedAt   string                 `json:"modified_at"`
	Parameters   int64                  `json:"parameter_count"`
	InputShape   [][]int                `json:"input_shapes"`
	OutputShape  []int                  `json:"output_shape"`
	Tags         []string               `json:"tags"`
	Extensions   map[string]interface{} `json:"extensions"`
}

ModelMetadata contains information about a model instance.

type ModelOptimizer added in v0.2.1

type ModelOptimizer[T tensor.Numeric] interface {
	// OptimizeModel applies optimizations to improve performance
	OptimizeModel(ctx context.Context, model ModelInstance[T], config OptimizationConfig) (ModelInstance[T], error)

	// GetOptimizations returns available optimization strategies
	GetOptimizations() []OptimizationStrategy

	// GetOptimizerInfo returns metadata about this optimizer
	GetOptimizerInfo() OptimizerInfo
}

ModelOptimizer provides model optimization capabilities.

type ModelOptimizerFactory added in v0.2.1

type ModelOptimizerFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (ModelOptimizer[T], error)

ModelOptimizerFactory creates ModelOptimizer instances

type ModelProvider added in v0.2.1

type ModelProvider[T tensor.Numeric] interface {
	// CreateModel creates a new model instance from configuration
	CreateModel(ctx context.Context, config ModelConfig) (ModelInstance[T], error)

	// CreateFromGraph creates a model instance from an existing graph
	CreateFromGraph(ctx context.Context, g *graph.Graph[T], config ModelConfig) (ModelInstance[T], error)

	// GetCapabilities returns the capabilities supported by this provider
	GetCapabilities() ModelCapabilities

	// GetProviderInfo returns metadata about this provider
	GetProviderInfo() ProviderInfo
}

ModelProvider creates and manages model instances with pluggable architectures. This interface allows different model implementations to be used interchangeably while maintaining consistent creation and management patterns.

type ModelProviderFactory added in v0.2.1

type ModelProviderFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (ModelProvider[T], error)

ModelProviderFactory creates ModelProvider instances

type ModelRegistry added in v0.2.1

type ModelRegistry[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

ModelRegistry manages registered model components and provides factory functions. This registry enables runtime component selection and supports multiple implementations of each model interface.

func NewModelRegistry added in v0.2.1

func NewModelRegistry[T tensor.Numeric]() *ModelRegistry[T]

NewModelRegistry creates a new model registry.

func (*ModelRegistry[T]) Clear added in v0.2.1

func (r *ModelRegistry[T]) Clear()

Clear removes all registrations.

func (*ModelRegistry[T]) FindProviderByCapability added in v0.2.1

func (r *ModelRegistry[T]) FindProviderByCapability(ctx context.Context, requirement string) ([]string, error)

FindProviderByCapability finds providers that support specific capabilities.

func (*ModelRegistry[T]) GetAllRegistrations added in v0.2.1

func (r *ModelRegistry[T]) GetAllRegistrations() map[string][]string

GetAllRegistrations returns all registered component names by type.

func (*ModelRegistry[T]) GetModelExporter added in v0.2.1

func (r *ModelRegistry[T]) GetModelExporter(ctx context.Context, name string, config map[string]interface{}) (ModelExporter[T], error)

GetModelExporter retrieves a registered model exporter factory and creates an instance.

func (*ModelRegistry[T]) GetModelLoader added in v0.2.1

func (r *ModelRegistry[T]) GetModelLoader(ctx context.Context, name string, config map[string]interface{}) (ModelLoader[T], error)

GetModelLoader retrieves a registered model loader factory and creates an instance.

func (*ModelRegistry[T]) GetModelOptimizer added in v0.2.1

func (r *ModelRegistry[T]) GetModelOptimizer(ctx context.Context, name string, config map[string]interface{}) (ModelOptimizer[T], error)

GetModelOptimizer retrieves a registered model optimizer factory and creates an instance.

func (*ModelRegistry[T]) GetModelProvider added in v0.2.1

func (r *ModelRegistry[T]) GetModelProvider(ctx context.Context, name string, config map[string]interface{}) (ModelProvider[T], error)

GetModelProvider retrieves a registered model provider factory and creates an instance.

func (*ModelRegistry[T]) GetModelSerializer added in v0.2.1

func (r *ModelRegistry[T]) GetModelSerializer(ctx context.Context, name string, config map[string]interface{}) (ModelSerializer[T], error)

GetModelSerializer retrieves a registered model serializer factory and creates an instance.

func (*ModelRegistry[T]) GetModelValidator added in v0.2.1

func (r *ModelRegistry[T]) GetModelValidator(ctx context.Context, name string, config map[string]interface{}) (ModelValidator[T], error)

GetModelValidator retrieves a registered model validator factory and creates an instance.

func (*ModelRegistry[T]) ListModelExporters added in v0.2.1

func (r *ModelRegistry[T]) ListModelExporters() []string

ListModelExporters returns all registered model exporter names.

func (*ModelRegistry[T]) ListModelLoaders added in v0.2.1

func (r *ModelRegistry[T]) ListModelLoaders() []string

ListModelLoaders returns all registered model loader names.

func (*ModelRegistry[T]) ListModelOptimizers added in v0.2.1

func (r *ModelRegistry[T]) ListModelOptimizers() []string

ListModelOptimizers returns all registered model optimizer names.

func (*ModelRegistry[T]) ListModelProviders added in v0.2.1

func (r *ModelRegistry[T]) ListModelProviders() []string

ListModelProviders returns all registered model provider names.

func (*ModelRegistry[T]) ListModelSerializers added in v0.2.1

func (r *ModelRegistry[T]) ListModelSerializers() []string

ListModelSerializers returns all registered model serializer names.

func (*ModelRegistry[T]) ListModelValidators added in v0.2.1

func (r *ModelRegistry[T]) ListModelValidators() []string

ListModelValidators returns all registered model validator names.

func (*ModelRegistry[T]) RegisterModelExporter added in v0.2.1

func (r *ModelRegistry[T]) RegisterModelExporter(name string, factory ModelExporterFactory[T]) error

RegisterModelExporter registers a model exporter factory.

func (*ModelRegistry[T]) RegisterModelLoader added in v0.2.1

func (r *ModelRegistry[T]) RegisterModelLoader(name string, factory ModelLoaderFactory[T]) error

RegisterModelLoader registers a model loader factory.

func (*ModelRegistry[T]) RegisterModelOptimizer added in v0.2.1

func (r *ModelRegistry[T]) RegisterModelOptimizer(name string, factory ModelOptimizerFactory[T]) error

RegisterModelOptimizer registers a model optimizer factory.

func (*ModelRegistry[T]) RegisterModelProvider added in v0.2.1

func (r *ModelRegistry[T]) RegisterModelProvider(name string, factory ModelProviderFactory[T]) error

RegisterModelProvider registers a model provider factory.

func (*ModelRegistry[T]) RegisterModelSerializer added in v0.2.1

func (r *ModelRegistry[T]) RegisterModelSerializer(name string, factory ModelSerializerFactory[T]) error

RegisterModelSerializer registers a model serializer factory.

func (*ModelRegistry[T]) RegisterModelValidator added in v0.2.1

func (r *ModelRegistry[T]) RegisterModelValidator(name string, factory ModelValidatorFactory[T]) error

RegisterModelValidator registers a model validator factory.

func (*ModelRegistry[T]) Summary added in v0.2.1

func (r *ModelRegistry[T]) Summary() map[string]int

Summary returns a summary of all registered components.

func (*ModelRegistry[T]) UnregisterModelExporter added in v0.2.1

func (r *ModelRegistry[T]) UnregisterModelExporter(name string)

UnregisterModelExporter removes a model exporter registration.

func (*ModelRegistry[T]) UnregisterModelLoader added in v0.2.1

func (r *ModelRegistry[T]) UnregisterModelLoader(name string)

UnregisterModelLoader removes a model loader registration.

func (*ModelRegistry[T]) UnregisterModelOptimizer added in v0.2.1

func (r *ModelRegistry[T]) UnregisterModelOptimizer(name string)

UnregisterModelOptimizer removes a model optimizer registration.

func (*ModelRegistry[T]) UnregisterModelProvider added in v0.2.1

func (r *ModelRegistry[T]) UnregisterModelProvider(name string)

UnregisterModelProvider removes a model provider registration.

func (*ModelRegistry[T]) UnregisterModelSerializer added in v0.2.1

func (r *ModelRegistry[T]) UnregisterModelSerializer(name string)

UnregisterModelSerializer removes a model serializer registration.

func (*ModelRegistry[T]) UnregisterModelValidator added in v0.2.1

func (r *ModelRegistry[T]) UnregisterModelValidator(name string)

UnregisterModelValidator removes a model validator registration.

type ModelSerializer added in v0.2.1

type ModelSerializer[T tensor.Numeric] interface {
	// Save serializes a model to the specified path or writer
	Save(ctx context.Context, model ModelInstance[T], destination interface{}) error

	// Load deserializes a model from the specified path or reader
	Load(ctx context.Context, source interface{}) (ModelInstance[T], error)

	// GetSupportedFormats returns the file formats supported by this serializer
	GetSupportedFormats() []string

	// GetSerializerInfo returns metadata about this serializer
	GetSerializerInfo() SerializerInfo
}

ModelSerializer handles model persistence in various formats.

type ModelSerializerFactory added in v0.2.1

type ModelSerializerFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (ModelSerializer[T], error)

ModelSerializerFactory creates ModelSerializer instances

type ModelValidator added in v0.2.1

type ModelValidator[T tensor.Numeric] interface {
	// ValidateModel performs comprehensive model validation
	ValidateModel(ctx context.Context, model ModelInstance[T]) (*ValidationResult, error)

	// ValidateInputs checks if inputs are compatible with the model
	ValidateInputs(ctx context.Context, model ModelInstance[T], inputs ...*tensor.TensorNumeric[T]) error

	// ValidateArchitecture checks model architecture consistency
	ValidateArchitecture(ctx context.Context, model ModelInstance[T]) error

	// GetValidatorInfo returns metadata about this validator
	GetValidatorInfo() ValidatorInfo
}

ModelValidator validates model correctness and compatibility.

type ModelValidatorFactory added in v0.2.1

type ModelValidatorFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (ModelValidator[T], error)

ModelValidatorFactory creates ModelValidator instances

type OptimizationConfig added in v0.2.1

type OptimizationConfig struct {
	Strategies   []string               `json:"strategies"`    // Optimization strategies to apply
	TargetDevice string                 `json:"target_device"` // Target device for optimization
	Precision    string                 `json:"precision"`     // Target precision (fp32, fp16, int8)
	MaxMemory    int64                  `json:"max_memory"`    // Memory constraints
	Extensions   map[string]interface{} `json:"extensions"`    // Strategy-specific options
}

OptimizationConfig configures model optimization.

type OptimizationStrategy added in v0.2.1

type OptimizationStrategy struct {
	Name         string                 `json:"name"`
	Description  string                 `json:"description"`
	Category     string                 `json:"category"`     // "performance", "memory", "accuracy"
	Impact       string                 `json:"impact"`       // "low", "medium", "high"
	Requirements []string               `json:"requirements"` // Prerequisites for this optimization
	Options      map[string]interface{} `json:"options"`      // Strategy-specific options
}

OptimizationStrategy describes an optimization approach.

type OptimizerInfo added in v0.2.1

type OptimizerInfo struct {
	Name          string   `json:"name"`
	Version       string   `json:"version"`
	Description   string   `json:"description"`
	Strategies    []string `json:"available_strategies"`
	TargetDevices []string `json:"target_devices"`
}

OptimizerInfo contains metadata about a model optimizer.

type ParamResolver added in v0.2.1

type ParamResolver interface {
	// Resolve returns the canonical name for a model-specific parameter name.
	// Returns the input unchanged if no mapping applies.
	Resolve(name string) string
}

ParamResolver maps architecture-specific parameter names to canonical names used by Zerfoo layers during model building. Canonical names follow the Llama/Gemma convention (the most common HuggingFace naming pattern).

func NewParamResolver added in v0.2.1

func NewParamResolver(arch string) ParamResolver

NewParamResolver returns a resolver for the given architecture type. Architecture types match the model_type field from HuggingFace config.json.

type ProviderInfo added in v0.2.1

type ProviderInfo struct {
	Name         string   `json:"name"`
	Version      string   `json:"version"`
	Description  string   `json:"description"`
	SupportedOps []string `json:"supported_operations"`
	Website      string   `json:"website"`
	License      string   `json:"license"`
}

ProviderInfo contains metadata about a model provider.

type SerializerInfo added in v0.2.1

type SerializerInfo struct {
	Name             string   `json:"name"`
	Version          string   `json:"version"`
	Description      string   `json:"description"`
	SupportedFormats []string `json:"supported_formats"`
	Compression      bool     `json:"supports_compression"`
	Encryption       bool     `json:"supports_encryption"`
}

SerializerInfo contains metadata about a model serializer.

type StandardModelInstance added in v0.2.1

type StandardModelInstance[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

StandardModelInstance adapts the existing Model struct to implement ModelInstance interface.

func NewStandardModelInstance added in v0.2.1

func NewStandardModelInstance[T tensor.Numeric](model *Model[T]) *StandardModelInstance[T]

NewStandardModelInstance creates a new StandardModelInstance adapter.

func (*StandardModelInstance[T]) Backward added in v0.2.1

func (s *StandardModelInstance[T]) Backward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) error

Backward implements ModelInstance.Backward. The first variadic input is the gradient of the loss with respect to the model output (initial gradient). Exactly one gradient tensor must be provided.

func (*StandardModelInstance[T]) Forward added in v0.2.1

func (s *StandardModelInstance[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward implements ModelInstance.Forward

func (*StandardModelInstance[T]) GetGraph added in v0.2.1

func (s *StandardModelInstance[T]) GetGraph() *graph.Graph[T]

GetGraph implements ModelInstance.GetGraph

func (*StandardModelInstance[T]) GetMetadata added in v0.2.1

func (s *StandardModelInstance[T]) GetMetadata() ModelMetadata

GetMetadata implements ModelInstance.GetMetadata

func (*StandardModelInstance[T]) IsTraining added in v0.2.1

func (s *StandardModelInstance[T]) IsTraining() bool

IsTraining implements ModelInstance.IsTraining

func (*StandardModelInstance[T]) Parameters added in v0.2.1

func (s *StandardModelInstance[T]) Parameters() []*graph.Parameter[T]

Parameters implements ModelInstance.Parameters

func (*StandardModelInstance[T]) SetTrainingMode added in v0.2.1

func (s *StandardModelInstance[T]) SetTrainingMode(training bool)

SetTrainingMode implements ModelInstance.SetTrainingMode

type StandardModelProvider added in v0.2.1

type StandardModelProvider[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

StandardModelProvider provides standard model creation capabilities.

func NewStandardModelProvider added in v0.2.1

func NewStandardModelProvider[T tensor.Numeric]() *StandardModelProvider[T]

NewStandardModelProvider creates a new StandardModelProvider.

func (*StandardModelProvider[T]) CreateFromGraph added in v0.2.1

func (p *StandardModelProvider[T]) CreateFromGraph(ctx context.Context, g *graph.Graph[T], config ModelConfig) (ModelInstance[T], error)

CreateFromGraph implements ModelProvider.CreateFromGraph

func (*StandardModelProvider[T]) CreateModel added in v0.2.1

func (p *StandardModelProvider[T]) CreateModel(ctx context.Context, config ModelConfig) (ModelInstance[T], error)

CreateModel implements ModelProvider.CreateModel

func (*StandardModelProvider[T]) GetCapabilities added in v0.2.1

func (p *StandardModelProvider[T]) GetCapabilities() ModelCapabilities

GetCapabilities implements ModelProvider.GetCapabilities

func (*StandardModelProvider[T]) GetProviderInfo added in v0.2.1

func (p *StandardModelProvider[T]) GetProviderInfo() ProviderInfo

GetProviderInfo implements ModelProvider.GetProviderInfo

type ValidationError added in v0.2.1

type ValidationError struct {
	Type       string `json:"type"`
	Message    string `json:"message"`
	Component  string `json:"component"`
	Severity   string `json:"severity"`
	Suggestion string `json:"suggestion"`
}

ValidationError represents a validation error.

type ValidationResult added in v0.2.1

type ValidationResult struct {
	IsValid    bool                   `json:"is_valid"`
	Errors     []ValidationError      `json:"errors"`
	Warnings   []ValidationWarning    `json:"warnings"`
	Metrics    map[string]float64     `json:"metrics"`
	Summary    string                 `json:"summary"`
	Extensions map[string]interface{} `json:"extensions"`
}

ValidationResult contains model validation results.

type ValidationWarning added in v0.2.1

type ValidationWarning struct {
	Type       string `json:"type"`
	Message    string `json:"message"`
	Component  string `json:"component"`
	Suggestion string `json:"suggestion"`
}

ValidationWarning represents a validation warning.

type ValidatorInfo added in v0.2.1

type ValidatorInfo struct {
	Name        string   `json:"name"`
	Version     string   `json:"version"`
	Description string   `json:"description"`
	CheckTypes  []string `json:"check_types"`
	Strictness  string   `json:"strictness"`
}

ValidatorInfo contains metadata about a model validator.

Directories

Path Synopsis
Package gguf provides GGUF file format parsing and writing.
Package gguf provides GGUF file format parsing and writing.
Package hrm provides experimental Hierarchical Reasoning Model types.
Package hrm provides experimental Hierarchical Reasoning Model types.
Package huggingface provides HuggingFace model configuration parsing.
Package huggingface provides HuggingFace model configuration parsing.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL