inference

package
v1.24.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 26, 2026 License: Apache-2.0 Imports: 34 Imported by: 0

Documentation

Overview

Package inference provides a high-level API for loading GGUF models and running text generation, chat, embedding, and speculative decoding with minimal boilerplate. (Stability: stable)

Loading Models

There are two entry points for loading a model:

  • Load resolves a model by name or HuggingFace repo ID, pulling it from the registry if not already cached, and returns a ready-to-use Model.
  • LoadFile loads a model directly from a local GGUF file path.

Both accept functional Option values to configure the compute device, cache directory, sequence length, and other parameters:

m, err := inference.Load("gemma-3-1b-q4",
	inference.WithDevice("cuda"),
	inference.WithMaxSeqLen(4096),
)
if err != nil {
	log.Fatal(err)
}
defer m.Close()

text, err := m.Generate(ctx, "Explain gradient descent briefly.",
	inference.WithMaxTokens(256),
	inference.WithTemperature(0.7),
)

Model Methods

A loaded Model exposes several generation methods:

  • Model.Generate produces text from a prompt and returns the full result.
  • Model.GenerateStream delivers tokens incrementally via a callback.
  • Model.GenerateBatch processes multiple prompts concurrently.
  • Model.Chat formats a slice of Message values using the model's chat template and generates a Response with token usage statistics.
  • Model.Embed returns an L2-normalized embedding vector for a text input by mean-pooling the model's token embedding table.
  • Model.SpeculativeGenerate runs speculative decoding with a smaller draft model to accelerate generation from a larger target model.

Load Options

The following Option functions configure model loading:

  • WithDevice — compute device: "cpu", "cuda", "cuda:N", "rocm", "opencl"
  • WithCacheDir — local directory for cached model files
  • WithMaxSeqLen — override the model's default maximum sequence length
  • WithRegistry — supply a custom model registry
  • WithBackend — select "tensorrt" for TensorRT-optimized inference
  • WithPrecision — set TensorRT compute precision ("fp16")
  • WithDType — set GPU compute precision ("fp16", "fp8")
  • WithKVDtype — set KV cache storage precision ("fp16")
  • WithMmap — enable memory-mapped model loading on unix

Generate Options

The following GenerateOption functions configure sampling for generation methods:

Model Aliases

Short aliases such as "gemma-3-1b-q4" and "llama-3-8b-q4" map to full HuggingFace repository IDs. Use ResolveAlias to look up the mapping and RegisterAlias to add custom aliases.

For lower-level control over text generation, KV caching, and sampling, see the github.com/zerfoo/zerfoo/generate package. For an OpenAI-compatible HTTP server built on top of this package, see github.com/zerfoo/zerfoo/serve. Stability: stable

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func AutoBuild added in v1.8.0

AutoBuild reads GGUF metadata from cfg and constructs the appropriate computation graph automatically, without requiring a hand-written per-model builder. It detects architecture features from metadata and delegates to the shared buildTransformerGraph for standard decoder-only transformer architectures.

For non-transformer architectures (Mamba, Whisper, etc.) that have a registered ArchBuilder, AutoBuild falls back to that builder.

For completely unknown architectures with standard decoder-only tensor names, AutoBuild constructs a plain transformer graph.

func BuildArchGraph added in v1.5.0

func BuildArchGraph(
	arch string,
	tensors map[string]*tensor.TensorNumeric[float32],
	cfg *gguf.ModelConfig,
	engine compute.Engine[float32],
) (*graph.Graph[float32], *tensor.TensorNumeric[float32], error)

BuildArchGraph dispatches to the appropriate architecture-specific graph builder. Exported for benchmark and integration tests that construct synthetic weight maps without loading from GGUF files.

func BuildJamba added in v1.5.0

BuildJamba constructs a computation graph for the Jamba hybrid architecture.

Attention layers use tensor names:

blk.{i}.attn_norm.weight
blk.{i}.attn_q.weight, blk.{i}.attn_k.weight, blk.{i}.attn_v.weight, blk.{i}.attn_output.weight
blk.{i}.ffn_norm.weight
blk.{i}.ffn_gate.weight, blk.{i}.ffn_up.weight, blk.{i}.ffn_down.weight

SSM layers use tensor names:

blk.{i}.ssm_norm.weight
blk.{i}.ssm_in_proj.weight, blk.{i}.ssm_conv1d.weight, blk.{i}.ssm_x_proj.weight
blk.{i}.ssm_dt_proj.weight, blk.{i}.ssm_A_log, blk.{i}.ssm_D, blk.{i}.ssm_out_proj.weight

func BuildLLaVAModel added in v1.7.0

func BuildLLaVAModel(
	lc LLaVAConfig,
	tensors map[string]*tensor.TensorNumeric[float32],
	cfg *gguf.ModelConfig,
	engine compute.Engine[float32],
) (*graph.Graph[float32], *tensor.TensorNumeric[float32], error)

BuildLLaVAModel constructs the LLaVA computation graph from a weight map. Exported for benchmark and integration tests that construct synthetic weight maps.

func BuildMamba3 added in v1.5.0

BuildMamba3 constructs a computation graph for Mamba-3 from a weight map.

Expected tensor names:

token_embd.weight                 — [vocab_size, d_model]
output.weight                     — [vocab_size, d_model]
output_norm.weight                — [d_model]
mamba.{i}.norm.weight             — [d_model]
mamba.{i}.in_proj.weight          — [2*d_inner, d_model]
mamba.{i}.conv1d.weight           — [d_inner, 1, d_conv]
mamba.{i}.conv1d.bias             — [d_inner] (optional)
mamba.{i}.x_proj.weight           — [dt_rank + 2*d_state, d_inner]
mamba.{i}.dt_proj.weight          — [d_inner, dt_rank]
mamba.{i}.dt_proj.bias            — [d_inner] (optional)
mamba.{i}.A_log                   — [d_inner, d_state]
mamba.{i}.D                       — [d_inner]
mamba.{i}.out_proj.weight         — [d_model, d_inner]

func BuildMamba3MIMO added in v1.8.0

func BuildMamba3MIMO(
	mc Mamba3Config,
	tensors map[string]*tensor.TensorNumeric[float32],
	engine compute.Engine[float32],
) (*graph.Graph[float32], *tensor.TensorNumeric[float32], error)

BuildMamba3MIMO constructs a computation graph for Mamba 3 using MIMO SSM blocks with exponential-trapezoidal discretization.

Expected tensor names:

token_embd.weight                      — [vocab_size, d_model]
output.weight                          — [vocab_size, d_model]
output_norm.weight                     — [d_model]
mamba3.{i}.norm.weight                 — [d_model]
mamba3.{i}.in_proj.weight              — [2*d_inner, d_model]
mamba3.{i}.conv1d.weight               — [d_inner, 1, d_conv]
mamba3.{i}.x_proj.weight               — [dt_rank + 2*d_state*num_heads, d_inner]
mamba3.{i}.dt_proj.weight              — [d_inner, dt_rank]
mamba3.{i}.A_log.{h}                   — [head_dim, d_state] per head
mamba3.{i}.D.{h}                       — [head_dim] per head
mamba3.{i}.head_mix.weight             — [d_inner, d_inner]
mamba3.{i}.out_proj.weight             — [d_model, d_inner]

func BuildQwenVLModel added in v1.8.0

func BuildQwenVLModel(
	qc QwenVLConfig,
	tensors map[string]*tensor.TensorNumeric[float32],
	cfg *gguf.ModelConfig,
	engine compute.Engine[float32],
) (*graph.Graph[float32], *tensor.TensorNumeric[float32], error)

BuildQwenVLModel constructs the Qwen-VL computation graph from a weight map. Exported for benchmark and integration tests that construct synthetic weight maps.

func BuildRWKV added in v1.7.0

BuildRWKV constructs a computation graph for the RWKV-6/7 architecture.

Expected tensor names (GGUF RWKV convention):

token_embd.weight            — [vocab_size, hidden_size]
output.weight                — [vocab_size, hidden_size]
output_norm.weight           — [hidden_size]
output_norm.bias             — [hidden_size]
blocks.{i}.ln0.weight        — [hidden_size] (layer 0 only, pre-embedding norm)
blocks.{i}.ln0.bias          — [hidden_size] (layer 0 only)
blocks.{i}.ln1.weight        — [hidden_size] (time mixing norm)
blocks.{i}.ln1.bias          — [hidden_size]
blocks.{i}.ln2.weight        — [hidden_size] (channel mixing norm)
blocks.{i}.ln2.bias          — [hidden_size]
blocks.{i}.att.time_mix_r    — [1, 1, hidden_size]
blocks.{i}.att.time_mix_k    — [1, 1, hidden_size]
blocks.{i}.att.time_mix_v    — [1, 1, hidden_size]
blocks.{i}.att.time_mix_g    — [1, 1, hidden_size]
blocks.{i}.att.time_decay    — [num_heads, head_size]
blocks.{i}.att.time_faaaa    — [num_heads, head_size] (initial state)
blocks.{i}.att.receptance.weight — [hidden_size, hidden_size]
blocks.{i}.att.key.weight        — [hidden_size, hidden_size]
blocks.{i}.att.value.weight      — [hidden_size, hidden_size]
blocks.{i}.att.gate.weight       — [hidden_size, hidden_size]
blocks.{i}.att.output.weight     — [hidden_size, hidden_size]
blocks.{i}.att.ln_x.weight       — [hidden_size] (group norm)
blocks.{i}.att.ln_x.bias         — [hidden_size]
blocks.{i}.ffn.time_mix_k        — [1, 1, hidden_size]
blocks.{i}.ffn.time_mix_r        — [1, 1, hidden_size]
blocks.{i}.ffn.key.weight        — [ffn_size, hidden_size]
blocks.{i}.ffn.value.weight      — [hidden_size, ffn_size]
blocks.{i}.ffn.receptance.weight — [hidden_size, hidden_size]

func BuildResidualConnection added in v1.9.0

func BuildResidualConnection[T tensor.Numeric](config ResidualConfig, engine compute.Engine[T]) any

BuildResidualConnection returns a residual handler appropriate for the given config. For "standard" mode (the default), it returns nil — callers should fall through to existing residual-add logic. For "attnres" and "block_attnres" modes, it returns a placeholder (nil for now); the actual implementation will be wired once layers/residual/ ships AttnRes types.

func BuildWhisperEncoder added in v1.5.0

func BuildWhisperEncoder(
	wc WhisperConfig,
	tensors map[string]*tensor.TensorNumeric[float32],
	engine compute.Engine[float32],
) (*graph.Graph[float32], *tensor.TensorNumeric[float32], error)

BuildWhisperEncoder constructs a computation graph for Whisper encoder from a weight map. Exported for benchmark and integration tests that construct synthetic weight maps.

func ConvertGraphToTRT

func ConvertGraphToTRT(g *graph.Graph[float32], workspaceBytes int, fp16 bool, dynamicShapes *DynamicShapeConfig) (*trtConversionResult, error)

ConvertGraphToTRT walks a graph in topological order and maps each node to a TensorRT layer. Returns serialized engine bytes or an UnsupportedOpError if the graph contains operations that cannot be converted. If dynamicShapes is non-nil, an optimization profile is created with the specified min/opt/max dimensions for each input.

func IsEncoderArchitecture added in v1.9.0

func IsEncoderArchitecture(arch string) bool

IsEncoderArchitecture reports whether the given architecture name is an encoder-only model (e.g., BERT, RoBERTa).

func ListArchitectures added in v1.5.0

func ListArchitectures() []string

ListArchitectures returns a sorted list of all registered architecture names.

func LoadTRTEngine

func LoadTRTEngine(key string) ([]byte, error)

LoadTRTEngine reads a serialized TensorRT engine from the cache. Returns nil, nil on cache miss (file not found).

func RegisterAlias

func RegisterAlias(shortName, repoID string)

RegisterAlias adds a custom short name -> HuggingFace repo ID mapping.

func RegisterArchitecture added in v1.5.0

func RegisterArchitecture(name string, builder ArchBuilder)

RegisterArchitecture registers an architecture builder under the given name. Names correspond to GGUF general.architecture values (e.g. "llama", "gemma"). Multiple names can map to the same builder (e.g. "gemma" and "gemma3"). Panics if name is empty or a builder is already registered for that name.

func ResolveAlias

func ResolveAlias(name string) string

ResolveAlias returns the HuggingFace repo ID for a short alias. If the name is not an alias, it is returned unchanged.

func SaveTRTEngine

func SaveTRTEngine(key string, data []byte) error

SaveTRTEngine writes a serialized TensorRT engine to the cache directory.

func TRTCacheKey

func TRTCacheKey(modelID, precision string) (string, error)

TRTCacheKey builds a deterministic cache key from model ID, precision, and GPU architecture. The key is a hex SHA-256 hash to avoid filesystem issues with long or special-character model IDs.

Types

type ArchBuilder added in v1.5.0

type ArchBuilder func(
	tensors map[string]*tensor.TensorNumeric[float32],
	cfg *gguf.ModelConfig,
	engine compute.Engine[float32],
) (*graph.Graph[float32], *tensor.TensorNumeric[float32], error)

ArchBuilder builds a computation graph for a model architecture from pre-loaded GGUF tensors. It returns the graph and the embedding table tensor (needed by the generator for token lookup).

func GetArchitecture added in v1.5.0

func GetArchitecture(name string) (ArchBuilder, bool)

GetArchitecture returns the builder registered for the given architecture name. Returns nil, false if no builder is registered.

type ArchConfigRegistry

type ArchConfigRegistry struct {
	// contains filtered or unexported fields
}

ArchConfigRegistry maps model_type strings to config parsers.

func DefaultArchConfigRegistry

func DefaultArchConfigRegistry() *ArchConfigRegistry

DefaultArchConfigRegistry returns a registry with all built-in parsers registered.

func (*ArchConfigRegistry) Parse

func (r *ArchConfigRegistry) Parse(raw map[string]interface{}) (*ModelMetadata, error)

Parse dispatches to the registered parser for the model_type in raw, or falls back to generic field extraction for unknown types.

func (*ArchConfigRegistry) Register

func (r *ArchConfigRegistry) Register(modelType string, parser ConfigParser)

Register adds a parser for the given model type.

type ConfigParser

type ConfigParser func(raw map[string]interface{}) (*ModelMetadata, error)

ConfigParser parses a raw JSON map (from config.json) into ModelMetadata.

type ConstantValueGetter

type ConstantValueGetter interface {
	GetValue() *tensor.TensorNumeric[float32]
}

ConstantValueGetter is an interface for nodes that hold constant tensor data.

type DTypeSetter

type DTypeSetter interface {
	SetDType(compute.DType)
}

DTypeSetter is implemented by engines that support setting compute precision.

type DynamicShapeConfig

type DynamicShapeConfig struct {
	// InputShapes maps input index (0-based) to its shape range.
	InputShapes []ShapeRange
}

DynamicShapeConfig specifies per-input shape ranges for TensorRT optimization profiles. When non-nil, the converter creates an optimization profile that allows variable-size inputs within the specified ranges.

type EncoderModel added in v1.9.0

type EncoderModel struct {
	// contains filtered or unexported fields
}

EncoderModel represents a loaded encoder-only model (BERT, RoBERTa, etc.). Unlike the decoder Model type, EncoderModel has no KV cache, no generator, and no autoregressive decoding loop. It runs a single forward pass over the full input sequence and returns classification logits.

func LoadEncoderFile added in v1.9.0

func LoadEncoderFile(path string, opts ...Option) (*EncoderModel, error)

LoadEncoderFile loads an encoder-only model from a GGUF file. It verifies the architecture is encoder-only and returns an EncoderModel instead of a Generator-based Model. Returns an error if the architecture is not an encoder type.

func (*EncoderModel) Close added in v1.9.0

func (m *EncoderModel) Close() error

Close releases resources held by the encoder model.

func (*EncoderModel) Config added in v1.9.0

func (m *EncoderModel) Config() *gguf.ModelConfig

Config returns the underlying model configuration.

func (*EncoderModel) Engine added in v1.9.0

func (m *EncoderModel) Engine() compute.Engine[float32]

Engine returns the compute engine.

func (*EncoderModel) Forward added in v1.9.0

func (m *EncoderModel) Forward(ctx context.Context, inputIDs []int) ([]float32, error)

Forward runs the encoder on input token IDs and returns classification logits. The input is a slice of integer token IDs. The returned slice contains logits of shape [1, numClasses] flattened to []float32.

func (*EncoderModel) Graph added in v1.9.0

func (m *EncoderModel) Graph() *graph.Graph[float32]

Graph returns the computation graph.

func (*EncoderModel) OutputShape added in v1.9.0

func (m *EncoderModel) OutputShape() []int

OutputShape returns the expected output shape [batch, numClasses].

type GGUFModel

type GGUFModel struct {
	Config  *gguf.ModelConfig
	Tensors map[string]*tensor.TensorNumeric[float32]
	File    *gguf.File
}

GGUFModel holds a loaded GGUF model's configuration and tensors. This is an intermediate representation; full inference requires an architecture-specific graph builder to convert these into a computation graph.

func LoadGGUF

func LoadGGUF(path string) (*GGUFModel, error)

LoadGGUF loads a GGUF model file and returns its configuration and tensors. Tensor names are mapped from GGUF convention (blk.N.attn_q.weight) to Zerfoo canonical names (model.layers.N.self_attn.q_proj.weight).

func (*GGUFModel) ToModelMetadata

func (m *GGUFModel) ToModelMetadata() *ModelMetadata

ToModelMetadata converts a GGUF model config to inference.ModelMetadata.

type GenerateOption

type GenerateOption func(*generate.SamplingConfig)

GenerateOption configures a generation call.

func WithGrammar

func WithGrammar(g *grammar.Grammar) GenerateOption

WithGrammar sets a grammar state machine for constrained decoding. When set, a token mask is applied at each sampling step to restrict output to tokens that are valid according to the grammar.

func WithMaxTokens

func WithMaxTokens(n int) GenerateOption

WithMaxTokens sets the maximum number of tokens to generate.

func WithRepetitionPenalty

func WithRepetitionPenalty(p float64) GenerateOption

WithRepetitionPenalty sets the repetition penalty factor.

func WithStopStrings

func WithStopStrings(ss ...string) GenerateOption

WithStopStrings sets strings that stop generation.

func WithTemperature

func WithTemperature(t float64) GenerateOption

WithTemperature sets the sampling temperature.

func WithTopK

func WithTopK(k int) GenerateOption

WithTopK sets the top-K sampling parameter.

func WithTopP

func WithTopP(p float64) GenerateOption

WithTopP sets the top-P (nucleus) sampling parameter.

type JambaConfig added in v1.5.0

type JambaConfig struct {
	NumLayers            int
	HiddenSize           int
	IntermediateSize     int
	AttnHeads            int
	KVHeads              int
	SSMHeads             int // number of SSM heads (maps to DState)
	AttentionLayerOffset int // attention layers at indices that are multiples of this value
	RMSEps               float32
	VocabSize            int
	MaxSeqLen            int
	RopeTheta            float64
	DConv                int // SSM convolution width (default 4)
}

JambaConfig holds Jamba-specific hybrid model configuration.

func JambaConfigFromGGUF added in v1.5.0

func JambaConfigFromGGUF(cfg *gguf.ModelConfig) JambaConfig

JambaConfigFromGGUF extracts Jamba configuration from GGUF ModelConfig.

type LLaVAConfig added in v1.7.0

type LLaVAConfig struct {
	// Vision encoder config.
	ImageSize       int
	PatchSize       int
	VisionHiddenDim int
	VisionNumHeads  int
	VisionNumLayers int
	NumChannels     int

	// Multi-modal projector config.
	ProjectorType string // "linear" or "mlp" (2-layer MLP is default for LLaVA 1.5+)

}

LLaVAConfig holds LLaVA-specific model configuration.

func LLaVAConfigFromGGUF added in v1.7.0

func LLaVAConfigFromGGUF(cfg *gguf.ModelConfig) LLaVAConfig

LLaVAConfigFromGGUF extracts LLaVA configuration from GGUF ModelConfig.

type Mamba3Config added in v1.8.0

type Mamba3Config struct {
	NumLayers  int
	DModel     int
	DState     int
	DConv      int
	DInner     int
	NumHeads   int
	VocabSize  int
	EOSTokenID int
	RMSNormEps float32
}

Mamba3Config holds Mamba 3-specific model configuration. Mamba 3 extends Mamba with multi-head MIMO SSM, exponential-trapezoidal discretization, and cross-head mixing.

func Mamba3ConfigFromGGUF added in v1.8.0

func Mamba3ConfigFromGGUF(cfg *gguf.ModelConfig) Mamba3Config

Mamba3ConfigFromGGUF extracts Mamba 3 configuration from GGUF ModelConfig. Fields are mapped as: HiddenSize -> DModel, NumKVHeads -> DState, IntermediateSize -> DInner, NumHeads -> NumHeads. DConv defaults to 4.

func Mamba3ConfigFromMetadata added in v1.8.0

func Mamba3ConfigFromMetadata(meta map[string]interface{}) Mamba3Config

Mamba3ConfigFromMetadata extracts Mamba 3 configuration from a raw metadata map.

type MambaConfig added in v1.5.0

type MambaConfig struct {
	NumLayers  int
	DModel     int
	DState     int
	DConv      int
	DInner     int
	VocabSize  int
	EOSTokenID int
	RMSNormEps float32
}

MambaConfig holds Mamba-specific model configuration.

func MambaConfigFromGGUF added in v1.5.0

func MambaConfigFromGGUF(cfg *gguf.ModelConfig) MambaConfig

MambaConfigFromGGUF extracts Mamba configuration from GGUF ModelConfig. Fields are mapped as: HiddenSize -> DModel, NumKVHeads -> DState, IntermediateSize -> DInner. DConv defaults to 4 if not specified.

func MambaConfigFromMetadata added in v1.5.0

func MambaConfigFromMetadata(meta map[string]interface{}) MambaConfig

MambaConfigFromMetadata extracts Mamba configuration from a raw metadata map.

type Message

type Message struct {
	Role    string // "system", "user", or "assistant"
	Content string
	Images  [][]byte // optional raw image data for vision models
}

Message represents a chat message.

type Model

type Model struct {
	// contains filtered or unexported fields
}

Model is a loaded model ready for generation.

func Load

func Load(modelID string, opts ...Option) (*Model, error)

Load loads a model by ID, pulling it if not cached.

func LoadFile

func LoadFile(path string, opts ...Option) (*Model, error)

LoadFile loads a model from a local GGUF file and returns a ready-to-use Model.

func NewTestModel

func NewTestModel(
	gen *generate.Generator[float32],
	tok tokenizer.Tokenizer,
	eng compute.Engine[float32],
	meta ModelMetadata,
	info *registry.ModelInfo,
) *Model

NewTestModel constructs a Model from pre-built components. Intended for use in external test packages that need a Model without going through the full Load pipeline.

func (*Model) Chat

func (m *Model) Chat(ctx context.Context, messages []Message, opts ...GenerateOption) (Response, error)

Chat formats messages using the model's chat template and generates a response. Sessions are pooled to preserve CUDA graph replay.

func (*Model) ChatStream added in v1.11.0

func (m *Model) ChatStream(ctx context.Context, messages []Message, handler generate.TokenStream, opts ...GenerateOption) error

ChatStream formats messages using the model's chat template and streams the response token-by-token via the provided handler. This is the streaming counterpart of Chat and ensures the same prompt formatting is applied.

func (*Model) Close

func (m *Model) Close() error

Close releases resources held by the model. If the model was loaded on a GPU, this frees the CUDA engine's handles, pool, and stream. If loaded with mmap, this releases the memory mapping.

func (*Model) Config

func (m *Model) Config() ModelMetadata

Config returns the model metadata.

func (*Model) Embed

func (m *Model) Embed(text string) ([]float32, error)

Embed returns an L2-normalized embedding vector for the given text by looking up token embeddings from the model's embedding table and mean-pooling them.

func (*Model) EmbeddingWeights

func (m *Model) EmbeddingWeights() ([]float32, int)

EmbeddingWeights returns the flattened token embedding table and the hidden dimension. Returns nil, 0 if embeddings are not available.

func (*Model) FormatMessages added in v1.11.0

func (m *Model) FormatMessages(messages []Message) string

FormatMessages converts messages to the model's chat template format. This is useful when callers need the formatted prompt without running inference, e.g. for streaming paths that call GenerateStream separately.

func (*Model) Generate

func (m *Model) Generate(ctx context.Context, prompt string, opts ...GenerateOption) (string, error)

Generate produces text from a prompt. Sessions are pooled to reuse GPU memory addresses, enabling CUDA graph replay across calls. Concurrent Generate calls get separate sessions from the pool.

func (*Model) GenerateBatch

func (m *Model) GenerateBatch(ctx context.Context, prompts []string, opts ...GenerateOption) ([]string, error)

GenerateBatch processes multiple prompts concurrently and returns the generated text for each prompt. Results are returned in the same order as the input prompts. If a prompt fails, its corresponding error is non-nil.

Concurrency is capped at maxBatchConcurrency (default 8) to prevent resource exhaustion on GPU-backed models.

[Deviation: Architectural] Used parallel goroutines instead of shared PagedKV decode — full multi-seq requires deeper Generator refactor.

func (*Model) GenerateStream

func (m *Model) GenerateStream(ctx context.Context, prompt string, handler generate.TokenStream, opts ...GenerateOption) error

GenerateStream delivers tokens one at a time via a callback. Sessions are pooled to preserve GPU memory addresses for CUDA graph replay.

func (*Model) Generator

func (m *Model) Generator() *generate.Generator[float32]

Generator returns the underlying generator.

func (*Model) Info

func (m *Model) Info() *registry.ModelInfo

Info returns the registry info for this model.

func (*Model) SetEmbeddingWeights

func (m *Model) SetEmbeddingWeights(weights []float32, hiddenSize int)

SetEmbeddingWeights sets the token embedding table for Embed(). weights is a flattened [vocabSize, hiddenSize] matrix.

func (*Model) SetMaxBatchConcurrency added in v1.11.0

func (m *Model) SetMaxBatchConcurrency(n int)

SetMaxBatchConcurrency sets the maximum number of concurrent goroutines that GenerateBatch will use. Values <= 0 are ignored.

func (*Model) SpeculativeGenerate

func (m *Model) SpeculativeGenerate(
	ctx context.Context,
	draft *Model,
	prompt string,
	draftLen int,
	opts ...GenerateOption,
) (string, error)

SpeculativeGenerate runs speculative decoding using this model as the target and the draft model for token proposal. draftLen controls how many tokens are proposed per verification step.

func (*Model) Tokenizer

func (m *Model) Tokenizer() tokenizer.Tokenizer

Tokenizer returns the model's tokenizer for token counting.

type ModelMetadata

type ModelMetadata struct {
	Architecture          string `json:"architecture"`
	VocabSize             int    `json:"vocab_size"`
	HiddenSize            int    `json:"hidden_size"`
	NumLayers             int    `json:"num_layers"`
	MaxPositionEmbeddings int    `json:"max_position_embeddings"`
	EOSTokenID            int    `json:"eos_token_id"`
	BOSTokenID            int    `json:"bos_token_id"`
	ChatTemplate          string `json:"chat_template"`

	// Extended fields for multi-architecture support.
	IntermediateSize    int                `json:"intermediate_size"`
	NumQueryHeads       int                `json:"num_attention_heads"`
	NumKeyValueHeads    int                `json:"num_key_value_heads"`
	RopeTheta           float64            `json:"rope_theta"`
	RopeScaling         *RopeScalingConfig `json:"rope_scaling,omitempty"`
	TieWordEmbeddings   bool               `json:"tie_word_embeddings"`
	SlidingWindow       int                `json:"sliding_window"`
	AttentionBias       bool               `json:"attention_bias"`
	PartialRotaryFactor float64            `json:"partial_rotary_factor"`

	// Granite-specific fields.
	EmbeddingMultiplier float64 `json:"embedding_multiplier,omitempty"`
	ResidualMultiplier  float64 `json:"residual_multiplier,omitempty"`
	LogitScale          float64 `json:"logit_scale,omitempty"`

	// DeepSeek MLA and MoE fields.
	KVLoRADim          int `json:"kv_lora_rank"`
	QLoRADim           int `json:"q_lora_rank"`
	QKRopeHeadDim      int `json:"qk_rope_head_dim"`
	NumExperts         int `json:"num_experts"`
	NumExpertsPerToken int `json:"num_experts_per_tok"`
	NumSharedExperts   int `json:"n_shared_experts"`
}

ModelMetadata holds model configuration loaded from config.json.

type Option

type Option func(*loadOptions)

Option configures model loading.

func WithBackend

func WithBackend(backend string) Option

WithBackend selects the inference backend. Supported values: "" or "default" for the standard Engine path, "tensorrt" for TensorRT-optimized inference. TensorRT requires the cuda build tag and a CUDA device.

func WithCacheDir

func WithCacheDir(dir string) Option

WithCacheDir sets the model cache directory.

func WithDType

func WithDType(dtype string) Option

WithDType sets the compute precision for the GPU engine. Supported values: "" or "fp32" for full precision, "fp16" for FP16 compute. FP16 mode converts activations F32->FP16 before GPU kernels and back after. Has no effect on CPU engines.

func WithDevice

func WithDevice(device string) Option

WithDevice sets the compute device ("cpu" or "cuda").

func WithKVDtype

func WithKVDtype(dtype string) Option

WithKVDtype sets the KV cache storage dtype. Supported: "fp32" (default), "fp16". FP16 halves KV cache bandwidth by storing keys/values in half precision.

func WithMaxBatchConcurrency added in v1.11.0

func WithMaxBatchConcurrency(n int) Option

WithMaxBatchConcurrency sets the maximum number of concurrent goroutines that GenerateBatch will use. Values <= 0 are ignored (the default of 8 is used).

func WithMaxSeqLen

func WithMaxSeqLen(n int) Option

WithMaxSeqLen overrides the model's default max sequence length.

func WithMmap

func WithMmap(enabled bool) Option

WithMmap enables memory-mapped model loading. When true, the ZMF file is mapped into memory using syscall.Mmap instead of os.ReadFile, avoiding heap allocation for model weights. Only supported on unix platforms.

func WithPrecision

func WithPrecision(precision string) Option

WithPrecision sets the compute precision for the TensorRT backend. Supported values: "" or "fp32" for full precision, "fp16" for half precision. Has no effect when the backend is not "tensorrt".

func WithRegistry

func WithRegistry(r registry.ModelRegistry) Option

WithRegistry provides a custom model registry.

func WithSessionPoolSize added in v1.12.0

func WithSessionPoolSize(n int) Option

WithSessionPoolSize sets the session pool capacity. The pool buffers inference sessions for reuse so that CUDA graph-captured GPU pointers remain valid across calls. Minimum value is 1; values below 1 are clamped to 1.

type QwenVLConfig added in v1.8.0

type QwenVLConfig struct {
	// Vision encoder config.
	ImageSize       int
	PatchSize       int
	VisionHiddenDim int
	VisionNumHeads  int
	VisionNumLayers int
	NumChannels     int

	// Multi-modal projector config.
	ProjectorType string // "linear" or "mlp"
}

QwenVLConfig holds Qwen-VL-specific model configuration.

func QwenVLConfigFromGGUF added in v1.8.0

func QwenVLConfigFromGGUF(cfg *gguf.ModelConfig) QwenVLConfig

QwenVLConfigFromGGUF extracts Qwen-VL configuration from GGUF ModelConfig.

type RWKVConfig added in v1.7.0

type RWKVConfig struct {
	NumLayers    int
	HiddenSize   int
	VocabSize    int
	HeadSize     int // WKV head size (default 64)
	NumHeads     int // HiddenSize / HeadSize
	LayerNormEps float32
}

RWKVConfig holds RWKV-specific model configuration.

func RWKVConfigFromGGUF added in v1.7.0

func RWKVConfigFromGGUF(cfg *gguf.ModelConfig) RWKVConfig

RWKVConfigFromGGUF extracts RWKV configuration from GGUF ModelConfig.

type ResidualConfig added in v1.9.0

type ResidualConfig struct {
	Mode      string // "standard" (default), "attnres", or "block_attnres"
	NumBlocks int    // block count for "block_attnres" mode (default 8)
}

ResidualConfig controls the residual connection strategy used by architecture graph builders. The default mode ("standard" or "") preserves existing behaviour. "attnres" and "block_attnres" enable attention-weighted residual connections from the layers/residual package (arXiv:2603.15031).

GGUF metadata convention

Models opt into attention residuals via two GGUF general-metadata keys:

  • general.residual_mode (string): one of "standard" (default), "attnres", or "block_attnres". When absent or empty, standard additive residuals are used and no extra memory is allocated.

  • general.attnres_blocks (uint32): number of blocks for the block_attnres variant. Ignored when residual_mode is not "block_attnres". Defaults to 8 when unset, which recovers most of the benefit of full AttnRes.

These keys follow the GGUF general metadata namespace convention (see https://github.com/ggerganov/ggml/blob/master/docs/gguf.md). ResidualConfigFromGGUF parses these values into a ResidualConfig.

func DefaultResidualConfig added in v1.9.0

func DefaultResidualConfig() ResidualConfig

DefaultResidualConfig returns a ResidualConfig with standard (no-op) residuals.

func ResidualConfigFromGGUF added in v1.9.0

func ResidualConfigFromGGUF(mode string, numBlocks int) ResidualConfig

ResidualConfigFromGGUF builds a ResidualConfig from GGUF model metadata. Missing keys produce the backward-compatible "standard" default.

type Response

type Response struct {
	Content          string
	TokensUsed       int
	PromptTokens     int
	CompletionTokens int
}

Response holds the result of a chat completion.

type RopeScalingConfig

type RopeScalingConfig struct {
	Type                          string  `json:"type"`
	Factor                        float64 `json:"factor"`
	OriginalMaxPositionEmbeddings int     `json:"original_max_position_embeddings"`
}

RopeScalingConfig holds configuration for RoPE scaling methods (e.g., YaRN).

type ShapeRange

type ShapeRange struct {
	Min []int32
	Opt []int32
	Max []int32
}

ShapeRange defines min/opt/max dimensions for a single input tensor. Used with DynamicShapeConfig to support variable-size inputs.

type TRTInferenceEngine

type TRTInferenceEngine struct {
	// contains filtered or unexported fields
}

TRTInferenceEngine holds a TensorRT engine and execution context for inference. It wraps the serialized engine, providing a Forward method that mirrors the graph forward pass but runs through TensorRT.

func (*TRTInferenceEngine) Close

func (e *TRTInferenceEngine) Close() error

Close releases all TensorRT resources.

func (*TRTInferenceEngine) Forward

func (e *TRTInferenceEngine) Forward(inputs []*tensor.TensorNumeric[float32], outputSize int) (*tensor.TensorNumeric[float32], error)

Forward runs inference through TensorRT with the given input tensors. Input tensors must already be on GPU.

type UnsupportedOpError

type UnsupportedOpError struct {
	Ops []string
}

UnsupportedOpError lists the operations that cannot be converted to TensorRT.

func (*UnsupportedOpError) Error

func (e *UnsupportedOpError) Error() string

type WhisperConfig added in v1.5.0

type WhisperConfig struct {
	NumMels    int
	HiddenDim  int
	NumHeads   int
	NumLayers  int
	KernelSize int
}

WhisperConfig holds Whisper-specific model configuration.

func WhisperConfigFromGGUF added in v1.5.0

func WhisperConfigFromGGUF(cfg *gguf.ModelConfig) WhisperConfig

WhisperConfigFromGGUF extracts Whisper configuration from GGUF ModelConfig. Fields are mapped as: HiddenSize -> HiddenDim, NumHeads -> NumHeads, NumLayers -> NumLayers. NumMels defaults to 80, KernelSize defaults to 3.

Directories

Path Synopsis
Package guardian implements prompt template rendering for IBM Granite Guardian safety risk evaluation across 13 pre-defined risk categories.
Package guardian implements prompt template rendering for IBM Granite Guardian safety risk evaluation across 13 pre-defined risk categories.
Package multimodal provides audio preprocessing for audio-language model inference.
Package multimodal provides audio preprocessing for audio-language model inference.
Package parallel provides tensor and pipeline parallelism for distributing inference across multiple GPUs.
Package parallel provides tensor and pipeline parallelism for distributing inference across multiple GPUs.
Package sentiment provides a high-level sentiment classification pipeline that wraps encoder model loading and inference.
Package sentiment provides a high-level sentiment classification pipeline that wraps encoder model loading and inference.
Package timeseries implements time-series model builders.
Package timeseries implements time-series model builders.
features
Package features provides a feature store for the Wolf time-series ML platform.
Package features provides a feature store for the Wolf time-series ML platform.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL